In [None]:
!export CUDA_LAUNCH_BLOCKING=1

In [None]:
import torch
from transformers import ViTModel, ViTConfig

from vit.vit import VIT
from vit.load_weights import transfer_pretrained_weights

from vit.utils import timed

In [None]:
torch.manual_seed(0)
# torch.use_deterministic_algorithms(True)
torch.backends.cudnn.benchmark = False

def diff(a, b):
    return torch.abs(a - b).mean()

device = 'cuda:0'
dtype = torch.float32
model_id = 'google/vit-base-patch16-224'
vit_config = ViTConfig(model_id)

**Loading weights**

In [None]:
height, width, channels = vit_config.image_size, vit_config.image_size, vit_config.num_channels
patch_size = vit_config.patch_size
hidden_dim = 768
num_heads=vit_config.num_attention_heads
num_layers=vit_config.num_hidden_layers

model = VIT(
    height=height,
    width=width,
    channels=channels,
    patch_size=patch_size,
    hidden_dim=hidden_dim,
    num_heads=num_heads,
    num_layers=num_layers
)
model.to(device, dtype)

pretrained_model = ViTModel.from_pretrained(model_id, add_pooling_layer=False)
pretrained_model.to(device, dtype)
pretrained_model.eval()

model = transfer_pretrained_weights(
    pretrained_model=pretrained_model,
    custom_model=model
)

# Number of params in each model
sum(p.numel() for p in pretrained_model.parameters()), sum(p.numel() for p in model.parameters())

**Verifying layer by layer outputs**

In [None]:
store_pretrained = {}
store_custom = {}

def hook(module, input, output, name, store):
    store[name] = output

for name, layer in pretrained_model.named_modules():
    layer.register_forward_hook(lambda layer, input, output, name=name: hook(layer, input, output, name, store_pretrained))

for name, layer in model.named_modules():
    layer.register_forward_hook(lambda layer, input, output, name=name: hook(layer, input, output, name, store_custom))

In [None]:
inputs = torch.randn((4, 3, 224, 224)).to(device, dtype)

with torch.no_grad():
    pretrained_out = pretrained_model(inputs)[0]
    custom_out = model(inputs)

In [None]:
o1 = timed(pretrained_model, inputs)
o2 = timed(model, inputs)

print(o2[1]/o1[1], o2[1], o1[1])
print(diff(o1[0][0], o2[0]))

In [None]:
weight_mapping = {
    'embeddings.patch_embeddings.projection': 'embeddings.projection',
    'embeddings.patch_embeddings': None,
    'embeddings.dropout': None,
    'embeddings': 'embeddings',
    'layernorm': 'layernorm',
}

for i in range(0, num_heads):
    weight_mapping.update({
        f'encoder.layer.{i}.layernorm_before': f'encoder.layer.{i}.layernorm_before',
        f'encoder.layer.{i}.attention.attention.query': f'encoder.layer.{i}.attention.attention.qkv',
        f'encoder.layer.{i}.attention.attention.key': f'encoder.layer.{i}.attention.attention.qkv',
        f'encoder.layer.{i}.attention.attention.value': f'encoder.layer.{i}.attention.attention.qkv',
        f'encoder.layer.{i}.attention.attention.dropout': None,
        f'encoder.layer.{i}.attention.attention': f'encoder.layer.{i}.attention.attention',
        f'encoder.layer.{i}.attention.output.dense': None,
        f'encoder.layer.{i}.attention.output.dropout': None,
        f'encoder.layer.{i}.attention.output': f'encoder.layer.{i}.attention.output',
        f'encoder.layer.{i}.attention': f'encoder.layer.{i}.attention',
        f'encoder.layer.{i}.layernorm_after': f'encoder.layer.{i}.layernorm_after',
        f'encoder.layer.{i}.intermediate.dense': f'encoder.layer.{i}.intermediate',
        f'encoder.layer.{i}.intermediate.intermediate_act_fn': None,
        f'encoder.layer.{i}.intermediate': None,
        f'encoder.layer.{i}.output.dense': f'encoder.layer.{i}.output',
        f'encoder.layer.{i}.output.dropout': None,
        f'encoder.layer.{i}.output': None,
        f'encoder.layer.{i}': f'encoder.layer.{i}',
    })

In [None]:
for k, v in weight_mapping.items():

    if k and v:
        if type(v) == list:
            val2 = torch.cat([store_custom[e] for e in v], dim=-1)
        else:
            val2 = store_custom[v]

        val1 = store_pretrained[k]

        if type(val1) == tuple:
            val1 = val1[0]

        if val1.shape != val2.shape:
             print(f'{k}: {val1.shape}\t\t\t\t{v}: {val2.shape}')
             continue

        print(f'{k}: {val1.shape}\t\t\t\t{v}: {val2.shape}\t\t\t\t{torch.abs(val1-val2).max()}')

        match = torch.allclose(val1, val2, rtol=0, atol=1)

Rough

In [None]:
i1 = store_pretrained['encoder.layer.0.layernorm_before']
i2 = store_custom['encoder.layer.0.layernorm_before']

w1 = pretrained_model.encoder.layer[0].attention.attention.key.weight
b1 = pretrained_model.encoder.layer[0].attention.attention.key.bias

w2 = model.encoder.layer[0].attention.attention.qkv.weight
b2 = model.encoder.layer[0].attention.attention.qkv.bias

z1 = torch.matmul(i1, w1.T)
z2 = torch.matmul(i2, w2)
z2 = z2[:, :, 768:1536]

v1 = store_pretrained['encoder.layer.0.attention.attention.key']
v2 = store_custom['encoder.layer.0.attention.attention.qkv']
v2 = v2[:, :, 768:1536]

In [None]:
diff(i1, i2), diff(w1.T, w2[:, 768:1536]), diff(b1, b2[768:1536])

In [None]:
diff(z1, z2), diff(v1, v2), diff(z1, v1), diff(z2, v2)

**Assigning identity weights for debugging**

In [None]:
pretrained_state_dict = pretrained_model.state_dict()
custom_state_dict = model.state_dict()

In [None]:
pretrained_state_dict_new = {}
custom_state_dict_new = {}

for k, v in pretrained_state_dict.items():
    pretrained_state_dict_new[k] = torch.ones_like(v).to(device, dtype)

for k, v in custom_state_dict.items():
    custom_state_dict_new[k] = torch.ones_like(v).to(device, dtype)

In [None]:
model.load_state_dict(custom_state_dict_new)
pretrained_model.load_state_dict(pretrained_state_dict_new)