In [1]:
from functools import partial

import torch
import torch.nn as nn

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.helpers import checkpoint_seq
from timm.models.layers import trunc_normal_, DropPath
from timm.models.registry import register_model
#from timm.models.layers.helpers import to_2tuple
from timm.models.layers import to_2tuple


class InceptionDWConv2d(nn.Module):
    def __init__(self, in_channels, square_kernel_size=3, band_kernel_size=11, branch_ratio=0.125):
        super().__init__()
        
        gc = int(in_channels * branch_ratio) # channel numbers of a convolution branch
        self.dwconv_hw = nn.Conv2d(gc, gc, square_kernel_size, padding=square_kernel_size//2, groups=gc)
        self.dwconv_w = nn.Conv2d(gc, gc, kernel_size=(1, band_kernel_size), padding=(0, band_kernel_size//2), groups=gc)
        self.dwconv_h = nn.Conv2d(gc, gc, kernel_size=(band_kernel_size, 1), padding=(band_kernel_size//2, 0), groups=gc)
        self.split_indexes = (in_channels - 3 * gc, gc, gc, gc)
        
    def forward(self, x):
        x_id, x_hw, x_w, x_h = torch.split(x, self.split_indexes, dim=1)
        return torch.cat(
            (x_id, self.dwconv_hw(x_hw), self.dwconv_w(x_w), self.dwconv_h(x_h)), 
            dim=1,
        )


class ConvMlp(nn.Module):
    def __init__(
            self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU,
            norm_layer=None, bias=True, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        bias = to_2tuple(bias)

        self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=bias[0])
        self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity()
        self.act = act_layer()
        self.drop = nn.Dropout(drop)
        self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=bias[1])

    def forward(self, x):
        x = self.fc1(x)
        x = self.norm(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        return x


class MlpHead(nn.Module):
    def __init__(self, dim, num_classes=1000, mlp_ratio=3, act_layer=nn.GELU,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), drop=0., bias=True):
        super().__init__()
        hidden_features = int(mlp_ratio * dim)
        self.fc1 = nn.Linear(dim, hidden_features, bias=bias)
        self.act = act_layer()
        self.norm = norm_layer(hidden_features)
        self.fc2 = nn.Linear(hidden_features, num_classes, bias=bias)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = x.mean((2, 3)) # global average pooling
        x = self.fc1(x)
        x = self.act(x)
        x = self.norm(x)
        x = self.drop(x)
        x = self.fc2(x)
        return x


class MetaNeXtBlock(nn.Module):

    def __init__(
            self,
            dim,
            token_mixer=nn.Identity,
            norm_layer=nn.BatchNorm2d,
            mlp_layer=ConvMlp,
            mlp_ratio=4,
            act_layer=nn.GELU,
            ls_init_value=1e-6,
            drop_path=0.,
            
    ):
        super().__init__()
        self.token_mixer = token_mixer(dim)
        self.norm = norm_layer(dim)
        self.mlp = mlp_layer(dim, int(mlp_ratio * dim), act_layer=act_layer)
        self.gamma = nn.Parameter(ls_init_value * torch.ones(dim)) if ls_init_value else None
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x):
        shortcut = x
        x = self.token_mixer(x)
        x = self.norm(x)
        x = self.mlp(x)
        if self.gamma is not None:
            x = x.mul(self.gamma.reshape(1, -1, 1, 1))
        x = self.drop_path(x) + shortcut
        return x


class MetaNeXtStage(nn.Module):
    def __init__(
            self,
            in_chs,
            out_chs,
            ds_stride=2,
            depth=2,
            drop_path_rates=None,
            ls_init_value=1.0,
            token_mixer=nn.Identity,
            act_layer=nn.GELU,
            norm_layer=None,
            mlp_ratio=4,
    ):
        super().__init__()
        self.grad_checkpointing = False
        if ds_stride > 1:
            self.downsample = nn.Sequential(
                norm_layer(in_chs),
                nn.Conv2d(in_chs, out_chs, kernel_size=ds_stride, stride=ds_stride),
            )
        else:
            self.downsample = nn.Identity()

        drop_path_rates = drop_path_rates or [0.] * depth
        stage_blocks = []
        for i in range(depth):
            stage_blocks.append(MetaNeXtBlock(
                dim=out_chs,
                drop_path=drop_path_rates[i],
                ls_init_value=ls_init_value,
                token_mixer=token_mixer,
                act_layer=act_layer,
                norm_layer=norm_layer,
                mlp_ratio=mlp_ratio,
            ))
            in_chs = out_chs
        self.blocks = nn.Sequential(*stage_blocks)

    def forward(self, x):
        x = self.downsample(x)
        if self.grad_checkpointing and not torch.jit.is_scripting():
            x = checkpoint_seq(self.blocks, x)
        else:
            x = self.blocks(x)
        return x


class MetaNeXt(nn.Module):
    def __init__(self, in_chans=3, num_classes=1000, depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), token_mixers=InceptionDWConv2d, norm_layer=nn.BatchNorm2d, act_layer=nn.GELU, mlp_ratios=(4, 4, 4, 3), head_fn=MlpHead, drop_rate=0., drop_path_rate=0., ls_init_value=1e-6, **kwargs):
        super().__init__()
        # ... (other initialization code as provided)
        self.stem = nn.Sequential(
            nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
            norm_layer(dims[0])
        )
        self.stages = nn.Sequential()
        dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
        stages = []
        prev_chs = dims[0]
        for i in range(len(depths)):
            out_chs = dims[i]
            stages.append(MetaNeXtStage(
                prev_chs,
                out_chs,
                ds_stride=2 if i > 0 else 1,
                depth=depths[i],
                drop_path_rates=dp_rates[i],
                ls_init_value=ls_init_value,
                act_layer=act_layer,
                token_mixer=token_mixers,
                norm_layer=norm_layer,
                mlp_ratio=mlp_ratios[i],
            ))
            prev_chs = out_chs
        self.stages = nn.Sequential(*stages)
        self.num_features = prev_chs
        self.head = head_fn(self.num_features, num_classes, drop=drop_rate)
        self.apply(self._init_weights)

    def forward_features(self, x):
        x = self.stem(x)
        x = self.stages(x)
        return x

    def forward_head(self, x):
        x = self.head(x)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.forward_head(x)
        return x

    def _init_weights(self, m):
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            trunc_normal_(m.weight, std=.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

def inceptionnext_tiny(**kwargs):
    model_torch = MetaNeXt(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), token_mixers=InceptionDWConv2d, **kwargs)
    My_model = "/kaggle/input/inceptionnext/pytorch/inceptionnext/1/inceptionnext_tiny.pth"
    checkpoint = torch.load(My_model, map_location=torch.device('cuda'))
    model_torch.load_state_dict(checkpoint.get("model", checkpoint))
    return model_torch

def inceptionnext_small(**kwargs):
    model_torch = MetaNeXt(depths=(3, 3, 27, 3), dims=(96, 192, 384, 768), token_mixers=InceptionDWConv2d, **kwargs)
    My_model = "/kaggle/input/inceptionnext/pytorch/inceptionnext/1/inceptionnext_small.pth"
    checkpoint = torch.load(My_model, map_location=torch.device('cuda'))
    model_torch.load_state_dict(checkpoint.get("model", checkpoint))
    return model_torch

def inceptionnext_base(**kwargs):
    model_torch = MetaNeXt(depths=(3, 3, 27, 3), dims=(128, 256, 512, 1024), token_mixers=InceptionDWConv2d, **kwargs)
    My_model = "/kaggle/input/inceptionnext/pytorch/inceptionnext/1/inceptionnext_base.pth"
    checkpoint = torch.load(My_model, map_location=torch.device('cuda'))
    model_torch.load_state_dict(checkpoint.get("model", checkpoint))
    return model_torch

def inceptionnext_base_384(**kwargs):
    model_torch = MetaNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], mlp_ratios=[4, 4, 4, 3], token_mixers=InceptionDWConv2d, **kwargs)
    My_model = "/kaggle/input/inceptionnext/pytorch/inceptionnext/1/inceptionnext_base_384.pth"
    checkpoint = torch.load(My_model, map_location=torch.device('cuda'))
    model_torch.load_state_dict(checkpoint.get("model", checkpoint))
    return model_torch



