<a href="https://colab.research.google.com/github/hbprosper/AIMS/blob/main/Labs/10.Transformer/code/tutorial_transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Tutorial: Transformer Neural Networks (TNN)
> Created Aug. 2024 for the FSU Course: *Machine Learning in Physics* <br>
> Update Nov 1 2025: Improve explanations<br>
> Update Nov 2 2025: Restructure code<br>
> H. B. Prosper<br>

Based on project by former FSU student Alex Judge<br>
Florida State University, Spring 2023 (closely follows the Annotated Transformer[1])<br>
Updated: July 4, 2023 for Terascale 2023, DESY, Hamburg, Germany<br>
Updated: March 31, 2024 HBP: load all data onto computational device<br>
Updated: November 19, 2024 HBP: for *Machine Learning in Physics course*<br>
Updated: October 27, 2025 HBP: for *Machine Learning in Physics course*

## Introduction

This tutorial describes a sequence to sequence (**seq2seq**) neural network, called the **transformer**[1], which translates one sequence of tokens to another. The tutorial follows closely the excellent description in the Annotated Transformer[2].

In natural language translation, a word, for example $\texttt{the}$, or part of a word, for example $\texttt{ly}$, could be a token. In symbolic mathematics, a token might be a mathematical function, e.g. $\texttt{sin}$. The set of tokens forms a **vocabulary** from which sequences of tokens are constructed. Typically, the vocabulary of the input sequences (the source) differs from that of the output sequences (the target). This makes sense. If one is translating from English to French, it makes sense to have different vocabularies for the two languages.

The seq2seq model
consists of three parts:

  1. The embedding layers: encodes the tokens and their relative positions within sequences. An input (i.e., source) sequence of tokens is thus mapped to a point cloud in a vector space.
  1. The transformer layers[1]: implements the syntactic and semantic analysis.
  1. The output layer: computes weights, one for every possible token in the output vocabulary of tokens, which are converted to probabilistic predictions for the next token in the output sequence given the input sequence and the current output sequence.

__Tensor Convention__
We follow the convention used in the Annotated Transformer[2] in which the batch is the first dimension in all tensors.

## Sequence to Sequence Model

### Introduction
A transformer-based seq2seq model comprises an **encoder** and a **decoder**. The encoder embeds every token in the source sequence $\boldsymbol{x}$ together with its ordinal value  in a vector space. The vectors are processed with a chain of algorithms called **attention** and the transformed vectors together with the current target sequence $\boldsymbol{t}$ or current predicted output sequence $\boldsymbol{y}$ are sent to the decoder, which embeds the targets in the same vector space. The target vectors are likewise processed with a chain of attention algorithms, while the target vectors and those from the encoder are processed with another attention algorithm. Finally, the decoder assigns a weight to every token in the target vocabulary. Using a greedy strategy, one chooses the next output token to be the one with the largest weight, that is, the most probable token. The model is **autoregressive**: the predicted token is appended to the existing predicted output sequence and the model is called again with the same source and the updated output. The procedure repeats until either the maximum output sequence length is reached or the end-of-sequence (EOS) token is predicted as the most probable token.


### Attention

When we translate from one sequence of symbols to another sequence of symbols, for example from one natural language to another,  the meaning of the sequences is encoded in the symbols, their relative order, and the degree to which a given symbol is related to the other symbols. Consider the phrases "the white house" and "la maison blanche". In order to obtain a correct translation it is important for the model to encode the fact that "la" and "maison" are strongly related, while "the" and "house" are less so. It is also important for the model to encode the strong relationship between "the" and "la", between "house" and "maison", and between "white" and "blanche". That is, the model needs to *pay attention to* grammatical and semantic facts. At least as far as we can tell that's what humans do.

