## Welcome to the tutorial on Transformer Models
As mentioned in the README, this tutorial demonstrates a pytorch-based implementation of the Transformer model. I've done my best to keep the amount of prior code knowledge to a minimum, but the code is best understood with a moderate understanding of Object Oriented Programming and some familiarity with pytorch. Many thanks to Yongrae Jo](https://github.com/dreamgonfly) for his working pytorch implementation, on which much of this code is based. 

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple
import random
import os
from collections import defaultdict

### Part 1: Loading the dataset
The dataset used in this tutorail is from [OpenNMT](https://github.com/OpenNMT/OpenNMT-py), which consists of source data in English and target data in German. As it is beyond the scope of the tutorial, we have created preprocessed data in the `data` directory and custom pytorch datasets in the helper's directory. The preprocessing results in encoded German and English sentences, where each unique word is mapped to one number (e.g. "the" = 1, "horse" = 2).

In [2]:
## Helper function for DataLoader: this function makes sure all the source/input/target encodings have the same 
## lengths by padding them with a pad index (0).
def input_target_collate_fn(batch: int):
    PAD_INDEX = 0
    sources_lengths = [len(sources) for sources, inputs, targets in batch]
    inputs_lengths = [len(inputs) for sources, inputs, targets in batch]
    targets_lengths = [len(targets) for sources, inputs, targets in batch]

    sources_max_length = max(sources_lengths)
    inputs_max_length = max(inputs_lengths)
    targets_max_length = max(targets_lengths)

    sources_padded = [sources + [PAD_INDEX] * (sources_max_length - len(sources)) for sources, inputs, targets in batch]
    inputs_padded = [inputs + [PAD_INDEX] * (inputs_max_length - len(inputs)) for sources, inputs, targets in batch]
    targets_padded = [targets + [PAD_INDEX] * (targets_max_length - len(targets)) for sources, inputs, targets in batch]

    sources_tensor = torch.tensor(sources_padded)
    inputs_tensor = torch.tensor(inputs_padded)
    targets_tensor = torch.tensor(targets_padded)

    return sources_tensor, inputs_tensor, targets_tensor

In [3]:
from helpers.datasets import IndexedInputTargetTranslationDataset
DATA_DIR = 'data/example/processed'

train_ds      = IndexedInputTargetTranslationDataset(data_dir=DATA_DIR, phase='train')
validation_ds = IndexedInputTargetTranslationDataset(data_dir=DATA_DIR, phase='val')

train_loader = torch.utils.data.DataLoader(train_ds, batch_size=20, shuffle=True, 
                                           collate_fn=input_target_collate_fn)
valid_loader = torch.utils.data.DataLoader(validation_ds, batch_size=20, shuffle=True, 
                                           collate_fn=input_target_collate_fn)

# Our sentences have been encoded to indexed inputs and outputs (where each unique word corresponds to a number) 
# Uncomment the below line to see the first tensor in the first batch
#print(list(train_loader)[0][0])

### Part 2: Building A Transformer, One Block at a Time
#### Scaled Dot Product Attention
The smallest building block of the transformer model presented in Attention is All You Need is Scaled Dot-Product attention. This class takes in the heads of the queries and keys, applies a scaled dot product each time it's called. 

#### Multi-Head Attention
While our Scaled-Dot Product attention works well for single tensors, we really want to apply that concept to multiple attention heads at once in parallel! Enter: the MultiHeadAttention module! This module does a lot of things at once. Given a tensor of multiple queries, keys, and values it:
1. Splits the inputs into separate blocks for each attention head
2. Passes the split tensors into the self attention module
3. Normalizes and applies dropout to the queries

Both of these methods are implemented in one class for simplicity - I've gone outside the usual order of defining Modules (by pytting the scaled attention function first instead of `forward`). Pytorch utilizes the `forward` method of its Modules like `__call__` in pure python (while keeping track of all the fine details like gradients in the background).

In [4]:
class MultiHeadAttention(nn.Module):

    def __init__(self, heads_count: int, 
                 d_model: int, 
                 dropout_prob: float, 
                 mode='self-attention'):
        super().__init__()

        self.d_head = d_model // heads_count
        self.heads_count = heads_count
        self.mode = mode
        self.query_projection = nn.Linear(d_model, heads_count * self.d_head)
        self.key_projection = nn.Linear(d_model, heads_count * self.d_head)
        self.value_projection = nn.Linear(d_model, heads_count * self.d_head)
        self.final_projection = nn.Linear(d_model, heads_count * self.d_head)
        self.dropout = nn.Dropout(dropout_prob)
        self.softmax = nn.Softmax(dim=3)

        self.attention = None
        self.key_projected = None
        self.value_projected = None
        
    def scaled_dot_product(self, query_heads: torch.Tensor, key_heads: torch.Tensor):
        """
        Args:
             query_heads: (batch_size, heads_count, query_len, d_head)
             key_heads: (batch_size, heads_count, key_len, d_head)
        """
        key_heads_transposed = key_heads.transpose(2, 3)  # Required for matrix multiplication
        dot_product = torch.matmul(query_heads, key_heads_transposed)
        attention_weights = dot_product / np.sqrt(self.d_head)  # Apply the scaling in the paper: sqrt(d_head)
        return attention_weights

    def forward(self, query, key, value, mask=None):
        """
        Args:
            query: (batch_size, query_len, model_dim)
            key: (batch_size, key_len, model_dim)
            value: (batch_size, value_len, model_dim)
            mask: (batch_size, query_len, key_len)
            state: DecoderState
        """
        batch_size, query_len, d_model = query.size()

        d_head = d_model // self.heads_count

        query_projected = self.query_projection(query)
       
        key_projected = self.key_projection(key)
        value_projected = self.value_projection(value)
       
        # For cache
        self.key_projected = key_projected
        self.value_projected = value_projected

        batch_size, key_len, d_model = key_projected.size()
        batch_size, value_len, d_model = value_projected.size()
        
        # Tensor.view is pytorch's implementation of "resize", which we need to apply attention to the correct 
        # dimensions
        query_heads = query_projected.view(batch_size, query_len, self.heads_count, d_head).transpose(1, 2) 
        key_heads = key_projected.view(batch_size, key_len, self.heads_count, d_head).transpose(1, 2) 
        value_heads = value_projected.view(batch_size, value_len, self.heads_count, d_head).transpose(1, 2) 

        attention_weights = self.scaled_dot_product(query_heads, key_heads) 
        
        if mask is not None:
            mask_expanded = mask.unsqueeze(1).expand_as(attention_weights)
            attention_weights = attention_weights.masked_fill(mask_expanded, -1e18)

        self.attention = self.softmax(attention_weights)  # Save attention to the object
        attention_dropped = self.dropout(self.attention)
        
        context_heads = torch.matmul(attention_dropped, value_heads)    # Scale by our context,
        context_sequence = context_heads.transpose(1, 2).contiguous()   # reshape,
        context = context_sequence.view(batch_size, query_len, d_model) # And project!
        final_output = self.final_projection(context)
        return final_output

#### Another small building block
The only other small block of a transformer model is the position-wise feed forward network - a simple
sequential model with fully connected layers and ReLU activation. 

In [5]:
class PointwiseFeedForwardNetwork(nn.Module):
    def __init__(self, d_ff, d_model, dropout_prob):
        super(PointwiseFeedForwardNetwork, self).__init__()

        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.Dropout(dropout_prob),
            nn.ReLU(),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout_prob),
        )

    def forward(self, x):
        return self.feed_forward(x)

#### Layers, Layers, and more Layers
We now need to build up our encoder and decoder from our multi-head attention! The only thing standing between us and that is the setup of the indivudal layers of the decoders and encoders. 

In [6]:
# In the paper, after each layer of the model they apply normalization. This is where we do that.
class LayerNormalization(nn.Module):

    def __init__(self, features_count, epsilon=1e-6):
        super(LayerNormalization, self).__init__()

        self.gain = nn.Parameter(torch.ones(features_count))
        self.bias = nn.Parameter(torch.zeros(features_count))
        self.epsilon = epsilon

    def forward(self, x):

        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True)

        return self.gain * (x - mean) / (std + self.epsilon) + self.bias

class Sublayer(nn.Module):
    '''A layer that applies normalization'''
    def __init__(self, sublayer, d_model):
        super().__init__()

        self.sublayer = sublayer
        self.layer_normalization = LayerNormalization(d_model)

    def forward(self, *args):
        x = args[0]
        x = self.sublayer(*args) + x
        return self.layer_normalization(x)

Below, we define an individual encoder and decoder layer: one attention layer, one dropout layer, one feedforward layer in sequence! Instead of dropout in the decoder, we use a memory attention layer for better performance.

In [7]:
class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, heads_count, d_ff, dropout_prob):
        super().__init__()

        self.self_attention_layer = Sublayer(MultiHeadAttention(heads_count, d_model, dropout_prob), d_model)
        self.pointwise_feedforward_layer = Sublayer(PointwiseFeedForwardNetwork(d_ff, d_model, dropout_prob), d_model)
        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, sources, sources_mask):
        sources = self.self_attention_layer(sources, sources, sources, sources_mask)
        sources = self.dropout(sources)
        sources = self.pointwise_feedforward_layer(sources)

        return sources