if __name__ == '__main__':
    model_torch = inceptionnext_base()
    n_parameters = sum(p.numel() for p in model_torch.parameters())
    print(n_parameters)

86672136


In [2]:
checkpoint = torch.load('/kaggle/input/inceptionnext/pytorch/inceptionnext/1/inceptionnext_base.pth', map_location=torch.device('cuda'))
model_torch.load_state_dict(checkpoint)

<All keys matched successfully>

# **Predict Image**

In [3]:
import requests
from PIL import Image
import torchvision.models as models
from torchvision import transforms

model = inceptionnext_base()
model.load_state_dict(torch.load('/kaggle/input/inceptionnext/pytorch/inceptionnext/1/inceptionnext_base.pth'))
model.eval()


transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

image_path = '/kaggle/input/catndog/CuteCat.jpg'
image = Image.open(image_path).convert('RGB')
image = transform(image)
image = image.unsqueeze(0)

url = 'https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json'
response = requests.get(url)
class_idx = response.json()

idx_to_class = {int(key): value[1] for key, value in class_idx.items()}

with torch.no_grad():
    output = model(image)
    _, predicted = torch.max(output, 1)
    class_id = predicted.item()
    class_name = idx_to_class[class_id]

print(f'Predicted class ID: {class_id}, Class name: {class_name}')

