In [None]:
import torch
from transformers import ViTModel

from vit.vit import VIT, SelfAttention

In [None]:
device = 'cuda:0'
dtype = torch.float16

pretrained_model = ViTModel.from_pretrained("google/vit-base-patch16-224")
pretrained_state_dict = pretrained_model.state_dict()

custom_vit_model = VIT(
    height=224, 
    width=224, 
    channels=3, 
    patch_size=16, 
    hidden_dim=768, 
    num_heads=12, 
    num_layers=12
)

pretrained_model.to(device, dtype)
custom_vit_model.to(device, dtype)

In [None]:
pretrained_model

In [None]:
for k, v in pretrained_model.state_dict().items():
    print(f'{k}\t{pretrained_model.state_dict()[k].shape}')

In [None]:
for k, v in custom_vit_model.state_dict().items():
    if 'query' in k or 'key' in k or 'value' in k:
        continue
    print(f'{k}\t{custom_vit_model.state_dict()[k].shape}')

- useful ref: https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-model-across-devices
- most of the layers are named correctly, but q, k, v needs to be split according to the heads and loaded correctly

In [None]:
custom_state_dict = custom_vit_model.state_dict()
pretrained_state_dict = pretrained_model.state_dict()

In [None]:
num_layers = 12
attention_layers = [f'encoder.layer.{k}.attention.attention' for k in range(num_layers)]

# Mapping dictionary from source model to destination model
weight_mapping = {
    'embeddings.position_embeddings': 'embeddings.position_embeddings',
    # 'embeddings.patch_embeddings.projection.weight': 'embeddings.projection.weight',
    # 'embeddings.patch_embeddings.projection.bias': 'embeddings.projection.bias'
}

# Adding mappings for each encoder layer's output and intermediate dense layers
for i in range(12):
    weight_mapping.update({
        f'encoder.layer.{i}.output.dense.weight': f'encoder.layer.{i}.output.weight',
        f'encoder.layer.{i}.output.dense.bias': f'encoder.layer.{i}.output.bias',
        f'encoder.layer.{i}.intermediate.dense.weight': f'encoder.layer.{i}.intermediate.weight',
        f'encoder.layer.{i}.intermediate.dense.bias': f'encoder.layer.{i}.intermediate.bias',
        f'encoder.layer.{i}.attention.output.dense.weight': f'encoder.layer.{i}.attention.output.weight',
        f'encoder.layer.{i}.attention.output.dense.bias': f'encoder.layer.{i}.attention.output.bias',
        f'encoder.layer.{i}.layernorm_before.weight': f'encoder.layer.{i}.layernorm_before.weight',
        f'encoder.layer.{i}.layernorm_before.bias': f'encoder.layer.{i}.layernorm_before.bias',
        f'encoder.layer.{i}.layernorm_after.weight': f'encoder.layer.{i}.layernorm_after.weight',
        f'encoder.layer.{i}.layernorm_after.bias': f'encoder.layer.{i}.layernorm_after.bias'
    })

In [None]:
def map_attn_layers(source_layer_num, source_proj, source_type, source_tensor, dest_state_dict):

    for layer, weight in dest_state_dict.items():
        if len(layer.split('.')) <= 2:
            continue

        layer_num, proj, type = layer.split('.')[2], layer.split('.')[-2], layer.split('.')[-1]
        if (layer_num == source_layer_num) and (proj == source_proj) and (type == source_type):
            num_head = int(layer.split('.')[5])

            if type == 'weight':
                src = source_tensor[:, num_head*64:(num_head+1)*64]
            else:
                src = source_tensor[num_head*64:(num_head+1)*64]

            print(f'Mapping to destination\tLayer: {layer}\tSource tensor shape: {src.shape}\tDestination tensor shape{weight.shape}')

            dest_state_dict[layer] = src.clone()

In [None]:
for layer_name, weight in pretrained_state_dict.items():
    for attn_layer in attention_layers:
        if attn_layer in layer_name:
            layer_num, proj, type = layer_name.split('.')[2], layer_name.split('.')[-2], layer_name.split('.')[-1]
            print(f'Mapping from source\tLayer number: {layer_num} \tProjection: {proj} \tType: {type}')
            map_attn_layers(layer_num, proj, type, weight, custom_state_dict)
            print('\n')

In [None]:
for key, value in pretrained_state_dict.items():
    mapped_key = weight_mapping.get(key)
    if mapped_key and mapped_key in custom_state_dict:
        # print(f"Transferring weight from {key} to {mapped_key}")

        if 'attention' not in mapped_key and ('output' in mapped_key or 'intermediate' in mapped_key):
            # print(f"Tranferring transpose weights")
            custom_state_dict[mapped_key] = value.t().clone()
            continue
        
        custom_state_dict[mapped_key] = value.clone()
        continue

    else:
        print(f'Not transferring weight for: {key}')

In [None]:
custom_vit_model.load_state_dict(custom_state_dict, strict=False)

In [None]:
custom_state_dict = custom_vit_model.state_dict()
for k, v in custom_state_dict.items():
    if torch.all(v == 0):
        print(k)