# Imports

In [1]:
import torch 
from torch import nn 
import pandas as pd
from torch import optim 
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms 
from sklearn.model_selection import train_test_split 
import matplotlib.pyplot as plt
import numpy as np 
import random 
import timeit
from tqdm import tqdm
import os

# Hyperparameters

In [2]:
## Added Params (for training and testing)
RANDOM_SEED = 42
BATCH_SIZE = 256 
EPOCHS = 40 ##why this high number? usually for transformers you do 1,2,3. 


#Image params 
PATCH_SIZE = 4 #we chose 4-> pixel length of 1 dimension
IMAGE_SIZE = 56 #The MNIST dataset images are 28 × 28 pixels in size. (H,W) = (28, 28) 
IN_CHANNELS = 1 #MNIST only has 1 channel (Grayscale). Note: RGB would be 3 channels. 
NUM_CLASSES = 10 #because MNIST
NUM_PATCHES = (IMAGE_SIZE // PATCH_SIZE) ** 2 # 49


# Training Params 
LEARNING_RATE = 1e-4
DROPOUT = 0.001 
ADAM_WEIGHT_DECAY = 0 # paper uses 0.1, set it to 0 (defautl value)
ADAM_BETAS = (0.9, 0.999) # again from paper. 
ACTIVATION = "gelu" #again use the same as the paper 

#Encoder-Decoder Params 
NUM_HEADS = 4  
NUM_ENCODER = 3 
                    
EMBED_DIM = 64 
HIDDEN_DIM = 128 #hidden dimentsion of MLP head for classification 



#Set out random intiialisation seeds

random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed_all(RANDOM_SEED)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False


device = "cuda" if torch.cuda.is_available() else "cpu"

# Patch Embedding

In [4]:
# Creating CLS Tokens and merging with Positional Embeddings 

class PatchEmbedding(nn.Module):
    def __init__(self, embedding_dim, patch_size, num_patches, dropout, in_channels): 
        super().__init__()
        
        #function that divides images into patches
        self.patcher = nn.Sequential(
            # all Conv2d does is divide our image into patch sizes
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=embedding_dim,
                kernel_size=patch_size,
                stride=patch_size,
            ), 
            nn.Flatten(start_dim=2)) #equivalent to nn.Flatten(start_dim=2, end_dim=-1) -> not a learnable layer (converts patched into sequence of vectors)
        
            #OUTPUT SHAPE: (batch_size, embedding_dim, num_patches) AKA the full sequence of patches
            
        
        #---- CLS Token ---- 
     
        #here we define the [CLS] token. nn.Parameter is a learnable tensor (its a single parameter not a full layer)
        # Create a random tensor of shape (1, in_channels, embedding_dim), wrap it as a learnable parameter, and assign it as the CLS token
        self.cls_token = nn.Parameter(torch.randn(size=(1,in_channels,embedding_dim)), requires_grad=True)
        
        
        #---- Positional Embedding ---- 
        
        
        #positional embedding is a learnable parameter 
        self.position_embedding = nn.Parameter(torch.randn(size=(1,num_patches+1,embedding_dim)), requires_grad=True) #we add 1 to num_patches because we have the [CLS] token
        
        self.dropout = nn.Dropout(dropout)
    
    
    #after patching and flattening we have a tesnor of shape (batch_size, embedding_dim, num_patches) e.g., (32, 16, 49)
    # x = x.permute(0, 2, 1) rearranges to (batch_size, num_patches, embedding_dim) e.g., (32, 49, 16)
        
        
        
        
    def forward(self, x): 
        #here we expand the cls token so its not just the shape for 1 sample but for a batch of images
        cls_token = self.cls_token.expand(x.shape[0], -1, -1) #expand the cls token to the batch size. x.shape[0] is the batch size. -1, -1 tells expand function to keep original dimensions. 
        x = self.patcher(x).permute(0,2,1) # first patch x through patcher -> where nn.Conv2d: splits x into patches and embeds them, nn.Flatten(start_dim=2) converts into 1D sequence
        
        #1 axis for batches, 1 axis for sequence of patches, 1 axis for embedding dimension 
        x = torch.cat([cls_token, x], dim=1) #so we want to add the CLS token to the left of the patches
        
        #then we need to add the position tokens to each patch 
        x = self.position_embedding + x
        
        x = self.dropout(x)
        return x


