# Setup

In [68]:
# This is my first transformer project
# The idea is to create a transformer that can take as imput a list (N mod p1, N mod p2, ..., N mod pn)
# where N is a number (integer, rational, etc) and p1,p2,...,pn are prime numbers
# and returns the number N

# %% 
# Let us start by importing the necessary libraries
import torch  # Main framework for defining and training the transformer
import torch.nn as nn  # Neural network module
import torch.optim as optim  # Optimization functions
import numpy as np  # For numerical operations
import random  # For generating random numbers
import itertools  # (Optional) For generating structured datasets
import math  # For mathematical operations

import matplotlib.pyplot as plt  # (Optional) For visualization
from torch.utils.data import Dataset, DataLoader  # To handle training data efficiently

import time # For timing the training process

import json # For saving and loading the model

from torch.nn.utils.rnn import pad_sequence # For padding sequences to the same length

In [69]:
# Load configuration from a JSON file
with open("config_T_1.json", "r") as f:
    config = json.load(f)

# Access dataset parameters:
primes_list = config["dataset_parameters"]["primes_list"]  # List of prime numbers
number_samples = config["dataset_parameters"]["number_samples"]  # Number of samples to generate

# Access model parameters:
model_dimension = config["model_parameters"]["model_dimension"]
number_heads = config["model_parameters"]["number_heads"]
number_encoder_layers = config["model_parameters"]["number_encoder_layers"]
number_decoder_layers = config["model_parameters"]["number_decoder_layers"]
dimension_feedforward = config["model_parameters"]["dimension_feedforward"]
dropout_rate = config["model_parameters"]["dropout_rate"]
max_length = config["model_parameters"]["positional_encoding_maximum_length"]  # Maximum sequence length for positional encoding
# source and target vocabulary
# For instance, suppose our source vocabulary (mod values) is 0..11 and target vocabulary (digits) is 0..9 plus a special token.
src_vocab_size = config["model_parameters"]["source_vocab_size"]   # digits 0-9 plus 4 special tokens (SOS, EOS, SEP, PAD)
tgt_vocab_size = config["model_parameters"]["target_vocab_size"]   # digits 0-9 plus 2 special tokens (SOS, EOS)

learning_rate = config["training_parameters"]["learning_rate"]
batch_size = config["training_parameters"]["batch_size"]
number_epochs = config["training_parameters"]["number_epochs"]

print("Loaded configuration:")
print(config)

Loaded configuration:
{'dataset_parameters': {'primes_list': [3, 5, 7, 11], 'number_samples': 5}, 'model_parameters': {'model': 'Seq2SeqTransformer', 'model_dimension': 64, 'number_heads': 4, 'number_encoder_layers': 2, 'number_decoder_layers': 2, 'dimension_feedforward': 128, 'dropout_rate': 0.1, 'source_vocab_size': 14, 'target_vocab_size': 12, 'positional_encoding_maximum_length': 500}, 'training_parameters': {'learning_rate': 0.001, 'batch_size': 32, 'number_epochs': 100, 'optimizer': 'Adam'}, 'log_params': {'experiment_name': 'experiment_001', 'notes': 'First experiment with Seq2SeqTransformer'}}


# Dataset

In [70]:
# Define special tokens
SOS_TOKEN = 10   # start-of-sequence for target
EOS_TOKEN = 11   # end-of-sequence for target
SEP_TOKEN = 12   # separator token for input moduli

def tokenize_moduli(N, primes):
    tokens = [SOS_TOKEN] # Start with the SOS token
    for p in primes:
        remainder = N % p
        # Convert the remainder into its constituent digits
        tokens.extend([int(d) for d in str(remainder)])
        # Append a separator token after each remainder
        tokens.append(SEP_TOKEN)
    # Remove the final separator since it's not needed
    if tokens:
        tokens = tokens[:-1]
    # Append the EOS token at the end
    tokens.append(EOS_TOKEN)
    return tokens

