## Example of loading the weights of a trained timm model to keras

#### import libraries

In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import timm
import tfimm
import tensorflow as tf
import torch

#### Define the timm model parameters
I will use ViT model, you can use any models supported by tfimm (link to these models is <a href='https://github.com/martinsbruveris/tensorflow-image-models#models'>here</a>)

In [None]:
TIMM_MODEL_NAME = 'vit_tiny_patch16_224'
N_CLASSES = 5
IMG_DIM = 224

I will create a model using timm, you can also load your custom weights

In [None]:
timm_model = timm.create_model(model_name=TIMM_MODEL_NAME, num_classes=N_CLASSES, pretrained=True)

#### Creating keras model using weights from the pretrained timm


In [None]:
keras_model = tfimm.load_timm_model(TIMM_MODEL_NAME, nb_classes=N_CLASSES, pt_model=timm_model)

#### Testing models

In [None]:
sample_input = torch.rand((2,3,IMG_DIM,IMG_DIM))
torch_output = timm_model(sample_input).detach().numpy()
keras_output = keras_model.predict(sample_input.permute(0,2,3,1).numpy())


In [None]:
print(torch_output)
print(keras_output)

#### Save keras model

In [None]:
keras_model.save(filepath=f'keras_{TIMM_MODEL_NAME}', save_format='tf')