# Attention is all you need!

In this notebook, we set up and explain an autoencoder (encoder + decoder) that uses the transformer architecture from [Attention Is All You Need](https://arxiv.org/abs/1706.03762) for the tiny shakespeare dataset. We set a few rules

1. We are only allowed to use the paper
2. We are only allowed to use PyTorch documentation

Personal note - when writing this notebook, I allowed myself the following leeway 

If we are not able to solve something after spending 30 minutes on it then

3. We can look at the annotated version of Attention Is All You Need by Sasha Rush

### Global variables and imports

In [13]:
import torch
import torch.nn as nn
import math
import copy

The model requires a lot of hyperparameters. Their purpose is not immediately straightforward at this point (we haven't even looked at the model yet!). We will constantly mark where each parameter appears for the first time in the notebook. For now the reader can just run the cell below 

In [189]:
# MODEL HYPERS --------------------------
vocab_size = 65 # For character level tokenisation. Actual vocab size used in the paper is 320000
context_length = 8
dmodel = 32
n_embd = dmodel
mlp_feature_dim = 32
h = 2
N = 2
p = 0.5
# ---------------------------------

## Implementing transformers from scratch

For any model, it is important to first define what the input looks like. Tranformers use a sequence of token ids as the input. The length of the sequence defines the `context_length` and denotes the number of tokens (context) that the network looks at in a given instance. Since we are dealing with character level tokenisation, the vocab size is 65 (jump ahead to [Data](##-Data) to see why).

We also need a target output. This, following the paper, will be the input with all tokens shifted left and a new token. So for example if the target text is abcd, the input is abc and the output is bcd (because given abc we want the network to predict d)

In [190]:
# Produce a random context_length sized list of token ids
Xtr = torch.randint(0, vocab_size,(context_length,))
Ytr = torch.cat((Xtr[1:], torch.randint(0,vocab_size,(1,))))
Xtr, Ytr

(tensor([23, 18,  2, 23, 27, 52, 24, 53]),
 tensor([18,  2, 23, 27, 52, 24, 53, 56]))

In [191]:
Xtr.shape

torch.Size([8])

We now have our dataset (which is just a dummy variable that we randomly initialised). Let us now try to understand how to implement the transformer. We are not going to keep it modular in this section. We are going to do everything in a dumb procedural way. After we figure out everything from the paper in this section is when we set up the PyTorch-ified version of transformers in the next section

### Embeddings

The first step is to embed these characters using an embedding matrix. This introduces a new hyperparamater `n_embd`.

In [192]:
C = torch.randn((vocab_size, n_embd))

In [193]:
emb = C[Xtr]
emb.shape

torch.Size([8, 32])

This part was easy because we have already done this. Now the next complication is the positional encoding. 

### Positional encoding

A word can often take different meanings on the basis of where it appears in a sentence. This further extends to position of a word in a paragraph if we consider multiple sentences. Thus it helps to add information about the position of a word inside the context window.

The paper uses the following way of including information about the position.

We have for each example a 16 by 32 matrix of embeddings. To these embeddings, we add a 16 by 32 positional encoding matrix. The values the matrix takes are as follows
$$
pos_{(n,2m)} = \sin\left(\frac{n}{10000^{2m/32}}\right) \\ 
pos_{(n,2m+1)} = \cos\left(\frac{n}{10000^{2m/32}}\right)
$$

This is a really weird way of adding positional encoding in my opinion but for now we just stay with this. I would personally train an entire matrix of positional encodings as a trainable parameter of the model. However, the paper mentions that they tried training a positional encoding and the performance was comparable to using the above encoding

In [194]:
pos = torch.zeros((context_length, n_embd))
position = torch.arange(0,context_length)
divisionFactor = 10000**(-torch.arange(0, n_embd, 2)/32)
pos[:, ::2] = torch.sin(position.view(-1,1) * divisionFactor)
pos[:,1::2] = torch.cos(position.view(-1,1) * divisionFactor)

In [195]:
(emb + pos).shape

torch.Size([8, 32])

### Attention heads

Now time for the main meat of the paper, which is the attention head. 

The gist of what the attention head does is that it treats the entire context as a phonebook, each token as an entity (a person or a company for example) and wants to give each token the power to choose which phone number it wants to focus on based on certain characteristics of the entity.

Put in simple words, we want to give each token the power to understand what role it plays. So for example let's take the sentence "The fox ate the sheep" and assume for a moment that the tokenisation we use is word-level. We want to give the token `sheep` the ability to choose which tokens to focus on to figure out its role in the sentence. We want to let each token ask a question (or query). For simplicity we assume that the question lies in a $d_k$ dimensional space. Each token also emits an answer to each of those questions (the key). Obviously, these two need to live in the same number of dimensions (you cannot have questions with no answers, and you cannot have answers to non-existent questions). Now each token also emits a phone number (value) which we assume to lie in a $d_v$ dimensional space. Based on the questions it asks and the answers it gets, it chooses to either ignore the phone number or give it the most importance. In other words, we want a weighted sum of all the values emitted by all the tokens in the context.

Summarising, for each character we have a $d_k$ dimensional query vector $q$, a $d_k$ dimensional key vector $k$, and  a $d_v$ dimensional value vector $v$. We package them together into the matrices $Q$ of dimension $\text{context_length}\times d_k$, $K$ of dimension $\text{context_length}\times d_k$ and $V$ of dimension $\text{context_length}\times d_v$. Then the attention is calculated as 
$$
\text{Attention} = \text{softmax}\left( \frac{QK^T}{\sqrt{d_k}} \right) V
$$

It is evident that the above is a sensible operation because $QK^T$ produces a $\text{context_length}\times\text{context_length}$ matrix, which can be multiplied in a sensible way with $V$

Each Q, K, V is learnt from the data itself. So then we need three linear layers that map from the data to Q, K, V and then calculate the attention. We follow the paper and take them all to be `n_embd`

In [196]:
Q = torch.randn((n_embd,n_embd))
K = torch.randn((n_embd,n_embd))
V = torch.randn((n_embd,n_embd))

def oldattention(x): # This was redefined in later parts of the notebook to make MHA easier to write
    query = x @ Q
    key = x @ K
    value = x @ V
    score = torch.softmax(query @ key.transpose(-2,-1), -2) / math.sqrt(n_embd) # The -2 in softmax tells PyTorch to take a softmax in the second last dimension
    attention = score @ value
    return attention

In [197]:
attentionOutput = oldattention(emb + pos) + (emb + pos)

Now that we have the attention head, we need to pass it through a MLP. We work with the architecture of the original paper of having a single hidden layer, followed by a ReLU, and then a second layer that brings it back to n_embd. This introduces a new hyperparameter which controls the dimensions of the hidden layer

In [198]:
W1 = torch.randn((n_embd, hidden_dim))
b1 = torch.randn((hidden_dim,))
W2 = torch.randn((hidden_dim, n_embd))
b2 = torch.randn((n_embd))

In [199]:
relu = torch.nn.ReLU()

In [200]:
firstLayer = relu(attentionOutput @ W1 + b1)

In [201]:
secondLayer = firstLayer @ W2 + b2

In [202]:
secondLayer.shape

torch.Size([8, 32])

I believe we have completely implemented one transformer block. What we have not figured out is the following - 

1. How to implement multi-head attention layer
2. What we have done so far is take a 16 context length by 32 embedding dimension and done a few things on it. Now how do we convert the output to something that we can evaluate the loss on?

### Multi-head attention

I am not sure I understand why multi-head attention makes calculation of attention more cost-effective. I need to take out a paper pad and do complexity calculations

First let's start by a simple calculation of the complexity of computing a product of two matrices. Let's say I have a matrix $A$ of size $n\times k$ and a matrix $B$ of size $k \times m$. We can take an example to make the calculation easier to understand.


$$
A = 
\begin{pmatrix}
8 & 6 \\
5 & 2 \\ 
9 & 2 \\
\end{pmatrix}
$$

and 

$$
B = 
\begin{pmatrix}
1 & 2 & 3 & 1 \\ 
4 & 5 & 6 & 3 \\
\end{pmatrix}
$$

Now let's say that every access to an element in a matrix is $O(1)$ and every multiplication of two floats is $O(1)$. So how many accesses and how many multiplications am I doing in $AB$?

If you work through it, it is $O(m\times n \times k)$.

Are there quicker multiplication algorithms?

I just checked the answer to that and it depends. If you have $n\times k$ and $k\times m$ then the only thing you can do is the complexity I calculated, which is $O(m\times n \times k)$. If it is a square matrix then we can bring it down to roughly $O(n^{2.37})$ (check Coppersmith-Winograd algorithm).

So now let us work to understand the complexity of the self-attention heads. 

What are the dimensions of $Q$ and $K$? It is $n \times d_k$ where $n$ is the context length. And so to multiply $Q$ and $K^T$ would be $O(n^2 d_k)$. Softmax and division are $O(n^2)$ because there are $n^2$ elements in $QK^T$. And now multiplying this to $V$ is $O(n^2 d_v)$. 

I think the point of multi-head attention is the following. Taking $h$ heads of $d_k/h$ and $d_v/h$ does not improve time complexity because you have to do $h$ number of $O(n^2 d_k/h)$ matrix multiplications so the time complexity is the exact same. What makes it more efficient is that you can parallelise these $h$ matrix multiplications so thus in-effect the time complexity is $O(n^2 d_k/h)$.

Before we start writing the code for MultiHeadAttention we need to modify a few things in the definition for attention so that we can make it more modular. We are going to define it to take the query, key, and value, and return the attention directly. Which means, we are going to do `x @ Q` outside the attention function

In [203]:
def attention(query, key, value):
    dk = key.shape[-1]
    assert query.shape[-1] == key.shape[-1]
    assert key.shape[-2] == value.shape[-2]
    score = torch.softmax(query @ key.transpose(-2,-1), -2) / math.sqrt(dk)
    attention = score @ value
    return attention

Now we are going to first implement an attention layer which takes in $d_k$ and $d_v$ and does vanilla attention without any multihead business. The only difference now is that we are going to initiate with `nn.Linear`

In [204]:
class AttentionHead(nn.Module):
    def __init__(self, context_length, dmodel):
        super().__init__()
        self.linears = nn.ModuleList([nn.Linear(dmodel, context_length, bias=False) for _ in range(3)])
    def forward(self, x):
        query, key, value = [lin(x) for lin in self.linears]
        return attention(query, key, value)

In [205]:
attentionhead = AttentionHead(context_length, n_embd)
x = emb+pos
attentionhead(x)

tensor([[ 0.0142, -0.0229, -0.0261,  0.0085,  0.0872, -0.0189,  0.0293,  0.0028],
        [ 0.0562, -0.0150, -0.0367,  0.0365,  0.3662, -0.0319,  0.1328,  0.0268],
        [ 0.0440, -0.0378, -0.0066,  0.0503,  0.4968, -0.0609,  0.1498,  0.0540],
        [-0.0045, -0.0285, -0.0168,  0.0075,  0.0808, -0.0218,  0.0104, -0.0178],
        [ 0.0756, -0.0193,  0.0133,  0.0619,  0.4654, -0.0774,  0.1321,  0.0845],
        [-0.0539, -0.0739, -0.0428,  0.0136,  0.2315, -0.0687,  0.0344, -0.1187],
        [ 0.1744, -0.0834, -0.1005,  0.0647,  0.4040, -0.0753,  0.1209,  0.1623],
        [ 0.0743, -0.0170, -0.1167,  0.0211,  0.2525,  0.0036,  0.1400, -0.0052]],
       grad_fn=<MmBackward0>)

Now that we have got this down, we need to replicate this to make the multihead attention class. Let's implement for the case where the input and output sequence are of the same length

In [212]:
emb.shape

torch.Size([8, 32])

In [206]:
class MultiHeadAttention(nn.Module):
    def __init__(self, context_length, dmodel, h):
        super().__init__()
        assert dmodel % h == 0
        self.context_length = context_length
        self.dmodel = dmodel
        self.h = h
        self.dk = self.dmodel // self.h
        self.linears = nn.ModuleList([nn.Linear(dmodel, context_length) for _ in range(4)]) # One extra because of final linear in MHA
    def forward(self, x):
        query, key, value = [lin(x).view(-1, self.h, self.context_length, self.dk) for lin in self.linears[:-1]]
        attentionScore = attention(query, key, value).view(-1, self.context_length, self.dmodel)
        return self.linears[-1](attentionScore)

In [207]:
multiheadattention = MultiHeadAttention(context_length, dmodel, h)

In [218]:
multiheadattention.linears[2](emb).shape

torch.Size([8, 8])

We have now pretty much absorbed most of the intricacies of the model. Now it is time to implemet the encoder and the decoder blocks as subclasses of the module class and then put them all together and match the number of parameters with the number of paramters in the paper. If that matches then we are golden

In [209]:
class MultiHeadAttention(nn.Module):
    def __init__(self, context_length, dmodel, h):
        super().__init__()
        assert dmodel % h == 0
        self.context_length = context_length
        self.dmodel = dmodel
        self.h = h
        self.dk = self.dmodel // self.h
        self.linears = nn.ModuleList([nn.Linear(dmodel, context_length) for _ in range(4)]) # One extra because of final linear in MHA
    def attention(self, query, key, value):
        dk = key.shape[-1]
        assert query.shape[-1] == key.shape[-1]
        assert key.shape[-2] == value.shape[-2]
        score = torch.softmax(query @ key.transpose(-2,-1), -2) / math.sqrt(dk)
        attention = score @ value
        return attention
    def forward(self, x):
        query, key, value = [lin(x).view(-1, self.h, self.context_length, self.dk) for lin in self.linears[:-1]]
        attentionScore = self.attention(query, key, value).view(-1, self.context_length, self.dmodel)
        return self.linears[-1](attentionScore)

# class EncodeBlock(nn.Module):# This is currently un-normalised
#     def __init__(self, context_length, dmodel, h):
#         self.dmodel = dmodel
#         self.h = h
#         self.layers = [
#             MultiHeadAttention(context_length, dmodel, h),
#             nn.Linear(dmodel, mlp_feature_dim, bias=True),
#             nn.ReLU(),
#             nn.Linear(mlp_feature_dim, dmodel, bias=True)
#         ]
#     def forward(self, x):
#         for layer in self.layers:
#             x = layer(x)
#         return x

In [210]:
encoder = EncodeBlock(10, 512, 1)

TypeError: __init__() missing 1 required positional argument: 'p'

In [None]:
encoder.layers

One key missing/pain point here is that I have not implemented a residual connection. This needs handling. One way of implementing it is to define a sublayer that is initialised using a module as an input

In [None]:
class Sublayer(nn.Module):
    def __init__(self, module):
        super().__init__()
        self.module = module
        self.norm = None # Needs changing 
    def forward(self, x):
        return x + module(x)

So now we change the encode block in the following way

In [140]:
class MLP(nn.Module):
    def __init__(self, dmodel, mlp_feature_dim):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.Linear(dmodel, mlp_feature_dim, bias=True),
            nn.ReLU(),
            nn.Linear(mlp_feature_dim, dmodel, bias=True)
                      ])
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

class EncodeBlock(nn.Module):# This is currently un-normalised
    def __init__(self, context_length, dmodel, h):
        super().__init__()
        self.dmodel = dmodel
        self.h = h
        self.layers = nn.ModuleList([
            Sublayer(MultiHeadAttention(context_length, dmodel, h)),
            Sublayer(MLP(dmodel, mlp_feature_dim))
        ])
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

In [None]:
encoder = EncodeBlock(10, 512, 1)

In [None]:
encoder.layers

Now this defines properly a full encoder block. We need to finish the file step of this, which is to write the full encoder stack and then verify everything is working as intended. For this we need to define now a function that clones the encode block N times for us

In [None]:
def clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

In [None]:
class Encoder(nn.Module):
    def __init__(self, N, context_length, dmodel, h):
        super().__init__()
        self.layers = clones(EncodeBlock(context_length, dmodel, h), N)
        
    def forward(self,x):
        for layer in self.layers:
            x = layer(x)
        return x

In [None]:
encoder = Encoder(2,3,32,4)

In [None]:
encoder.layers

This I believe completes the implementation of the full encoder block. Now we collect them all together as a recap

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, context_length, dmodel, h):
        super().__init__()
        assert dmodel % h == 0
        self.context_length = context_length
        self.dmodel = dmodel
        self.h = h
        self.dk = self.dmodel // self.h
        self.linears = nn.ModuleList([nn.Linear(dmodel, context_length) for _ in range(4)]) # One extra because of final linear in MHA
    def attention(self, query, key, value):
        dk = key.shape[-1]
        assert query.shape[-1] == key.shape[-1]
        assert key.shape[-2] == value.shape[-2]
        score = torch.softmax(query @ key.transpose(-2,-1), -2) / math.sqrt(dk)
        attention = score @ value
        return attention
    def forward(self, x):
        query, key, value = [lin(x).view(-1, self.h, self.context_length, self.dk) for lin in self.linears[:-1]]
        attentionScore = self.attention(query, key, value).view(-1, self.context_length, self.dmodel)
        return self.linears[-1](attentionScore)

class MLP(nn.Module):
    def __init__(self, dmodel, mlp_feature_dim):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.Linear(dmodel, mlp_feature_dim, bias=True),
            nn.ReLU(),
            nn.Linear(mlp_feature_dim, dmodel, bias=True)
                      ])
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

