In [5]:
 # Cell 0: FAST MULTIPROCESSING PREPROCESSING
import os
import numpy as np
import cv2
import pandas as pd
from tqdm import tqdm
from multiprocessing import Pool, cpu_count
import pickle

def get_directional_kernels():
    k = np.array([[1, 2, 1], [0, 0, 0], [-1, -2, -1]])
    kernels = [
        k,
        np.rot90(k, 1),
        np.rot90(k, 2),
        np.rot90(k, 3),
        np.fliplr(k),
        np.flipud(k),
        np.fliplr(np.rot90(k, 1)),
        np.flipud(np.rot90(k, 3)),
    ]
    return kernels

def get_directional_maps(image):
    kernels = get_directional_kernels()
    edge_maps = [cv2.filter2D(image, -1, kern) for kern in kernels]
    edge_maps = [(em.astype(np.float32) / 255.0) for em in edge_maps]
    edge_maps = [np.clip(em, 0, 1) for em in edge_maps]
    return np.stack(edge_maps, axis=0)

def process_single_image(args):
    """Process a single image - used for multiprocessing"""
    idx, filename, label, dataset_root, split, output_dir = args
    
    try:
        # Construct paths
        img_path = os.path.join(dataset_root, split, filename)
        output_filename = filename.replace('.png', '.npz').replace('.jpg', '.npz')
        output_path = os.path.join(output_dir, split, output_filename)
        
        # Read and process image
        image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        if image is None:
            return None, f"Could not read {img_path}"
        
        image = image.astype(np.float32) / 255.0
        H, W = image.shape
        
        # Compute 9 channels
        channels = np.zeros((9, H, W), dtype=np.float32)
        channels[0] = image
        channels[1:] = get_directional_maps(image)
        
        # Create subdirectories and save
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        np.savez_compressed(output_path, channels=channels)
        
        return {
            'filename': output_filename,
            'normalized_label': label,
            'original_filename': filename
        }, None
        
    except Exception as e:
        return None, f"Error processing {filename}: {str(e)}"

def precompute_9channel_dataset_parallel(csv_file, dataset_root, split, output_dir, num_workers=30):
    """
    Precompute dataset using multiprocessing - MUCH FASTER!
    """
    print(f"\n{'='*60}")
    print(f"PREPROCESSING {split.upper()} SPLIT (PARALLEL)")
    print(f"{'='*60}")
    
    # Create output directory
    split_output_dir = os.path.join(output_dir, split)
    os.makedirs(split_output_dir, exist_ok=True)
    
    # Load CSV
    df = pd.read_csv(csv_file)
    total_images = len(df)
    print(f"Found {total_images:,} images to process")
    
    # Determine number of workers
    if num_workers is None:
        num_workers = max(1, cpu_count() - 2)  # Leave 2 cores free
    print(f"Using {num_workers} CPU cores")
    
    # Prepare arguments for multiprocessing
    args_list = [
        (idx, df.iloc[idx]['filename'], df.iloc[idx]['normalized_label'], 
         dataset_root, split, output_dir)
        for idx in range(total_images)
    ]
    
    # Process in parallel with progress bar
    processed_data = []
    failed_count = 0
    
    with Pool(processes=num_workers) as pool:
        results = list(tqdm(
            pool.imap(process_single_image, args_list),
            total=total_images,
            desc=f"Processing {split}"
        ))
    
    # Collect results
    for result, error in results:
        if result is not None:
            processed_data.append(result)
        else:
            failed_count += 1
            if error and failed_count <= 10:  # Print first 10 errors
                print(f"\n{error}")
    
    # Save mapping CSV
    mapping_df = pd.DataFrame(processed_data)
    mapping_csv = os.path.join(output_dir, f'{split}_mapping.csv')
    mapping_df.to_csv(mapping_csv, index=False)
    
    print(f"\n{'='*60}")
    print(f"PREPROCESSING {split.upper()} COMPLETE")
    print(f"{'='*60}")
    print(f"Successfully processed: {len(processed_data):,} / {total_images:,}")
    print(f"Failed: {failed_count:,}")
    print(f"Success rate: {len(processed_data)/total_images*100:.2f}%")
    print(f"Saved to: {split_output_dir}")
    print(f"Mapping saved to: {mapping_csv}")
    print(f"{'='*60}\n")
    
    return len(processed_data), failed_count

# ==============================
# RUN PREPROCESSING
# ==============================
if __name__ == '__main__':  # Important for Windows multiprocessing
    DATASET_ROOT = '/home/ie643_errorcode500/errorcode500-working/ProcessedFullMathwriting'
    PROCESSED_ROOT = '/home/ie643_errorcode500/errorcode500-working/FinalMath'
    
    print("Starting PARALLEL dataset preprocessing...")
    print(f"Source: {DATASET_ROOT}")
    print(f"Output: {PROCESSED_ROOT}")
    print(f"Available CPUs: {cpu_count()}")
    print("\nThis will use multiprocessing for 10-20x speedup!")
    print("Estimated time for 454k images: 2-3 hours (vs 20+ hours sequential)\n")
    
    # You can adjust num_workers - more = faster (but uses more RAM)
    NUM_WORKERS = 30  # Adjust based on your system (typically cpu_count() - 2)
    
    input("Press Enter to continue...")
    
    import time
    start_time = time.time()
    
    # Preprocess all splits
    for split in ['train', 'val', 'test']:
        csv_file = os.path.join(DATASET_ROOT, f'{split}_database.csv')
        if os.path.exists(csv_file):
            success, failed = precompute_9channel_dataset_parallel(
                csv_file, DATASET_ROOT, split, PROCESSED_ROOT, 
                num_workers=NUM_WORKERS
            )
        else:
            print(f"Warning: {csv_file} not found, skipping {split} split")
    
    total_time = time.time() - start_time
    
    print("\n" + "="*60)
    print("ALL PREPROCESSING COMPLETE!")
    print("="*60)
    print(f"Total time: {total_time/3600:.2f} hours")
    print("You can now use the Fast9ChDataset for training")
    print("="*60)

Starting PARALLEL dataset preprocessing...
Source: /home/ie643_errorcode500/errorcode500-working/ProcessedFullMathwriting
Output: /home/ie643_errorcode500/errorcode500-working/FinalMath
Available CPUs: 32

This will use multiprocessing for 10-20x speedup!
Estimated time for 454k images: 2-3 hours (vs 20+ hours sequential)

Press Enter to continue...

PREPROCESSING TRAIN SPLIT (PARALLEL)
Found 454,437 images to process
Using 30 CPU cores


Processing train:   0%|▏                                                        | 1758/454437 [01:01<4:23:06, 28.67it/s]Process ForkPoolWorker-29:
Process ForkPoolWorker-25:
Process ForkPoolWorker-34:
Process ForkPoolWorker-26:
Process ForkPoolWorker-41:
Process ForkPoolWorker-39:
Process ForkPoolWorker-38:
Process ForkPoolWorker-43:
Process ForkPoolWorker-37:
Process ForkPoolWorker-31:
Process ForkPoolWorker-22:
Process ForkPoolWorker-23:
Traceback (most recent call last):
Process ForkPoolWorker-24:
Traceback (most recent call last):
Process ForkPoolWorker-42:
Traceback (most recent call last):
Process ForkPoolWorker-20:
Process ForkPoolWorker-40:
  File "/opt/anaconda/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/opt/anaconda/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
Traceback (most recent call last):
  File "/opt/anaconda/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self

KeyboardInterrupt: 

In [4]:
import os
from multiprocessing import cpu_count

# Check your system
total_cores = cpu_count()
print(f"Total CPU cores available: {total_cores}")

# Rule of thumb for choosing workers:
# 1. Leave 2-4 cores for system/other tasks
# 2. Consider your RAM (each worker needs ~300-500MB)
# 3. Consider I/O bottlenecks (SSD vs HDD)

recommended_workers = max(1, total_cores - 2)
print(f"Recommended workers: {recommended_workers}")

Total CPU cores available: 32
Recommended workers: 30


In [None]:
# Cell 1: FAST Dataset Loading (loads precomputed 9-channel arrays)
import pandas as pd
import numpy as np

import os
from torch.utils.data import Dataset, DataLoader
import torch

os.environ['CUDA_VISIBLE_DEVICES'] = '1'

class Fast9ChDataset(Dataset):
    """
    Fast dataset that loads precomputed 9-channel numpy arrays
    NO cv2.filter2D calls during training - 10x faster!
    """
    def __init__(self, processed_root, split='train'):
        """
        Args:
            processed_root: Root directory of preprocessed data
            split: 'train', 'val', or 'test'
        """
        self.processed_root = processed_root
        self.split = split
        
        # Load mapping CSV
        mapping_csv = os.path.join(processed_root, f'{split}_mapping.csv')
        if not os.path.exists(mapping_csv):
            raise FileNotFoundError(
                f"Mapping file not found: {mapping_csv}\n"
                f"Please run the preprocessing step first (Cell 0)"
            )
        
        self.data_frame = pd.read_csv(mapping_csv)
        print(f"Loaded {len(self.data_frame)} preprocessed {split} samples")

    def __len__(self):
        return len(self.data_frame)

    def __getitem__(self, idx):
        filename = self.data_frame.iloc[idx]['filename']
        label = self.data_frame.iloc[idx]['normalized_label']
        
        # Load precomputed 9-channel array (FAST!)
        file_path = os.path.join(self.processed_root, self.split, filename)
        
        try:
            data = np.load(file_path)
            channels = data['channels']  # Already [9, H, W]
        except Exception as e:
            raise FileNotFoundError(f"Could not load {file_path}: {str(e)}")
        
        return {
            'image': torch.from_numpy(channels).float(),
            'label': label
        }

# ==============================
# Dataset Configuration
# ==============================
PROCESSED_ROOT = r'C:\Users\kani1\Desktop\IE643\custom-dataset\ProccessMathwritting-exercpt-9ch'

# Load training dataset
train_dataset = Fast9ChDataset(PROCESSED_ROOT, split='train')

# Optimized DataLoader for A6000
train_loader = DataLoader(
    train_dataset,
    batch_size=64,  # Increased from 16 to 64 (4x more)
    shuffle=True,
    num_workers=12,  # More workers
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=4  # Prefetch 4 batches per worker
)

print(f"\n{'='*60}")
print("DATASET READY FOR TRAINING")
print(f"{'='*60}")
print(f"Training samples: {len(train_dataset):,}")
print(f"Batch size: 64")
print(f"Batches per epoch: {len(train_loader):,}")
print(f"Workers: 12")
print(f"{'='*60}\n")

# Test loading a batch
try:
    for batch in train_loader:
        images, labels = batch['image'], batch['label']
        print(f"Test batch loaded successfully!")
        print(f"Image shape: {images.shape}")
        print(f"First 3 labels: {labels[:3]}")
        break
except Exception as e:
    print(f"Error loading batch: {str(e)}")

In [None]:
import pandas as pd
from collections import Counter

DATASET_ROOT = r"C:\Users\kani1\Desktop\IE643\custom-dataset\ProcessedFullMathwriting"
# Load all labels from train/val/test CSVs
csv_files = [
    'train_database.csv',
    'val_database.csv',
    'test_database.csv'
]


all_labels = []
for csv_file in csv_files:
    df = pd.read_csv(os.path.join(DATASET_ROOT, csv_file))
    all_labels.extend(df['normalized_label'].astype(str).tolist())

# Build character-level vocabulary
special_tokens = ['<PAD>', '<SOS>', '<EOS>']
char_counter = Counter()
for label in all_labels:
    char_counter.update(list(label))

vocab = special_tokens + sorted(char_counter.keys())
char2idx = {ch: idx for idx, ch in enumerate(vocab)}
idx2char = {idx: ch for ch, idx in char2idx.items()}

print(f"Vocabulary size: {len(vocab)}")
print("First 20 tokens:", vocab[:20])

# Encode a label string to indices
def encode_label(label, max_len=128):
    tokens = [char2idx['<SOS>']] + [char2idx[ch] for ch in label] + [char2idx['<EOS>']]
    if len(tokens) < max_len:
        tokens += [char2idx['<PAD>']] * (max_len - len(tokens))
    else:
        tokens = tokens[:max_len]
    return tokens

# Example usage
sample_label = all_labels[0]
encoded = encode_label(sample_label)
print("Original label:", sample_label)
print("Encoded:", encoded[:20])

# For your dataset class, you can add:
# label_indices = encode_label(label)
# sample = {'image': image_tensor, 'label': label_indices}


Vocabulary size: 91
First 20 tokens: ['<PAD>', '<SOS>', '<EOS>', ' ', '!', '#', '&', '(', ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4']
Original label: \vartheta=-\frac{log\frac{\phi_{\varsigma_{1}}}{\phi_{\varsigma_{2}}}}{log\frac{\varsigma_{1}}{\varsigma_{2}}}
Encoded: [1, 58, 83, 62, 79, 81, 69, 66, 81, 62, 28, 12, 58, 67, 79, 62, 64, 88, 73, 76]


In [None]:
# Cell 3: Model Architecture
import torch
import torch.nn as nn
import torch.nn.functional as F

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, num_layers=4, dropout_p=0.0):
        super().__init__()
        layers = []
        for i in range(num_layers):
            layers.append(nn.Conv2d(
                in_channels if i == 0 else out_channels,
                out_channels,
                kernel_size=3,
                stride=1,
                padding=1
            ))
            layers.append(nn.BatchNorm2d(out_channels))
            layers.append(nn.ReLU(inplace=True))
            if dropout_p > 0:
                layers.append(nn.Dropout2d(p=dropout_p))
        self.block = nn.Sequential(*layers)

    def forward(self, x):
        return self.block(x)

class WatcherFCN(nn.Module):
    def __init__(self, in_channels=9):
        super().__init__()
        # Smaller architecture for faster training
        self.block1 = ConvBlock(in_channels, 32)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.block2 = ConvBlock(32, 64)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.block3 = ConvBlock(64, 64)
        self.pool3 = nn.MaxPool2d(2, 2)
        self.block4 = ConvBlock(64, 128, dropout_p=0.2)
        self.pool4 = nn.MaxPool2d(2, 2)

    def forward(self, x):
        x = self.block1(x)
        x = self.pool1(x)
        x = self.block2(x)
        x = self.pool2(x)
        x = self.block3(x)
        x = self.pool3(x)
        x = self.block4(x)
        x = self.pool4(x)
        return x

# Test model
model = WatcherFCN(in_channels=9)
dummy_input = torch.randn(2, 9, 480, 1600)
output = model(dummy_input)
print(f"Output shape: {output.shape}")

# Count parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"WatcherFCN parameters: {count_parameters(model):,}")

torch.Size([2, 128, 30, 100])


In [7]:
batch_size, channels, height, width = output.shape
encoder_outputs = output.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels)
# encoder_outputs: [batch, 3000, 512]
encoder_outputs.shape

torch.Size([2, 3000, 128])

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CoverageAttention(nn.Module):
    def __init__(self, encoder_dim, decoder_dim, attention_dim, coverage_dim):
        super().__init__()
        self.W_a = nn.Linear(decoder_dim, attention_dim)
        self.U_a = nn.Linear(encoder_dim, attention_dim)
        self.U_f = nn.Linear(coverage_dim, attention_dim)
        self.v = nn.Linear(attention_dim, 1)

    def forward(self, encoder_outputs, decoder_hidden, coverage):
        # encoder_outputs: [batch, L, encoder_dim]
        # decoder_hidden: [batch, decoder_dim]
        # coverage: [batch, L, coverage_dim]
        Wh = self.W_a(decoder_hidden).unsqueeze(1)  # [batch, 1, att_dim]
        Ua = self.U_a(encoder_outputs)              # [batch, L, att_dim]
        Uf = self.U_f(coverage)                     # [batch, L, att_dim]
        att = torch.tanh(Wh + Ua + Uf)              # [batch, L, att_dim]
        scores = self.v(att).squeeze(-1)            # [batch, L]
        alpha = F.softmax(scores, dim=1)            # [batch, L]
        context = torch.sum(encoder_outputs * alpha.unsqueeze(-1), dim=1)  # [batch, encoder_dim]
        return context, alpha

# class ParserGRUDecoder(nn.Module):
#     def __init__(self, vocab_size, encoder_dim=128, embed_dim=256, decoder_dim=256, attention_dim=256, coverage_dim=1):
#         super().__init__()
#         self.embedding = nn.Embedding(vocab_size, embed_dim)
#         self.gru = nn.GRUCell(embed_dim + encoder_dim, decoder_dim)
#         self.attention = CoverageAttention(encoder_dim, decoder_dim, attention_dim, coverage_dim)
#         self.fc = nn.Linear(decoder_dim + encoder_dim, vocab_size)

#     def forward(self, encoder_outputs, targets, max_len):
#         batch_size, L, encoder_dim = encoder_outputs.size()
#         device = encoder_outputs.device
#         coverage = torch.zeros(batch_size, L, 1, device=device)
#         inputs = torch.full((batch_size,), 1, dtype=torch.long, device=device)  # <SOS> token index
#         hidden = torch.zeros(batch_size, 256, device=device)
#         outputs = []
#         for t in range(max_len):
#             embedded = self.embedding(inputs)  # [batch, embed_dim]
#             context, alpha = self.attention(encoder_outputs, hidden, coverage)
#             gru_input = torch.cat([embedded, context], dim=1)
#             hidden = self.gru(gru_input, hidden)
#             output = self.fc(torch.cat([hidden, context], dim=1))
#             outputs.append(output)
#             # Teacher forcing: use ground truth if available
#             if targets is not None and t < targets.size(1):
#                 inputs = targets[:, t]
#             else:
#                 inputs = output.argmax(dim=1)
#             coverage = coverage + alpha.unsqueeze(-1)
#         outputs = torch.stack(outputs, dim=1)  # [batch, max_len, vocab_size]
#         return outputs


#Modified ParserGRUDecoder

class ParserGRUDecoder(nn.Module):
    def __init__(self, vocab_size, encoder_dim=128, embed_dim=256, decoder_dim=256, attention_dim=256, coverage_dim=1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        # Modify input size to include context vector
        self.gru = nn.GRUCell(embed_dim + encoder_dim, decoder_dim)
        self.attention = CoverageAttention(encoder_dim, decoder_dim, attention_dim, coverage_dim)
        # Change output layer to use all available information
        self.out = nn.Sequential(
            nn.Linear(decoder_dim + encoder_dim + embed_dim, decoder_dim),
            nn.Tanh(),
            nn.Linear(decoder_dim, vocab_size)
        )
        self.decoder_dim = decoder_dim

    def forward(self, encoder_outputs, targets, max_len):
        batch_size, L, encoder_dim = encoder_outputs.size()
        device = encoder_outputs.device
        coverage = torch.zeros(batch_size, L, 1, device=device)
        inputs = torch.full((batch_size,), 1, dtype=torch.long, device=device)  # <SOS> token index
        hidden = torch.zeros(batch_size, self.decoder_dim, device=device)
        outputs = []

        for t in range(max_len):
            # 1. Get current input embedding
            embedded = self.embedding(inputs)  # [batch, embed_dim]
            
            # 2. Calculate attention and context
            context, alpha = self.attention(encoder_outputs, hidden, coverage)
            
            # 3. Update GRU hidden state with concatenated input
            gru_input = torch.cat([embedded, context], dim=1)
            hidden = self.gru(gru_input, hidden)
            
            # 4. Generate output using all available information
            # Concatenate current embedding, hidden state, and context
            output = self.out(torch.cat([embedded, hidden, context], dim=1))
            outputs.append(output)
            
            # 5. Teacher forcing or use own predictions
            if targets is not None and t < targets.size(1):
                inputs = targets[:, t]
            else:
                inputs = output.argmax(dim=1)
            
            # 6. Update coverage vector
            coverage = coverage + alpha.unsqueeze(-1)
        
        outputs = torch.stack(outputs, dim=1)  # [batch, max_len, vocab_size]
        return outputs

# Example usage:
# encoder_outputs: [batch, L, encoder_dim] (flatten FCN output to [batch, L, 512])
# targets: [batch, max_len] (token indices)
# decoder = ParserGRUDecoder(vocab_size=len(vocab))
# outputs = decoder(encoder_outputs, targets, max_len)


This is the main training code which we were using before. Now I am using a better code which is below this commented cell


In [None]:
# Cell: OPTIMIZED TRAINING for A6000
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import torch.nn.functional as F
import time

# ==============================
# Device Setup
# ==============================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    # A6000 optimizations
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    torch.backends.cudnn.benchmark = True

# ==============================
# Initialize Models
# ==============================
watcher = WatcherFCN(in_channels=9).to(device)
decoder = ParserGRUDecoder(vocab_size=len(vocab)).to(device)

# Print model sizes
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nModel Parameters:")
print(f"  WatcherFCN: {count_parameters(watcher):,}")
print(f"  Decoder: {count_parameters(decoder):,}")
print(f"  Total: {count_parameters(watcher) + count_parameters(decoder):,}")

# ==============================
# Training Configuration
# ==============================
pad_idx = vocab.index('<PAD>')
criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)
optimizer = optim.AdamW(  # AdamW instead of Adadelta
    list(watcher.parameters()) + list(decoder.parameters()),
    lr=1e-3,
    weight_decay=0.01
)
scheduler = optim.lr_scheduler.OneCycleLR(  # Better scheduler
    optimizer,
    max_lr=1e-3,
    epochs=10,
    steps_per_epoch=len(train_loader),
    pct_start=0.1
)

num_epochs = 10
max_len = 128
best_loss = float('inf')

# Mixed precision scaler
scaler = torch.cuda.amp.GradScaler()

# ==============================
# Helper Functions
# ==============================
def apply_weight_noise(model, std=0.01):
    """Adds Gaussian noise for regularization"""
    with torch.no_grad():
        for p in model.parameters():
            if p.requires_grad:
                p.add_(torch.randn_like(p) * std)

# ==============================
# Training Loop
# ==============================
print(f"\n{'='*60}")
print("STARTING TRAINING")
print(f"{'='*60}\n")

training_start = time.time()

try:
    for epoch in range(num_epochs):
        epoch_start = time.time()
        watcher.train()
        decoder.train()
        total_loss = 0
        batch_count = 0
        
        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')
        
        for batch_idx, batch in enumerate(pbar):
            # Non-blocking transfer
            images = batch['image'].to(device, non_blocking=True)
            labels = [encode_label(lbl, max_len) for lbl in batch['label']]
            labels = torch.tensor(labels, dtype=torch.long, device=device)

            optimizer.zero_grad(set_to_none=True)
            
            try:
                # Mixed precision forward pass
                with torch.cuda.amp.autocast():
                    watcher_output = watcher(images)
                    batch_size, channels, height, width = watcher_output.shape
                    encoder_outputs = watcher_output.permute(0, 2, 3, 1).reshape(
                        batch_size, height * width, channels
                    )

                    outputs = decoder(encoder_outputs, labels, max_len)
                    outputs_flat = outputs.view(-1, outputs.size(-1))
                    labels_flat = labels.view(-1)

                    loss = criterion(outputs_flat, labels_flat)

                # Scaled backpropagation
                scaler.scale(loss).backward()
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(
                    list(watcher.parameters()) + list(decoder.parameters()), 
                    max_norm=5.0
                )
                scaler.step(optimizer)
                scaler.update()
                scheduler.step()

                total_loss += loss.item()
                batch_count += 1
                
                # Update progress bar
                pbar.set_postfix({
                    'loss': f'{loss.item():.4f}',
                    'gpu': f'{torch.cuda.memory_allocated(0)/1e9:.1f}GB',
                    'lr': f'{scheduler.get_last_lr()[0]:.2e}'
                })
                
            except RuntimeError as e:
                print(f"\nError in batch {batch_idx}: {str(e)}")
                torch.cuda.empty_cache()
                continue

        # Epoch summary
        epoch_time = time.time() - epoch_start
        avg_loss = total_loss / batch_count
        imgs_per_sec = len(train_dataset) / epoch_time
        
        print(f"\n{'='*60}")
        print(f"Epoch {epoch+1}/{num_epochs} Summary")
        print(f"{'='*60}")
        print(f"Average Loss: {avg_loss:.4f}")
        print(f"Time: {epoch_time/60:.1f} minutes")
        print(f"Speed: {imgs_per_sec:.0f} images/sec")
        print(f"GPU Memory: {torch.cuda.memory_allocated(0)/1e9:.2f}GB")
        print(f"{'='*60}\n")

        # Apply weight noise
        if epoch > 0 and epoch % 2 == 0:
            apply_weight_noise(watcher, std=0.01)
            apply_weight_noise(decoder, std=0.01)
            print("Applied weight noise regularization\n")

        # Save best model
        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save({
                'epoch': epoch,
                'watcher_state_dict': watcher.state_dict(),
                'decoder_state_dict': decoder.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'loss': best_loss,
            }, 'best_model.pth')
            print(f"✓ Saved best model (loss: {best_loss:.4f})\n")

        # Clear cache
        torch.cuda.empty_cache()

except KeyboardInterrupt:
    print("\n\nTraining interrupted by user")
except Exception as e:
    print(f"\n\nError during training: {str(e)}")
    import traceback
    traceback.print_exc()
finally:
    # Save final model
    torch.save({
        'watcher_state_dict': watcher.state_dict(),
        'decoder_state_dict': decoder.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': total_loss / batch_count if 'batch_count' in locals() and batch_count > 0 else None,
    }, 'final_model.pth')
    
    training_time = time.time() - training_start
    print(f"\n{'='*60}")
    print("TRAINING COMPLETE")
    print(f"{'='*60}")
    print(f"Total time: {training_time/3600:.2f} hours")
    print(f"Best loss: {best_loss:.4f}")
    print(f"Models saved: best_model.pth, final_model.pth")
    print(f"{'='*60}\n")
    
    torch.cuda.empty_cache()

In [None]:
# Cell: EVALUATION
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

@torch.no_grad()
def evaluate_model(watcher, decoder, data_loader, split_name, max_samples=50):
    """Evaluate model performance"""
    watcher.eval()
    decoder.eval()
    
    total_correct = 0
    total_samples = 0
    samples = []
    
    print(f"\nEvaluating on {split_name}...")
    for batch in tqdm(data_loader, desc=f'{split_name} Evaluation'):
        images = batch['image'].to(device, non_blocking=True)
        labels = batch['label']
        
        with torch.cuda.amp.autocast():
            watcher_output = watcher(images)
            batch_size, channels, height, width = watcher_output.shape
            encoder_outputs = watcher_output.permute(0, 2, 3, 1).reshape(
                batch_size, height * width, channels
            )
            outputs = decoder(encoder_outputs, None, max_len=128)
        
        for i in range(len(outputs)):
            pred_indices = outputs[i].argmax(dim=-1)
            pred_text = ''.join([idx2char[idx.item()] for idx in pred_indices 
                               if idx2char[idx.item()] not in ['<PAD>', '<SOS>', '<EOS>']])
            true_text = labels[i]
            
            is_correct = (pred_text == true_text)
            total_correct += is_correct
            total_samples += 1
            
            if len(samples) < max_samples:
                samples.append({
                    'pred': pred_text,
                    'true': true_text,
                    'correct': is_correct
                })
        
        if total_samples >= max_samples:
            break
    
    accuracy = total_correct / total_samples
    print(f"\n{'='*60}")
    print(f"{split_name.upper()} RESULTS")
    print(f"{'='*60}")
    print(f"Accuracy: {accuracy:.2%} ({total_correct}/{total_samples})")
    print(f"{'='*60}\n")
    
    return accuracy, samples

# Load best model
checkpoint = torch.load('best_model.pth', map_location=device)
watcher = WatcherFCN(in_channels=9).to(device)
decoder = ParserGRUDecoder(vocab_size=len(vocab)).to(device)
watcher.load_state_dict(checkpoint['watcher_state_dict'])
decoder.load_state_dict(checkpoint['decoder_state_dict'])

print(f"Loaded model from epoch {checkpoint['epoch']} (loss: {checkpoint['loss']:.4f})")

# Validation
val_dataset = Fast9ChDataset(PROCESSED_ROOT, split='val')
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=8, pin_memory=True)
val_acc, val_samples = evaluate_model(watcher, decoder, val_loader, 'Validation', max_samples=50)

# Test
test_dataset = Fast9ChDataset(PROCESSED_ROOT, split='test')
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=8, pin_memory=True)
test_acc, test_samples = evaluate_model(watcher, decoder, test_loader, 'Test', max_samples=50)

torch.cuda.empty_cache()