Predicted class ID: 285, Class name: Egyptian_cat


In [4]:
def extract_params(model_torch):
    params_dict = {}
    for name, param in model_torch.named_parameters():
        np_param = param.detach().cpu().numpy()
        params_dict[name] = np_param
    return params_dict

# Extract the parameters from the model
params = extract_params(model_torch)

In [5]:
import tensorflow as tf
from tensorflow.keras.layers import Dense, Conv2D, LayerNormalization, Dropout, Activation, Layer, Concatenate, GlobalAveragePooling1D, GlobalAveragePooling2D
from tensorflow.keras import Model
import numpy as np
import collections.abc
from itertools import repeat

# Helper function to handle n-tuple
def _ntuple(n):
    def parse(x):
        if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
            return tuple(x)
        return tuple(repeat(x, n))
    return parse

to_2tuple = _ntuple(2)

class Identity(Layer):
    def __init__(self):
        super(Identity, self).__init__()

    def call(self, x):
        return x

class DropPath(Layer):
    def __init__(self, drop_prob=0., scale_by_keep=True):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob
        self.keep_prob = 1 - drop_prob
        self.scale_by_keep = scale_by_keep

    def call(self, x, training=None):
        if self.drop_prob == 0. or not training:
            return x
        keep_prob = 1 - self.drop_prob
        random_tensor = tf.random.uniform(shape=tf.shape(x))
        random_tensor = tf.floor(random_tensor + keep_prob)
        if self.scale_by_keep:
            random_tensor = random_tensor / keep_prob
        return x * random_tensor

class InceptionDWConv2d(Layer):
    def __init__(self, in_channels, square_kernel_size=3, band_kernel_size=11, branch_ratio=0.125):
        super(InceptionDWConv2d, self).__init__()
        gc = int(in_channels * branch_ratio)
        self.dwconv_hw = Conv2D(gc, square_kernel_size, padding='same', groups=gc)
        self.dwconv_w = Conv2D(gc, (1, band_kernel_size), padding='same', groups=gc)
        self.dwconv_h = Conv2D(gc, (band_kernel_size, 1), padding='same', groups=gc)
        self.split_indexes = (in_channels - 3 * gc, gc, gc, gc)

    def call(self, x):
        x_id, x_hw, x_w, x_h = tf.split(x, self.split_indexes, axis=-1)
        return Concatenate(axis=-1)([x_id, self.dwconv_hw(x_hw), self.dwconv_w(x_w), self.dwconv_h(x_h)])