# #always test model after you define it    
# model = PatchEmbedding(EMBED_DIM, PATCH_SIZE, NUM_PATCHES, DROPOUT, IN_CHANNELS).to(device)  
# x = torch.randn(512, 1, 56, 56).to(device)   #create dummy image of batch size 512, channels 1, and dimensions 28x28 
# print(model(x).shape) #expect (512, 50, 16) where batch size 512, 50 is number of tokens we feed transformer (correct because we have 49 patches + CLS token), 16 is size of patches (embedding dimension)

# Encoder 

In [None]:
import math
# === Helper activation ========================================================
class NewGELUActivation(nn.Module):                       # same formula as HF
    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(
            torch.sqrt(torch.tensor(2.0)) * (x + 0.044715 * torch.pow(x, 3))
        ))

# === 1. One attention head ====================================================
class AttentionHead(nn.Module):
    def __init__(self, hidden_size: int, head_dim: int, dropout: float,
                 bias: bool = True):
        super().__init__()
        self.head_dim = head_dim
        self.q_proj = nn.Linear(hidden_size, head_dim, bias=bias)
        self.k_proj = nn.Linear(hidden_size, head_dim, bias=bias)
        self.v_proj = nn.Linear(hidden_size, head_dim, bias=bias)
        self.drop = nn.Dropout(dropout)

    def forward(self, q_in, k_in, v_in, mask=None):                                 # x: (B, S, D)
        # q, k, v = self.q(x), self.k(x), self.v(x)         # (B,S,d_h) each
        q = self.q_proj(q_in) # (B, Seq_q, d_h)
        k = self.k_proj(k_in) # (B, Seq_k, d_h)
        v = self.v_proj(v_in) # (B, Seq_v, d_h)
        
        scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        
        if mask is not None:
            scores = scores.masked_fill(~mask, float("-inf"))
            
        attn = scores.softmax(dim=-1)
        attn = self.drop(attn)            # (B,S,S)
        context = attn @ v 
        return context                                 # (B,S,d_h)

# === 2. Multi-head self-attention =============================================
class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_size: int = EMBED_DIM,
                 num_heads: int = NUM_HEADS,
                 dropout: float = DROPOUT,
                 qkv_bias: bool = True):
        super().__init__()
        assert hidden_size % num_heads == 0, "hidden_size must be divisible by num_heads"
        head_dim = hidden_size // num_heads
        self.heads = nn.ModuleList([
            AttentionHead(hidden_size, head_dim, dropout, qkv_bias)
            for _ in range(num_heads)
        ])
        self.out_proj = nn.Linear(hidden_size, hidden_size)
        self.drop = nn.Dropout(dropout)

    def forward(self, q, k=None, v=None, mask=None):          
        """
        q : (B, Seq_q, D)
        k : (B, Seq_k, D)   defaults to q if None
        v : (B, Seq_k, D)   defaults to k if None
        mask : (1, 1, Seq_q, Seq_k) or None
        """
        k = q if k is None else k
        v = k if v is None else v
                                       # x: (B,S,D)
      # run every head, collect their (B, Seq_q, d_h) outputs
        head_outputs = [
            head(q, k, v, mask) for head in self.heads
        ]                                 # list length H
        concat = torch.cat(head_outputs, dim=-1)  # (B, Seq_q, D)
        return self.drop(self.out_proj(concat))

# === 3. Position-wise feed-forward (MLP) ======================================
class MLP(nn.Module):
    def __init__(self, hidden_size: int = EMBED_DIM,
                 intermediate_size: int = HIDDEN_DIM * 4,
                 dropout: float = DROPOUT):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(hidden_size, intermediate_size),
            NewGELUActivation(),
            nn.Linear(intermediate_size, hidden_size),
            nn.Dropout(dropout)
        )
    def forward(self, x):                                 # (B,S,D) -> (B,S,D)
        return self.net(x)