class Sublayer(nn.Module):
    def __init__(self, module):
        super().__init__()
        self.module = module
        self.norm = None # Needs changing 
    def forward(self, x):
        return x + self.module(x)
    
class EncodeBlock(nn.Module):# This is currently un-normalised
    def __init__(self, context_length, dmodel, h):
        super().__init__()
        self.dmodel = dmodel
        self.h = h
        self.layers = nn.ModuleList([
            Sublayer(MultiHeadAttention(context_length, dmodel, h)),
            Sublayer(MLP(dmodel, mlp_feature_dim))
        ])
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x
    
def clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

class Encoder(nn.Module):
    def __init__(self, N, context_length, dmodel, h):
        super().__init__()
        self.layers = clones(EncodeBlock(context_length, dmodel, h), N)
    def forward(self,x):
        for layer in self.layers:
            x = layer(x)
        return x

First we need to make sure that whatever we have done here makes total sense. We are going to start with a dummy set of data and parameters and make sure one by one that we get what we should be getting. 

In [None]:
# HYPERS --------------------------
context_length = 4
n_embd = 4
dmodel = 4
mlp_feature_dim = 32
h = 2
N = 2
# ---------------------------------

First we start with what the input is supposed to look like. We have a context length of `context_length` and a vocab size of `v_size`. The encoder block does not deal with this directly. Instead, it deals with the embedded version of it, and thus what we want as an input is a tensor of dim `(context_length, n_embd)`

