In [1]:
# Imports
import keras
import keras_hub
import math
import numpy as np
import PIL
import timm
import torch
import torch.nn as nn

In [2]:
# Define presets
PRESET_MAP = {
    "enet_b0_ra": "timm/efficientnet_b0.ra_in1k",
    "enet_b1_ft": "timm/efficientnet_b1.ft_in1k",
    "enet_b1_pruned": "timm/efficientnet_b1_pruned.in1k",
    "enet_b2_ra": "timm/efficientnet_b2.ra_in1k",
    "enet_b2_pruned": "timm/efficientnet_b2_pruned.in1k",
    "enet_b3_ra2": "timm/efficientnet_b3.ra2_in1k",
    "enet_b3_pruned": "timm/efficientnet_b3_pruned.in1k",
    "enet_b4_ra2": "timm/efficientnet_b4.ra2_in1k",
    "enet_b5_sw": "timm/efficientnet_b5.sw_in12k",
    "enet_b5_sw_ft": "timm/efficientnet_b5.sw_in12k_ft_in1k",
    "enet_el_ra": "timm/efficientnet_el.ra_in1k",
    "enet_el_pruned": "timm/efficientnet_el_pruned.in1k",
    "enet_em_ra2": "timm/efficientnet_em.ra2_in1k",
    "enet_es_ra": "timm/efficientnet_es.ra_in1k",
    "enet_es_pruned": "timm/efficientnet_es_pruned.in1k",
    "enet_b0_ra4_e3600_r224": "timm/efficientnet_b0.ra4_e3600_r224_in1k",
    "enet_b1_ra4_e3600_r240": "timm/efficientnet_b1.ra4_e3600_r240_in1k",
    "enet2_rw_m_agc": "timm/efficientnetv2_rw_m.agc_in1k",
    "enet2_rw_s_ra2": "timm/efficientnetv2_rw_s.ra2_in1k",
    "enet2_rw_t_ra2": "timm/efficientnetv2_rw_t.ra2_in1k",
}

In [3]:
stackwise_kernel_sizes = [3, 3, 5, 3, 5, 5, 3]
stackwise_num_repeats = [1, 2, 2, 3, 3, 4, 1]
stackwise_input_filters = [32, 16, 24, 40, 80, 112, 192]
stackwise_output_filters = [16, 24, 40, 80, 112, 192, 320]
stackwise_expansion_ratios = [1, 6, 6, 6, 6, 6, 6]
stackwise_strides = [1, 2, 2, 2, 1, 2, 1]
stackwise_squeeze_and_excite_ratios = [0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25]

In [4]:
# Convenience functions
def channel_first_to_last(x):
    return keras.ops.transpose(x, axes=(0, 2, 3, 1))

def channel_last_to_first(x):
    return keras.ops.transpose(x, axes=(0, 3, 1, 2))

def keras_compute(x, layer, sublayer_names):
    cur_res = False
    for name in sublayer_names:
        sub_layer = layer.get_layer(name)
        if name.endswith("se_squeeze"):
            se = sub_layer(x)
        elif name.endswith("se_reshape") or name.endswith("se_reduce") or name.endswith("se_expand"):
            se = sub_layer(se)
        elif name.endswith("se_excite"):
            x = sub_layer([x, se])
        elif name.endswith("add"):
            x = sub_layer([x, res])
            cur_res = False
        # For potential residual computations
        elif name.endswith("expand_conv"):
            res = x
            cur_res = True
            x = sub_layer(x)
        elif name.endswith("dwconv_pad"):
            if not cur_res:
                res = x
                cur_res = True
            x = sub_layer(x)
        else:
            x = sub_layer(x)
    return x

def pt_compute(x, modules):
    for module in modules:
        x = module(x)
    return x

def compare_tensors(keras_tensor, timm_tensor, atol=1e-8):
    return np.allclose(keras_tensor, channel_first_to_last(timm_tensor.detach().numpy()), atol)

def compare_conv2D_kernels(keras_conv2D, pt_conv2D, atol=1e-8):
    return np.allclose(np.transpose(pt_conv2D.weight.detach().numpy(), (2, 3, 1, 0)), keras_conv2D.get_weights()[0], atol)

def produce_layer_names(stack, block, expand_ratio, se_ratio, strides, filters_in, filters_out, dropout):
    letter_identifier = chr(block + 97)
    block_prefix = f"block{stack+1}{letter_identifier}_"
    out = []
    
    if expand_ratio != 1:
        out.append(block_prefix + "expand_conv")
        out.append(block_prefix + "expand_bn")
        out.append(block_prefix + "expand_activation")
        
    out.append(block_prefix + "dwconv_pad")
    out.append(block_prefix + "dwconv")
    out.append(block_prefix + "dwconv_bn")
    out.append(block_prefix + "dwconv_activation")

    if 0 < se_ratio <= 1:
        out.append(block_prefix + "se_squeeze")
        out.append(block_prefix + "se_reshape")
        out.append(block_prefix + "se_reduce")
        out.append(block_prefix + "se_expand")
        out.append(block_prefix + "se_excite")

    out.append(block_prefix + "project")
    out.append(block_prefix + "project_bn")
    out.append(block_prefix + "project_activation")

    if strides == 1 and filters_in == filters_out:
        if dropout > 0:
            out.append(block_prefix + "drop")
        out.append(block_prefix + "add")

    return out


