In [1]:
# Imports
import keras
import keras_hub
import math
import numpy as np
import PIL
import timm
import torch
import torch.nn as nn

In [2]:
# Define presets
PRESET_MAP = {
    "enet_b0_ra": "timm/efficientnet_b0.ra_in1k",
    "enet_b1_ft": "timm/efficientnet_b1.ft_in1k",
    "enet_b1_pruned": "timm/efficientnet_b1_pruned.in1k",
    "enet_b2_ra": "timm/efficientnet_b2.ra_in1k",
    "enet_b2_pruned": "timm/efficientnet_b2_pruned.in1k",
    "enet_b3_ra2": "timm/efficientnet_b3.ra2_in1k",
    "enet_b3_pruned": "timm/efficientnet_b3_pruned.in1k",
    "enet_b4_ra2": "timm/efficientnet_b4.ra2_in1k",
    "enet_b5_sw": "timm/efficientnet_b5.sw_in12k",
    "enet_b5_sw_ft": "timm/efficientnet_b5.sw_in12k_ft_in1k",
    "enet_el_ra": "timm/efficientnet_el.ra_in1k",
    "enet_el_pruned": "timm/efficientnet_el_pruned.in1k",
    "enet_em_ra2": "timm/efficientnet_em.ra2_in1k",
    "enet_es_ra": "timm/efficientnet_es.ra_in1k",
    "enet_es_pruned": "timm/efficientnet_es_pruned.in1k",
    "enet_b0_ra4_e3600_r224": "timm/efficientnet_b0.ra4_e3600_r224_in1k",
    "enet_b1_ra4_e3600_r240": "timm/efficientnet_b1.ra4_e3600_r240_in1k",
    "enet2_rw_m_agc": "timm/efficientnetv2_rw_m.agc_in1k",
    "enet2_rw_s_ra2": "timm/efficientnetv2_rw_s.ra2_in1k",
    "enet2_rw_t_ra2": "timm/efficientnetv2_rw_t.ra2_in1k",
}

In [3]:
stackwise_kernel_sizes = [3, 3, 5, 3, 5, 5, 3]
stackwise_num_repeats = [1, 2, 2, 3, 3, 4, 1]
stackwise_input_filters = [32, 16, 24, 40, 80, 112, 192]
stackwise_output_filters = [16, 24, 40, 80, 112, 192, 320]
stackwise_expansion_ratios = [1, 6, 6, 6, 6, 6, 6]
stackwise_strides = [1, 2, 2, 2, 1, 2, 1]
stackwise_squeeze_and_excite_ratios = [0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25]

In [4]:
# Convenience functions
def channel_first_to_last(x):
    return keras.ops.transpose(x, axes=(0, 2, 3, 1))

def channel_last_to_first(x):
    return keras.ops.transpose(x, axes=(0, 3, 1, 2))

def keras_compute(x, layer, sublayer_names):
    cur_res = False
    for name in sublayer_names:
        sub_layer = layer.get_layer(name)
        if name.endswith("se_squeeze"):
            se = sub_layer(x)
        elif name.endswith("se_reshape") or name.endswith("se_reduce") or name.endswith("se_expand"):
            se = sub_layer(se)
        elif name.endswith("se_excite"):
            x = sub_layer([x, se])
        elif name.endswith("add"):
            x = sub_layer([x, res])
            cur_res = False
        # For potential residual computations
        elif name.endswith("expand_conv"):
            res = x
            cur_res = True
            x = sub_layer(x)
        elif name.endswith("dwconv_pad"):
            if not cur_res:
                res = x
                cur_res = True
            x = sub_layer(x)
        else:
            x = sub_layer(x)
    return x

def pt_compute(x, modules):
    for module in modules:
        x = module(x)
    return x

def compare_tensors(keras_tensor, timm_tensor, atol=1e-8):
    return np.allclose(keras_tensor, channel_first_to_last(timm_tensor.detach().numpy()), atol)

def compare_conv2D_kernels(keras_conv2D, pt_conv2D, atol=1e-8):
    return np.allclose(np.transpose(pt_conv2D.weight.detach().numpy(), (2, 3, 1, 0)), keras_conv2D.get_weights()[0], atol)

