In [1]:
# Imports
import keras
import keras_hub
import math
import numpy as np
import PIL
import timm
import torch
import torch.nn as nn

In [2]:
# Define presets
PRESET_MAP = {
    "enet_b0_ra": "timm/efficientnet_b0.ra_in1k",
    "enet_b1_ft": "timm/efficientnet_b1.ft_in1k",
    "enet_b1_pruned": "timm/efficientnet_b1_pruned.in1k",
    "enet_b2_ra": "timm/efficientnet_b2.ra_in1k",
    "enet_b2_pruned": "timm/efficientnet_b2_pruned.in1k",
    "enet_b3_ra2": "timm/efficientnet_b3.ra2_in1k",
    "enet_b3_pruned": "timm/efficientnet_b3_pruned.in1k",
    "enet_b4_ra2": "timm/efficientnet_b4.ra2_in1k",
    "enet_b5_sw": "timm/efficientnet_b5.sw_in12k",
    "enet_b5_sw_ft": "timm/efficientnet_b5.sw_in12k_ft_in1k",
    "enet_el_ra": "timm/efficientnet_el.ra_in1k",
    "enet_el_pruned": "timm/efficientnet_el_pruned.in1k",
    "enet_em_ra2": "timm/efficientnet_em.ra2_in1k",
    "enet_es_ra": "timm/efficientnet_es.ra_in1k",
    "enet_es_pruned": "timm/efficientnet_es_pruned.in1k",
    "enet_b0_ra4_e3600_r224": "timm/efficientnet_b0.ra4_e3600_r224_in1k",
    "enet_b1_ra4_e3600_r240": "timm/efficientnet_b1.ra4_e3600_r240_in1k",
    "enet2_rw_m_agc": "timm/efficientnetv2_rw_m.agc_in1k",
    "enet2_rw_s_ra2": "timm/efficientnetv2_rw_s.ra2_in1k",
    "enet2_rw_t_ra2": "timm/efficientnetv2_rw_t.ra2_in1k",
}

In [3]:
# Convenience functions
def channel_first_to_last(x):
    return keras.ops.transpose(x, axes=(0, 2, 3, 1))

def channel_last_to_first(x):
    return keras.ops.transpose(x, axes=(0, 3, 1, 2))

def keras_compute(x, layer, sublayer_names):
    for name in sublayer_names:
        x = layer.get_layer(name)(x)
    return x

def pt_compute(x, modules):
    for module in modules:
        x = module(x)
    return x

def compare_tensors(keras_tensor, timm_tensor, atol=1e-8):
    return np.allclose(keras_tensor, channel_first_to_last(timm_tensor.detach().numpy()), atol)

def compare_conv2D_kernels(keras_conv2D, pt_conv2D, atol=1e-8):
    return np.allclose(np.transpose(pt_conv2D.weight.detach().numpy(), (2, 3, 1, 0)), keras_conv2D.get_weights()[0], atol)

In [4]:
# Create timm model & convert to keras_hub model
timm_name = PRESET_MAP["enet_b0_ra"]
timm_model = timm.create_model(timm_name, pretrained=True)
timm_model = timm_model.eval()
keras_model = keras_hub.models.ImageClassifier.from_preset(
    "hf://" + timm_name,
)

In [5]:
# Load example image & Preprocess
file = keras.utils.get_file(
    origin=(
        "https://storage.googleapis.com/keras-cv/"
        "models/paligemma/cow_beach_1.png"
    )
)
image = PIL.Image.open(file)
batch = np.array([image])

# Preprocessing.
timm_batch = keras_model.preprocessor(batch)
timm_batch = keras.ops.transpose(timm_batch, axes=(0, 3, 1, 2))
timm_batch = keras.ops.cast(timm_batch, dtype="float32")
timm_batch = torch.from_numpy(np.array(timm_batch))

batch = keras.ops.cast(batch, dtype="float32")

In [6]:
# Check input tensor is the same
np.allclose(batch, channel_first_to_last(timm_batch))

True

In [7]:
# Inference full forward pass
timm_outputs = timm_model(timm_batch).detach().numpy()
timm_label = np.argmax(timm_outputs[0])

keras_outputs = keras_model.predict(batch)
keras_label = np.argmax(keras_outputs[0])

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 719ms/step


In [8]:
# Results
print("🔶 Keras output:", keras_outputs[0, :10])
print("🔶 TIMM output:", timm_outputs[0, :10])
print("🔶 Difference:", np.mean(np.abs(keras_outputs - timm_outputs)))
print("🔶 Keras label:", keras_label)
print("🔶 TIMM label:", timm_label)