class TransformerDecoderLayer(nn.Module):
    def __init__(self, d_model, heads_count, d_ff, dropout_prob):
        super().__init__()
        self.self_attention_layer = Sublayer(MultiHeadAttention(heads_count, d_model, dropout_prob), d_model)
        self.memory_attention_layer = Sublayer(MultiHeadAttention(heads_count, d_model, dropout_prob), d_model)
        self.pointwise_feedforward_layer = Sublayer(PointwiseFeedForwardNetwork(d_ff, d_model, dropout_prob), d_model)

    def forward(self, inputs, memory, memory_mask, inputs_mask):
        inputs = self.self_attention_layer(inputs, inputs, inputs, inputs_mask)
        # The decoder gets to remember the encoder!
        inputs = self.memory_attention_layer(inputs, memory, memory, memory_mask) 
        inputs = self.pointwise_feedforward_layer(inputs)
        return inputs


#### Encoder and Decoder, Separately!
With those layers defined, we can finally build our Encoder and Decoder with as many layers as we like!

In [8]:
class TransformerEncoder(nn.Module):

    def __init__(self, layers_count, d_model, heads_count, d_ff, dropout_prob, embedding):
        super().__init__()

        self.d_model = d_model
        self.embedding = embedding
        self.encoder_layers = nn.ModuleList(
            [TransformerEncoderLayer(d_model, heads_count, d_ff, dropout_prob) for _ in range(layers_count)]  
        )  # We loop through, staacking encoder layers on top of ecah other.

    def forward(self, sources, mask):
        """
        args:
           sources: embedded_sequence, (batch_size, seq_len, embed_size)
        """
        sources = self.embedding(sources)

        for encoder_layer in self.encoder_layers:
            sources = encoder_layer(sources, mask)

        return sources


class TransformerDecoder(nn.Module):

    def __init__(self, layers_count, d_model, heads_count, d_ff, dropout_prob, embedding):
        super().__init__()

        self.d_model = d_model
        self.embedding = embedding
        self.decoder_layers = nn.ModuleList(
            [TransformerDecoderLayer(d_model, heads_count, d_ff, dropout_prob) for _ in range(layers_count)]
        )
        # These last few lines allow our model to make predictions, based on embedding weights.
        self.generator = nn.Linear(embedding.embedding_dim, embedding.num_embeddings)
        self.generator.weight = self.embedding.weight

    def forward(self, inputs, memory, memory_mask, inputs_mask=None, state=None):
        inputs = self.embedding(inputs)

        for layer_index, decoder_layer in enumerate(self.decoder_layers):
            inputs = decoder_layer(inputs, memory, memory_mask, inputs_mask)
            
        generated = self.generator(inputs)  
        return generated, state
    
    def init_decoder_state(self, **args):
        return DecoderState()
    
class DecoderState:

    def __init__(self):
        self.previous_inputs = torch.tensor([])
        self.layer_caches = defaultdict(lambda: {'self-attention': None, 'memory-attention': None})

    def update_state(self, layer_index, layer_mode, key_projected, value_projected):
        self.layer_caches[layer_index][layer_mode] = {
            'key_projected': key_projected,
            'value_projected': value_projected
        }

    def beam_update(self, positions):
        for layer_index in self.layer_caches:
            for mode in ('self-attention', 'memory-attention'):
                if self.layer_caches[layer_index][mode] is not None:
                    for projection in self.layer_caches[layer_index][mode]:
                        cache = self.layer_caches[layer_index][mode][projection]
                        if cache is not None:
                            cache.data.copy_(cache.data.index_select(0, positions))