In [5]:
def round_filters(
    filters,
    width_coefficient,
    min_depth,
    depth_divisor,
    use_depth_divisor_as_min_depth,
    cap_round_filter_decrease,
):
    """Round number of filters based on depth multiplier.

    Args:
        filters: int, number of filters for Conv layer
        width_coefficient: float, denotes the scaling coefficient of network
            width
        depth_divisor: int, a unit of network width
        use_depth_divisor_as_min_depth: bool, whether to use depth_divisor as
            the minimum depth instead of min_depth (as per v1)
        max_round_filter_decrease: bool, whether to cap the decrease in the
            number of filters this process produces (as per v1)

    Returns:
        int, new rounded filters value for Conv layer
    """
    filters *= width_coefficient

    if use_depth_divisor_as_min_depth:
        min_depth = depth_divisor

    new_filters = max(
        min_depth,
        int(filters + depth_divisor / 2) // depth_divisor * depth_divisor,
    )

    if cap_round_filter_decrease:
        # Make sure that round down does not go down by more than 10%.
        if new_filters < 0.9 * filters:
            new_filters += depth_divisor

    return int(new_filters)

def round_repeats(repeats, depth_coefficient):
    """Round number of repeats based on depth multiplier.

    Args:
        repeats: int, number of repeats of efficientnet block
        depth_coefficient: float, denotes the scaling coefficient of network
            depth

    Returns:
        int, rounded repeats
    """
    return int(math.ceil(depth_coefficient * repeats))


width_coefficient = 1.0
depth_coefficient = 1.0
min_depth = None,
depth_divisor = 8
use_depth_divisor_as_min_depth = True
cap_round_filter_decrease = True


for i in range(len(stackwise_kernel_sizes)):
    num_repeats = stackwise_num_repeats[i]
    input_filters = stackwise_input_filters[i]
    output_filters = stackwise_output_filters[i]

    input_filters = round_filters(
        filters=input_filters,
        width_coefficient=width_coefficient,
        min_depth=min_depth,
        depth_divisor=depth_divisor,
        use_depth_divisor_as_min_depth=use_depth_divisor_as_min_depth,
        cap_round_filter_decrease=cap_round_filter_decrease,
    )
    output_filters = round_filters(
        filters=output_filters,
        width_coefficient=width_coefficient,
        min_depth=min_depth,
        depth_divisor=depth_divisor,
        use_depth_divisor_as_min_depth=use_depth_divisor_as_min_depth,
        cap_round_filter_decrease=cap_round_filter_decrease,
    )

    repeats = round_repeats(
        repeats=num_repeats,
        depth_coefficient=depth_coefficient,
    )
    strides = stackwise_strides[i]
    squeeze_and_excite_ratio = stackwise_squeeze_and_excite_ratios[i]

    for j in range(repeats):
        
        if j > 0:
            strides = 1
            input_filters = output_filters

        print(produce_layer_names(i, j, stackwise_expansion_ratios[i], stackwise_squeeze_and_excite_ratios[i],
                            strides, input_filters, output_filters, 0))

