# Unit-testing all elements of a Transformer
The pupose of this notebook is to unit-test: 
 * Multi-head attention
 * Point-wise feed-forward network via conv1d
 * Layer nomalization
 * Learnable positional embeddings  
 
All that items are building blocks for the transfomer

In [129]:
%config Completer.use_jedi = False

## Multihead attention 
[docs](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html)  
how to create a proper attention mask to attend only to past events and not to the future

Credit for the image to https://jalammar.github.io/illustrated-transformer/  
![](./transformer_multi-headed_self-attention-recap.png)

In [28]:
import torch.nn as nn
import torch

In [71]:
# parameters for a toy model
d_model = 6
num_heads = 1
seq_len = 4
batch_size = 2

In [256]:
# torch.nn.MultiheadAttention(embed_dim, 
#                             num_heads, 
#                             dropout=0.0, 
#                             bias=True, 
#                             add_bias_kv=False, 
#                             add_zero_attn=False, 
#                             kdim=None, 
#                             vdim=None, 
#                             batch_first=False, 
#                             device=None, 
#                             dtype=None)

multihead_attn = nn.MultiheadAttention(d_model, num_heads, batch_first=True)

In [257]:
# input minibatch
x = torch.randn((batch_size, seq_len, d_model))

In [258]:
x.shape

torch.Size([2, 4, 6])

In [75]:
# compute a forward pass
with torch.no_grad():
    attn_output, attn_output_weights = multihead_attn.forward(x, x, x, key_padding_mask=None, need_weights=True, attn_mask=None)

In [76]:
attn_output.shape, attn_output_weights.shape

(torch.Size([2, 4, 6]), torch.Size([2, 4, 4]))

In [77]:
# output weights must sum to one
attn_output_weights.sum(dim=-1)

tensor([[1.0000, 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000]], grad_fn=<SumBackward1>)

In [78]:
attn_output_weights

tensor([[[0.1936, 0.2610, 0.2356, 0.3098],
         [0.3099, 0.3128, 0.1453, 0.2319],
         [0.2948, 0.1864, 0.3057, 0.2130],
         [0.3823, 0.3377, 0.0763, 0.2036]],

        [[0.3052, 0.2019, 0.2533, 0.2396],
         [0.4544, 0.1016, 0.2112, 0.2327],
         [0.3193, 0.2103, 0.2090, 0.2614],
         [0.4490, 0.0918, 0.2630, 0.1963]]], grad_fn=<DivBackward0>)

In [260]:
# structure and number of parameters
for n, v in multihead_attn.named_parameters() :
    print(f"{v.numel():<4} {n:>15} - {v.shape}")

108   in_proj_weight - torch.Size([18, 6])
18      in_proj_bias - torch.Size([18])
36   out_proj.weight - torch.Size([6, 6])
6      out_proj.bias - torch.Size([6])


### Make a masked attention  
for sequntial recommendation so `j` can attend to `i` only if `j>=i` -> this removes looking into future items `j<i` 

In [231]:
# final attention mask matrix is this one
# rows are target seq and columns are source seq
# how to read the last line - last output item attends to all source/input items
# how to read the first line - first output element attends ONLY to itself
# when an element is True - it is excluded from attention
# True in mask meaning that this element is excluded from the attention
attn_mask = torch.ones((seq_len, seq_len), dtype=torch.bool)
attn_mask = torch.triu(attn_mask, diagonal=1)
attn_mask

tensor([[False,  True,  True,  True],
        [False, False,  True,  True],
        [False, False, False,  True],
        [False, False, False, False]])

In [82]:
with torch.no_grad():
    attn_output, attn_output_weights = multihead_attn.forward(x, x, x, key_padding_mask=None, need_weights=True, attn_mask=attn_mask)

In [83]:
attn_output_weights.sum(dim=-1)

tensor([[1.0000, 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000]])

In [84]:
attn_output_weights

tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
         [0.4977, 0.5023, 0.0000, 0.0000],
         [0.3746, 0.2369, 0.3885, 0.0000],
         [0.3823, 0.3377, 0.0763, 0.2036]],

        [[1.0000, 0.0000, 0.0000, 0.0000],
         [0.8173, 0.1827, 0.0000, 0.0000],
         [0.4323, 0.2848, 0.2829, 0.0000],
         [0.4490, 0.0918, 0.2630, 0.1963]]])

## PointWiseFeedForward via Conv1d
in the simplest case, the output value of the layer with input size (N,Cin,L) and output (N,Cout,Lout)

In [141]:
# note that prints are the same
print(*x.shape) # batch-seq_len-d_model
print(batch_size, seq_len, d_model)

