In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import pickle
import pandas as pd 
torch.manual_seed(2)

<torch._C.Generator at 0x10c8bac70>

let's define the main parameters of the sel attention mechanism

In [2]:
# embedding dimension; how many features we use for each token
embed_dim = 4

# number of attention heads
num_heads = 2

# how many tokens we have
seq_length = 4

# how many batches
batch_size = 2

# how many features per head
head_dim = embed_dim // num_heads

In [3]:
# define the embeddings. this will determine a feature vector for each distint token
embed = torch.nn.Embedding(5, embed_dim, 0)

In [4]:
# These are the input features
# We have two data points. The first one has features 1 and 2, the other two positions contain no features.
# The second data point has 4 characters
                  # A B,_,_    A,B,C,D
s =  torch.tensor([[1,2,0,0], [1,2,3,4] ]) # here 0 is the padding token

In [5]:
s

tensor([[1, 2, 0, 0],
        [1, 2, 3, 4]])

In [6]:
# embed the input features
e = embed(s)

In [7]:
e

tensor([[[ 0.0299, -0.0498,  1.0651,  0.8860],
         [-0.8110,  0.6737, -1.1233, -0.0919],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000]],

        [[ 0.0299, -0.0498,  1.0651,  0.8860],
         [-0.8110,  0.6737, -1.1233, -0.0919],
         [ 0.1405,  1.1191,  0.3152,  1.7528],
         [-0.7396, -1.2425, -0.1752,  0.6990]]], grad_fn=<EmbeddingBackward0>)

Now, each characrer has been represented using a 4-element vector

Below, we do another transformation to accomadate Q, K and V vectors. What effectively happens is that the embedding dimension gets increased by 3 times.
This final embedding dimension will be split into three parts for Q,K and V tensors.

In [26]:

qkv_proj = nn.Linear(embed_dim, 3*embed_dim, bias=False)

In [27]:
qkv = qkv_proj(e)

In [29]:
print(qkv.shape)

torch.Size([2, 4, 12])


Now let's reshape qkv to accommodate the heads.

In [30]:
qkv = qkv.reshape(batch_size, seq_length, num_heads, 3*head_dim)

In [31]:
qkv.shape

torch.Size([2, 4, 2, 6])

We have two data points, each contain 4 tokens. Each token is featuerized using 6 numbers (embedding dimension). And we have two such representations (two heads).

In [32]:
# qkv for data point 1.
qkv[0]

tensor([[[ 0.1592, -0.3586,  0.5629, -0.5026, -0.3295,  0.6718],
         [ 0.1544, -0.4201,  0.1823, -0.1312,  0.1394, -0.2156]],

        [[-0.2941,  0.1607, -0.5823,  0.5015,  0.6178, -0.5147],
         [-0.0365,  0.5209,  0.3704,  0.3071, -0.3399,  0.4653]],

        [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000]]],
       grad_fn=<SelectBackward0>)

In [33]:
qkv[0].shape

torch.Size([4, 2, 6])

In [34]:
# permute the tensor so that we have [batch_size, num_heads, seq_length, 3*head_dim]
qkv = qkv.permute(0, 2, 1, 3) # [batch_size, num_heads, seq_length, 3*head_dim]

In [35]:
qkv.shape

torch.Size([2, 2, 4, 6])

In [36]:
qkv[0]

tensor([[[ 0.1592, -0.3586,  0.5629, -0.5026, -0.3295,  0.6718],
         [-0.2941,  0.1607, -0.5823,  0.5015,  0.6178, -0.5147],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],

        [[ 0.1544, -0.4201,  0.1823, -0.1312,  0.1394, -0.2156],
         [-0.0365,  0.5209,  0.3704,  0.3071, -0.3399,  0.4653],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000]]],
       grad_fn=<SelectBackward0>)

Seperate q, k and v.

In [38]:
q, k, v = qkv.chunk(3, dim=-1)

In [39]:
q.shape

torch.Size([2, 2, 4, 2])

In [40]:
q

tensor([[[[ 0.1592, -0.3586],
          [-0.2941,  0.1607],
          [ 0.0000,  0.0000],
          [ 0.0000,  0.0000]],

         [[ 0.1544, -0.4201],
          [-0.0365,  0.5209],
          [ 0.0000,  0.0000],
          [ 0.0000,  0.0000]]],


        [[[ 0.1592, -0.3586],
          [-0.2941,  0.1607],
          [ 0.1441, -0.9701],
          [ 0.2768,  0.3215]],

         [[ 0.1544, -0.4201],
          [-0.0365,  0.5209],
          [-0.4737, -0.3240],
          [-0.0634,  0.1860]]]], grad_fn=<SplitBackward0>)

