In [12]:
import math
import torch
import torch.nn.functional as F

torch.__version__

'1.7.0+cu101'

In [43]:
N_TOKENS = 2
KEY_DIM  = 3
VAL_DIM  = 4

queries = torch.randn(N_TOKENS, KEY_DIM)
keys    = torch.randn(N_TOKENS, KEY_DIM)
values  = torch.randn(N_TOKENS, VAL_DIM)

In [44]:
queries

tensor([[-0.3362, -0.0215, -0.1524],
        [-0.1292,  0.1794,  0.1826]])

In [48]:
queries.shape, keys.T.shape

(torch.Size([2, 3]), torch.Size([3, 2]))

In [49]:
scores = (queries @ keys.T )

scores.shape, values.shape

(torch.Size([2, 2]), torch.Size([2, 4]))

In [50]:
scores

tensor([[ 0.3926, -0.2439],
        [-0.3123, -0.0934]])

In [51]:
F.softmax(scores, dim=-1)

tensor([[0.6540, 0.3460],
        [0.4455, 0.5545]])

In [57]:
def attention(queries, keys, values, mask=None):
    KEY_DIM = queries.size()[-1]
    scores = torch.matmul(queries, keys.transpose(-1, -2)) # Before: queries @ keys.T
    scores = scores /  math.sqrt(KEY_DIM)
    scores = F.softmax(scores, dim=-1)
        
    output = torch.matmul(scores, values)
    return output

attention(queries, keys, values)

tensor([[[-0.7839, -0.4311, -0.0485,  1.3585],
         [-0.8765, -0.5418, -0.1353,  1.4346]],

        [[-0.5609,  0.3538, -0.8545,  0.1455],
         [-0.5722,  0.2888, -0.8354,  0.0571]]])

In [56]:
BATCH_SISE = 2
N_TOKENS = 2
KEY_DIM  = 3
VAL_DIM  = 4

queries = torch.randn(BATCH_SISE, N_TOKENS, KEY_DIM)
keys    = torch.randn(BATCH_SISE, N_TOKENS, KEY_DIM)
values  = torch.randn(BATCH_SISE, N_TOKENS, VAL_DIM)

attention2(queries, keys, values)

tensor([[[-0.7839, -0.4311, -0.0485,  1.3585],
         [-0.8765, -0.5418, -0.1353,  1.4346]],

        [[-0.5609,  0.3538, -0.8545,  0.1455],
         [-0.5722,  0.2888, -0.8354,  0.0571]]])

In [None]:
class MultiHeadAttention(nn.Module):

    def __init__(self, heads, d_model, dropout = 0.1):
        super().__init__()
        
        self.d_model = d_model
        self.d_k     = d_model // heads
        self.h       = heads
        
        self.q_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        self.out = nn.Linear(d_model, d_model)
    
    def forward(self, q, k, v, mask=None):
        
        bs = q.size(0)
        
        # perform linear operation and split into h heads
        k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
        q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
        v = self.v_linear(v).view(bs, -1, self.h, self.d_k)
        
        # transpose to get dimensions bs * h * sl * d_model
       
        k = k.transpose(1,2)
        q = q.transpose(1,2)
        v = v.transpose(1,2)
# calculate attention using function we will define next
        scores = attention(q, k, v, self.d_k, mask, self.dropout)
        
        # concatenate heads and put through final linear layer
        concat = scores.transpose(1,2).contiguous()\
        .view(bs, -1, self.d_model)
        
        output = self.out(concat)
    
        return output

In [58]:
BS = 2
SEQUENCE_DIM = 4
HEAD_DIM = 3
VALUE_DIM = 5

t = torch.randn(BS, SEQUENCE_DIM, HEAD_DIM, VALUE_DIM)
t.shape

torch.Size([2, 4, 3, 5])

In [99]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


def attention(queries, keys, values, mask=None):

    # queries:[BS, SEQ_LEN, KEY_LEN]
    # keys:   [BS, SEQ_LEN, KEY_LEN]
    # values: [BS, SEQ_LEN, VAL_LEN]
    # output: [BS, SEQ_LEN, VAL_LEN]

    KEY_DIM = queries.size()[-1]
    scores = torch.matmul(queries, keys.transpose(-1, -2)) # Before: queries @ keys.T
    scores = scores / math.sqrt(KEY_DIM)
    if mask is not None:
        scores = scores.masked_fill(mask==0, -1e9)
    scores = F.softmax(scores, dim=-1)
        
    output = torch.matmul(scores, values)
    return output


