# Setup

In [167]:
# This is my first transformer project
# The idea is to create a transformer that can take as imput a list (N mod p1, N mod p2, ..., N mod pn)
# where N is a number (integer, rational, etc) and p1,p2,...,pn are prime numbers
# and returns the number N

# %% 
# Let us start by importing the necessary libraries
import torch  # Main framework for defining and training the transformer
import torch.nn as nn  # Neural network module
import torch.optim as optim  # Optimization functions
import numpy as np  # For numerical operations
import random  # For generating random numbers
import itertools  # (Optional) For generating structured datasets
import math  # For mathematical operations

import matplotlib.pyplot as plt  # (Optional) For visualization
from torch.utils.data import Dataset, DataLoader  # To handle training data efficiently

import time # For timing the training process

import json # For saving and loading the model

from torch.nn.utils.rnn import pad_sequence # For padding sequences to the same length

In [168]:
# Define the device to use for training
#OLD: device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.backends.mps.is_available():
    device = torch.device("mps")          # Apple-silicon GPU
elif torch.cuda.is_available():
    device = torch.device("cuda")         # NVIDIA / AMD GPU
else:
    device = torch.device("cpu")

print("Using device:", device)

Using device: mps


In [169]:
# Load configuration from a JSON file
with open("config_T_2.json", "r") as f:
    config = json.load(f)

# Access dataset parameters:
primes_list = config["dataset_parameters"]["primes_list"]  # List of prime numbers
number_samples = config["dataset_parameters"]["number_samples"]  # Number of samples to generate

# Access model parameters:
model_dimension = config["model_parameters"]["model_dimension"]
number_heads = config["model_parameters"]["number_heads"]
number_encoder_layers = config["model_parameters"]["number_encoder_layers"]
number_decoder_layers = config["model_parameters"]["number_decoder_layers"]
dimension_feedforward = config["model_parameters"]["dimension_feedforward"]
dropout_rate = config["model_parameters"]["dropout_rate"]
max_length = config["model_parameters"]["positional_encoding_maximum_length"]  # Maximum sequence length for positional encoding
# source and target vocabulary
# For instance, suppose our source vocabulary (mod values) is 0..11 and target vocabulary (digits) is 0..9 plus a special token.
src_vocab_size = config["model_parameters"]["source_vocab_size"]   # digits 0-9 plus 4 special tokens (SOS, EOS, SEP, PAD)
tgt_vocab_size = config["model_parameters"]["target_vocab_size"]   # digits 0-9 plus 2 special tokens (SOS, EOS)

learning_rate = config["training_parameters"]["learning_rate"]
batch_size = config["training_parameters"]["batch_size"]
number_epochs = config["training_parameters"]["number_epochs"]

print("Loaded configuration:")
print(config)

Loaded configuration:
{'dataset_parameters': {'primes_list': [11, 13, 17, 23], 'number_samples': 10000}, 'model_parameters': {'model': 'Seq2SeqTransformer', 'model_dimension': 128, 'number_heads': 8, 'number_encoder_layers': 4, 'number_decoder_layers': 4, 'dimension_feedforward': 512, 'dropout_rate': 0.1, 'source_vocab_size': 14, 'target_vocab_size': 13, 'positional_encoding_maximum_length': 500}, 'training_parameters': {'learning_rate': 0.001, 'batch_size': 32, 'number_epochs': 100, 'optimizer': 'Adam'}, 'log_params': {'experiment_name': 'experiment_001', 'notes': 'First experiment with Seq2SeqTransformer'}, 'special_tokens': {'PAD_TOKEN': 10, 'SOS_TOKEN': 11, 'EOS_TOKEN': 12, 'SEP_TOKEN': 13}}


In [170]:
print("MPS available:", torch.backends.mps.is_available())

MPS available: True


# Dataset

In [171]:
# Define special tokens
SOS_TOKEN = config["special_tokens"]["SOS_TOKEN"]   # start-of-sequence for target
EOS_TOKEN = config["special_tokens"]["EOS_TOKEN"]   # end-of-sequence for target
SEP_TOKEN = config["special_tokens"]["SEP_TOKEN"]   # separator token for input moduli