# === 4. Transformer block =====================================================
class Block(nn.Module):
    def __init__(self, hidden_size: int = EMBED_DIM,
                 num_heads: int = NUM_HEADS,
                 mlp_ratio: int = 4,
                 dropout: float = DROPOUT):
        super().__init__()
        self.ln1 = nn.LayerNorm(hidden_size)
        self.attn = MultiHeadAttention(hidden_size, num_heads, dropout)
        self.ln2 = nn.LayerNorm(hidden_size)
        self.mlp = MLP(hidden_size, hidden_size * mlp_ratio, dropout)

    def forward(self, x):                                 # (B,S,D)
        x = x + self.attn(self.ln1(x))                    # SA + residual
        x = x + self.mlp(self.ln2(x))                     # MLP + residual
        return x                                          # (B,S,D)

# === 5. Encoder = N stacked blocks ============================================
class Encoder(nn.Module):
    def __init__(self, depth: int = NUM_ENCODER,
                 hidden_size: int = EMBED_DIM,
                 num_heads: int = NUM_HEADS,
                 dropout: float = DROPOUT):
        super().__init__()
        self.blocks = nn.ModuleList([
            Block(hidden_size, num_heads, dropout=dropout)
            for _ in range(depth)
        ])
        self.ln_final = nn.LayerNorm(hidden_size)

    def forward(self, x):                                 # (B,S,D)
        for blk in self.blocks:
            x = blk(x)
        return self.ln_final(x)                           # final norm

# === 6. ViT classifier head (uses Encoder) ====================================
class ViT(nn.Module):
    def __init__(self,
                 num_patches: int = NUM_PATCHES,
                 num_classes: int = NUM_CLASSES,
                 patch_size: int = PATCH_SIZE,
                 embed_dim: int = EMBED_DIM,
                 depth: int = NUM_ENCODER,
                 num_heads: int = NUM_HEADS,
                 dropout: float = DROPOUT,
                 in_channels: int = IN_CHANNELS):
        super().__init__()
        self.embed = PatchEmbedding(embed_dim, patch_size,
                                    num_patches, dropout, in_channels)
        self.encoder = Encoder(depth, embed_dim, num_heads, dropout)
        self.cls_head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):                                 # x: (B,C,H,W)
        x = self.embed(x)                                 # (B,S,D)
        x = self.encoder(x)                               # (B,S,D)
        cls = x[:, 0]                                     # (B,D)
        return self.cls_head(cls)                         # (B,num_classes)


model = ViT(NUM_PATCHES, NUM_CLASSES, PATCH_SIZE, EMBED_DIM, NUM_ENCODER, NUM_HEADS, DROPOUT, IN_CHANNELS).to(device)  
x = torch.randn(512, 1, 56, 56) #dummy image
print(model(x).shape) #expect [512, 10] -> batch of 512 (512 images) for 10 classes -> returns 

# Decoder

In [6]:
# ─────────────────────────────────────────────────────────────────────────────
#  Synthetic 2×2–grid dataset (Stage-1 of our encoder-decoder project)
#  Canvas: 56×56   PatchSize: 4  → 14×14 = 196 patch tokens
# ─────────────────────────────────────────────────────────────────────────────
VOCAB = {str(i): i for i in range(10)}
VOCAB['<start>']  = 10
VOCAB['<finish>'] = 11
VOCAB['<pad>']    = 12
PAD_IDX      = VOCAB['<pad>']
SEQ_MAX_LEN  = 18                    #  <start> + up-to-16 digits + <finish>
VOCAB_INV = {v: k for k, v in VOCAB.items()}

In [7]:
#causal mask 
def causal_mask(seq_len: int, device=None):
    """
    Returns a boolean mask of shape (1, 1, seq_len, seq_len)
    where mask[..., i, j] is True  if  j ≤ i   (allowed)
                           and False if  j > i   (forbidden)
    The leading (1,1) dimensions let it broadcast over batch and head.
    """
    idx = torch.arange(seq_len, device=device)
    mask = idx[:, None] >= idx[None, :]   # (seq_len, seq_len) True/False
    # add dummy batch and head dimensions
    return mask[None, None, :, :]         # (1, 1, S, S)