class MultiHeadAttention(nn.Module):

    def __init__(self, inputEmb_dim, outputEmb_dim, num_heads, key_dim, value_dim):

        super(MultiHeadAttention, self).__init__()

        self.dim_heads = num_heads
        self.dim_keys  = key_dim
        self.dim_vals  = value_dim

        self.linear_q = nn.Linear(inputEmb_dim, num_heads * key_dim)
        self.linear_k = nn.Linear(inputEmb_dim, num_heads * key_dim)
        self.linear_v = nn.Linear(inputEmb_dim, num_heads * value_dim)
        self.linear_o = nn.Linear(num_heads * value_dim, outputEmb_dim)

    def forward(self, x, mask=None):

        # x: [BS, SEQUENCE_DIM, INPUT_EMBEDING_DIM]
        dim_bs  = x.size(0)
        dim_seq = x.size(1)

        q = self.linear_q(x) # [BS, SEQUENCE_DIM, HEAD_DIM * KEY_DIM]
        k = self.linear_k(x) # [BS, SEQUENCE_DIM, HEAD_DIM * KEY_DIM]
        v = self.linear_v(x) # [BS, SEQUENCE_DIM, HEAD_DIM * VALUE_DIM]

        q = q.view(dim_bs, dim_seq, self.dim_heads, self.dim_keys) # [BS, SEQUENCE_DIM, HEAD_DIM, KEY_DIM]
        k = k.view(dim_bs, dim_seq, self.dim_heads, self.dim_keys) # [BS, SEQUENCE_DIM, HEAD_DIM, KEY_DIM]
        v = v.view(dim_bs, dim_seq, self.dim_heads, self.dim_vals) # [BS, SEQUENCE_DIM, HEAD_DIM, VALUE_DIM]

        k = k.transpose(1,2) # [BS, HEAD_DIM, SEQUENCE_DIM, KEY_DIM]
        q = q.transpose(1,2) # [BS, HEAD_DIM, SEQUENCE_DIM, KEY_DIM]
        v = v.transpose(1,2) # [BS, HEAD_DIM, SEQUENCE_DIM, VALUE_DIM]

        out = attention(q, k, v, mask) # [BS, HEAD_DIM, SEQUENCE_DIM, VALUE_DIM]

        out = out.transpose(1,2) # [BS, SEQUENCE_DIM, HEAD_DIM, VALUE_DIM]
        out = out.contiguous()
        out = out.view(dim_bs, dim_seq, self.dim_heads * self.dim_vals) # [BS, SEQUENCE_DIM, HEAD_DIM * VALUE_DIM]
        out = self.linear_o(out) # [BS, SEQUENCE_DIM, OUTPUT_EMBEDING_DIM]

        return out


In [100]:
mha = MultiHeadAttention(inputEmb_dim=8,
                         outputEmb_dim=8,
                         num_heads=4,
                         key_dim=3,
                         value_dim=5)
BS = 64
SEQUENCE_DIM = 100
INPUT_EMBEDING_DIM = 8
x = torch.randn(BS, SEQUENCE_DIM, INPUT_EMBEDING_DIM) 

In [101]:
mha

MultiHeadAttention(
  (linear_q): Linear(in_features=8, out_features=12, bias=True)
  (linear_k): Linear(in_features=8, out_features=12, bias=True)
  (linear_v): Linear(in_features=8, out_features=20, bias=True)
  (linear_o): Linear(in_features=20, out_features=8, bias=True)
)

In [102]:
mha(x)

