-
Notifications
You must be signed in to change notification settings - Fork 301
Add BASNet to keras hub #1984
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Add BASNet to keras hub #1984
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
efd1042
adding initial basnet files
laxmareddyp 9c3d8d8
run api_gen.sh
laxmareddyp 2f5dce1
Merge branch 'keras-team:master' into laxma_basnet
laxmareddyp e51a891
Addressing Matt minor changes except backbone class changes
laxmareddyp 4d5dc6c
Fixing lint errors and run api_gen
laxmareddyp dc9619c
Merge branch 'keras-team:master' into laxma_basnet
laxmareddyp f018bbf
separate backbone, add compute_loss
laxmareddyp 756e8a1
reverting unwanted changes to branch
laxmareddyp 8e4163a
separate backbone, compute_loss, fix tests
laxmareddyp b2e9127
Merge branch 'keras-team:master' into laxma_basnet
laxmareddyp ac961c5
Fix format issues, removed presets file
laxmareddyp 2ccfde9
adding deleted presets file in previous commit
laxmareddyp 7665914
Fix format issues, removed presets, newline issues
laxmareddyp a32e2f3
Fix for docstrings
laxmareddyp 7963b47
Add basnet conversion script
laxmareddyp ffc8469
Fix format issue in conversion script
laxmareddyp c64d589
Fix format issue
laxmareddyp 3233586
Fix pytorch GPU test,removed weight conversion script
laxmareddyp File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone | ||
from keras_hub.src.models.basnet.basnet_presets import basnet_presets | ||
from keras_hub.src.utils.preset_utils import register_presets | ||
|
||
register_presets(basnet_presets, BASNetBackbone) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
import keras | ||
|
||
from keras_hub.src.api_export import keras_hub_export | ||
from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone | ||
from keras_hub.src.models.basnet.basnet_preprocessor import BASNetPreprocessor | ||
from keras_hub.src.models.image_segmenter import ImageSegmenter | ||
|
||
|
||
@keras_hub_export("keras_hub.models.BASNetImageSegmenter") | ||
class BASNetImageSegmenter(ImageSegmenter): | ||
"""BASNet image segmentation task. | ||
|
||
Args: | ||
backbone: A `keras_hub.models.BASNetBackbone` instance. | ||
preprocessor: `None`, a `keras_hub.models.Preprocessor` instance, | ||
a `keras.Layer` instance, or a callable. If `None` no preprocessing | ||
will be applied to the inputs. | ||
|
||
Example: | ||
```python | ||
import keras_hub | ||
|
||
images = np.ones(shape=(1, 288, 288, 3)) | ||
labels = np.zeros(shape=(1, 288, 288, 1)) | ||
|
||
image_encoder = keras_hub.models.ResNetBackbone.from_preset( | ||
"resnet_18_imagenet", | ||
load_weights=False | ||
) | ||
backbone = keras_hub.models.BASNetBackbone( | ||
image_encoder, | ||
num_classes=1, | ||
image_shape=[288, 288, 3] | ||
) | ||
model = keras_hub.models.BASNetImageSegmenter(backbone) | ||
|
||
# Evaluate the model | ||
pred_labels = model(images) | ||
|
||
# Train the model | ||
model.compile( | ||
optimizer="adam", | ||
loss=keras.losses.BinaryCrossentropy(from_logits=False), | ||
metrics=["accuracy"], | ||
) | ||
model.fit(images, labels, epochs=3) | ||
``` | ||
""" | ||
|
||
backbone_cls = BASNetBackbone | ||
preprocessor_cls = BASNetPreprocessor | ||
|
||
def __init__( | ||
self, | ||
backbone, | ||
preprocessor=None, | ||
**kwargs, | ||
): | ||
# === Functional Model === | ||
x = backbone.input | ||
outputs = backbone(x) | ||
# only return the refinement module's output as final prediction | ||
outputs = outputs["refine_out"] | ||
super().__init__(inputs=x, outputs=outputs, **kwargs) | ||
|
||
# === Config === | ||
self.backbone = backbone | ||
self.preprocessor = preprocessor | ||
|
||
def compute_loss(self, x, y, y_pred, *args, **kwargs): | ||
# train BASNet's prediction and refinement module outputs against the | ||
# same ground truth data | ||
outputs = self.backbone(x) | ||
losses = [] | ||
for output in outputs.values(): | ||
losses.append(super().compute_loss(x, y, output, *args, **kwargs)) | ||
return keras.ops.sum(losses, axis=0) | ||
|
||
def compile( | ||
self, | ||
optimizer="auto", | ||
loss="auto", | ||
metrics="auto", | ||
**kwargs, | ||
): | ||
"""Configures the `BASNet` task for training. | ||
|
||
`BASNet` extends the default compilation signature | ||
of `keras.Model.compile` with defaults for `optimizer` and `loss`. To | ||
override these defaults, pass any value to these arguments during | ||
compilation. | ||
|
||
Args: | ||
optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer` | ||
instance. Defaults to `"auto"`, which uses the default | ||
optimizer for `BASNet`. See `keras.Model.compile` and | ||
`keras.optimizers` for more info on possible `optimizer` | ||
values. | ||
loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance. | ||
Defaults to `"auto"`, in which case the default loss | ||
computation of `BASNet` will be applied. | ||
See `keras.Model.compile` and `keras.losses` for more info on | ||
possible `loss` values. | ||
metrics: `"auto"`, or a list of metrics to be evaluated by | ||
the model during training and testing. Defaults to `"auto"`, | ||
where a `keras.metrics.Accuracy` will be applied to track the | ||
accuracy of the model during training. | ||
See `keras.Model.compile` and `keras.metrics` for | ||
more info on possible `metrics` values. | ||
**kwargs: See `keras.Model.compile` for a full list of arguments | ||
supported by the compile method. | ||
""" | ||
if loss == "auto": | ||
loss = keras.losses.BinaryCrossentropy() | ||
if metrics == "auto": | ||
metrics = [keras.metrics.Accuracy()] | ||
super().compile( | ||
optimizer=optimizer, | ||
loss=loss, | ||
metrics=metrics, | ||
**kwargs, | ||
) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.