2 4 6
2 4 6


In [142]:
x

tensor([[[-0.4513,  0.0353,  0.1270,  0.4740,  0.8944,  1.1747],
         [-1.3230,  0.1927, -0.1822,  0.0926,  1.0287, -0.4540],
         [ 0.6806,  1.1146,  0.4576, -0.6504, -1.3897, -0.1947],
         [-1.2598, -0.0638,  0.6872, -2.4544,  1.2562, -0.5079]],

        [[ 1.2511, -0.1029, -0.9956,  0.1884, -0.5661, -0.4276],
         [-0.4554,  0.5926, -3.0739,  1.1600, -0.6636, -1.0234],
         [-0.7723,  0.8463, -0.3214, -0.7221, -0.2912, -1.7728],
         [ 1.5076, -0.5536, -3.4129,  0.1765, -0.0653,  0.5672]]])

In [143]:
# now we transpose the inner matrix so we would have a column per embedding
# this is due to conv1d - it mixes data between rows, i.e. samples
x.swapaxes(-1, -2)

tensor([[[-0.4513, -1.3230,  0.6806, -1.2598],
         [ 0.0353,  0.1927,  1.1146, -0.0638],
         [ 0.1270, -0.1822,  0.4576,  0.6872],
         [ 0.4740,  0.0926, -0.6504, -2.4544],
         [ 0.8944,  1.0287, -1.3897,  1.2562],
         [ 1.1747, -0.4540, -0.1947, -0.5079]],

        [[ 1.2511, -0.4554, -0.7723,  1.5076],
         [-0.1029,  0.5926,  0.8463, -0.5536],
         [-0.9956, -3.0739, -0.3214, -3.4129],
         [ 0.1884,  1.1600, -0.7221,  0.1765],
         [-0.5661, -0.6636, -0.2912, -0.0653],
         [-0.4276, -1.0234, -1.7728,  0.5672]]])

In [144]:
def init_weight(m):
    """
    this is the init weight function that sets all convolution parameters to 1 but the second matrix to 2
    """
    from functools import reduce
    # l = reduce(lambda x,y: x*y, m.weight.data.shape)
    l = m.weight.data.numel() # just the number of elements
    if type(m) == nn.Conv1d:
        m.weight.data = torch.ones(l).to(torch.float).reshape(m.weight.data.shape)
        m.weight.data[1] += m.weight.data[1]

In [145]:
# we take in the size of d_model and output the same size (that's the stuff in paper SASrec)
# in attention is all you need intermediate dim is 4 times larger than d_model
# kernel size = 1 makes it a point-wise transformation - remember we have item embeddings in each column and kernel size corresponds to the number of columns
# that we aggregate in summation of convolution
layer_conv1d = nn.Conv1d(in_channels=d_model, out_channels=d_model, kernel_size=1, bias=False)
layer_conv1d.apply(init_weight)
layer_conv1d.weight.data.shape

torch.Size([6, 6, 1])

In [128]:
layer_conv1d.weight.data

tensor([[[1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.]],

        [[2.],
         [2.],
         [2.],
         [2.],
         [2.],
         [2.]],

        [[1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.]],

        [[1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.]],

        [[1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.]],

        [[1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.]]])

In [146]:
conv_1d_out = layer_conv1d(x.swapaxes(-1,-2))
conv_1d_out.shape

torch.Size([2, 6, 4])

In [147]:
# change back order of axes and move item embeddings to row orientation (to last dim)
conv_1d_out.swapaxes(-1,-2)

tensor([[[ 2.2543,  4.5085,  2.2543,  2.2543,  2.2543,  2.2543],
         [-0.6453, -1.2905, -0.6453, -0.6453, -0.6453, -0.6453],
         [ 0.0181,  0.0361,  0.0181,  0.0181,  0.0181,  0.0181],
         [-2.3424, -4.6847, -2.3424, -2.3424, -2.3424, -2.3424]],

        [[-0.6529, -1.3058, -0.6529, -0.6529, -0.6529, -0.6529],
         [-3.4638, -6.9277, -3.4638, -3.4638, -3.4638, -3.4638],
         [-3.0335, -6.0670, -3.0335, -3.0335, -3.0335, -3.0335],
         [-1.7805, -3.5610, -1.7805, -1.7805, -1.7805, -1.7805]]],
       grad_fn=<TransposeBackward0>)

In [149]:
# and take a note that now for each embedding we have the same pattern <sum(orig_embedding), twice the sum, sum, ..., sum>
# that is due to all conv1d weights are 1 except for the 2-nd channel wich is set to 2
x

