In [None]:
import torch
from transformers import ViTModel

from vit.vit import VIT, SelfAttention

In [None]:
pretrained_model = ViTModel.from_pretrained("google/vit-base-patch16-224")
pretrained_state_dict = pretrained_model.state_dict()

custom_vit_model = VIT(
    height=224, 
    width=224, 
    channels=3, 
    patch_size=16, 
    hidden_dim=768, 
    num_heads=12, 
    num_layers=12
)

In [None]:
for k, v in pretrained_model.state_dict().items():
    print(f'{k}\t{pretrained_model.state_dict()[k].shape}')

In [None]:
for k, v in custom_vit_model.state_dict().items():
    print(f'{k}\t{custom_vit_model.state_dict()[k].shape}')

useful ref: https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-model-across-devices

In [None]:
# Prepare a mapping dictionary from pre-trained model keys to custom model keys
# Here we need to match the keys manually based on your understanding of both models
key_mapping = {
    'embeddings.patch_embeddings.projection.weight': 'projection',
    'embeddings.position_embeddings': 'positional_embedding',
    # Add more mappings as required
}

# Initialize a new state dict for the custom model
custom_state_dict = custom_vit_model.state_dict()

# Update the custom state dict with weights from the pre-trained model
for key, value in pretrained_model.state_dict().items():
    mapped_key = key_mapping.get(key)
    if mapped_key and mapped_key in custom_state_dict:
        print(f"Transferring weight for {mapped_key}")
        custom_state_dict[mapped_key] = value.clone()

custom_vit_model.load_state_dict(custom_state_dict, strict=False)

# Move the model to the GPU
custom_vit_model.to('cuda:0')
