In [10]:
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

In [11]:
L = 4 #lenght of intput : My name is Krishna
d_k, d_v = 8,8 #size of k,v vectors 

In [12]:
#randomly initializing the vectors
q = np.random.randn(L,d_k)
k = np.random.randn(L,d_k)
v = np.random.randn(L,d_v)

In [13]:
print("q:", q ,"k:", k, "v:", v)

q: [[ 0.71637793  2.122557    0.65611945 -2.79270637 -0.13957608 -0.42654572
   1.33000972  0.67839715]
 [ 0.16970322  1.18782568 -0.34402368 -1.42346335 -0.17658202 -1.05895873
   0.2645851   1.05979683]
 [ 1.4003189  -0.79588714  1.16545465  0.09623789  0.11615151 -0.09831192
   1.820415    1.20911123]
 [-2.03504234 -0.14572441 -0.0876285   2.4955127   1.06987593  0.72536285
   0.75068959 -0.93609233]] k: [[-0.43292757 -0.48195108 -0.22097563  0.66661556  1.3611114   0.74620559
  -0.68710278 -0.10632383]
 [ 0.99990842  1.24237737  0.88310711  0.95270578  0.59066807 -0.7662061
   1.06595574  0.39012535]
 [-0.56981871 -0.23800165 -0.01020816 -0.37970047 -0.19352451  1.57137181
   0.63146991  1.69130594]
 [-0.66874421 -1.35599212  0.35196259 -0.28288742  0.99172998  0.50563394
   1.08359416 -0.25984852]] v: [[-0.11139139 -0.0846177  -0.86453517 -1.32807291  0.5046783  -0.16446757
  -1.143422   -0.08288812]
 [ 1.24250396  1.93482167 -0.30446374  1.00781966 -0.66378703  0.73967111
  -0.92

Self Attention = $\text{softmax}\left(\frac{QK^T}{\sqrt{d_k}} + M\right)V$


In [14]:
np.matmul(q,k.T)

array([[-4.83400888,  3.19889556,  1.48430424, -1.42548065],
       [-2.84385259,  1.38803279,  0.49427024, -2.14182483],
       [-1.71067669,  4.08842709,  2.36060739,  2.24961988],
       [ 4.21538751,  0.59537599,  0.07122192,  3.30621399]])

In [15]:
scaled_var = np.matmul(q,k.T)/ math.sqrt(d_k)  ##scaled to minimize the variance

$\text{softmax}(x_i) = \frac{e^{x_i}}{\sum_{j} e^{x_j}}$


In [16]:
def softmax(x):
    return(np.exp(x).T / np.sum(np.exp(x))).T


### Masking required in decoder but not in encoder

In [17]:
mask = np.tril(np.ones((L,L)))
mask[mask == 0] = -math.inf
mask[mask == 1] = 0
mask

array([[  0., -inf, -inf, -inf],
       [  0.,   0., -inf, -inf],
       [  0.,   0.,   0., -inf],
       [  0.,   0.,   0.,   0.]])

In [18]:
attention = softmax(scaled_var)

In [19]:
attention

array([[0.00636106, 0.10888093, 0.05938581, 0.02122736],
       [0.01285613, 0.05739851, 0.04184723, 0.01647798],
       [0.01919134, 0.14912007, 0.08095353, 0.07783843],
       [0.15596617, 0.04337017, 0.03603373, 0.11309156]])

In [20]:
def attention_fn(q,k,v,mask = None):
    d_k = q.shape[-1]
    scaled = np.matmul(q,k.T)/ math.sqrt(d_k)
    if mask is not None:
        scaled = scaled+mask
    attention = softmax(scaled)
    output = np.matmul(attention,v)
    return output,attention

In [21]:
new_values, attention = attention_fn(q,k,v)

### Multi Head Attention

In [22]:
sequence_len = 4  ##My name is Krishna
batch_size = 1
input_dim = 512
out_dim = 512
x = torch.randn((batch_size,sequence_len, input_dim)) ##input to the multi head attention block
x.shape

torch.Size([1, 4, 512])

In [23]:
qkv_layer = nn.Linear(input_dim, 3*out_dim)  ##q,k,v vectors all concatenated and all have the 8 attention heads

In [24]:
qkv = qkv_layer(x)
qkv.shape

torch.Size([1, 4, 1536])

In [25]:
num_heads = 8
head_dim = out_dim//num_heads
qkv = qkv.reshape(batch_size, sequence_len, num_heads, 3*head_dim)
qkv = qkv.permute(0,2,1,3)
qkv.shape


torch.Size([1, 8, 4, 192])

In [26]:
q,k,v = qkv.chunk(3,dim=-1)
q.shape,k.shape,v.shape

(torch.Size([1, 8, 4, 64]),
 torch.Size([1, 8, 4, 64]),
 torch.Size([1, 8, 4, 64]))

In [27]:
d_k = q.size()[-1]
scaled = torch.matmul(q,k.transpose(-2,-1)/math.sqrt(d_k))
scaled.shape

torch.Size([1, 8, 4, 4])

In [28]:
k.shape,k.T.shape, k.transpose(-2,-1).shape

  k.shape,k.T.shape, k.transpose(-2,-1).shape


(torch.Size([1, 8, 4, 64]),
 torch.Size([64, 4, 8, 1]),
 torch.Size([1, 8, 64, 4]))

In [29]:
mask = torch.full(scaled.size(), float('-inf'))
mask = torch.tril(mask, diagonal=1)
mask[0][1]

tensor([[-inf, -inf, 0., 0.],
        [-inf, -inf, -inf, 0.],
        [-inf, -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf]])

In [30]:
(scaled+mask)[0][1]

tensor([[   -inf,    -inf, -0.3362,  0.1631],
        [   -inf,    -inf,    -inf,  0.7193],
        [   -inf,    -inf,    -inf,    -inf],
        [   -inf,    -inf,    -inf,    -inf]], grad_fn=<SelectBackward0>)

In [31]:
scaled += mask
attention = F.softmax(scaled, dim=1)

In [32]:
values = torch.matmul(attention,v)
values.shape

torch.Size([1, 8, 4, 64])

In [33]:
def multi_attention_fn(q,k,v, mask = None):
    d_k = q.size()[-1]
    scaled = torch.matmul(q,k.transpose(-2,-1)/math.sqrt(d_k))
    if mask is not None:
        scaled += mask
    attention = F.softmax(scaled, dim=-1)
    values = torch.matmul(attention,v)
    return values,attention

In [34]:
values, attention = multi_attention_fn(q,k,v,mask=mask)
attention.shape

torch.Size([1, 8, 4, 4])

In [35]:
attention[0][0]

tensor([[0.0000, 0.0000, 0.5634, 0.4366],
        [0.0000, 0.0000, 0.0000, 1.0000],
        [   nan,    nan,    nan,    nan],
        [   nan,    nan,    nan,    nan]], grad_fn=<SelectBackward0>)

In [36]:
def attention_dot_product(q,k,v, mask = None):
    d_k = q.size()[-1]
    scaled = torch.matmul(q,k.transpose(-2,-1)/math.sqrt(d_k))
    if mask is not None:
        scaled += mask
    attention = F.softmax(scaled, dim=-1)
    values = torch.matmul(attention,v)
    return values,attention

class MultiHeadAttention(nn.Module):
    def __init__(self, input_dim, n_heads, d_model):
        super().__init__()
        self.input_dim = input_dim
        self.n_heads = n_heads
        self.d_model = d_model
        self.head_dim = d_model//n_heads
        self.qkv_layer = nn.Linear(input_dim,3*d_model)
        self.linear_layer = nn.Linear(d_model,d_model)

    def forward(self,x,mask=None):
        batch_size, seq_len, input_dim = x.size()
        qkv = self.qkv_layer(x)
        qkv = qkv.reshape(batch_size, seq_len, self.n_heads, 3*self.head_dim)
        qkv = qkv.permute(0,2,1,3)
        q,k,v = qkv.chunk(3, dim=-1)
        values,attention = attention_dot_product(q,k,v,mask)
        values = values.reshape(batch_size,seq_len,self.n_heads*self.head_dim)
        output = self.linear_layer(values)
        return output


In [40]:
input_dim = 1024
d_model = 512
n_heads = 8

batch_size = 30
seq_len = 5
x = torch.randn( (batch_size, seq_len, input_dim) )

model = MultiHeadAttention(input_dim, n_heads, d_model)
out = model.forward(x)