The need for the model to pay attention to relevant linguistic facts is the basis of the so-called [attention mechanism](https://nlp.seas.harvard.edu/annotated-transformer/). In the encoding stage, the model associates a vector to every token that tries to capture the strength of a token's relationship to other tokens. Since this association mechanism operates within the same sequence (that is, within the same point cloud in the vector space in which the sequence is embedded) it is referred to as **self attention**. Ideally, self attention will note the fact that "la" and "maison" are strongly coupled and, ideally, that the relative positions of "maison" and "blanche" are also strongly coupled as are the relative positions of "white" and "house". In the decoding stage of the model, in addition to the self attention over the target sequences another attention mechanism should pay attention to the fact that "the" and "la", "house" and "maison", and "white" and "blanche" are strongly coupled. At a minimum, therefore, we expect a successful seq2seq model to model self attention in both the encoding and decoding phases and source to target attention in the decoding phase. The optimal way to implement this is not known, but the transformer model implements an attention mechanism, described next, which empirically appears to be highly effective.


### Prediction
As noted above the a transformer is  used *autoregressively*: given a source, i.e., input, sequence $\boldsymbol{x} = x_0, x_1,\cdots, x_k, x_{k+1}$ of length $k+2$ tokens, where $x_0 \equiv \text{<sos>}$ denotes the **start of sequence** token and $x_{k+1} \equiv \text{<eos>}$ denotes the **end of sequence** token and the current output sequence  $\boldsymbol{y}_{\lt l} = y_0, y_1,\cdots, y_{l-1}$ of length $l$ tokens, for every predicted target sub-sequence $\boldsymbol{y}_{\lt l}$ the model approximates a discrete conditional probability distribution  
\begin{align}
p_{l} & \equiv p(t_l \in v_t| \boldsymbol{x}, \boldsymbol{y}_{\lt l}),
\end{align}
 over the target vocabulary $v_t = \text{<sos>}, \text{<eos>}, v_1, \cdots, v_{m}$, of size $m$ tokens, excluding the start and end tokens. For a given predicted target sub-sequence, its probability distribution is used to pick the next token $y_l$, which is appended to current predicted output sequence $y_{\lt l}$ and the procedure is repeated until $y_l = \text{<eos>}$ or the maximum allowed output sequence length is reached.

For a vocabulary of size $m$ and a sequence of size $k$ (omitting the delimeters) every position in the sequence can be filled in $m$ ways. Therefore, there are $m^k$ possible sequences of which we want the most probable. Alas we have a bit of a computational problem. For example, for a sequence of size $k=85$ tokens and a target vocabulary of size $m = 28$ tokens there are $\sim 10^{123}$ possible sentences. Even at a trillion probability calculations per second an exhaustive search would be an utterly futile undertaking because it would take far longer to complete than the current age of the universe ($\sim 4 \times 10^{17}$ s)! Obviously, we have no choice but to use a **heuristic strategy**.

The simplest such strategy is the **greedy search** in which we choose the most probable token as the next token at position $l$.
A potentially better strategy is **beam search** in which at each prediction stage we keep track of the $n$ most probable sequences so far. At the end we pick the most probable sequence among the $n$.


### References
  1.  [Attention is all you need](https://papers.nips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf)
  1. [Annotated Transformer](https://nlp.seas.harvard.edu/annotated-transformer/)

## Installation of `mlinphysics`

### Local installation
  ```bash
      git clone https://github.com/hbprosper/mlinphysics
      cd mlinphysics
      pip install -e .
  ```

## Running on Google Colab
If on Google Colab (https://colab.research.google.com), execute cell below.

In [1]:
try:
    import google.colab
    url = f"https://raw.githubusercontent.com/hbprosper/mlinphysics/refs/heads/main"
    !wget -q {url}/clone2colab.ipynb -O clone2colab.ipynb
    %run clone2colab.ipynb
    IN_COLAB = True
except ImportError:
    IN_COLAB = False

/content
	1. uninstall mlinphysics
	2. sparse clone mlinphysics

Cloning into 'mlinphysics'...
remote: Enumerating objects: 46, done.[K
remote: Counting objects: 100% (46/46), done.[K
remote: Compressing objects: 100% (35/35), done.[K
Receiving objects: 100% (46/46), 5.41 KiB | 5.41 MiB/s, done.
remote: Total 46 (delta 0), reused 38 (delta 0), pack-reused 0 (from 0)[K
remote: Enumerating objects: 5, done.[K
remote: Counting objects: 100% (5/5), done.[K
remote: Compressing objects: 100% (5/5), done.[K
Receiving objects: 100% (5/5), 4.16 KiB | 4.16 MiB/s, done.
remote: Total 5 (delta 0), reused 1 (delta 0), pack-reused 0 (from 0)[K
/content/mlinphysics
remote: Enumerating objects: 9, done.[K
remote: Counting objects: 100% (9/9), done.[K
remote: Compressing objects: 100% (7/7), done.[K
remote: Total 9 (delta 0), reused 4 (delta 0), pack-reused 0 (from 0)[K
Receiving objects: 100% (9/9), 27.17 KiB | 3.88 MiB/s, done.
Already on 'main'
Your branch is up to date with 'origin/main

In [2]:
import os, sys
import numpy as np
import importlib
import shutil
import random
import matplotlib as mp
import matplotlib.pyplot as plt
from tqdm import tqdm

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F

# ML in physics module
import mlinphysics.nn as mlp
import mlinphysics.utils.data as dat
import mlinphysics.utils.monitor as mon
import mlinphysics.utils.tutorials as tut
import mlinphysics.utils.transformer as tnm

# update fonts

plt.rcParams.update({
  "text.usetex": shutil.which('latex') is not None,
  "font.family": "sans-serif",
  "font.sans-serif": "Helvetica",
  "font.size": 14
  })

## Computational device

In [3]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'\nComputational device: {str(DEVICE):s}')


Computational device: cpu


In [4]:
# SEED = 42
# random.seed(SEED)
# np.random.seed(SEED)
# torch.manual_seed(SEED)
# torch.cuda.manual_seed(SEED)
# torch.backends.cudnn.deterministic = True

## Constants

In [12]:
short_tutorial = True

if short_tutorial:

    if IN_COLAB:
        DATAFILE = 'seq2seq_series_2terms.txt'
        dat.download(DATAFILE)
    else:
        DATAFILE = '../data/seq2seq_series_2terms.txt'

    MAX_SEQ_LEN= 85

    # model hyperparameters
    ENC_EMB_DIM= 64    # dimension of embedding vector space
    ENC_LAYERS = 2     # number of encoder layers
    ENC_HEADS  = 8     # number of attention heads
    ENC_FF_DIM = 128   # "hidden" dimension of feed-forward network
    ENC_DROPOUT= 0.1

    DEC_EMB_DIM= 64    # dimension of embedding vector space
    DEC_LAYERS = 2     # number of decoder layers
    DEC_HEADS  = 8     # number od decoder heads
    DEC_FF_DIM = 128
    DEC_DROPOUT= 0.1

    # training hyperparameters
    BATCH_SIZE    = 32
    LEARNING_RATE = 2e-4
    NITERATIONS   = 400_000
    STEP          =    100

else:
    DATAFILE = '../data/seq2seq_series.txt'
    MAX_SEQ_LEN= 200

    # model hyperparameters
    ENC_EMB_DIM= 256   # dimension of embedding vector space
    ENC_LAYERS = 4     # number of encoder layers
    ENC_HEADS  = 8
    ENC_FF_DIM = 1024  # "hidden" dimension of feed-forward network
    ENC_DROPOUT= 0.1

    DEC_EMB_DIM= 256   # dimension of embedding vector space
    DEC_LAYERS = 4
    DEC_HEADS  = 8
    DEC_FF_DIM = 1024
    DEC_DROPOUT= 0.1

    BATCH_SIZE    = 64
    DROPOUT       = 0.1
    LEARNING_RATE = 2e-4
    NITERATIONS   = 400_000
    STEP          =    100

## Read Sequence Data

The file **seq2seq_series_2terms.txt** contains (source, target) pairs where the targets are the Taylor series expansions of the corresponding sources up to an error term of ${\cal O}(x^6)$ and the sources are functions built from one or two terms randomly sampled from the set `{exp, sin, cos, tan, sinh, cosh, tanh}`. Since the source sequences are reasonably simple functions it is possible to train a transformer model to predict their Taylor series expansions in under an hour on a GPU. The more complicated functions in the file **seq2seq_series.txt** require more time.

In [13]:
importlib.reload(tnm)

seqdata = tnm.SequenceData(DATAFILE, max_seq_len=MAX_SEQ_LEN)

	reading sequences

	sample size: 14367

0


cosh(a*x)**3 + tanh(b*x)

1 + b*x - b**3*x**3/3 + 2*b**5*x**5/15 + 3*a**2*x**2/2 + 7*a**4*x**4/8 + O(x**6)


4500


exp(a*x)*cosh(h*x)**2

1 + x**2*(a**2/2 + h**2) + x**3*(a**3/6 + a*h**2) + x**4*(a**4/24 + a**2*h**2/2 + h**4/3) + x**5*(a**5/120 + a**3*h**2/6 + a*h**4/3) + a*x + O(x**6)


9000


tan(d*x)/tanh(c*x)

d/c + x**2*(c*d/3 + d**3/(3*c)) + x**4*(-c**3*d/45 + c*d**3/9 + 2*d**5/(15*c)) + O(x**6)


13500


sinh(c*x) - cosh(m*x)

-1 - m**2*x**2/2 - m**4*x**4/24 + c*x + c**3*x**3/6 + c**5*x**5/120 + O(x**6)


	building source vocabulary
{'<pad>': 0, '<sos>': 1, '<eos>': 2, '(': 3, ')': 4, '*': 5, '**': 6, '+': 7, '-': 8, '/': 9, '0': 10, '1': 11, '2': 12, '3': 13, '4': 14, '5': 15, '6': 16, '7': 17, '8': 18, '9': 19, 'a': 20, 'b': 21, 'c': 22, 'cos': 23, 'cosh': 24, 'd': 25, 'exp': 26, 'f': 27, 'g': 28, 'h': 29, 'm': 30, 'n': 31, 'sin': 32, 'sinh': 33, 'tan': 34, 'tanh': 35, 'x': 36}

	building target vocabulary
{'<pad>': 0, '<sos>': 1, '<eos>': 2, '(': 3, ')': 4, '*': 5, '**': 6, '+': 7, '-': 8, '/': 9, '0': 10, '1': 11, '2': 12, '3': 13, '4': 14, '5': 15, '6': 16, '7': 17, '8': 18, '9': 19, 'O(x**6)': 20, 'a': 21, 'b': 22, 'c': 23, 'd': 24, 'f': 25, 'g': 26, 'h': 27, 'm': 28, 'x': 29}

	tokenize sources
 14000
	tokenize targets
 14000
	pad sequences and bracket with <sos> and <eos>

Summary
 sample size: 11850
   source sequence length:       22
   source vocabulary size:       37

   target sequence length:       85
   target vocabulary size:       30



In [None]:
seqdata.sources[0], seqdata.targets[0]

(array([ 1, 24,  3, 20,  5, 36,  4,  6, 13,  7, 35,  3, 21,  5, 36,  4,  0,
         0,  0,  0,  0,  2]),
 array([ 1, 11,  7, 22,  5, 29,  7, 13,  5, 21,  6, 12,  5, 29,  6, 12,  9,
        12,  8, 22,  6, 13,  5, 29,  6, 13,  9, 13,  7, 17,  5, 21,  6, 14,
         5, 29,  6, 14,  9, 18,  7, 12,  5, 22,  6, 15,  5, 29,  6, 15,  9,
        11, 15,  7, 20,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  2]))

## Configuration

In [None]:
ndata = len(seqdata.sources)
train_size = 11_000
test_size  =    750
val_size = ndata - train_size - test_size
val_size

100

In [None]:
importlib.reload(mlp)

# name of model
# -----------------------------------------
name = 'TNN'

# choose whether to create or load a configuration file
load_existing_config = False

if load_existing_config:
    config = mlp.Config(f'{name}.yaml')
else:
    # create new configuration
    config = mlp.Config(name, dirname='test')

    # ----------------------------------------
    # training configuration
    # ----------------------------------------
    config('train_size',  train_size)   # training dataset size
    config('val_size',    val_size)
    config('test_size',   test_size)
    config('batch_size',  BATCH_SIZE)   # number of graphs / batch
    config('monitor_step',STEP)         # monitor training every n (=10) iterations
    config('delete', True)              # if True recreate losses file before training
    config('frac', 0.01)                # save model if average loss decreases by...
                                        # ...more than a fraction "frac"
    # ----------------------------------------
    # optimizer / scheduler configuration
    # ----------------------------------------
    # a step comprises a given number of iterations
    config('n_steps', 1)                # number of training steps
    config('n_iterations', NITERATIONS)
    config('n_iters_per_step', int(config('n_iterations') / config('n_steps')))
    config('base_lr', LEARNING_RATE)    # initial learning rate
    config('gamma', 0.5)                # learning rate scale factor

    # ----------------------------------------
    # data
    # ----------------------------------------
    config('DATAFILE',    DATAFILE)
    config('MAX_SEQ_LEN', MAX_SEQ_LEN)

    # ----------------------------------------
    # model specification
    # ----------------------------------------
    # ENCODER
    config('ENC_EMB_DIM', ENC_EMB_DIM)  # dimension of embedding vector space
    config('ENC_LAYERS',  ENC_LAYERS)   # number of encoder layers
    config('ENC_HEADS',   ENC_HEADS)    # number of attention heads
    config('ENC_FF_DIM',  ENC_FF_DIM)   # "hidden" dimension of feed-forward network
    config('ENC_DROPOUT', ENC_DROPOUT)

    # DECODER
    config('DEC_EMB_DIM', DEC_EMB_DIM)  # dimension of embedding vector space
    config('DEC_LAYERS',  DEC_LAYERS)   # number of encoder layers
    config('DEC_HEADS',   DEC_HEADS)    # number of attention heads
    config('DEC_FF_DIM',  DEC_FF_DIM)   # "hidden" dimension of feed-forward network
    config('DEC_DROPOUT', DEC_DROPOUT)


config('SRC_SEQ_LEN',      seqdata.SRC_SEQ_LEN)
config('SRC_VOCAB_SIZE',   seqdata.SRC_VOCAB_SIZE)

config('TRG_SEQ_LEN',      seqdata.TRG_SEQ_LEN)
config('TRG_VOCAB_SIZE',   seqdata.TRG_VOCAB_SIZE)

config('PAD_CODE',    seqdata.PAD)
config('SOS_CODE',    seqdata.SOS)
config('EOS_CODE',    seqdata.EOS)

print('\n\tconfiguration\n')
print(config)

# learning rate scale factor

print(f'\nSave configuration to file {config.cfg_filename}\n')

config.save()


	configuration

name: TNN
file:
  losses: runs/test/TNN_losses.csv
  params: runs/test/TNN_params.pth
  init_params: runs/test/TNN_init_params.pth
  plots: runs/test/TNN_plots.png
train_size: 11000
val_size: 100
test_size: 750
batch_size: 32
monitor_step: 100
delete: true
frac: 0.01
n_steps: 1
n_iterations: 400000
n_iters_per_step: 400000
base_lr: 0.0002
gamma: 0.5
DATAFILE: ../data/seq2seq_series_2terms.txt
MAX_SEQ_LEN: 85
ENC_EMB_DIM: 64
ENC_LAYERS: 2
ENC_HEADS: 8
ENC_FF_DIM: 128
ENC_DROPOUT: 0.1
DEC_EMB_DIM: 64
DEC_LAYERS: 2
DEC_HEADS: 8
DEC_FF_DIM: 128
DEC_DROPOUT: 0.1
SRC_SEQ_LEN: 22
SRC_VOCAB_SIZE: 37
TRG_SEQ_LEN: 85
TRG_VOCAB_SIZE: 30
PAD_CODE: 0
SOS_CODE: 1
EOS_CODE: 2


Save configuration to file runs/test/TNN_config.yaml



## Datasets

In [None]:
importlib.reload(dat)

train_size = config('train_size')
val_size   = config('val_size')
test_size  = config('test_size')

# training dataset (this defines the empirical risk to be minimized)
print('training data')
train_data = dat.Dataset(seqdata.sources,
                         start=0,
                         end=train_size,
                         targets=seqdata.targets)

# a random subset of the training data to check for overtraining
# by comparing with the empirical risk from the validation set
print('training data for validation')
train_data_val = dat.Dataset(seqdata.sources,
                             start=0,
                             end=train_size,
                             targets=seqdata.targets,
                             random_sample_size=val_size)

# validation dataset (for monitoring training)
print('validation data')
val_data = dat.Dataset(seqdata.sources,
                       start=train_size,
                       end=train_size + val_size,
                       targets=seqdata.targets)

# test dataset
print('test data')
test_data= dat.Dataset(seqdata.sources,
                       start=train_size + val_size,
                       end=train_size + val_size + test_size,
                       targets=seqdata.targets)

training data
Dataset
  shape of x: torch.Size([11000, 22])
  shape of y: torch.Size([11000, 85])

training data for validation
Dataset
  shape of x: torch.Size([100, 22])
  shape of y: torch.Size([100, 85])

validation data
Dataset
  shape of x: torch.Size([100, 22])
  shape of y: torch.Size([100, 85])

test data
Dataset
  shape of x: torch.Size([750, 22])
  shape of y: torch.Size([750, 85])



In [None]:
def printme(text, data, ii):
    print(text)
    print(' source:', data[ii][0])
    print(' target:', data[ii][1])
    print()

ii = 0
printme('TRAIN_DATA', train_data, ii)
printme('TRAIN_DATA_VAL', train_data_val, ii)
printme('VAL_DATA', val_data, ii)
printme('TEST_DATA', test_data, ii)

TRAIN_DATA
 source: tensor([ 1, 24,  3, 20,  5, 36,  4,  6, 13,  7, 35,  3, 21,  5, 36,  4,  0,  0,
         0,  0,  0,  2])
 target: tensor([ 1, 11,  7, 22,  5, 29,  7, 13,  5, 21,  6, 12,  5, 29,  6, 12,  9, 12,
         8, 22,  6, 13,  5, 29,  6, 13,  9, 13,  7, 17,  5, 21,  6, 14,  5, 29,
         6, 14,  9, 18,  7, 12,  5, 22,  6, 15,  5, 29,  6, 15,  9, 11, 15,  7,
        20,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  2])

TRAIN_DATA_VAL
 source: tensor([ 1, 32,  3, 21,  5, 36,  4,  5, 23,  3, 20,  5, 36,  4,  0,  0,  0,  0,
         0,  0,  0,  2])
 target: tensor([ 1, 22,  5, 29,  7, 29,  6, 13,  5,  3,  8, 21,  6, 12,  5, 22,  9, 12,
         8, 22,  6, 13,  9, 16,  4,  7, 29,  6, 15,  5,  3, 21,  6, 14,  5, 22,
         9, 12, 14,  7, 21,  6, 12,  5, 22,  6, 13,  9, 11, 12,  7, 22,  6, 15,
         9, 11, 12, 10,  4,  7, 20,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0

## DataLoaders

In [None]:
importlib.reload(dat)

print('train data loader')
train_loader = dat.DataLoader(train_data,
                              batch_size=config('batch_size'),
                              num_iterations=config('n_iterations'))

print('train data loader for validation')
train_loader_val = dat.DataLoader(train_data_val,
                                  batch_size=len(train_data_val))

print('validation data loader')
val_loader = dat.DataLoader(val_data,
                            batch_size=len(val_data))

print('test data loader')
test_loader = dat.DataLoader(test_data,
                             batch_size=1)

train data loader
DataLoader
  Number of iterations has been specified
  maxiter:          400000
  batch_size:           32
  shuffle_step:        343

train data loader for validation
DataLoader
  maxiter:               1
  batch_size:          100
  shuffle_step:          1

validation data loader
DataLoader
  maxiter:               1
  batch_size:          100
  shuffle_step:          1

test data loader
DataLoader
  maxiter:             750
  batch_size:            1
  shuffle_step:        750



## The Model

The transformer comprises an **encoder** and **decoder**, each of which consists of one or more processing layers.

### Encoder

The encoder does the following:
 1. Each token in the source (input) sequence is encoded as a vector $\boldsymbol{t}$ in a space of $d =$ **emb_dim** dimensions. A sequence is therefore represented as a point cloud in the vector space.
 1. The position of each token is also encoded as a vector $\boldsymbol{p}$ in a vector space of the same dimension as $\boldsymbol{t}$. We can think of these vectors $\boldsymbol{p}$ as residing in the same vector space as the vectors $\boldsymbol{t}$.  Both the token and position embeddings are trainable.
 1. Each token is associated with a third vector: $\boldsymbol{v} = \lambda \boldsymbol{t} + \boldsymbol{p}$, where the scale factor $\lambda = \sqrt{d}$.  In this tutorial, we make $\lambda$ a tunable parameter.

The vectors $\boldsymbol{v}$ are processed through $N$ *encoder layers*.

Since the source sequences are **padded** so that they are all of equal length, a method is needed to ensure that the pad tokens are ignored in all calculations. This is done using **masks**.
The source mask, `src_mask`, has value 0 if the token in the source is a `<pad>` token and 1 otherwise. There is also a target mask.

In [None]:
class Encoder(nn.Module):

    def __init__(self,
                 vocab_size,      # vocabulary size (of source)
                 max_len,         # maximum number of tokens per sequence
                 emb_dim,         # dimension of token embedding space
                 n_layers,        # number of encoding layers
                 n_heads,         # number of attention heads per encoding layer
                 ff_dim,          # dimension of feed-forward network
                 dropout,         # dropout probability
                 device):         # computational device

        super().__init__()

        print(f'''
    Encoder
    -------
      vocabulary size:       {vocab_size:10d}
      sequence length:       {max_len:10d}
      embedding dimension:   {emb_dim:10d}
      number of layers:      {n_layers:10d}
      number of heads:       {n_heads:10d}
      hidden dim. of FFN:    {ff_dim:10d}
        ''')

        # cache computational device
        self.device = device

        # represent each of the 'vocab_size' tokens by a vector
        # of size d = emb_dim. nn.Embedding "learns" a simple
        # lookup table that maps the code for each token in the
        # vocabulary to a vector of size emb_dim.
        self.tok_embedding = nn.Embedding(vocab_size, emb_dim)

        # represent the position of each token by a vector of
        # size d = emb_dim.
        # 'max_len' is the maximum length of a sequence.
        self.pos_embedding = nn.Embedding(max_len, emb_dim)

        # create 'n_layers' encoding layers
        self.layers = nn.ModuleList([EncoderLayer(emb_dim,
                                                  n_heads,
                                                  ff_dim,
                                                  dropout,
                                                  device)
                                     for _ in range(n_layers)])

        # randomly set to zero weights during training.
        # dropout is thought to mitigate over-training
        self.dropout= nn.Dropout(dropout)

        # factor by which to scale token embedding vectors.
        # use nn.Parameter to tell PyTorch that this is a
        # tunable parameter.
        self.scale = nn.Parameter(torch.sqrt(torch.FloatTensor([emb_dim])))

    def forward(self, src, src_mask):
        # src      : [batch_size, src_len]         (shape of src)
        # src_mask : [batch_size, 1, 1, src_len]   (shape of src_mask)

        batch_size, src_len = src.shape

        # ---------------------------------------
        # token embedding
        # ---------------------------------------
        src = self.tok_embedding(src)
        # src: [batch_size, src_len, emb_dim]

        # ---------------------------------------
        # token position embedding
        # ---------------------------------------
        # create a row tensor, p, with entries [0, 1,..., src_len-1]
        pos = torch.arange(0, src_len)
        # pos: [src_len]

        # 1. add a dimension at position 0 (for batch size)
        # 2. repeat one instance of p per row 'batch_size'
        #    times so that we obtain
        # pos = |p|
        #       |p|
        #        :
        #       |p|
        # 3. send to computational device
        once_per_row = 1
        #   3.1 unsqueeze inserts a dimension, here dimension 0
        #       so that pos has shape [1, src_len].
        #   3.2 repeat this row of integers batch_size times,
        #       once per row
        #   3.3 send to computational device
        pos = pos.unsqueeze(0).repeat(batch_size, once_per_row).to(self.device)
        # pos: [batch_size, src_len]

        # the embedding maps every token ordinal value (position) to a vector
        # in a vector the embedding space.
        pos = self.pos_embedding(pos)
        # pos: [batch_size, src_len, emb_dim]

        # linearly combine token and token position embeddings.
        # (perhaps this could be replaced by an MLP?)
        src = src * self.scale + pos
        # src: [batch_size, src_len, emb_dim]

        # it is not clear how much this helps, but lets keep it.
        src = self.dropout(src)

        # now pass embedded vectors through encoding layers.
        # Note: every token in the sequence src is processed
        # simultaneously.
        for layer in self.layers:
            src = layer(src, src_mask)
            # src: [batch_size, src_len, emb_dim]

        # return the vectors representing the processed tokens.
        # the tensor src will be fed into the decoder along with
        # the target tensor.
        return src

### Encoder Layer

 1. Pass the source tensor and its mask to the **multi-head attention** layers.
 1. Apply a residual connection and [Layer Normalization](https://arxiv.org/abs/1607.06450).
 1. Apply a linear layer.
 1. Apply a residual connection and layer normalization.

In [None]:
class EncoderLayer(nn.Module):

    def __init__(self,
                 emb_dim, # token embedding dimension
                 n_heads, # number of attention "heads"
                 ff_dim,  # dimension of feed-forward network
                 dropout, # dropout probability
                 device): # computational device

        super().__init__()

        self.self_attention       = MultiHeadAttention(emb_dim,
                                                       n_heads,
                                                       dropout,
                                                       device)

        self.self_attention_norm  = nn.LayerNorm(emb_dim)

        self.feedforward          = Feedforward(emb_dim, ff_dim, dropout)

        self.feedforward_norm     = nn.LayerNorm(emb_dim)

        self.dropout              = nn.Dropout(dropout)

    def forward(self, src, src_mask):
        # src      : [batch_size, src_len, emb_dim]
        # src_mask : [batch_size, 1, 1, src_len]

        # ------------------------------------------
        # self attention over embedded source tensor
        # ------------------------------------------
        # distinguish between src and src_ as the
        # former is needed later for a residual connection.
        # the output of the self_attention layer are vectors
        # that incorporate semantic and syntactic information
        # about the input tokens.
        #
        # Note, for self attention:
        #   Q = src
        #   K = src
        #   V = src
        src_ = self.self_attention(src, src, src, src_mask)
        # src_: [batch_size, src_len, emb_dim]

        # how useful is this?
        src_ = self.dropout(src_)

        # ------------------------------------------
        # add a residual connection, followed by
        # layer normalization.
        # ------------------------------------------
        # distinguish between src and src+src_ as the
        # former is later needed for another
        # residual connection.
        src  = self.self_attention_norm(src + src_)
        # src: [batch_size, src_len, emb_dim]

        # apply a feed-forward network
        src_ = self.feedforward(src)
        # src_: [batch_size, src_len, emb_dim]

        src_ = self.dropout(src_)

        # add residual connection and layer normalization
        src  = self.feedforward_norm(src + src_)
        # src: [batch_size, src_len, emb_dim]

        return src

## Multi-Head Attention Layer

Attention is the key to the transformer model.  Attention, in this model, is defined by the matrix expression

\begin{align}
    \texttt{Attention}(Q, K, V) & = \texttt{Softmax}\left(\frac{Q K^T}{\sqrt{d}} \right) V,
\end{align}

where $Q$ is called the `query`, $K$ the `key`, $V$ the `value`, and $d =$ **emb_dim** is the dimension of the vectors that represent the tokens. The Google researchers found that it is better to split each vector representing a token into **n_heads** smaller vectors each of size
$$\textrm{\bf head\_dim} = d / \textrm{\bf n\_heads}.$$
The integer **n_heads** is the number of so-called **attention heads**. It is claimed, with some justification in the Google paper, that each head pays attention to different aspects of a sequence. It is certainly plausible that this splitting procedure enhances the flexibility of the model, however, at our current level of understanding of how functions with millions of parameters truly work, such claims should nonetheless be taken with a liberal pinch of salt.

In self attention, the query, key, and value tensors are derived from the *same* tensor, either the source or target tensor, via separate linear transformations of that tensor (see *Attention Algorithm* below). The coefficients of the linear functions are free parameters to be determined by the training algorithm.  The number of rows in $Q$, $K$, and $V$, namely, **query_len**,  **key_len**, and **value_len**, respectively, is equal to the sequence length **seq_len**. For target/source attention, the query is a linear function of the target tensor while the key and value tensors are linear functions of the source tensor, where, again, the coefficients are free parameters to be fitted during training.

We first describe the attention mechanism mathematically and then follow with an algorithmic description that closely follows
the description in the [Annotated Transformer](https://nlp.seas.harvard.edu/annotated-transformer/). It is to be understood that every operation described below is performed for a batch of sequences. Therefore, when we refer to a matrix we really mean a batch of matrices.

Because each vector has been split into n_heads vectors of dimension $head\_dim$, the calculations below are applied separately to each such vector. At the end, the smaller vectors are coalesced back into vectors of the embedding dimension. First consider the matrix product $Q K^T$ in component form, where summation over repeated indices (the Einstein convention) is implied,
\begin{align}
A_{qk}
& = Q_{q h} \, [K^T]_{hk}, \nonumber\\
& \quad q=1,\cdots, \text{query\_len}, \,\, h = 1, \cdots, \text{head\_dim}, \,\, k = 1, \cdots, \text{key\_len} .
\end{align}
When the matrix $A$ is scaled and a softmax function applied elementwise along the key length dimension (here, horizontally) the result is another matrix $W$ whose row elements, by construction, sum to unity. The matrix $W$, which is a matrix of normalized weights, is then multiplied by $V$ to yield the matrix of sub-vectors
\begin{align}
    \text{Attention}_{qh}  
    & = W_{qk} V_{kh},
\end{align}
which encodes information about the degree of association between the smaller vectors, each associated with a token.

Since tokens are represented by vectors, it is instructive to think of the attention computation geometrically.   Each row, $i$, of $Q$, $K$, and $V$ can be regarded as the vectors $\boldsymbol{q}_i$, $\boldsymbol{k}_i$, and $\boldsymbol{v}_i$, respectively, associated with token $i$, where each vector (really sub-vector) is of dimension head_dim.  Consider, for example, a sequence with **seq_len** = 2. We can write $Q$, $K$, and $V$ as

\begin{align}
Q & = \left[\begin{matrix} \boldsymbol{q}_1 \\ \boldsymbol{q}_2 \end{matrix}\right], \\
K & = \left[\begin{matrix} \boldsymbol{k}_1 \\ \boldsymbol{k}_2 \end{matrix}\right], \text{ and} \\
V & = \left[\begin{matrix} \boldsymbol{v}_1 \\ \boldsymbol{v}_2 \end{matrix}\right] ,
\end{align}

and $A = Q K^T$ as the outer product matrix

\begin{align}
A & = \left[\begin{matrix} \boldsymbol{q}_1 \\ \boldsymbol{q}_2 \end{matrix}\right]
\left[\begin{matrix} \boldsymbol{k}_1 & \boldsymbol{k}_2 \end{matrix}\right] ,
\nonumber\\
& = \left[
\begin{matrix}
\boldsymbol{q}_1\cdot\boldsymbol{k}_1 & \boldsymbol{q}_1\cdot \boldsymbol{k}_2 \\
\boldsymbol{q}_2\cdot\boldsymbol{k}_1 & \boldsymbol{q}_2\cdot \boldsymbol{k}_2
\end{matrix}
\right] .
\end{align}

The matrix $A$ can be interpreted as a measure of the degree to which the $\boldsymbol{q}$ and $\boldsymbol{k}$ vectors are aligned. Presumably, the more aligned the two vectors the stronger the relationship between the  tokens they represent. Because of the use of the dot product, the degree of alignment depends both on the angle between the vectors as well as on their magnitudes. Consequently, two vectors can be more strongly aligned than a vector's alignment with itself!

After the scaling and softmax operations on $A$, tokens 1 and 2 become associated with vectors $\boldsymbol{w}_1 =  (w_{11}, w_{12})$ and $\boldsymbol{w}_2 =  (w_{21}, w_{22})$, respectively, where
\begin{align}
    w_{ij} & = \frac{\exp\left(\boldsymbol{q}_i \cdot \boldsymbol{k}_j \, / \, \sqrt{d}\right)}
    {\sum_{k = 1}^2 \exp\left(\boldsymbol{q}_i \cdot \boldsymbol{k}_k \, / \, \sqrt{d}\right)} .
\end{align}

These (weight) vectors lie in the line segment $[\boldsymbol{p}_1, \boldsymbol{p}_2]$ depicted in the figure below. The line segment is a simplex (here, a 1-simplex) that is embedded in a vector space of dimension **seq_len**.  In this vector space, tokens 1 and 2 are represented by the orthogonal unit vectors $\boldsymbol{u}_1$ and $\boldsymbol{u}_2$, respectively. For a sequence of length $n$, the vectors $\boldsymbol{w}_i$, $i = 1,\cdots, n$ lie in the $(n-1)$-simplex and, again, each coordinate unit vector $\boldsymbol{u}_i$ represents a token.  
<img src="https://github.com/hbprosper/mlinphysics/blob/main/Labs/10.Transformer/code/simplex.png?raw=1" align="left" width="250px"/>
The vector for token, $i$, is the weighted average
\begin{align}
    \text{Attention}_i & = w_{i1}  \boldsymbol{v}_1 + w_{i2} \boldsymbol{v}_2
\end{align}
of the so-called value vectors $\boldsymbol{v}_1$ and $\boldsymbol{v}_2$. In both self attention and (source to target) attention, the value vectors are source token vectors.

The upshot of this construction is that vectors representing the tokens are moved about in the embedding space in complex ways in accordance with the attention operation in such a way that their relative positions within that space encodes information about the degree and nature of the association between the tokens.
<br clear="left"/>


### Attention Algorithm

Now we describe the transformer attention mechanism algorithmically, following closely the description in the [Annotated Transformer](https://nlp.seas.harvard.edu/annotated-transformer/), but with some notational changes.

<img src="https://github.com/hbprosper/mlinphysics/blob/main/Labs/10.Transformer/code/transformer.png?raw=1" align="left" width="500px"/>

#### Step 1
As noted, the attention mechanism starts with three tensors, $V_\text{in}$,  $K_\text{in}$, and $Q_\text{in}$ of shapes **[batch_size,query_len,emb_dim]**, **[batch_size,key_len,emb_dim]**, and **[batch_size,value_len,emb_dim]**, respectively, with **value_len = key_len**. (emb_dim is called hid_dim in the [Annotated Transformer](https://nlp.seas.harvard.edu/annotated-transformer/)).  For self attention, $V_\text{in}$,  $K_\text{in}$, and $Q_\text{in}$ are the *same* tensor, while for target to source attention $Q_\text{in}$ is associated with the target tensor and $K_\text{in}$ and $V_\text{in}$ with the source tensor.

Three trainable linear layers, $f_V$, $f_K$, $f_Q$ are defined, each of shape **[emb_dim,emb_dim]**, which yield the so-called `query`, `key`, and `value` tensors
\begin{align}
    V & = f_V(\boldsymbol{V_\text{in}}), \\
    K & = f_K(\boldsymbol{K_\text{in}}), \text{ and} \\
    Q & = f_Q(\boldsymbol{Q_\text{in}}).
\end{align}
Each tensor $Q$, $K$, and $V$ is the same shape as $Q_\text{in}$, $K_\text{in}$, and $V_\text{in}$, respectively.

#### Step 2
Tensors $Q$, $K$, and $V$ are reshaped by first splitting the embedding dimension, **emb_dim**, into **n_heads** blocks of size **head_dim = emb_dim / n_heads** so that their shapes become **[batch_size, seq_len, n_heads, head_dim]**, where the **seq_len** pertains to **query_len**, **key_len**, or **value_len**, whose value is determined at runtime. (This is why the source (and target) masks must have shape [batch_size, 1, 1, emb_dim] as will become clear below.)

#### Step 3
Dimensions 1 and 2 of the tensors $Q$, $K$, and $V$ are permuted (`Tensor.permute(0, 2, 1, 3)`) so that we now have **[batch_size, n_heads, seq_len, head_dim]**. Tensor $K$ is further permuted (`Tensor.permute(0, 1, 3, 2)`) to shape **[batch_size, n_heads, head_dim, seq_len]** so that it represents $K^T$.

#### Step 4
Tensor $A = Q K^T$ is computed using `torch.matmul(Q, K^T)`, scaled by $1 \, / \, \sqrt{d}$, and a softmax is applied to the last dimension of $A$, that is, the sequence length dimension, yielding the tensor $W$ of shape **[batch_size, n_heads, query_len, key_len]**.

#### Step 5
$\text{Attention} = W V$ is computed, yielding a tensor of shape
**[batch_size, n_heads, query_len, head_dim]**.

#### Step 6
The n_heads and query_len dimensions of `Attention` are transposed (`Tensor.permute(0, 2, 1, 3)`) to shape **[batch_size, query_len, n_heads, head_dim]** and forced to be contiguous in memory (`contiguous()`).

#### Step 7
The **n_heads** vectors of dimension  **head_dim** are concatenated using `Attention.view(batch_size, seq_len, emb_dim)` to merge the attention heads into a single `MultiHeadAttention` tensor.

#### Step 8
Finally, the merged `MultiHeadAttention` tensor is pushed through a trainable linear layer of shape **[emb_dim, emb_dim]** to output a tensor of shape **[batch_size, seq_len, emb_dim]**.

### Comments
As noted above, it is claimed that the  algorithm above captures the notion of "paying attention to" token-token associations both within the same sequence and across sequences and that each attention head "pays attention to" a different aspect of the sequences. All such claims should be taken with a pinch of salt for at least two reasons.
First, it is not at all obvious that this computation aligns with our intuitive understanding of  that notion and, second, the computation is nested through multiple attention layers. Therefore, whatever the attention layers are doing, it is distributed over multiple layers in a highly non-linear, non-local, way.

It is, however, undeniable that the transformer has yielded amazing results. Therefore, we are forced to concede that, in practice,  whatever is going on in the attention layers the algorithm works wonders!


In [None]:
class MultiHeadAttention(nn.Module):

    def __init__(self, emb_dim, n_heads, dropout, device):

        super().__init__()

        # emb_dim must be a multiple of n_heads
        assert emb_dim % n_heads == 0

        self.emb_dim  = emb_dim
        self.n_heads  = n_heads
        self.head_dim = emb_dim // n_heads

        self.linear_Q = nn.Linear(emb_dim, emb_dim)
        self.linear_K = nn.Linear(emb_dim, emb_dim)
        self.linear_V = nn.Linear(emb_dim, emb_dim)
        self.linear_O = nn.Linear(emb_dim, emb_dim)

        self.dropout = nn.Dropout(dropout)

        self.scale = nn.Parameter(torch.sqrt(torch.FloatTensor([emb_dim])))

    def forward(self, query, key, value, mask=None):
        # query  : [batch_size, query_len, emb_dim]
        # key    : [batch_size, key_len,   emb_dim]
        # value  : [batch_size, value_len, emb_dim]

        batch_size, _, emb_dim = query.shape
        assert emb_dim == self.emb_dim

        Q = self.linear_Q(query)
        # Q: [batch_size, query_len, emb_dim]

        K = self.linear_K(key)
        # K: [batch_size, key_len,   emb_dim]

        V = self.linear_V(value)
        # V: [batch_size, value_len, emb_dim]

        # split vectors of size emb_dim into 'n_heads' vectors each
        # of size 'head_dim' and then permute dimensions 1 and 2
        Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        # Q: [batch_size, n_heads, query_len, head_dim]

        K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        # K: [batch_size, n_heads, key_len,   head_dim]

        V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        # V: [batch_size, n_heads, value_len, head_dim]

        # transpose K (by permuting key_len and head_dim), then
        # compute QK^T/scale
        A = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
        # A: [batch_size, n_heads, query_len, key_len]

        # apply (optional) map to ensure that pad tokens do
        # not contribute to the attention calculation.
        if mask is not None:
            A = A.masked_fill(mask == 0, -1e10)

        # apply softmax to the last dimension (i.e, to key len)
        # WARNING: W is referred to as 'attention' in Annotated Transformer!
        W = torch.softmax(A, dim=-1)
        # W: [batch_size, n_heads, query_len, key_len]

        # not sure why dropout is useful here
        W = self.dropout(W)

        # compute attention: (QK^T/scale)V
        attention = torch.matmul(W, V)
        # attention: [batch_size, n_heads, query_len, head_dim]

        # permute n_heads and query len and make sure the tensor
        # is contiguous in memory...
        attention = attention.permute(0, 2, 1, 3).contiguous()
        # attention: [batch_size, query_len, n_heads, head_dim]

        # ... and concatenate the n heads into a single multi-head
        # attention tensor.
        # if attention is being applied to source sequences, then
        # query_len = src_len. If applied to output sequences, then
        # query_len = trg_len.
        attention = attention.view(batch_size, -1, self.emb_dim)
        # attention: [batch_size, query_len, emb_dim]

        output = self.linear_O(attention)
        # output: [batch_size, query_len, emb_dim]

        return output

### Feedforward Layer

In [None]:
class Feedforward(nn.Module):

    def __init__(self, emb_dim, ff_dim, dropout):

        super().__init__()

        self.linear_1 = nn.Linear(emb_dim, ff_dim)

        self.linear_2 = nn.Linear(ff_dim, emb_dim)

        self.dropout  = nn.Dropout(dropout)

    def forward(self, x):
        # x: [batch_size, seq_len, emb_dim]

        x = self.linear_1(x)
        x = torch.relu(x)
        x = self.dropout(x)
        # x: [batch_size, seq_len, ff_dim]

        x = self.linear_2(x)
        # x: [batch_size, seq_len, emb_dim]

        return x

### Decoder

The decoder takes the encoded representation of the source sequence, which is represented by a point cloud in a vector space of dimension **emb_dim**, together with the target sequence, or the current predicted output sequence, and computes weights over the target vocabulary which can be converted to a probability distribution over the target vocabulary for the next output token.

The decoder has two multi-head attention layers: a *masked multi-head attention layer* over the target sequence, and a multi-head attention layer which uses the decoder representation as the query and the encoder representation as the key and value. The mask, as discussed below, is in addition to the one that masks out the pad tokens.

**Note**: In PyTorch, the softmax operation, which converts the output weights to probabilities, is contained within the loss function, so the decoder does not have a softmax layer.

In [None]:
class Decoder(nn.Module):

    def __init__(self,
                 vocab_size,   # size of target vocabulary
                 max_len,      # maximum output sequence length
                 emb_dim,      # dimension of embedding vector space
                 n_layers,     # number of decoder layers
                 n_heads,      # number of masked attention heads
                 ff_dim,       # hidden dimension of feed-forward network
                 dropout,      # weight dropout probability
                 device):      # computational device

        super().__init__()

        print(f'''
    Decoder
    -------
      vocabulary size:       {vocab_size:10d}
      sequence length:       {max_len:10d}
      embedding dimension:   {emb_dim:10d}
      number of layers:      {n_layers:10d}
      number of heads:       {n_heads:10d}
      hidden dim. of FFN:    {ff_dim:10d}
        ''')
        self.device = device

        self.tok_embedding = nn.Embedding(vocab_size, emb_dim)

        self.pos_embedding = nn.Embedding(max_len, emb_dim)

        self.layers  = nn.ModuleList([DecoderLayer(emb_dim,
                                                   n_heads,
                                                   ff_dim,
                                                   dropout,
                                                   device)
                                     for _ in range(n_layers)])

        self.linear  = nn.Linear(emb_dim, vocab_size)

        self.dropout = nn.Dropout(dropout)

        self.scale = nn.Parameter(torch.sqrt(torch.FloatTensor([emb_dim])))

    def forward(self, trg, src, trg_mask, src_mask):
        # trg      : [batch_size, trg_len]
        # src      : [batch_size, src_len, emb_dim]
        # trg_mask : [batch_size, 1, trg_len, trg_len]
        # src_mask : [batch_size, 1, 1, src_len]

        batch_size, trg_len = trg.shape

        # see Encoder for comments
        pos = torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size, 1).to(self.device)
        # pos: [batch_size, trg_len]

        trg = self.tok_embedding(trg) * self.scale + self.pos_embedding(pos)
        # trg: [batch_size, trg_len, emb_dim]

        trg = self.dropout(trg)

        # send the *same* source tensor to every decoding layer. however,
        # analogously to the Encoder, in the Decoder the target tensor is
        # processed through a sequence of layers.
        for layer in self.layers:
            trg = layer(trg, src, trg_mask, src_mask)
            # trg: [batch_size, trg_len, emb_dim]

        # for each output token, output 'vocab_size' weights,
        # which later will be converted to probabilities.
        output = self.linear(trg)
        # output: [batch_size, trg_len, vocab_size]

        return output

### Decoder Layer

The decoder layer has two multi-head attention layers, `self_attention` and `attention`. The former applies the attention algorithm to the target sequences, while the latter applies the algorithm between the target and source sequences.

In [None]:
class DecoderLayer(nn.Module):

    def __init__(self,
                 emb_dim,
                 n_heads,
                 ff_dim,
                 dropout,
                 device):

        super().__init__()

        # attention within the target sequences
        self.self_attention      = MultiHeadAttention(emb_dim, n_heads, dropout, device)

        self.self_attention_norm = nn.LayerNorm(emb_dim)

        # attention between the source and target sequences
        self.attention           = MultiHeadAttention(emb_dim, n_heads, dropout, device)

        self.attention_norm      = nn.LayerNorm(emb_dim)

        self.feedforward         = Feedforward(emb_dim, ff_dim, dropout)

        self.feedforward_norm    = nn.LayerNorm(emb_dim)

        self.dropout = nn.Dropout(dropout)

    def forward(self, trg, src, trg_mask, src_mask):
        # trg      : [batch_size, trg_len, emb_dim]
        # src      : [batch_size, src_len, emb_dim]
        # trg_mask : [batch_size, 1, trg_len, trg_len]
        # src_mask : [batch_size, 1, 1, src_len]

        # compute attention over embedded target sequences.
        # distinguish between trg and trg_, since the former
        # is needed later for residual connections.
        #                          Q    K    V
        trg_ = self.self_attention(trg, trg, trg, trg_mask)
        # trg_: [batch_size, trg_len, emb_dim]

        trg_ = self.dropout(trg_)

        # residual connection and layer norm
        trg  = self.self_attention_norm(trg + trg_)
        # trg: [batch_size, trg_len, emb_dim]

        # target to source attention
        #                     Q    K    V
        trg_ = self.attention(trg, src, src, src_mask)
        # trg_: [batch_size, trg_len, emb_dim]

        trg_ = self.dropout(trg_)

        # residual connection and layer norm
        trg  = self.attention_norm(trg + trg_)
        # trg: [batch_size, trg_len, emb_dim]

        trg_ = self.feedforward(trg)
        # trg_: [batch_size, trg_len, emb_dim]

        trg = self.dropout(trg)

        # residual and layer norm
        trg  = self.feedforward_norm(trg + trg_)
        # trg: [batch_size, trg_len, emb_dim]

        return trg

## The `transformer` Model

The `transformer` model encapsulates the encoder and decoder and handles the creation of the source and target masks.

The source mask, as described above, masks out `<pad>` tokens: the mask is 0 where the token is  a `<pad>` token and 1 otherwise. The mask is reshaped so that it can be broadcast to tensors of shape **[batch_size, n_heads, seq_len, seq_len]** which appear in the multi-head attention calculation.
The target mask also includes a mask for the `<pad>` tokens.

Consider a target sequence $\text{<sos>}, t_1,\cdots, t_{k}, \text{<eos>}$ of length $k+2$ constructed with tokens, $t_i$, from the target vocabulary and delimited by the special tokens $t_0 \equiv \text{<sos>}$ and $t_{k+1} \equiv \text{<eos>}$, the start-of-sequence and end-of-sequence tokens, respectively. During training, ideally, we would like to test the quality of all predicted sub-sequences *simultaneously*. For example, given sub-sequences $\text{<sos>}$ and  $\text{<sos>}, t_1$ we would like to check simultaneously the prediction $\text{<sos>} \rightarrow \ell_1$ and the prediction $\text{<sos>}, t_1 \rightarrow \ell_2$ and so on, where $\ell_1$ and $\ell_2$ are the model predictions and $t_1$ and $t_2$ are the true target tokens.

In practice, the $\ell_i$ are vectors of weights, called **logits**, of dimension equal to the size $|v_t|$ of the target vocabulary. The logits for each sub-sequence are converted into a discrete probability distribution
\begin{align}
p_{l} & \equiv p(t_l \in v_t| \boldsymbol{x}, \boldsymbol{y}_{\lt l}),
\end{align}
over the target vocabulary $v_t$
and an algorithm, e.g., greedy or beam search, uses the probability distribution $p_l$ to predict the next token $y_l$ for each sub-sequence.

During training, the transformer works on an entire source-target sequence pair. Therefore, it can compute the logits for all sub-sequences simultaneously. This is achieved with a simple, clever, trick: the so-called **subsequent mask**, `trg_sub_mask`. This mask is created using  the function $\texttt{torch.tril}$, which creates a lower diagonal square matrix where the elements above the diagonal are zero and the elements below the diagonal are unity. For example, for a target sequence $\boldsymbol{t} = \text{<sos>}, t_1, t_2, t_{3}, \text{<eos>}$ comprising 5 tokens  the `trg_sub_mask` looks like this:

$$\begin{pmatrix}
1 & 0 & 0 & 0 & 0\\
1 & 1 & 0 & 0 & 0\\
1 & 1 & 1 & 0 & 0\\
1 & 1 & 1 & 1 & 0\\
1 & 1 & 1 & 1 & 1\\
\end{pmatrix}.$$

When applied to the target sequence, $\boldsymbol{t}$, the subsequent mask ensures that for every sub-sequence the model has access to the next token when the losses are computed but not before. For example, the first row of the subsequent mask is **[1, 0, 0, 0, 0]**. Therefore, given the source sequence, $\boldsymbol{x}$, only the `<sos>` token of the target sequence is available for prediction of the second token. The second row of the subsequent mask is **[1, 1, 0, 0, 0]**. In this case, the target tokens `<eos>` and $t_1$ are available to predict the third token, and so on.  During training, the
subsequent mask makes it possible to compute losses for each sub-sequence simultaneously,
\begin{align}
  \boldsymbol{x}, \text{<sos>}  \rightarrow \ell_1 & \rightarrow loss(\ell_1, t_1),\\
  \boldsymbol{x}, \text{<sos>}, t_1   \rightarrow \ell_2 &\rightarrow loss(\ell_2, t_2), \\
        : & : \\
  \boldsymbol{x}, \text{<sos>}, t_1,\cdots, t_{k}   \rightarrow \ell_{k+1} &\rightarrow loss(\ell_{k+1}, \text{<eos>}) .
\end{align}
Notice that this requires the target sequence into the decoder be stripped of its end-of-sequence token while it is the start-of-sequence token that must be stripped from the original target sequence before the losses are computed.
    
In evaluation mode, the model is used autoregressively: the predicted output sequence is initialized to the token $\text{<sos>}$, which, together with the source sequence $\boldsymbol{x}$, is used to predict the logits $\ell_1$. The latter are converted to a discrete probability distribution over the target vocabulary which is then used to predict token $y_1$; then the source sequence and updated predicted sequence $\text{<sos>}, y_1$ is fed into the transformer to predict $y_2$ and so on until either the token $\text{<eos>}$ is predicted or the maximum allowed target sequence length is reached, whichever comes first.

The target mask is the logical and of the target pad and subsequent masks.

In [None]:
importlib.reload(tnm)

class Transformer(mlp.Model):

    def __init__(self, config,
                 device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
                 debug=False):

        super().__init__()

        if not tnm.config_complete(config):
            pass

        self.pad    = config('PAD_CODE')
        self.sos    = config('SOS_CODE')
        self.eos    = config('EOS_CODE')
        self.device = device
        self.debug  = debug
        self.max_len= config('TRG_SEQ_LEN')

        self.encoder = Encoder(
            config('SRC_VOCAB_SIZE'),
            config('SRC_SEQ_LEN'),
            config('ENC_EMB_DIM'),
            config('ENC_LAYERS'),
            config('ENC_HEADS'),
            config('ENC_FF_DIM'),
            config('ENC_DROPOUT'),
            device)

        self.decoder = Decoder(
            config('TRG_VOCAB_SIZE'),
            config('TRG_SEQ_LEN'),
            config('DEC_EMB_DIM'),
            config('DEC_LAYERS'),
            config('DEC_HEADS'),
            config('DEC_FF_DIM'),
            config('DEC_DROPOUT'),
            device)

        # initialize weights
        if hasattr(self, 'weight') and self.weight.dim() > 1:
            nn.init.xavier_uniform_(self.weight.data)

    def make_src_mask(self, src):
        # src: [batch_size, src_len]

        src_mask = (src != self.pad).unsqueeze(1).unsqueeze(2)
        # src_mask: [batch_size, 1, 1, src_len]

        return src_mask

    def make_trg_mask(self, trg):
        # trg: [batch size, trg len]

        _, trg_len = trg.shape

        trg_pad_mask = (trg != self.pad).unsqueeze(1).unsqueeze(2)
        # trg_pad_mask: [batch_size, 1, 1, trg_len]

        trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len),
                                             device=self.device)).bool()
        # trg_sub_mask: [trg_len, trg_len]

        # logical AND of the two masks
        trg_mask = trg_pad_mask & trg_sub_mask
        # trg_mask: [batch_size, 1, trg_len, trg_len]

        return trg_mask

    def forward(self, src, trg):
        # src: [batch_size, src_len]
        # trg: [batch_size, trg_len]

        if self.debug:
            print('BEGIN')
            print(' model - src.shape(on input):', src.shape)
            print('       - ', src[:2])
            print(' model - trg.shape(on input):', trg.shape)
            print('       - ', trg[:2])

        src_mask = self.make_src_mask(src)
        # src_mask: [batch_size, 1, 1, src_len]

        trg_mask = self.make_trg_mask(trg)
        # trg_mask: [batch_size, 1, trg_len, trg_len]

        src      = self.encoder(src, src_mask)
        # src: [batch_size, src_len, emb_dim]
        if self.debug:
            print(' model - src.shape(after encoder):', src.shape)
            print('       - ', src[:2])

        # the decoder will encode the target sequences
        # before applying the attention layers.
        logits   = self.decoder(trg, src, trg_mask, src_mask)
        # output: [batch_size, trg_len, trg_vocab_size]

        if self.debug:
            print(' model - logits.shape(on exit):', logits.shape)
            print('       - ', logits[:2])
            print('END')
        return logits

    # see section "Using the Model" below
    # -----------------------------------
    def translate(self, src):
        if src.ndim != 2:
            sys.exit(f'''
    translate: src must be of shape [batch_size, src_seq_len] with
    batch_size = 1.
            ''')

        pad = self.pad
        sos = self.sos
        eos = self.eos
        max_len= self.max_len
        device = self.device

        def execute(trg, src_, src_mask, top_k=3):
            # trg:  list of integer codes
            # src_: [batch_size, src_seq_len, src_vocab_size], batch_size = 1
            if src_.ndim != 3:
                sys.exit(f'''
    translate.execute: src_ must be of shape [batch_size, src_seq_len, src_vocab_size] with
    batch_size = 1.
            ''')

            trg_ = torch.tensor(trg).unsqueeze(0).to(device)
            # trg_: [batch_size, trg_seq_len], batch_size = 1

            trg_mask = self.make_trg_mask(trg_)
            # trg_mask: [batch_size, 1, 1, trg_seq_len]

            with torch.no_grad(): # no need to compute gradients
                # defining y0 = <sos>, given the current predicted sequence,
                # trg_ = (y0,..yi) and the source sequence, src_, compute
                # logit vectors of size trg_vocab_size logits for the next
                # token for each sub-sequence, (y0), (y0, y1), (y0,...,yi).
                #
                logits = self.decoder(trg_, src_, trg_mask, src_mask)
                # logits: [batch_size, trg_seq_len, trg_vocab_size]
                #
                # ...and convert the logits to probabilities by applying
                # a softmax to the trg_vocab_size dimension (dim=-1, i.e.,
                # horizontally) of the last series of logits (logits[:, -1, :])
                # Note: logits[:, -1, :] is of shape [batch_size, trg_vocab_size]
                logits_for_final_sub_sequence = logits[:, -1, :]
                probs = torch.softmax(logits_for_final_sub_sequence, dim=-1)

            # return the top_k token codes with the largest probabilities
            token_probs, token_codes = torch.topk(probs, k=top_k)
            token_probs = token_probs.t() # transpose: [trg_seq_len, top_k] => [top_k, trg_seq_len]
            token_codes = token_codes.t()

            return token_probs, token_codes

        # -------------------------------------
        # Start autoregressive translation
        # -------------------------------------
        self.eval()

        src = src.to(device) # add batch dimension
        # src: [batch_size, src_seq_len], batch_size = 1

        src_mask = self.make_src_mask(src)
        # src_mask: [batch_size, 1, 1, src_seq_len]

        # encode (i.e., embed and analyze) source sequence
        src_ = self.encoder(src, src_mask)
        # src_: [batch_size, src_seq_len, emb_dim]

        # initialize output sequence with the start-of-sequence token <sos>.
        # the decoder takes in the encoded source sequence and for each
        # current output sub-sequence computes weights for the next token.
        # the logits for the next token of the last sub-sequence are converted
        # to probabilities.
        #
        # using a greedy strategy, the most probable token is chosen as the
        # next token, which is appended to the current output sequence.
        # the algorithm repeats and stops when either the <eos> token is
        # predicted or the maximum output sequence is reached.
        trg = [sos]
        for i in range(max_len):
            # Note: must pass encoded source (src_) to execute
            probs, codes = execute(trg, src_, src_mask)

            code = codes[0, -1] # pick most probable next token
            if code == pad:
                continue

            trg.append(code)
            if code == eos:
                break

        return trg

## Training the `transformer` Model

Our model is miniscule compared with the transformer models used today, of which there are many variants, but our model is small enough to be trained on a single GPU in less than an hour.

As noted in Section *Transformer Model* above, given the entire source sequence, $\boldsymbol{x}$, and all target sub-sequences, the model predicts the next logits for every target sub-sequence simultaneously. Consider, again, a target sequence of size $k = 5$. Since we want the model to predict `<eos>`, we slice off the `<eos>` token from the end of the target sequence,
\begin{align}
\text{trg} &= [\text{<sos>}, t_1, t_2, t_3, \text{<eos>}]\\
\text{trg[:-1]} &= [\text{<sos>}, t_1, t_2, t_3],
\end{align}
where the $t_i$ denote target sequence tokens other than `<sos>` and `<eos>` and the sliced target is fed into the model, which predicts simultaneously the logit vectirs $\ell_1,\cdots, \ell_4$ corresponding to each of the sub-sequences $[\text{<sos>}]$, $[\text{<sos>}, t_1]$, $[\text{<sos>}, t_1, t_2]$, and $[\text{<sos>}, t_1, t_2, t_3]$, respectively. Recall that the dimension of each logit vector $\ell_i$ is the cardinality $|v_t|$ (size) of the target vocabulary.  The loss, $loss(\ell_i, t_i)$, for each sub-sequence is computed using the data
\begin{align}
\text{model outputs} &= [\ell_1, \ell_2, \ell_3, \ell_4]\\
\text{trg[1:]} &= [t_1, t_2, t_3, \text{<eos>}] ,
\end{align}
that is, the logits and the original targer tensor with the `<sos>` token stripped away.

### Loss function
A transformer is an autoregressive multi-class classifier: every time it is invoked with a source sequence and the current predicted output sequence, the classifier computes a probability distribution over the target vocabulary for every token in the current predicted sequence. These probability distributions can be used to decide upon, that is, classifier, the next output token. Therefore, like all multi-class classifiers, a transformer is trained using the **cross entropy loss**. Given the data $(\ell_i, t_i)$, the loss is given by
\begin{align}
    loss(\ell_{i, k}, t_i) & = - \log p_k, \quad\text{ where } k \equiv t_i, \\
        p_k & = \texttt{softmax}_k(\ell_i) \equiv \frac{\exp(\ell_{i, k})}{\sum_{j} \exp(\ell_{i,j})},
\end{align}
and $\ell_{i, k}$ denotes component $k$ of logit vector $\ell_i$. The losses are averaged over the output sub-sequences and over the batch of sequences.

In [None]:
def train(objective, optimizer,
          train_loader, train_small_loader, val_loader,
          config):

    # ------------------------------------------
    # enter training loop
    # ------------------------------------------

    # get configuration info
    lossfile   = config('file/losses')
    paramsfile = config('file/params')
    monstep    = config('monitor_step')
    delete     = config('delete')
    frac       = config('frac')
    niterations= config('n_iterations')
    base_lr    = config('base_lr')

    # instantiate object that saves average losses to
    # a csv file for realtime monitoring

    losswriter = mon.LossWriter(niterations,
                                lossfile,
                                step=monstep,
                                delete=delete,
                                frac=frac,
                                model=objective.model,
                                paramsfile=paramsfile)

    # -----------------------------
    # training loop
    # -----------------------------

    for ii, (src, trg) in enumerate(train_loader):

        objective.train()

        R = objective(src, trg)

        optimizer.zero_grad()     # zero gradients

        R.backward()              # compute gradients

        #torch.nn.utils.clip_grad_norm_(model.parameters(), 1)

        optimizer.step()          # make a single step in average loss

        if (ii % monstep == 0) or (ii == niterations-1):

            # set mode to evaluation so that training-specific
            # operations such as dropout, etc., are disabled.
            objective.eval()

            src, trg = next(iter(train_small_loader))
            t_loss   = objective(src, trg).item()

            src, trg = next(iter(val_loader))
            v_loss   = objective(src, trg).item()

            # update loss file
            losswriter(ii, t_loss, v_loss, base_lr)

In [None]:
importlib.reload(tnm)

model = Transformer(config, device=DEVICE).to(DEVICE)
print(model)
print(f'The model has {mlp.number_of_parameters(model):,} trainable parameters')

optimizer = torch.optim.Adam(model.parameters(), lr=config('base_lr'))

avgloss   = nn.CrossEntropyLoss(ignore_index=config('PAD_CODE'))

# make a specialized objective from mlp.Objective
class TNNObjective(mlp.Objective):
    def __init__(self, model, avgloss):
        super().__init__(model, avgloss)

    def forward(self, x, t):
        # x[batch_size, src_seq_len]
        # t[batch_size, trg_seq_len]

        # Slice off <eos> token from all targets in the batch.
        # For every sub-sequence, the model computes a vector
        # of logits of size trg_vocab_size. The calculations
        # for all sub-sequences are done simultaneously.
        t_in = t[:, :-1] # strip await <eos>
        logits = model(x, t_in)
        # logits[batch_size, trg_seq_len, trg_vocab_size]

        # reshape logits by flattening first 2 dimensions for
        # the subsequent loss calculations
        trg_vocab_size = logits.shape[-1] # get target vocabulary size
        logits_out = logits.reshape(-1, trg_vocab_size)
        # logits_out[batch_size * tgt_seq_len, tgt_vocab_size]

        # For each sub-sequence, the model predicts the logits
        # l_1, l_2,...,l_k, l_k+1, which ultimately will yield
        # predict the next token, and for the last sub-sequence
        # the token <eos>.
        # Before we compute losses, we need to strip away the
        # <sos> token from the targets to arrive at, t_1, t_2,
        # ..., <eos> for the target sequences.
        t_out = t[:, 1:].reshape(-1) # strip away <sos> and flatten
        # [batch_size * tgt_seq_len]
        return avgloss(logits_out, t_out).mean()

objective = TNNObjective(model, avgloss)


    Encoder
    -------
      vocabulary size:               37
      sequence length:               22
      embedding dimension:           64
      number of layers:               2
      number of heads:                8
      hidden dim. of FFN:           128
        

    Decoder
    -------
      vocabulary size:               30
      sequence length:               85
      embedding dimension:           64
      number of layers:               2
      number of heads:                8
      hidden dim. of FFN:           128
        
Transformer(
  (encoder): Encoder(
    (tok_embedding): Embedding(37, 64)
    (pos_embedding): Embedding(22, 64)
    (layers): ModuleList(
      (0-1): 2 x EncoderLayer(
        (self_attention): MultiHeadAttention(
          (linear_Q): Linear(in_features=64, out_features=64, bias=True)
          (linear_K): Linear(in_features=64, out_features=64, bias=True)
          (linear_V): Linear(in_features=64, out_features=64, bias=True)
          (linear

In [None]:
TRAIN = False

if TRAIN:
    train(objective, optimizer,
          train_loader, train_loader_val, val_loader,
          config)

    monitor = mon.Monitor(config('file/losses'))
    monitor.plot()

## Using the Model

The test data are already tokenized, coded, and bracketed with the `<sos>` and `<eos>` codes. The translation steps, implemented in `model.translate`, are as follows:

  1. convert the coded source tokens, `src`, to the tensor, `src_`, and add a batch dimension to it (at dimension 0) so that the source is of the correct shape, namely, `[batch_size, src_len]`, but with `batch_size = 1`;
  1. create the source mask `src_mask` to mask out pad tokens;
  1. feed the source `src_` and its mask `src_mask` into the encoder;
  1. for the predicted output sequence, create a list initialized with the `<sos>` token;
  1. *repeat* steps `A` to `E` below until the model predicts the `<eos>` token or the maximum output length is reached:
     1. convert the current output list `trg` into the tensor `trg_` and add a batch dimension at dimension 0 so that like the source tensor, the output tensor will have the shape `[batch_size, src_len]` with `batch_size = 1`;
     1. create the target mask `trg_mask` to mask out pad tokens;
     1. feed the current output `trg_`, encoder output `src_`, and the source and target masks into the decoder;
     1. get the predicted token from the decoder;
     1. add the predicted token to the current output list;
  1. convert the output sequence from codes to a string.

In [None]:
PRINT_MISTAKES = False

# load best model
model.load(config('file/params'))

M = 0
F = 0.0
for i, (src, trg) in enumerate(test_loader):
    # src: [batch_size, src_seq_len], batch_size = 1
    # trg: [batch_size, trg_seq_len], batch_size = 1

    # translate: src => out
    out = model.translate(src)
    # out: list of integer codes

    # convert sequence of target codes to a string (skipping <sos>,
    # <eos>, and <pad> tokens)
    trg  = trg.squeeze() # get rid of batch dimension for stringify
    trg_ = tnm.stringify(trg[1:-1], seqdata.trg_code2token)

    # convert predicted sequence of target codes to a string
    # (skipping <sos>, <eos>, and <pad> tokens)
    out_ = tnm.stringify(out[1:-1], seqdata.trg_code2token)

    # count how often we're right
    if out_ == trg_:
        M += 1
        F = M / (i+1)
    else:
        if PRINT_MISTAKES:
            print()
            print(tgt_)
            print()
            print(out_)
            print()
            print('-'*91)

    print(f'\r{i:8d}\taccuracy: {F:8.3f}', end='')

N  = len(test_data)
dF = np.sqrt(F*(1-F)/N)
print()
print(f'Accuracy: {F:8.3f} +/- {dF:.3f}')

     749	accuracy:    0.979
Accuracy:    0.979 +/- 0.005