class TranslationDataset(Dataset):
    def __init__(self, num_samples=number_samples, primes=primes_list):
        self.primes = primes
        # Calculate the product of primes for range of N
        self.P = 1
        for p in primes:
            self.P *= p
        self.samples = []
        for _ in range(num_samples):
            # Generate a random integer N in [0, P)
            N = torch.randint(0, self.P, (1,)).item()
            # Tokenize the input: each remainder becomes a sequence of digits with separators SEP
            input_tokens = tokenize_moduli(N, primes)
            # Prepare the target: add <SOS> at the beginning and <EOS> at the end
            output_tokens = [SOS_TOKEN] + [int(d) for d in str(N)] + [EOS_TOKEN]
            self.samples.append((torch.tensor(input_tokens, dtype=torch.long),
                                 torch.tensor(output_tokens, dtype=torch.long)))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        return self.samples[idx]

# Quick test of the updated dataset:
dataset = TranslationDataset(num_samples=number_samples)
print("Sample from dataset:", dataset[0])



Sample from dataset: (tensor([10,  0, 12,  0, 12,  0, 12,  0, 11]), tensor([10,  0, 11]))


In [71]:
#simple test to check the tokenization function
testnum = 1014  # Get a random sample from the dataset
print(testnum)
print([testnum % p for p in primes_list])
print(tokenize_moduli(testnum, primes_list))

1014
[0, 4, 6, 2]
[10, 0, 12, 4, 12, 6, 12, 2, 11]


In [72]:

PAD_TOKEN = 13  # Define a PAD token index (adjust your vocab sizes accordingly)

def collate_fn(batch):
    # Each batch element is a tuple: (src, tgt)
    src_batch, tgt_batch = zip(*batch)
    # Pad the sequences so that all sequences in the batch have equal length
    src_batch = pad_sequence(src_batch, batch_first=True, padding_value=PAD_TOKEN)
    tgt_batch = pad_sequence(tgt_batch, batch_first=True, padding_value=PAD_TOKEN)
    return src_batch, tgt_batch

# Create a DataLoader using the collate function:
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)


# Positional encoding + masking

In [84]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=500):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)  # Create a (max_len, d_model) matrix.
        position = torch.arange(0, max_len).unsqueeze(1)  # Shape: (max_len, 1) with positions 0,1,2,...
        # Compute a scaling factor for each even dimension.
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        # For even indices: use sine; for odd indices: use cosine.
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(1)  # Shape becomes (max_len, 1, d_model) for easy broadcasting.
        self.register_buffer('pe', pe)  # Register as a buffer so it’s part of the module but not a parameter.

    def forward(self, x):
        # x shape: (seq_len, batch_size, d_model)
        x = x + self.pe[:x.size(0)]  # Add positional encoding to each token embedding.
        return self.dropout(x)


In [85]:
# Test parameters
d_model = 4      # Dimensionality of embeddings/positional encodings
seq_len = 1      # Sequence length
batch_size = 4   # Batch size

# Create a dummy token matrix (for example, all zeros)
dummy_tokens = torch.zeros(seq_len, batch_size, d_model)
print("Original token matrix:")
print(dummy_tokens)

# Instantiate PositionalEncoding with no dropout for clarity
pos_enc = PositionalEncoding(d_model, dropout=0.0, max_len=10)

print("\nPositional encodings (first 4 positions):")
# The positional encoding matrix has shape (max_len, 1, d_model)
# We'll print the first 4 positions, which correspond to our sequence length.
print(pos_enc.pe[:seq_len])

# Add positional encoding to the dummy tokens
tokens_with_pe = pos_enc(dummy_tokens)
print("\nToken matrix after adding positional encoding:")
print(tokens_with_pe)

Original token matrix:
tensor([[[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]]])

Positional encodings (first 4 positions):
tensor([[[0., 1., 0., 1.]]])

Token matrix after adding positional encoding:
tensor([[[0., 1., 0., 1.],
         [0., 1., 0., 1.],
         [0., 1., 0., 1.],
         [0., 1., 0., 1.]]])