def produce_layer_names(stack, block, expand_ratio, se_ratio, strides, filters_in, filters_out, dropout):
    letter_identifier = chr(block + 97)
    block_prefix = f"block{stack+1}{letter_identifier}_"
    out = []
    
    if expand_ratio != 1:
        out.append(block_prefix + "expand_conv")
        out.append(block_prefix + "expand_bn")
        out.append(block_prefix + "expand_activation")
        
    out.append(block_prefix + "dwconv_pad")
    out.append(block_prefix + "dwconv")
    out.append(block_prefix + "dwconv_bn")
    out.append(block_prefix + "dwconv_activation")

    if 0 < se_ratio <= 1:
        out.append(block_prefix + "se_squeeze")
        out.append(block_prefix + "se_reshape")
        out.append(block_prefix + "se_reduce")
        out.append(block_prefix + "se_expand")
        out.append(block_prefix + "se_excite")

    out.append(block_prefix + "project")
    out.append(block_prefix + "project_bn")
    out.append(block_prefix + "project_activation")

    if strides == 1 and filters_in == filters_out:
        if dropout > 0:
            out.append(block_prefix + "drop")
        out.append(block_prefix + "add")

    return out


In [5]:
def round_filters(
    filters,
    width_coefficient,
    min_depth,
    depth_divisor,
    use_depth_divisor_as_min_depth,
    cap_round_filter_decrease,
):
    """Round number of filters based on depth multiplier.

    Args:
        filters: int, number of filters for Conv layer
        width_coefficient: float, denotes the scaling coefficient of network
            width
        depth_divisor: int, a unit of network width
        use_depth_divisor_as_min_depth: bool, whether to use depth_divisor as
            the minimum depth instead of min_depth (as per v1)
        max_round_filter_decrease: bool, whether to cap the decrease in the
            number of filters this process produces (as per v1)

    Returns:
        int, new rounded filters value for Conv layer
    """
    filters *= width_coefficient

    if use_depth_divisor_as_min_depth:
        min_depth = depth_divisor

    new_filters = max(
        min_depth,
        int(filters + depth_divisor / 2) // depth_divisor * depth_divisor,
    )

    if cap_round_filter_decrease:
        # Make sure that round down does not go down by more than 10%.
        if new_filters < 0.9 * filters:
            new_filters += depth_divisor

    return int(new_filters)

def round_repeats(repeats, depth_coefficient):
    """Round number of repeats based on depth multiplier.

    Args:
        repeats: int, number of repeats of efficientnet block
        depth_coefficient: float, denotes the scaling coefficient of network
            depth

    Returns:
        int, rounded repeats
    """
    return int(math.ceil(depth_coefficient * repeats))


width_coefficient = 1.0
depth_coefficient = 1.0
min_depth = None,
depth_divisor = 8
use_depth_divisor_as_min_depth = True
cap_round_filter_decrease = True


for i in range(len(stackwise_kernel_sizes)):
    num_repeats = stackwise_num_repeats[i]
    input_filters = stackwise_input_filters[i]
    output_filters = stackwise_output_filters[i]

    input_filters = round_filters(
        filters=input_filters,
        width_coefficient=width_coefficient,
        min_depth=min_depth,
        depth_divisor=depth_divisor,
        use_depth_divisor_as_min_depth=use_depth_divisor_as_min_depth,
        cap_round_filter_decrease=cap_round_filter_decrease,
    )
    output_filters = round_filters(
        filters=output_filters,
        width_coefficient=width_coefficient,
        min_depth=min_depth,
        depth_divisor=depth_divisor,
        use_depth_divisor_as_min_depth=use_depth_divisor_as_min_depth,
        cap_round_filter_decrease=cap_round_filter_decrease,
    )

    repeats = round_repeats(
        repeats=num_repeats,
        depth_coefficient=depth_coefficient,
    )
    strides = stackwise_strides[i]
    squeeze_and_excite_ratio = stackwise_squeeze_and_excite_ratios[i]

    for j in range(repeats):
        
        if j > 0:
            strides = 1
            input_filters = output_filters

        print(produce_layer_names(i, j, stackwise_expansion_ratios[i], stackwise_squeeze_and_excite_ratios[i],
                            strides, input_filters, output_filters, 0))

