In [None]:
import time
import torch
import requests
import numpy as np
from PIL import Image
from transformers import ViTModel, ViTConfig, AutoImageProcessor

from vit.vit import VIT
from vit.utils import transfer_pretrained_weights

In [None]:
torch.manual_seed(0)
# torch.use_deterministic_algorithms(True)
torch.backends.cudnn.benchmark = False

In [None]:
def diff(a, b):
    return torch.abs(a - b).max()

device = 'cuda:0'
dtype = torch.float32
model_id = 'google/vit-base-patch16-224'
vit_config = ViTConfig(model_id)

**Loading weights**

In [None]:
height, width, channels = vit_config.image_size, vit_config.image_size, vit_config.num_channels
patch_size = vit_config.patch_size
hidden_dim = 768
num_heads=vit_config.num_attention_heads
num_layers=vit_config.num_hidden_layers

model = VIT(
    height=height,
    width=width,
    channels=channels,
    patch_size=patch_size,
    hidden_dim=hidden_dim,
    num_heads=num_heads,
    num_layers=num_layers
)
model.to(device, dtype)

pretrained_model = ViTModel.from_pretrained(model_id, add_pooling_layer=False)
pretrained_model.to(device, dtype)
pretrained_model.eval()

model = transfer_pretrained_weights(
    pretrained_model=pretrained_model,
    custom_model=model
)

sum(p.numel() for p in pretrained_model.parameters()), sum(p.numel() for p in model.parameters())

In [None]:
pretrained_model

In [None]:
model

**Verifying layer by layer outputs**

In [None]:
image_processor = AutoImageProcessor.from_pretrained(model_id)

In [None]:
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
image = image.resize((height, width))
image = torch.Tensor(np.array(image)).to(device=device, dtype=dtype)

inputs = image_processor(image, return_tensors="pt")
inputs = inputs['pixel_values'].to(device, dtype)


print(f'Input image shape: {image.shape}')

In [None]:
store_pretrained = {}
store_custom = {}

def hook(module, input, output, name, store):
    store[name] = output

for name, layer in pretrained_model.named_modules():
    layer.register_forward_hook(lambda layer, input, output, name=name: hook(layer, input, output, name, store_pretrained))

for name, layer in model.named_modules():
    layer.register_forward_hook(lambda layer, input, output, name=name: hook(layer, input, output, name, store_custom))

In [None]:
start_time = time.time()

with torch.no_grad():
    pretrained_out = pretrained_model(inputs)
    m1_time = time.time() - start_time
    
    custom_out = model(inputs)
    m2_time = time.time() - m1_time

In [None]:
m1_time, m2_time

In [None]:
weight_mapping = {
    'embeddings.patch_embeddings.projection': 'embeddings.projection',
    'embeddings.patch_embeddings': None,
    'embeddings.dropout': None,
    'embeddings': 'embeddings',
    'layernorm': 'layernorm',
}