class ConvMlp(Layer):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer='relu', norm_layer=None, bias=True, drop=0.):
        super(ConvMlp, self).__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features

        self.fc1 = Conv2D(hidden_features, 1, use_bias=bias)
        self.norm = norm_layer() if norm_layer else Identity()
        self.act = Activation(act_layer)
        self.drop = Dropout(drop)
        self.fc2 = Conv2D(out_features, 1, use_bias=bias)

    def call(self, x):
        x = self.fc1(x)
        x = self.norm(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class MlpHead(Layer):
    def __init__(self, dim, num_classes=1000, mlp_ratio=3, act_layer='gelu', norm_layer=LayerNormalization, drop=0., bias=True):
        super(MlpHead, self).__init__()
        hidden_features = int(mlp_ratio * dim)
        self.fc1 = Dense(hidden_features, use_bias=bias)
        self.act = Activation(act_layer)
        self.norm = norm_layer(epsilon=1e-6)
        self.fc2 = Dense(num_classes, use_bias=bias)
        self.drop = Dropout(drop)

    def call(self, x):
        x = tf.reduce_mean(x, axis=[1, 2])  # global average pooling
        x = self.fc1(x)
        x = self.act(x)
        x = self.norm(x)
        x = self.drop(x)
        x = self.fc2(x)
        return x

class MetaNeXtBlock(Layer):
    def __init__(self, dim, token_mixer_class, norm_layer=LayerNormalization, mlp_layer=ConvMlp, mlp_ratio=4, act_layer='relu', ls_init_value=1e-6, drop_path=0.):
        super(MetaNeXtBlock, self).__init__()
        self.token_mixer = token_mixer_class(dim)
        self.norm = norm_layer()
        self.mlp = mlp_layer(dim, int(mlp_ratio * dim), act_layer=act_layer)
        self.gamma = self.add_weight(shape=(dim,), initializer=tf.keras.initializers.Constant(ls_init_value), trainable=True) if ls_init_value else None
        self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()

    def call(self, x):
        shortcut = x
        x = self.token_mixer(x)
        x = self.norm(x)
        x = self.mlp(x)
        if self.gamma is not None:
            x = x * self.gamma[None, None, None, :]
        x = self.drop_path(x) + shortcut
        return x

class MetaNeXtStage(Layer):
    def __init__(self, in_chs, out_chs, ds_stride=2, depth=2, drop_path_rates=None, ls_init_value=1.0, token_mixer_class=InceptionDWConv2d, act_layer='relu', norm_layer=LayerNormalization, mlp_ratio=4):
        super(MetaNeXtStage, self).__init__()
        if ds_stride > 1:
            self.downsample = tf.keras.Sequential([
                norm_layer(),
                Conv2D(out_chs, kernel_size=ds_stride, strides=ds_stride)
            ])
        else:
            self.downsample = Identity()

        drop_path_rates = drop_path_rates or [0.] * depth
        self.blocks = [MetaNeXtBlock(dim=out_chs, drop_path=drop_path_rates[i], ls_init_value=ls_init_value, token_mixer_class=token_mixer_class, act_layer=act_layer, norm_layer=norm_layer, mlp_ratio=mlp_ratio) for i in range(depth)]

    def call(self, x):
        x = self.downsample(x)
        for block in self.blocks:
            x = block(x)
        return x

class MetaNeXt(Model):
    def __init__(self, in_chans=3, num_classes=1000, depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), token_mixers=InceptionDWConv2d, norm_layer=LayerNormalization, act_layer='relu', mlp_ratios=(4, 4, 4, 3), head_fn=MlpHead, drop_rate=0., drop_path_rate=0., ls_init_value=1e-6):
        super(MetaNeXt, self).__init__()

        num_stage = len(depths)
        if not isinstance(token_mixers, (list, tuple)):
            token_mixers = [token_mixers] * num_stage
        if not isinstance(mlp_ratios, (list, tuple)):
            mlp_ratios = [mlp_ratios] * num_stage

        self.num_classes = num_classes
        self.drop_rate = drop_rate
        self.stem = tf.keras.Sequential([
            Conv2D(dims[0], kernel_size=4, strides=4),
            norm_layer()
        ])

        dp_rates = [x.numpy().tolist() for x in tf.split(tf.linspace(0., drop_path_rate, sum(depths)), num_or_size_splits=depths)]
        self.stages = []
        prev_chs = dims[0]
        for i in range(num_stage):
            out_chs = dims[i]
            self.stages.append(MetaNeXtStage(prev_chs, out_chs, ds_stride=2 if i > 0 else 1, depth=depths[i], drop_path_rates=dp_rates[i], ls_init_value=ls_init_value, act_layer=act_layer, token_mixer_class=token_mixers[i], norm_layer=norm_layer, mlp_ratio=mlp_ratios[i]))
            prev_chs = out_chs
        self.num_features = prev_chs
        self.head = head_fn(self.num_features, num_classes, drop=drop_rate)

    def call(self, x):
        x = self.stem(x)
        for stage in self.stages:
            x = stage(x)
        x = self.head(x)
        return x
    
    