In [None]:
x = torch.randn([context_length, n_embd])
x.shape

Now we need to decide on the attention architecture. We are going to pick the dummy `dmodel = 4`. So the dimension of our query and key space is `dmodel`. The value space dimension is also `dmodel`. Here one thing is to note. The authors are working in a situation where `dmodel = n_embd` and we maintain that status quo. In case one wanted to work in a different setting, the query and key will be `dmodel` while the value space has to be `n_embd` for it to make sense. Remember, the query and key are intertwined and are kind of a lookup in the yellow pages, while the value space is the actual phone number. This **has** to be `n_embd`.

In [None]:
encoder = Encoder(N, context_length, dmodel, h)

For now we work under the assumption that the random x we generated has already been masked with the positional encoding. This is something that needs to be implemented in the architecture *outside* the encoder block, and thus for this testing we assume this has been achieved.

Now onto calculating the result of `encoder` and making sure it all works well.

In [None]:
encoderoutput = encoder(x)

The first immediate and obvious check is to make sure the dimension remains the same

In [None]:
encoderoutput.shape

The above output is expected since we specifically set `view(-1, self.context_length, self.dmodel)` in the multi-head attention block, which adds the batch dimension. This will be useful and important when we have a non-trivial batch dimension

Now we want to verify that the module is doing what we want it to do. Let's look at x and the components again