In [86]:
# OLD float mask
#def generate_square_subsequent_mask(sz):
#    # Create an upper-triangular matrix filled with ones
#    mask = torch.triu(torch.ones(sz, sz), diagonal=1)
#    # Replace 1's with -infinity and 0's with 0.0 so that the softmax later ignores the future positions.
#    mask = mask.masked_fill(mask == 1, float('-inf')).masked_fill(mask == 0, float(0.0))
#    return mask

# NEW boolean mask
def generate_square_subsequent_mask(sz, device=None):
    """
    Returns a boolean matrix of shape (sz, sz) where
    `True`  = block attention (upper-triangle, i.e. future positions)
    `False` = allow attention (diagonal & lower-triangle)
    """
    return torch.triu(
        torch.ones(sz, sz, dtype=torch.bool, device=device),
        diagonal=1
    )

# Example usage:
tgt_seq_len = 5  # suppose our target sequence length is 5
tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
print("Target Mask:\n", tgt_mask)

Target Mask:
 tensor([[False,  True,  True,  True,  True],
        [False, False,  True,  True,  True],
        [False, False, False,  True,  True],
        [False, False, False, False,  True],
        [False, False, False, False, False]])


# Transformer model

In [87]:
# Seq2Seq Transformer model
class Seq2SeqTransformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=model_dimension, nhead=number_heads,
                 num_encoder_layers=number_encoder_layers, num_decoder_layers=number_decoder_layers, dim_feedforward=dimension_feedforward, dropout=dropout_rate):
        super(Seq2SeqTransformer, self).__init__()
        self.d_model = d_model
        # Embedding layers for source (moduli) and target (digits).
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        # Positional encodings for source and target.
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        self.pos_decoder = PositionalEncoding(d_model, dropout)
        # Transformer module from PyTorch.
        self.transformer = nn.Transformer(d_model, nhead,
                                          num_encoder_layers, num_decoder_layers,
                                          dim_feedforward, dropout)
        # Final linear layer maps transformer output to target vocabulary logits.
        self.fc_out = nn.Linear(d_model, tgt_vocab_size)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None, src_key_padding_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
        # Expect src and tgt shapes: (batch_size, seq_len)
        # Transpose to shape: (seq_len, batch_size) as required by the transformer.
        src = src.transpose(0, 1)
        tgt = tgt.transpose(0, 1)
        # Obtain token embeddings and apply scaling.
        src_emb = self.src_embedding(src) * math.sqrt(self.d_model)
        tgt_emb = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
        # Add positional encodings.
        src_emb = self.pos_encoder(src_emb)
        tgt_emb = self.pos_decoder(tgt_emb)
        # Forward pass through the transformer.
        outs = self.transformer(src_emb, tgt_emb, src_mask=src_mask, tgt_mask=tgt_mask, src_key_padding_mask=src_key_padding_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask)
        # Project transformer outputs to logits and transpose back to (batch_size, seq_len, vocab_size).
        logits = self.fc_out(outs)
        return logits.transpose(0, 1)


In [88]:

# Instantiate the model.
model = Seq2SeqTransformer(src_vocab_size, tgt_vocab_size, d_model=model_dimension, nhead=number_heads,
                           num_encoder_layers=number_encoder_layers, num_decoder_layers=number_decoder_layers, dim_feedforward=dimension_feedforward, dropout=dropout_rate)
print(model)

