In [1]:
import timm
import torch
import torch.nn as nn
import numpy as np

In [2]:
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
model = timm.create_model('vit_base_patch16_224', pretrained=True)

config = resolve_data_config({}, model=model)
transform = create_transform(**config)

In [3]:
# dc = model.named_children()
# dc = list(dc)
# x = torch.arange(4*3*224*224, dtype=torch.float32).reshape(4, 3, 224, 224)
# x.shape

## block WEIGHT binaryfile 생성


In [4]:
for layer in range(12):
    block = dict(dict(dict(model.named_children())['blocks'].named_children())[str(layer)].named_children())
    ls = [block['norm1'],list(block['attn'].children())[0], list(block['attn'].children())[4],block['norm2'],list(block['mlp'].children())[0], list(block['mlp'].children())[4]]

    weight_ls = list()
    for i in range(len(ls)):
        weight_ls.append(ls[i].weight.clone().detach().numpy().T)
        weight_ls.append(ls[i].bias.clone().detach().numpy())

    blk_weight = np.array(weight_ls[0].flatten())

    for i in range(1, len(weight_ls)):
        blk_weight = np.concatenate((blk_weight, weight_ls[i].flatten()))
    blk_weight.tofile('./pre_weights/'+str(layer)+'_newblock'+'.bin')

### QKV WEIGHT BINFILE

In [5]:
layer = 0
dd = dict(list(list(model.named_children())[4][1].children())[layer].named_children())
blk = dict(dd['attn'].named_children())

## dummy input check

In [6]:
torch.manual_seed(10)
dummy_input = torch.randn((4, 196, 768))
qkv_w = dict(dd['attn'].named_children())['qkv'].weight.clone().detach().T

In [7]:
# dummy_input.numpy().tofile('./pre_weights/dummy_input_4_196_768.bin')

In [8]:
dummy_input.shape

torch.Size([4, 196, 768])

In [9]:
blocks = list()
for layer in range(12):
    block = dict(dict(dict(model.named_children())['blocks'].named_children())[str(layer)].named_children())
    ls = [list(block['attn'].children())[0], list(block['attn'].children())[4],list(block['mlp'].children())[0], list(block['mlp'].children())[4]]

    weight_ls = list()
    for i in range(len(ls)):
        weight_ls.append(ls[i].weight.clone().detach().T)
        weight_ls.append(ls[i].bias.clone().detach())
    blocks.append(weight_ls)

    

for i in blocks[0]:
    print(i.shape)
    

torch.Size([768, 2304])
torch.Size([2304])
torch.Size([768, 768])
torch.Size([768])
torch.Size([768, 3072])
torch.Size([3072])
torch.Size([3072, 768])
torch.Size([768])


# FLASH ATTENTION VALUE CHECK

In [10]:
block = dict(dict(dict(model.named_children())['blocks'].named_children())[str(layer)].named_children())
# Fetch correct qkv weights and biases from the block
qkv_layer = dict(block['attn'].named_children())['qkv']
qkv_weights = qkv_layer.weight  # Shape: (out_features, in_features)
qkv_bias = qkv_layer.bias

# Transpose qkv_weights to shape (in_features, out_features) for matrix multiplication
qkv_weights = qkv_weights.T  # Now shape: (in_features, out_features)

# Apply the QKV linear projection with the correct transposed weights and biases
dQKV = torch.add(torch.matmul(dummy_input, qkv_weights), qkv_bias)

# Reshape and permute to get the correct multi-head attention format
dQKV = dQKV.view(4, 196, 12, 3, 64)  # (batch, seq_len, num_heads, qkv, head_dim)
dQKV = torch.einsum('b p h k d -> b h k p d', dQKV)  # Permute dimensions

# Calculate attention manually for each head
Os = list()
for i in range(12):
    Q = dQKV[0][i][0]  # Query for head i
    K = dQKV[0][i][1]  # Key for head i
    V = dQKV[0][i][2]  # Value for head i
    
    # Apply scaled dot-product attention with numerical stability
    attn_weights = torch.matmul(Q, K.T) / (Q.size(-1) ** 0.5)  # Scale by sqrt(d_k)
    attn_weights = attn_weights - torch.max(attn_weights, dim=-1, keepdim=True)[0]  # Numerical stability
    attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
    
    O = torch.matmul(attn_weights, V)  # Multiply attention weights by V
    Os.append(O)

# Concatenate outputs from all heads and apply the projection layer
concat_Os = torch.cat(Os, dim=1)  # Concatenate along the feature dimension (head outputs)
proj_output = dict(block['attn'].named_children())['proj'](concat_Os)[:8, :8]  # Apply projection layer
proj_output

tensor([[-0.2055, -0.1736,  0.4550,  0.2223, -0.0058, -0.1951,  0.1081, -0.0950],
        [-0.2403, -0.1827,  0.3924,  0.2354, -0.0299, -0.1278,  0.1057, -0.0935],
        [-0.2598, -0.1673,  0.4965,  0.2337, -0.0930, -0.1645,  0.0652, -0.0784],
        [-0.2632, -0.1369,  0.4431,  0.2356,  0.0452, -0.1379,  0.0822, -0.0350],
        [-0.2474, -0.1771,  0.3949,  0.1860,  0.0159, -0.1896,  0.0057, -0.0729],
        [-0.2212, -0.1546,  0.4233,  0.1712, -0.0062, -0.1302,  0.0796, -0.0667],
        [-0.2667, -0.1787,  0.4048,  0.2312,  0.0124, -0.1626,  0.1034, -0.1029],
        [-0.2447, -0.1572,  0.4262,  0.2408,  0.0147, -0.1515,  0.0910, -0.0736]],
       grad_fn=<SliceBackward0>)

