In [2]:
import torch
import torch.nn as nn

In [None]:


# some notes for learning
# basically this code provides the class for the components within the transformer and covers a few key details
# the first step is inheriting some variable from the parent class , these are embed size and head
# embed size refers to the vector representation of the embedded token
# head refers to the number of attention heads that are split within the sentence or corpus of text
# head_dim refers to the embed_size/ head , this value must be divisible. 



## the following texts will describe a single head of attention or self attention head

# keys, values and queries are important for the attention mechanisms
# query is like each token asking "which token is important for me to pay attention to?"
# the key represents each tokens "identity card", holding information about itself that it can share with others
# the value holds teh actual content or information each token has to offer once the model decides to focus on it. 

# to create these queries, keys and values vectors, we have to pass each embedded vector for the tokens through three separate linear transformations. the next stage is creating linear transformations for the key, queries and values
# each of these transformations takes the embedding vector and outputs a new vector for each word. 


# as an explanation difference between embeddings and the attention mechanism, imagine you have the word "mole" that has a different context depending on the sentence or the words surrounding it
# the embeddings vector representation would be the same across the different instances of the word or token
# however, attention is able to pull information from around the word using keys, queries and value calculations to shift the embedded vector representation into a new direction such that it can learn the new representations.

# for the queries and keys, each query and key will be multiplied via dot multiplication is a matrix and the "size" will be calculated. This value shows the tokens that align between the keys and the queries. 
# you want the multiplied values to have a probability distribution, and therefore have to have a min of 0 and max of 1. This is done using softmax to normalise.

# masking is used where tokens are allowed to attend only to earlier positions in the sequence but not later ones, therefore, the later tokens are masked out. This is done by setting later tokens a very low attention score. 
# the reason why masking is used after dot product multiplication is because you want the queries and keys for the heads to learn the future context first before it is masked out 
# the reason for masking is for training, so that the model does not cheat and use future tokens to check on its accuracy. 


# the value matrix: If you have 2 related words from the query key matrix, for example "fluffy creature", you multiply the first embedding vector by the value matrix to get a value vector. 
# this value vector is added to the embedding of the second word to shift the embedding into the high dimensional space given the "fluffy" context. 
# thats a simplification, in reality, there are multiple keys that can influence the token you are attending to. Therefore, all the relevant queries key matrices that have a high value are multiplied by the value vector and added up
# the ones that have a low probability of being associated with the specific token youre looking at will have a probability of 0 
# the vectors are added up and the added value is the "direction" that the new head is going towards. 


# these heads of attention are then run in parallel into "multi headed attention" with its own distinct keys queries and values. 


In [None]:
class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.head = heads
        self.head_dim = embed_size // heads

        assert(self.head_dim * heads == embed_size), "Embed size needs to be div by heads"

        self.values = nn.linear(self.head_dim, self.head_dim, bias = False)
        self.keys = nn.linear(self.head_dim, self.head_dim, bias = False)
        self.queries = nn.linear(self.head_dim, self.head_dim, bias = False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query, mask ):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        #split embedding into self.heads pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, key_len, self.heads, self.head_dim)

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        attention = torch.softmax((energy/(self.embed_size) ** (1/2)), dim = 3)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim 
        )

        out = self.fc_out(out)
        return out



In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)

        self.norm1 = nn.