['block1a_dwconv_pad', 'block1a_dwconv', 'block1a_dwconv_bn', 'block1a_dwconv_activation', 'block1a_se_squeeze', 'block1a_se_reshape', 'block1a_se_reduce', 'block1a_se_expand', 'block1a_se_excite', 'block1a_project', 'block1a_project_bn', 'block1a_project_activation']
['block2a_expand_conv', 'block2a_expand_bn', 'block2a_expand_activation', 'block2a_dwconv_pad', 'block2a_dwconv', 'block2a_dwconv_bn', 'block2a_dwconv_activation', 'block2a_se_squeeze', 'block2a_se_reshape', 'block2a_se_reduce', 'block2a_se_expand', 'block2a_se_excite', 'block2a_project', 'block2a_project_bn', 'block2a_project_activation']
['block2b_expand_conv', 'block2b_expand_bn', 'block2b_expand_activation', 'block2b_dwconv_pad', 'block2b_dwconv', 'block2b_dwconv_bn', 'block2b_dwconv_activation', 'block2b_se_squeeze', 'block2b_se_reshape', 'block2b_se_reduce', 'block2b_se_expand', 'block2b_se_excite', 'block2b_project', 'block2b_project_bn', 'block2b_project_activation', 'block2b_add']
['block3a_expand_conv', 'block3a

In [6]:
# Create timm model & convert to keras_hub model
timm_name = PRESET_MAP["enet_b0_ra"]
timm_model = timm.create_model(timm_name, pretrained=True)
timm_model = timm_model.eval()
keras_model = keras_hub.models.ImageClassifier.from_preset(
    "hf://" + timm_name,
)

In [None]:
"https://storage.googleapis.com/keras-cv/"
        "models/paligemma/cow_beach_1.png"
"https://github.com/pytorch/hub/raw/master/images/dog.jpg"
"https://i.imgur.com/mtbl1cr.jpeg"

In [8]:
# Load example image & Preprocess
file = keras.utils.get_file(
        origin=(
            "https://storage.googleapis.com/keras-cv/"
            "models/paligemma/cow_beach_1.png"
        )
    )
image = PIL.Image.open(file)
batch = np.array([image])

data_config = timm.data.resolve_model_data_config(timm_model)
data_config["crop_pct"] = 1.0  # Stop timm from cropping.
transforms = timm.data.create_transform(**data_config, is_training=False)
timm_preprocessed = transforms(image)
timm_preprocessed = keras.ops.transpose(timm_preprocessed, axes=(1, 2, 0))
timm_preprocessed = keras.ops.expand_dims(timm_preprocessed, 0)

# Preprocess with Keras.
batch = keras.ops.cast(batch, "float32")
keras_preprocessed = keras_model.preprocessor(batch)

# Call with Timm. Use the keras preprocessed image so we can keep modeling
# and preprocessing comparisons independent.
timm_batch = keras.ops.transpose(keras_preprocessed, axes=(0, 3, 1, 2))
timm_batch = torch.from_numpy(np.array(timm_batch))
timm_outputs = timm_model(timm_batch).detach().numpy()
timm_label = np.argmax(timm_outputs[0])

# Call with Keras.
keras_outputs = keras_model.predict(batch)
keras_label = np.argmax(keras_outputs[0])

print("🔶 Keras output:", keras_outputs[0, :10])
print("🔶 TIMM output:", timm_outputs[0, :10])
print("🔶 Keras label:", keras_label)
print("🔶 TIMM label:", timm_label)
modeling_diff = np.mean(np.abs(keras_outputs - timm_outputs))
print("🔶 Modeling difference:", modeling_diff)
preprocessing_diff = np.mean(np.abs(keras_preprocessed - timm_preprocessed))
print("🔶 Preprocessing difference:", preprocessing_diff)

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 658ms/step
🔶 Keras output: [-0.00153291  0.01410532 -0.09853543  0.19247465  0.2943132  -0.71479553
  0.41454506  0.74578226  0.8629832   0.8470797 ]
🔶 TIMM output: [-0.3285119  -0.07352898 -0.81079614  1.2667956   0.49385935  0.13255954
  0.78108     0.25159395  0.11376506 -1.6937063 ]
🔶 Keras label: 313
🔶 TIMM label: 345
🔶 Modeling difference: 0.7089012
🔶 Preprocessing difference: 1.1485839e-07


In [13]:
keras_model.summary()

TypeError: 'NoneType' object is not iterable

In [15]:
sum(p.numel() for p in timm_model.parameters())

5288548

In [18]:
# Disection time
# Take the input tensors and only go through the first layers
keras_inter_tensor = keras_compute(keras_preprocessed, keras_model.backbone, ["stem_conv_pad", "stem_conv"])
timm_inter_tensor = pt_compute(timm_batch, [timm_model.conv_stem])

In [23]:
np.testing.assert_allclose(keras_inter_tensor, channel_first_to_last(timm_inter_tensor.detach().numpy()), atol=1e-6)

AssertionError: 
Not equal to tolerance rtol=1e-07, atol=1e-06

Mismatched elements: 4133 / 401408 (1.03%)
Max absolute difference: 6.198883e-06
Max relative difference: 0.08169387
 x: array([[[[-1.568060e+00, -2.120042e+00, -3.686755e-01, ...,
           7.063465e-01,  7.547081e-02, -5.708362e+00],
         [-1.712825e+00, -8.447712e-03, -5.333350e-01, ...,...
 y: array([[[[-1.568060e+00, -2.120042e+00, -3.686756e-01, ...,
           7.063465e-01,  7.547063e-02, -5.708363e+00],
         [-1.712824e+00, -8.447831e-03, -5.333350e-01, ...,...

In [9]:
# Example if you want to grab a different intermediate tensor
keras_inter_tensor = keras_compute(batch, keras_model.backbone, ["stem_conv_pad", "stem_conv", "stem_bn", "stem_activation",
    "block1a_dwconv_pad", "block1a_dwconv", "block1a_dwconv_bn", "block1a_dwconv_activation", "block1a_se_squeeze", "block1a_se_reshape",
    "block1a_se_reduce", "block1a_se_expand", "block1a_se_excite", "block1a_project", "block1a_project_bn", "block1a_project_activation"])
timm_inter_tensor = pt_compute(timm_batch, [timm_model.conv_stem, timm_model.bn1, timm_model.blocks[0][0]])