Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dtype support for SegmentAnythingModel #2207

Merged

Conversation

tirthasheshpatel
Copy link
Contributor

What does this PR do?

Segment Anything model can use keras.mixed_precision.set_dype_policy for quick optimizations. This PR fixed the model so that it can be run with any dtype policy set. Also added a test for float32 (default in keras), mixed_float16, and bfloat16.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue? Please add a link
    to it if that's the case.
  • Did you write any new necessary tests?
  • If this adds a new model, can you run a few training steps on TPU in Colab to ensure that no XLA incompatible OP are used?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@divyashreepathihalli
Copy link
Collaborator

Thanks for the PR! LGTM!

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

def test_end_to_end_model_predict(self, dtype_policy):
import threading

with threading.Lock():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's with this? are we running our cv testing multi-processed ever?

Copy link
Contributor Author

@tirthasheshpatel tirthasheshpatel Dec 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can be multi-processed with the -n <num_threads> argumment in pytest. PyTest uses multi-processing and not multi-threading so locking should not be necessary here. I just added this as a safeguard if anyone ever tries to run these tests using Python threads.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Long term, we could move towards Model(dtype=policy) support, so that these tests can run effectively without mutating global state.

# Check the number of parameters
num_parameters = np.sum([np.prod(x.shape) for x in model.weights])
self.assertEqual(num_parameters, 89_670_912 + 6_476 + 4_058_340)
@parameterized.named_parameters(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this test marked as large? (just for my own learning)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The model initialized here is a ViT Base model with 130M parameters. Creating and evaluating it takes about 15-20 seconds which is significantly more than small unit tests in KerasCV.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha, thanks! No need on this PR, but in general it will be good to separate small checks (like dtype stuff) into fast running tests, and keep the large test only for the things that must inherently by large parameter count and slow (like preset tests).

Did a big rewrite of KerasNLP backbones to this effect a bit ago. e.g. https://github.com/keras-team/keras-nlp/blob/a05f411a27eab437e71a1651c97e9addf26298ef/keras_nlp/models/bert/bert_backbone_test.py#L38-L80

@tirthasheshpatel
Copy link
Contributor Author

@divyashreepathihalli This is ready from my side. Feel free to merge if everything looks good to you now!

@divyashreepathihalli divyashreepathihalli merged commit 37ffac0 into keras-team:master Dec 4, 2023
6 of 9 checks passed
@tirthasheshpatel tirthasheshpatel deleted the fix-sam-dtype branch December 4, 2023 23:30
sampathweb pushed a commit that referenced this pull request Dec 6, 2023
* Fix dtype support for SAM

* Update keras_cv/models/segmentation/segment_anything/sam_test.py

* Fix Keras 2 failures

* Fix F401 lint error; remove unused import
sampathweb added a commit that referenced this pull request Dec 6, 2023
* Fix Keras 3 version check (#2191)

* Fix Keras 3 version check

* Fix Keras 3 version check

* Fix Keras 3 version check

* Raise error if Keras is not compatible with TF

* Fix bug when upranking passthrough inputs to RandAugment (#2194)

- RandAugment sometimes will choose a "no augmentation" option and
  passthrough inputs unaltered.
- Preprocessing normalization routines were not making copies of inputs
  and sometimes mutating layer input directly (mutating the input
  dict to cast dtypes and uprank tensors).
- RandAugment under the passthrough option would return these inputs
  directly.

The net effect was sometimes attempting to uprank during a passthrough
call, breaking tf.map_fn

* fix stable diffusion rank error (#2208)

* Simplify running KerasCV with Keras 3 (#2179)

* remove keras_core dependency

* update init

* update readme

* fix model None error (#2176) (#2177)

* Update pycoco_callback.py

* Update waymo_evaluation_callback.py

* fix model None error (#2176) (#2178)

* Update pycoco_callback.py

* Update waymo_evaluation_callback.py

* update readme and conftest

* update readme

* update citation list

* fix mix transformer tests

* fix lint error

* fix all failing tests

* Fix dtype support for SegmentAnythingModel (#2207)

* Fix dtype support for SAM

* Update keras_cv/models/segmentation/segment_anything/sam_test.py

* Fix Keras 2 failures

* Fix F401 lint error; remove unused import

* Version bump to r0.7.2.dev0

---------

Co-authored-by: Matt Watson <1389937+mattdangerw@users.noreply.github.com>
Co-authored-by: Divyashree Sreepathihalli <divyashreepathihalli@gmail.com>
Co-authored-by: Tirth Patel <tirthasheshpatel@gmail.com>
yuvraj-wale pushed a commit to yuvraj-wale/keras-cv that referenced this pull request Feb 8, 2024
* Fix dtype support for SAM

* Update keras_cv/models/segmentation/segment_anything/sam_test.py

* Fix Keras 2 failures

* Fix F401 lint error; remove unused import
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants