In [None]:
import tensorflow as tf

In [None]:
tf.keras.mixed_precision.set_global_policy('mixed_float16')

backbone = tf.keras.applications.ResNet50(include_top=True, weights='imagenet', input_shape=(224, 224, 3))
backbone.summary()

In [3]:
model = tf.keras.Model(
    inputs=backbone.input,
    outputs={
        "logits": backbone.output,
        "embeddings": backbone.get_layer("avg_pool").output,
    },
)

In [24]:
class Classifier(tf.Module):

    def __init__(self, model: tf.keras.Model, input_size: tuple[int, int], jit_compile: bool = False):
        super(Classifier, self).__init__()
        self.model = model
        self.jit_compile = jit_compile
        self.input_size = input_size

    def get_model_predict_fn(self):
        if not self.jit_compile:
            return self.model
        return tf.function(
            self.model,
            jit_compile=True,
            input_signature=[tf.TensorSpec([None, *self.input_size, 3], tf.float32)],
        )

    def postprocess(self, predictions: dict[str, tf.Tensor]) -> dict[str, tf.Tensor]:

        logits = tf.cast(predictions["logits"], tf.float32)
        probs = tf.nn.softmax(logits, axis=-1)
        scores = tf.reduce_max(probs, axis=-1)
        classes = tf.argmax(probs, axis=-1)

        outputs = {
            "scores": scores,
            "classes": classes,
            "embeddings": tf.cast(predictions["embeddings"], tf.float32),
        }
        return outputs

    @tf.function(input_signature=[tf.TensorSpec([None, None, None, 3], tf.uint8)])
    def predict_images(self, images: tf.Tensor) -> dict[str, tf.Tensor]:
        images = tf.cast(images, tf.float32) / 255.0
        images = tf.image.resize(images, self.input_size)
        predictor = self.get_model_predict_fn()
        predictions = predictor(images)
        return self.postprocess(predictions)

    @tf.function(input_signature=[tf.TensorSpec([], tf.string)])
    def predict_jpeg(self, jpeg_image: tf.Tensor) -> dict[str, tf.Tensor]:
        images = tf.image.decode_jpeg(jpeg_image, channels=3)
        images = tf.expand_dims(images, axis=0)
        return self.predict_images(images)

    def export(self, save_dir: str) -> str:
        signatures = {
            "predict_images": self.predict_images.get_concrete_function(),
            "predict_jpeg": self.predict_jpeg.get_concrete_function(),
        }
        return tf.saved_model.save(self, save_dir, signatures=signatures)

module = Classifier(model, (224, 224))

In [20]:
images = tf.random.uniform((1, 224, 224, 3), minval=0, maxval=255, dtype=tf.int32)
images = tf.cast(images, tf.uint8)
module.predict_images(images)

{'scores': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([0.001121], dtype=float32)>,
 'classes': <tf.Tensor: shape=(1,), dtype=int64, numpy=array([111])>,
 'embeddings': <tf.Tensor: shape=(1, 2048), dtype=float32, numpy=
 array([[0.       , 0.       , 0.       , ..., 1.5585938, 0.       ,
         0.       ]], dtype=float32)>}

In [21]:
%timeit module.predict_images(images)

2.62 ms ± 261 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [22]:
jpeg_image = tf.image.encode_jpeg(images[0])
module.predict_jpeg(jpeg_image)

{'scores': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([0.00111377], dtype=float32)>,
 'classes': <tf.Tensor: shape=(1,), dtype=int64, numpy=array([111])>,
 'embeddings': <tf.Tensor: shape=(1, 2048), dtype=float32, numpy=
 array([[0.       , 0.       , 0.       , ..., 1.5390625, 0.       ,
         0.       ]], dtype=float32)>}

In [23]:
%timeit module.predict_jpeg(jpeg_image)

2.67 ms ± 231 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [26]:
def build_and_export_classifier(
    backbone_class: tf.keras.Model,
    save_dir: str,
    jit_compile: bool = False,
    mixed_precision: bool = False,
    input_size: tuple[int, int] = (224, 224),
):
    if mixed_precision:
        tf.keras.mixed_precision.set_global_policy('mixed_float16')
    else:
        tf.keras.mixed_precision.set_global_policy('float32')

    backbone = backbone_class(
        include_top=True, weights="imagenet", input_shape=(*input_size, 3)
    )
    model = tf.keras.Model(
        inputs=backbone.input,
        outputs={
            "logits": backbone.output,
            "embeddings": backbone.get_layer("avg_pool").output,
        },
    )
    module = Classifier(model, input_size, jit_compile=jit_compile)
    return module.export(save_dir)