def inceptionnext_tiny(**kwargs):
    model = MetaNeXt(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768),
                     token_mixers=InceptionDWConv2d,
                      **kwargs
    )
    return model

def inceptionnext_small(**kwargs):
    model = MetaNeXt(depths=(3, 3, 27, 3), dims=(96, 192, 384, 768),
                     token_mixers=InceptionDWConv2d,
                      **kwargs
    )
    return model

def inceptionnext_base(**kwargs):
    model = MetaNeXt(depths=(3, 3, 27, 3), dims=(128, 256, 512, 1024),
                     token_mixers=InceptionDWConv2d,
                      **kwargs
    )
    return model

def inceptionnext_base_384(**kwargs):
    model = MetaNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024],
                     mlp_ratios=[4, 4, 4, 3],
                     token_mixers=InceptionDWConv2d,
                      **kwargs
    )
    return model

model_tf = inceptionnext_base()
input_tensor = tf.random.normal([1, 224, 224, 3])
output = model_tf(input_tensor)
n_parameters = np.sum([np.prod(v.shape) for v in model_tf.trainable_weights])
print(n_parameters)

2024-06-20 10:28:24.270011: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-06-20 10:28:24.270075: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-06-20 10:28:24.271522: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
I0000 00:00:1718879323.297156    3989 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


86672136


# **Predict Image**

In [6]:
import tensorflow as tf
from tensorflow.keras.preprocessing import image
import numpy as np
import json
import requests

json_url = "https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json"


response = requests.get(json_url)
class_index = response.json()

def get_class_label(index):
    return class_index[str(index)][1]

def load_and_preprocess_image(img_path, target_size=(384, 384)):
    img = tf.keras.utils.load_img(img_path, target_size=target_size)
    img_array = tf.keras.utils.img_to_array(img)
    img_array = np.expand_dims(img_array, axis=0)
    img_array = img_array / 255.0
    return img_array

def predict_image_class(img_path):
    preprocessed_image = load_and_preprocess_image(img_path)
    predictions = model_tf(preprocessed_image)
    predicted_class = np.argmax(predictions, axis=-1)[0]
    predicted_label = get_class_label(predicted_class)
    return predicted_class, predicted_label

img_path = '/kaggle/input/catndog/CuteCat.jpg'
predicted_class, predicted_label = predict_image_class(img_path)
print(f"Predicted class index: {predicted_class}, label: {predicted_label}")

Predicted class index: 45, label: Gila_monster


In [7]:
# for i, layer in enumerate(model_tf.layers):
#     print(i, layer.name, [w.shape for w in layer.get_weights()])

In [8]:
import numpy as np