tensor([[[ 0.0735,  0.1203,  0.1862,  ...,  0.0435, -0.0371, -0.2604],
         [ 0.2182,  0.0490,  0.0103,  ...,  0.0568, -0.0927, -0.1459],
         [ 0.1954,  0.0210,  0.0602,  ...,  0.0114, -0.0521, -0.2118],
         ...,
         [ 0.1868,  0.0354,  0.0185,  ...,  0.0102, -0.0067, -0.2195],
         [ 0.1770,  0.0434,  0.0842,  ...,  0.0193, -0.0479, -0.2201],
         [ 0.0283,  0.0499,  0.1299,  ...,  0.0041, -0.0264, -0.2475]],

        [[ 0.2088,  0.0162,  0.0898,  ...,  0.1180, -0.0918, -0.1002],
         [ 0.3187,  0.0293,  0.0811,  ...,  0.0752, -0.1491, -0.1509],
         [ 0.2766, -0.0804,  0.0940,  ...,  0.0539, -0.1444, -0.1465],
         ...,
         [ 0.2179, -0.0348,  0.1274,  ...,  0.0461, -0.0990, -0.1675],
         [ 0.1575, -0.0888,  0.0826,  ...,  0.0265,  0.0364, -0.2286],
         [ 0.1361, -0.0101,  0.1384,  ...,  0.0723, -0.0191, -0.1932]],

        [[ 0.1610, -0.0536,  0.1465,  ..., -0.0775, -0.0967, -0.2957],
         [ 0.1623,  0.0085,  0.1551,  ..., -0

In [103]:
BS = 64
SEQUENCE_DIM = 100
INPUT_EMBEDING_DIM = 8
x = torch.randn(BS, SEQUENCE_DIM, INPUT_EMBEDING_DIM) 
x.shape


torch.Size([64, 100, 8])

In [104]:
mean = x.mean(dim=-1, keepdim=True)
mean.shape

torch.Size([64, 100, 1])

In [107]:
(x - mean).shape

torch.Size([64, 100, 8])

In [111]:
lm = nn.LayerNorm(-1)
lm(x).shape

RuntimeError: ignored

In [112]:
class LayerNorm(nn.Module):

    def __init__(self, emb_dim, eps=1e-6):
        super().__init__()
    
        self.emb_dim = emb_dim

        # create two learnable parameters to calibrate normalisation
        self.alpha = nn.Parameter(torch.ones(self.emb_dim))
        self.bias  = nn.Parameter(torch.zeros(self.emb_dim))
        self.eps   = eps

    def forward(self, x):

        # x: [BS, SEQUENCE_DIM, EMBEDING_DIM]

        mean = x.mean(dim=-1, keepdim=True) [BS, SEQUENCE_DIM, 1]
        std  = x.std(dim=-1,  keepdim=True)

        return self.alpha * (x - mean) / (std + self.eps) + self.bias

In [122]:
x * torch.Tensor([1,2,12,1,1,1,1,1])

tensor([[[ 5.6180e-01, -1.1600e+00,  2.1425e+01,  ..., -1.2426e+00,
          -1.4187e+00, -5.6282e-02],
         [-5.3771e-01,  4.8430e-01, -9.1559e-01,  ..., -1.1818e+00,
           8.8488e-01,  1.2458e+00],
         [ 2.7389e+00,  3.0028e+00,  1.0331e+01,  ..., -3.6169e-01,
           1.6268e+00,  2.3974e-01],
         ...,
         [-1.6074e+00, -1.1349e+00,  1.2201e+01,  ..., -1.6317e+00,
          -2.8649e-01,  3.6544e-02],
         [ 2.5783e-01,  2.9038e+00, -5.9686e+00,  ...,  2.3704e-01,
          -1.7103e-01,  2.2783e-01],
         [-2.4159e+00, -1.2550e+00,  2.1654e+00,  ..., -8.1385e-01,
           5.7956e-01, -1.8479e-01]],

        [[-2.4769e-01, -4.6142e-01, -8.9455e+00,  ...,  6.1784e-01,
           7.5373e-02,  1.5667e+00],
         [ 3.5687e-01, -8.1378e-02,  1.1038e+01,  ..., -9.8647e-02,
          -2.3364e-01,  7.1643e-01],
         [ 1.4991e+00,  1.3932e+00,  1.0280e+01,  ...,  7.6585e-01,
           7.4912e-02, -1.6023e+00],
         ...,
         [ 1.5412e+00,  2

In [115]:
ln = LayerNorm(8)
ln(x).shape

torch.Size([64, 100, 8])

In [119]:
ln(x).std(-1)#.shape

tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       grad_fn=<StdBackward1>)