🔶 Keras output: [ 1620.1848   2247.4805  -7034.3467  -3782.0728  -9877.553    1429.7053
  2474.2737   9480.087    8330.119     782.66644]
🔶 TIMM output: [ 754.8181  5853.823   3473.4465  4661.4194  2204.8643  1050.2249
 -368.87488  461.20947 -236.12665 -935.4804 ]
🔶 Difference: 4179.8276
🔶 Keras label: 512
🔶 TIMM label: 391


In [9]:
# Disection time
# Take the input tensors and only go through the first layers
keras_inter_tensor = keras_compute(batch, keras_model.backbone, ["stem_conv_pad", "stem_conv"])
timm_inter_tensor = pt_compute(timm_batch, [timm_model.conv_stem])

In [10]:
# Are the kernels the same?
compare_conv2D_kernels(keras_model.backbone.get_layer("stem_conv"), timm_model.conv_stem)

True

In [11]:
# Are the intermediate tensors the same?
compare_tensors(keras_inter_tensor, timm_inter_tensor)

ValueError: operands could not be broadcast together with shapes (1,113,113,32) (1,112,112,32) 

In [12]:
# different shape but close... let's look at them
keras_inter_tensor

<tf.Tensor: shape=(1, 113, 113, 32), dtype=float32, numpy=
array([[[[-6.90355652e+02, -9.22706909e+02, -2.55646973e+02, ...,
           3.65845520e+02, -2.06023769e+01, -2.87541412e+02],
         [-7.27479492e+02, -1.38209457e+01, -2.59368225e+02, ...,
           3.78148956e+02,  1.32786036e+00, -3.44223114e+02],
         [-7.28762451e+02, -1.35537834e+01, -2.60155396e+02, ...,
           3.78620239e+02,  4.03245783e+00, -3.46049500e+02],
         ...,
         [-5.42898560e+02, -9.79613781e+00, -1.91869919e+02, ...,
           2.74462280e+02, -3.88873196e+00, -4.65800751e+02],
         [-5.41980225e+02, -1.18839827e+01, -1.91713654e+02, ...,
           2.74291321e+02,  4.37592506e-01, -4.65262207e+02],
         [-2.51835957e+01,  6.79348877e+02, -5.76185751e+00, ...,
           7.05634069e+00,  2.23309612e+01, -8.57044830e+01]],

        [[-3.22593069e+00, -9.86635864e+02, -9.67449093e+00, ...,
           5.15525696e+02, -2.16028042e+01, -3.41197510e+02],
         [-2.59661555e+00,  4

In [15]:
# and the timm one
channel_first_to_last(timm_inter_tensor.detach().numpy())

<tf.Tensor: shape=(1, 112, 112, 32), dtype=float32, numpy=
array([[[[-6.90355591e+02, -9.22706848e+02, -2.55646988e+02, ...,
           3.65845551e+02, -2.06023769e+01, -2.87541412e+02],
         [-7.27479553e+02, -1.38209381e+01, -2.59368225e+02, ...,
           3.78149017e+02,  1.32786036e+00, -3.44223114e+02],
         [-7.28762573e+02, -1.35537720e+01, -2.60155396e+02, ...,
           3.78620270e+02,  4.03246546e+00, -3.46049469e+02],
         ...,
         [-5.49251343e+02, -1.55145330e+01, -1.94512436e+02, ...,
           2.77907166e+02, -2.19626427e-01, -4.65813965e+02],
         [-5.42898621e+02, -9.79614162e+00, -1.91869934e+02, ...,
           2.74462250e+02, -3.88874722e+00, -4.65800751e+02],
         [-5.41980225e+02, -1.18840094e+01, -1.91713654e+02, ...,
           2.74291321e+02,  4.37561989e-01, -4.65262207e+02]],

        [[-3.22582960e+00, -9.86635864e+02, -9.67446423e+00, ...,
           5.15525757e+02, -2.16027431e+01, -3.41197510e+02],
         [-2.59651923e+00,  4

In [15]:
# Looks the same but different shape... likely due to conversion from PT Conv2D parameters not aligning with Keras Conv2D parameters
# Specifically padding = (1, 1) in PT does not align exactly with padding = "same" or "valid" with the manual "stem_conv_pad" added.
# There should be a general form of padding = (x, x) in PT -> Keras parameters for padding and potentially a manual zero pad, depending on
# parameters (like stride, input shape and the padding values) <-- may not be exhaustive but these are the likely culprits.

False