In [18]:
blk1 = blocks[0]
torch.manual_seed(10)
blk1[0] = blk1[0].to(torch.float32)
blk1[1] = blk1[1].to(torch.float32)

dummy_input = torch.randn((4, 196, 768), dtype=torch.float32)
print(torch.matmul(dummy_input,blk1[0])[0][0][1])
dQKV = torch.add(torch.matmul(dummy_input,blk1[0]), blk1[1])
dQKV = dQKV.view(4, 196, 3, 12, 64)
dQKV = torch.einsum('b p k h d -> k b h p d',dQKV)
dQKV.shape

tensor(-1.1279)


torch.Size([3, 4, 12, 196, 64])

In [23]:
Qs, Ks, Vs = dQKV.unbind(0)

4 : batch
12 : Multi head
3 : QKV

In [24]:
Q = Qs[0][0]# batch head
K = Ks[0][0]
V = Vs[0][0]

In [25]:
# torch.matmul(torch.nn.functional.softmax(torch.matmul(Q, K.T), dim=1),V)[188:196,8:16]#일반 ATTN
torch.matmul(torch.nn.functional.softmax(torch.matmul(Q, K.T), dim=1),V).shape#일반 ATTN

torch.Size([196, 64])

In [29]:
Os = list()
for i in range(12):
    Q = Qs[0][i]# batch head
    K = Ks[0][i]
    V = Vs[0][i]
    O = torch.matmul(torch.nn.functional.softmax(torch.matmul(Q, K.T)/8, dim=1),V)
    Os.append(O)
torch.concat(Os, dim=1)[:8,:8]
dict(block['attn'].named_children())['proj'](torch.concat(Os, dim=1))[:8,:8]


tensor([[-0.4365, -0.0053, -2.0785, -0.1310,  0.3362, -0.7637, -0.1959, -0.0355],
        [ 0.1909, -0.0482, -0.0667, -0.1483,  0.0187, -0.6523, -0.2404,  0.1926],
        [-0.4567,  0.0283,  0.0388, -0.0864,  0.1838, -0.6500, -0.1512,  0.2088],
        [ 0.0561,  0.0061,  0.0934, -0.1666,  0.1272, -0.5512, -0.3549,  0.0698],
        [-0.0255,  0.0720, -0.7790, -0.0651,  0.1271, -0.1607, -0.1948,  0.0879],
        [-0.3870,  0.0909,  0.0525, -0.0342,  0.0777, -0.5152, -0.1509,  0.0721],
        [-0.1256,  0.1418,  0.0562, -0.1245, -0.0947, -0.1670, -0.3574,  0.1665],
        [ 0.1702,  0.2194, -0.7297, -0.1137, -0.2793,  0.2847, -0.5775,  0.2026]],
       grad_fn=<SliceBackward0>)

In [30]:
layer = 0
block = dict(dict(dict(model.named_children())['blocks'].named_children())[str(layer)].named_children())
# list(block['attn'].children())[0]
ls = [list(block['attn'].children())[0], list(block['attn'].children())[4],list(block['mlp'].children())[0], list(block['mlp'].children())[4]]
ls[0]
block['attn']
print((block['attn'](dummy_input))[1,:8,:8])
# torch.round(dict(dict(model.named_children())['blocks'].named_children())[str(0)](dummy_input)[0,:8,:8]*100)/100
# dict(dict(model.named_children())['blocks'].named_children())[str(0)]

tensor([[-0.4365, -0.0053, -2.0785, -0.1310,  0.3362, -0.7637, -0.1959, -0.0355],
        [ 0.1909, -0.0482, -0.0667, -0.1483,  0.0187, -0.6523, -0.2404,  0.1926],
        [-0.4567,  0.0283,  0.0388, -0.0864,  0.1838, -0.6500, -0.1512,  0.2088],
        [ 0.0561,  0.0061,  0.0934, -0.1666,  0.1272, -0.5512, -0.3549,  0.0698],
        [-0.0255,  0.0720, -0.7790, -0.0651,  0.1271, -0.1607, -0.1948,  0.0879],
        [-0.3870,  0.0909,  0.0525, -0.0342,  0.0777, -0.5152, -0.1509,  0.0721],
        [-0.1256,  0.1418,  0.0562, -0.1245, -0.0947, -0.1670, -0.3574,  0.1665],
        [ 0.1702,  0.2194, -0.7297, -0.1137, -0.2793,  0.2847, -0.5775,  0.2026]],
       grad_fn=<SliceBackward0>)


In [28]:
block['attn']

Attention(
  (qkv): Linear(in_features=768, out_features=2304, bias=True)
  (q_norm): Identity()
  (k_norm): Identity()
  (attn_drop): Dropout(p=0.0, inplace=False)
  (proj): Linear(in_features=768, out_features=768, bias=True)
  (proj_drop): Dropout(p=0.0, inplace=False)
)