# Modified Transformer Architecture

Sources:

- [correct `transformer implementation` from scratch in `pytorch`](https://towardsdatascience.com/build-your-own-transformer-from-scratch-using-pytorch-84c850470dcb) (in-depth tutorial from same author: [part1](https://towardsdatascience.com/all-you-need-to-know-about-attention-and-transformers-in-depth-understanding-part-1-552f0b41d021), [part2](https://towardsdatascience.com/all-you-need-to-know-about-attention-and-transformers-in-depth-understanding-part-2-bf2403804ada))
- [nice visuals for understanding `multi-head attention`](http://jalammar.github.io/illustrated-transformer/)
- [`positional encoding`](https://machinelearningmastery.com/a-gentle-introduction-to-positional-encoding-in-transformer-models-part-1/)
- [kaggle transformer code: ❗Contains mistakes (see comments), but nice overall explanation](https://www.kaggle.com/code/arunmohan003/transformer-from-scratch-using-pytorch)



In [1]:
!python --version

Python 3.9.13


In [2]:
import os
import math

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
seed = 42
torch.manual_seed(seed)
print(f"torch version: {torch.__version__}")

torch version: 2.4.1+cpu


# Dataset

For details about the dataset see `data_handling.ipynb`

In [29]:
SAVE_FOLDER = "processed_dataset"

df = pd.read_pickle(os.path.join(SAVE_FOLDER, "data.pkl"))
df.head()

Unnamed: 0_level_0,x_dominant_hand_0,x_dominant_hand_1,x_dominant_hand_2,x_dominant_hand_3,x_dominant_hand_4,x_dominant_hand_5,x_dominant_hand_6,x_dominant_hand_7,x_dominant_hand_8,x_dominant_hand_9,...,z_dominant_hand_11,z_dominant_hand_12,z_dominant_hand_13,z_dominant_hand_14,z_dominant_hand_15,z_dominant_hand_16,z_dominant_hand_17,z_dominant_hand_18,z_dominant_hand_19,z_dominant_hand_20
sequence_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1816796431,0.408832,0.519912,0.612159,0.707576,0.797313,0.494709,0.532817,0.553556,0.566219,0.391196,...,-0.245855,-0.269148,-0.129743,-0.251501,-0.278687,-0.26653,-0.152852,-0.257519,-0.275822,-0.266876
1816796431,0.398663,0.523662,0.638807,0.744236,0.832567,0.538486,0.564302,0.581011,0.597674,0.441541,...,-0.37077,-0.408097,-0.185217,-0.325494,-0.343373,-0.328294,-0.203126,-0.315719,-0.326104,-0.314282
1816796431,0.41929,0.509726,0.593165,0.685492,0.777913,0.483669,0.510993,0.53641,0.564583,0.393016,...,-0.28577,-0.318548,-0.155317,-0.274822,-0.312119,-0.316411,-0.181363,-0.286298,-0.316182,-0.322671
1816796431,0.398764,0.498118,0.583356,0.677779,0.775966,0.481279,0.491659,0.524974,0.571944,0.412262,...,-0.235725,-0.267054,-0.14138,-0.219369,-0.256553,-0.27369,-0.170996,-0.240285,-0.266193,-0.27811
1816796431,0.420213,0.49565,0.57179,0.659049,0.74974,0.485707,0.47593,0.501727,0.53915,0.438294,...,-0.186706,-0.217181,-0.10774,-0.165642,-0.201059,-0.222898,-0.131329,-0.183113,-0.208774,-0.225284


In [40]:
metadata_df = pd.read_csv(os.path.join(SAVE_FOLDER, "metadata.csv"), header=0)

max_phrase_len = max([len(it) for it in metadata_df.phrase.values])
possible_characters = sorted(set.union(*[set(p) for p in metadata_df.phrase.values]))
token_map = {c: i+3 for i, c in enumerate(possible_characters)}
token_map['P'] = 0 # padding
token_map['<'] = 1 # SOS
token_map['>'] = 2 # EOS
metadata_df.phrase = metadata_df.phrase.apply(lambda it: np.array([token_map[c] for c in '<'+it+'>'+('P'*(max_phrase_len-len(it)))], dtype=np.int32))

index = {row[0]: {"phrase": row[1], "signer_id": row[2]} for row in metadata_df.values}
str(index)[:100] + "..."

"{1816796431: {'phrase': array([ 1,  7,  3, 16, 31, 18, 18, 24, 21, 28, 34, 32, 18,  2,  0,  0,  0,\n ..."

In [5]:
# from sklearn.model_selection import train_test_split

# sequence_ids = df.index.unique()

# train_ids, temp_ids = train_test_split(sequence_ids, test_size=0.3, random_state=seed)
# valid_ids, test_ids = train_test_split(temp_ids, test_size=0.5, random_state=seed)

# train_df = df[df.index.isin(train_ids)]
# valid_df = df[df.index.isin(valid_ids)]
# test_df = df[df.index.isin(test_ids)]

In [41]:
import torch

class TransformerDataset(torch.utils.data.Dataset):
    # NOTE don't change the padding value as the Transformer still relies on 0
    def __init__(self, df, meta_data, seq_len=256, padding_value=0):
        self.df = df
        self.meta_data = meta_data
        self.sequence_ids = df.index.unique()
        self.padding_value = padding_value
        self.seq_len = seq_len

    def __len__(self):
        return len(self.sequence_ids)

    def __getitem__(self, idx):
        sequence_id = self.sequence_ids[idx]
        x_values = torch.tensor(self.df.loc[sequence_id].values, dtype=torch.float32)

        # Apply padding if the sequence is shorter than seq_len
        if x_values.shape[0] < self.seq_len:
            padding_size = self.seq_len - x_values.shape[0]
            padding = torch.full((padding_size, x_values.shape[1]), self.padding_value)
            x_values = torch.cat([x_values, padding], dim=0)
        elif x_values.shape[0] > self.seq_len:
            # Truncate the sequence if it's longer than seq_len
            x_values = x_values[:self.seq_len]

        y_phrase = self.meta_data[sequence_id]['phrase']
        return x_values, y_phrase

dataset = TransformerDataset(df, index)

In [84]:
from torch.utils.data import DataLoader

# NOTE Refactor DataSet (move huge objects outside) before increasing number of workers
dataloader = DataLoader(dataset, batch_size=32, num_workers=0)

for it in dataloader:
    print(it[0].shape)
    print(it[1])
    break

torch.Size([32, 256, 63])
tensor([[ 1,  7,  3,  ...,  0,  0,  0],
        [ 1,  5,  7,  ...,  0,  0,  0],
        [ 1, 13, 12,  ...,  0,  0,  0],
        ...,
        [ 1,  8,  5,  ...,  0,  0,  0],
        [ 1,  8,  5,  ...,  0,  0,  0],
        [ 1, 26, 14,  ...,  0,  0,  0]], dtype=torch.int32)


# Constructing the Transformer

In [86]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        attn_probs = torch.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_probs, V)
        return output
        
    def split_heads(self, x):
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
        
    def combine_heads(self, x):
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
        
    def forward(self, Q, K, V, mask=None):
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))
        
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
        output = self.W_o(self.combine_heads(attn_output))
        return output

In [87]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_len, n=10000):
        super(PositionalEncoding, self).__init__()

        assert d_model % 2 == 0, "d_model must be even"
        
        pe = torch.zeros(max_seq_len, d_model)
        position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(n) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        # NOTE
        # register buffer in Pytorch ->
        # If you have parameters in your model, which should be saved and restored in the state_dict,
        # but not trained by the optimizer, you should register them as buffers.
        self.register_buffer('pe', pe.unsqueeze(0))
        
    def forward(self, x):
        # Broadcasting mechanism (automatically works for multiple batches even when shapes don't match along batch dim)
        return x + self.pe[:, :x.size(1)]