Seq2SeqTransformer(
  (src_embedding): Embedding(14, 64)
  (tgt_embedding): Embedding(12, 64)
  (pos_encoder): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (pos_decoder): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-1): 2 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
          )
          (linear1): Linear(in_features=64, out_features=128, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=128, out_features=64, bias=True)
          (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=Fals



# Checks

In [90]:
#define a simple test dataset
batch_size = 4
primes = primes_list
sample_ds = TranslationDataset(num_samples=batch_size, primes=primes)
src, tgt = zip(*sample_ds)                 # list of tensors
src = pad_sequence(src, batch_first=True, padding_value=PAD_TOKEN)
tgt = pad_sequence(tgt, batch_first=True, padding_value=PAD_TOKEN)
print(src)
print(tgt)

tensor([[10,  2, 12,  4, 12,  0, 12,  4, 11, 13],
        [10,  2, 12,  0, 12,  3, 12,  4, 11, 13],
        [10,  1, 12,  3, 12,  2, 12,  4, 11, 13],
        [10,  2, 12,  1, 12,  5, 12,  1,  0, 11]])
tensor([[10,  2,  2,  4, 11],
        [10,  2,  9,  0, 11],
        [10,  2,  6,  8, 11],
        [10,  1,  3,  1, 11]])


In [91]:
# Teacher-forcing split
# i.e. return the target sequence without the last token and the target sequence without the first token
# these two cases are used as part of the transformer architecture
tgt_in  = tgt[:, :-1]
tgt_lab = tgt[:, 1:]
print("Target input (tgt_in):", tgt_in)
print("Target labels (tgt_lab):", tgt_lab)

Target input (tgt_in): tensor([[10,  2,  2,  4],
        [10,  2,  9,  0],
        [10,  2,  6,  8],
        [10,  1,  3,  1]])
Target labels (tgt_lab): tensor([[ 2,  2,  4, 11],
        [ 2,  9,  0, 11],
        [ 2,  6,  8, 11],
        [ 1,  3,  1, 11]])


In [92]:
# Generate masks for the source and target sequences
# Note masks will only be True if there is a padding token in the sequence
tgt_mask = generate_square_subsequent_mask(tgt_in.size(1),device=src.device)
src_kpm = (src == PAD_TOKEN)
tgt_kpm = (tgt_in == PAD_TOKEN)
print("Source key padding mask (src_kpm):", src_kpm)
print("Target key padding mask (tgt_kpm):", tgt_kpm)

Source key padding mask (src_kpm): tensor([[False, False, False, False, False, False, False, False, False,  True],
        [False, False, False, False, False, False, False, False, False,  True],
        [False, False, False, False, False, False, False, False, False,  True],
        [False, False, False, False, False, False, False, False, False, False]])
Target key padding mask (tgt_kpm): tensor([[False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False]])


In [None]:
# Forward pass through the model, check no errors
logits = model(
    src, tgt_in,
    src_mask=None,
    tgt_mask=tgt_mask,
    src_key_padding_mask=src_kpm,
    tgt_key_padding_mask=tgt_kpm,
    memory_key_padding_mask=src_kpm
)

assert logits.shape == (batch_size, tgt_in.size(1), tgt_vocab_size)
print("✓ forward pass shape OK")

✓ forward pass shape OK


In [94]:
# Compute the loss using CrossEntropyLoss
# check everything works as expected
criterion = nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_lab.reshape(-1))

loss.backward()        # should succeed without NaNs/Infs
print("✓ backward gradient computed, loss =", float(loss))
model.zero_grad()

✓ backward gradient computed, loss = 2.72538423538208


In [95]:
# count trainable parameters
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Total trainable parameters:", total_params)
# ≈ 420 k for 2-layer enc/dec, d_model 64, FF 128, vocab ≤ 14

Total trainable parameters: 170124


In [99]:
# Test the encoder respects the source padding mask
with torch.no_grad():
    src_masked = src.clone()
    src_masked[::2, 0] = PAD_TOKEN          # force PAD at position 0 of every 2nd sample
    src_kpm2 = (src_masked == PAD_TOKEN)

    memory = model.transformer.encoder(
        model.pos_encoder(model.src_embedding(src_masked.T) * math.sqrt(model.d_model)),
        src_key_padding_mask=src_kpm2
    )                                       # shape (seq, batch, 64)

    pad_norm  = memory[0, ::2].pow(2).sum(-1).sqrt().mean().item()
    real_norm = memory[0, 1::2].pow(2).sum(-1).sqrt().mean().item()
    assert pad_norm < 0.3 * real_norm, "Padding positions carry too much signal"
print("✓ encoder respects src padding mask")

AssertionError: Padding positions carry too much signal

# Training loop