# Convert tensorflow checkpoint to PyTorch model

In [None]:
import tensorflow as tf
from efficientnet_pytorch import EfficientNet
from torch import nn
import torch

ModuleNotFoundError: No module named 'tensorflow'

In [None]:
# Load tensorflow checkpoint
checkpoint_path = "../pretrained/checkpoint_EfficientNetB0/checkpoint"  # path to your checkpoint folder
ckpt = tf.train.Checkpoint()
ckpt.restore(tf.train.latest_checkpoint(checkpoint_path)).expect_partial()

In [None]:
# Explore the variables in the checkpoint
for var in tf.train.list_variables(checkpoint_path):
    print(var)

In [None]:
model = EfficientNet.from_name('efficientnet-b0')
model._fc = nn.Linear(model._fc.in_features, 19) 

In [None]:
# Example placeholder for mapping (you need to complete this based on your variables)
tf_to_torch = {
    'blocks_0/conv2d/kernel': 'blocks.0._depthwise_conv.weight',
    'blocks_0/conv2d/bias': 'blocks.0._depthwise_conv.bias',
    # ...continue mapping
}

# Load TF variables
reader = tf.train.load_checkpoint(checkpoint_path)

# Load state dict
state_dict = model.state_dict()

# Replace weights
for tf_name, torch_name in tf_to_torch.items():
    tensor = reader.get_tensor(tf_name)

    # Convert TensorFlow tensor to PyTorch tensor and match shape if needed
    tensor = torch.tensor(tensor)

    if state_dict[torch_name].shape != tensor.shape:
        print(f"Shape mismatch for {torch_name}: expected {state_dict[torch_name].shape}, got {tensor.shape}")
        continue

    state_dict[torch_name] = tensor

# Load modified state dict back into model
model.load_state_dict(state_dict)

In [None]:
# Test forward pass
model.eval()
with torch.no_grad():
    dummy_input = torch.randn(1, 3, 224, 224)
    output = model(dummy_input)
    print(output.shape)