In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
import h5py

In [13]:
import math

In [33]:
import matplotlib.pyplot       as plt
import matplotlib.colors       as mcolors
import matplotlib.patches      as mpatches
import matplotlib.transforms   as mtransforms
import matplotlib.font_manager as font_manager
%matplotlib inline

In [3]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [4]:
from poorman_transformer.modeling.transformer import Transformer, TransformerBlock

In [5]:
from sparix.trans import Pad

In [6]:
path_model_weight = "chkpts/google.vit-base-patch16-224-in21k"
model_weight = torch.load(path_model_weight)

In [50]:
model_weight_dict = model_weight.get('model_state_dict')

In [39]:
for key in model_weight.get('model_state_dict').keys():
    print(key)

vit.embeddings.cls_token
vit.embeddings.position_embeddings
vit.embeddings.patch_embeddings.projection.weight
vit.embeddings.patch_embeddings.projection.bias
vit.encoder.layer.0.attention.attention.query.weight
vit.encoder.layer.0.attention.attention.query.bias
vit.encoder.layer.0.attention.attention.key.weight
vit.encoder.layer.0.attention.attention.key.bias
vit.encoder.layer.0.attention.attention.value.weight
vit.encoder.layer.0.attention.attention.value.bias
vit.encoder.layer.0.attention.output.dense.weight
vit.encoder.layer.0.attention.output.dense.bias
vit.encoder.layer.0.intermediate.dense.weight
vit.encoder.layer.0.intermediate.dense.bias
vit.encoder.layer.0.output.dense.weight
vit.encoder.layer.0.output.dense.bias
vit.encoder.layer.0.layernorm_before.weight
vit.encoder.layer.0.layernorm_before.bias
vit.encoder.layer.0.layernorm_after.weight
vit.encoder.layer.0.layernorm_after.bias
vit.encoder.layer.1.attention.attention.query.weight
vit.encoder.layer.1.attention.attention.query

In [41]:
model_weight_dict["vit.embeddings.position_embeddings"].shape

torch.Size([1, 197, 768])

In [8]:
token_lib_size = 56
embd_size = 768
context_length = 32
num_blocks = 1
num_heads  = 1
encoder = TransformerBlock(
                      embd_size,
                      context_length,
                      num_heads,
                      uses_causal_mask  = False,
                      attention_dropout = 0.0,
                      residual_dropout  = 0.0,
                      feedforward_dropout = 0.0)

In [None]:
token_lib_size = 56
embd_size = 768
context_length = 32
num_blocks = 1
num_heads  = 1
encoder = Transformer(token_lib_size,
                      embd_size,
                      context_length,
                      num_blocks,
                      num_heads,
                      uses_causal_mask  = False,
                      attention_dropout = 0.0,
                      residual_dropout  = 0.0,
                      feedforward_dropout = 0.0)

In [7]:
encoder_state_dict = encoder.state_dict()
for key in encoder_state_dict.keys(): print(key)

multi_head_att_layer.mask
multi_head_att_layer.proj_q.weight
multi_head_att_layer.proj_q.bias
multi_head_att_layer.proj_k.weight
multi_head_att_layer.proj_k.bias
multi_head_att_layer.proj_v.weight
multi_head_att_layer.proj_v.bias
multi_head_att_layer.proj_linear.weight
multi_head_att_layer.proj_linear.bias
ff_layer.ff_layer.0.weight
ff_layer.ff_layer.0.bias
ff_layer.ff_layer.2.weight
ff_layer.ff_layer.2.bias
layer_norm_pre_multi_head.weight
layer_norm_pre_multi_head.bias
layer_norm_pre_feedforward.weight
layer_norm_pre_feedforward.bias