['block1a_dwconv_pad', 'block1a_dwconv', 'block1a_dwconv_bn', 'block1a_dwconv_activation', 'block1a_se_squeeze', 'block1a_se_reshape', 'block1a_se_reduce', 'block1a_se_expand', 'block1a_se_excite', 'block1a_project', 'block1a_project_bn', 'block1a_project_activation']
['block2a_expand_conv', 'block2a_expand_bn', 'block2a_expand_activation', 'block2a_dwconv_pad', 'block2a_dwconv', 'block2a_dwconv_bn', 'block2a_dwconv_activation', 'block2a_se_squeeze', 'block2a_se_reshape', 'block2a_se_reduce', 'block2a_se_expand', 'block2a_se_excite', 'block2a_project', 'block2a_project_bn', 'block2a_project_activation']
['block2b_expand_conv', 'block2b_expand_bn', 'block2b_expand_activation', 'block2b_dwconv_pad', 'block2b_dwconv', 'block2b_dwconv_bn', 'block2b_dwconv_activation', 'block2b_se_squeeze', 'block2b_se_reshape', 'block2b_se_reduce', 'block2b_se_expand', 'block2b_se_excite', 'block2b_project', 'block2b_project_bn', 'block2b_project_activation', 'block2b_add']
['block3a_expand_conv', 'block3a

In [6]:
# Create timm model & convert to keras_hub model
timm_name = PRESET_MAP["enet_b0_ra"]
timm_model = timm.create_model(timm_name, pretrained=True)
timm_model = timm_model.eval()
keras_model = keras_hub.models.ImageClassifier.from_preset(
    "hf://" + timm_name,
)

In [None]:
"https://storage.googleapis.com/keras-cv/"
        "models/paligemma/cow_beach_1.png"
"https://github.com/pytorch/hub/raw/master/images/dog.jpg"
"https://i.imgur.com/mtbl1cr.jpeg"

In [7]:
# Load example image & Preprocess
file = "mtbl1cr.jpeg"
image = PIL.Image.open(file)
batch = np.array([image])

# Preprocessing.
timm_batch = keras_model.preprocessor(batch)
batch = keras_model.preprocessor(batch)
timm_batch = keras.ops.transpose(timm_batch, axes=(0, 3, 1, 2))
timm_batch = keras.ops.cast(timm_batch, dtype="float32")
timm_batch = torch.from_numpy(np.array(timm_batch))
timm_batch = timm_batch / 255.0

batch = keras.ops.cast(batch, dtype="float32")
batch = batch / 255.0

In [8]:
batch.shape

TensorShape([1, 224, 224, 3])

In [9]:
timm_batch.shape

torch.Size([1, 3, 224, 224])

In [8]:
# Check input tensor is the same
np.allclose(batch, channel_first_to_last(timm_batch), atol=1e-16)

True

In [9]:
# Inference full forward pass
timm_outputs = timm_model(timm_batch).detach().numpy()
timm_label = np.argmax(timm_outputs[0])

keras_outputs = keras_model.predict(batch)
keras_label = np.argmax(keras_outputs[0])

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 705ms/step


In [10]:
# Results
print("🔶 Keras output:", keras_outputs[0, 500:510])
print("🔶 TIMM output:", timm_outputs[0, 500:510])
print("🔶 Difference:", np.mean(np.abs(keras_outputs - timm_outputs)))
print("🔶 Keras label:", keras_label)
print("🔶 TIMM label:", timm_label)

🔶 Keras output: [ 0.5226594   0.4194867   0.44462386 -0.23206796 -0.70875645 -0.58344996
  0.9328016   0.13332373 -0.44932097  0.09161342]
🔶 TIMM output: [ 0.9023841  -1.1280954  -1.6994343  -0.751282   -0.48456857  0.18079609
  0.59988517 -0.8264805   1.0880274  -0.47144806]
🔶 Difference: 0.843581
🔶 Keras label: 730
🔶 TIMM label: 287


In [13]:
sum(p.numel() for p in timm_model.parameters())

5288548

In [38]:
# Disection time
# Take the input tensors and only go through the first layers
keras_inter_tensor = keras_compute(batch, keras_model.backbone, ["stem_conv_pad", "stem_conv"])
timm_inter_tensor = pt_compute(timm_batch, [timm_model.conv_stem, timm_model.bn1, timm_model.blocks[0][0].conv_dw, 
                                            timm_model.blocks[0][0].bn1])

In [18]:
timm_inter_tensor = pt_compute(timm_batch, [timm_model.conv_stem, timm_model.bn1, timm_model.blocks[0][0].conv_dw, 
                                            timm_model.blocks[0][0].bn1])

In [20]:
timm_inter_tensor.shape

torch.Size([1, 32, 112, 112])

In [42]:
np.testing.assert_allclose(keras_inter_tensor, channel_first_to_last(timm_inter_tensor.detach().numpy()), atol=1e-6)

AssertionError: 
Not equal to tolerance rtol=1e-07, atol=1e-06

Mismatched elements: 952 / 401408 (0.237%)
Max absolute difference: 2.9802322e-06
Max relative difference: 0.34615067
 x: array([[[[-2.707277e+00, -3.618458e+00, -1.002537e+00, ...,
           1.434689e+00, -8.079384e-02, -1.127613e+00],
         [-2.852861e+00, -5.419985e-02, -1.017130e+00, ...,...
 y: array([[[[-2.707277e+00, -3.618458e+00, -1.002537e+00, ...,
           1.434689e+00, -8.079372e-02, -1.127614e+00],
         [-2.852861e+00, -5.419981e-02, -1.017130e+00, ...,...

In [75]:
keras_inter_tensor = keras_compute(batch, keras_model.backbone, ["stem_conv_pad", "stem_conv", "stem_bn", "stem_activation"])
timm_inter_tensor = pt_compute(timm_batch, [timm_model.conv_stem, timm_model.bn1])

In [9]:
keras_inter_tensor = keras_compute(batch, keras_model.backbone, ["stem_conv_pad", "stem_conv", "stem_bn", "stem_activation",
    "block1a_dwconv_pad", "block1a_dwconv", "block1a_dwconv_bn", "block1a_dwconv_activation", "block1a_se_squeeze", "block1a_se_reshape",
    "block1a_se_reduce", "block1a_se_expand", "block1a_se_excite", "block1a_project", "block1a_project_bn", "block1a_project_activation"])
timm_inter_tensor = pt_compute(timm_batch, [timm_model.conv_stem, timm_model.bn1, timm_model.blocks[0][0]])

In [12]:
keras_inter_tensor = keras_compute(batch, keras_model.backbone, ["stem_conv_pad", "stem_conv", "stem_bn", "stem_activation",
    "block1a_dwconv_pad", "block1a_dwconv", "block1a_dwconv_bn", "block1a_dwconv_activation", "block1a_se_squeeze", "block1a_se_reshape",
    "block1a_se_reduce", "block1a_se_expand", "block1a_se_excite", "block1a_project", "block1a_project_bn", "block1a_project_activation",
    'block2a_expand_conv', 'block2a_expand_bn', 'block2a_expand_activation', 'block2a_dwconv_pad', 'block2a_dwconv', 'block2a_dwconv_bn', 'block2a_dwconv_activation', 'block2a_se_squeeze', 'block2a_se_reshape', 'block2a_se_reduce', 'block2a_se_expand', 'block2a_se_excite', 'block2a_project', 'block2a_project_bn', 'block2a_project_activation'])
timm_inter_tensor = pt_compute(timm_batch, [timm_model.conv_stem, timm_model.bn1, timm_model.blocks[0][0], timm_model.blocks[1][0]])

ValueError: Input 0 of layer "block2a_expand_bn" is incompatible with the layer: expected axis 3 of input shape to have value 96, but received input with shape (1, 110, 110, 16)

In [14]:
keras_inter_tensor = keras_compute(batch, keras_model.backbone, ["stem_conv_pad", "stem_conv", "stem_bn", "stem_activation",
    "block1a_dwconv_pad", "block1a_dwconv", "block1a_dwconv_bn", "block1a_dwconv_activation", "block1a_se_squeeze", "block1a_se_reshape",
    "block1a_se_reduce", "block1a_se_expand", "block1a_se_excite", "block1a_project", "block1a_project_bn", "block1a_project_activation",
    'block2a_expand_conv', 'block2a_expand_bn', 'block2a_expand_activation', 'block2a_dwconv_pad', 'block2a_dwconv', 'block2a_dwconv_bn', 'block2a_dwconv_activation', 'block2a_se_squeeze', 'block2a_se_reshape', 'block2a_se_reduce', 'block2a_se_expand', 'block2a_se_excite', 'block2a_project', 'block2a_project_bn', 'block2a_project_activation',
    'block2b_expand_conv', 'block2b_expand_bn', 'block2b_expand_activation', 'block2b_dwconv_pad', 'block2b_dwconv', 'block2b_dwconv_bn', 'block2b_dwconv_activation', 'block2b_se_squeeze', 'block2b_se_reshape', 'block2b_se_reduce', 'block2b_se_expand', 'block2b_se_excite', 'block2b_project', 'block2b_project_bn', 'block2b_project_activation', 'block2b_add'])
timm_inter_tensor = pt_compute(timm_batch, [timm_model.conv_stem, timm_model.bn1, timm_model.blocks[0], timm_model.blocks[1]])

In [17]:
keras_inter_tensor = keras_compute(batch, keras_model.backbone, ["stem_conv_pad", "stem_conv", "stem_bn", "stem_activation",
    "block1a_dwconv_pad", "block1a_dwconv", "block1a_dwconv_bn", "block1a_dwconv_activation", "block1a_se_squeeze", "block1a_se_reshape",
    "block1a_se_reduce", "block1a_se_expand", "block1a_se_excite", "block1a_project", "block1a_project_bn", "block1a_project_activation",
    'block2a_expand_conv', 'block2a_expand_bn', 'block2a_expand_activation', 'block2a_dwconv_pad', 'block2a_dwconv', 'block2a_dwconv_bn', 'block2a_dwconv_activation', 'block2a_se_squeeze', 'block2a_se_reshape', 'block2a_se_reduce', 'block2a_se_expand', 'block2a_se_excite', 'block2a_project', 'block2a_project_bn', 'block2a_project_activation',
    'block2b_expand_conv', 'block2b_expand_bn', 'block2b_expand_activation', 'block2b_dwconv_pad', 'block2b_dwconv', 'block2b_dwconv_bn', 'block2b_dwconv_activation', 'block2b_se_squeeze', 'block2b_se_reshape', 'block2b_se_reduce', 'block2b_se_expand', 'block2b_se_excite', 'block2b_project', 'block2b_project_bn', 'block2b_project_activation', 'block2b_add',
    'block3a_expand_conv', 'block3a_expand_bn', 'block3a_expand_activation', 'block3a_dwconv_pad', 'block3a_dwconv', 'block3a_dwconv_bn', 'block3a_dwconv_activation', 'block3a_se_squeeze', 'block3a_se_reshape', 'block3a_se_reduce', 'block3a_se_expand', 'block3a_se_excite', 'block3a_project', 'block3a_project_bn', 'block3a_project_activation',
    'block3b_expand_conv', 'block3b_expand_bn', 'block3b_expand_activation', 'block3b_dwconv_pad', 'block3b_dwconv', 'block3b_dwconv_bn', 'block3b_dwconv_activation', 'block3b_se_squeeze', 'block3b_se_reshape', 'block3b_se_reduce', 'block3b_se_expand', 'block3b_se_excite', 'block3b_project', 'block3b_project_bn', 'block3b_project_activation', 'block3b_add'])
timm_inter_tensor = pt_compute(timm_batch, [timm_model.conv_stem, timm_model.bn1, timm_model.blocks[0], timm_model.blocks[1], timm_model.blocks[2]])

In [26]:
keras_inter_tensor = keras_compute(batch, keras_model.backbone, ["stem_conv_pad", "stem_conv", "stem_bn", "stem_activation",
    "block1a_dwconv_pad", "block1a_dwconv", "block1a_dwconv_bn", "block1a_dwconv_activation", "block1a_se_squeeze", "block1a_se_reshape",
    "block1a_se_reduce", "block1a_se_expand", "block1a_se_excite", "block1a_project", "block1a_project_bn", "block1a_project_activation",
    'block2a_expand_conv', 'block2a_expand_bn', 'block2a_expand_activation', 'block2a_dwconv_pad', 'block2a_dwconv', 'block2a_dwconv_bn', 'block2a_dwconv_activation', 'block2a_se_squeeze', 'block2a_se_reshape', 'block2a_se_reduce', 'block2a_se_expand', 'block2a_se_excite', 'block2a_project', 'block2a_project_bn', 'block2a_project_activation',
    'block2b_expand_conv', 'block2b_expand_bn', 'block2b_expand_activation', 'block2b_dwconv_pad', 'block2b_dwconv', 'block2b_dwconv_bn', 'block2b_dwconv_activation', 'block2b_se_squeeze', 'block2b_se_reshape', 'block2b_se_reduce', 'block2b_se_expand', 'block2b_se_excite', 'block2b_project', 'block2b_project_bn', 'block2b_project_activation', 'block2b_add',
    'block3a_expand_conv', 'block3a_expand_bn', 'block3a_expand_activation', 'block3a_dwconv_pad', 'block3a_dwconv', 'block3a_dwconv_bn', 'block3a_dwconv_activation', 'block3a_se_squeeze', 'block3a_se_reshape', 'block3a_se_reduce', 'block3a_se_expand', 'block3a_se_excite', 'block3a_project', 'block3a_project_bn', 'block3a_project_activation',
    'block3b_expand_conv', 'block3b_expand_bn', 'block3b_expand_activation', 'block3b_dwconv_pad', 'block3b_dwconv', 'block3b_dwconv_bn', 'block3b_dwconv_activation', 'block3b_se_squeeze', 'block3b_se_reshape', 'block3b_se_reduce', 'block3b_se_expand', 'block3b_se_excite', 'block3b_project', 'block3b_project_bn', 'block3b_project_activation', 'block3b_add',
    'block4a_expand_conv', 'block4a_expand_bn', 'block4a_expand_activation', 'block4a_dwconv_pad', 'block4a_dwconv', 'block4a_dwconv_bn', 'block4a_dwconv_activation', 'block4a_se_squeeze', 'block4a_se_reshape', 'block4a_se_reduce', 'block4a_se_expand', 'block4a_se_excite', 'block4a_project', 'block4a_project_bn', 'block4a_project_activation',
    'block4b_expand_conv', 'block4b_expand_bn', 'block4b_expand_activation', 'block4b_dwconv_pad', 'block4b_dwconv', 'block4b_dwconv_bn', 'block4b_dwconv_activation', 'block4b_se_squeeze', 'block4b_se_reshape', 'block4b_se_reduce', 'block4b_se_expand', 'block4b_se_excite', 'block4b_project', 'block4b_project_bn', 'block4b_project_activation', 'block4b_add',
    'block4c_expand_conv', 'block4c_expand_bn', 'block4c_expand_activation', 'block4c_dwconv_pad', 'block4c_dwconv', 'block4c_dwconv_bn', 'block4c_dwconv_activation', 'block4c_se_squeeze', 'block4c_se_reshape', 'block4c_se_reduce', 'block4c_se_expand', 'block4c_se_excite', 'block4c_project', 'block4c_project_bn', 'block4c_project_activation', 'block4c_add',
    'block5a_expand_conv', 'block5a_expand_bn', 'block5a_expand_activation', 'block5a_dwconv_pad', 'block5a_dwconv', 'block5a_dwconv_bn', 'block5a_dwconv_activation', 'block5a_se_squeeze', 'block5a_se_reshape', 'block5a_se_reduce', 'block5a_se_expand', 'block5a_se_excite', 'block5a_project', 'block5a_project_bn', 'block5a_project_activation',
    'block5b_expand_conv', 'block5b_expand_bn', 'block5b_expand_activation', 'block5b_dwconv_pad', 'block5b_dwconv', 'block5b_dwconv_bn', 'block5b_dwconv_activation', 'block5b_se_squeeze', 'block5b_se_reshape', 'block5b_se_reduce', 'block5b_se_expand', 'block5b_se_excite', 'block5b_project', 'block5b_project_bn', 'block5b_project_activation', 'block5b_add',
    'block5c_expand_conv', 'block5c_expand_bn', 'block5c_expand_activation', 'block5c_dwconv_pad', 'block5c_dwconv', 'block5c_dwconv_bn', 'block5c_dwconv_activation', 'block5c_se_squeeze', 'block5c_se_reshape', 'block5c_se_reduce', 'block5c_se_expand', 'block5c_se_excite', 'block5c_project', 'block5c_project_bn', 'block5c_project_activation', 'block5c_add',
    'block6a_expand_conv', 'block6a_expand_bn', 'block6a_expand_activation', 'block6a_dwconv_pad', 'block6a_dwconv', 'block6a_dwconv_bn', 'block6a_dwconv_activation', 'block6a_se_squeeze', 'block6a_se_reshape', 'block6a_se_reduce', 'block6a_se_expand', 'block6a_se_excite', 'block6a_project', 'block6a_project_bn', 'block6a_project_activation',
    'block6b_expand_conv', 'block6b_expand_bn', 'block6b_expand_activation', 'block6b_dwconv_pad', 'block6b_dwconv', 'block6b_dwconv_bn', 'block6b_dwconv_activation', 'block6b_se_squeeze', 'block6b_se_reshape', 'block6b_se_reduce', 'block6b_se_expand', 'block6b_se_excite', 'block6b_project', 'block6b_project_bn', 'block6b_project_activation', 'block6b_add',
    'block6c_expand_conv', 'block6c_expand_bn', 'block6c_expand_activation', 'block6c_dwconv_pad', 'block6c_dwconv', 'block6c_dwconv_bn', 'block6c_dwconv_activation', 'block6c_se_squeeze', 'block6c_se_reshape', 'block6c_se_reduce', 'block6c_se_expand', 'block6c_se_excite', 'block6c_project', 'block6c_project_bn', 'block6c_project_activation', 'block6c_add',
    'block6d_expand_conv', 'block6d_expand_bn', 'block6d_expand_activation', 'block6d_dwconv_pad', 'block6d_dwconv', 'block6d_dwconv_bn', 'block6d_dwconv_activation', 'block6d_se_squeeze', 'block6d_se_reshape', 'block6d_se_reduce', 'block6d_se_expand', 'block6d_se_excite', 'block6d_project', 'block6d_project_bn', 'block6d_project_activation', 'block6d_add',
    'block7a_expand_conv', 'block7a_expand_bn', 'block7a_expand_activation', 'block7a_dwconv_pad', 'block7a_dwconv', 'block7a_dwconv_bn', 'block7a_dwconv_activation', 'block7a_se_squeeze', 'block7a_se_reshape', 'block7a_se_reduce', 'block7a_se_expand', 'block7a_se_excite', 'block7a_project', 'block7a_project_bn', 'block7a_project_activation',
    'top_conv', 'top_bn', 'top_activation',])
timm_inter_tensor = pt_compute(timm_batch, [timm_model.conv_stem, timm_model.bn1, timm_model.blocks, timm_model.conv_head, timm_model.bn2])

In [29]:
keras_model.num_classes

1000

In [27]:
keras_inter_tensor.shape

TensorShape([1, 7, 7, 1280])

In [28]:
timm_inter_tensor.shape

torch.Size([1, 1280, 7, 7])

In [14]:
keras_model.summary()

In [15]:
sum(p.numel() for p in timm_model.parameters())

5288548

In [25]:
print(timm_model)

EfficientNet(
  (conv_stem): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (bn1): BatchNormAct2d(
    32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
    (drop): Identity()
    (act): SiLU(inplace=True)
  )
  (blocks): Sequential(
    (0): Sequential(
      (0): DepthwiseSeparableConv(
        (conv_dw): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (bn1): BatchNormAct2d(
          32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): SiLU(inplace=True)
        )
        (aa): Identity()
        (se): SqueezeExcite(
          (conv_reduce): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
          (act1): SiLU(inplace=True)
          (conv_expand): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
          (gate): Sigmoid()
        )
        (conv_pw): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn2

In [24]:
# Are the kernels the same?
channel_first_to_last(timm_inter_tensor.detach().numpy())

<tf.Tensor: shape=(1, 7, 7, 320), dtype=float32, numpy=
array([[[[ 1.3716862 , -1.1939676 , -0.37741017, ...,  2.4644077 ,
          -1.6410608 ,  0.7925929 ],
         [ 0.5655327 , -1.1748612 ,  0.24523261, ...,  2.0986621 ,
          -1.1746744 ,  0.3095982 ],
         [ 0.41059875, -1.820539  , -0.08460848, ...,  2.0553904 ,
          -0.99730897,  0.19595492],
         ...,
         [ 0.7994732 , -1.7311807 , -0.32693085, ...,  1.8780891 ,
          -1.2545584 ,  0.48745966],
         [ 0.41935104, -1.0593461 ,  0.3290768 , ...,  2.1181998 ,
          -1.1987295 ,  0.4043243 ],
         [ 1.1867679 , -0.88511103, -0.11355655, ...,  1.737643  ,
          -1.4012647 ,  0.922706  ]],

        [[ 1.4016988 , -1.8750598 ,  0.28465202, ...,  2.9077697 ,
          -2.099918  ,  0.6227735 ],
         [ 1.1538404 , -3.0511618 ,  1.1347332 , ...,  2.032656  ,
          -2.190263  , -0.44623667],
         [ 1.64216   , -4.021599  ,  1.6149508 , ...,  2.0009499 ,
          -4.177692  ,  0.708

In [93]:
np.testing.assert_allclose(keras_inter_tensor, channel_first_to_last(timm_inter_tensor.detach().numpy()), atol=1)

AssertionError: 
Not equal to tolerance rtol=1e-07, atol=1

Mismatched elements: 39652 / 200704 (19.8%)
Max absolute difference: 16.5599
Max relative difference: 3584.6614
 x: array([[[[ 1.556794e+01, -7.245570e-02, -9.366754e-02, ...,
          -2.784230e-01, -1.548447e-01, -2.063109e-01],
         [-3.931458e-02, -2.276724e-01, -1.623400e-01, ...,...
 y: array([[[[ 1.557224e+01, -4.166960e+00, -5.130501e-01, ...,
          -9.678993e-01, -2.793869e+00, -2.740242e+00],
         [-1.269636e-01, -9.002581e-01, -3.110859e+00, ...,...

In [82]:
np.mean(np.abs(keras_inter_tensor - channel_first_to_last(timm_inter_tensor.detach().numpy())))

4.6444835e-05

In [34]:
# different shape but close... let's look at them
np.argmax(keras_inter_tensor)

120227

In [33]:
# and the timm one
np.max(channel_first_to_last(timm_inter_tensor.detach().numpy()))

1633.6973

In [17]:
print(timm_model)

EfficientNet(
  (conv_stem): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (bn1): BatchNormAct2d(
    32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
    (drop): Identity()
    (act): SiLU(inplace=True)
  )
  (blocks): Sequential(
    (0): Sequential(
      (0): DepthwiseSeparableConv(
        (conv_dw): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (bn1): BatchNormAct2d(
          32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): SiLU(inplace=True)
        )
        (aa): Identity()
        (se): SqueezeExcite(
          (conv_reduce): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
          (act1): SiLU(inplace=True)
          (conv_expand): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
          (gate): Sigmoid()
        )
        (conv_pw): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn2

In [14]:
keras_model.backbone(batch)

<tf.Tensor: shape=(1, 7, 7, 1280), dtype=float32, numpy=
array([[[[-6.46060658e-33,  4.42028656e+01, -0.00000000e+00, ...,
          -0.00000000e+00, -0.00000000e+00, -0.00000000e+00],
         [-1.06582133e-36,  5.92849388e+01, -0.00000000e+00, ...,
          -0.00000000e+00, -0.00000000e+00, -0.00000000e+00],
         [-0.00000000e+00,  6.53860703e+01, -0.00000000e+00, ...,
          -0.00000000e+00, -0.00000000e+00, -0.00000000e+00],
         ...,
         [-0.00000000e+00,  6.09511795e+01, -0.00000000e+00, ...,
          -0.00000000e+00, -0.00000000e+00, -0.00000000e+00],
         [-9.93342045e-23,  3.84074516e+01, -0.00000000e+00, ...,
          -0.00000000e+00, -0.00000000e+00, -0.00000000e+00],
         [-1.58727821e-03,  1.92612305e+01, -0.00000000e+00, ...,
          -0.00000000e+00, -0.00000000e+00, -0.00000000e+00]],

        [[-0.00000000e+00,  5.76001701e+01, -0.00000000e+00, ...,
          -0.00000000e+00, -0.00000000e+00, -0.00000000e+00],
         [-0.00000000e+00,  7.1

In [22]:
timm_tensor.shape

torch.Size([1, 320, 7, 7])