In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import math
import numpy  as np

In [1]:
list("महान")

['म', 'ह', 'ा', 'न']

In [3]:
# parameters

batch_size  = 30
num_heads = 8
emb_dim = 512
token_inp = 50

In [10]:
data = torch.rand(batch_size,token_inp,emb_dim)
data.size()
B,T,C=data.size() #B:Batch len, T:number Token, C:dimension of each token

In [12]:
B,T,C

(30, 50, 512)

In [5]:
input_linear = nn.Linear(in_features=emb_dim,out_features=3*emb_dim)
out_linear = nn.Linear(in_features=emb_dim,out_features=emb_dim)

In [180]:
new = input_linear(data)
new.size()

torch.Size([30, 50, 1536])

In [181]:
qkv = input_linear(data).split(emb_dim,dim=-1)
q,k,v = qkv

In [182]:
q = q.view(B,T,num_heads,C//num_heads)

In [183]:
q.size()

torch.Size([30, 50, 8, 64])

In [184]:
q = q.transpose(1,2) # or q = q.permute(0,2,1,3)
q.shape

torch.Size([30, 8, 50, 64])

In [185]:
k = k.view(B,T,num_heads,C//num_heads).transpose(1,2)
v = v.view(B,T,num_heads,C//num_heads).permute(0,2,1,3)

In [186]:
q.shape,k.shape,v.shape

(torch.Size([30, 8, 50, 64]),
 torch.Size([30, 8, 50, 64]),
 torch.Size([30, 8, 50, 64]))

## Self Attention for multiple heads

For a single head:
$$
\text{self attention} = softmax\bigg(\frac{Q.K^T}{\sqrt{d_k}}+M\bigg)
$$

$$
\text{new V} = \text{self attention}.V
$$ 

In [71]:
d_k = q.size()[-1]
scalar_product = F.softmax((q@k.transpose(-2,-1))/ math.sqrt(d_k))

  scalar_product = F.softmax((q@k.transpose(-2,-1))/ math.sqrt(d_k))


In [72]:
q.shape, k.transpose(-2,-1).shape,scalar_product.shape,v.shape

(torch.Size([30, 8, 50, 64]),
 torch.Size([30, 8, 64, 50]),
 torch.Size([30, 8, 50, 50]),
 torch.Size([30, 8, 50, 64]))

In [73]:
self_attention = scalar_product@v

In [75]:
self_attention.shape

torch.Size([30, 8, 50, 64])

In [92]:
scaled = ((q@k.transpose(-2,-1))/ math.sqrt(d_k))
mask = torch.triu(torch.full(scaled.size(),float('-inf')),diagonal=1)

In [94]:
scaled[0][0][1],mask[0][0]

(tensor([ 0.1086,  0.0241,  0.0190,  0.0170,  0.0250,  0.1070, -0.0152,  0.0375,
          0.0707,  0.0511,  0.1340, -0.0173,  0.0381,  0.0345,  0.0416,  0.0333,
          0.0140,  0.0426,  0.1132, -0.0325,  0.0618,  0.0323,  0.0649,  0.0240,
          0.0122,  0.0853,  0.0192,  0.0625,  0.0602, -0.0682,  0.0943,  0.0454,
          0.0668,  0.0624,  0.0006,  0.0155,  0.1139,  0.1363,  0.0148, -0.0051,
          0.1329, -0.0170,  0.0492, -0.0213, -0.0133, -0.0218, -0.0062,  0.1501,
         -0.0045,  0.0240], grad_fn=<SelectBackward0>),
 tensor([[0., -inf, -inf,  ..., -inf, -inf, -inf],
         [0., 0., -inf,  ..., -inf, -inf, -inf],
         [0., 0., 0.,  ..., -inf, -inf, -inf],
         ...,
         [0., 0., 0.,  ..., 0., -inf, -inf],
         [0., 0., 0.,  ..., 0., 0., -inf],
         [0., 0., 0.,  ..., 0., 0., 0.]]))

In [106]:
self_attention_masked = F.softmax((scaled+mask),dim=-1)@v


In [107]:
self_attention_masked.shape,self_attention_masked[0][0]


(torch.Size([30, 8, 50, 64]),
 tensor([[ 0.5031,  0.3435, -0.1777,  ..., -0.5685,  0.1213,  0.3096],
         [ 0.3170,  0.2586, -0.3077,  ..., -0.4127,  0.2261,  0.1800],
         [ 0.3054,  0.0597, -0.4125,  ..., -0.4866,  0.2588,  0.1979],
         ...,
         [ 0.2566,  0.1146, -0.4509,  ..., -0.4326,  0.0508,  0.1586],
         [ 0.2589,  0.1135, -0.4532,  ..., -0.4314,  0.0588,  0.1593],
         [ 0.2608,  0.1130, -0.4591,  ..., -0.4315,  0.0578,  0.1530]],
        grad_fn=<SelectBackward0>))

tensor([[0., -inf, -inf,  ..., -inf, -inf, -inf],
        [0., 0., -inf,  ..., -inf, -inf, -inf],
        [0., 0., 0.,  ..., -inf, -inf, -inf],
        ...,
        [0., 0., 0.,  ..., 0., -inf, -inf],
        [0., 0., 0.,  ..., 0., 0., -inf],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [191]:
def CrossProduct(q,k,v,mask=None):
    print(f"{q.shape}{k.transpose(-2,-1).shape}{v.shape}")
    d_k = q.size()[-1]
    scalar = (q@k.transpose(-2,-1))/math.sqrt(d_k)
    if mask is not None:
        mask = torch.full(scalar.size(),float('-inf'))
        mask = torch.triu(mask)
        scalar += mask
    attention = F.softmax(scalar,dim=-1)@v
    return attention
        


In [205]:
num = CrossProduct(q,k,v)
num.shape

torch.Size([30, 8, 50, 64])torch.Size([30, 8, 64, 50])torch.Size([30, 8, 50, 64])


torch.Size([30, 8, 50, 64])

In [206]:
num = num.transpose(1,2)
num.shape

torch.Size([30, 50, 8, 64])

In [207]:
B,T,C

(30, 50, 512)

In [217]:

# num = num.reshape(B,T,num_heads*C)
num = num.contiguous().view(B,T,C)
num.shape


torch.Size([30, 50, 512])

In [214]:
class MultiHeadAttention(nn.Module):
    def __init__(self,emb_dim,num_heads):
        super().__init__()
        # self.data = data
        self.emb_dim = emb_dim
        self.num_heads  = num_heads
        assert self.emb_dim%self.num_heads==0
        self.input_linear = nn.Linear(in_features=self.emb_dim,out_features=3*self.emb_dim)
        self.output_linear = nn.Linear(in_features=self.emb_dim,out_features=self.emb_dim)

    def forward(self,data):
        self.data = data
        B,T,C = self.data.size() 
        q,k,v = self.input_linear(self.data).split(self.emb_dim,dim=-1) # in--> 30,50,512 out-->30,50,1536
        q = q.view(B,T,self.num_heads,C//self.num_heads).transpose(1,2)
        k = k.view(B,T,self.num_heads,C//self.num_heads).transpose(1,2)
        v = v.view(B,T,self.num_heads,C//self.num_heads).transpose(1,2)
        attention = CrossProduct(q,k,v)
        attention = attention.transpose(1,2).contiguous().view(B,T,C)
        out = out_linear(attention)
        return out
    



In [215]:
number = MultiHeadAttention(emb_dim=emb_dim,num_heads=num_heads)

In [225]:
numb=number.forward(data)
numb.shape,numb[0][0]

torch.Size([30, 8, 50, 64])torch.Size([30, 8, 64, 50])torch.Size([30, 8, 50, 64])


(torch.Size([30, 50, 512]),
 tensor([-2.5665e-01,  4.0416e-02,  8.0793e-02,  5.0079e-02, -1.8421e-03,
          1.9777e-01, -2.4611e-01, -2.7946e-01, -3.6623e-01,  3.4265e-02,
          1.8036e-02,  1.0174e-01, -4.3216e-02,  3.7045e-02, -1.9604e-01,
          8.1235e-02, -1.2366e-01,  8.8899e-02,  2.6143e-01,  1.0775e-01,
          4.2346e-02, -1.8797e-02, -3.0577e-02,  2.9665e-01, -5.0877e-02,
         -4.3885e-02,  9.9424e-02,  1.0587e-01, -1.5881e-01, -1.1248e-01,
          4.3572e-02,  1.6446e-01, -2.4672e-01,  9.6683e-02, -1.5970e-01,
         -2.2575e-01,  5.0873e-02, -1.3085e-01, -8.2951e-03,  2.4432e-01,
         -1.9944e-01,  3.3726e-02,  1.0078e-01,  7.7097e-02,  1.0899e-01,
         -7.3702e-03, -3.9956e-02, -9.7962e-02, -6.9033e-02,  7.4464e-02,
         -1.7607e-01, -3.1354e-02, -1.5993e-01,  2.7830e-01,  2.6122e-01,
         -8.6033e-02,  6.3355e-02, -8.3642e-02, -1.0748e-02, -1.0637e-01,
         -1.6083e-01,  3.9753e-01,  3.5859e-01,  1.7571e-01, -2.7137e-01,
         -

In [137]:
number

MultiHeadAttention(
  (input_linear): Linear(in_features=512, out_features=1536, bias=True)
  (output_linear): Linear(in_features=512, out_features=512, bias=True)
)