class DecoderBlock(nn.Module):
    def __init__(self,
                 hidden_size: int = EMBED_DIM,
                 num_heads: int = NUM_HEADS,
                 dropout: float = DROPOUT):
        super().__init__()
        self.ln1 = nn.LayerNorm(hidden_size)
        self.self_attn = MultiHeadAttention(hidden_size, num_heads, dropout)

        self.ln2 = nn.LayerNorm(hidden_size)
        self.cross_attn = MultiHeadAttention(hidden_size, num_heads, dropout)

        self.ln3 = nn.LayerNorm(hidden_size)
        self.mlp = MLP(hidden_size, hidden_size * 4, dropout)

    def forward(self, x, enc_out, mask):
        """
        x       : (B, T, D)   decoder input so far
        enc_out : (B, S, D)   encoder memory
        mask    : (1, 1, T, T) causal mask
        """
        # 1. masked self-attention
        qkv_in = self.ln1(x)
        x = x + self.self_attn(qkv_in, qkv_in, qkv_in, mask)

        # 2. encoder–decoder cross-attention
        q = self.ln2(x)
        k = v = enc_out                        # same tensor for key and value
        x = x + self.cross_attn(q, k, v)       # no mask here

        # 3. feed-forward
        x = x + self.mlp(self.ln3(x))
        return x


class DigitDecoder(nn.Module):
    def __init__(self, 
                 hidden_size: int = EMBED_DIM,
                 num_heads: int = NUM_HEADS,
                 dropout: float = DROPOUT,
                 depth: int = 3, #number of decoder blocks stacked sequentially
                 vocab_size: int = len(VOCAB)
                 ):
        super().__init__()
        self.token_embed = nn.Embedding(vocab_size, hidden_size) 
        self.pos_embed = nn.Parameter(torch.randn(1, SEQ_MAX_LEN, hidden_size))
        self.start_token = nn.Parameter(torch.randn(1, 1, hidden_size))
        
        self.blocks = nn.ModuleList([
            DecoderBlock(hidden_size, num_heads, dropout)
            for _ in range(depth)
        ])
        self.ln_final = nn.LayerNorm(hidden_size)
        self.head = nn.Linear(hidden_size, vocab_size)
        
        
    def forward(self, targets, enc_out):
        """
        targets : (B, 5) integer ids excluding <start>
                  column layout: [d1 d2 d3 d4 <pad>]
        enc_out : (B, S, D)
        returns
        logits  : (B, 6, vocab)
        """
        B = targets.size(0) #extract batch size from input 
    
        #prepend the learnable <start> token 
        start_vec = self.start_token.expand(B, -1, -1) #(B, 1, D)
        tgt_vecs = self.token_embed(targets) #(B, 5, D)
        x = torch.cat([start_vec, tgt_vecs], dim=1) #(B, 6, D)
    
        # add positional embeddings (same for every batch items)
        x = x + self.pos_embed
    
        T = x.size(1) #sequence length including <start> token
        
        mask = causal_mask(T, x.device) #(1, 1, 6, 6)
    
        #run through blocks 
        for blk in self.blocks: 
            x = blk(x, enc_out, mask) #note the recursion here 
        
        x = self.ln_final(x)
        logits = self.head(x)
        
        return logits 

    
    

# Full Encoder-Decoder Model

In [8]:
class GridTranscriber(nn.Module):
    def __init__(self): 
        super().__init__()
        
        #reuse PatchEmbedding and Encoder 
        self.embed = PatchEmbedding(EMBED_DIM, PATCH_SIZE, NUM_PATCHES, DROPOUT, IN_CHANNELS)
        self.encoder = Encoder(NUM_ENCODER, EMBED_DIM, NUM_HEADS, DROPOUT)
        self.decoder = DigitDecoder(EMBED_DIM, NUM_HEADS, depth = 3)
        
    def forward(self, images, targets): 
        enc_seq = self.encoder(self.embed(images))
        logits = self.decoder(targets, enc_seq)
        return logits 

In [None]:
B = 2
dummy_img = torch.randn(B, 1, 56, 56)          # batch of fake canvases
dummy_tgt = torch.randint(0, 10, (B, 5))       # random digits 0-9
model = GridTranscriber()
out = model(dummy_img, dummy_tgt)
print("logits shape :", out.shape)             # should be (2, 6, 13)


# Dataset

In [10]:
#Download MNIST dataset 

transform = transforms.ToTensor()

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)



In [12]:


GRID_SIZE    = 4                    # 4×4 cells
CELL_PIX     = 28 
CANVAS_PIX   = GRID_SIZE * CELL_PIX 

grid_transforms = transforms.Compose([
    transforms.ToTensor(),                  # uint8 → 0-1
    transforms.Normalize([0.5], [0.5])      # centre to −1..1
])

