Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the Segment Anything Model to KerasCV #1987

Merged
merged 15 commits into from
Sep 19, 2023

Conversation

tirthasheshpatel
Copy link
Contributor

@tirthasheshpatel tirthasheshpatel commented Jul 28, 2023

What does this PR do?

This PR implements the Segment Anything Model in multi-backend Keras.

Fixes #1679
See also #1933

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue? Please add a link
    to it if that's the case.
  • Did you write any new necessary tests?
  • If this adds a new model, can you run a few training steps on TPU in Colab to ensure that no XLA incompatible OP are used?

Who can review?

@ianstenbit @DavidLandup0

Copy link
Contributor

@ianstenbit ianstenbit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

This is super exciting 🎊

Just a few comments as I took a quick look



@keras.utils.register_keras_serializable(package="keras_cv")
class MLPBlock(keras.layers.Layer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems a bit heavy to make this a class since we can just make this a pair of dense layers in a sequential wherever it's used.

Copy link
Contributor Author

@tirthasheshpatel tirthasheshpatel Jul 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is used a few times in the mask decoder actually. So, just inline the dense layers would just duplicate a lot of code. Is there any side effect of having this? If not, I'd prefer keeping it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's re-used in many places then I am alright with it -- it looked to me like it was only used once or twice but I probably missed some uses



@keras.utils.register_keras_serializable(package="keras_cv")
class SAMLayerNormalization(keras.layers.Layer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there no way to parameterize keras.layers.LayerNormalization to achieve this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is keras.layers.LayerNormalization(epsilon=1e-6). I will push this in the next batch of commits.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably double check and don't take my word for it, but I'm not sure if the numerics are the same. keras.layers.LayerNormalization is:

        # Compute the batch normalization.
        inv = 1 / ops.sqrt(variance + self.epsilon)
        if scale is not None:
            scale = ops.cast(scale, inputs.dtype)
            inv = inv * scale
        x = -mean * inv
        if offset is not None:
            offset = ops.cast(offset, inputs.dtype)
            x = offset + x
        outputs = inputs * ops.cast(inv, inputs.dtype) + ops.cast(
            x, inputs.dtype
        )

        outputs = ops.cast(outputs, input_dtype)

        # If some components of the shape got lost due to adjustments, fix that.
        outputs = ops.reshape(outputs, ops.shape(inputs))

For SAM, they call it LayerNorm2d() in the official implementation, but the official impl is taken directly from Detectron2 which has BatchNorm2D and LayerNorm in turn taken from ConvNeXt: https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119

Technically, the LayerNorm re-implementation in ConvNeXt, SAM and Detectron2 shouldn't be the same LayerNorm from PyTorch.

More in these issues:

Copy link
Contributor Author

@tirthasheshpatel tirthasheshpatel Jul 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for linking these issues @DavidLandup0, I too was wondering why this was reimplemented. After some testing, I can confirm that keras.layers.LayerNormalization(epsilon=1e-6) is numerically equivalent to SAMLayerNormalization() for segment anything. Here's the code I used to test:

import os
os.environ["KERAS_BACKEND"] = "torch"

import numpy as np
import torch
import keras_core as keras
from keras_cv.models.segmentation.segment_anything import sam_layers

sam_ln = sam_layers.SAMLayerNormalization()
ln = keras.layers.LayerNormalization(epsilon=1e-6)

sam_ln.build((1, 512, 512, 3))
ln.build((1, 512, 512, 3))

sam_ln.set_weights(ln.weights)

x_np = np.random.randint(0, 256, size=(1, 512, 512, 3), dtype=np.uint8)
x_np = x_np.astype(np.float32)
x = torch.tensor(x_np, requires_grad=True)
x_sam = torch.tensor(x_np, requires_grad=True)

x_out_sam = sam_ln(x_sam)
x_out = ln(x)

x_out_sam.backward(torch.ones_like(x_out_sam))
x_out.backward(torch.ones_like(x_out))

np.testing.assert_allclose(
    x_out_sam.detach().numpy(),
    x_out.detach().numpy(),
    rtol=8e-5
)
np.testing.assert_allclose(
    ln.weights[0].value.grad.detach().numpy(),
    sam_ln.weights[2].value.grad.detach().numpy(),
    rtol=6e-7
)
np.testing.assert_allclose(
    ln.weights[1].value.grad.detach().numpy(),
    sam_ln.weights[3].value.grad.detach().numpy()
)
np.testing.assert_allclose(
    x_sam.grad.detach().numpy(),
    x.grad.detach().numpy(),
    atol=3e-7
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, thanks for checking! That simplifies things a lot :D

image_pe,
sparse_prompt_embeddings,
dense_prompt_embeddings,
multimask_output,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this have a default?

Copy link
Contributor Author

@tirthasheshpatel tirthasheshpatel Jul 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should. I haven't set any sensible defaults yet. I will update all the layers with some default values that make sense.

@@ -0,0 +1,13 @@
# Copyright 2023 The KerasCV Authors
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should probably have a top-level SegmentAnything model which takes an ImageEncoder as a backbone and subclasses Task. Then the high-level workflows can live on that model.

Then we can also include a preset which includes your ported weights!

Let's also include a reference to the paper and original implementation

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The image encoder for SAM is a near 1:1 from Detectron2's ViTDet.
IMO it makes sense to have ViTDet as a standalone class/network rather than a SAM encoder only.
That way we can train it from scratch, have it as a standalone object detection model, a backbone for SAM and reuse the same code across all of that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree @DavidLandup0, I will move the layer to keras_cv/layers.

Copy link
Contributor Author

@tirthasheshpatel tirthasheshpatel Jul 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should probably have a top-level SegmentAnything model which takes an ImageEncoder as a backbone and subclasses Task. Then the high-level workflows can live on that model.

Then we can also include a preset which includes your ported weights!

That's the plan. I will add a Task model in the kera_cv/models/segmentation/segment_anything/sam.py file. I am not yet sure how exactly the training step would be implemented with the Task API, we could just raise a NotImplementedError for now and write a train step as a follow-up.

Copy link
Contributor Author

@tirthasheshpatel tirthasheshpatel Aug 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update: I have moved the image encoder to a standalone backbone. Let me know if that looks good to you. Thanks for the suggestion @DavidLandup0!

(I will add a Task model next)

Copy link
Contributor

@ianstenbit ianstenbit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great progress!

from keras_cv.models.segmentation.segment_anything.sam_layers import MLPBlock


def get_rel_pos(query_size, key_size, rel_pos):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move the helper functions to the bottom of the file

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in ac7f30e

x_out = ops.convert_to_numpy(attention_with_rel_pe(x))
self.assertEqual(x_out.shape, (1, 64, 64, 1280))

def test_windowed_transformer_encoder(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a test for ViTDetPatchingAndEmbedding as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in ac7f30e

from keras_cv.utils.python_utils import classproperty


@keras.utils.register_keras_serializable(package="keras_cv.models")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry that this changed while this PR is in-flight, but if you sync to master this should now be

from keras_cv.api_export import keras_cv_export

@keras_cv_export("keras_cv.models.ViTDetBackbone")

(Same for all new public API symbols)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in ac7f30e


def __init__(
self,
img_size=1024,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've standardized on input_shape for this in other backbones (accepting a tuple of (height, width, channels))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to add a input_shape attribute but since I am not using the Functional syntax for building the model, it doesn't allow me to add that attribute.

In case you didn't notice, I am using the call method to specify the computations in the model instead of passing a symbolic input through each layer/operation in the __init__ method.

I noticed that it's just easier not to deal with symbolic inputs with Keras Core. One of the main reasons why Keras Core struggles with symbolic inputs is that it doesn't do shape inference. For example, in Keras, this works:

from tensorflow import keras

x = keras.Input([2, 3])
tf.shape(x)[0] * 10  # Note that even though the shape at axis 0 is None,
                     # TensorFlow returns a symbolic tensor making computations
                     # like these valid instead of throwing an exception.

but Keras Core fails, since x.shape[0] is None.

I think it should not be difficult to convert the implementation to fully use the Functional syntax but we would have to manually check if the shapes we are getting are Nones or not. So, for example, we could do this in Keras Core:

import keras_core
from keras_core import ops

x = keras_core.Input([2, 3])
if x.shape[0] is not None:
    x = ops.reshape(x, (x.shape[0] * 2, 3))
else:
    x = ops.reshape(x, (None, 3))

or use some other operation.

Sorry for not highlighting these details properly beforehand! I will add a bunch of comments where I do something unintuitive so it's easier for future reviewers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, turns out the shape errors were not a problem here. So, got the model ported to the Functional syntax and it should now be consistent with other backbones. I have added include_rescaling, and input_tensor arguments along with the input_shape arguments. I also tested that the weights port and can be saved/loaded in any backend. Let me know if this resolves the consistency issues.



@keras.utils.register_keras_serializable(package="keras_cv.models")
class ViTDetBackbone(Backbone):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume that this backbone doesn't produce pyramid_level_outputs since it's a transformer architecture -- let's call this out in the docstring, and maybe even create an @Property for self.pyramid_level_inputs which throws a nice NotImplementedError

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in ac7f30e

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The backbone outputs the same shapes itself, but they do use a feature pyramid output: https://arxiv.org/pdf/2203.16527.pdf

TL:DR for the paper, the simple feature pyramid on the right turned out to be the most performant for them.
image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @DavidLandup0, thanks for pointing out the paper! I didn't know the authors also proposed a FPN! I did look into the paper but don't know how the pyramid-level inputs would fit in the backbone here. Given this PR has already blown up a bit, I'd prefer to do this as a follow-up. Maybe you can take it up if you have time :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doing this as a follow-up sgtm



@keras.utils.register_keras_serializable(package="keras_cv")
class MLP(keras.layers.Layer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this different from MLPBlock?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could unite those. The only difference is that the MLPBlock has architecture embedding_dim -> mlp_dim -> embedding_dim while MLP has architecture input_dim -> [hidden_dim] * (num_layers - 1) -> output_dim. Looks like a low-hanging fruit, will address in the next commit.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

from keras_cv.tests.test_case import TestCase


class TestSAM(TestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: SAMTest

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@ianstenbit
Copy link
Contributor

@tirthasheshpatel LMK when you're ready for another review on this 😄

- Use `keras_cv.export_api.keras_cv_export` instead of `keras.saving.register_keras_serializable`.
- Add a `SerializableSequential` class to address the saving bug with the `Sequential` model.
- Push the helper functions in `keras_cv/layers/detectron2_layers.py` to the bottom of the file.
- Add the detectron2 layers to the `keras_cv/layers/__init__.py` file.
- Add a test for the `ViTDetPatchingAndEmbedding` layer.
@tirthasheshpatel
Copy link
Contributor Author

Hi @ianstenbit, thank you very much for your reviews so far! Very helpful!

LMK when you're ready for another review on this

I have a Task API for the SAM model ready, I will just have to make some more changes before pushing it + add tests. I will mark the PR ready for review once I do that and help start reviews. But if you have time, feel free to review!

@@ -17,6 +17,10 @@
from tensorflow.keras.layers import RandomWidth

from keras_cv.layers.augmenter import Augmenter
from keras_cv.layers.detectron2_layers import AddPositionalEmbedding
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to put these under a detectron2 namespace?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -17,6 +17,10 @@
from tensorflow.keras.layers import RandomWidth

from keras_cv.layers.augmenter import Augmenter
from keras_cv.layers.detectron2_layers import AddPositionalEmbedding
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this would be exported as part of the public API - we have PatchingAndEmbedding which does patching with a Conv2D and then adds embeddings in this same form. Do we want to update that layer to use AddPositionalEmbedding as well for conformity?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't adding a new layer in a pre-existing class invalidate the weights set for the ViT model? Also, since PatchingAndEmbedding is still a TensorFlow Keras layer, I think, for the time being, it'd be easier to keep the two separate.

Although, I'd add a comment about this as a TODO so we don't forget to do it in the future. What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.



@keras_cv_export("keras_cv.layers.MultiHeadAttentionWithRelativePE")
class MultiHeadAttentionWithRelativePE(keras.layers.Layer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps it would make sense to do an AddRelativePositionalEmbedding class for consistency with the aforementioned AddPositionalEmbedding?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

)

if self.use_rel_pos:
attention_map = add_decomposed_rel_pos(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably be a private method as part of a layer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

if self.window_size > 0:
H, W = x.shape[1], x.shape[2]

x, HW_padded = window_partition(x, self.window_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about doing this as a layer instead of a method?

I.e. https://github.com/DavidLandup0/deepvision/blob/main/deepvision/layers/window_partitioning.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Instead of creating two classes, one for partitioning and one for unpartitioning, I handled both in a single class. Let me know if that looks good.



@keras_cv_export("keras_cv.layers.ViTDetPatchingAndEmbedding")
class ViTDetPatchingAndEmbedding(keras.layers.Layer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the same as the ViT patching and embedding but without positional embedding.
I'm torn between being able to turn off PE in the default layer and adding that as a flag and having a new layer for this...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you referring to the PathingAndEmbedding class for the ViT model in KerasCV? I addressed that here: #1987 (comment)

return config


def get_rel_pos(query_size, key_size, rel_pos):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably be a private method or turned into a public layer.
I.e.: https://github.com/DavidLandup0/deepvision/blob/main/deepvision/layers/decomposed_relative_positional_embedding.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

return ops.take(rel_pos_resized, relative_coordinates, 0)


def add_decomposed_rel_pos(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

# This only happens when the `build` method is called in the `__init__`
# step.
@keras_cv_export("keras_cv.layers.SerializableSequential")
class SerializableSequential(keras.layers.Layer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this still an issue in Keras Core?

Copy link
Contributor Author

@tirthasheshpatel tirthasheshpatel Aug 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bug has been addressed in Keras Core v0.1.5 but the latest TensorFlow Keras still has it. So, weights won't load in TensorFlow Keras until the bug is addressed in the next release.

We can either:

  1. Drop support temporarily for TensorFlow Keras just for this model with a note in the docs. With the new release of TF Keras, the bug should be fixed and we can remove the note.
  2. Keep the simple replication of the class until the bug is resolved in TensorFlow Keras.

I am leaning more towards option 2 but I don't have a strong opinion. What do you think @ianstenbit @DavidLandup0?

Copy link
Contributor Author

@tirthasheshpatel tirthasheshpatel Sep 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ended up removing it since the legacy weights load in all backends in Keras Core and also in TF Keras. I think until some saving issues are addressed with the new .weights.h5 format, we should just use the legacy weights. Let me know what you both think!

@@ -43,6 +43,18 @@
from keras_cv.models.backbones.densenet.densenet_backbone import (
DenseNetBackbone,
)
from keras_cv.models.backbones.detectron2.detectron2_aliases import (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Afaik, the backbone is basically the same as the official ViTDet, so there may not be a need to call it a SAM{name}Backbone

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, thanks for looking into it!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -166,5 +178,8 @@
YOLOV8Detector,
)
from keras_cv.models.segmentation import DeepLabV3Plus
from keras_cv.models.segmentation import MaskDecoder
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be better as SAMMaskDecoder for clarity

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably left by accident?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added this intentionally. This is used in the tests to verify that the model weights are loaded correctly and that the forward pass in all backends yields the same result.

"""Dictionary of preset names and configurations."""
return copy.deepcopy(backbone_presets)

# @classproperty
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stray comment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method loads the presets with weights. I will uncomment it later once the model layers are finalized and the final weights are uploaded.



@keras_cv_export("keras_cv.layers.MLP")
class MLP(keras.layers.Layer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a public class here - it should probably be a private subclass, especially since there was an MLP with the same name in a layer related to this, iirc

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I have removed the export.

@@ -0,0 +1,230 @@
# Copyright 2023 The KerasCV Authors
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd probably put this layer under sam layers

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


@keras_cv_export("keras_cv.models.MaskDecoder")
class MaskDecoder(keras.models.Model):
"""Mask decoder for the segment anything model.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: "Segment Anything (SAM)"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.



@keras_cv_export("keras_cv.models.MaskDecoder")
class MaskDecoder(keras.models.Model):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned before, to avoid confusion, probably best if this is called SAMMaskDecoder or something along those lines

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

network. Defaults to "gelu".

References:
- [Segment Anything](https://arxiv.org/abs/2304.02643)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably want a code reference as well

@tirthasheshpatel tirthasheshpatel marked this pull request as ready for review September 9, 2023 23:13
@tirthasheshpatel tirthasheshpatel changed the title [WIP] Add the Segment Anything Model to KerasCV Add the Segment Anything Model to KerasCV Sep 9, 2023
@tirthasheshpatel
Copy link
Contributor Author

I think the PR is almost ready for some thorough reviews except for a few TODOs:

  1. I ported the weights to the Keras Core model and am able to load in any backend but loading weights is broken between Keras Core and TensorFlow Keras (xref Saving broken between tf.keras and Keras Core keras-core#855)
  2. I need to add more docs (especially examples) for the internal layers that are exported.
  3. Need to upload weights and add them as presets here. I think we need to first finalize a few blockers (point 1 above and this discussion) before uploading the final weights set.

Let me know if you have any other major points @DavidLandup0 @ianstenbit. And thanks for the reviews so far, super helpful!

@tirthasheshpatel
Copy link
Contributor Author

tirthasheshpatel commented Sep 14, 2023

  1. I ported the weights to the Keras Core model and am able to load in any backend but loading weights is broken between Keras Core and TensorFlow Keras (xref

An update on this: the legacy weights *.h5 load in both TF Keras and Keras Core (all backends). I think we can just use that until the saving fixes are available in TF Keras. Also, SerializableSequential is no longer needed when using legacy weights.

Copy link
Contributor

@ianstenbit ianstenbit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only thing left that I see is adding a preset for the pre-trained version of SAM.

Thanks for your great work!



@keras.utils.register_keras_serializable(package="keras_cv.models")
class ViTDetBackbone(Backbone):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doing this as a follow-up sgtm

@ianstenbit
Copy link
Contributor

/gcbrun

Copy link
Contributor

@ianstenbit ianstenbit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is awesome -- thanks Tirth!

Just one little fix to make GCBRun happy

keras_cv/models/segmentation/segment_anything/sam_test.py Outdated Show resolved Hide resolved
@ianstenbit
Copy link
Contributor

/gcbrun

@ianstenbit ianstenbit merged commit bc80fbb into keras-team:master Sep 19, 2023
8 of 9 checks passed
@tirthasheshpatel
Copy link
Contributor Author

Thanks, @DavidLandup0 @ianstenbit for your reviews! This was fun to work on. Excited to have this in KerasCV!

The next steps are to add some guides to use and train the model. It would also be nice to have some benchmarks. On it now! But I will also create a tracking issue in case the community wants to take over some of these tasks.

@ianstenbit
Copy link
Contributor

Thanks, @DavidLandup0 @ianstenbit for your reviews! This was fun to work on. Excited to have this in KerasCV!

The next steps are to add some guides to use and train the model. It would also be nice to have some benchmarks. On it now! But I will also create a tracking issue in case the community wants to take over some of these tasks.

Thank you Tirth for your outstanding work on this -- we really appreciate it!

I think our long-term goal should be to add support for text prompts. There are some community projects out there which demonstrate the feasibility of this, and I think it would be a great step for us.

But I 100% agree that some guides and training are the right place to start!

@tirthasheshpatel tirthasheshpatel deleted the add-sam branch September 19, 2023 03:33
@tirthasheshpatel tirthasheshpatel restored the add-sam branch September 19, 2023 06:29
@tirthasheshpatel tirthasheshpatel deleted the add-sam branch September 19, 2023 06:29
ghost pushed a commit to y-vectorfield/keras-cv that referenced this pull request Nov 16, 2023
* Start adding components for the segment anything model

* SAMLayerNormalization -> keras.layers.LayerNormalization

They both behave exactly the same when moving_mean and moving_variance are None and epsilon is 1e-6

* Move the image encoder to detectron2 backbone and fix for tf.keras backend

* Address review comments and address saving bug

- Use `keras_cv.export_api.keras_cv_export` instead of `keras.saving.register_keras_serializable`.
- Add a `SerializableSequential` class to address the saving bug with the `Sequential` model.
- Push the helper functions in `keras_cv/layers/detectron2_layers.py` to the bottom of the file.
- Add the detectron2 layers to the `keras_cv/layers/__init__.py` file.
- Add a test for the `ViTDetPatchingAndEmbedding` layer.

* Make the backbone functional; unite MLP and MLPBlock

* Address David's review comments

* Add SAM Task model; make MaskDecoder and PromptEncoder XLA compatible

* Remove a stray file

* Add docs for the Task model

* Add more references

[skip ci]

* Remove SerializableSequential layer

* detectron2 -> vit_det; add SAM presets; fix ViTDet presets

* Increse test tolerence for GCB Run
yuvraj-wale pushed a commit to yuvraj-wale/keras-cv that referenced this pull request Feb 8, 2024
* Start adding components for the segment anything model

* SAMLayerNormalization -> keras.layers.LayerNormalization

They both behave exactly the same when moving_mean and moving_variance are None and epsilon is 1e-6

* Move the image encoder to detectron2 backbone and fix for tf.keras backend

* Address review comments and address saving bug

- Use `keras_cv.export_api.keras_cv_export` instead of `keras.saving.register_keras_serializable`.
- Add a `SerializableSequential` class to address the saving bug with the `Sequential` model.
- Push the helper functions in `keras_cv/layers/detectron2_layers.py` to the bottom of the file.
- Add the detectron2 layers to the `keras_cv/layers/__init__.py` file.
- Add a test for the `ViTDetPatchingAndEmbedding` layer.

* Make the backbone functional; unite MLP and MLPBlock

* Address David's review comments

* Add SAM Task model; make MaskDecoder and PromptEncoder XLA compatible

* Remove a stray file

* Add docs for the Task model

* Add more references

[skip ci]

* Remove SerializableSequential layer

* detectron2 -> vit_det; add SAM presets; fix ViTDet presets

* Increse test tolerence for GCB Run
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add Segment Anything Model (SAM)
3 participants