In [None]:
x

In [None]:
encoder.layers

In [None]:
firstAttentionModule = encoder.layers[0].layers[0].module
encoder.layers[0].layers[0](x) == x + firstAttentionModule(x)

This achieves the required x + attention(x). We are currently not applying a layer norm. We need to do that. Now onto the MLP

In [None]:
firstMLPModule = encoder.layers[0].layers[1].module
attentionOutput = x + firstAttentionModule(x)
encoder.layers[0].layers[1](attentionOutput) == attentionOutput + firstMLPModule(attentionOutput)

This also makes sense. As a final check, let's make sure that the MLP calculates precisely what we want

In [None]:
print(firstMLPModule.layers[0](x).shape)
print(firstMLPModule.layers[1](firstMLPModule.layers[0](x)) >= 0)
firstMLPModule.layers[2](firstMLPModule.layers[1](firstMLPModule.layers[0](x)))

So finally we need to make sure that the full block calculates what we want

In [None]:
encoder.layers[0](x) == attentionOutput + firstMLPModule(attentionOutput)

Now if this block works then we can be assured that the full encoder, which is just these encoder blocks copied `N` times, should work. Thus the final output of the un-normalised encoder without dropouts is just `encoderoutput`

In [None]:
encoderoutput

Now onto implementing the layer norm and dropouts. Layernorm is just like batchnorm but you normalise over the layer instead of normalising over the batch. The layernorm module in PyTorch takes the output dimensions (called normalised_shape in the documentation). This is a little tricky to understand from the documentation so let's try to do it ourselves and see what it spits out

