## Unit testing Self attention and multi head attention modules

Torch is considered as ground truth

In [None]:
import torch
from torch import nn
import torch.nn.functional as F

from vit.vit import SelfAttention

In [None]:
batch_size = 2
seq_len = 100
input_dim = 128
embed_dim = 128
device = 'cuda:0'
dtype = torch.float32

In [None]:
# Projection matrices

q_proj = nn.Parameter(torch.randint(1, 5, (input_dim, embed_dim), device=device, dtype=dtype))
k_proj = nn.Parameter(torch.randint(1, 5, (input_dim, embed_dim), device=device, dtype=dtype))
v_proj = nn.Parameter(torch.randint(1, 5, (input_dim, embed_dim), device=device, dtype=dtype))

qb_proj = nn.Parameter(torch.randint(1, 3, (1, embed_dim), device=device, dtype=dtype)).squeeze(0).contiguous() # Bias
kb_proj = nn.Parameter(torch.randint(1, 3, (1, embed_dim), device=device, dtype=dtype)).squeeze(0).contiguous()
vb_proj = nn.Parameter(torch.randint(1, 3, (1, embed_dim), device=device, dtype=dtype)).squeeze(0).contiguous()

input = torch.randint(1, 10, (batch_size, seq_len, input_dim), device=device, dtype=dtype)
input.shape

### Self attention

In [None]:
# Preparing Q, K, V for torch

q1 = torch.matmul(input, q_proj) + qb_proj
k1 = torch.matmul(input, k_proj) + kb_proj
v1 = torch.matmul(input, v_proj) + vb_proj

print(q1.shape, k1.shape, v1.shape)

In [None]:
# Preparing Q, K, V for our custom implementation

sa = SelfAttention(d_in=input_dim, d_out=embed_dim, num_heads=1)
sa.to(device=device, dtype=dtype)

temp = torch.cat([q_proj, k_proj, v_proj], axis=-1)
temp = temp.contiguous()
print(temp.shape)
with torch.no_grad():
    sa.qkv.weight.copy_(temp.clone())

temp = torch.cat([qb_proj, kb_proj, vb_proj], axis=-1)
temp = temp.contiguous()
print(temp.shape)
with torch.no_grad():
    sa.qkv.bias.copy_(temp.clone())

In [None]:
o1 = F.scaled_dot_product_attention(q1, k1, v1)
o2 = sa(input)

o1.shape, o2.shape

In [None]:
import math
i = torch.matmul(q1, k1.transpose(1, 2))/math.sqrt(128)
o = torch.matmul(torch.softmax(i, axis=-1), v1)

torch.max(torch.abs(o1 - o)), torch.max(torch.abs(o2 - o))

In [None]:
# Measuring diff b/w the two
diff = torch.abs(o1 - o2)
max_diff_index = torch.argmax(diff)
max_diff_index_multi = (max_diff_index // (100 * 128), (max_diff_index % (100 * 128)) // 128, max_diff_index % 128)

print(f'Diff b/w both the implementations: {torch.max(diff)}')

### Multi head attention

In [None]:
from vit.vit import MultiHeadAttention

In [None]:
mha_torch = nn.MultiheadAttention(
    embed_dim=128,
    num_heads=4,
    bias=False,
    device=device,
    dtype=dtype,
    batch_first=True
)

mha_triton = MultiHeadAttention(
    num_heads=4, d_in=128, d_out=32
)
mha_triton.to(device, dtype)

In [None]:
o_proj = nn.Parameter(torch.randint(3, 5, (input_dim, embed_dim), device=device, dtype=dtype))
ob_proj = nn.Parameter(torch.zeros((1, embed_dim), device=device, dtype=dtype)).squeeze(0).contiguous()

In [None]:
temp = torch.cat([q_proj, k_proj, v_proj], axis=-1)
temp = temp.contiguous()
print(temp.shape)
with torch.no_grad():
    mha_triton.attention.qkv.weight.copy_(temp.clone())

temp = torch.cat([qb_proj, kb_proj, vb_proj], axis=-1)
temp = temp.contiguous()
print(temp.shape)
with torch.no_grad():
    mha_triton.attention.qkv.bias.copy_(temp.clone())

with torch.no_grad():
    mha_torch.out_proj.weight.copy_(o_proj.clone())
    # mha_torch.out_proj.bias.copy_(ob_proj.clone())

    temp = o_proj.t().contiguous()
    mha_triton.output.weight.copy_(temp.clone())
    # mha_triton.output.bias.copy_(ob_proj.clone())

In [None]:
o1, _ = mha_torch(q1, k1, v1)
o2 = mha_triton(input)

In [None]:
# Measuring diff b/w the two
diff = torch.abs(o1 - o2)
max_diff_index = torch.argmax(diff)
max_diff_index_multi = (max_diff_index // (100 * 128), (max_diff_index % (100 * 128)) // 128, max_diff_index % 128)

print(torch.max(diff))

In [None]:
o1, o2