In [1]:
# importing required libraries
import torch.nn as nn
import torch
import torch.nn.functional as F
import math,copy,re
import warnings
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt


# generate tensor of word embeddings for all possible words used (the vocabulary)
# can be used to turn each input word to a vector, only done at first encoder
class Embedding(nn.Module):
    def __init__(self, vocab_size, embed_dim):
        super(Embedding, self).__init__()
        # embed will be tensor of size vocabulary * dim of embedding
        # each word can reference this array to get the respective embedding
        self.embed = nn.Embedding(vocab_size, embed_dim)
        
    def forward(self, x):
        # generates vector of size batch_size * sentence_len * embed_dim
        return self.embed(x)
    
    
# generate and add positional encoding vectors to embedding vectors    
class PositionalEncoding(nn.Module):
    def __init__(self,max_seq_len,embed_model_dim):
        super(PositionalEncoding, self).__init__()
        self.embed_dim = embed_model_dim
        pe = torch.zeros(max_seq_len,embed_model_dim)
        
        for i in range(0,max_seq_len):
            for j in range(0,embed_model_dim):
                if (j % 2 != 0):
                    pe[i, j] = math.sin(i / (10000 ** ((2 * j)/embed_model_dim)))
                else:
                    pe[i, j] = math.cos(i / (10000 ** ((2 * (j))/embed_model_dim)))
        pe = pe.unsqueeze(0)
        
        # register as buffer so it is tracked correctly
        self.register_buffer('pe', pe)


    def forward(self, x):
        # scale
        x = x * math.sqrt(self.embed_dim)
        # add positional encoding using autograd 
        x = x + torch.autograd.Variable(self.pe[:,:x.size(1)], requires_grad=False)
        return x
    
    


In [2]:
class SelfAttention(nn.Module):
    # initialize sizes of num heads, sentence size, etc
    # create linear mappings for different tensors that will be computed
    def __init__(self, embed_dim=512, heads=8):
        super(SelfAttention, self).__init__()
        # need to both be integers
        self.embed_dim = embed_dim
        self.heads = heads
        
        self.head_dim = embed_dim // heads

        
        # create tensors for the values/keys/queries to be used
        self.values = nn.Linear(embed_dim, embed_dim)
        self.keys = nn.Linear(embed_dim, embed_dim)
        self.queries = nn.Linear(embed_dim, embed_dim)
        
        # fully connected layer
        self.fully_connected_layer = nn.Linear(embed_dim, embed_dim)
        
    # get value, key, queries tensors
    def generate_tensors(self, values, keys, query, N):
        # fill tensors with values, keys,queries
        values = self.values(values)  
        keys = self.keys(keys)
        queries = self.queries(query)

        # split embedding into num self.heads different pieces for val, key, queries
        values = values.reshape(N, values.shape[1], self.heads, self.head_dim)
        keys = keys.reshape(N, keys.shape[1], self.heads, self.head_dim)
        queries = queries.reshape(N, query.shape[1], self.heads, self.head_dim)
        
        return values, keys, queries
    
    # run calculations on keys, queries, values
    # mult queries by keys for diff words
    def mat_mult(self, values, keys, query, mask, queries, N):
        
        q_k = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

        # apply mask so that weights become 0
        if mask is not None:
            q_k = q_k.masked_fill(mask == 0, float("-1e20"))

        # normalize out values so that they sum to 1 anddivide by scaling factor
        # directly from youtube video
        attention = torch.softmax(q_k / (self.embed_dim ** (1 / 2)), dim=3)

        # mat multiply calculated attention by the values
        attention = torch.einsum("nhql,nlhd->nqhd", [attention, values])
        
        # reshape to get rid of extra dimensions
        attention = attention.reshape(
            N, query.shape[1], self.heads * self.head_dim
        )

        # run attention through fully connected layer 
        attention = self.fully_connected_layer(attention)

        return attention
        
    # combine it all together
    def forward(self, values, keys, query, mask=None):
        values, keys, queries = self.generate_tensors(values, keys, query, query.shape[0])
        return self.mat_mult(values, keys, query, mask, queries, query.shape[0])