#### Encoder and Decoder, Together - The Transformer Module
At long last, we can finally put our transformer together! We use some helper functions for to generate our masks so that we aren't paying attention to the pad tokens (0s), and then can build our transformer from our encoder and decoders!

In [9]:
# Helpers for our masks
PAD_TOKEN_INDEX = 0

def pad_masking(x, target_len):
    # x: (batch_size, seq_len)
    batch_size, seq_len = x.size()
    padded_positions = x == PAD_TOKEN_INDEX  # (batch_size, seq_len)
    pad_mask = padded_positions.unsqueeze(1).expand(batch_size, target_len, seq_len)
    return pad_mask


def subsequent_masking(x):
    batch_size, seq_len = x.size()
    subsequent_mask = np.triu(np.ones(shape=(seq_len, seq_len)), k=1).astype('bool')
    subsequent_mask = torch.tensor(subsequent_mask).to(x.device)
    subsequent_mask = subsequent_mask.unsqueeze(0).expand(batch_size, seq_len, seq_len)
    return subsequent_mask

In [10]:
class Transformer(nn.Module):

    def __init__(self, encoder, decoder):
        super().__init__()

        self.encoder = encoder
        self.decoder = decoder

    def forward(self, sources, inputs):
        batch_size, sources_len = sources.size()
        batch_size, inputs_len = inputs.size()

        sources_mask = pad_masking(sources, sources_len)
        memory_mask = pad_masking(sources, inputs_len)
        inputs_mask = subsequent_masking(inputs) | pad_masking(inputs, inputs_len)

        memory = self.encoder(sources, sources_mask)  # (batch_size, seq_len, d_model)
        outputs, state = self.decoder(inputs, memory, memory_mask, inputs_mask)  # (batch_size, seq_len, d_model)
        return outputs

### Part 3: Training the Model

Since pytorch training is outside the scope of the tutorial, I've written the training loop and accuracy metrics in the helper files, and included comments about what they do as I import them. We start with some constants that we'll use to define our model.

In [11]:
## Constants - feel free to change to see how they affect the model!
### Training
num_epochs = 10             # How many times to loop through all train/validation data
d_model = 128               # Number of dimensions for the models
layers_count = 1            # Number of layers in the encoders/decoders
heads_count = 2             # Number of attention heads (must divide d_model evenly)
d_ff = 128                  # feed-forward dimensions
dropout_prob = 0.1          # how much dropout
label_smoothing = 0.1       # how much smoothing to apply to our labels. Higher = smoother
clip_grads = True           # Clip the gradients (to prevent from exploding)
lr = 0.001                  # Learning rate for Adam optimizer
seed = 3621                 # Let our model be deterministic (set to None if you want to have an unseeded model)

### Logging
output_dir = "./models/"    # Where to save the model
save_mode = 'best'          # Save only the best model (can be changed to 'all' to save all models)

Set up all our random number generators, logging directories, and GPU (if we have one)

In [12]:
if seed is not None:
    torch.manual_seed(seed)
    torch.backends.cudnn.benchmark = False  
    np.random.seed(seed)
    random.seed(seed)

os.makedirs(output_dir, exist_ok=True)  # Makes if not exists
device = 'cuda' if torch.cuda.is_available() else 'cpu'

Load in our data to our custom dictionary (to allow the model to access it efficiently)

In [13]:
from helpers.dictionary import IndexDictionary

source_dictionary = IndexDictionary.load(DATA_DIR, mode='source')
target_dictionary = IndexDictionary.load(DATA_DIR, mode='target')

Use our sinusoidal positional encoding as they do in the paper

In [14]:
from helpers.embeddings import PositionalEncoding

source_embedding = PositionalEncoding(
    num_embeddings=source_dictionary.vocabulary_size,
    embedding_dim=d_model,
    dim=d_model)

target_embedding = PositionalEncoding(
    num_embeddings=target_dictionary.vocabulary_size,
    embedding_dim=d_model,
    dim=d_model)

Build our encoders and decoders, and then combine them into our transformer model!

In [15]:
encoder = TransformerEncoder(
        layers_count=layers_count,
        d_model=d_model,
        d_ff=d_ff,
        dropout_prob=dropout_prob,
        embedding=source_embedding, 
        heads_count=heads_count
)