class GridMNIST(Dataset):
    def __init__(self, base_images, base_labels,
                 epoch_size=60_000, rng=None):
        self.base_images = base_images          # (N,28,28) uint8
        self.base_labels = base_labels          # (N,)
        self.epoch_size  = epoch_size
        self.rng = np.random.default_rng(rng)

        # indices per digit for fast balanced sampling
        self.per_digit = {d: np.where(base_labels == d)[0] for d in range(10)}

        # 15 non-empty cell-occupancy patterns (bitmask 1..15)
        patterns = np.arange(1, 1 << 16, dtype=np.uint16)
        reps = math.ceil(epoch_size / len(patterns))
        self.pattern_pool = self.rng.permutation(
            np.tile(patterns, reps)[:epoch_size])

    def __len__(self):
        return self.epoch_size

    @staticmethod
    def _cells_from_mask(mask: int):
        return [i for i in range(16) if (mask >> i) & 1]   # TL,TR,BL,BR

    def __getitem__(self, idx):
        mask     = int(self.pattern_pool[idx])
        cell_ids = self._cells_from_mask(mask)

        canvas = np.zeros((CANVAS_PIX, CANVAS_PIX), dtype=np.uint8)
        odd_list, even_list = [], []

        for cell in cell_ids:
            d = int(self.rng.integers(0, 10))
            img_idx = self.rng.choice(self.per_digit[d])
            digit_img = self.base_images[img_idx]

            row, col = divmod(cell, GRID_SIZE)
            top, left = row*CELL_PIX, col*CELL_PIX
            canvas[top:top+CELL_PIX, left:left+CELL_PIX] = digit_img

            # bucket into odd/even lists
            (odd_list if d % 2 else even_list).append(d)

        #Build the target sequence: 
        odd_list.sort()
        even_list.sort(reverse=True)
        seq = [VOCAB['<start>']] + [VOCAB[str(d)] for d in odd_list] + [VOCAB[str(d)] for d in even_list] + [VOCAB['<finish>']]
        
        #pad to fixed length 
        length = len(seq)
        seq += [PAD_IDX] * (SEQ_MAX_LEN - length)
        
        
        
        return {
            'image'  : grid_transforms(canvas),          # (1,56,56) float
            'target' : torch.tensor(seq, dtype=torch.long),  # (6,)
            'length' : length
        }

# ─── build train / val / test loaders ───────────────────────────────────────
train_grid = GridMNIST(train_dataset.data.numpy(),
                       train_dataset.targets.numpy(),
                       epoch_size=60_000, rng=RANDOM_SEED)

val_grid   = GridMNIST(test_dataset.data.numpy(),   # we reuse MNIST test set
                       test_dataset.targets.numpy(),
                       epoch_size=10_000, rng=RANDOM_SEED+1)

test_grid  = GridMNIST(test_dataset.data.numpy(),
                       test_dataset.targets.numpy(),
                       epoch_size=10_000, rng=RANDOM_SEED+2)

# ——————————————————————————————————————————————————————————————
# Decide how many CPU cores to devote to data loading.
# Four to eight usually keeps the GPU fed without wasting resources.
# ——————————————————————————————————————————————————————————————
NUM_WORKERS = min(8, os.cpu_count())        # 4-8 is a good starting range

loader_kwargs = dict(
    batch_size      = BATCH_SIZE,
    num_workers     = NUM_WORKERS,          # <- key addition
    pin_memory      = True,                 # speeds up host-to-device copy
    persistent_workers = True,              # keeps workers alive across epochs
    prefetch_factor = 4                     # each worker holds 4 batches ready
)

train_dataloader = DataLoader(
    train_grid,
    shuffle = True,
    **loader_kwargs                           # unpack the common arguments
)

val_dataloader = DataLoader(
    val_grid,
    shuffle = False,
    **loader_kwargs
)

test_dataloader = DataLoader(
    test_grid,
    shuffle = False,
    **loader_kwargs
)


In [None]:
# Little smoke test:
sample = GridMNIST(train_dataset.data.numpy(),
                   train_dataset.targets.numpy(),
                   epoch_size=1, rng=0)[0]