def tokenize_moduli(N, primes):
    tokens = [SOS_TOKEN] # Start with the SOS token
    for p in primes:
        remainder = N % p
        # Convert the remainder into its constituent digits
        tokens.extend([int(d) for d in str(remainder)])
        # Append a separator token after each remainder
        tokens.append(SEP_TOKEN)
    # Remove the final separator since it's not needed
    if tokens:
        tokens = tokens[:-1]
    # Append the EOS token at the end
    tokens.append(EOS_TOKEN)
    return tokens

class TranslationDataset(Dataset):
    def __init__(self, num_samples=number_samples, primes=primes_list):
        self.primes = primes
        # Calculate the product of primes for range of N
        self.P = 1
        for p in primes:
            self.P *= p
        self.samples = []
        for _ in range(num_samples):
            # Generate a random integer N in [0, P)
            N = torch.randint(0, self.P, (1,)).item()
            # Tokenize the input: each remainder becomes a sequence of digits with separators SEP
            input_tokens = tokenize_moduli(N, primes)
            # Prepare the target: add <SOS> at the beginning and <EOS> at the end
            output_tokens = [SOS_TOKEN] + [int(d) for d in str(N)] + [EOS_TOKEN]
            self.samples.append((torch.tensor(input_tokens, dtype=torch.long),
                                 torch.tensor(output_tokens, dtype=torch.long)))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        return self.samples[idx]

# Quick test of the updated dataset:
dataset = TranslationDataset(num_samples=number_samples)
print("Sample from dataset:", dataset[0])



Sample from dataset: (tensor([11,  1,  0, 13,  4, 13,  1,  3, 13,  0, 12]), tensor([11,  1,  2,  6,  2,  7, 12]))


In [172]:
#simple test to check the tokenization function
testnum = 1014  # Get a random sample from the dataset
print(testnum)
print([testnum % p for p in primes_list])
print(tokenize_moduli(testnum, primes_list))

1014
[2, 0, 11, 2]
[11, 2, 13, 0, 13, 1, 1, 13, 2, 12]


In [173]:

PAD_TOKEN = config["special_tokens"]["PAD_TOKEN"]  # Define a PAD token index (adjust your vocab sizes accordingly)

def collate_fn(batch):
    # Each batch element is a tuple: (src, tgt)
    src_batch, tgt_batch = zip(*batch)
    # Pad the sequences so that all sequences in the batch have equal length
    src_batch = pad_sequence(src_batch, batch_first=True, padding_value=PAD_TOKEN)
    tgt_batch = pad_sequence(tgt_batch, batch_first=True, padding_value=PAD_TOKEN)
    return src_batch, tgt_batch

# Create a DataLoader using the collate function:
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)


# Positional encoding + masking

In [174]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=500):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)  # Create a (max_len, d_model) matrix.
        position = torch.arange(0, max_len).unsqueeze(1)  # Shape: (max_len, 1) with positions 0,1,2,...
        # Compute a scaling factor for each even dimension.
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        # For even indices: use sine; for odd indices: use cosine.
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(1)  # Shape becomes (max_len, 1, d_model) for easy broadcasting.
        self.register_buffer('pe', pe)  # Register as a buffer so it’s part of the module but not a parameter.

    def forward(self, x):
        # x shape: (seq_len, batch_size, d_model)
        x = x + self.pe[:x.size(0)]  # Add positional encoding to each token embedding.
        return self.dropout(x)


In [175]:
# Test parameters
d_model = 4      # Dimensionality of embeddings/positional encodings
seq_len = 1      # Sequence length
batch_size = 4   # Batch size

# Create a dummy token matrix (for example, all zeros)
dummy_tokens = torch.zeros(seq_len, batch_size, d_model)
print("Original token matrix:")
print(dummy_tokens)

# Instantiate PositionalEncoding with no dropout for clarity
pos_enc = PositionalEncoding(d_model, dropout=0.0, max_len=10)

print("\nPositional encodings (first 4 positions):")
# The positional encoding matrix has shape (max_len, 1, d_model)
# We'll print the first 4 positions, which correspond to our sequence length.
print(pos_enc.pe[:seq_len])