Obtain the attention logits

In [41]:
d_k = q.size()[-1]
attn_logits = torch.matmul(q, k.transpose(-2, -1))
attn_logits = attn_logits / math.sqrt(d_k)

In [42]:
attn_logits.shape

torch.Size([2, 2, 4, 4])

Take care of the masked tokens

In [43]:
mask = s.eq(0)

In [44]:
mask

tensor([[False, False,  True,  True],
        [False, False, False, False]])

In [45]:
mask = mask.unsqueeze(1).unsqueeze(2).to(torch.bool)

In [46]:
mask

tensor([[[[False, False,  True,  True]]],


        [[[False, False, False, False]]]])

Apply the mask

In [338]:
attn_logits=attn_logits.masked_fill(mask, float('-inf') )

In [47]:
attn_logits

tensor([[[[ 1.9084e-01, -1.9272e-01,  0.0000e+00,  0.0000e+00],
          [-1.7417e-01,  1.7806e-01,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00]],

         [[ 5.8889e-02, -5.0800e-02,  0.0000e+00,  0.0000e+00],
          [-5.3038e-02,  1.0355e-01,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00]]],


        [[[ 1.9084e-01, -1.9272e-01,  2.7438e-01, -1.7331e-01],
          [-1.7417e-01,  1.7806e-01, -2.2569e-01,  1.0025e-01],
          [ 4.0216e-01, -4.0335e-01,  6.1004e-01, -4.3979e-01],
          [-4.0753e-03,  2.6458e-05, -5.2432e-02,  1.1283e-01]],

         [[ 5.8889e-02, -5.0800e-02, -1.3966e-01,  2.9043e-02],
          [-5.3038e-02,  1.0355e-01,  2.7758e-01, -6.9283e-02],
          [-3.1021e-02, -1.9444e-01, -5.0737e-01,  1.4975e-01],
          [-2.5426e-02,  2.3800e

In [48]:
attention = F.softmax(attn_logits, dim=-1)

In [49]:
attention

tensor([[[[0.2999, 0.2044, 0.2478, 0.2478],
          [0.2082, 0.2961, 0.2478, 0.2478],
          [0.2500, 0.2500, 0.2500, 0.2500],
          [0.2500, 0.2500, 0.2500, 0.2500]],

         [[0.2644, 0.2370, 0.2493, 0.2493],
          [0.2337, 0.2733, 0.2465, 0.2465],
          [0.2500, 0.2500, 0.2500, 0.2500],
          [0.2500, 0.2500, 0.2500, 0.2500]]],


        [[[0.2887, 0.1968, 0.3139, 0.2006],
          [0.2133, 0.3034, 0.2026, 0.2807],
          [0.3217, 0.1437, 0.3960, 0.1386],
          [0.2450, 0.2461, 0.2335, 0.2754]],

         [[0.2713, 0.2431, 0.2224, 0.2633],
          [0.2200, 0.2573, 0.3062, 0.2165],
          [0.2726, 0.2315, 0.1693, 0.3266],
          [0.2406, 0.2527, 0.2634, 0.2433]]]], grad_fn=<SoftmaxBackward0>)

In [50]:
values = torch.matmul(attention, v)
values = values.permute(0, 2, 1, 3) # [Batch, SeqLen, Head, Dims]
values = values.reshape(batch_size, seq_length, embed_dim)

In [51]:
values

tensor([[[ 0.0274,  0.0963, -0.0437,  0.0533],
         [ 0.1143, -0.0125, -0.0603,  0.0768],
         [ 0.0721,  0.0393, -0.0501,  0.0624],
         [ 0.0721,  0.0393, -0.0501,  0.0624]],

        [[ 0.1522,  0.3589,  0.1203,  0.2496],
         [ 0.2690,  0.1288,  0.1019,  0.1842],
         [ 0.0879,  0.5008,  0.1444,  0.3327],
         [ 0.2238,  0.2107,  0.1110,  0.2233]]], grad_fn=<UnsafeViewBackward0>)