In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
from sparix.trans import Pad
from sparix.modeling.transformer import Transformer
from sparix.data  import FrameDataset

In [3]:
import h5py
import math
import random

In [4]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [5]:
path_model_weight = "chkpts/google.vit-base-patch16-224-in21k"
model_weight = torch.load(path_model_weight)
model_weight_dict = model_weight.get('model_state_dict')

In [6]:
Hp, Wp = 16, 16
num_frame_in_context = 4
num_patch = 400

tok_size            = Hp * Wp
embd_size           = 768    # (google's pretrained ViT)
context_length      = num_frame_in_context * num_patch
num_blocks          = 4
num_heads           = 4
uses_causal_mask    = True
attention_dropout   = 0.1
residual_dropout    = 0.1
feedforward_dropout = 0.1
# model = Transformer(Hp                  = Hp,
#                     Wp                  = Wp,
model = Transformer(tok_size            = tok_size,
                    embd_size           = embd_size,
                    context_length      = context_length,
                    num_blocks          = num_blocks,
                    num_heads           = num_heads,
                    uses_causal_mask    = uses_causal_mask,
                    attention_dropout   = attention_dropout,
                    residual_dropout    = residual_dropout,
                    feedforward_dropout = feedforward_dropout,)
model.to(device)

Transformer(
  (tok_embd_layer): Linear(in_features=256, out_features=768, bias=True)
  (pos_embd_layer): Embedding(1600, 768)
  (transformer_block): Sequential(
    (0): TransformerBlock(
      (multi_head_att_layer): MultiHeadAttention(
        (proj_q): Linear(in_features=768, out_features=768, bias=True)
        (proj_k): Linear(in_features=768, out_features=768, bias=True)
        (proj_v): Linear(in_features=768, out_features=768, bias=True)
        (proj_linear): Linear(in_features=768, out_features=768, bias=True)
        (attention_dropout): Dropout(p=0.1, inplace=False)
        (residual_dropout): Dropout(p=0.1, inplace=False)
      )
      (ff_layer): FeedForward(
        (ff_layer): Sequential(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU()
          (2): Linear(in_features=3072, out_features=768, bias=True)
          (3): Dropout(p=0.1, inplace=False)
        )
      )
      (layer_norm_pre_multi_head): LayerNorm((768,), eps=1e-0

In [9]:
for key in model.state_dict(): print(key)

pos_indices
tok_embd_layer.weight
tok_embd_layer.bias
pos_embd_layer.weight
transformer_block.0.multi_head_att_layer.mask
transformer_block.0.multi_head_att_layer.proj_q.weight
transformer_block.0.multi_head_att_layer.proj_q.bias
transformer_block.0.multi_head_att_layer.proj_k.weight
transformer_block.0.multi_head_att_layer.proj_k.bias
transformer_block.0.multi_head_att_layer.proj_v.weight
transformer_block.0.multi_head_att_layer.proj_v.bias
transformer_block.0.multi_head_att_layer.proj_linear.weight
transformer_block.0.multi_head_att_layer.proj_linear.bias
transformer_block.0.ff_layer.ff_layer.0.weight
transformer_block.0.ff_layer.ff_layer.0.bias
transformer_block.0.ff_layer.ff_layer.2.weight
transformer_block.0.ff_layer.ff_layer.2.bias
transformer_block.0.layer_norm_pre_multi_head.weight
transformer_block.0.layer_norm_pre_multi_head.bias
transformer_block.0.layer_norm_pre_feedforward.weight
transformer_block.0.layer_norm_pre_feedforward.bias
transformer_block.1.multi_head_att_layer.m

In [10]:
google_to_custom_dict = {
    "transformer_block.0.multi_head_att_layer.proj_q.weight"      : "vit.encoder.layer.0.attention.attention.query.weight",
    "transformer_block.0.multi_head_att_layer.proj_q.bias"        : "vit.encoder.layer.0.attention.attention.query.bias",
    "transformer_block.0.multi_head_att_layer.proj_k.weight"      : "vit.encoder.layer.0.attention.attention.key.weight",
    "transformer_block.0.multi_head_att_layer.proj_k.bias"        : "vit.encoder.layer.0.attention.attention.key.bias",
    "transformer_block.0.multi_head_att_layer.proj_v.weight"      : "vit.encoder.layer.0.attention.attention.value.weight",
    "transformer_block.0.multi_head_att_layer.proj_v.bias"        : "vit.encoder.layer.0.attention.attention.value.bias",
    "transformer_block.0.multi_head_att_layer.proj_linear.weight" : "vit.encoder.layer.0.attention.output.dense.weight",
    "transformer_block.0.multi_head_att_layer.proj_linear.bias"   : "vit.encoder.layer.0.attention.output.dense.bias",
    "transformer_block.0.ff_layer.ff_layer.0.weight"              : "vit.encoder.layer.0.intermediate.dense.weight",
    "transformer_block.0.ff_layer.ff_layer.0.bias"                : "vit.encoder.layer.0.intermediate.dense.bias",
    "transformer_block.0.ff_layer.ff_layer.2.weight"              : "vit.encoder.layer.0.output.dense.weight",
    "transformer_block.0.ff_layer.ff_layer.2.bias"                : "vit.encoder.layer.0.output.dense.bias",
    "transformer_block.0.layer_norm_pre_multi_head.weight"        : "vit.encoder.layer.0.layernorm_before.weight",
    "transformer_block.0.layer_norm_pre_multi_head.bias"          : "vit.encoder.layer.0.layernorm_before.bias",
    "transformer_block.0.layer_norm_pre_feedforward.weight"       : "vit.encoder.layer.0.layernorm_after.weight",
    "transformer_block.0.layer_norm_pre_feedforward.bias"         : "vit.encoder.layer.0.layernorm_after.bias",
    "transformer_block.1.multi_head_att_layer.proj_q.weight"      : "vit.encoder.layer.1.attention.attention.query.weight",
    "transformer_block.1.multi_head_att_layer.proj_q.bias"        : "vit.encoder.layer.1.attention.attention.query.bias",
    "transformer_block.1.multi_head_att_layer.proj_k.weight"      : "vit.encoder.layer.1.attention.attention.key.weight",
    "transformer_block.1.multi_head_att_layer.proj_k.bias"        : "vit.encoder.layer.1.attention.attention.key.bias",
    "transformer_block.1.multi_head_att_layer.proj_v.weight"      : "vit.encoder.layer.1.attention.attention.value.weight",
    "transformer_block.1.multi_head_att_layer.proj_v.bias"        : "vit.encoder.layer.1.attention.attention.value.bias",
    "transformer_block.1.multi_head_att_layer.proj_linear.weight" : "vit.encoder.layer.1.attention.output.dense.weight",
    "transformer_block.1.multi_head_att_layer.proj_linear.bias"   : "vit.encoder.layer.1.attention.output.dense.bias",
    "transformer_block.1.ff_layer.ff_layer.0.weight"              : "vit.encoder.layer.1.intermediate.dense.weight",
    "transformer_block.1.ff_layer.ff_layer.0.bias"                : "vit.encoder.layer.1.intermediate.dense.bias",
    "transformer_block.1.ff_layer.ff_layer.2.weight"              : "vit.encoder.layer.1.output.dense.weight",
    "transformer_block.1.ff_layer.ff_layer.2.bias"                : "vit.encoder.layer.1.output.dense.bias",
    "transformer_block.1.layer_norm_pre_multi_head.weight"        : "vit.encoder.layer.1.layernorm_before.weight",
    "transformer_block.1.layer_norm_pre_multi_head.bias"          : "vit.encoder.layer.1.layernorm_before.bias",
    "transformer_block.1.layer_norm_pre_feedforward.weight"       : "vit.encoder.layer.1.layernorm_after.weight",
    "transformer_block.1.layer_norm_pre_feedforward.bias"         : "vit.encoder.layer.1.layernorm_after.bias",
    "transformer_block.2.multi_head_att_layer.proj_q.weight"      : "vit.encoder.layer.2.attention.attention.query.weight",
    "transformer_block.2.multi_head_att_layer.proj_q.bias"        : "vit.encoder.layer.2.attention.attention.query.bias",
    "transformer_block.2.multi_head_att_layer.proj_k.weight"      : "vit.encoder.layer.2.attention.attention.key.weight",
    "transformer_block.2.multi_head_att_layer.proj_k.bias"        : "vit.encoder.layer.2.attention.attention.key.bias",
    "transformer_block.2.multi_head_att_layer.proj_v.weight"      : "vit.encoder.layer.2.attention.attention.value.weight",
    "transformer_block.2.multi_head_att_layer.proj_v.bias"        : "vit.encoder.layer.2.attention.attention.value.bias",
    "transformer_block.2.multi_head_att_layer.proj_linear.weight" : "vit.encoder.layer.2.attention.output.dense.weight",
    "transformer_block.2.multi_head_att_layer.proj_linear.bias"   : "vit.encoder.layer.2.attention.output.dense.bias",
    "transformer_block.2.ff_layer.ff_layer.0.weight"              : "vit.encoder.layer.2.intermediate.dense.weight",
    "transformer_block.2.ff_layer.ff_layer.0.bias"                : "vit.encoder.layer.2.intermediate.dense.bias",
    "transformer_block.2.ff_layer.ff_layer.2.weight"              : "vit.encoder.layer.2.output.dense.weight",
    "transformer_block.2.ff_layer.ff_layer.2.bias"                : "vit.encoder.layer.2.output.dense.bias",
    "transformer_block.2.layer_norm_pre_multi_head.weight"        : "vit.encoder.layer.2.layernorm_before.weight",
    "transformer_block.2.layer_norm_pre_multi_head.bias"          : "vit.encoder.layer.2.layernorm_before.bias",
    "transformer_block.2.layer_norm_pre_feedforward.weight"       : "vit.encoder.layer.2.layernorm_after.weight",
    "transformer_block.2.layer_norm_pre_feedforward.bias"         : "vit.encoder.layer.2.layernorm_after.bias",
    "transformer_block.3.multi_head_att_layer.proj_q.weight"      : "vit.encoder.layer.3.attention.attention.query.weight",
    "transformer_block.3.multi_head_att_layer.proj_q.bias"        : "vit.encoder.layer.3.attention.attention.query.bias",
    "transformer_block.3.multi_head_att_layer.proj_k.weight"      : "vit.encoder.layer.3.attention.attention.key.weight",
    "transformer_block.3.multi_head_att_layer.proj_k.bias"        : "vit.encoder.layer.3.attention.attention.key.bias",
    "transformer_block.3.multi_head_att_layer.proj_v.weight"      : "vit.encoder.layer.3.attention.attention.value.weight",
    "transformer_block.3.multi_head_att_layer.proj_v.bias"        : "vit.encoder.layer.3.attention.attention.value.bias",
    "transformer_block.3.multi_head_att_layer.proj_linear.weight" : "vit.encoder.layer.3.attention.output.dense.weight",
    "transformer_block.3.multi_head_att_layer.proj_linear.bias"   : "vit.encoder.layer.3.attention.output.dense.bias",
    "transformer_block.3.ff_layer.ff_layer.0.weight"              : "vit.encoder.layer.3.intermediate.dense.weight",
    "transformer_block.3.ff_layer.ff_layer.0.bias"                : "vit.encoder.layer.3.intermediate.dense.bias",
    "transformer_block.3.ff_layer.ff_layer.2.weight"              : "vit.encoder.layer.3.output.dense.weight",
    "transformer_block.3.ff_layer.ff_layer.2.bias"                : "vit.encoder.layer.3.output.dense.bias",
    "transformer_block.3.layer_norm_pre_multi_head.weight"        : "vit.encoder.layer.3.layernorm_before.weight",
    "transformer_block.3.layer_norm_pre_multi_head.bias"          : "vit.encoder.layer.3.layernorm_before.bias",
    "transformer_block.3.layer_norm_pre_feedforward.weight"       : "vit.encoder.layer.3.layernorm_after.weight",
    "transformer_block.3.layer_norm_pre_feedforward.bias"         : "vit.encoder.layer.3.layernorm_after.bias",
}

In [12]:
encoder_state_dict = model.state_dict()
for k in encoder_state_dict.keys():
    if not k in google_to_custom_dict: continue
    k_google = google_to_custom_dict[k]
    encoder_state_dict[k] = model_weight_dict[k_google]

In [13]:
model.load_state_dict(encoder_state_dict)

<All keys matched successfully>