def map_weights_to_tf(tf_model, params):
    # Function to set weights for dense layers
    def set_dense_weights(tf_layer, pt_weight_key, pt_bias_key):
        dense_weight = params[pt_weight_key]
        dense_weight_reshaped = np.transpose(dense_weight, (1, 0))  # Convert from PyTorch to TensorFlow format
        dense_bias = params[pt_bias_key]
        tf_layer.set_weights([dense_weight_reshaped, dense_bias])
    
    # Function to set weights for conv layers
    def set_conv_weights(tf_layer, pt_weight_key, pt_bias_key=None):
        conv_weight = params[pt_weight_key]
        conv_weight_reshaped = np.transpose(conv_weight, (2, 3, 1, 0))  # Convert from PyTorch to TensorFlow format
        if pt_bias_key and pt_bias_key in params:
            conv_bias = params[pt_bias_key]
            tf_layer.set_weights([conv_weight_reshaped, conv_bias])
        else:
            tf_layer.set_weights([conv_weight_reshaped])

    # Function to set weights for batch norm layers
    def set_bn_weights(tf_layer, pt_weight_key, pt_bias_key, pt_running_mean_key=None, pt_running_var_key=None):
        bn_weight = params[pt_weight_key]
        bn_bias = params[pt_bias_key]
        if pt_running_mean_key and pt_running_var_key and pt_running_mean_key in params and pt_running_var_key in params:
            bn_running_mean = params[pt_running_mean_key]
            bn_running_var = params[pt_running_var_key]
            tf_layer.set_weights([bn_weight, bn_bias, bn_running_mean, bn_running_var])
        else:
            tf_layer.set_weights([bn_weight, bn_bias])
    
    # Function to set weights for layer norm layers
    def set_ln_weights(tf_layer, pt_weight_key, pt_bias_key):
        ln_weight = params[pt_weight_key]
        ln_bias = params[pt_bias_key]
        tf_layer.set_weights([ln_weight, ln_bias])
    
    # Set weights for the stem layer
    set_conv_weights(tf_model.stem.layers[0], 'stem.0.weight', 'stem.0.bias')
    set_bn_weights(tf_model.stem.layers[1], 'stem.1.weight', 'stem.1.bias', 'stem.1.running_mean', 'stem.1.running_var')

    # Set weights for each stage
    for i, stage in enumerate(tf_model.stages):
        for j, block in enumerate(stage.blocks):
            block_prefix = f'stages.{i}.blocks.{j}'
            set_conv_weights(block.token_mixer.dwconv_hw, f'{block_prefix}.token_mixer.dwconv_hw.weight', f'{block_prefix}.token_mixer.dwconv_hw.bias')
            set_conv_weights(block.token_mixer.dwconv_w, f'{block_prefix}.token_mixer.dwconv_w.weight', f'{block_prefix}.token_mixer.dwconv_w.bias')
            set_conv_weights(block.token_mixer.dwconv_h, f'{block_prefix}.token_mixer.dwconv_h.weight', f'{block_prefix}.token_mixer.dwconv_h.bias')
            set_ln_weights(block.norm, f'{block_prefix}.norm.weight', f'{block_prefix}.norm.bias')
            set_conv_weights(block.mlp.fc1, f'{block_prefix}.mlp.fc1.weight', f'{block_prefix}.mlp.fc1.bias')
            set_conv_weights(block.mlp.fc2, f'{block_prefix}.mlp.fc2.weight', f'{block_prefix}.mlp.fc2.bias')
        
        if hasattr(stage, 'downsample') and isinstance(stage.downsample, tf.keras.Sequential):
            downsample_prefix = f'stages.{i}.downsample'
            set_conv_weights(stage.downsample.layers[1], f'{downsample_prefix}.1.weight', f'{downsample_prefix}.1.bias')
            set_bn_weights(stage.downsample.layers[0], f'{downsample_prefix}.0.weight', f'{downsample_prefix}.0.bias', f'{downsample_prefix}.0.running_mean', f'{downsample_prefix}.0.running_var')
    
    # Set weights for the head layer
    set_dense_weights(tf_model.head.fc1, 'head.fc1.weight', 'head.fc1.bias')
    set_ln_weights(tf_model.head.norm, 'head.norm.weight', 'head.norm.bias')
    set_dense_weights(tf_model.head.fc2, 'head.fc2.weight', 'head.fc2.bias')

In [9]:
map_weights_to_tf(model_tf, params)

In [10]:
model_tf.save_weights('model.weights.h5')

In [11]:
model_tf.load_weights('/kaggle/working/model.weights.h5')

In [12]:
n_parameters = np.sum([np.prod(v.shape) for v in model_tf.trainable_weights])
print(n_parameters)

86672136


In [13]:
import os

def preprocess_image(img_path, target_size=(224, 224)):
    img = image.load_img(img_path, target_size=target_size)
    img_array = image.img_to_array(img)
    img_array = np.expand_dims(img_array, axis=0)
    img_array = tf.keras.applications.imagenet_utils.preprocess_input(img_array)
    return img_array

class_index_path = 'imagenet_class_index.json'
if not os.path.exists(class_index_path):
    url = 'https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json'
    response = requests.get(url)
    with open(class_index_path, 'wb') as f:
        f.write(response.content)

# Load the ImageNet class index
with open(class_index_path) as f:
    class_index = json.load(f)

img_path = '/kaggle/input/catndog/CuteCat.jpg'

img_array = preprocess_image(img_path)

predictions = model_tf.predict(img_array)
predicted_class_index = np.argmax(predictions, axis=-1)[0]

predicted_class_name = class_index[str(predicted_class_index)][1]

print(f'Predicted class index: {predicted_class_index}, class name: {predicted_class_name}')

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 16s/step
Predicted class index: 163, class name: bloodhound