plt.imshow(sample['image'].squeeze(), cmap='gray'); plt.axis('off')
print("target ids :", sample['target'].tolist())
print("target txt :", ' '.join(VOCAB_INV[i] for i in sample['target']
                               if i != PAD_IDX))


In [None]:
import multiprocessing as mp
mp.set_start_method("fork", force=True)

# quick sanity-check: show a random training sample
sample = next(iter(train_dataloader))
img_grid = sample['image'][0].squeeze().cpu()   # (56,56)
target   = sample['target'][0].tolist()
seq_txt  = ' '.join([str(VOCAB_INV[t]) for t in target if t != PAD_IDX])

plt.imshow(img_grid, cmap='gray')
plt.title(f"target sequence: {seq_txt}")
plt.axis('off')
plt.show()


In [16]:
#free up space in GPU memory 
torch.cuda.empty_cache()

# Training Loop for Encoder-Decoder Model

In [None]:
# ─────────────────────────────────────────────────────────────────────────────
#  1. helper: exact-match metric
# ─────────────────────────────────────────────────────────────────────────────
def sequence_exact_match(logits, targets):
    """
    logits  : (B, 6, vocab)  – raw scores from the model
    targets : (B, 6)         – ground-truth ids  (<start> … <finish> <pad>)
    Returns a float in [0,1] = proportion of sequences that match exactly
    (ignoring <pad>).
    """
    preds = logits.argmax(dim=-1)                     # (B,6) choose top id
    # a token counts as correct if ids equal OR gold token is PAD
    match = (preds == targets) | (targets == PAD_IDX) # bool (B,6)
    seq_match = match.all(dim=1)                      # bool (B,)
    return seq_match.float().mean().item()

# ─────────────────────────────────────────────────────────────────────────────
#  2. model, loss and optimiser
# ─────────────────────────────────────────────────────────────────────────────
model = GridTranscriber().to(device)
criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
optimizer = optim.AdamW(model.parameters(),
                        lr=LEARNING_RATE,
                        weight_decay=ADAM_WEIGHT_DECAY,
                        betas=ADAM_BETAS)

# ─────────────────────────────────────────────────────────────────────────────
#  3. training / validation loop
# ─────────────────────────────────────────────────────────────────────────────
start_time = timeit.default_timer()

for epoch in range(EPOCHS):
    # ── TRAIN ────────────────────────────────────────────────────────────
    model.train()
    train_loss = 0.0
    train_em   = 0.0

    for batch in tqdm(train_dataloader, desc=f"train {epoch+1}/{EPOCHS}"):
        imgs    = batch['image'].to(device)    # (B,1,56,56)
        targets = batch['target'].to(device)   # (B,6)

        decoder_in  = targets[:, 1:]           # drop <start>  → (B,5)
        gold_tokens = targets                  # keep full seq (B,6)

        logits = model(imgs, decoder_in)       # (B,6,vocab)

        # loss = criterion(logits.view(-1, len(VOCAB)),
        #                  gold_tokens.view(-1))
        loss = criterion(logits.reshape(-1, len(VOCAB)),
                         gold_tokens.reshape(-1))


        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        train_em   += sequence_exact_match(logits, gold_tokens)

    train_loss /= len(train_dataloader)
    train_em   /= len(train_dataloader)

    # ── VALIDATE ─────────────────────────────────────────────────────────
    model.eval()
    val_loss = 0.0
    val_em   = 0.0

    with torch.no_grad():
        for batch in tqdm(val_dataloader, desc="valid"):
            imgs    = batch['image'].to(device)
            targets = batch['target'].to(device)

            decoder_in  = targets[:, 1:]
            gold_tokens = targets

            logits = model(imgs, decoder_in)
            # loss   = criterion(logits.view(-1, len(VOCAB)),
            #                    gold_tokens.view(-1))
            
            loss = criterion(logits.reshape(-1, len(VOCAB)),
                         gold_tokens.reshape(-1))
            
            

            val_loss += loss.item()
            val_em   += sequence_exact_match(logits, gold_tokens)

    val_loss /= len(val_dataloader)
    val_em   /= len(val_dataloader)

    print("─" * 60)
    print(f"Epoch {epoch+1:02d}/{EPOCHS}")
    print(f"  train loss {train_loss:.4f}   EM {train_em:.4f}")
    print(f"  val   loss {val_loss:.4f}   EM {val_em:.4f}")