In [None]:
class PositionalLandmarkEmbedding(nn.Module):
    def __init__(self, len_of_seq, d_model, num_features, num_conv_layers=3, filter_size=11):
        super(PositionalLandmarkEmbedding, self).__init__()
        self.d_model = d_model
        self.len_of_seq = len_of_seq

        first_conv = nn.Conv1d(in_channels=num_features, out_channels=d_model, kernel_size=filter_size, padding='same')
        rest_of_convs = [
            nn.Conv1d(in_channels=d_model, out_channels=d_model, kernel_size=filter_size, padding='same')
            for _ in range(num_conv_layers-1)
        ]
        
        self.conv_block = nn.Sequential(
            first_conv, *rest_of_convs
        )
        
        self.pos_encoding = PositionalEncoding(d_model, len_of_seq)

    def forward(self, x):
        # x is expected to be of shape (batch_size, seq_len, num_of_features)
        # Permute to (batch_size, num_of_features, seq_len) for Conv1D
        x = x.permute(0, 2, 1)
        # Apply convolutional layers
        x = self.conv_block(x)
        # Scale the output
        x *= math.sqrt(self.d_model)
        # Permute back to (batch_size, seq_len, d_model)
        x = x.permute(0, 2, 1)
        # Add positional encoding
        x = self.pos_encoding(x)
        return x

In [None]:
from torch_geometric.nn import GCNConv, GATConv, GATv2Conv
from torch_geometric.nn import global_mean_pool

class PositionalGraphEmbedding(nn.Module):
    def __init__(self, len_of_seq, d_model, num_features_per_node):
        super(PositionalLandmarkEmbedding, self).__init__()
        self.d_model = d_model

        # TODO experiment with hyperparameters
        hidden_dim = 256
        self.conv1 = GATConv(num_features_per_node, hidden_dim, heads=1, concat=False)
        self.conv2 = GATConv(hidden_dim, d_model, heads=1, concat=False)
        
        self.pos_encoding = PositionalEncoding(d_model, len_of_seq)

    def forward(self, x, edge_index, batch):
        # 1. Obtain node embeddings
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()

        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        return x

In [88]:
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(FeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

In [89]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = FeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask):
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x

In [90]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(DecoderLayer, self).__init__()
        self.masked_self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = FeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, enc_output, src_mask, tgt_mask):
        attn_output = self.masked_self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))
        attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)
        x = self.norm2(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        return x

In [91]:
class Transformer(nn.Module):
    def __init__(self, len_of_seq, num_inp_features, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout):
        super(Transformer, self).__init__()
        self.encoder_embedding = PositionalLandmarkEmbedding(
            len_of_seq=len_of_seq,
            d_model=d_model,
            num_features=num_inp_features,
            num_conv_layers=3,
            filter_size=11
        )
        self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length)

        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])

        self.fc = nn.Linear(d_model, tgt_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def generate_mask(self, src, tgt):
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
        tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)
        seq_length = tgt.size(1)
        nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool()
        tgt_mask = tgt_mask & nopeak_mask
        return src_mask, tgt_mask

    def forward(self, src, tgt):
        src_mask, tgt_mask = self.generate_mask(src, tgt)
        src_embedded = self.dropout(self.positional_encoding(self.encoder_embedding(src)))
        tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)))

        enc_output = src_embedded
        for enc_layer in self.encoder_layers:
            enc_output = enc_layer(enc_output, src_mask)

        dec_output = tgt_embedded
        for dec_layer in self.decoder_layers:
            dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)

        output = self.fc(dec_output)
        return output

In [92]:
src_vocab_size = 5000
tgt_vocab_size = 5000
d_model = 512
num_heads = 8
num_layers = 6
d_ff = 2048
max_seq_length = 100
dropout = 0.1

transformer = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout)

# Generate random sample data
src_data = torch.randint(1, src_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)
tgt_data = torch.randint(1, tgt_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)

In [80]:
src_data.shape

torch.Size([64, 100])

In [81]:
tgt_data.shape

torch.Size([64, 100])

In [72]:
# # Optional lr scheduling as in Attention Is All You Need
# class LinearWarmupInverseSquarerootDecay(torch.optim.lr_scheduler.LambdaLR):
#     def __init__(self, d_model, warmup_steps=4000, optimizer=None):
#         self.d_model = d_model
#         self.warmup_steps = warmup_steps
#         super(LinearWarmupInverseSquarerootDecay, self).__init__(optimizer, self.lr_lambda)

#     def lr_lambda(self, step_num):
#         return (self.d_model ** -0.5) * min(step_num ** -0.5 if step_num != 0 else 1e20, step_num * (self.warmup_steps ** -1.5))
    
# steps = np.arange(0, 10000)
# learning_rates = [scheduler.lr_lambda(step) for step in steps]

# plt.figure(figsize=(6, 3))
# plt.plot(steps, learning_rates, label='Learning Rate')
# plt.title('Learning Rate Schedule')
# plt.xlabel('Step Number')
# plt.ylabel('Learning Rate')
# plt.grid(True)
# plt.legend()
# plt.show()

# scheduler = LinearWarmupInverseSquarerootDecay(d_model=d_model, optimizer=optimizer)

In [82]:
criterion = nn.CrossEntropyLoss(ignore_index=0) # ignoring padding
optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9, weight_decay=0.0)
# optimizer = torch.optim.Adam(transformer.parameters())

transformer.train()

for epoch in range(100):
    optimizer.zero_grad()
    output = transformer(src_data, tgt_data[:, :-1])
    loss = criterion(output.contiguous().view(-1, tgt_vocab_size), tgt_data[:, 1:].contiguous().view(-1))
    loss.backward()
    optimizer.step()
    print(f"Epoch: {epoch+1}, Loss: {loss.item()}")