In [None]:
layerNormTrial = nn.LayerNorm(dmodel)

In [None]:
x.sum(-1, keepdims=True)/4

In [None]:
layerNormTrial(x).sum(-1, keepdims=True)/4

So this is normalising in the last dimension, which is the right thing to do. However this does raise a documentation question. How do I normalise instead in the context instead of the embedding space?

Anyway, we need to add this normalisation to each sublayer. That is easily achieved by only changing the sublayer code while keeping the other parts unchanged. We did already anticipate that if you notice we had a `self.norm = None` in the definition of `Sublayer`

In [None]:
class Sublayer(nn.Module):
    def __init__(self, module):
        super().__init__()
        self.module = module
        self.norm = nn.LayerNorm(dmodel) 
    def forward(self, x):
        return self.norm(x + self.module(x))

In [None]:
encoder = Encoder(N, context_length, dmodel, h)
encoder.layers

There is also a dropout that is employed. Let's have a look at that. The idea there is that at the forward pass, some of the elements of the input tensor are randomly set to zero with a probability that you control. So, time for a new hyperparameter. The paper sets as p = 0.1. A 10% dropout rate seems a little high to me but we can test the performance when we actually do a training run later.

The position of the dropout is slightly unclear to me. They say "We apply dropout to the output of each sub-layer, before it is added to the sub-layer input and normalized." My interpretation is the following code. Notice I apply dropout to `self.module(x)`