stop_time = timeit.default_timer()
print(f"Total training time: {stop_time - start_time:.1f} seconds")


In [11]:
#test to see our training loop 
sample = next(iter(val_dataloader))
img_grid = sample['image'][0].squeeze().cpu()
gold_seq = sample['target'][0].tolist()
gold_txt = ' '.join([IDX_TO_SYM[t] for t in gold_seq if t != PAD_IDX])

model.eval()
with torch.no_grad():
    logits = model(sample['image'].to(device),
                   sample['target'][:,1:].to(device))
pred_seq = logits.argmax(dim=-1)[0].cpu().tolist()
pred_txt = ' '.join([IDX_TO_SYM[t] for t in pred_seq if t != PAD_IDX])

plt.imshow(img_grid, cmap='gray'); plt.axis('off')
plt.title(f"gold: {gold_txt}     pred: {pred_txt}")
plt.show()

In [None]:
#Need to review this. 


# ─────────────────────────────────────────────────────────────────────────────
#  AUTOREGRESSIVE INFERENCE ON THE TEST SET
#  • Greedy decoding (arg-max at each step)
#  • Stops when <finish> is produced or when 5 digits are generated
# ─────────────────────────────────────────────────────────────────────────────
def greedy_decode(model, images):
    """
    images : (B, 1, 56, 56)  torch.float32
    returns
    pred_seq : (B, 6)  int, includes <start>  …  <finish>  <pad> ...
    """
    B = images.size(0)
    device = images.device

    # encode once
    enc_seq = model.encoder(model.embed(images))          # (B,S,D)

    # first input to decoder = just <start>
    start_id = VOCAB['<start>']
    decoder_input = torch.full((B, 1), start_id,
                               dtype=torch.long, device=device)  # (B,1)

    finished = torch.zeros(B, dtype=torch.bool, device=device)
    outputs  = [decoder_input.squeeze(1)]                  # list of (B,)

    for step in range(5):                                  # at most 5 more
        logits = model.decoder(decoder_input, enc_seq)     # (B, step+1, V)
        next_token = logits[:, -1].argmax(dim=-1)          # (B,)

        outputs.append(next_token)

        finished |= (next_token == VOCAB['<finish>'])
        if finished.all():
            break

        decoder_input = torch.cat([decoder_input,
                                   next_token.unsqueeze(1)], dim=1)  # append

    # pad to length 6
    while len(outputs) < 6:
        outputs.append(torch.full((B,), PAD_IDX, device=device))

    return torch.stack(outputs, dim=1)   # (B,6)

# ─── evaluate on the whole test set ──────────────────────────────────────────
model.eval()
exact_matches = 0
total_batches = 0

for batch in tqdm(test_dataloader, desc="test"):
    imgs   = batch['image'].to(device)
    gold   = batch['target'].to(device)          # (B,6)

    pred = greedy_decode(model, imgs)            # (B,6)

    match = ((pred == gold) | (gold == PAD_IDX)).all(dim=1)
    exact_matches += match.sum().item()
    total_batches += gold.size(0)

test_em = exact_matches / total_batches
print(f"\nExact-match accuracy on test set : {test_em:.4f}")

# ─── visualise a few samples ────────────────────────────────────────────────
VOCAB_INV = {v: k for k, v in VOCAB.items()}

def seq_to_text(seq):
    toks = [VOCAB_INV[i] for i in seq if i not in (PAD_IDX,)]
    # remove repeated <start> if any, keep tokens until <finish>
    result = []
    for t in toks:
        result.append(t)
        if t == '<finish>':
            break
    return ' '.join(result)

n_rows, n_cols = 2, 4
plt.figure(figsize=(n_cols*2.5, n_rows*2.5))

sample_iter = iter(test_dataloader)
batch = next(sample_iter)
imgs = batch['image'].to(device)
gold = batch['target']
pred = greedy_decode(model, imgs).cpu()

for i in range(n_rows * n_cols):
    ax = plt.subplot(n_rows, n_cols, i+1)
    ax.imshow(imgs[i].cpu().squeeze(), cmap='gray')
    ax.axis('off')
    ax.set_title(f"G: {seq_to_text(gold[i])}\nP: {seq_to_text(pred[i])}",
                 fontsize=8)

plt.tight_layout()
plt.show()
