Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion keras_hub/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from keras_hub.src.models.sam.sam_image_converter import SAMImageConverter
from keras_hub.src.models.sam.sam_mask_decoder import SAMMaskDecoder
from keras_hub.src.models.sam.sam_prompt_encoder import SAMPromptEncoder
from keras_hub.src.models.vgg.vgg_image_classifier import VGGImageConverter
from keras_hub.src.models.vgg.vgg_image_converter import VGGImageConverter
from keras_hub.src.models.whisper.whisper_audio_converter import (
WhisperAudioConverter,
)
2 changes: 1 addition & 1 deletion keras_hub/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@
from keras_hub.src.models.text_to_image import TextToImage
from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone
from keras_hub.src.models.vgg.vgg_image_classifier import VGGImageClassifier
from keras_hub.src.models.vgg.vgg_image_classifier import (
from keras_hub.src.models.vgg.vgg_image_classifier_preprocessor import (
VGGImageClassifierPreprocessor,
)
from keras_hub.src.models.vit_det.vit_det_backbone import ViTDetBackbone
Expand Down
4 changes: 4 additions & 0 deletions keras_hub/src/models/vgg/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone
from keras_hub.src.models.vgg.vgg_presets import backbone_presets
from keras_hub.src.utils.preset_utils import register_presets

register_presets(backbone_presets, VGGBackbone)
19 changes: 4 additions & 15 deletions keras_hub/src/models/vgg/vgg_image_classifier.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,12 @@
import keras

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
from keras_hub.src.models.image_classifier import ImageClassifier
from keras_hub.src.models.image_classifier_preprocessor import (
ImageClassifierPreprocessor,
)
from keras_hub.src.models.task import Task
from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone


@keras_hub_export("keras_hub.layers.VGGImageConverter")
class VGGImageConverter(ImageConverter):
backbone_cls = VGGBackbone


@keras_hub_export("keras_hub.models.VGGImageClassifierPreprocessor")
class VGGImageClassifierPreprocessor(ImageClassifierPreprocessor):
backbone_cls = VGGBackbone
image_converter_cls = VGGImageConverter
from keras_hub.src.models.vgg.vgg_image_classifier_preprocessor import (
VGGImageClassifierPreprocessor,
)


@keras_hub_export("keras_hub.models.VGGImageClassifier")
Expand Down Expand Up @@ -211,6 +199,7 @@ def __init__(
self.pooling = pooling
self.pooling_hidden_dim = pooling_hidden_dim
self.dropout = dropout
self.preprocessor = preprocessor

def get_config(self):
# Backbone serialized in `super`
Expand Down
12 changes: 12 additions & 0 deletions keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.image_classifier_preprocessor import (
ImageClassifierPreprocessor,
)
from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone
from keras_hub.src.models.vgg.vgg_image_converter import VGGImageConverter


@keras_hub_export("keras_hub.models.VGGImageClassifierPreprocessor")
class VGGImageClassifierPreprocessor(ImageClassifierPreprocessor):
backbone_cls = VGGBackbone
image_converter_cls = VGGImageConverter
14 changes: 10 additions & 4 deletions keras_hub/src/models/vgg/vgg_image_classifier_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,40 @@

from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone
from keras_hub.src.models.vgg.vgg_image_classifier import VGGImageClassifier
from keras_hub.src.models.vgg.vgg_image_classifier_preprocessor import (
VGGImageClassifierPreprocessor,
)
from keras_hub.src.models.vgg.vgg_image_converter import VGGImageConverter
from keras_hub.src.tests.test_case import TestCase


class VGGImageClassifierTest(TestCase):
def setUp(self):
# Setup model.
self.images = np.ones((2, 8, 8, 3), dtype="float32")
self.labels = [0, 3]
self.labels = [0, 1]
self.backbone = VGGBackbone(
stackwise_num_repeats=[2, 4, 4],
stackwise_num_filters=[2, 16, 16],
image_shape=(8, 8, 3),
)
image_converter = VGGImageConverter(image_size=(8, 8))
self.preprocessor = VGGImageClassifierPreprocessor(
image_converter=image_converter,
)
self.init_kwargs = {
"backbone": self.backbone,
"num_classes": 2,
"activation": "softmax",
"pooling": "flatten",
"preprocessor": self.preprocessor,
}
self.train_data = (
self.images,
self.labels,
)

def test_classifier_basics(self):
pytest.skip(
reason="TODO: enable after preprocessor flow is figured out"
)
self.run_task_test(
cls=VGGImageClassifier,
init_kwargs=self.init_kwargs,
Expand Down
8 changes: 8 additions & 0 deletions keras_hub/src/models/vgg/vgg_image_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone


@keras_hub_export("keras_hub.layers.VGGImageConverter")
class VGGImageConverter(ImageConverter):
backbone_cls = VGGBackbone
Loading