# Contents
[Transformer Top Level](#transformertop)

> [Embedding and Positional Encoding](#emb)

> [Encoder](#enc)

> [Decoder](#dec)

> [Multi-Head Attention](#mha)

> [Inference Functions](#inffunc)

[Data Management](#data)

> [Multi30k Dataset](#m30k)

> [Vocabularies](#vocab)

> [Datapipe Collation and Masking](#mask)

[Training Loop](#train)

> [Top Level](#epochs)

> [Optimization](#opt)

> [Manual Loss](#loss)

[Run Model](#run)

> [Initialization](#init)

> [Training](#modeltrain)

> [Inference](#modelinf)

[Evaluation](#eval)

In [None]:
import copy
import datetime
import math
import os
import string
import time
from functools import partial

import spacy
import torch

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as opt
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torchtext.data.metrics import bleu_score
from torchtext.vocab import build_vocab_from_iterator

# TRANSFORMER TOP LEVEL<a class="anchor" id="transformertop"></a>

Note: Docstrings will use the following dimensional nomenclature throughout this notebook

B  = batch size <br>
L  = sequence length (Ls for source length and Lt for target where applicable) <br>
E  = embedding dimension <br>
H  = number of attention heads <br>
V = source vocabulary size (Vs for source vocabulary and Vt for target where applicable)<br>

In [None]:
class Transformer(nn.Module):
    """
    The top-level class defining a transformer
    
    Attributes:
        dmodel (int): The dimensionality of the transformer's embeddings
        vocab_en (torchtext Vocab): The vocabulary object created from our English training inputs
        vocab_de (torchtext Vocab): The vocabulary object created from our German training targets
        pedrop (torch Module): A dropout layer following positional encoding
        en_embedding (torch Module): The input embedding layer for English sentences to be translated
        de_embedding (torch Module): The input embedding layer for German translations
        encoder (torch Module): The top-level class implementing the encoder stack
        decoder (torch Module): The top-level class implementing the decoder stack
        generator (torch Module): The linear layer following the decoder; outputs to softmax
    """
    
    def __init__(self, vocab_en, vocab_de, dmodel=512, dff=2048, num_heads=8, num_layers=6, dropout=0.1):
        """
        Transformer initialization function 
        Default values of input arguments follow "Attention is All You Need" (AIAYN)
        
        Args:
            dmodel (int): The dimensionality of the transformer's embeddings
            dff (int): The dimmensionality of feed-forward layers within the encoder and decoder
            num_heads (int): The number of attention heads used in both self- and cross-attention layers
            num_layers (int): The number of times encoder and decoder layers are repeated within their top-level stacks
            dropout (float): The value used for dropout layers throughout the transformer
        """
        
        super().__init__()
        
        assert dmodel % 2 == 0, 'The embedding dimension must be an even number.'
        
        self.dmodel = dmodel
        
        self.pedrop = nn.Dropout(dropout) # For use after the positional encoding
        self.embedding_en = ScaledEmbedding(len(vocab_en),dmodel)
        self.embedding_de = ScaledEmbedding(len(vocab_de),dmodel)
        self.encoder = Encoder(dmodel, dff, num_heads, num_layers, dropout)
        self.decoder = Decoder(dmodel, dff, num_heads, num_layers, dropout)
        self.generator = nn.Linear(dmodel, len(vocab_de))
        self.generator.weight = self.embedding_de.embedding.weight # Matches the final linear layer's weights to those of the embedding layer 
        
    def forward(self, batch_en, mask_en, batch_de, mask_de):
        """
        Transformer forward step
        
        Args:
            batch_en (torch Tensor): A 2D tensor of dim = (B,Ls) made up of English sentences, tokenized and indexed to a vocabulary
            mask_en (torch Tensor): A 4D training mask of dim = (B,1,1, Ls) used to ignore sentence padding in attention calculations
            batch_de (torch Tensor): A 2D tensor of dim = (B,Lt) made up of German translations, tokenized and indexed to a vocabulary
            mask_de (torch Tensor): A 4D training mask of dim = (B,1,Lt,Lt) used in attention calculations, made up of a padding mask and an upper triangular look-ahead mask
            output_probs (boolean): Toggles the return value between the outputs of the Transformer's final linear layer (when False), and the post-softmax probabilities (when True)

        Returns:
            A tensor of dim (B,Lt,Vt) which contains either pre- or post-softmax transformer output values, depending on the value of output_probs
        """

        seq_len_en = batch_en.size(1)
        seq_len_de = batch_de.size(1)

        x_enc = self.embedding_en(batch_en) # Convert incoming English batch to embedding

        x_enc = x_enc + positional_encoding_matrix(seq_len_en, self.dmodel)
        x_enc = self.pedrop(x_enc)
        x_enc = self.encoder(x_enc, mask_en) # Pass embedding with positional encoding and mask to encoder
        
        x = self.embedding_de(batch_de) # Convert incoming German batch to embedding
        x = x + positional_encoding_matrix(seq_len_de, self.dmodel)
        x = self.pedrop(x)
        x = self.decoder(x, x_enc, mask_en, mask_de) # Pass embedding with positional encoding and mask to decoder along with encoder output
        
        x = self.generator(x)
        
        return x 
    
    def translate(self, sentence, tokenizer_en, vocab_en, tokenizer_de, vocab_de, max_sentence_len=100):
        """
        Carries out a forward step, then passes it through a log_softmax and returns the translated sentence
        
        Currently, this function only supports string-to-string translation, though I may add list handling for batches of strings in the future
        
        Args:
            sentence (str): The English sentence to be translated
            tokenizer_en: An instance of spacy.lang.en.English
            vocab_en (torchtext Vocab): The English language vocabulary
            tokenizer_de: An instance of spacy.lang.de.German
            vocab_de (torchtext Vocab): The German language vocabulary
            max_sentence_len (int): The maximum allowed length of a translated sentence
            
        Returns:
            The English sentence's German translation
        """
        
        self.eval()
        EoS_output = vocab_de(['<EoS>'])
        
        source = pre_process(sentence, tokenizer_en, vocab_en)
        
        output = ''
        output = pre_process(output, tokenizer_de, vocab_de, source=False)
        
        current_token = output
        next_token = torch.zeros(1,1,dtype=torch.int) 
        
        for _ in range(max_sentence_len):
        
            fwdpass = self.forward(source,None,output,None)

            log_probs = F.log_softmax(fwdpass, dim=2)

            next_token = gen_next_token(log_probs) # Find the token with the highes log probability
    
            output = torch.cat((output,next_token),1)
        
            # Break loop if <EoS> is generated
            if (next_token[0,0]) == EoS_output[0]:
                break
        
        tokens_str = [vocab_de.get_itos()[x] for x in output[0,:]] # Look up tokens in vocabulary
        output_str = ' '.join(tokens_str) # List to string
        
        # Remove <SoS> and <EoS> tokens 
        if output_str == '<SoS> <EoS>':
            return ''
        else:
            output_str = output_str.replace('<SoS> ','').replace(' <EoS>','') 
            
        self.train()

        return output_str
    
    def reset(self):
        """
        Re-initializes all parameters throughout the transformer
        """
        
        for param in self.parameters():
            if param.dim() > 1:
                nn.init.xavier_uniform_(param)

## Embedding and Positional Encoding <a class="anchor" id="emb"></a>

In [None]:
class ScaledEmbedding(nn.Module):
    """
    A standard embedding class with a multiplier, as defined in AIAYN
    
    Attributes:
        dmodel (int): The dimensionality of the embedding
        embedding (torch Module): A standard PyTorch Embedding layer
    """
    
    def __init__(self, vocab_len, dmodel):
        """
        Initialization function
        
        Args:
            vocab_len (int): The length of the vocabulary to which the embedding is mapped
            dmodel (int): The dimensionality of the embedding 
        """
        
        super().__init__()
        
        self.dmodel = dmodel
        self.embedding = nn.Embedding(vocab_len, dmodel)
        
    def forward(self, x):
        """
        Forward step
            
        Args:
            x (torch Tensor): Batch of sentences, tokenized and indexed to the vocabulary, dim = (B,L)
        
        Returns:
            A tensor of dim (B,L,E) containing the scaled embedding of the input batch
        """
        
        return self.embedding(x)*math.sqrt(self.dmodel)

In [None]:
def positional_encoding_matrix(seq_len, dmodel, denom_base=10000.):
    """
    Creates a positional encoding matrix to be added to an input sequence
        
    Args:
        seq_len (int): The length of the input sequence
        dmodel (int): The dimensionality of the word embedding
        denom_base (float): A number that appears as the base of an exponential in the denominator of the encoding angle

    Returns:
        The positional encoding as a tensor of dim=(S,E)
    """
    
    encmat = torch.zeros(seq_len, dmodel)
    
    posvec = torch.arange(seq_len) # Position-tracking vector
    posvec = posvec.unsqueeze(1)

    idxmat = torch.ones(seq_len, dmodel//2) / torch.pow(denom_base,torch.arange(0,dmodel,2)/dmodel) # Index iterator

    encmat[:,0:dmodel:2] = torch.sin(posvec * idxmat)
    encmat[:,1:dmodel:2] = torch.cos(posvec * idxmat)

    return encmat.unsqueeze(0)

## ENCODER<a class="anchor" id="enc"></a>

In [None]:
class Encoder(nn.Module):
    """
    The top-level class defining an Encoder
    
    Attributes:
        encoder_layers (torch ModuleList): A PyTorch ModuleList containing 'num_layers' EncoderLayer modules
    """
    
    def __init__(self, dmodel, dff, num_heads, num_layers, dropout):
        """
        Initialization function
        
        Args:
            dmodel (int): The dimensionality of the transformer's embeddings
            dff (int): The dimmensionality of encoder feed-forward layers
            num_heads (int): The number of heads used in self-attention layers
            num_layers (int): The number of times encoder layers are repeated within the top-level stack 
            dropout (float): The probability provided to each encoder layer's dropout sublayer
        """
        
        super().__init__()
        
        self.encoder_layers = nn.ModuleList([copy.deepcopy(EncoderLayer(dmodel, dff, num_heads, dropout)) for _ in range(num_layers)])
        
    def forward(self, x, mask):
        """
        Encoder forward step
            
        Args:
            x (torch Tensor): The transformer's English source input tensor, dim=(B,Ls,E)
            mask (torch Tensor): The training mask to be applied; expected to be a padding mask of dim = (B,1,1,Ls) where 1=padding
        
        Returns:
            The encoder output, a tensor of dim=(B,Ls,E) to be used in decoder cross-attention calculations
        """
        
        for layer in self.encoder_layers:
            x = layer(x, mask)

        return x

In [None]:
class EncoderLayer(nn.Module):
    """
    A single layer of the encoder, to be used in the Encoder class
    
    Attributes:
        MHA (torch Module): The multi-head self-attention layer
        ff1 (torch Module): The first of the position-wise feed-forward layers (expands)
        ff2 (torch Module): The second of the position-wise feed-forward layers (reduces)
        drop (torch Module): The dropout layer following ff2
        lnormMHA (torch Module): A PyTorch LayerNorm module, to be applied after the dropout layer and the adding of residuals
        lnormff (torch Module): A PyTorch LayerNorm module, to be applied after the dropout layer and the adding of residuals
    """
    
    def __init__(self, dmodel, dff, num_heads, dropout):
        """
        Initialization function
        
        Args:
            dmodel (int): The dimensionality of the transformer's embeddings
            dff (int): The dimmensionality of encoder feed-forward layers
            num_heads (int): The number of heads used in self-attention layers
            dropout (float): The dropout probability provided to self.drop
        """
        
        super().__init__()
        self.MHA = ManualMHA(dmodel, num_heads)
        self.ff1 = nn.Linear(dmodel,dff)
        self.ff2 = nn.Linear(dff,dmodel)
        self.drop = nn.Dropout(dropout)
        self.lnormMHA = nn.LayerNorm(dmodel)
        self.lnormff = nn.LayerNorm(dmodel)
        
        
    def forward(self,x,mask):
        """
        Forward step
        
        Note that this function follows a different norm/add sequence than AIAYN. The correct ordering seems to be a matter of open discussion and this is what worked well empirically.
            
        Args:
            x (torch Tensor): Input - English source tensor or the output of a previous encoder layer, dim=(B,Ls,E)
            mask (torch Tensor): The training mask to be applied; expected to be a padding mask of dim = (B,1,1,Ls) where 1=padding
        
        Returns:
            Encoder sublayer output - a tensor of dim=(B,Ls,E) to be sent to the following encoder layer or to the decoder
        """
        
        xres = x 
        x = self.lnormMHA(x)
        x = xres + self.drop(self.MHA(x,x,x,mask=mask)) # Multi-head attention, with dropout and residual addition
        
        xres = x
        x = self.lnormff(x)
        
        x = self.ff1(x)
        x = F.relu(x)
        
        x = xres + self.drop(self.ff2(x)) # Feed forward, with dropout and residual addition
        
        return x

## DECODER<a class="anchor" id="dec"></a>

In [None]:
class Decoder(nn.Module):
    """
    The top-level class defining an Decoder
    
    Attributes:
        decoder_layers (torch ModuleList): A PyTorch ModuleList containing 'num_layers' DecoderLayer modules
    """
    
    def __init__(self, dmodel, dff, num_heads, num_layers, dropout):
        """
        Initialization function
        
        Args:
            dmodel (int): The dimensionality of the transformer's embeddings
            dff (int): The dimmensionality of decoder feed-forward layers
            num_heads (int): The number of heads used in self- and cross-attention layers
            num_layers (int): The number of times decoder layers are repeated within the top-level stack 
            dropout (float): The probability provided to each decoder layer's dropout sublayer
        """
        
        super().__init__()
        
        self.decoder_layers = nn.ModuleList([copy.deepcopy(DecoderLayer(dmodel, dff, num_heads, dropout)) for _ in range(num_layers)])
        
    def forward(self, x, output_enc, mask_enc, mask_dec):
        """
        Decoder forward step
            
        Args:
            x (torch Tensor): The transformer's German target input tensor, dim=(B,Lt,E)
            output_enc (torch Tensor): The output of the encoder, to be used in cross-attention, dim=(B,Ls,E)
            mask_enc (torch Tensor): The mask to be applied to the encoder output, expected to be a padding mask of dim = (B,1,1,Ls) where 1=padding
            mask_dec (torch Tensor): The mask to be applied to the target, expected to be a padding mask of dim = (B,1,Lt,Lt) where 1=padding
        
        Returns:
            The decoder output, a tensor of dim=(B,Lt,E) to be passed to the final linear and softmax layers
        """

        for layer in self.decoder_layers:
            x = layer(x, output_enc, mask_enc, mask_dec)
        return(x)

In [None]:
class DecoderLayer(nn.Module):
    """
    A single layer of the decoder, to be used in the Decoder class
    
    Attributes:
        MHA_self (torch Module): The multi-head self-attention layer
        MHA_cross (torch Module): The multi-head encoder cross-attention layer
        ff1 (torch Module): The first of the position-wise feed-forward layers (expands)
        ff2 (torch Module): The second of the position-wise feed-forward layers (reduces)
        drop (torch Module): The dropout layer following ff2
        lnormMHA (torch Module): A PyTorch LayerNorm module, to be applied after the dropout layer and the adding of residuals
        lnormff (torch Module): A PyTorch LayerNorm module, to be applied after the dropout layer and the adding of residuals
    """
    
    def __init__(self, dmodel, dff, num_heads, dropout):
        """
        Initialization function
        
        Args:
            dmodel (int): The dimensionality of the transformer's embeddings
            dff (int): The dimmensionality of decoder feed-forward layers
            num_heads (int): The number of heads used in self- and cross-attention layers
            num_layers (int): The number of times decoder layers are repeated within the top-level stack 
            dropout (float): The probability provided to each decoder layer's dropout sublayer
        """
        
        super().__init__()
        self.MHA_self = ManualMHA(dmodel, num_heads)
        self.MHA_cross = ManualMHA(dmodel, num_heads)
        self.ff1 = nn.Linear(dmodel,dff)
        self.ff2 = nn.Linear(dff,dmodel)
        self.drop = nn.Dropout(dropout)
        self.lnormMHA = nn.LayerNorm(dmodel)
        self.lnormff = nn.LayerNorm(dmodel)
        
    def forward(self, x, output_enc, mask_enc, mask_dec):
        """
        Forward step
            
        Args:
            x: Input - German target tensor or the output of a previous decoder layer, dim=(B,Lt,E)
            output_enc: The output of the encoder, to be used in all cross-attention calculations, dim=(B,Ls,E)
            mask_enc: The mask to be applied to the encoder output, expected to be a padding mask of dim = (B,1,1,Ls) where 1=padding
            mask_dec: The mask to be applied to the target, expected to be a padding mask of dim = (B,1,Lt,Lt) where 1=padding
        
        Returns:
            Decoder sublayer output - a tensor of dim=(B,Lt,E) sent to the following decoder layer or to the final linear and softmax layers
        """

        xres = x 
        x = self.lnormMHA(x)
        x = xres + self.drop(self.MHA_self(x,x,x,mask=mask_dec)) # Multi-head attention, with dropout and residual addition
        
        xres = x 
        x = self.lnormMHA(x)
        x = xres + self.drop(self.MHA_cross(x,output_enc,output_enc,mask=mask_enc)) # Multi-head attention, with dropout and residual addition
        
        xres = x
        x = self.lnormff(x)
        
        x = self.ff1(x)
        x = F.relu(x)
        
        x = xres + self.drop(self.ff2(x)) # Feed forward, with dropout and residual addition
        
        return x

## MULTI-HEAD ATTENTION<a class="anchor" id="mha"></a>

In [None]:
class ManualMHA(nn.Module):
    """
    Implements Multi-head Attention (MHA), used in both self- and cross-attention calculations within the encoder and decoder
    
    Attributes:
        dmodel (int): The dimensionality of the transformer's embeddings
        num_heads (int): The number of attention heads 
        wQ (torch Module): A PyTorch Linear module with weights mapping from input to queries, to be split across heads
        wK (torch Module): A PyTorch Linear module with weights mapping from input to keys, to be split across heads
        wV (torch Module): A PyTorch Linear module with weights mapping from input to values, to be split across heads
        wO (torch Module): The matrix mapping from the concatenated outputs of all attention heads to the final MHA output value
    """
    
    def __init__(self, dmodel, num_heads):
        """
        Initialization function
        
        Args:
            dmodel (int): The dimensionality of the transformer's embeddings
            num_heads (int): The number of attention heads 
        """
        
        super().__init__()
        
        assert dmodel % num_heads == 0
        
        self.dmodel = dmodel
        self.num_heads = num_heads
        
        self.wQ = nn.Linear(dmodel, dmodel, bias=False) # The concatenation of QKV mappings across all heads, to be split up in the forward pass
        self.wK = nn.Linear(dmodel, dmodel, bias=False)
        self.wV = nn.Linear(dmodel, dmodel, bias=False)
        
        self.wO = nn.Linear(dmodel, dmodel, bias=False) # The concatenation of attention mappings from all heads, to be provided to the output
        
    def forward(self, xQ, xK, xV, mask=None):
        
        """
        Forward step, performs the attention calculation
        If xQ == xK == xV, this is self-attention
    
             
        Args:
            xQ: The input to be multiplied by wQ to create the query tensor, xQ dim=(B,L,E)
            xK: The input to be multiplied by wK to create the key tensor, xK dim=(B,L,E)
            xV: The input to be multiplied by wV to create the value tensor, xV dim=(B,L,E)
            mask: The mask to be applied to the argument of the softmax function, dimensionality differs according to self- and cross-attention
            
        Returns:
            A tensor of attention values to be sent to the next sublayer of the encoder or decoder, dim=(B,L,E)
        """
        
        batch_size = xQ.size(0)
        max_seq_Q = xQ.size(1) # The length of the maximum sequence in the provided data
        max_seq_KV = xK.size(1) # This will be equivalent to max_seq_Q for self-attention calculations, but not necessarily for cross attention
        
        dQKV = self.dmodel // self.num_heads

        Q = self.wQ(xQ)
        Q = Q.contiguous().view(batch_size, max_seq_Q, self.num_heads, dQKV)
        Q = Q.permute(0,2,1,3)
        
        K = self.wK(xK)
        K = K.contiguous().view(batch_size, max_seq_KV, self.num_heads, dQKV)
        K = K.permute(0,2,1,3)
        
        V = self.wV(xV)
        V = V.contiguous().view(batch_size, max_seq_KV, self.num_heads, dQKV)
        V = V.permute(0,2,1,3)
        
        sm_arg = torch.matmul(Q, K.transpose(2,3))/math.sqrt(dQKV) # Argument to softmax function
        
        if mask!=None:
            mask.to(xQ.device) # This line is not necessary as long as I'm running everything on CPU, but a useful placeholder for possible future work
            sm_arg = sm_arg - mask*1e9

        sm = F.softmax(sm_arg, dim = -1) # Takes softmax across rows of the Q, K^T matrix product
        
        attn = torch.matmul(sm, V) # Attention matrix, split across heads

        attn_cat = torch.flatten(attn.permute(0,2,1,3), start_dim = -2) # Concatenate the values for all heads
        
        attn_out = self.wO(attn_cat)

        return attn_out            

## INFERENCE FUNCTIONS<a class="anchor" id="inffunc"></a>

In [None]:
def pre_process(sentence, tokenizer, vocab, source=True):
    """
    Converts an incoming sentence to a set of vocabulary-indexed tokens for use in inference.
        
    Args:
        sentence (str): The source sentence to be translated, ignored if source=False
        tokenizer: An instance of spacy.lang.en.English or spacy.lang.de.German
        vocab (torchtext Vocab): The vocabulary for use in indexing
        source (boolean): If True, assumes an English sentence and returns the tokenized, indexed version. If False, returns the German <SoS> token.

    Returns:
        A list of indexed token values
    """
    
    pstrip = str.maketrans('', '', string.punctuation) # Used to strip punctuation from incoming sentences
    sentence = sentence.rstrip('\n').translate(pstrip).lower()
    
    if source:
        tokens = ['<SoS>'] + [tok.text for tok in tokenizer.tokenizer(sentence)] + ['<EoS>']
    else:
        tokens = ['<SoS>']

        
    tokens_indexed = vocab(tokens)
    
    tensor_out = torch.tensor(tokens_indexed, dtype=torch.int)
    tensor_out = tensor_out.unsqueeze(0)
    
    return tensor_out

In [None]:
def gen_next_token(probs):
    """
    Selects a value to append to a transformer output during inference using greedy decoding
        
    Args:
        probs (torch Tensor): The post-softmax output of a transformer, dim=(1,L,Vt)
        
    Returns:
        The highest-probability next token
    """

    index = torch.argmax(probs[0,-1,:])
    
    # Add two dimensions for agreement with transformer output format
    index = index.unsqueeze(0)
    index = index.unsqueeze(0)
    
    return index

# DATA MANAGEMENT<a class="anchor" id="data"></a>

I'll be using Multi30k for English-to-German translation. This isn't a very large dataset and, since the aim of this project is familiarizing myself with transformer structure, I'll just load it all into memory rather than trying to set up lazy loading. I may revisit in the future.

There is a pre-built PyTorch Multi30k datapipe, but for the sake of understanding I'll create and organize the dataset myself in the ManualMulti30k class.


## MULTI30K DATASET<a class="anchor" id="m30k"></a>

In [None]:
class ManualMulti30k(Dataset):
    """
    Creates a Dataset from a subset (train, validation, or test) of Multi30k data.
    
    Attributes:
        data_en (list): A list of English sentences
        data_de (list): A list of German translations that match the sentences of data_en
    """
    
    def __init__(self, data_dir, data_spec):
        """
        Given a source directory and a specified subset of data, initializes a ManualMulti30k object
        
        Args:
            data_dir (str): A directory containing appropriate plaintext files with Multi30k data; specific formatting of those files is expected; see comments below
            data_spec (str): Accepts 'train', 'val', or 'test'
        """
        
        data_spec = data_spec.lower()
        assert data_spec == 'train' or data_spec == 'val' or data_spec == 'test'

        path_en = os.path.join(data_dir, data_spec + '.en') # This follows the file naming convention of the data used by quest.dcs.shef.ac.uk/wmt16_files_mmt/
        path_de = os.path.join(data_dir, data_spec + '.de')
        
        pstrip = str.maketrans('', '', string.punctuation) # Used to strip punctuation from incoming sentences
        
        with open(path_en) as file:
            self.data_en = [line.rstrip('\n').translate(pstrip).lower() for line in file.readlines()]
            
        with open(path_de) as file:
            self.data_de = [line.rstrip('\n').translate(pstrip).lower() for line in file.readlines()]

        assert len(self.data_en) == len(self.data_de)

    def __len__(self):
        return len(self.data_en)

    def __getitem__(self, idx):
        """
        Returns a pair of English and German sentences
        
        Args:
            idx (int): The index of the desired sentence pair

        Returns:
            A tuple of (English sentence, German translation)
        """
        
        sentence_en = self.data_en[idx]
        sentence_de = self.data_de[idx]
        
        return sentence_en, sentence_de

## VOCABULARIES<a class="anchor" id="vocab"></a>

In [None]:
def build_vocabularies(dataset, tokenizer_en, idx_en, tokenizer_de, idx_de):
    """
    Returns the English and German vocabularies
        
    Args:
        dataset (torch ConcatDataset): A dataset containing all tuples of (English sentence, German translation) to be included in the vocabulary
        tokenizer_en: An instance of spacy.lang.en.English
        idx_en (int): The index of the English sentence within the tuple
        tokenizer_de: An instance of spacy.lang.de.German
        idx_de (int): The index of the German translation within the tuple

    Returns:
        A tuple of (English, German) torchtext Vocabularies
    """

    vocab_en = build_vocab_from_iterator(
        token_gen(iter(dataset), tokenizer_en, idx_en),
        min_freq=2,
        specials=["<pad>", "<SoS>", "<EoS>", "<unk>"],
    )

    vocab_de = build_vocab_from_iterator(
        token_gen(iter(dataset), tokenizer_de, idx_de),
        min_freq=2,
        specials=["<pad>", "<SoS>", "<EoS>", "<unk>"],
    )

    vocab_en.set_default_index(vocab_en["<unk>"])
    vocab_de.set_default_index(vocab_de["<unk>"])

    return vocab_en, vocab_de

In [None]:
def token_gen(data_iter, tokenizer, language_idx):
    """
    A generator that yields tokenized versions of a sentence
        
    Args:
        data_iter (iterator): An iterator over a ManualMulti30k dataset
        tokenizer: A Spacy tokenizer (spacy.lang.en.English or spacy.lang.de.German)
        language_idx (int): The element of the tuple (English sentence, German translation) to be tokenized

    Yields:
        A list of tokens from the selected sentence
    """

    for sentence_pair in data_iter:
        sentence = sentence_pair[language_idx]
        yield [tok.text for tok in tokenizer.tokenizer(sentence)]

## DATAPIPE COLLATION AND MASKING<a class="anchor" id="mask"></a>

In [None]:
def collate(batch, tokenizer_en, vocab_en, vocab_de, tokenizer_de):
    """
    A DataLoader collate_fn that tokenizes and pads the incoming batch
        
    Args:
        batch (list): A list of (English, German) sentence pair tuples as provided by the ManualMulti30k class
        tokenizer_en: An instance of spacy.lang.en.English
        vocab_en (torchtext Vocab): The English language vocabulary
        tokenizer_de: An instance of spacy.lang.de.German
        vocab_de (torchtext Vocab): The German language vocabulary
        
    Returns:
        A batch of data in the form of two tensors (English, German) of tokenized values 
    """
    
    batch_len = len(batch)
    seq_len_en = 1 # This variable tracks the maximum length of a sentence in the batch
    seq_len_de = 1
    
    en_out = torch.zeros(batch_len, seq_len_en, dtype=torch.int) 
    de_out = torch.zeros(batch_len, seq_len_de, dtype=torch.int) 
    
    idx = 0
    
    for (en_sentence, de_sentence) in batch:
        
        en_tokens = ["<SoS>"] + [str(token) for token in tokenizer_en(en_sentence)] + ["<EoS>"]
        en_processed = torch.Tensor(vocab_en(en_tokens)) # Numericalize the tokenized sentence

        if len(en_processed) > seq_len_en: # If the last sentence processed is longer than any seen so far, pad the output tensors accordingly
            padding = len(en_processed) - seq_len_en
            seq_len_en = len(en_processed)
            en_out = F.pad(en_out, (0,padding), "constant", 0)
            
        en_out[idx,:len(en_processed)] = en_processed # Assign the numericalized data to a row of the output tensor
 
        de_tokens = ["<SoS>"] + [str(token) for token in tokenizer_de(de_sentence)] + ["<EoS>"]
        de_processed = torch.Tensor(vocab_de(de_tokens)) # Numericalize the tokenized sentence
        
        if len(de_processed) > seq_len_de: # If the last sentence processed is longer than any seen so far, pad the output tensors accordingly
            padding = len(de_processed) - seq_len_de
            seq_len_de = len(de_processed)
            de_out = F.pad(de_out, (0,padding), "constant", 0) 
            
        de_out[idx,:len(de_processed)] = de_processed # Assign the numericalized data to a row of the output tensor
        
        idx += 1
    
    mask_en = create_padding_mask(en_out)
    
    mask_de = create_padding_mask(de_out) + create_lookahead_mask(seq_len_de)
    mask_de = (mask_de!=0).int() # Correction for the previous line, which introduced values >1 in areas of overlap
    
    return (en_out, mask_en, de_out, mask_de)

In [None]:
def create_padding_mask(input_tens):
    """
    Creates and applies a padding mask to be used with input to a multi-head attention's softmax function
    
    Args:
        input_tens (torch Tensor): The tensor for which the mask will be used, dim=(B,L)
    
    Returns:
        mask: A tensor of dim=(B,1,1,Ls) that masks padded values (padding=1)
    """

    mask = (input_tens==0).int()
    mask = mask.unsqueeze(1) # Add a dimension to be used in an outer product with a second sequence
    mask = mask.unsqueeze(1) # Add a dimension to broadcast across heads
    
    return mask

In [None]:
def create_lookahead_mask(seq_len):
    """
    Creates a look-ahead mask that keeps a decoder from using later words in a sequence in its attention calculations
    
    Args:
        seq_len (int): The length of the longest sequence in the batch to be masked
    
    Returns:
        mask: A tensor of dim=(B,1,Lt,Lt) that masks the final two dimensions with an upper triangular matrix (masking=1)
    """
    
    mask = torch.triu(torch.ones(seq_len,seq_len), diagonal=1) # Upper triangular matrix to replace the inner two (equal) dimensions of the input
    
    mask = mask.unsqueeze(0) # Add two dimensionts broadcast across batches and heads
    mask = mask.unsqueeze(0)
    
    return mask

# TRAINING LOOP<a class="anchor" id="train"></a>

## TOP LEVEL<a class="anchor" id="epochs"></a>

In [None]:
def train_model(transformer, train_dl, valid_dl, criterion, optimizer, scheduler, num_epochs, batch_reporting=10, clip_value=100., monitoring=False):
    """
    Executes the entire training process for our transformer model
    
    Args:
        transformer (torch Module): The transformer to be trained
        train_dl (torch DataLoader): The training data's DataLoader
        valid_dl (torch DataLoader): The validation data's DataLoader
        criterion (torch Module): A module storing the loss function and associated information
        optimizer (torch Adam): The Adam optimizer
        scheduler (torch lr_scheduler.LambdaLR): The learning rate scheduler
        num_epochs (int): The number of epochs to execute 
        batch_reporting (int): The number of batches to execute within an epoch before printing a status update
        clip_value (float): The max norm value to which all gradients will be clipped 
        monitoring (boolean): If True, a sample source, target, and model output will be printed as part of the batch_reporting status update
    """
    
    dt_now = datetime.datetime.now()
    dt_now = dt_now.strftime("%Y%m%d_%H%M")
    savedir = f"Training_Run_{dt_now}"
    
    # Create directory with date/timestamp for saving state dict after each epoch
    
    if not os.path.exists(savedir):
        os.makedirs(savedir)
    
    # Begin training run
    
    for i in range(num_epochs):
        tstart = time.time()
        
        print(f"Epoch {i}, beginning training run")
        transformer.train()
        train_loss = single_epoch(transformer, train_dl, criterion, optimizer, scheduler, batch_reporting, clip_value=clip_value, train_mode=True, monitoring=monitoring)
        
        print(f"Epoch {i}: beginning validation run")
        transformer.eval()
        with torch.no_grad():
            valid_loss = single_epoch(transformer, valid_dl, criterion, optimizer, scheduler, batch_reporting, monitoring=monitoring)
        
        trun = time.time() - tstart
        
        print(f"Epoch {i}: Training loss={train_loss}, Validation loss={valid_loss}, Epoch Runtime={trun}")
        print()
        
        saveloc = os.path.join(savedir, f"Epoch_{i}_Statedict.pt")
        
        torch.save(transformer.state_dict(),saveloc)
        
    print("Training complete.")

In [None]:
def single_epoch(transformer, data_dl, criterion, optimizer, scheduler, batch_reporting, train_mode=False, clip_value=100, monitoring=False):
    """
    Handles execution of a single epoch.
    
    Args:
        transformer (torch Module): The transformer to be trained
        data_dl (torch DataLoader): The training or validation data's DataLoader
        criterion (torch Module): A module storing the loss function and associated information
        optimizer (torch Adam): The Adam optimizer
        scheduler (torch lr_scheduler.LambdaLR): The learning rate scheduler
        batch_reporting (int): The number of batches to execute within an epoch before printing a status update 
        train_mode (boolean): If True, the transformer will update its weights during the epoch
        clip_value (float): The max norm value to which all gradients will be clipped 
        monitoring (boolean): If True, a sample source, target, and model output will be printed as part of the batch_reporting status update
    
    Returns:
        The loss calculation for the epoch as a float
    """
    
    total_loss = 0.0
    
    batchnum = 0
    
    tstart = time.time()
    
    for batch in data_dl:
        
        batch_en = batch[0]
        mask_en = batch[1] 
        batch_de = batch[2]
        mask_de = batch[3]

        outputs = transformer(batch_en, mask_en, batch_de, mask_de)
        outputs = outputs[:,:-1,:]
        
        vocab_size = outputs.size(-1)
        outputs_to_loss = outputs.contiguous().view(-1, vocab_size) # Resize into a 2D matrix that can be pased to CrossEntropyLoss
        
        target = batch_de[:,1:]
        target_to_loss = target.contiguous().view(-1)

        loss = criterion(outputs_to_loss, target_to_loss.long())
        
        optimizer.zero_grad()

        if train_mode:
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_value) # I found that this worked well in later-stage training
            optimizer.step()
            scheduler.step()
        
        total_loss += loss.item()
        
        if batchnum % batch_reporting == 0:
            trun = time.time() - tstart
            print(f"Batch {batchnum} complete, runtime {trun} sec")

            if monitoring:
                source_sentence = batch_en[0]
                target_sentence = batch_de[0]
                
                source_sentence_tokens = [vocab_en.get_itos()[x] for x in source_sentence] # Look up tokens in English vocabulary
                target_sentence_tokens = [vocab_de.get_itos()[x] for x in target_sentence] # Look up tokens in German vocabulary
                
                log_probs = F.log_softmax(outputs, dim=2) # Conversion from model output to probabilities
                next_tokens = torch.argmax(log_probs,2) # Most likely next token values
                
                output_tokens = [vocab_de.get_itos()[x] for x in next_tokens[0,:]] 
                target_tokens = [vocab_de.get_itos()[x] for x in target[0,:]]
                
                print()
                print(f" - Monitoring Enabled - ")
                if train_mode:
                    print(f"Max Gradient Norm: {max([param.grad.data.norm(2).item()for param in model.parameters()])}")
                print(f"Source from data: {source_sentence_tokens}")
                print(f"Taget from data: {target_sentence_tokens}")
                print(f"Model output: {output_tokens}")
                print(f"Target tokens: {target_tokens}")
                print()
            
            tstart = time.time()
        
        batchnum += 1
        
    # This is averaged across sentences rather than tokens, so the train loss can wind up being continually higher than the validation loss.
    average_loss = total_loss/len(data_dl)
    
    return average_loss

## OPTIMIZATION<a class="anchor" id="opt"></a>

In [None]:
def gen_scheduler(transformer, dmodel, start_step=0, warmup_steps=4000, betas=(0.9,0.98), eps=1e-9):
    """
    Learning rate scheduler for the Adam optimizer
        
    Args:
        transformer (torch Module): The transformer linked to the optimizer
        dmodel (int): The dimensionality of the transformer's embeddings
        start_step (int): The step at which to start the scheduler
        warmup_steps (int): The step at which the schedule shifts from a linearly increasing learning rate to a decreasing one
        betas (tuple) = Coefficients of the Adam algorithm
        eps (float) = Small value used in an Adam calculation's denominator for numerical stability

    Returns:
        An instance of torch.optim.lr_scheduler.LambdaLR
    """
    
    lfunc = lambda x: learning_rate(dmodel,x + start_step, warmup_steps) 
    
    scheduler = opt.lr_scheduler.LambdaLR(optimizer, lr_lambda=lfunc, verbose=False)
    
    return scheduler

In [None]:
def learning_rate(dmodel, step_num, warmup_steps):
    """
    The AIAYN learning rate calculation for the Adam optimizer
        
    Args:
        dmodel (int): The dimensionality of the transformer's embeddings
        step_num (int): The current step's index
        warmup_steps (int): The step at which the schedule shifts from a linearly increasing learning rate to a decreasing one

    Returns:
        The rate value calculated based on the current step
        
    """
    if step_num == 0:
        return dmodel**-0.5 * step_num * warmup_steps**-1.5
    else:
        rate = dmodel**(-0.5) * min(step_num**(-0.5), step_num * warmup_steps**(-1.5))

    return rate

In [None]:
def gen_optimizer(transformer, betas=(0.9,0.98), eps=1e-9):
    """
    Create the Adam optimizer
    
    Args:
        transformer (torch Module): The transformer to be linked to the optimizer
        betas (tuple) = Coefficients of the Adam algorithm 
        eps (float) = Small value used in an Adam calculation's denominator for numerical stability
        
    Returns:
        An instance of torch.optim.Adam
    """
    
    optimizer = opt.Adam(transformer.parameters(), lr=1, betas=betas, eps=eps) # Learning rate will be overwritten by scheduler
    
    return optimizer

## MANUAL LOSS<a class="anchor" id="loss"></a>

In [None]:
def ManualCEL(logits, target, smoothing=0.1, epsilon=1e-9):
    """
    An explicitly defined cross-entropy loss, unused in this script but previously compared to the PyTorch implementation for debugging
        
    Args:
        logits (torch Tensor): A tensor of dim=(B*Lt,Vt) containing a pre-softmax transformer output
        target (torch Tensor): A tensor of dim=(B*Lt) containing vocabulary indices for each token in a batch of sentences
        smoothing (float): The amount of label smoothing to apply to the calculation

    Returns:
        The cross-entropy loss calculation as a float
    """
    
    vocab_size = logits.size(-1)
    
    target = F.one_hot(target, num_classes=vocab_size)

    target_reduce = target * (1 - smoothing - smoothing/(vocab_size-1)) # Apply label smoothing: first line
    target_smooth = target_reduce + torch.ones(target.size())*smoothing/(vocab_size-1) # Apply label smoothing: second line
    
    probs = F.softmax(logits,dim=-1) # Conversion of logits to probabilities
    probs = probs.detach()

    probs = probs.apply_(lambda x: max(1e-9, min(1 - 1e-9, x)))
    
    crossentropy = - target_smooth * torch.log(probs)

    return torch.sum(crossentropy)

# RUN MODEL<a class="anchor" id="run"></a>

## INITIALIZATION<a class="anchor" id="init"></a>

In [None]:
en_idx = 0 # Track the ordering of the sentence-pair tuples
de_idx = 1

# Create train, val, and test Datasets.

userdir = os.path.expanduser('~')
datadir = os.path.join(userdir, '.data/')

train_dataset = ManualMulti30k(datadir,'train')
val_dataset = ManualMulti30k(datadir,'val')
test_dataset = ManualMulti30k(datadir,'test')

combined_dataset = ConcatDataset([train_dataset, val_dataset, test_dataset]) # For building vocabularies

# Import tokenizers from Spacy

try:
    tokenizer_en = spacy.load('en_core_web_sm')
except IOError:
    os.system('python -m spacy download en_core_web_sm')
    tokenizer_en = spacy.load('en_core_web_sm')

try:
    tokenizer_de = spacy.load('de_core_news_sm')
except IOError:
    os.system('python -m spacy download de_core_news_sm')
    tokenizer_de = spacy.load('de_core_news_sm')

# Create English and German vocabularies
    
print("Preparing Data")
vocab_en, vocab_de = build_vocabularies(combined_dataset, tokenizer_en, en_idx, tokenizer_de, de_idx)

vocab_len_en = len(vocab_en)
vocab_len_de = len(vocab_de)

# Create DataLoaders

train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True, drop_last=True, collate_fn=partial(collate, tokenizer_en=tokenizer_en, vocab_en = vocab_en, tokenizer_de=tokenizer_de, vocab_de = vocab_de))
val_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=False, collate_fn=partial(collate, tokenizer_en=tokenizer_en, vocab_en = vocab_en, tokenizer_de=tokenizer_de, vocab_de = vocab_de))
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)

print("Data Preparation Complete")
print(f"English vocabulary length: {vocab_len_en}")
print(f"German vocabulary length: {vocab_len_de}")

In [None]:
# Initialize Transformer

weight_import = f"Trained_Weights.pt" # Change value to None to initialize a new model

model = Transformer(vocab_en, vocab_de)

if weight_import != None:
    model.load_state_dict(torch.load(weight_import))

## TRAINING<a class="anchor" id="modeltrain"></a>

In [None]:
# Apart from the gradient clip value, the numbers below correspond to the AIAYN defaults. In practice I had to do some experimenting to get good training results

num_epochs = 100
warmup_steps = 4000
start_step = 0
clip_value = 10

criterion = nn.CrossEntropyLoss(ignore_index=0, label_smoothing=0.1)
optimizer = gen_optimizer(model)
scheduler = gen_scheduler(model,model.dmodel, start_step=start_step, warmup_steps=warmup_steps)

train_model(model, train_dataloader, val_dataloader, criterion, optimizer, scheduler, num_epochs, batch_reporting=50, clip_value=clip_value, monitoring=True)

## INFERENCE<a class="anchor" id="modelinf"></a>

In [None]:
test_input = 'The cow jumped over the moon.'

test_output = model.translate(test_input, tokenizer_en, vocab_en, tokenizer_de, vocab_de)

print(test_output)

# EVALUATION<a class="anchor" id="eval"></a>

Note that using the default Multi30k test set to calculate BLEU score can take a long time. You might consider using just a subset.

In [None]:
# Computes the BLEU score for transformer output over the Multi30k test Dataset - this takes a few minutes to run

translations = []
references = []

pstrip = str.maketrans('', '', string.punctuation) # Used to strip punctuation from incoming sentences
idx=0
print(f'Calculating BLEU score')
for sentence_pair in test_dataloader:
    try:
        idx += 1
        if idx % 10 == 0:
            print(f'{idx}/{len(test_dataloader)} sentences evaluated')
        en_sentence = sentence_pair[0][0].translate(pstrip).lower()
        de_translation = model.translate(en_sentence, tokenizer_en, vocab_en, tokenizer_de, vocab_de)
        de_translation_tokens = [str(token) for token in tokenizer_de(de_translation)]

        de_reference = sentence_pair[1][0].translate(pstrip).lower()
        de_reference_tokens = [str(token) for token in tokenizer_de(de_reference)]

        translations.append(de_translation_tokens)
        references.append([de_reference_tokens])
    except:
        print(f"Translation Error: {e}")
    
print(f'Finished. BLEU score = {bleu_score(translations,references)}')