In [3]:
# connects transformers' noramlization and attention calculations
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, expansion_factor=4, heads=8, dropout=0.2):
        super(TransformerBlock, self).__init__()

        self.attention = SelfAttention(embed_dim, heads)
        
        self.norm1 = nn.LayerNorm(embed_dim) 
        self.norm2 = nn.LayerNorm(embed_dim)
        
        # feed forward network that applies self attention to get a normalized tensor
        self.feed_forward = nn.Sequential(
                          nn.Linear(embed_dim, expansion_factor*embed_dim),
                          nn.ReLU(),
                          nn.Linear(expansion_factor*embed_dim, embed_dim)
        )

        self.dropout = nn.Dropout(dropout)
        
    # connect layers of transformer block together
    def forward(self,key,query,value, mask=None):
        attention_out = self.attention(key, query, value, mask)  
        x = attention_out + query 
        x = self.dropout(self.norm1(x)) 

        fwd = self.feed_forward(x) 
        fwd = fwd + x 
        return self.dropout(self.norm2(x)) 




# connect transformer blocks to embedding/positional encoding to create the encoder
class Encoder(nn.Module):
    def __init__(self, seq_len, vocab_size, embed_dim, num_layers=2, expansion_factor=4, heads=8):
        super(Encoder, self).__init__()
        
        self.embedding_layer = Embedding(vocab_size, embed_dim)
        self.positional_encoder = PositionalEncoding(seq_len, embed_dim)

        self.layers = nn.ModuleList([
            TransformerBlock(embed_dim, expansion_factor, heads) for i in range(num_layers)
        ])
        
    # loop through all 
    def forward(self, x, mask=None):
        # preliminary embedding 
        out = self.positional_encoder(self.embedding_layer(x))
        # iterate through all transformer blocks
        if mask is not None:
            for layer in self.layers:
                out = layer(out,out,out, mask)
        else:
            for layer in self.layers:
                out = layer(out,out,out)

        return out  


In [4]:
# decoder components contanerized
# builds off of the transformer block for simplicity
class DecoderBlock(nn.Module):
    def __init__(self, embed_dim, expansion_factor=4, heads=8, dropout=0.2):
        super(DecoderBlock, self).__init__()
        self.attention = SelfAttention(embed_dim, heads=8)
        self.norm = nn.LayerNorm(embed_dim)
        
        self.transformer_block = TransformerBlock(embed_dim, expansion_factor, heads)
        
        self.dropout = nn.Dropout(dropout)
        
    # run components
    def forward(self, key, query, x, mask=None):
        
        attention = self.attention(x,x,x,mask=mask)
        
        x = self.dropout(self.norm(attention + x))
        
        return self.transformer_block(key, query, x)

  
class Decoder(nn.Module):
    def __init__(self, target_vocab_size, embed_dim, seq_len, num_layers=2, expansion_factor=4, heads=8, dropout=0.2):
        super(Decoder, self).__init__()
        
        # create new embedding
        self.word_embedding = nn.Embedding(target_vocab_size, embed_dim)
        self.position_embedding = PositionalEncoding(seq_len, embed_dim)

        self.layers = nn.ModuleList([
                DecoderBlock(embed_dim, expansion_factor=expansion_factor, heads=heads) 
                for i in range(num_layers)
            ])
        
        # fully connected layer for the end
        self.fc = nn.Linear(embed_dim, target_vocab_size)
        self.dropout = nn.Dropout(dropout)
        


    def forward(self, x, enc_out, mask=None):
        # run embeddings and dropout before looping through decoder blocks
        x = self.position_embedding(self.word_embedding(x))
        x = self.dropout(x)
     
        # loop through layers in decoder block
        for layer in self.layers:
            x = layer(enc_out, x, enc_out, mask) 

        return torch.nn.functional.softmax(self.fc(x), dim=0)

    