In [None]:
class Sublayer(nn.Module):
    def __init__(self, module, p):
        super().__init__()
        self.module = module
        self.dropout = nn.Dropout(p)
        self.norm = nn.LayerNorm(dmodel) 
    def forward(self, x):
        return self.norm(x + self.dropout(self.module(x)))

The full code now needs to be modified to incorporate this addition `p`. This is the result

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, context_length, dmodel, h):
        super().__init__()
        assert dmodel % h == 0
        self.context_length = context_length
        self.dmodel = dmodel
        self.h = h
        self.dk = self.dmodel // self.h
        self.linears = nn.ModuleList([nn.Linear(dmodel, context_length) for _ in range(4)]) # One extra because of final linear in MHA
    def attention(self, query, key, value):
        dk = key.shape[-1]
        assert query.shape[-1] == key.shape[-1]
        assert key.shape[-2] == value.shape[-2]
        score = torch.softmax(query @ key.transpose(-2,-1), -2) / math.sqrt(dk)
        attention = score @ value
        return attention
    def forward(self, x, mem = None):
        query, key, value = [lin(x).view(-1, self.h, self.context_length, self.dk) for lin in self.linears[:-1]]
        attentionScore = self.attention(query, key, value).view(-1, self.context_length, self.dmodel)
        return self.linears[-1](attentionScore)