decoder = TransformerDecoder(
    layers_count=layers_count,
    d_model=d_model,
    heads_count=heads_count,
    d_ff=128,
    dropout_prob=dropout_prob,
    embedding=target_embedding
)

model = Transformer(encoder, decoder)

A label smoothing loss that lets us weight getting "close enough" better, and a metric that counts the number of correct words that weren't predicted as our pad index. 

In [16]:
from helpers.loss import LabelSmoothingLoss
from helpers.Accuracy import AccuracyMetric

loss_fn = LabelSmoothingLoss(label_smoothing, vocabulary_size=target_dictionary.vocabulary_size)
accuracy_metric = AccuracyMetric()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)  # Adam optimizer (very famous ML optimizer)

A class that loops through training and validation loops.

In [17]:
from helpers.Trainer import Trainer


training_config = {
    "print_every": 1,
    "save_every": 1,
    "device": device,
    "clip_grads": True
}

trainer = Trainer(
    model=model,
    train_dataloader = train_loader,
    val_dataloader = valid_loader,
    loss_function = loss_fn, 
    metric_function = accuracy_metric,
    optimizer=optimizer,
    logger=None, 
    run_name='tutorial', 
    save_config=None, 
    save_checkpoint=f'{output_dir}/model.ckpt', 
    config = training_config
) 

In [18]:
# trainer.run(num_epochs)

## Testing the Model
We have our test data still, let's see it's predictions on some of the test set! To save time, I've run the training loop myself earlier for 10 epochs 

In [19]:
from helpers.beam import Beam

In [20]:
class Predictor:
    def __init__(self, preprocess, postprocess, model, checkpoint_filepath, max_length=30, beam_size=8):
        self.preprocess = preprocess
        self.postprocess = postprocess
        self.model = model
        self.max_length = max_length
        self.beam_size = beam_size

        self.model.eval()
        checkpoint = torch.load(checkpoint_filepath, map_location='cpu')
        self.model.load_state_dict(checkpoint)

    def predict_one(self, source, num_candidates=5):
        source_preprocessed = self.preprocess(source)
        source_tensor = torch.tensor(source_preprocessed).unsqueeze(0)  # why unsqueeze?
        length_tensor = torch.tensor(len(source_preprocessed)).unsqueeze(0)

        sources_mask = pad_masking(source_tensor, source_tensor.size(1))
        memory_mask = pad_masking(source_tensor, 1)
        memory = self.model.encoder(source_tensor, sources_mask)

        decoder_state = self.model.decoder.init_decoder_state()

        # Repeat beam_size times
        memory_beam = memory.detach().repeat(self.beam_size, 1, 1)  # (beam_size, seq_len, hidden_size)

        beam = Beam(beam_size=self.beam_size, min_length=0, n_top=num_candidates, ranker=None)

        for _ in range(self.max_length):

            new_inputs = beam.get_current_state().unsqueeze(1)  # (beam_size, seq_len=1)
            decoder_outputs, decoder_state = self.model.decoder(new_inputs, memory_beam,
                                                                            memory_mask,
                                                                            state=decoder_state)

            attention = self.model.decoder.decoder_layers[-1].memory_attention_layer.sublayer.attention
            beam.advance(decoder_outputs.squeeze(1), attention)

            beam_current_origin = beam.get_current_origin()  # (beam_size, )
            decoder_state.beam_update(beam_current_origin)

            if beam.done():
                break

        scores, ks = beam.sort_finished(minimum=num_candidates)
        hypothesises, attentions = [], []
        for i, (times, k) in enumerate(ks[:num_candidates]):
            hypothesis, attention = beam.get_hypothesis(times, k)
            hypothesises.append(hypothesis)
            attentions.append(attention)

        self.attentions = attentions
        self.hypothesises = [[token.item() for token in h] for h in hypothesises]
        hs = [self.postprocess(h) for h in self.hypothesises]
        return list(reversed(hs))

In [21]:
predictor = Predictor(
    preprocess=IndexedInputTargetTranslationDataset.preprocess(source_dictionary),
    postprocess=lambda x: ' '.join([token for token in target_dictionary.tokenify_indexes(x) if token != '<EndSent>']),
    model=model,
    checkpoint_filepath='models/model.ckpt'
)

In [22]:
source = "There is an imbalance here ."

In [24]:
!python helpers/source/predict.py --source={source} --config=helpers/source/chekpoints/

There is an imbalance here .