# connect it all together finally    
class Transformer(nn.Module):
    def __init__(self, embed_dim, src_vocab_size, target_vocab_size, seq_length,num_layers=2, expansion_factor=4, heads=8):
        super(Transformer, self).__init__()

        # connect encoder and decoder
        self.encoder = Encoder(seq_length, src_vocab_size, embed_dim, num_layers=num_layers, expansion_factor=expansion_factor, heads=heads)
        self.decoder = Decoder(target_vocab_size, embed_dim, seq_length, num_layers=num_layers, expansion_factor=expansion_factor, heads=heads)
        
    # create matrix with zero's and 1's in a triangle
    # for preventing the use of specific words in the transformer when too early
    def generate_mask(self, target):
        N, l = target.shape
       
        return torch.tril(torch.ones((l, l))).expand(
            N, 1, l, l
        )   
    
    # run everything using mask
    def run_transformer(self,source,target):
        # encode
        encoded = self.encoder(source)
        
        # create mask
        mask = self.generate_mask(target)
        
        l = []
        
        out = target
        # for each word in the sentence
        for i in range(src.shape[1]):
            # decode
            out = self.decoder(out,encoded,mask) 
            # cut off end for next round
            out = out[:,-1,:]
     
            out = out.argmax(-1)
            # add to list
            l.append(out.item())
            # fix size
            out = torch.unsqueeze(out,axis=0)
          
        return l
    
    def forward(self, source, target):
        #initialize encoder
        encoded = self.encoder(source)
        mask = self.generate_mask(target)
        # generate decoder
        outputs = self.decoder(target, encoded, mask)
        return outputs




In [5]:
# example run through
import random 

# let 0 be sos token and 1 be eos token
source = torch.tensor([[0, random.randint(2, 10), random.randint(2, 10), 
                        random.randint(2, 10), random.randint(2, 10), 
                        random.randint(2, 10), random.randint(2, 10),
                        random.randint(2, 10), random.randint(2, 10), 
                        random.randint(2, 10), random.randint(2, 10), 1], 
                    [0, random.randint(2, 10), random.randint(2, 10), 
                        random.randint(2, 10), random.randint(2, 10), 
                        random.randint(2, 10), random.randint(2, 10),
                        random.randint(2, 10), random.randint(2, 10), 
                        random.randint(2, 10), random.randint(2, 10), 1]])
target = torch.tensor([[0, random.randint(2, 10), random.randint(2, 10), 
                        random.randint(2, 10), random.randint(2, 10), 
                        random.randint(2, 10), random.randint(2, 10),
                        random.randint(2, 10), random.randint(2, 10), 
                        random.randint(2, 10), random.randint(2, 10), 1], 
                       [0, random.randint(2, 10), random.randint(2, 10), 
                        random.randint(2, 10), random.randint(2, 10), 
                        random.randint(2, 10), random.randint(2, 10),
                        random.randint(2, 10), random.randint(2, 10), 
                        random.randint(2, 10), random.randint(2, 10), 1]])


# create transformer
model = Transformer(embed_dim=512, src_vocab_size=11, 
                    target_vocab_size=11, seq_length=12,
                    num_layers=6, expansion_factor=4, heads=8)

out = model(source, target)
model




Transformer(
  (encoder): Encoder(
    (embedding_layer): Embedding(
      (embed): Embedding(11, 512)
    )
    (positional_encoder): PositionalEncoding()
    (layers): ModuleList(
      (0-5): 6 x TransformerBlock(
        (attention): SelfAttention(
          (values): Linear(in_features=512, out_features=512, bias=True)
          (keys): Linear(in_features=512, out_features=512, bias=True)
          (queries): Linear(in_features=512, out_features=512, bias=True)
          (fully_connected_layer): Linear(in_features=512, out_features=512, bias=True)
        )
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (feed_forward): Sequential(
          (0): Linear(in_features=512, out_features=2048, bias=True)
          (1): ReLU()
          (2): Linear(in_features=2048, out_features=512, bias=True)
        )
        (dropout): Dropout(p=0.2, inplace=False)
      )
    )
  )
  (decoder): Dec