In [None]:
import torch 
import torch.nn as nn  

In [None]:
# Self attentention
class SelfAttention(nn.Module): #inherith from nn module
    def __init__(self,embed_size, heads): # dato un embedding se lo dividiamo in n parti, ciascuna di queste sarà una head (ex:emb_size=256, se splittiamo in 6 --> ogni head sarà di 32)
        super(SelfAttention,self).__init__() #initi inizializerà la classe parente
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads # // is an integer division
        assert(self.head_dim * heads = embed_size, "Embed size needs to be divisibile by heads")

        #definiamo i linear layers che andremo a passare ai volori di key, query and values
        self.values = nn.Linear(self.head_dim, self.head_dim, bias = False) # 1°termine è il valore degli input layer secondo è il valore dell'output layer (etichette)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias = False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias = False)
        self.fc_out = nn.Linear(heads * self.head_dim,embed_size) #nel primo temrine ricostruisco la dimensione dell'embedding e lo rimappo nello stesso spazio


    def forward(self, values, keys, query, mask): # !!! attenzione la query è singola! non è queries !!!
        N = query.shape[0] # definisco il numero di traing_example da mandare allo steso momento
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]  #da capire meglio, mi sembra che voglia capire le dimensioni di keyS,valueS (tante) e query (singola)

        # Split embedding into self.heads piecies
        values = values.reshape(N, value_len, self.heads, self.head_dim) #self.heads, self.head_dim è come andremo a splittarlo
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = queries.reshape(N, query_len, self.heads, self.head_dim)

        #we use EINSUM for matrix multimplication
        energy = torch.einsum("nqhd,nkhd->nhqk",[queries,keys])
        # queries shape: (N, query_len, heads, head_dim)
        # keys shape: (N, key_len, heads, head_dim)
        # energy shape: (N, heads, query_len, key_len)
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float ("-1e20")) #if the element of the mask is 0, it doesen't impact any others so we want to replace them with a -infinity, utile per la softamax
        
        attention = torch.softmax(energy / (self.embed_size **(1/2)), dim = 3) # it means that we'll normalize across the key lenght! if we look at energy shape: (N, heads, query_len, key_len) it will be across dim =3
        out = torch.einsum("nhql,nlhd->nqhd",[attention,values]) #we are colling l the dimension that we want to nmultiply across (so key_len and value_len ar both l now)
        # attention shape: (N, heads, query_len, key_len) [same of the energy shape]
        # values shape: (N, value_len, heads, head_dim)
        # we want get (N, query_len, heads, heads_dim), to do this key_len and value_len must be the same! so we ca multiply across that dim
        out = out.reshape(N, query_len, self.heads * self.head_dim) #make this reshape in order to flatten the last two dimension into one single dimension, which is the original embed size!

        #lastly send the output into a FC ! it maps the embed size to ambed size
        out = self.fc_out(out)
        return out


In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attentention = SelfAttention(embed_size, heads) #we use the attention implemented above
        
        # we have to do two batch normalization
        self.norm1 = nn.LayerNorm(embed_size)  #takes the average for every single example (more computationale expansive) (different from batchNorm tae the average from the batch and then normalize it)
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size) #it takes as the input layer the embed_size (dim = embed_size) and map into dim = forward_expansion * embed_size
            nn.ReLU()
            NN.Linear(forward_expansion * embed_size, embed_size) # this layer map into the original dimension
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        attention = self.attentention(value, key, query, mask)
        x = self.dropout(self.norm1(attention + query)) # we sum query because we're adding skip connection
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x)) # now we pass the forward and we sum the copy of the signale (x) before the FF
        return out



In [None]:
class Encoder(nn.Module):
    def __init__(self,
                 src_vocab_size, # we're gonna do the embedding 
                 embed_size,
                 num_layers,
                 heads,
                 device,
                 forward_expansion,
                 dropout,
                 max_lenght,
                 ):
        super(Encoder,self).__init__()
        self.embed_size = embed_size
        self.device = device
        self.word_embedding = nn.Embedding(src_vocab_size, embed_size)
        self.positional_embedding = nn.Embedding(max_lenght,embed_size) #map the dim from the max_lenght to embed_size
        
        
        
        self.num_layers = nn.ModuleList(
              [
                    TransformerBlock(
                          embed_size = embed_size,
                          heads = heads,
                          dropout = dropout,
                          forward_expansion = forward_expansion
                    )
              ]
        )
        
        self.dropout = nn.Dropout(dropout = dropout)
    
    def forward(self, x, mask):
        N, seq_length = x.shape
        positions = torch.arange(0,seq_length).expand(N,seq_length).to(self.device) # is for positional embeding (we create sequece from 0 up seq_length)
        out = self.dropout(self.word_embedding(x) + self.positional_embedding(positions)) # non capisco il perchè del dropout dopo la somma dei due embedding

        for layer in self.num_layers:
            out = layer(out, out, out, mask) #al the inputs in the encoder is the same ???

            return out
        



In [None]:
class Decoder(nn.Module):
    def __init__(self, embed_size, heads, forward_expansion, dropout, device):
        super(Decoder,self).__init__()
        self.attention = SelfAttention(embed_size = embed_size, heads = heads)
        self.norm = nn.LayerNorm(embed_size)
        self.transformer_block = TransformerBlock(embed_size = embed_size,
                                                  heads = heads,
                                                  dropout = dropout,
                                                  forward_expansion = forward_expansion)
        self.dropout = nn.Dropout(dropout = dropout)

    def forward(self, x, value, key, src_mask, trg_mask):
        attention = self.attention(x, x, x, trg_mask)
        query = self.dropout(self.norm(attention + x)) #skip connection
        out = self.transformer_block(value, key, query, src_mask)
        
        return out