In [8]:
google_to_custom_dict = {
    "multi_head_att_layer.proj_q.weight"      : "vit.encoder.layer.0.attention.attention.query.weight",
    "multi_head_att_layer.proj_q.bias"        : "vit.encoder.layer.0.attention.attention.query.bias",
    "multi_head_att_layer.proj_k.weight"      : "vit.encoder.layer.0.attention.attention.key.weight",
    "multi_head_att_layer.proj_k.bias"        : "vit.encoder.layer.0.attention.attention.key.bias",
    "multi_head_att_layer.proj_v.weight"      : "vit.encoder.layer.0.attention.attention.value.weight",
    "multi_head_att_layer.proj_v.bias"        : "vit.encoder.layer.0.attention.attention.value.bias",
    "multi_head_att_layer.proj_linear.weight" : "vit.encoder.layer.0.attention.output.dense.weight",
    "multi_head_att_layer.proj_linear.bias"   : "vit.encoder.layer.0.attention.output.dense.bias",
    "ff_layer.ff_layer.0.weight"              : "vit.encoder.layer.0.intermediate.dense.weight",
    "ff_layer.ff_layer.0.bias"                : "vit.encoder.layer.0.intermediate.dense.bias",
    "ff_layer.ff_layer.2.weight"              : "vit.encoder.layer.0.output.dense.weight",
    "ff_layer.ff_layer.2.bias"                : "vit.encoder.layer.0.output.dense.bias",
    "layer_norm_pre_multi_head.weight"        : "vit.encoder.layer.0.layernorm_before.weight",
    "layer_norm_pre_multi_head.bias"          : "vit.encoder.layer.0.layernorm_before.bias",
    "layer_norm_pre_feedforward.weight"       : "vit.encoder.layer.0.layernorm_after.weight",
    "layer_norm_pre_feedforward.bias"         : "vit.encoder.layer.0.layernorm_after.bias",
}

In [None]:
google_to_custom_dict = {
    "transformer_block.0.multi_head_att_layer.proj_q.weight"      : "vit.encoder.layer.0.attention.attention.query.weight",
    "transformer_block.0.multi_head_att_layer.proj_q.bias"        : "vit.encoder.layer.0.attention.attention.query.bias",
    "transformer_block.0.multi_head_att_layer.proj_k.weight"      : "vit.encoder.layer.0.attention.attention.key.weight",
    "transformer_block.0.multi_head_att_layer.proj_k.bias"        : "vit.encoder.layer.0.attention.attention.key.bias",
    "transformer_block.0.multi_head_att_layer.proj_v.weight"      : "vit.encoder.layer.0.attention.attention.value.weight",
    "transformer_block.0.multi_head_att_layer.proj_v.bias"        : "vit.encoder.layer.0.attention.attention.value.bias",
    "transformer_block.0.multi_head_att_layer.proj_linear.weight" : "vit.encoder.layer.0.attention.output.dense.weight",
    "transformer_block.0.multi_head_att_layer.proj_linear.bias"   : "vit.encoder.layer.0.attention.output.dense.bias",
    "transformer_block.0.ff_layer.ff_layer.0.weight"              : "vit.encoder.layer.0.intermediate.dense.weight",
    "transformer_block.0.ff_layer.ff_layer.0.bias"                : "vit.encoder.layer.0.intermediate.dense.bias",
    "transformer_block.0.ff_layer.ff_layer.2.weight"              : "vit.encoder.layer.0.output.dense.weight",
    "transformer_block.0.ff_layer.ff_layer.2.bias"                : "vit.encoder.layer.0.output.dense.bias",
    "transformer_block.0.layer_norm_pre_multi_head.weight"        : "vit.encoder.layer.0.layernorm_before.weight",
    "transformer_block.0.layer_norm_pre_multi_head.bias"          : "vit.encoder.layer.0.layernorm_before.bias",
    "transformer_block.0.layer_norm_pre_feedforward.weight"       : "vit.encoder.layer.0.layernorm_after.weight",
    "transformer_block.0.layer_norm_pre_feedforward.bias"         : "vit.encoder.layer.0.layernorm_after.bias",
}

