Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port DeepLabV3 segmentation model from keras_cv.models.legacy to the Task API #1831

Merged
merged 70 commits into from
Jul 12, 2023

Conversation

soumik12345
Copy link
Contributor

@soumik12345 soumik12345 commented May 27, 2023

What does this PR do?

This PR moves the DeepLabV3 segmentation model from keras_cv.models.legacy and aligns it with the Task API.

A Kaggle TPU notebook demostrating a few training steps: https://www.kaggle.com/code/soumikrakshit/deeplabv3-tpu/

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue? Please add a link
    to it if that's the case.
  • Did you write any new necessary tests?
  • If this adds a new model, can you run a few training steps on TPU in Colab to ensure that no XLA incompatible OP are used?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@soumik12345 soumik12345 marked this pull request as draft May 27, 2023 07:50
@soumik12345 soumik12345 marked this pull request as ready for review May 28, 2023 23:43
@soumik12345
Copy link
Contributor Author

cc: @jbischof

@soumik12345 soumik12345 changed the title Port DeepLabV3 segmentation model from keras_cv.models.legacy Port DeepLabV3 segmentation model from keras_cv.models.legacy to Task API May 28, 2023
@soumik12345 soumik12345 changed the title Port DeepLabV3 segmentation model from keras_cv.models.legacy to Task API Port DeepLabV3 segmentation model from keras_cv.models.legacy to the Task API May 28, 2023
Copy link
Contributor

@ianstenbit ianstenbit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the awesome PR!

keras_cv/models/segmentation/deeplab_v3/deeplab_v3.py Outdated Show resolved Hide resolved
keras_cv/models/segmentation/deeplab_v3/deeplab_v3.py Outdated Show resolved Hide resolved
keras_cv/models/segmentation/__init__.py Outdated Show resolved Hide resolved
keras_cv/models/segmentation/__init__.py Outdated Show resolved Hide resolved
keras_cv/layers/segmentation/segmentation_head.py Outdated Show resolved Hide resolved
@soumik12345
Copy link
Contributor Author

Hi @ianstenbit
Made the changes requested and updated the notebook linked at #1831 (comment)

Copy link
Contributor

@ianstenbit ianstenbit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

High-level question:

Should we just directly offer DeepLabV3Plus? I see from the papers that this is a relatively straightforward addition on top of DeepLabV3, and I see that you wrote a Keras tutorial on this: https://keras.io/examples/vision/deeplabv3_plus/

Unless there's a reason to offer both, I'd suggest that we update this to implement DeepLabV3Plus, and then we can verify the forward pass numerics against your KerasIO tutorial.

Copy link
Contributor

@jbischof jbischof left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work!


inputs = backbone.input

