In [None]:
import numpy as np
import torch
import torchvision.models as models
from tensorflow import keras
from tensorflow.keras import layers

from resnet import resnet18

In [None]:
use_float64 = True  # Use float64 for more precision
training = False  # Verify batch norm behavior for both train and eval

In [None]:
# Load pretrained pytorch model
torch_model = models.resnet18(pretrained=True)
if use_float64:
    torch_model = torch_model.double()
if training:
    torch_model.train()
else:
    torch_model.eval()

In [None]:
# Create tf2 model
if use_float64:
    keras.backend.set_floatx('float64')
inputs = keras.Input(shape=(None, None, 3))
outputs = resnet18(inputs)
model = keras.Model(inputs, outputs)

In [None]:
# Load pytorch weights
state_dict = torch_model.state_dict()
for layer in model.layers:
    if isinstance(layer, layers.Conv2D):
        layer.set_weights([state_dict[f'{layer.name}.weight'].numpy().transpose((2, 3, 1, 0))])
    elif isinstance(layer, layers.Dense):
        layer.set_weights([
            state_dict[f'{layer.name}.weight'].numpy().transpose(),
            state_dict[f'{layer.name}.bias'].numpy()
        ])
    elif isinstance(layer, layers.BatchNormalization):
        keys = ['weight', 'bias', 'running_mean', 'running_var']
        layer.set_weights([state_dict[f'{layer.name}.{key}'].numpy() for key in keys])

In [None]:
# Compare outputs
input_batch = np.random.rand(4, 256, 256, 3).astype(model.dtype)
output = model(input_batch, training=training).numpy()
with torch.no_grad():
    torch_output = torch_model(torch.tensor(input_batch.transpose((0, 3, 1, 2)))).numpy()
print(np.abs(output - torch_output).max())