In [9]:
for k in encoder_state_dict.keys():
    if not k in google_to_custom_dict: continue
    k_google = google_to_custom_dict[k]
    encoder_state_dict[k] = model_weight_dict[k_google]

In [10]:
encoder.load_state_dict(encoder_state_dict)

<All keys matched successfully>

In [11]:
encoder

TransformerBlock(
  (multi_head_att_layer): MultiHeadAttention(
    (proj_q): Linear(in_features=768, out_features=768, bias=True)
    (proj_k): Linear(in_features=768, out_features=768, bias=True)
    (proj_v): Linear(in_features=768, out_features=768, bias=True)
    (proj_linear): Linear(in_features=768, out_features=768, bias=True)
    (attention_dropout): Dropout(p=0.0, inplace=False)
    (residual_dropout): Dropout(p=0.0, inplace=False)
  )
  (ff_layer): FeedForward(
    (ff_layer): Sequential(
      (0): Linear(in_features=768, out_features=3072, bias=True)
      (1): GELU()
      (2): Linear(in_features=3072, out_features=768, bias=True)
      (3): Dropout(p=0.0, inplace=False)
    )
  )
  (layer_norm_pre_multi_head): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (layer_norm_pre_feedforward): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)

In [14]:
Tx, Cx, Hx, Wx = 5000, 1, 129, 129
video_clip = torch.randn(Tx, Cx, Hx, Wx).to(device)

In [16]:
patch_embd_layer = nn.Conv2d(1, embd_size, kernel_size = (16, 16), stride = (16, 16)).to(device)

In [17]:
patch_embd = patch_embd_layer(video_clip)
patch_embd.shape

torch.Size([5000, 768, 8, 8])

In [18]:
16 * 8

128

In [9]:
path_h5 = "3IYF.Fibonacci.h5"
with h5py.File(path_h5, "r") as fh:
    data = fh.get("intensities")[:100]

In [11]:
B, C, H, W = data.shape

In [28]:
H_patch = 16
W_patch = 16
H_padded = math.ceil(H / H_patch) * H_patch
W_padded = math.ceil(W / W_patch) * W_patch
H_padded, W_padded

(160, 160)

In [30]:
pad = Pad(H_padded, W_padded)

In [32]:
data_padded = pad(data.reshape(B * C, H, W)).reshape(B, C, H_padded, W_padded)
data_padded.shape

(100, 1, 160, 160)

In [69]:
patch_embd_layer = nn.Conv2d(1, embd_size, kernel_size = (16, 16), stride = (16, 16)).to(device)
patch_embd_layer_state_dict = patch_embd_layer.state_dict()

In [77]:
google_to_embd_dict = {
    "weight" : "vit.embeddings.patch_embeddings.projection.weight",
    "bias"   : "vit.embeddings.patch_embeddings.projection.bias",
}

for k in patch_embd_layer_state_dict.keys():
    if not k in google_to_embd_dict: continue
    k_google = google_to_embd_dict[k]
    if model_weight_dict[k_google].ndim == 4:
        patch_embd_layer_state_dict[k] = model_weight_dict[k_google].mean(dim = 1, keepdims = True)    # (B, C, H, W), mean along the channel dimension
    if model_weight_dict[k_google].ndim == 1:
        patch_embd_layer_state_dict[k] = model_weight_dict[k_google]

In [78]:
patch_embd_layer.load_state_dict(patch_embd_layer_state_dict)

<All keys matched successfully>

In [79]:
patch_embd = patch_embd_layer(torch.tensor(data_padded).to(device))
patch_embd.shape

torch.Size([100, 768, 10, 10])

In [80]:
B, E, Hp, Wp = patch_embd.shape
patch_embd = patch_embd.view(B, E, Hp * Wp).transpose(1, 2).contiguous()    # (B, Hp * Wp, E)
patch_embd = patch_embd.view(1, B * Hp * Wp, E)