<a href="https://colab.research.google.com/github/haifeng-jin/keras-benchmarking/blob/main/batch_norm_op_jax_after.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

See details in [this pull request](https://github.com/keras-team/keras/pull/18793).

This notebook runs on the HEAD commit of the PR.

Runtime: V100

In [1]:
# %%script false --no-raise-error

!apt install -qq python3-venv
!pip install -q namex
!git clone --quiet https://github.com/haifeng-jin/keras.git
!cd keras && git checkout 167724293105a12cfe62f9321cb72ef75f07b3b0
!python keras/pip_build.py --install > /dev/null

The following additional packages will be installed:
  python3-pip-whl python3-setuptools-whl python3.10-venv
The following NEW packages will be installed:
  python3-pip-whl python3-setuptools-whl python3-venv python3.10-venv
0 upgraded, 4 newly installed, 0 to remove and 9 not upgraded.
Need to get 2,474 kB of archives.
After this operation, 2,890 kB of additional disk space will be used.
Selecting previously unselected package python3-pip-whl.
(Reading database ... 120880 files and directories currently installed.)
Preparing to unpack .../python3-pip-whl_22.0.2+dfsg-1ubuntu0.4_all.deb ...
Unpacking python3-pip-whl (22.0.2+dfsg-1ubuntu0.4) ...
Selecting previously unselected package python3-setuptools-whl.
Preparing to unpack .../python3-setuptools-whl_59.6.0-1.2ubuntu0.22.04.1_all.deb ...
Unpacking python3-setuptools-whl (59.6.0-1.2ubuntu0.22.04.1) ...
Selecting previously unselected package python3.10-venv.
Preparing to unpack .../python3.10-venv_3.10.12-1~22.04.2_amd64.deb ...
Unpa

Some useful files:

/usr/local/lib/python3.10/dist-packages/keras_core/src/backend/torch/core.py

In [2]:
import os

os.environ["KERAS_BACKEND"] = "jax"

In [3]:
import keras

print(keras.__file__)

/usr/local/lib/python3.10/dist-packages/keras/__init__.py


In [4]:
import cProfile
import pstats


def start_profile():
    profiler = cProfile.Profile()
    profiler.enable()


def end_profile(profiler):
    profiler.disable()
    stats = pstats.Stats(profiler).sort_stats("cumtime")
    stats.print_stats()

In [5]:
import time
import tensorflow as tf
import numpy as np


class BenchmarkMetricsCallback(keras.callbacks.Callback):
    def __init__(self, start_batch=1, stop_batch=None):
        self.start_batch = start_batch
        self.stop_batch = stop_batch

        self.state = {}

    def on_train_batch_begin(self, batch, logs=None):
        if batch == self.start_batch:
            self.state["benchmark_begin"] = time.time()

    def on_train_batch_end(self, batch, logs=None):
        if batch == self.stop_batch:
            self.state["benchmark_end"] = time.time()
            throughput = (self.stop_batch - self.start_batch + 1) / (
                self.state["benchmark_end"] - self.state["benchmark_begin"]
            )
            self.state["throughput"] = throughput

    def on_predict_batch_begin(self, batch, logs=None):
        if batch == self.start_batch:
            self.state["benchmark_begin"] = time.time()

    def on_predict_batch_end(self, batch, logs=None):
        if batch == self.stop_batch:
            self.state["benchmark_end"] = time.time()
            throughput = (self.stop_batch - self.start_batch + 1) / (
                self.state["benchmark_end"] - self.state["benchmark_begin"]
            )
            self.state["throughput"] = throughput


batch_size = 128
model = keras.applications.resnet50.ResNet50()
images = np.random.randn(batch_size, 224, 224, 3)
labels = np.random.randint(0, 1000, (batch_size,))
dataset = tf.data.Dataset.from_tensor_slices((images, labels)).repeat(101)
dataset = dataset.batch(batch_size)

model.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(),
    optimizer="adam",
)

callback = BenchmarkMetricsCallback(stop_batch=100)
model.fit(dataset, epochs=1, callbacks=[callback])
print(f"training: {1000.0 / callback.state['throughput']:.0f} ms/step")

callback = BenchmarkMetricsCallback(stop_batch=100)
model.predict(dataset, callbacks=[callback])
print(f"inferencing: {1000.0 / callback.state['throughput']:.0f} ms/step")

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels.h5
[1m102967424/102967424[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 0us/step
[1m101/101[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m142s[0m 942ms/step - loss: 0.4892
training: 456 ms/step
[1m101/101[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m20s[0m 172ms/step
inferencing: 171 ms/step