# Add positional encoding to the dummy tokens
tokens_with_pe = pos_enc(dummy_tokens)
print("\nToken matrix after adding positional encoding:")
print(tokens_with_pe)

Original token matrix:
tensor([[[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]]])

Positional encodings (first 4 positions):
tensor([[[0., 1., 0., 1.]]])

Token matrix after adding positional encoding:
tensor([[[0., 1., 0., 1.],
         [0., 1., 0., 1.],
         [0., 1., 0., 1.],
         [0., 1., 0., 1.]]])


In [176]:
# OLD float mask
#def generate_square_subsequent_mask(sz):
#    # Create an upper-triangular matrix filled with ones
#    mask = torch.triu(torch.ones(sz, sz), diagonal=1)
#    # Replace 1's with -infinity and 0's with 0.0 so that the softmax later ignores the future positions.
#    mask = mask.masked_fill(mask == 1, float('-inf')).masked_fill(mask == 0, float(0.0))
#    return mask

# NEW boolean mask
def generate_square_subsequent_mask(sz, device=None):
    """
    Returns a boolean matrix of shape (sz, sz) where
    `True`  = block attention (upper-triangle, i.e. future positions)
    `False` = allow attention (diagonal & lower-triangle)
    """
    return torch.triu(
        torch.ones(sz, sz, dtype=torch.bool, device=device),
        diagonal=1
    )

# Example usage:
tgt_seq_len = 5  # suppose our target sequence length is 5
tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
print("Target Mask:\n", tgt_mask)

Target Mask:
 tensor([[False,  True,  True,  True,  True],
        [False, False,  True,  True,  True],
        [False, False, False,  True,  True],
        [False, False, False, False,  True],
        [False, False, False, False, False]])


# Transformer model

In [177]:
# Seq2Seq Transformer model
class Seq2SeqTransformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=model_dimension, nhead=number_heads,
                 num_encoder_layers=number_encoder_layers, num_decoder_layers=number_decoder_layers, dim_feedforward=dimension_feedforward, dropout=dropout_rate):
        super(Seq2SeqTransformer, self).__init__()
        self.d_model = d_model
        # Embedding layers for source (moduli) and target (digits).
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        # Positional encodings for source and target.
        self.pos_encoder = PositionalEncoding(d_model, dropout, max_len=max_length)
        self.pos_decoder = PositionalEncoding(d_model, dropout, max_len=max_length)
        # Transformer module from PyTorch.
        self.transformer = nn.Transformer(d_model, nhead,
                                          num_encoder_layers, num_decoder_layers,
                                          dim_feedforward, dropout)
        # Final linear layer maps transformer output to target vocabulary logits.
        self.fc_out = nn.Linear(d_model, tgt_vocab_size)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None, src_key_padding_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
        # Expect src and tgt shapes: (batch_size, seq_len)
        # Transpose to shape: (seq_len, batch_size) as required by the transformer.
        src = src.transpose(0, 1)
        tgt = tgt.transpose(0, 1)
        # Obtain token embeddings and apply scaling.
        src_emb = self.src_embedding(src) * math.sqrt(self.d_model)
        tgt_emb = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
        # Add positional encodings.
        src_emb = self.pos_encoder(src_emb)
        tgt_emb = self.pos_decoder(tgt_emb)
        # Forward pass through the transformer.
        outs = self.transformer(src_emb, tgt_emb, src_mask=src_mask, tgt_mask=tgt_mask, src_key_padding_mask=src_key_padding_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask)
        # Project transformer outputs to logits and transpose back to (batch_size, seq_len, vocab_size).
        logits = self.fc_out(outs)
        return logits.transpose(0, 1)


In [178]:

# Instantiate the model.
model = Seq2SeqTransformer(src_vocab_size, tgt_vocab_size, d_model=model_dimension, nhead=number_heads,
                           num_encoder_layers=number_encoder_layers, num_decoder_layers=number_decoder_layers, dim_feedforward=dimension_feedforward, dropout=dropout_rate)