In [28]:
build_and_export_classifier(tf.keras.applications.ResNet50, "../data/resnet50-xla-amp/1", jit_compile=True, mixed_precision=True)
build_and_export_classifier(tf.keras.applications.ResNet50, "../data/resnet50-no-opt/1", jit_compile=False, mixed_precision=False)

INFO:tensorflow:Assets written to: ../data/resnet50-xla-amp/1/assets


INFO:tensorflow:Assets written to: ../data/resnet50-xla-amp/1/assets


INFO:tensorflow:Assets written to: ../data/resnet50-no-opt/1/assets


INFO:tensorflow:Assets written to: ../data/resnet50-no-opt/1/assets


In [34]:
build_and_export_classifier(tf.keras.applications.EfficientNetB0, "../data/efficientnetb0-xla-amp/1", jit_compile=True, mixed_precision=True)
build_and_export_classifier(tf.keras.applications.EfficientNetB0, "../data/efficientnetb0-no-opt/1", jit_compile=False, mixed_precision=False)

INFO:tensorflow:Assets written to: ../data/efficientnetb0-xla-amp/1/assets


INFO:tensorflow:Assets written to: ../data/efficientnetb0-xla-amp/1/assets


INFO:tensorflow:Assets written to: ../data/efficientnetb0-no-opt/1/assets


INFO:tensorflow:Assets written to: ../data/efficientnetb0-no-opt/1/assets


In [None]:
import tensorflow as tf
import ops

In [3]:
models = [
    "../data/resnet50-no-opt/1",
    "../data/resnet50-xla-amp/1",
    "../data/efficientnetb0-no-opt/1",
    "../data/efficientnetb0-xla-amp/1",
]


for model in models:
    print(f"Loading model: {model}")
    predictor = tf.saved_model.load(model)
    # warmup
    images = tf.random.uniform((100, 224, 224, 3), minval=0, maxval=255, dtype=tf.int32)
    images = tf.cast(images, tf.uint8)
    predictor.predict_images(images)
    # benchmark
    images = tf.random.uniform((100, 224, 224, 3), minval=0, maxval=255, dtype=tf.int32)
    images = tf.cast(images, tf.uint8)

    def predict_fn(_):
        predictor.predict_images(images)

    _ = ops.thread_imap1(predict_fn, list(range(500)), num_workers=1, desc=model)

Loading model: ../data/resnet50-no-opt/1


../data/resnet50-no-opt/1: 100%|██████████| 500/500 [00:57<00:00,  8.66it/s]


Loading model: ../data/resnet50-xla-amp/1


2024-04-07 08:19:40.445034: W tensorflow/core/grappler/costs/op_level_cost_estimator.cc:693] Error in PredictCost() for the op: op: "Softmax" attr { key: "T" value { type: DT_FLOAT } } inputs { dtype: DT_FLOAT shape { unknown_rank: true } } device { type: "GPU" vendor: "NVIDIA" model: "NVIDIA RTX A4000 Laptop GPU" frequency: 1680 num_cores: 40 environment { key: "architecture" value: "8.6" } environment { key: "cuda" value: "11080" } environment { key: "cudnn" value: "8600" } num_registers: 65536 l1_cache_size: 24576 l2_cache_size: 4194304 shared_memory_size_per_multiprocessor: 102400 memory_size: 5849612288 bandwidth: 384064000 } outputs { dtype: DT_FLOAT shape { unknown_rank: true } }
../data/resnet50-xla-amp/1: 100%|██████████| 500/500 [00:20<00:00, 24.49it/s]


Loading model: ../data/efficientnetb0-no-opt/1


../data/efficientnetb0-no-opt/1: 100%|██████████| 500/500 [00:52<00:00,  9.55it/s]


Loading model: ../data/efficientnetb0-xla-amp/1


2024-04-07 08:21:03.270982: W tensorflow/core/grappler/costs/op_level_cost_estimator.cc:693] Error in PredictCost() for the op: op: "Softmax" attr { key: "T" value { type: DT_FLOAT } } inputs { dtype: DT_FLOAT shape { unknown_rank: true } } device { type: "GPU" vendor: "NVIDIA" model: "NVIDIA RTX A4000 Laptop GPU" frequency: 1680 num_cores: 40 environment { key: "architecture" value: "8.6" } environment { key: "cuda" value: "11080" } environment { key: "cudnn" value: "8600" } num_registers: 65536 l1_cache_size: 24576 l2_cache_size: 4194304 shared_memory_size_per_multiprocessor: 102400 memory_size: 5849612288 bandwidth: 384064000 } outputs { dtype: DT_FLOAT shape { unknown_rank: true } }
../data/efficientnetb0-xla-amp/1: 100%|██████████| 500/500 [00:13<00:00, 38.08it/s]