class MLP(nn.Module):
    def __init__(self, dmodel, mlp_feature_dim):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.Linear(dmodel, mlp_feature_dim, bias=True),
            nn.ReLU(),
            nn.Linear(mlp_feature_dim, dmodel, bias=True)
                      ])
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

class Sublayer(nn.Module):
    def __init__(self, module, p):
        super().__init__()
        self.module = module
        self.dropout = nn.Dropout(p)
        self.norm = nn.LayerNorm(dmodel) 
    def forward(self, x):
        return self.norm(x + self.dropout(self.module(x)))
    
class EncodeBlock(nn.Module):# This is currently un-normalised
    def __init__(self, context_length, dmodel, h, p):
        super().__init__()
        self.dmodel = dmodel
        self.h = h
        self.layers = nn.ModuleList([
            Sublayer(MultiHeadAttention(context_length, dmodel, h), p),
            Sublayer(MLP(dmodel, mlp_feature_dim), p)
        ])
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x
    
def clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

class Encoder(nn.Module):
    def __init__(self, N, context_length, dmodel, h, p):
        super().__init__()
        self.layers = clones(EncodeBlock(context_length, dmodel, h, p), N)
    def forward(self,x):
        for layer in self.layers:
            x = layer(x)
        return x

In [None]:
# HYPERS --------------------------
context_length = 4
n_embd = 4
dmodel = 4
mlp_feature_dim = 32
h = 2
N = 2
pdropout = 0.1
# ---------------------------------

In [None]:
encoder = Encoder(N, context_length, dmodel, h, pdropout)
encoder.layers

In [None]:
encoder(x).shape

Now it is time to write the decoder stack. The only change in this is the addition of one more sublayer where instead of having an attention block and a MLP, there are two attention blocks and an MLP. The first attention block does exactly what you expect. The second one is a little weird. It performs the attention on the output of the encoder. To be more precise, it takes computes the query and key from the output of the encoder and applies it to the values computed from the input to the decoder

In order to achieve this, we introduce a new `mem` variable in forward. We will have to write the `DecodeBlock` in a slightly different way where we cannot use `x = lin(x) for lin in self.layers`

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, context_length, dmodel, h):
        super().__init__()
        assert dmodel % h == 0
        self.context_length = context_length
        self.dmodel = dmodel
        self.h = h
        self.dk = self.dmodel // self.h
        self.linears = nn.ModuleList([nn.Linear(dmodel, context_length) for _ in range(4)]) # One extra because of final linear in MHA
    def attention(self, query, key, value):
        dk = key.shape[-1]
        assert query.shape[-1] == key.shape[-1]
        assert key.shape[-2] == value.shape[-2]
        score = torch.softmax(query @ key.transpose(-2,-1), -2) / math.sqrt(dk)
        attention = score @ value
        return attention
    def forward(self, x, mem = None):
        if mem != None:
            query, key = [lin(mem).view(-1, self.h, self.context_length, self.dk) for lin in self.linears[:2]]
            value = self.linears[2](x).view(-1, self.h, self.context_length, self.dk)
        else:
            query, key, value = [lin(x).view(-1, self.h, self.context_length, self.dk) for lin in self.linears[:-1]]
        attentionScore = self.attention(query, key, value).view(-1, self.context_length, self.dmodel)
        return self.linears[-1](attentionScore)

class MLP(nn.Module):
    def __init__(self, dmodel, mlp_feature_dim):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.Linear(dmodel, mlp_feature_dim, bias=True),
            nn.ReLU(),
            nn.Linear(mlp_feature_dim, dmodel, bias=True)
                      ])
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