# Move the model to the appropriate device (GPU or CPU).
model.to(device)
# Print the model architecture
print(model)

Seq2SeqTransformer(
  (src_embedding): Embedding(14, 128)
  (tgt_embedding): Embedding(13, 128)
  (pos_encoder): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (pos_decoder): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-3): 4 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
          )
          (linear1): Linear(in_features=128, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=512, out_features=128, bias=True)
          (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inpl



# Checks

First examine sample dataset

In [183]:
#define a simple test dataset
batch_size = 4
primes = primes_list
sample_ds = TranslationDataset(num_samples=batch_size, primes=primes)
src, tgt = zip(*sample_ds)                 # list of tensors
src = pad_sequence(src, batch_first=True, padding_value=PAD_TOKEN)
tgt = pad_sequence(tgt, batch_first=True, padding_value=PAD_TOKEN)
src, tgt = src.to(device), tgt.to(device)
print(src)
print(tgt)

tensor([[11,  3, 13,  2, 13,  0, 13,  9, 12, 10, 10],
        [11,  2, 13,  8, 13,  1,  2, 13,  3, 12, 10],
        [11,  1, 13,  1,  1, 13,  2, 13,  1,  9, 12],
        [11,  8, 13,  1, 13,  7, 13,  2,  1, 12, 10]], device='mps:0')
tensor([[11,  1,  5,  8,  1,  0, 12],
        [11,  8,  8,  3,  5, 12, 10],
        [11,  3,  7,  5,  5,  5, 12],
        [11,  1,  3,  1,  3,  1, 12]], device='mps:0')


In [184]:
# Teacher-forcing split
# i.e. return the target sequence without the last token and the target sequence without the first token
# these two cases are used as part of the transformer architecture
tgt_in  = tgt[:, :-1]
tgt_lab = tgt[:, 1:]
print("Target input (tgt_in):", tgt_in)
print("Target labels (tgt_lab):", tgt_lab)

Target input (tgt_in): tensor([[11,  1,  5,  8,  1,  0],
        [11,  8,  8,  3,  5, 12],
        [11,  3,  7,  5,  5,  5],
        [11,  1,  3,  1,  3,  1]], device='mps:0')
Target labels (tgt_lab): tensor([[ 1,  5,  8,  1,  0, 12],
        [ 8,  8,  3,  5, 12, 10],
        [ 3,  7,  5,  5,  5, 12],
        [ 1,  3,  1,  3,  1, 12]], device='mps:0')


Then generate and print target mask and source + target padding masks

In [185]:
# Generate masks for the source and target sequences
# Note masks will only be True if there is a padding token in the sequence
tgt_mask = generate_square_subsequent_mask(tgt_in.size(1),device=src.device)
src_kpm = (src == PAD_TOKEN)
tgt_kpm = (tgt_in == PAD_TOKEN)
print("Source key padding mask (src_kpm):", src_kpm)
print("Target key padding mask (tgt_kpm):", tgt_kpm)
print("Target mask (tgt_mask):", tgt_mask)

Source key padding mask (src_kpm): tensor([[False, False, False, False, False, False, False, False, False,  True,
          True],
        [False, False, False, False, False, False, False, False, False, False,
          True],
        [False, False, False, False, False, False, False, False, False, False,
         False],
        [False, False, False, False, False, False, False, False, False, False,
          True]], device='mps:0')
Target key padding mask (tgt_kpm): tensor([[False, False, False, False, False, False],
        [False, False, False, False, False, False],
        [False, False, False, False, False, False],
        [False, False, False, False, False, False]], device='mps:0')
Target mask (tgt_mask): tensor([[False,  True,  True,  True,  True,  True],
        [False, False,  True,  True,  True,  True],
        [False, False, False,  True,  True,  True],
        [False, False, False, False,  True,  True],
        [False, False, False, False, False,  True],
        [False, Fals

Then forward pass through the untrained model and check this behaves as expected

In [186]:
# Forward pass through the model, check no errors
logits = model(
    src, tgt_in,
    src_mask=None,
    tgt_mask=tgt_mask,
    src_key_padding_mask=src_kpm,
    tgt_key_padding_mask=tgt_kpm,
    memory_key_padding_mask=src_kpm
)

assert logits.shape == (batch_size, tgt_in.size(1), tgt_vocab_size)
print("✓ forward pass shape OK")

✓ forward pass shape OK


Compute loss and check this behaves as expected

In [187]:
# Compute the loss using CrossEntropyLoss
# check everything works as expected
criterion = nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_lab.reshape(-1))

loss.backward()        # should succeed without NaNs/Infs
print("✓ backward gradient computed, loss =", float(loss))
model.zero_grad()

✓ backward gradient computed, loss = 2.519232749938965


Count parameters

In [188]:
# count trainable parameters
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Total trainable parameters:", total_params)
# ≈ 420 k for 2-layer enc/dec, d_model 64, FF 128, vocab ≤ 14

Total trainable parameters: 1857037


Check that the attention weights in the first encoder layer are zero for padding tokens

In [190]:
# 1) Build a two-sample batch and inject PAD at position 0 of sample 0
ds = TranslationDataset(num_samples=2)
src_batch, _ = zip(*ds)
src_batch = pad_sequence(src_batch, batch_first=True, padding_value=PAD_TOKEN)
src_batch[0, 0] = PAD_TOKEN                         # <- forced PAD
src_batch = src_batch.to(device)                  # move to device

# 2) Grab that first encoder self-attention module
mha = model.transformer.encoder.layers[0].self_attn

# 3) Prepare the same embeddings the encoder would see
with torch.no_grad():
    x = model.pos_encoder(                          # add positions
            model.src_embedding(src_batch.T) *      # embed IDs ➜ 64-D
            math.sqrt(model.d_model)                # √d scaling
        )                                           # shape (seq, batch, 64)

    # 4) Run *just* the attention, asking for weights
    attn_out, attn_w = mha(                         # attn_w shape: (batch, heads, seq, seq)
        x, x, x,
        need_weights=True,
        average_attn_weights=False,
        key_padding_mask=(src_batch == PAD_TOKEN)   # boolean PAD mask
    )

# 5) Average over heads, then print two key columns
W = attn_w.mean(1)              # (batch, seq, seq)

# here we look at the attention weights for the first sample
# first we look at columnn 0, ie the attention weights referring to the PAD token
# then we look at column 1, ie the attention weights referring to the real key
print("Weights onto PAD key (col-0): ",  W[0, :, 0])
print("Weights onto real key (col-1):",  W[0, :, 1])


Weights onto PAD key (col-0):  tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='mps:0')
Weights onto real key (col-1): tensor([0.1389, 0.1373, 0.1389, 0.1388, 0.2778, 0.2778, 0.2776, 0.2778, 0.1389,
        0.2778], device='mps:0')


Now run a test training with a single sample to make sure we can overfit

In [191]:
model.train()                              # ensure dropout is ON (helps test realism)

# 0. For reproducibility (optional)
torch.manual_seed(0);  np.random.seed(0);  random.seed(0)

# 1. Grab *one* random (src, tgt) pair
sample_src, sample_tgt = TranslationDataset(num_samples=1)[0]
sample_src = sample_src.unsqueeze(0)       # shape (1, src_len)
sample_tgt = sample_tgt.unsqueeze(0)       # shape (1, tgt_len)
sample_src, sample_tgt = sample_src.to(device), sample_tgt.to(device)  # move to device

# 2. Optimiser & loss
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion  = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)

# 3. Training loop
for step in range(301):
    # --- teacher-forcing split --------------------------------------------
    tgt_in  = sample_tgt[:, :-1]           # BOS … last-1
    tgt_lab = sample_tgt[:, 1:]            # next token (labels)

    # --- boolean masks ----------------------------------------------------
    src_kpm = (sample_src == PAD_TOKEN)               # shape (1, src_len)
    tgt_kpm = (tgt_in    == PAD_TOKEN)               # shape (1, tgt_len-1)
    tgt_mask = generate_square_subsequent_mask(
                  tgt_in.size(1), device=sample_src.device)

    # --- forward ----------------------------------------------------------
    logits = model(
        sample_src, tgt_in,
        src_mask=None,
        tgt_mask=tgt_mask,
        src_key_padding_mask=src_kpm,
        tgt_key_padding_mask=tgt_kpm,
        memory_key_padding_mask=src_kpm
    )

    loss = criterion(
        logits.reshape(-1, logits.size(-1)),         # (tokens ×  vocab)
        tgt_lab.reshape(-1)                          # (tokens)
    )

    # --- backward & step --------------------------------------------------
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # --- monitor ----------------------------------------------------------
    if step % 50 == 0 or step == 299:
        print(f"step {step:3d} | loss {loss.item():.4f}")

model.eval()                               # back to eval mode afterwards
# -------------------------------------------------------------------------

step   0 | loss 2.6668
step  50 | loss 0.1900
step 100 | loss 0.1483
step 150 | loss 0.0311
step 200 | loss 0.6515
step 250 | loss 0.0505
step 299 | loss 0.0676
step 300 | loss 0.0548


Seq2SeqTransformer(
  (src_embedding): Embedding(14, 128)
  (tgt_embedding): Embedding(13, 128)
  (pos_encoder): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (pos_decoder): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-3): 4 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
          )
          (linear1): Linear(in_features=128, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=512, out_features=128, bias=True)
          (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inpl

# Training loop

In [192]:
# ── reproducibility (optional)
torch.manual_seed(0); np.random.seed(0); random.seed(0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ── (re-)instantiate model & move to device
model = Seq2SeqTransformer(src_vocab_size, tgt_vocab_size,
                           d_model=model_dimension, nhead=number_heads,
                           num_encoder_layers=number_encoder_layers,
                           num_decoder_layers=number_decoder_layers,
                           dim_feedforward=dimension_feedforward,
                           dropout=dropout_rate).to(device)

# ── criterion & optimiser
criterion  = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
optimizer   = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler   = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)

# ---------------- 1. TRAINING LOOP ----------------
for epoch in range(1, number_epochs + 1):
    model.train()
    running_loss, n_tokens = 0.0, 0

    for src, tgt in data_loader:
        src, tgt = src.to(device), tgt.to(device)

        # Teacher forcing split
        tgt_in  = tgt[:, :-1]          # BOS … last-1
        tgt_lab = tgt[:, 1:]           # next token

        # Boolean masks
        src_kpm = (src == PAD_TOKEN)                   # (batch, src_len)
        tgt_kpm = (tgt_in == PAD_TOKEN)                # (batch, tgt_len-1)
        tgt_mask = generate_square_subsequent_mask(
                       tgt_in.size(1), device=device)  # (tgt_len-1, tgt_len-1)

        # Forward
        logits = model(src, tgt_in,
                       src_mask=None,
                       tgt_mask=tgt_mask,
                       src_key_padding_mask=src_kpm,
                       tgt_key_padding_mask=tgt_kpm,
                       memory_key_padding_mask=src_kpm)

        loss = criterion(logits.reshape(-1, logits.size(-1)),
                         tgt_lab.reshape(-1))

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # optional
        optimizer.step()

        running_loss += loss.item() * tgt_lab.numel()
        n_tokens     += tgt_lab.numel()

    scheduler.step()
    avg_loss = running_loss / n_tokens
    print(f"Epoch {epoch:3d}/{number_epochs} | avg tok-loss {avg_loss:.4f}")

    # ---- quick sanity decode every 10 epochs ----
    if epoch % 10 == 0:
        model.eval()
        with torch.no_grad():
            sample_src, _ = dataset[random.randrange(len(dataset))]
            sample_src = sample_src.unsqueeze(0).to(device)
            src_kpm = (sample_src == PAD_TOKEN)

            # start with BOS
            generated = [SOS_TOKEN]
            for _ in range(20):                          # max 20 digits
                tgt_in  = torch.tensor([generated], device=device)
                tgt_mask = generate_square_subsequent_mask(
                               tgt_in.size(1), device=device)
                logits = model(sample_src, tgt_in,
                               src_mask=None, tgt_mask=tgt_mask,
                               src_key_padding_mask=src_kpm,
                               tgt_key_padding_mask=(tgt_in == PAD_TOKEN),
                               memory_key_padding_mask=src_kpm)
                next_tok = logits[0, -1].argmax(-1).item()
                generated.append(next_tok)
                if next_tok == EOS_TOKEN:
                    break

            # remove BOS/EOS and print digits
            digits = [str(t) for t in generated[1:-1]]
            print("  sample decode →", "".join(digits))


Using device: cpu




Epoch   1/100 | avg tok-loss 2.0884
Epoch   2/100 | avg tok-loss 1.9896
Epoch   3/100 | avg tok-loss 1.9637
Epoch   4/100 | avg tok-loss 1.9499
Epoch   5/100 | avg tok-loss 1.9476
Epoch   6/100 | avg tok-loss 1.9351
Epoch   7/100 | avg tok-loss 1.9280
Epoch   8/100 | avg tok-loss 1.9214
Epoch   9/100 | avg tok-loss 1.9179
Epoch  10/100 | avg tok-loss 2.0375
  sample decode → 
Epoch  11/100 | avg tok-loss 2.3518
Epoch  12/100 | avg tok-loss 2.3514
Epoch  13/100 | avg tok-loss 2.3510
Epoch  14/100 | avg tok-loss 2.3564
Epoch  15/100 | avg tok-loss 2.0761
Epoch  16/100 | avg tok-loss 1.9927
Epoch  17/100 | avg tok-loss 1.9772
Epoch  18/100 | avg tok-loss 1.9690
Epoch  19/100 | avg tok-loss 1.9692
Epoch  20/100 | avg tok-loss 2.0897
  sample decode → 
Epoch  21/100 | avg tok-loss 2.2697
Epoch  22/100 | avg tok-loss 2.2477
Epoch  23/100 | avg tok-loss 2.2027
Epoch  24/100 | avg tok-loss 2.1843
Epoch  25/100 | avg tok-loss 2.1742
Epoch  26/100 | avg tok-loss 2.1696
Epoch  27/100 | avg tok-lo

# Testing

Define function to convert a list of moduli into an appropriate PyTorch source tensor

In [150]:
def moduli_to_tensor(mod_list, device, add_batch_dim=True):
    """
    mod_list:  [r1, r2, …, rn]  in the *same prime order* you used for training.
    Returns:   tensor(shape=(1, src_len)) on the chosen device.
    """
    tokens = [SOS_TOKEN]
    for r in mod_list:
        tokens.extend([int(d) for d in str(r)])   # decimal digits
        tokens.append(SEP_TOKEN)
    tokens[-1] = EOS_TOKEN                        # replace last SEP with EOS
    t = torch.tensor(tokens, dtype=torch.long, device=device)
    return t.unsqueeze(0) if add_batch_dim else t


Pick a list of moduli and pass through the model to predict integer

In [159]:
test_moduli = [2,0,0,2]  # Example moduli for testing
print("Test moduli:", test_moduli)

model.eval()
with torch.no_grad():
    sample_src = moduli_to_tensor(test_moduli, device)
    src_kpm = (sample_src == PAD_TOKEN)

    # start with SOS
    generated = [SOS_TOKEN]
    for _ in range(20):                          # max 20 digits
        tgt_in  = torch.tensor([generated], device=device)
        tgt_mask = generate_square_subsequent_mask(tgt_in.size(1), device=device)
        logits = model(sample_src, tgt_in,
                        src_mask=None, tgt_mask=tgt_mask,
                        src_key_padding_mask=src_kpm,
                        tgt_key_padding_mask=(tgt_in == PAD_TOKEN),
                        memory_key_padding_mask=src_kpm)
        next_tok = logits[0, -1].argmax(-1).item()
        generated.append(next_tok)
        if next_tok == EOS_TOKEN:
            break

    # remove BOS/EOS and print digits
    digits = [str(t) for t in generated[1:-1]]
    print("  sample decode →", "".join(digits))

Test moduli: [2, 0, 0, 2]
  sample decode → 1080


In [160]:
print([1080 % p for p in primes_list])

[0, 0, 2, 2]


In [163]:
print(11*13*17*23)

55913