final_backbone_pyramid_output = backbone.get_layer(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it safe to assume that this dictionary is sorted? We may need to extract the max manually.

Also: why not use the output of the backbone as the max layer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was done as per this request.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in general we probably can just use the backbone output here.

I brought this up originally because I wasn't clear on whether DeepLabV3Plus always wants to use the highest-available P-level here or if it always wants to use e.g. P4. It seems like it may want the highest P-level always, which I think should mean it always matches up with the backbone output.

Copy link
Contributor

@jbischof jbischof Jun 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the last P-level is exactly the same as the output, although it is the same shape in the width and height dimensions. There are often extra normalization and activations (and even extra conv layers!) afterward that we omit for 1:1 comparability with the previous levels (e.g., ResNetV2, MobileNet)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jbischof
I attempted to train with the output of the backbone as the features, and it seems that the outputs are significantly worse in that case, compared to the P5 features.

Here's a comparison: https://wandb.ai/geekyrakshit/deeplabv3-keras-cv/reports/Backbone-Outputs-vs-P5-Features--Vmlldzo0NjQ5NDUw

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm having trouble reading that chart. Which color is P5? Which is backbone?

"backbone": self.backbone,
"spatial_pyramid_pooling": self.spatial_pyramid_pooling,
"projection_filters": self.projection_filters,
"segmentation_head": self.segmentation_head,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still struggle with serialization. Do we need to explicitly serialize and deserialize a submodel like this for saving to work properly? I see a saving test but I know there are gotchas.

used as a feature extractor for the DeepLabV3+ Encoder. Should
either be a `keras_cv.models.backbones.backbone.Backbone` or a
`tf.keras.Model` that implements the `pyramid_level_inputs`
property with keys "P2", "P3", "P4", and "P5" and layer names as
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need all five keys? It just seems like "P2" now

@jbischof
Copy link
Contributor

jbischof commented Jun 15, 2023

Do we have a reference implementation for this model where we can check the forward pass and maybe grab some pretrained weights?

Also: don't forget to set up presets!

Thanks for your hard work 🚀

@ianstenbit
Copy link
Contributor

@soumik12345 do you think you'll be able to finish up this PR soon?

If not I may push some changes to your branch so that we can get this finished -- thank you!

@soumik12345
Copy link
Contributor Author

@soumik12345 do you think you'll be able to finish up this PR soon?

If not I may push some changes to your branch so that we can get this finished -- thank you!

Hi @ianstenbit
There are probably no changes left to make from an architectural perspective. I can train the model on the Cityscapes fine-annotation set (without extra data) and add some presets.

Please let me know if there are any immediate changes that I can make.

@ianstenbit
Copy link
Contributor

@soumik12345 do you think you'll be able to finish up this PR soon?
If not I may push some changes to your branch so that we can get this finished -- thank you!

Hi @ianstenbit There are probably no changes left to make from an architectural perspective. I can train the model on the Cityscapes fine-annotation set (without extra data) and add some presets.

Please let me know if there are any immediate changes that I can make.

Yeah I think the architecture + tests are looking good -- just want to verify the performance by training it against a benchmark 😄

"pixel-level semantic labeling task, which consists of fine "
"annotations for train and val sets (3475 annotated images) "
"and 34 classes. This model achieves a final Mean IoU of "
"0.36 on the validation set for fine annotations."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this compare to the expected performance of this model on this dataset?

@ianstenbit
Copy link
Contributor

@soumik12345 we'd like to go ahead and merge this for now (without worrying about verifying the performance just yet, and without porting to Keras Core).

Would you be willing to:

  • Fix the lint error
  • Remove the preset for now

and then I can merge it?

I tried to push to to your fork to make the changes myself but I don't have access.

@soumik12345
Copy link
Contributor Author

@soumik12345 we'd like to go ahead and merge this for now (without worrying about verifying the performance just yet, and without porting to Keras Core).

Would you be willing to:

  • Fix the lint error
  • Remove the preset for now

and then I can merge it?

I tried to push to to your fork to make the changes myself but I don't have access.

Hi @ianstenbit removed the presets and fixed linting.
Immense Thanks to you and @jbischof for guiding me through this PR!

@ianstenbit
Copy link
Contributor

@soumik12345 we'd like to go ahead and merge this for now (without worrying about verifying the performance just yet, and without porting to Keras Core).
Would you be willing to:

  • Fix the lint error
  • Remove the preset for now

and then I can merge it?
I tried to push to to your fork to make the changes myself but I don't have access.

Hi @ianstenbit removed the presets and fixed linting. Immense Thanks to you and @jbischof for guiding me through this PR!

Wonderful, thank you for the contribution! If the CI is happy then I will go ahead and merge this 😄

@ianstenbit
Copy link
Contributor

Looks like the linter is still unhappy.

./keras_cv/models/segmentation/deeplab_v3_plus/deeplab_v3_plus.py:28:81: E501 line too long (127 > 80 characters)
./keras_cv/models/segmentation/deeplab_v3_plus/deeplab_v3_plus.py:30:81: E501 line too long (107 > 80 characters)
./keras_cv/models/segmentation/deeplab_v3_plus/deeplab_v3_plus.py:48:81: E501 line too long (132 > 80 characters)

You'll need to add # noqa: E501 to the end of the docstring for DeepLabV3Plus (See YOLOV8Detector as an example of how to do this)

@ianstenbit
Copy link
Contributor

Keras Core (incl. accelerator) testing is not yet relevant for this, so this is good to go

@ianstenbit ianstenbit dismissed jbischof’s stale review July 12, 2023 21:17

will fix during port to KerasCore

@ianstenbit ianstenbit merged commit 3b97525 into keras-team:master Jul 12, 2023
6 of 9 checks passed
ghost pushed a commit to y-vectorfield/keras-cv that referenced this pull request Nov 16, 2023
…e `Task` API (keras-team#1831)

* add: DeepLabV3 backbone + preset placeholders

* add: presents for DeepLabV3 backbone

* update: minor changes

* update: minor changes

* add: license info

* update: refactored DeepLabV3 as a single Task

* add: deeplab_v3 module inside segmentation module

* update: refactor segmentation head

* add: serialization test for SegmentationHead

* remove: legacy segmentation models dir

* remove: legacy segmentation imports

* update: added weight_decay in compile()

* add: docstring for deeplabv3

* update: remove compile() and train_step() from DeepLabV3

* update: imports

* update: copyright info

* remove: Segmentation Head layer implementation + serialization test

* update: refactor DeepLabV3 segmentation head as a keras Sequential

* update: made abstractions to implement encoder-decoder architecture

* update: _make_deeplabv3_encoder

* refactor: docstrings for DeepLabV3

* refactor: docstrings for DeepLabV3

* refactor: DeepLabV3

* update: refactored DeepLabV3 to implement DeepLabV3+

* update: docstring

* update: docstring

* update: removed segmentation_head_activation

* update: removed input_shape and input_tensor

* update: remove outdated comment

* update: docstring

* update: replaced low_level_feature_layer_name with low_level_feature_pyramid_level

* update: docastring

* update: DeepLabV3

* update: remove unused imports

* update: changed DeepLabV3 to DeepLabV3+

* update: remove build

* update: docstring

* update: add docstring for dropout

* update: made encoder output upsampling dynamic

* add: tests for DeepLabV3+

* update: order of parameters

* update: example in docstring

* update: renamed feature_map

* update: docs for dropout

* update: renamed tests

* update: update tests

* update: removed test_no_nans

* update: fix linting

* update: remove cleanup_global_session

* update: tests

* update: fix linting

* update: incorporated feedback

* update: incorporated feedback

* update: incorporated feedback

* add: preset for deeplabv3+

* update: remove presets

* fix: linting

* update: make CI happy
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants