In [None]:
# Cell 1: COMPLETE PREPROCESSING - Dataset Loading + Vocab Creation + Label Encoding

import pandas as pd
import numpy as np
import cv2
import os
from torch.utils.data import Dataset, DataLoader
import torch
from collections import Counter
from tqdm import tqdm

os.environ['CUDA_VISIBLE_DEVICES'] = '1'

# ============================================================
# CONFIGURATION
# ============================================================
DATASET_ROOT = '/home/ie643_errorcode500/errorcode500-working/Mathwritting-1000'
MAX_SEQ_LEN = 128
BATCH_SIZE = 32
NUM_WORKERS = 0

print(f"\n{'='*60}")
print("PREPROCESSING PIPELINE")
print(f"{'='*60}\n")

# ============================================================
# STEP 1: LaTeX Tokenization Function
# ============================================================
def custom_latex_tokenize(latex_str):
    """
    Tokenize LaTeX string into meaningful units
    Examples:
        "\\frac{a}{b}" -> ['\\frac', '{', 'a', '}', '{', 'b', '}']
        "2x + 3" -> ['2', 'x', '+', '3']
    """
    tokens = []
    i = 0
    latex_str = latex_str.strip()
    
    while i < len(latex_str):
        char = latex_str[i]
        
        # LaTeX commands (start with backslash)
        if char == '\\':
            if i + 1 < len(latex_str):
                next_char = latex_str[i + 1]
                # Single-character command (\{, \}, \\, \%, etc.)
                if not next_char.isalpha():
                    tokens.append(latex_str[i:i+2])
                    i += 2
                # Multi-character alphabetic command (\frac, \sqrt, etc.)
                else:
                    j = i + 1
                    while j < len(latex_str) and latex_str[j].isalpha():
                        j += 1
                    tokens.append(latex_str[i:j])
                    i = j
            else:
                tokens.append('\\')
                i += 1
        
        # Multi-digit numbers (including decimals)
        elif char.isdigit():
            j = i
            while j < len(latex_str) and (latex_str[j].isdigit() or latex_str[j] == '.'):
                j += 1
            tokens.append(latex_str[i:j])
            i = j
        
        # Skip whitespace
        elif char.isspace():
            i += 1
        
        # Single characters (brackets, operators, letters, symbols)
        else:
            tokens.append(char)
            i += 1
    
    return tokens

print("✓ Step 1: LaTeX tokenization function loaded")

# ============================================================
# STEP 2: Build Vocabulary from ALL splits
# ============================================================
print("\n" + "="*60)
print("STEP 2: BUILDING VOCABULARY")
print("="*60)

csv_files = ['train_database.csv', 'val_database.csv', 'test_database.csv']
all_labels = []

for csv_file in csv_files:
    csv_path = os.path.join(DATASET_ROOT, csv_file)
    if os.path.exists(csv_path):
        df = pd.read_csv(csv_path)
        all_labels.extend(df['normalized_label'].astype(str).tolist())
        print(f"  Loaded {len(df):,} samples from {csv_file}")
    else:
        print(f"  ⚠️  Warning: {csv_path} not found")

print(f"\nTotal labels loaded: {len(all_labels):,}")

# Tokenize all labels and count frequencies
special_tokens = ['<PAD>', '<SOS>', '<EOS>']
token_counter = Counter()

print("\nTokenizing all LaTeX sequences...")
for label in tqdm(all_labels, desc="Tokenizing"):
    tokens = custom_latex_tokenize(label)
    token_counter.update(tokens)

# Build vocabulary: special tokens + sorted by frequency
vocab_tokens = special_tokens + [tok for tok, _ in token_counter.most_common()]
vocab_size = len(vocab_tokens)

# Create mappings
token2idx = {tok: idx for idx, tok in enumerate(vocab_tokens)}
idx2token = {idx: tok for tok, idx in token2idx.items()}

print(f"\n✓ Vocabulary created:")
print(f"  Vocab size: {vocab_size:,}")
print(f"  Most common tokens: {token_counter.most_common(10)}")
print(f"  Special tokens: {special_tokens}")
print(vocab_tokens)

# ============================================================
# STEP 3: Encoding Function
# ============================================================
def encode_label_tokens(label, max_len=MAX_SEQ_LEN):
    """
    Encode LaTeX string to token indices
    Returns: List of token indices [<SOS>, tok1, tok2, ..., <EOS>, <PAD>, ...]
    """
    tokens = custom_latex_tokenize(label)
    token_indices = [token2idx['<SOS>']]
    
    for tok in tokens:
        if tok in token2idx:
            token_indices.append(token2idx[tok])
        else:
            print(f"⚠️  Unknown token '{tok}' - skipping")
    
    token_indices.append(token2idx['<EOS>'])
    
    # Pad or truncate
    if len(token_indices) < max_len:
        token_indices += [token2idx['<PAD>']] * (max_len - len(token_indices))
    else:
        token_indices = token_indices[:max_len]
    
    return token_indices

print("\n✓ Step 3: Encoding function loaded")

# ============================================================
# STEP 4: Pre-encode ALL datasets and save to disk
# ============================================================
print("\n" + "="*60)
print("STEP 4: PRE-ENCODING ALL DATASETS")
print("="*60)

def preprocess_and_save_dataset(csv_file, split_name):
    """Pre-encode labels and save to new CSV"""
    csv_path = os.path.join(DATASET_ROOT, csv_file)
    
    if not os.path.exists(csv_path):
        print(f"⚠️  Skipping {csv_file} - not found")
        return None
    
    df = pd.read_csv(csv_path)
    print(f"\nProcessing {split_name} set ({len(df):,} samples)...")
    
    # Encode all labels
    encoded_labels = []
    for idx in tqdm(range(len(df)), desc=f"Encoding {split_name}"):
        label = df.iloc[idx]['normalized_label']
        encoded = encode_label_tokens(label, MAX_SEQ_LEN)
        encoded_labels.append(str(encoded))  # Store as string for CSV
    
    # Add encoded column
    df['encoded_label'] = encoded_labels
    
    # Save encoded dataset
    output_path = csv_path.replace('.csv', '_encoded.csv')
    df.to_csv(output_path, index=False)
    
    print(f"✓ Saved to: {output_path}")
    return output_path

# Pre-encode all splits
train_encoded_path = preprocess_and_save_dataset('train_database.csv', 'TRAIN')
val_encoded_path = preprocess_and_save_dataset('val_database.csv', 'VAL')
test_encoded_path = preprocess_and_save_dataset('test_database.csv', 'TEST')

# ============================================================
# STEP 5: Dataset Class with Pre-Encoded Labels
# ============================================================
print("\n" + "="*60)
print("STEP 5: CREATING DATASET WITH PRE-ENCODED LABELS")
print("="*60)

class MathEquationEncodedDataset(Dataset):
    def __init__(self, csv_file, dataset_root, split='train', transform=None):
        """Dataset that uses PRE-ENCODED labels"""
        self.data_frame = pd.read_csv(csv_file)
        self.dataset_root = dataset_root
        self.split = split
        self.transform = transform
        
        if 'encoded_label' not in self.data_frame.columns:
            raise ValueError(f"CSV {csv_file} missing 'encoded_label' column!")
        
        print(f"  ✓ Loaded {split} dataset: {len(self.data_frame):,} samples")

    def __len__(self):
        return len(self.data_frame)

    def __getitem__(self, idx):
        # Load image
        filename = self.data_frame.iloc[idx]['filename']
        img_path = os.path.join(self.dataset_root, self.split, filename)
        img_path = os.path.normpath(img_path).replace('\\', '/')
        
        image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        if image is None:
            raise FileNotFoundError(f"Image not found: {img_path}")
        
        image = image.astype(np.float32) / 255.0
        channels = image[np.newaxis, :, :]  # [1, H, W]
        
        # Get PRE-ENCODED label (stored as string)
        encoded_str = self.data_frame.iloc[idx]['encoded_label']
        encoded_label = eval(encoded_str)  # Convert string back to list
        
        sample = {
            'image': torch.tensor(channels, dtype=torch.float32),
            'label': torch.tensor(encoded_label, dtype=torch.long)  # Already encoded!
        }
        
        if self.transform:
            sample['image'] = self.transform(sample['image'])
        
        return sample

# ============================================================
# STEP 6: Create DataLoaders
# ============================================================
print("\n" + "="*60)
print("STEP 6: CREATING DATALOADERS")
print("="*60)

train_dataset = MathEquationEncodedDataset(
    train_encoded_path, 
    DATASET_ROOT, 
    split='train'
)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    persistent_workers=False
)

print(f"\n✓ Train DataLoader created:")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Batches per epoch: {len(train_loader)}")

# Test batch loading
try:
    test_batch = next(iter(train_loader))
    print(f"\n✓ Test batch loaded successfully:")
    print(f"  Image shape: {test_batch['image'].shape}")
    print(f"  Label shape: {test_batch['label'].shape}")
    print(f"  Label dtype: {test_batch['label'].dtype}")
    print(f"  Sample label (first 10 tokens): {test_batch['label'][0][:10].tolist()}")
except Exception as e:
    print(f"\n❌ Error loading test batch: {e}")

# ============================================================
# PREPROCESSING COMPLETE
# ============================================================
print("\n" + "="*60)
print("PREPROCESSING COMPLETE!")
print("="*60)
print(f"✓ Vocabulary size: {vocab_size:,}")
print(f"✓ Train samples: {len(train_dataset):,}")
print(f"✓ Max sequence length: {MAX_SEQ_LEN}")
print(f"✓ Batch size: {BATCH_SIZE}")
print(f"✓ All labels pre-encoded and cached to disk")
print("="*60 + "\n")

print("Ready for training! All encoding is done.")
print("Training loop will use pre-encoded labels directly.\n")

In [None]:
# Replace your WatcherFCN cell with this LIGHTWEIGHT version:

import torch
import torch.nn as nn
import torch.nn.functional as F

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, num_layers=2, dropout_p=0.0):  # ← REDUCED from 4 to 2 layers!
        super().__init__()
        layers = []
        for i in range(num_layers):
            layers.append(nn.Conv2d(
                in_channels if i == 0 else out_channels,
                out_channels,
                kernel_size=3,
                stride=1,
                padding=1
            ))
            layers.append(nn.BatchNorm2d(out_channels))
            layers.append(nn.ReLU(inplace=True))
            if dropout_p > 0:
                layers.append(nn.Dropout2d(p=dropout_p))
        self.block = nn.Sequential(*layers)

    def forward(self, x):
        return self.block(x)

class WatcherFCN(nn.Module):
    """
    LIGHTWEIGHT Encoder with:
    - Only 3 blocks (reduced from 4)
    - Smaller channels (16->32->64 instead of 32->64->128)
    - Adaptive pooling to fixed 500 sequence positions
    - Final channel dim = 64 (reduced from 128)
    """
    def __init__(self, in_channels=1):
        super().__init__()
        
        # Block 1: 1 -> 16 channels
        self.block1 = ConvBlock(in_channels, 16, num_layers=2)  # ← Smaller!
        self.pool1 = nn.MaxPool2d(2, 2)
        
        # Block 2: 16 -> 32 channels
        self.block2 = ConvBlock(16, 32, num_layers=2)
        self.pool2 = nn.MaxPool2d(2, 2)
        
        # Block 3: 32 -> 64 channels (with dropout)
        self.block3 = ConvBlock(32, 64, num_layers=2, dropout_p=0.2)
        self.pool3 = nn.MaxPool2d(2, 2)
        
        # ⭐ NEW: Adaptive pooling to FIXED size [10, 50] = 500 positions
        # This works for ANY input image size!
        self.adaptive_pool = nn.AdaptiveAvgPool2d((10, 50))
        
        # ⭐ No need for channel reduction - already at 64!

    def forward(self, x):
        # [batch, 1, H, W] -> [batch, 16, H/2, W/2]
        x = self.block1(x)
        x = self.pool1(x)
        
        # [batch, 16, H/2, W/2] -> [batch, 32, H/4, W/4]
        x = self.block2(x)
        x = self.pool2(x)
        
        # [batch, 32, H/4, W/4] -> [batch, 64, H/8, W/8]
        x = self.block3(x)
        x = self.pool3(x)
        
        # ⭐ [batch, 64, H/8, W/8] -> [batch, 64, 10, 50]
        x = self.adaptive_pool(x)
        
        return x  # [batch, 64, 10, 50]


# ============================================================
# Test the new architecture
# ============================================================
print("\n" + "="*60)
print("TESTING LIGHTWEIGHT WATCHERFCN")
print("="*60)

model = WatcherFCN(in_channels=1)
dummy_input = torch.randn(2, 1, 480, 1600)
output = model(dummy_input)

print(f"\nInput shape:  {dummy_input.shape}")
print(f"Output shape: {output.shape}")

batch_size, channels, height, width = output.shape
encoder_outputs = output.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels)

print(f"\nEncoder outputs shape: {encoder_outputs.shape}")
print(f"  Sequence length: {height * width} (reduced from 3000!)")
print(f"  Channel dim:     {channels} (reduced from 128!)")

# Calculate memory savings
old_memory = 2 * 3000 * 128 * 4 / 1e6  # batch=2, old size
new_memory = 2 * 500 * 64 * 4 / 1e6    # batch=2, new size
print(f"\nMemory comparison:")
print(f"  OLD: {old_memory:.2f} MB per batch")
print(f"  NEW: {new_memory:.2f} MB per batch")
print(f"  Savings: {old_memory/new_memory:.1f}× less memory!")

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"\nModel parameters: {total_params:,}")
print("="*60 + "\n")

torch.Size([2, 128, 30, 100])


In [9]:
batch_size, channels, height, width = output.shape
encoder_outputs = output.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels)
# encoder_outputs: [batch, 3000, 512]
encoder_outputs.shape

torch.Size([2, 3000, 128])

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CoverageAttention(nn.Module):
    def __init__(self, encoder_dim, decoder_dim, attention_dim, coverage_dim):
        super().__init__()
        self.W_a = nn.Linear(decoder_dim, attention_dim)
        self.U_a = nn.Linear(encoder_dim, attention_dim)
        self.U_f = nn.Linear(coverage_dim, attention_dim)
        self.v = nn.Linear(attention_dim, 1)

    def forward(self, encoder_outputs, decoder_hidden, coverage):
        # encoder_outputs: [batch, L, encoder_dim]
        # decoder_hidden: [batch, decoder_dim]
        # coverage: [batch, L, coverage_dim]
        Wh = self.W_a(decoder_hidden).unsqueeze(1)  # [batch, 1, att_dim]
        Ua = self.U_a(encoder_outputs)              # [batch, L, att_dim]
        Uf = self.U_f(coverage)                     # [batch, L, att_dim]
        att = torch.tanh(Wh + Ua + Uf)              # [batch, L, att_dim]
        scores = self.v(att).squeeze(-1)            # [batch, L]
        alpha = F.softmax(scores, dim=1)            # [batch, L]
        context = torch.sum(encoder_outputs * alpha.unsqueeze(-1), dim=1)  # [batch, encoder_dim]
        return context, alpha


# In your decoder cell, update the class:

class ParserGRUDecoder(nn.Module):
    def __init__(self, vocab_size, encoder_dim=64, embed_dim=128, decoder_dim=128, attention_dim=128, coverage_dim=1):  # ← CHANGED dims!
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.gru = nn.GRUCell(embed_dim + encoder_dim, decoder_dim)
        self.attention = CoverageAttention(encoder_dim, decoder_dim, attention_dim, coverage_dim)
        
        # Simpler output layer
        self.out = nn.Sequential(
            nn.Linear(decoder_dim + encoder_dim + embed_dim, decoder_dim),
            nn.Tanh(),
            nn.Linear(decoder_dim, vocab_size)
        )
        self.decoder_dim = decoder_dim

    def forward(self, encoder_outputs, targets, max_len):
        batch_size, L, encoder_dim = encoder_outputs.size()
        device = encoder_outputs.device
        coverage = torch.zeros(batch_size, L, 1, device=device)  

        hidden = torch.zeros(batch_size, self.decoder_dim, device=device)
        sos_token_idx = token2idx['<SOS>']
        inputs = torch.full((batch_size,), sos_token_idx, dtype=torch.long, device=device)
        
        outputs = []

        for t in range(max_len):
            embedded = self.embedding(inputs)
            context, alpha = self.attention(encoder_outputs, hidden, coverage)
            gru_input = torch.cat([embedded, context], dim=1)
            hidden = self.gru(gru_input, hidden)
            output = self.out(torch.cat([embedded, hidden, context], dim=1))
            outputs.append(output)
            
            if targets is not None and t < targets.size(1):
                inputs = targets[:, t]
            else:
                inputs = output.argmax(dim=1)
            
            coverage = coverage + alpha.unsqueeze(-1)
            
            # ⭐ Memory cleanup every 20 steps
            if t % 20 == 0 and t > 0:
                del embedded, context, alpha, gru_input
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
        
        outputs = torch.stack(outputs, dim=1)
        return outputs

# Example usage:
# encoder_outputs: [batch, L, encoder_dim] (flatten FCN output to [batch, L, 512])
# targets: [batch, max_len] (token indices)
# decoder = ParserGRUDecoder(vocab_size=len(vocab))
# outputs = decoder(encoder_outputs, targets, max_len)


In [None]:

import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import gc
import os

# ============================================================
# Memory Configuration
# ============================================================
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'
torch.cuda.empty_cache()
gc.collect()

# ============================================================
# Device Setup
# ============================================================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Total Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    print(f"Free Memory: {(torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0)) / 1e9:.2f} GB")

# ============================================================
# Initialize Models
# ============================================================
watcher = WatcherFCN(in_channels=1).to(device)
decoder = ParserGRUDecoder(
    vocab_size=vocab_size,
    encoder_dim=64,    # ← CHANGED from 128!
    embed_dim=128,     # ← REDUCED from 256!
    decoder_dim=128,   # ← REDUCED from 256!
    attention_dim=128  # ← REDUCED from 256!
).to(device)

# Loss function
pad_idx = token2idx['<PAD>']
criterion = nn.CrossEntropyLoss(ignore_index=pad_idx, label_smoothing=0.1)

# Optimizer
optimizer = optim.AdamW(
    list(watcher.parameters()) + list(decoder.parameters()),
    lr=1e-3,
    betas=(0.9, 0.98),
    eps=1e-9,
    weight_decay=1e-4
)

# Scheduler
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer, T_0=5, T_mult=2, eta_min=1e-6
)

# Training config
num_epochs = 10
max_len = MAX_SEQ_LEN
best_loss = float('inf')
scaler = torch.cuda.amp.GradScaler()

print(f"\n{'='*60}")
print("TRAINING CONFIGURATION")
print(f"{'='*60}")
print(f"Vocabulary size: {vocab_size:,}")
print(f"Max sequence length: {max_len}")
print(f"Batch size: {train_loader.batch_size}")
print(f"Epochs: {num_epochs}")
total_params = sum(p.numel() for p in list(watcher.parameters()) + list(decoder.parameters()))
print(f"Model parameters: {total_params:,}")
print(f"GPU Memory (after load): {torch.cuda.memory_allocated(0)/1e9:.2f}GB")
print(f"{'='*60}\n")

# ============================================================
# TRAINING LOOP - NO ENCODING!
# ============================================================
try:
    for epoch in range(num_epochs):
        watcher.train()
        decoder.train()
        total_loss = 0
        batch_count = 0
        
        torch.cuda.empty_cache()
        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')
        
        for batch_idx, batch in enumerate(pbar):
            try:
                # ⭐ NO ENCODING! Labels are already encoded tensors
                images = batch['image'].to(device, non_blocking=True)
                labels = batch['label'].to(device, non_blocking=True)  # Already torch.Tensor!
                
                optimizer.zero_grad(set_to_none=True)
                
                with torch.cuda.amp.autocast():
                    # Forward pass
                    watcher_output = watcher(images)
                    b, c, h, w = watcher_output.shape
                    encoder_outputs = watcher_output.permute(0, 2, 3, 1).reshape(b, h*w, c)
                    
                    outputs = decoder(encoder_outputs, labels, max_len)
                    outputs = outputs.view(-1, vocab_size)
                    labels_flat = labels.view(-1)
                    
                    loss = criterion(outputs, labels_flat)
                
                # Backward pass
                scaler.scale(loss).backward()
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(
                    list(watcher.parameters()) + list(decoder.parameters()), 
                    max_norm=1.0
                )
                scaler.step(optimizer)
                scaler.update()
                
                total_loss += loss.item()
                batch_count += 1
                
                # Update progress
                pbar.set_postfix({
                    'loss': f'{loss.item():.4f}',
                    'lr': f'{optimizer.param_groups[0]["lr"]:.2e}',
                    'gpu': f'{torch.cuda.memory_allocated(0)/1e9:.1f}GB'
                })
                
                # Cleanup
                del watcher_output, encoder_outputs, outputs, labels_flat, loss
                
                if batch_idx % 5 == 0:
                    torch.cuda.empty_cache()
                
            except RuntimeError as e:
                if "out of memory" in str(e):
                    print(f"\n⚠️  OOM at batch {batch_idx}")
                    print(f"    Memory: {torch.cuda.memory_allocated(0)/1e9:.2f}GB / "
                          f"{torch.cuda.memory_reserved(0)/1e9:.2f}GB")
                    
                    # Cleanup
                    for var in ['images', 'labels', 'watcher_output', 'encoder_outputs', 
                                'outputs', 'labels_flat', 'loss']:
                        if var in locals():
                            del locals()[var]
                    
                    torch.cuda.empty_cache()
                    gc.collect()
                    continue
                else:
                    raise e
        
        # End of epoch
        avg_loss = total_loss / batch_count if batch_count > 0 else float('inf')
        
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print(f"  Loss: {avg_loss:.4f}")
        print(f"  LR: {optimizer.param_groups[0]['lr']:.2e}")
        print(f"  GPU: {torch.cuda.memory_allocated(0)/1e9:.2f}GB / "
              f"{torch.cuda.memory_reserved(0)/1e9:.2f}GB")
        
        scheduler.step()
        
        # Save best model
        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save({
                'epoch': epoch,
                'watcher_state_dict': watcher.state_dict(),
                'decoder_state_dict': decoder.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'loss': best_loss,
                'vocab_size': vocab_size,
                'token2idx': token2idx,
                'idx2token': idx2token,
            }, 'best_model.pth')
            print(f"  ✓ Saved best model (loss: {best_loss:.4f})")
        
        torch.cuda.empty_cache()
        gc.collect()

except KeyboardInterrupt:
    print("\n⚠️  Training interrupted by user")
except Exception as e:
    print(f"\n❌ Training error: {e}")
    import traceback
    traceback.print_exc()
finally:
    # Save final model
    torch.save({
        'watcher_state_dict': watcher.state_dict(),
        'decoder_state_dict': decoder.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': total_loss / batch_count if 'total_loss' in locals() and batch_count > 0 else None,
        'vocab_size': vocab_size,
        'token2idx': token2idx,
        'idx2token': idx2token,
    }, 'final_model.pth')
    print("\n✓ Final model saved")
    
    torch.cuda.empty_cache()
    gc.collect()

print("\n" + "="*60)
print("TRAINING COMPLETE")
print("="*60)

In [None]:
# Cell: UPDATED Beam Search Decoding for Token-Based Vocabulary
import torch
import torch.nn.functional as F

@torch.no_grad()
def beam_search_decode(watcher, decoder, image, beam_width=5, max_len=128, length_penalty=0.7):
    """
    Beam search decoding with coverage penalty and length normalization
    UPDATED: Uses token-based vocabulary (token2idx/idx2token)
    """
    watcher.eval()
    decoder.eval()
    
    # Encode image
    if image.dim() == 3:
        image = image.unsqueeze(0)
    image = image.to(device)
    
    encoder_out = watcher(image)
    batch_size, channels, height, width = encoder_out.shape
    encoder_outputs = encoder_out.permute(0, 2, 3, 1).reshape(
        batch_size, height * width, channels
    )

    # CHANGED: Use token2idx for special tokens
    start_token = token2idx['<SOS>']
    end_token = token2idx['<EOS>']
    pad_token = token2idx['<PAD>']
    
    # Initialize
    coverage = torch.zeros(1, encoder_outputs.size(1), 1, device=device)
    hidden = torch.zeros(1, decoder.decoder_dim, device=device)

    beams = [(torch.tensor([start_token], device=device), hidden, 0.0, coverage, 0.0)]
    completed_beams = []

    for step in range(max_len):
        if len(completed_beams) >= beam_width:
            break
            
        candidates = []
        
        for seq, h, score, cov, _ in beams:
            if seq[-1].item() == end_token:
                completed_beams.append((seq, score))
                continue

            # Decode one step
            embedded = decoder.embedding(seq[-1].unsqueeze(0))
            context, alpha = decoder.attention(encoder_outputs, h, cov)
            gru_input = torch.cat([embedded, context], dim=1)
            new_hidden = decoder.gru(gru_input, h)
            output = decoder.out(torch.cat([embedded, new_hidden, context], dim=1))
            log_probs = F.log_softmax(output, dim=1)

            topk_probs, topk_idx = log_probs.topk(beam_width * 2, dim=1)
            
            # Coverage penalty
            coverage_penalty = 0.5 * torch.sum(torch.min(cov.squeeze(-1), alpha)).item()
            
            for k in range(topk_probs.size(1)):
                next_token = topk_idx[0, k].item()
                
                # Skip PAD tokens
                if next_token == pad_token:
                    continue
                
                # Block immediate repetitions (except special tokens)
                if (len(seq) > 0 and 
                    next_token == seq[-1].item() and 
                    next_token not in [start_token, end_token, pad_token]):
                    continue
                
                next_seq = torch.cat([seq, topk_idx[0, k].unsqueeze(0)])
                new_cov = cov + alpha.unsqueeze(-1)
                raw_score = score + topk_probs[0, k].item() - coverage_penalty
                
                # Length normalization
                length = next_seq.size(0)
                norm_score = raw_score / (length ** length_penalty)
                
                candidates.append((next_seq, new_hidden, raw_score, new_cov, norm_score))

        if not candidates:
            break
        
        beams = sorted(candidates, key=lambda x: x[4], reverse=True)[:beam_width]

    # Add incomplete beams
    for seq, _, score, _, _ in beams:
        if seq[-1].item() != end_token:
            completed_beams.append((seq, score))

    if not completed_beams:
        completed_beams = [(seq, score) for seq, _, score, _, _ in beams]

    best_seq, best_score = max(completed_beams, key=lambda x: x[1])
    
    # CHANGED: Decode tokens to LaTeX string
    decoded_tokens = []
    for idx in best_seq:
        token_idx = idx.item()
        if token_idx == start_token:
            continue
        if token_idx == end_token:
            break
        if token_idx == pad_token:
            continue
        
        # CHANGED: Use idx2token instead of idx2char
        if token_idx in idx2token:
            decoded_tokens.append(idx2token[token_idx])
    
    # Join tokens (no separator for LaTeX)
    return ''.join(decoded_tokens)


print("Beam search function loaded successfully")
print("  Using token-based vocabulary")
print(f"  Vocab size: {vocab_size}")

In [None]:
# Cell: FIXED Evaluation - Uses Pre-Encoded Labels

import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import editdistance

# Load models
print("Loading models...")
watcher = WatcherFCN(in_channels=1)
decoder = ParserGRUDecoder(vocab_size=vocab_size)

checkpoint = torch.load('best_model.pth', map_location=device)
watcher.load_state_dict(checkpoint['watcher_state_dict'])
decoder.load_state_dict(checkpoint['decoder_state_dict'])

if 'vocab_size' in checkpoint:
    assert checkpoint['vocab_size'] == vocab_size, "Vocabulary size mismatch!"
    print(f"✓ Vocabulary size verified: {vocab_size}")

watcher = watcher.to(device)
decoder = decoder.to(device)
watcher.eval()
decoder.eval()

print("✓ Model loaded successfully!")
print(f"  Best training loss: {checkpoint['loss']:.4f}")

def calculate_wer(reference, hypothesis):
    """Calculate Word Error Rate for LaTeX tokens"""
    ref_tokens = custom_latex_tokenize(reference)
    hyp_tokens = custom_latex_tokenize(hypothesis)
    
    if len(ref_tokens) == 0 and len(hyp_tokens) == 0:
        return 0.0
    if len(ref_tokens) == 0:
        return 1.0
    
    distance = editdistance.eval(ref_tokens, hyp_tokens)
    wer = distance / len(ref_tokens)
    return wer

@torch.no_grad()
def evaluate_model(watcher, decoder, data_loader, max_samples=50):
    """
    FIXED: Evaluate using PRE-ENCODED labels
    """
    watcher.eval()
    decoder.eval()
    
    total_correct = 0
    total_samples = 0
    total_wer = 0.0
    samples = []
    
    for batch in tqdm(data_loader, desc='Evaluating'):
        images = batch['image'].to(device, non_blocking=True)
        # CHANGED: Labels are already encoded tensors, need to decode them
        encoded_labels = batch['label']  # [batch, 128] tensor of token indices
        
        for i in range(len(images)):
            if total_samples >= max_samples:
                break
            
            single_img = images[i]  # [1, H, W]
            encoded_label = encoded_labels[i]  # [128] tensor
            
            # DECODE the encoded label to LaTeX string
            true_text = decode_tokens_to_latex(encoded_label)
            
            # Predict
            pred_text = beam_search_decode(watcher, decoder, single_img, beam_width=5)
            
            # Calculate metrics
            is_correct = (pred_text == true_text)
            wer = calculate_wer(true_text, pred_text)
            
            total_correct += is_correct
            total_wer += wer
            total_samples += 1
            
            if len(samples) < max_samples:
                samples.append({
                    'pred': pred_text,
                    'true': true_text,
                    'correct': is_correct,
                    'wer': wer
                })
        
        if total_samples % 20 == 0:
            torch.cuda.empty_cache()
        
        if total_samples >= max_samples:
            break
    
    accuracy = total_correct / total_samples if total_samples > 0 else 0
    avg_wer = total_wer / total_samples if total_samples > 0 else 0
    
    print(f"\n{'='*60}")
    print(f"METRICS SUMMARY")
    print(f"{'='*60}")
    print(f"Exact Match Accuracy: {accuracy:.2%} ({total_correct}/{total_samples})")
    print(f"Average WER:          {avg_wer:.4f}")
    print(f"WER Percentage:       {avg_wer*100:.2f}%")
    print(f"{'='*60}\n")
    
    # Show samples
    print("Sample Predictions (first 10):")
    print("-" * 60)
    for i, sample in enumerate(samples[:10], 1):
        status = '✓' if sample['correct'] else '✗'
        print(f"\n{i}. {status} [WER: {sample['wer']:.4f}]")
        print(f"   True:      {sample['true']}")
        print(f"   Predicted: {sample['pred']}")
        
        if not sample['correct']:
            true_tokens = custom_latex_tokenize(sample['true'])
            pred_tokens = custom_latex_tokenize(sample['pred'])
            print(f"   True tokens:  {true_tokens[:10]}{'...' if len(true_tokens) > 10 else ''}")
            print(f"   Pred tokens:  {pred_tokens[:10]}{'...' if len(pred_tokens) > 10 else ''}")
            
            distance = editdistance.eval(true_tokens, pred_tokens)
            print(f"   Token count (true/pred): {len(true_tokens)}/{len(pred_tokens)}")
            print(f"   Edit distance: {distance}")
    
    print("-" * 60)
    
    # WER distribution
    print("\nWER Distribution:")
    wer_perfect = sum(1 for s in samples if s['wer'] == 0.0)
    wer_low = sum(1 for s in samples if 0.0 < s['wer'] <= 0.1)
    wer_medium = sum(1 for s in samples if 0.1 < s['wer'] <= 0.5)
    wer_high = sum(1 for s in samples if s['wer'] > 0.5)
    
    print(f"  Perfect (WER = 0.0):      {wer_perfect:3d} ({wer_perfect/len(samples)*100:.1f}%)")
    print(f"  Low (0.0 < WER ≤ 0.1):    {wer_low:3d} ({wer_low/len(samples)*100:.1f}%)")
    print(f"  Medium (0.1 < WER ≤ 0.5): {wer_medium:3d} ({wer_medium/len(samples)*100:.1f}%)")
    print(f"  High (WER > 0.5):         {wer_high:3d} ({wer_high/len(samples)*100:.1f}%)")
    print("-" * 60)
    
    return accuracy, avg_wer, samples


# HELPER: Decode token indices back to LaTeX string
def decode_tokens_to_latex(encoded_label):
    """
    Convert encoded label tensor to LaTeX string
    Args:
        encoded_label: torch.Tensor [128] of token indices
    Returns:
        latex_str: LaTeX string (without special tokens/padding)
    """
    tokens = []
    for idx in encoded_label:
        token_idx = idx.item()
        
        # Stop at EOS or PAD
        if token_idx == token2idx['<EOS>'] or token_idx == token2idx['<PAD>']:
            break
        
        # Skip SOS
        if token_idx == token2idx['<SOS>']:
            continue
        
        # Add token
        if token_idx in idx2token:
            tokens.append(idx2token[token_idx])
    
    return ''.join(tokens)


# FIXED: Use MathEquationEncodedDataset (with pre-encoded labels)
print("\n" + "="*60)
print("VALIDATION SET EVALUATION")
print("="*60)

VAL_CSV_ENCODED = os.path.join(DATASET_ROOT, 'val_database_encoded.csv')  # ← Use encoded CSV!
val_dataset = MathEquationEncodedDataset(VAL_CSV_ENCODED, DATASET_ROOT, split='val')
val_loader = DataLoader(
    val_dataset, 
    batch_size=16, 
    shuffle=False, 
    num_workers=4,  # Reduced from 8
    pin_memory=True
)

val_accuracy, val_wer, val_samples = evaluate_model(watcher, decoder, val_loader, max_samples=50)


print("\n" + "="*60)
print("TEST SET EVALUATION")
print("="*60)

TEST_CSV_ENCODED = os.path.join(DATASET_ROOT, 'test_database_encoded.csv')  # ← Use encoded CSV!
test_dataset = MathEquationEncodedDataset(TEST_CSV_ENCODED, DATASET_ROOT, split='test')
test_loader = DataLoader(
    test_dataset, 
    batch_size=16, 
    shuffle=False, 
    num_workers=4,
    pin_memory=True
)

test_accuracy, test_wer, test_samples = evaluate_model(watcher, decoder, test_loader, max_samples=50)


# Final summary
print("\n" + "="*60)
print("FINAL EVALUATION SUMMARY")
print("="*60)
print(f"{'Metric':<25} {'Validation':<15} {'Test':<15}")
print("-" * 60)
print(f"{'Exact Match Accuracy':<25} {val_accuracy:>14.2%} {test_accuracy:>14.2%}")
print(f"{'Average WER':<25} {val_wer:>14.4f} {test_wer:>14.4f}")
print(f"{'WER Percentage':<25} {val_wer*100:>13.2f}% {test_wer*100:>13.2f}%")
print(f"{'Vocabulary Size':<25} {vocab_size:>14,}")
print(f"{'GPU Memory (Reserved)':<25} {torch.cuda.memory_reserved(0)/1e9:>13.2f}GB")
print(f"{'GPU Memory (Allocated)':<25} {torch.cuda.memory_allocated(0)/1e9:>13.2f}GB")
print("="*60)

torch.cuda.empty_cache()

Model loaded successfully!
Best training loss: 2.0836
Evaluating on Validation Set...


Evaluating: 100%|██████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.80s/it]



Accuracy: 0.00% (0/10)

Sample 1 ✗
Predicted: \tilde{n}(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(
True     : f_{\omega+1}(f_{\omega}(3))-2
------------------------------------------------------------
Sample 2 ✗
Predicted: \tilde{n}(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(
True     : s\notin\alpha,t\notin\gamma
------------------------------------------------------------
Sample 3 ✗
Predicted: \tilde{n}(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(
True     : \int f(x)dx
------------------------------------------------------------
Sample 4 ✗
Predicted: \tilde{n}(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(
True     : \hat{y}(f)
------------------------------------------------------------
Sample 5 ✗
Predicted: \tilde{n}(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(
True  

Evaluating: 100%|██████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.37s/it]


Accuracy: 0.00% (0/10)

Sample 1 ✗
Predicted: \tilde{n}(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(
True     : M,w\models I^{\alpha}(e)
------------------------------------------------------------
Sample 2 ✗
Predicted: \tilde{n}(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(
True     : dX=\frac{\partial X}{\partial x}dx=F^{-1}dx=HdxordX_{M}=\frac{\partial X_{M}}{\partial x_{n}}dx_{n}
------------------------------------------------------------
Sample 3 ✗
Predicted: \tilde{n}(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(
True     : TU=\sqrt{\frac{DU^{3}}{G*M}}
------------------------------------------------------------
Sample 4 ✗
Predicted: \tilde{n}(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(x)(
True     : (\frac{4}{7}-9)^{204\cdot\sqrt{5}}
------------------------------------------------------------
Sample 5 ✗
Predict


