Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion scripts/keras_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,8 @@
from sparseml.keras.utils import (
LossesAndMetricsLoggingCallback,
ModelExporter,
keras,
TensorBoardLogger,
keras,
)
from sparseml.utils import create_dirs

Expand Down
1 change: 1 addition & 0 deletions src/sparseml/keras/datasets/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@
# flake8: noqa

from .imagefolder import *
from .imagenet import *
from .imagenette import *
20 changes: 11 additions & 9 deletions src/sparseml/keras/datasets/classification/imagefolder.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
def imagenet_normalizer(img: tensorflow.Tensor, mode: str):
"""
Normalize an image using mean and std of the imagenet dataset

:param img: The input image to normalize
:param mode: either "tf", "caffe", "torch"
:return: The normalized image
Expand All @@ -73,7 +72,7 @@ def imagenet_normalizer(img: tensorflow.Tensor, mode: str):

def default_imagenet_normalizer():
def normalizer(img: tensorflow.Tensor):
# Default to the same preprocessing used by ResNet
# Default to the same preprocessing used by Keras Applications ResNet
return imagenet_normalizer(img, "caffe")

return normalizer
Expand Down Expand Up @@ -109,7 +108,7 @@ def __init__(
self,
root: str,
train: bool,
image_size: Union[int, Tuple[int, int]] = 224,
image_size: Union[None, int, Tuple[int, int]] = 224,
pre_resize_transforms: Union[SplitsTransforms, None] = SplitsTransforms(
train=(
random_scaling_crop(),
Expand All @@ -126,9 +125,14 @@ def __init__(
if not os.path.exists(self._root):
raise ValueError("Data set folder {} must exist".format(self._root))
self._train = train
self._image_size = (
image_size if isinstance(image_size, tuple) else (image_size, image_size)
)
if image_size is not None:
self._image_size = (
image_size
if isinstance(image_size, tuple)
else (image_size, image_size)
)
else:
self._image_size = None
self._pre_resize_transforms = pre_resize_transforms
self._post_resize_transforms = post_resize_transforms

Expand Down Expand Up @@ -199,7 +203,6 @@ def processor(self, file_path: tensorflow.Tensor, label: tensorflow.Tensor):
"""
img = tensorflow.io.read_file(file_path)
img = tensorflow.image.decode_jpeg(img, channels=3)

if self.pre_resize_transforms:
transforms = (
self.pre_resize_transforms.train
Expand All @@ -209,7 +212,7 @@ def processor(self, file_path: tensorflow.Tensor, label: tensorflow.Tensor):
if transforms:
for trans in transforms:
img = trans(img)
if self._image_size:
if self._image_size is not None:
img = tensorflow.image.resize(img, self.image_size)

if self.post_resize_transforms:
Expand All @@ -221,7 +224,6 @@ def processor(self, file_path: tensorflow.Tensor, label: tensorflow.Tensor):
if transforms:
for trans in transforms:
img = trans(img)

return img, label

def creator(self):
Expand Down
130 changes: 130 additions & 0 deletions src/sparseml/keras/datasets/classification/imagenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Imagenet dataset implementations for the image classification field in computer vision.
More info for the dataset can be found `here <http://www.image-net.org/>`__.
"""

import random
from typing import Tuple, Union

import tensorflow as tf

from sparseml.keras.datasets.classification import (
ImageFolderDataset,
SplitsTransforms,
imagenet_normalizer,
)
from sparseml.keras.datasets.helpers import random_scaling_crop
from sparseml.keras.datasets.registry import DatasetRegistry
from sparseml.keras.utils import keras
from sparseml.utils import clean_path
from sparseml.utils.datasets import (
IMAGENET_RGB_MEANS,
IMAGENET_RGB_STDS,
default_dataset_path,
)


__all__ = ["ImageNetDataset"]


def torch_imagenet_normalizer():
def normalizer(image: tf.Tensor):
return imagenet_normalizer(image, "torch")

return normalizer


def imagenet_pre_resize_processor():
def processor(image: tf.Tensor):
image_batch = tf.expand_dims(image, axis=0)

# Resize the image the following way to match torchvision's Resize
# transform used by Pytorch code path for Imagenet:
# torchvision.transforms.Resize(256)
# which resize the smaller side of images to 256 and the other one based
# on the aspect ratio
shape = tf.shape(image)
h, w = shape[0], shape[1]
if h > w:
new_h, new_w = tf.cast(256 * h / w, dtype=tf.uint16), tf.constant(
256, dtype=tf.uint16
)
else:
new_h, new_w = tf.constant(256, dtype=tf.uint16), tf.cast(
256 * w / h, dtype=tf.uint16
)
resizer = keras.layers.experimental.preprocessing.Resizing(new_h, new_w)
image_batch = tf.cast(resizer(image_batch), dtype=tf.uint8)

# Center crop
center_cropper = keras.layers.experimental.preprocessing.CenterCrop(224, 224)
image_batch = tf.cast(center_cropper(image_batch), dtype=tf.uint8)

return image_batch[0, :]

return processor


@DatasetRegistry.register(
key=["imagenet"],
attributes={
"num_classes": 1000,
"transform_means": IMAGENET_RGB_MEANS,
"transform_stds": IMAGENET_RGB_STDS,
},
)
class ImageNetDataset(ImageFolderDataset):
"""
Wrapper for the ImageNet dataset to apply standard transforms.

:param root: The root folder to find the dataset at
:param train: True if this is for the training distribution,
False for the validation
:param rand_trans: True to apply RandomCrop and RandomHorizontalFlip to the data,
False otherwise
:param image_size: the size of the image to output from the dataset
"""

def __init__(
self,
root: str = default_dataset_path("imagenet"),
train: bool = True,
rand_trans: bool = False,
image_size: Union[None, int, Tuple[int, int]] = 224,
pre_resize_transforms=SplitsTransforms(
train=(
random_scaling_crop(),
tf.image.random_flip_left_right,
),
val=(imagenet_pre_resize_processor(),),
),
post_resize_transforms=SplitsTransforms(
train=(torch_imagenet_normalizer(),), val=(torch_imagenet_normalizer(),)
),
):
root = clean_path(root)
super().__init__(
root,
train,
image_size=image_size,
pre_resize_transforms=pre_resize_transforms,
post_resize_transforms=post_resize_transforms,
)

if train:
# make sure we don't preserve the folder structure class order
random.shuffle(self.samples)
1 change: 0 additions & 1 deletion src/sparseml/keras/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def build(
) -> tensorflow.data.Dataset:
"""
Create the dataset in the current graph using tensorflow.data APIs
:param batch_size: the batch size to create the dataset for
:param repeat_count: the number of times to repeat the dataset,
if unset or None, will repeat indefinitely
Expand Down
1 change: 0 additions & 1 deletion src/sparseml/keras/datasets/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def random_scaling_crop(
"""
Random crop implementation which also randomly scales the crop taken
as well as the aspect ratio of the crop.
:param scale_range: the (min, max) of the crop scales to take from the orig image
:param ratio_range: the (min, max) of the aspect ratios to take from the orig image
:return: the callable function for random scaling crop op,
Expand Down
1 change: 1 addition & 0 deletions src/sparseml/keras/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@

# flake8: noqa

from .classification import *
from .external import *
from .registry import *
15 changes: 15 additions & 0 deletions src/sparseml/keras/models/classification/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .resnet import *
Loading