for i in range(0, num_heads):
    weight_mapping.update({
        f'encoder.layer.{i}.layernorm_before': f'encoder.layer.{i}.layernorm_before',
        f'encoder.layer.{i}.attention.attention.query': [f'encoder.layer.{i}.attention.attention.0.query', f'encoder.layer.{i}.attention.attention.1.query', f'encoder.layer.{i}.attention.attention.2.query', f'encoder.layer.{i}.attention.attention.3.query', f'encoder.layer.{i}.attention.attention.4.query', f'encoder.layer.{i}.attention.attention.5.query', f'encoder.layer.{i}.attention.attention.6.query', f'encoder.layer.{i}.attention.attention.7.query', f'encoder.layer.{i}.attention.attention.8.query', f'encoder.layer.{i}.attention.attention.9.query', f'encoder.layer.{i}.attention.attention.10.query', f'encoder.layer.{i}.attention.attention.11.query'],
        f'encoder.layer.{i}.attention.attention.key': [f'encoder.layer.{i}.attention.attention.0.key', f'encoder.layer.{i}.attention.attention.1.key', f'encoder.layer.{i}.attention.attention.2.key', f'encoder.layer.{i}.attention.attention.3.key', f'encoder.layer.{i}.attention.attention.4.key', f'encoder.layer.{i}.attention.attention.5.key', f'encoder.layer.{i}.attention.attention.6.key', f'encoder.layer.{i}.attention.attention.7.key', f'encoder.layer.{i}.attention.attention.8.key', f'encoder.layer.{i}.attention.attention.9.key', f'encoder.layer.{i}.attention.attention.10.key', f'encoder.layer.{i}.attention.attention.11.key'],
        f'encoder.layer.{i}.attention.attention.value': [f'encoder.layer.{i}.attention.attention.0.value', f'encoder.layer.{i}.attention.attention.1.value', f'encoder.layer.{i}.attention.attention.2.value', f'encoder.layer.{i}.attention.attention.3.value', f'encoder.layer.{i}.attention.attention.4.value', f'encoder.layer.{i}.attention.attention.5.value', f'encoder.layer.{i}.attention.attention.6.value', f'encoder.layer.{i}.attention.attention.7.value', f'encoder.layer.{i}.attention.attention.8.value', f'encoder.layer.{i}.attention.attention.9.value', f'encoder.layer.{i}.attention.attention.10.value', f'encoder.layer.{i}.attention.attention.11.value'],
        f'encoder.layer.{i}.attention.attention.dropout': None,
        f'encoder.layer.{i}.attention.attention': [f'encoder.layer.{i}.attention.attention.0', f'encoder.layer.{i}.attention.attention.1', f'encoder.layer.{i}.attention.attention.2', f'encoder.layer.{i}.attention.attention.3', f'encoder.layer.{i}.attention.attention.4', f'encoder.layer.{i}.attention.attention.5', f'encoder.layer.{i}.attention.attention.6', f'encoder.layer.{i}.attention.attention.7', f'encoder.layer.{i}.attention.attention.8', f'encoder.layer.{i}.attention.attention.9', f'encoder.layer.{i}.attention.attention.10', f'encoder.layer.{i}.attention.attention.11'],
        f'encoder.layer.{i}.attention.output.dense': None,
        f'encoder.layer.{i}.attention.output.dropout': None,
        f'encoder.layer.{i}.attention.output': f'encoder.layer.{i}.attention.output',
        f'encoder.layer.{i}.attention': f'encoder.layer.{i}.attention',
        f'encoder.layer.{i}.layernorm_after': f'encoder.layer.{i}.layernorm_after',
        f'encoder.layer.{i}.intermediate.dense': f'encoder.layer.{i}.intermediate',
        f'encoder.layer.{i}.intermediate.intermediate_act_fn': None,
        f'encoder.layer.{i}.intermediate': None,
        f'encoder.layer.{i}.output.dense': f'encoder.layer.{i}.output',
        f'encoder.layer.{i}.output.dropout': None,
        f'encoder.layer.{i}.output': None,
        f'encoder.layer.{i}': f'encoder.layer.{i}',
    })

In [None]:
for k, v in weight_mapping.items():

    if k and v:
        if type(v) == list:
            val2 = torch.cat([store_custom[e] for e in v], dim=-1)
        else:
            val2 = store_custom[v]

        val1 = store_pretrained[k]

        if type(val1) == tuple:
            val1 = val1[0]

        print(f'{k}\t{val1.shape}, {v}\t{val2.shape}\t{torch.abs(val1-val2).max()}')

        match = torch.allclose(val1, val2, rtol=0, atol=1)

        print('\n')

In [None]:
model

In [None]:
pretrained_model

In [None]:
i1 = store_pretrained['encoder.layer.4.intermediate.dense']
q1 = pretrained_model.encoder.layer[4].output.dense
i2 = store_custom['encoder.layer.4.intermediate']
q2 = model.encoder.layer[4].output

In [None]:
diff(q1(i1), q2(i2))

In [None]:
q1i1 = q1(i1)
q1i2 = q1(i2)
q2i1 = q2(i1)
q2i2 = q2(i2)

In [None]:
diff(i1, i2)

In [None]:
diff(q1i1, q1i2)

In [None]:
diff(q1i1, q2i1)

In [None]:
diff(q1i2, q2i2)

**Assigning identity weights**

In [None]:
pretrained_state_dict = pretrained_model.state_dict()
custom_state_dict = model.state_dict()

In [None]:
pretrained_state_dict_new = {}
custom_state_dict_new = {}

for k, v in pretrained_state_dict.items():
    pretrained_state_dict_new[k] = torch.ones_like(v).to(device, dtype)

for k, v in custom_state_dict.items():
    custom_state_dict_new[k] = torch.ones_like(v).to(device, dtype)

In [None]:
model.load_state_dict(custom_state_dict_new)
pretrained_model.load_state_dict(pretrained_state_dict_new)