class Sublayer(nn.Module):
    def __init__(self, module, p):
        super().__init__()
        self.module = module
        self.dropout = nn.Dropout(p)
        self.norm = nn.LayerNorm(dmodel) 
    def forward(self, x, mem = None):
        if mem != None:
            self.norm(x + self.dropout(self.module(x, mem)))
        return self.norm(x + self.dropout(self.module(x)))
    
class EncodeBlock(nn.Module):# This is currently un-normalised
    def __init__(self, context_length, dmodel, h, p):
        super().__init__()
        self.dmodel = dmodel
        self.h = h
        self.layers = nn.ModuleList([
            Sublayer(MultiHeadAttention(context_length, dmodel, h), p),
            Sublayer(MLP(dmodel, mlp_feature_dim), p)
        ])
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x
    
def clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

class Encoder(nn.Module):
    def __init__(self, N, context_length, dmodel, h, p):
        super().__init__()
        self.layers = clones(EncodeBlock(context_length, dmodel, h, p), N)
    def forward(self,x):
        for layer in self.layers:
            x = layer(x)
        return x

In [None]:
class DecodeBlock(nn.Module):# This is currently un-normalised
    def __init__(self, context_length, dmodel, h, p):
        super().__init__()
        self.dmodel = dmodel
        self.h = h
        self.layers = nn.ModuleList([
            Sublayer(MultiHeadAttention(context_length, dmodel, h), p),
            Sublayer(MultiHeadAttention(context_length, dmodel, h), p),
            Sublayer(MLP(dmodel, mlp_feature_dim), p)
        ])
    def forward(self, x, mem):
        x = self.layers[0](x)      # The attention layer with everything on decoder input
        x = self.layers[1](x, mem) # The attention layer with Q and K on encoder output
        x = self.layers[2](x)      # FFNet
        return x
    
class Decoder(nn.Module):
    def __init__(self, N, context_length, dmodel, h, p):
        super().__init__()
        self.layers = clones(DecodeBlock(context_length, dmodel, h, p), N)
    def forward(self, x, mem):
        for layer in self.layers:
            x = layer(x, mem)
        return x

In [None]:
decoder = Decoder(N, context_length, dmodel, h, pdropout)

In [None]:
decoder(x, encoder(x))

We now have both, an encoder and a decoder. Time to put all of these together. We still have to implement as mask in the decoder. We can come back to it later.

In [None]:
decoder.layers

In [None]:
class PositionalEncoding:
    def __init__(self, context_length, n_embd):
        self.pos = torch.zeros((context_length, n_embd))
        self.position = torch.arange(0,context_length)
        self.divisionFactor = 10000**(-torch.arange(0, n_embd, 2)/32)
        self.pos[:, ::2] = torch.sin(self.position.view(-1,1) * self.divisionFactor)
        self.pos[:,1::2] = torch.cos(self.position.view(-1,1) * self.divisionFactor)
class Transformer(nn.Module):
    def __init__(self, context_length, vocab_dim, n_embd, dmodel, mlp_feature_dim, h, N, pdropout):
        super().__init__()
        self.penc = PositionalEncoding(context_length, n_embd)
        self.encoder = Encoder(N, context_length, dmodel, h, pdropout)
        self.decoder = Decoder(N, context_length, dmodel, h, pdropout)
        self.linear = nn.Linear(n_embd, vocab_dim)
        self.softmax = nn.Softmax()

In [None]:
vocab_dim = 128
transformer = Transformer(context_length, vocab_dim, n_embd, dmodel, mlp_feature_dim, h, N, pdropout)

In [None]:
tranformer.linear

Now let us write down the actual hyperparameters from the paper and see if the parameters are of the same order of magnitude as their paper

In [None]:
# HYPERS --------------------------
vocab_dim = 32000
context_length = 1024
n_embd = 512
dmodel = 512
mlp_feature_dim = 2048
h = 8
N = 6
pdropout = 0.1
# ---------------------------------

In [None]:
transformer = Transformer(context_length, vocab_dim, n_embd, dmodel, mlp_feature_dim, h, N, pdropout)

In [None]:
sum(p.numel() for p in transformer.parameters())

This is significantly higher than the expected 65 million. What's going wrong?

## Data