tensor([[[-0.4513,  0.0353,  0.1270,  0.4740,  0.8944,  1.1747],
         [-1.3230,  0.1927, -0.1822,  0.0926,  1.0287, -0.4540],
         [ 0.6806,  1.1146,  0.4576, -0.6504, -1.3897, -0.1947],
         [-1.2598, -0.0638,  0.6872, -2.4544,  1.2562, -0.5079]],

        [[ 1.2511, -0.1029, -0.9956,  0.1884, -0.5661, -0.4276],
         [-0.4554,  0.5926, -3.0739,  1.1600, -0.6636, -1.0234],
         [-0.7723,  0.8463, -0.3214, -0.7221, -0.2912, -1.7728],
         [ 1.5076, -0.5536, -3.4129,  0.1765, -0.0653,  0.5672]]])

In [148]:
x.sum(dim=-1)

tensor([[ 2.2543, -0.6453,  0.0181, -2.3424],
        [-0.6529, -3.4638, -3.0335, -1.7805]])

In [130]:
torch.nn.functional.relu(conv_1d_out.swapaxes(-1,-2))

tensor([[[2.2543, 4.5085, 2.2543, 2.2543, 2.2543, 2.2543],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0181, 0.0361, 0.0181, 0.0181, 0.0181, 0.0181],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

        [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]],
       grad_fn=<ReluBackward0>)

## LayerNorm  
[docs](https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html)  
What it does - it normalises data in each sample and multiplies each dimention by `Gamma`(each dim has it's own gamma) and adds `Bias` (again specific bias for each dim) and those a learnable parameters

In [150]:
from torch.nn import LayerNorm
embedding = torch.randn(batch_size, seq_len, d_model)
layer_norm = nn.LayerNorm(d_model)

In [203]:
for n, v in layer_norm.named_parameters() :
    print(f"{v.numel():<4} {n:>15} - {v.shape}")

6             weight - torch.Size([6])
6               bias - torch.Size([6])


In [159]:
layer_norm.weight.data, layer_norm.bias.data

(tensor([1., 1., 1., 1., 1., 1.]), tensor([0., 0., 0., 0., 0., 0.]))

In [199]:
embedding[0]

tensor([[ 1.2505,  0.2956,  0.5981, -1.6835,  1.3008, -1.9311],
        [-1.1474,  0.2873, -0.5039, -0.5448, -1.5466, -0.7329],
        [ 1.7583,  1.6505, -0.2722, -0.8209, -0.9714,  0.9023],
        [ 1.1841,  0.7555,  0.3598, -0.6615,  0.1244, -0.7045]])

In [205]:
# Activate module
with torch.no_grad():
    layer_out = layer_norm(embedding)[0]
layer_out

tensor([[ 0.9779,  0.2477,  0.4790, -1.2658,  1.0164, -1.4552],
        [-0.7871,  1.7261,  0.3401,  0.2684, -1.4865, -0.0611],
        [ 1.2393,  1.1429, -0.5791, -1.0705, -1.2054,  0.4728],
        [ 1.4583,  0.8382,  0.2655, -1.2124, -0.0751, -1.2746]])

####  Calculate same stuff manually

In [206]:
embedding.mean(dim=-1)[0], embedding.var(dim=-1)[0]

(tensor([-0.0282, -0.6981,  0.3744,  0.1763]),
 tensor([2.0519, 0.3910, 1.4961, 0.5731]))

In [195]:
embedding.mean(dim=-1)[0].unsqueeze(0).T

tensor([[-0.0282],
        [-0.6981],
        [ 0.3744],
        [ 0.1763]])

In [211]:
# same result as from layer norm
manual_layer_norm = (embedding[0] - embedding.mean(dim=-1)[0].unsqueeze(0).T)/(torch.sqrt(embedding.var(dim=-1, unbiased=False)[0].unsqueeze(0).T + layer_norm.eps))

In [225]:
manual_layer_norm - layer_out < 0.0001

tensor([[True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True]])

In [230]:
~torch.tril(torch.ones((5, 5), dtype=torch.bool))

tensor([[False,  True,  True,  True,  True],
        [False, False,  True,  True,  True],
        [False, False, False,  True,  True],
        [False, False, False, False,  True],
        [False, False, False, False, False]])

In [None]:
torch.triu

In [233]:
import numpy as np

In [253]:
positions = torch.tile(torch.arange(0,seq_len), (batch_size,1))

In [254]:
positions

tensor([[0, 1, 2, 3],
        [0, 1, 2, 3]])

In [251]:
positions.expand

tensor([[0, 1, 2, 3],
        [0, 1, 2, 3]])

In [252]:
positions.expand_as

AttributeError: 'numpy.ndarray' object has no attribute 'expand_as'