In [3]:
#https://www.youtube.com/watch?v=U0s0f995w14&list=PLHTWUxPL1q6llge6cVWIf4SpZiJd3cpfw&index=11&t=678s
import ipyplot
images_list = ["img/att1.png" , "img/att2-eq.png" , "img/att3.png"]
ipyplot.plot_images(images_list, max_images=10, img_width=300)



In [4]:
import torch
import torch.nn as nn

In [3]:
class SelfAttention(nn.Module):
    def __init__ (self, embed_size, heads):
        super(SelfAttention,self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads #integer division
        
        assert (self.head_dim * heads == embed_size  ), "Embedding size needs to be divisible by heads"
        
        #Value, Keys and Queries
        self.values = nn.Linear(self.head_dim,self.head_dim, bias = False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias = False)
        self.queries = nn.Linear(self.head_dim , self.head_dim , bias = False)
        self.fc_out = nn.Linear(self.heads * self.head_dim , self.embed_size) #heads x head dims -> embed size. this is to make more clear
        
    def forward(self, value,key,queries,mask):
        N = queries.shape[0] # how many example we will send at same time
        val_len , key_len , query_len = value.shape[1] , key.shape[1] , queries.shape[1]
        
        # split embeding into self.heads pieces
        value = value.reshape(N , val_len , self.heads , self.head_dim)
        key = key.reshape(N , key_len , self.heads , self.head_dim)
        queries = queries.reshape(N , query_len , self.heads , self.head_dim)
        
        energy = torch.einsum("NQHD , NKHD -> NHQK" , [queries,key])
        #energy shape -> (N, heads, query_len, key_len)
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20")) #if elemnt in mask is zero shut that off 
            #; mask is gone y trianggular matrix ; in paper they use minus inf. this code we set to very samll value
        attention = torch.softmax(energy  / ((self.embed_size)**(1/2)) , dim = 3 ) #dim 3 -> mean targeting key_len -> so if softmax = 0.8 we pay attention on source sentece / key len
        #attention shape (N, heads, query_len, key_len) #value shape : (N, val_len, heads, head_dim)
        # out shape -> (N, query len , heads, head_dim)
        out = torch.einsum("NHQL , NLHD -> NQHD" , [attention,value] ) #key len and val len is match
        out = out.reshape(N, query_len , self.heads*self.head_dim)
        out = self.fc_out(out)
        
        return out

In [4]:
class TransformersBlock(nn.Module):
    def __init__ (self, embed_size, heads, dropout, forward_expansion):
        super(TransformersBlock,self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size) #like batch norm, but it take the avg for every single example rather than atch
        self.norm2 = nn.LayerNorm(embed_size) 
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size , forward_expansion*embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion*embed_size , embed_size))
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)
        x = self.dropout(self.norm1(query + attention)) #add skip connection
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward +x)) #add skip connection 2
        return out
        
        

In [5]:
class Encoder(nn.Module):
    def __init__(self, src_vocab_size, embed_size, num_layers, heads, device, forward_expansion, dropout, max_length):
        super(Encoder,self).__init__()
        self.embed_size = embed_size
        self.device = device
        self.word_embedding = nn.Embedding(src_vocab_size , embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)
        
        self.layers = nn.ModuleList([
            TransformersBlock(
                embed_size,
                heads,
                dropout = dropout,
                forward_expansion = forward_expansion)])
        self.dropout = nn.Dropout(dropout)
    
    def forward(self,x ,mask):
        N, seq_length = x.shape
        positions = torch.arange(0 , seq_length).expand(N, seq_length).to(device)
        out = self.dropout(self.position_embedding(positions) + self.word_embedding(x))
        
        for layer in self.layers:
            out = layer(out,out,out,mask)
        return out
        

In [6]:
class DecoderBlock(nn.Module):
    def __init__(self, embed_size , heads, forward_expansion, dropout, device):
        super(DecoderBlock,self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm = nn.LayerNorm(embed_size)
        self.transformer_block = TransformersBlock(embed_size, heads, dropout, forward_expansion)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x , value, key, src_mask , trg_mask):
        
        attention = self.attention(x,x,x,trg_mask)
        
        query = self.dropout(self.norm(attention+x))
        out = self.transformer_block(value, key ,query , src_mask)
        return out
        
        

In [7]:
class Decoder(nn.Module):
    def __init__(self, trg_vocab_size, embed_size, num_layers, heads, 
                 forward_expansion, dropout,device,max_length):
        super(Decoder,self).__init__()
        self.device = device
        self.word_embedding = nn.Embedding(trg_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)
        
        self.db = DecoderBlock(embed_size , heads, forward_expansion, dropout, device)
        self.layers = nn.ModuleList([self.db for _ in range(num_layers)])
        
        self.fc_out = nn.Linear(embed_size, trg_vocab_size)
        self.dropout = nn.Dropout()
        
    def forward(self,x, enc_out, src_mask, trg_mask):
        N, seq_length = x.shape
        positions = torch.arange(0,seq_length).expand(N,seq_length).to(device)
        x = self.dropout(self.position_embedding(positions) + self.word_embedding(x))
        
        for layer in self.layers:
            x = layer(x,enc_out,enc_out,src_mask,trg_mask)
        out = self.fc_out(x)
        return out

In [8]:
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, trg_vocab_size, 
                 src_pad_idx, trg_pad_idx, embed_size=256 , num_layers= 6, 
                forward_expansion = 4, heads = 8 , dropout = 0,
                device = "cuda" , max_length = 1000):
        super(Transformer,self).__init__()
        self.encoder = Encoder(src_vocab_size, embed_size, num_layers, heads, device, forward_expansion, dropout, max_length)
        self.decoder = Decoder(trg_vocab_size, embed_size, num_layers, heads, forward_expansion, dropout, device, max_length)
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device
    def make_src_mask(self, src):
        src_mask = (src !=self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        return src_mask.to(self.device)
    
    def make_trg_mask(self, trg):
        N, trg_len = trg.shape
        trg_mask = torch.tril(torch.ones((trg_len, trg_len) )).expand(N,1,trg_len,trg_len)
    
    def forward(self,src,trg):
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)
        enc_out = self.encoder(src,src_mask)
        out = self.decoder(trg,enc_out , src_mask , trg_mask)
        return out
        
        

In [9]:
device = "cuda"
x = torch.tensor( [[1, 5, 6, 4, 3, 9, 5, 2, 0],[1, 8, 7, 3, 4,5, 6, 7, 2]]).to(device)
trg = torch.tensor([[1, 7, 4, 3, 5, 9, 2, 0,],[1, 5, 6, 2, 4, 7, 6, 2]]).to(device)
src_pad_idx = 0
trg_pad_idx = 0
src_vocab_size = 10
trg_vocab_size = 10
model = Transformer(src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx).to(device)
#print(model)
print(trg[:, :-1])
out = model(x, trg[:, :-1])

tensor([[1, 7, 4, 3, 5, 9, 2],
        [1, 5, 6, 2, 4, 7, 6]], device='cuda:0')


In [10]:
import numpy
numpy.shape(out)

torch.Size([2, 7, 10])

In [7]:
device = "cuda"
def make_src_mask(src,src_pad_idx ):
    src_mask = (src !=src_pad_idx).unsqueeze(1).unsqueeze(2)
    return src_mask.to(device)
x = torch.tensor( [[1, 5, 6, 4, 3, 9, 5, 2, 0],[1, 8, 7, 3, 4,5, 6, 7, 2]]).to(device)

In [8]:
make_src_mask(x,10)

AttributeError: 'Tensor' object has no attribute 'src_pad_idx'