# Modified Transformer Architecture

Sources:

- [correct `transformer implementation` from scratch in `pytorch`](https://towardsdatascience.com/build-your-own-transformer-from-scratch-using-pytorch-84c850470dcb) (in-depth tutorial from same author: [part1](https://towardsdatascience.com/all-you-need-to-know-about-attention-and-transformers-in-depth-understanding-part-1-552f0b41d021), [part2](https://towardsdatascience.com/all-you-need-to-know-about-attention-and-transformers-in-depth-understanding-part-2-bf2403804ada))
- [nice visuals for understanding `multi-head attention`](http://jalammar.github.io/illustrated-transformer/)
- [`positional encoding`](https://machinelearningmastery.com/a-gentle-introduction-to-positional-encoding-in-transformer-models-part-1/)
- [kaggle transformer code: ❗Contains mistakes (see comments), but nice overall explanation](https://www.kaggle.com/code/arunmohan003/transformer-from-scratch-using-pytorch)



In [1]:
!python --version

Python 3.9.13


In [2]:
import os
import math

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
seed = 42
torch.manual_seed(seed)
print(f"torch version: {torch.__version__}")

torch version: 2.4.1+cpu


# Dataset

For details about the dataset see `data_handling.ipynb`

In [3]:
SAVE_FOLDER = "processed_dataset"
sl = 128
data_file_path = os.path.join(SAVE_FOLDER, f"data_{sl}.pkl")
meta_file_path = os.path.join(SAVE_FOLDER, f"metadata_{sl}.csv")

df = pd.read_pickle(data_file_path)
df.head()

Unnamed: 0_level_0,x_dominant_hand_0,x_dominant_hand_1,x_dominant_hand_2,x_dominant_hand_3,x_dominant_hand_4,x_dominant_hand_5,x_dominant_hand_6,x_dominant_hand_7,x_dominant_hand_8,x_dominant_hand_9,...,z_dominant_hand_11,z_dominant_hand_12,z_dominant_hand_13,z_dominant_hand_14,z_dominant_hand_15,z_dominant_hand_16,z_dominant_hand_17,z_dominant_hand_18,z_dominant_hand_19,z_dominant_hand_20
sequence_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1816796431,0.408832,0.519912,0.612159,0.707576,0.797313,0.494709,0.532817,0.553556,0.566219,0.391196,...,-0.245855,-0.269148,-0.129743,-0.251501,-0.278687,-0.26653,-0.152852,-0.257519,-0.275822,-0.266876
1816796431,0.398663,0.523662,0.638807,0.744236,0.832567,0.538486,0.564302,0.581011,0.597674,0.441541,...,-0.37077,-0.408097,-0.185217,-0.325494,-0.343373,-0.328294,-0.203126,-0.315719,-0.326104,-0.314282
1816796431,0.41929,0.509726,0.593165,0.685492,0.777913,0.483669,0.510993,0.53641,0.564583,0.393016,...,-0.28577,-0.318548,-0.155317,-0.274822,-0.312119,-0.316411,-0.181363,-0.286298,-0.316182,-0.322671
1816796431,0.398764,0.498118,0.583356,0.677779,0.775966,0.481279,0.491659,0.524974,0.571944,0.412262,...,-0.235725,-0.267054,-0.14138,-0.219369,-0.256553,-0.27369,-0.170996,-0.240285,-0.266193,-0.27811
1816796431,0.420213,0.49565,0.57179,0.659049,0.74974,0.485707,0.47593,0.501727,0.53915,0.438294,...,-0.186706,-0.217181,-0.10774,-0.165642,-0.201059,-0.222898,-0.131329,-0.183113,-0.208774,-0.225284


In [4]:
metadata_df = pd.read_csv(meta_file_path, header=0)

max_phrase_len = max([len(it) for it in metadata_df.phrase.values])
possible_characters = sorted(set.union(*[set(p) for p in metadata_df.phrase.values]))
token_map = {c: i+3 for i, c in enumerate(possible_characters)}
PADDING = 'P'
SOS = '<'
EOS = '>'
token_map[PADDING] = 0 # padding
token_map[SOS] = 1 # SOS
token_map[EOS] = 2 # EOS
metadata_df.phrase = metadata_df.phrase.apply(lambda it: np.array([token_map[c] for c in '<'+it+'>'+('P'*(max_phrase_len-len(it)))], dtype=np.int32))
max_phrase_len_with_sequence_control_tokens = max_phrase_len + 2 # SOS, EOS

# Shuffle dataset
metadata_df = metadata_df.sample(frac=1).reset_index(drop=True)

index = {row[0]: {"phrase": row[1], "signer_id": row[2]} for row in metadata_df.to_numpy()}
str(index)[:100] + "..."

"{607853238: {'phrase': array([ 1,  5,  9,  3, 36,  3, 28, 25, 17,  3, 24, 18, 38, 32, 33, 28, 27,\n  ..."

In [28]:
settings = {
    # Data
    "d_model": 32,
    "seq_len": 256,
    "padding_token": token_map[PADDING],

    # Training
    "batch_size": 32,
    "epochs": 1,

    # Transformer
    "num_nodes_per_hand": 21,
    "num_features_per_node": 3,
    "tgt_vocab_size": len(token_map),
    "num_heads": 2,
    "num_enc_layer": 2,
    "num_dec_layer": 4,
    "dff": 32,
    "dropout": 0.1,
    "phrase_length": max_phrase_len_with_sequence_control_tokens,

    # LandMarkEmbedding
    "num_conv_layers": 3,
    "filter_size": 11,

    # GraphEmbedding
    "hidden_dim": 128,
}

In [6]:
class TransformerDataset(torch.utils.data.Dataset):
    def __init__(self, df, meta_data, seq_len=256, padding_value=0.0):
        self.df = df
        self.meta_data = meta_data
        self.sequence_ids = df.index.unique()
        self.padding_value = padding_value
        self.seq_len = seq_len

    def __len__(self):
        return len(self.sequence_ids)

    def __getitem__(self, idx):
        sequence_id = self.sequence_ids[idx]
        x_values = torch.tensor(self.df.loc[sequence_id].values, dtype=torch.float32)

        x_mask = torch.ones(self.seq_len, dtype=torch.float32)

        # Apply padding if the sequence is shorter than seq_len
        if x_values.shape[0] < self.seq_len:
            x_mask[x_values.shape[0]:] = settings["padding_token"]
            padding_size = self.seq_len - x_values.shape[0]
            padding = torch.full((padding_size, x_values.shape[1]), self.padding_value)
            x_values = torch.cat([x_values, padding], dim=0)
        elif x_values.shape[0] > self.seq_len:
            # Truncate the sequence if it's longer than seq_len
            x_values = x_values[:self.seq_len]

        y_phrase = self.meta_data[sequence_id]['phrase']
        return x_values, x_mask, y_phrase

dataset = TransformerDataset(df, index)

In [7]:
dataset.__getitem__(0)

(tensor([[ 0.4088,  0.5199,  0.6122,  ..., -0.2575, -0.2758, -0.2669],
         [ 0.3987,  0.5237,  0.6388,  ..., -0.3157, -0.3261, -0.3143],
         [ 0.4193,  0.5097,  0.5932,  ..., -0.2863, -0.3162, -0.3227],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]),
 tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,

In [8]:
from torch.utils.data import random_split, DataLoader

# Split lengths for train (80%), test (10%), valid (10%)
train_size = int(0.8 * len(dataset))
valid_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - valid_size

train_dataset, valid_dataset, test_dataset = random_split(dataset, [train_size, valid_size, test_size])

# NOTE Keep num_workers=0
train_loader = DataLoader(train_dataset, batch_size=settings["batch_size"], shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=settings["batch_size"], shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=settings["batch_size"], shuffle=False)

for it in train_loader:
    print("Shape of x:")
    print(it[0].shape)
    print("\nMask for input:")
    print(it[1].shape)
    print(it[1])
    print("\nTarget:")
    print(it[2].shape)
    print(it[2]) # NOTE dec inp: it[1][:, :-1], target: it[1][:, 1:]
    break

Shape of x:
torch.Size([32, 256, 63])

Mask for input:
torch.Size([32, 256])
tensor([[1., 1., 1.,  ..., 0., 0., 0.],
        [1., 1., 1.,  ..., 0., 0., 0.],
        [1., 1., 1.,  ..., 0., 0., 0.],
        ...,
        [1., 1., 1.,  ..., 0., 0., 0.],
        [1., 1., 1.,  ..., 0., 0., 0.],
        [1., 1., 1.,  ..., 0., 0., 0.]])

Target:
torch.Size([32, 33])
tensor([[ 1,  9, 12,  ...,  0,  0,  0],
        [ 1, 14, 17,  ...,  0,  0,  0],
        [ 1,  5,  5,  ...,  0,  0,  0],
        ...,
        [ 1, 33, 14,  ...,  0,  0,  0],
        [ 1, 10,  8,  ...,  0,  0,  0],
        [ 1, 10,  5,  ...,  0,  0,  0]], dtype=torch.int32)


In [9]:
%timeit next(iter(train_loader))

62.7 ms ± 5.28 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


# Constructing the Transformer

## General Components

In [10]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        attn_probs = torch.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_probs, V)
        return output
        
    def split_heads(self, x):
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
        
    def combine_heads(self, x):
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
        
    def forward(self, Q, K, V, mask=None):
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))
        
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
        output = self.W_o(self.combine_heads(attn_output))
        return output

In [11]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_len, n=10000):
        super(PositionalEncoding, self).__init__()

        assert d_model % 2 == 0, "d_model must be even"
        
        pe = torch.zeros(max_seq_len, d_model)
        position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(n) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        # NOTE
        # register buffer in Pytorch ->
        # If you have parameters in your model, which should be saved and restored in the state_dict,
        # but not trained by the optimizer, you should register them as buffers.
        self.register_buffer('pe', pe.unsqueeze(0))
        
    def forward(self, x):
        # Broadcasting mechanism (automatically works for multiple batches even when shapes don't match along batch dim)
        return x + self.pe[:, :x.size(1)]

In [12]:
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(FeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

## Encoder Embeddings

In [13]:
class LandmarkEmbedding(nn.Module):
    def __init__(self, d_model, num_features, num_conv_layers, filter_size):
        super(LandmarkEmbedding, self).__init__()

        self.d_model = d_model

        first_conv = nn.Conv1d(in_channels=num_features, out_channels=d_model, kernel_size=filter_size, padding='same')
        rest_of_convs = [
            nn.Conv1d(in_channels=d_model, out_channels=d_model, kernel_size=filter_size, padding='same')
            for _ in range(num_conv_layers-1)
        ]
        
        self.conv_block = nn.Sequential(
            first_conv, *rest_of_convs
        )

    def forward(self, x):
        # x is expected to be of shape (batch_size, seq_len, num_of_features)
        x = x.permute(0, 2, 1) # (batch_size, num_of_features, seq_len)
        x = self.conv_block(x)
        # TODO experiment with scaling the output
        # x *= math.sqrt(self.d_model)
        x = x.permute(0, 2, 1) # (batch_size, seq_len, d_model)
        return x

In [14]:
lm_embedding = LandmarkEmbedding(
    d_model=settings["d_model"],
    num_features=settings["num_nodes_per_hand"]*settings["num_features_per_node"],
    num_conv_layers=settings["num_conv_layers"],
    filter_size=settings["filter_size"]
)
x = next(iter(train_loader))[0]
%timeit lm_embedding(x)

10.4 ms ± 585 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [15]:
for name, param in lm_embedding.named_parameters():
    if param.requires_grad:
        print(f"Layer: {name} | Size: {param.size()}")

Layer: conv_block.0.weight | Size: torch.Size([32, 63, 11])
Layer: conv_block.0.bias | Size: torch.Size([32])
Layer: conv_block.1.weight | Size: torch.Size([32, 32, 11])
Layer: conv_block.1.bias | Size: torch.Size([32])
Layer: conv_block.2.weight | Size: torch.Size([32, 32, 11])
Layer: conv_block.2.bias | Size: torch.Size([32])


In [16]:
from torch_geometric.data import Data, Batch
from torch_geometric.nn import global_mean_pool
from torch_geometric.nn import GCNConv, GATConv, GATv2Conv

class GraphEmbedding(nn.Module):
    def __init__(
            self,
            num_nodes: int,
            num_features_per_node: int,
            d_model: int,
            hidden_dim: int,
            seq_len: int,
            batch_size: int
        ):
        super(GraphEmbedding, self).__init__()
        self.d_model = d_model
        self.num_nodes = num_nodes
        self.num_features_per_node = num_features_per_node
        self.batch_size = batch_size
        self.seq_len = seq_len

        # self.conv1 = GATConv(num_features_per_node, hidden_dim, heads=4, concat=False)
        # self.conv2 = GATConv(hidden_dim, d_model, heads=4, concat=False)

        self.conv1 = GCNConv(num_features_per_node, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, d_model)

        # Precompute edge_index and batch_info for batch
        # Based on: https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/loader/dataloader.py
        connections = [
            (0,1), (0,5), (0,17), (1,2), (2,3), (3,4),
            (5,6), (5,9), (6,7), (7,8), (9,10), (9,13),
            (10,11), (11,12), (13,14), (13,17), (14,15),
            (15,16), (17,18), (18,19), (19,20)
        ]
        edges = []
        for a, b in connections:
            edges.append([a, b])
            edges.append([b, a])  # Add the reverse connection
        single_graph_edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()

        # NOTE Edge index calculation can be done without using Batch.from_data_list(...) -> improve performance
        # For a one time run this is good enough.
        example_graph_features = torch.zeros(self.num_nodes, self.num_features_per_node)
        data_list = [Data(x=example_graph_features, edge_index=single_graph_edge_index) for _ in range(self.batch_size*self.seq_len)]
        mini_batch = Batch.from_data_list(data_list)

        self.register_buffer('edge_index', mini_batch.edge_index)
        self.register_buffer('batch_info', mini_batch.batch)

    def forward(self, x):
        # x ~ (batch_size, seq_len, num_features)

        x = x.reshape(-1, self.num_features_per_node) # (batch_size*seq_len*num_nodes, num_features_per_node)

        # Obtain node embeddings
        x = self.conv1(x, self.edge_index)
        x = x.relu()
        x = self.conv2(x, self.edge_index)
        x = x.relu()

        # Readout layer
        x = global_mean_pool(x, self.batch_info)  # (batch_size*sequence_len, d_model)

        x = x.reshape(self.batch_size, self.seq_len, self.d_model)
        return x

In [17]:
graph_embedding = GraphEmbedding(
    num_nodes=settings["num_nodes_per_hand"],
    num_features_per_node=settings["num_features_per_node"],
    d_model=settings["d_model"],
    hidden_dim=settings["hidden_dim"],
    seq_len=settings["seq_len"],
    batch_size=settings["batch_size"]
)
x = next(iter(train_loader))[0]
# NOTE
# Running the Graph-based layers on each graph separately is more than 10x slower than running on the mini batch
# To make it faster:
# - run on GPU
# - rewrite GNN layers using the fact that all graphs will have the same structure
%timeit graph_embedding(x)

189 ms ± 9.62 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [18]:
for name, param in graph_embedding.named_parameters():
    if param.requires_grad:
        print(f"Layer: {name} | Size: {param.size()}")

Layer: conv1.bias | Size: torch.Size([128])
Layer: conv1.lin.weight | Size: torch.Size([128, 3])
Layer: conv2.bias | Size: torch.Size([32])
Layer: conv2.lin.weight | Size: torch.Size([32, 128])


## Rest of the Transformer

In [19]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = FeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask):
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x

In [20]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(DecoderLayer, self).__init__()
        self.masked_self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = FeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, enc_output, src_mask, tgt_mask):
        attn_output = self.masked_self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))
        attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)
        x = self.norm2(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        return x

In [21]:
class Transformer(nn.Module):
    def __init__(
            self,
            encoder_embedding: nn.Module,
            tgt_vocab_size: int,
            d_model: int,
            num_heads: int,
            num_enc_layers: int,
            num_dec_layers: int,
            d_ff: int,
            max_seq_length: int,
            dropout: float
        ):
        super(Transformer, self).__init__()

        self.encoder_embedding = encoder_embedding
        self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length)

        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_enc_layers)])
        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_dec_layers)])

        self.fc = nn.Linear(d_model, tgt_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def generate_mask(self, src_mask, tgt):
        src_mask = (src_mask != 0).unsqueeze(1).unsqueeze(2) # NOTE Here the src_mask contains ones and zeros
        tgt_mask = (tgt != settings["padding_token"]).unsqueeze(1).unsqueeze(3)
        seq_length = tgt.size(1)
        nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool()
        tgt_mask = tgt_mask & nopeak_mask
        return src_mask, tgt_mask

    def forward(self, src, x_mask, tgt):
        src_mask, tgt_mask = self.generate_mask(x_mask, tgt)
        src_embedded = self.dropout(self.positional_encoding(self.encoder_embedding(src)))
        tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)))

        enc_output = src_embedded
        for enc_layer in self.encoder_layers:
            enc_output = enc_layer(enc_output, src_mask)

        dec_output = tgt_embedded
        for dec_layer in self.decoder_layers:
            dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)

        output = self.fc(dec_output)
        return output

In [22]:
encoder_embedding = LandmarkEmbedding(
    d_model=settings["d_model"],
    num_features=settings["num_nodes_per_hand"]*settings["num_features_per_node"],
    num_conv_layers=settings["num_conv_layers"],
    filter_size=settings["filter_size"]
)
# encoder_embedding = GraphEmbedding(
#     num_nodes=settings["num_nodes_per_hand"],
#     num_features_per_node=settings["num_features_per_node"],
#     d_model=settings["d_model"],
#     hidden_dim=settings["hidden_dim"],
#     seq_len=settings["seq_len"],
#     batch_size=settings["batch_size"]
# )

transformer = Transformer(
    encoder_embedding=encoder_embedding,
    tgt_vocab_size = settings["tgt_vocab_size"],
    d_model=settings["d_model"],
    num_heads=settings["num_heads"],
    num_enc_layers=settings["num_enc_layer"],
    num_dec_layers=settings["num_dec_layer"],
    d_ff=settings["dff"],
    max_seq_length=max(settings["phrase_length"], settings["seq_len"]),
    dropout=settings["dropout"]
)

# Training

In [23]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


In [24]:
transformer.to(device)

Transformer(
  (encoder_embedding): LandmarkEmbedding(
    (conv_block): Sequential(
      (0): Conv1d(63, 32, kernel_size=(11,), stride=(1,), padding=same)
      (1): Conv1d(32, 32, kernel_size=(11,), stride=(1,), padding=same)
      (2): Conv1d(32, 32, kernel_size=(11,), stride=(1,), padding=same)
    )
  )
  (decoder_embedding): Embedding(40, 32)
  (positional_encoding): PositionalEncoding()
  (encoder_layers): ModuleList(
    (0-1): 2 x EncoderLayer(
      (self_attn): MultiHeadAttention(
        (W_q): Linear(in_features=32, out_features=32, bias=True)
        (W_k): Linear(in_features=32, out_features=32, bias=True)
        (W_v): Linear(in_features=32, out_features=32, bias=True)
        (W_o): Linear(in_features=32, out_features=32, bias=True)
      )
      (feed_forward): FeedForward(
        (fc1): Linear(in_features=32, out_features=32, bias=True)
        (fc2): Linear(in_features=32, out_features=32, bias=True)
        (relu): ReLU()
      )
      (norm1): LayerNorm((32,), 

In [25]:
# Source: https://github.com/jamfromouterspace/levenshtein

def levenshtein_distance(seq1, seq2):
    len_1, len_2 = len(seq1), len(seq2)
    dp = [[0] * (len_2 + 1) for _ in range(len_1 + 1)]

    for i in range(len_1 + 1):
        dp[i][0] = i
    for j in range(len_2 + 1):
        dp[0][j] = j

    for i in range(1, len_1 + 1):
        for j in range(1, len_2 + 1):
            cost = 0 if seq1[i - 1] == seq2[j - 1] else 1
            dp[i][j] = min(dp[i - 1][j] + 1,
                           dp[i][j - 1] + 1,
                           dp[i - 1][j - 1] + cost)

    return dp[len_1][len_2]

indices_to_ignore = torch.tensor([token_map[" "], token_map["<"], token_map["P"]]) # End of sequence is important

def masked_levenshtein(output_tokens, target_tokens):
    batch_size = output_tokens.size(0)
    total_distance = 0

    for i in range(batch_size):
        pred_seq = output_tokens[i].tolist()
        target_seq = target_tokens[i].tolist()

        # Find first occurrence of EOS in both sequences
        pred_eos_idx = next((idx for idx, token in enumerate(pred_seq) if token == token_map[EOS]), len(pred_seq))
        target_eos_idx = next((idx for idx, token in enumerate(target_seq) if token == token_map[EOS]), len(target_seq))

        # Trim the sequences at the first EOS index
        pred_trimmed = pred_seq[:pred_eos_idx]
        target_trimmed = target_seq[:target_eos_idx]

        # Compute Levenshtein distance for the trimmed sequences
        total_distance += levenshtein_distance(pred_trimmed, target_trimmed)

    # NOTE not sum of distances are returned!
    return total_distance


In [26]:
# # Optional lr scheduling as in Attention Is All You Need
# class LinearWarmupInverseSquarerootDecay(torch.optim.lr_scheduler.LambdaLR):
#     def __init__(self, d_model, warmup_steps=4000, optimizer=None):
#         self.d_model = d_model
#         self.warmup_steps = warmup_steps
#         super(LinearWarmupInverseSquarerootDecay, self).__init__(optimizer, self.lr_lambda)

#     def lr_lambda(self, step_num):
#         return (self.d_model ** -0.5) * min(step_num ** -0.5 if step_num != 0 else 1e20, step_num * (self.warmup_steps ** -1.5))
    
# steps = np.arange(0, 10000)
# learning_rates = [scheduler.lr_lambda(step) for step in steps]

# plt.figure(figsize=(6, 3))
# plt.plot(steps, learning_rates, label='Learning Rate')
# plt.title('Learning Rate Schedule')
# plt.xlabel('Step Number')
# plt.ylabel('Learning Rate')
# plt.grid(True)
# plt.legend()
# plt.show()

# scheduler = LinearWarmupInverseSquarerootDecay(d_model=d_model, optimizer=optimizer)

In [42]:
criterion = nn.CrossEntropyLoss(ignore_index=0)  # ignoring padding
#optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9, weight_decay=0.0)
optimizer = torch.optim.Adam(transformer.parameters())

def run_epoch(loader, is_train=True):
    epoch_loss = 0
    total_correct = 0
    total_tokens = 0
    total_levenshtein_distance = 0

    if is_train:
        transformer.train()
    else:
        transformer.eval()

    for batch in loader:
        src_data, x_mask, tgt_data = batch
        src_data = src_data.to(device)
        x_mask = x_mask.to(device)
        tgt_data = tgt_data.to(device)

        optimizer.zero_grad()

        # Disable gradient calculation during validation
        with torch.set_grad_enabled(is_train):
            output = transformer(src_data, x_mask, tgt_data[:, :-1])
            loss = criterion(output.contiguous().view(-1, settings["tgt_vocab_size"]), 
                             tgt_data[:, 1:].contiguous().view(-1).long())

            if is_train:
                loss.backward()
                optimizer.step()

        epoch_loss += loss.item()

        # Calculate masked accuracy
        output_tokens = output.argmax(dim=-1)  # Shape: (batch_size, seq_len)
        non_pad_mask = tgt_data[:, 1:] != settings["padding_token"]  # Ignore padding tokens (mask for target data)
        correct = (output_tokens == tgt_data[:, 1:]) & non_pad_mask  # Compare predictions to targets and ignore padding
        total_correct += correct.sum().item()
        total_tokens += non_pad_mask.sum().item()

        # Calculate masked Levenshtein distance
        total_levenshtein_distance += masked_levenshtein(output_tokens, tgt_data[:, 1:])

    avg_loss = epoch_loss / len(loader)
    accuracy = total_correct / total_tokens
    return avg_loss, accuracy, total_levenshtein_distance/(settings["batch_size"]*len(loader))

training_metrics = []
for epoch in range(settings["epochs"]):
    train_loss, train_accuracy, train_levenshtein = run_epoch(train_loader, is_train=True)
    valid_loss, valid_accuracy, valid_levenshtein = run_epoch(valid_loader, is_train=False)
    
    display_text = f"Epoch {epoch+1}"
    display_text += f" | Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, Train Levenshtein: {train_levenshtein:.4f}"
    display_text += f" | Valid Loss: {valid_loss:.4f}, Valid Accuracy: {valid_accuracy:.4f}, Valid Levenshtein: {valid_levenshtein:.4f}"
    print(display_text)

    training_metrics.append([epoch+1, train_loss, train_accuracy, train_levenshtein, valid_loss, valid_accuracy, valid_levenshtein])


Epoch 1 | Train Loss: 2.5188, Train Accuracy: 0.2377, Train Levenshtein: 12.6226 | Valid Loss: 2.4518, Valid Accuracy: 0.2567, Valid Levenshtein: 12.5116


In [52]:
import json

with open("settings.json", "w") as settings_file:
    settings_file.write(json.dumps(settings, indent=4))

torch.save(transformer.state_dict(), 'model.pth')

training_metrics_df = pd.DataFrame(
    training_metrics,
    columns=['epoch', 'train_loss', 'train_accuracy', 'train_levenshtein', 'valid_loss', 'valid_accuracy', 'valid_levenshtein']
)
training_metrics_df.to_csv("training_metrics.csv", index=False)

In [44]:
test_loss, test_accuracy, test_levenshtein = run_epoch(test_loader, is_train=False)

width = 34
print("\n" + "=" * width)
print(" Evaluation Results ")
print("=" * width)
print(f" Test Loss: {test_loss:.4f}")
print(f" Test Accuracy: {100*test_accuracy:.2f}%")
print(f" Test Levenshtein Distance: {test_levenshtein:.2f}")
print("=" * width)


 Evaluation Results 
 Test Loss: 2.4384
 Test Accuracy: 25.69%
 Test Levenshtein Distance: 12.15
