# Line-by-Line Classification with Vision Transformer

A clean implementation of sequential scanline classification using a Vision Transformer (ViT) architecture.

## What This Notebook Does

This notebook trains a transformer model to classify STM (Scanning Tunneling Microscope) images **while they're still being scanned**. Instead of waiting for the full image, the model can make predictions after seeing just a few scanlines.

**Real-world use case**: Stop bad scans early to save time!

## Architecture Overview

Our approach treats each scanline as a single token:
- **Input**: Partial scan with j scanlines (where j = 1, 2, ..., 128)
- **Architecture**: Vision Transformer that treats each full line as one token
- **Output**: Classification prediction based on lines seen so far

## Key Components

1. **Line Embedding**: Project each 128-pixel scanline to a 256-d embedding
2. **Positional Encoding**: Tell the model which line comes first, second, etc.
3. **Transformer Encoder**: Let all lines "talk to each other" via self-attention
4. **Classification Head**: Predict the class from the CLS token

## What to Expect

- **Setup**: ~10 seconds
- **Data loading**: ~5 seconds  
- **Model creation**: Instant
- **Training (2 epochs)**: ~5-10 minutes on CPU, ~30 seconds on GPU
- **Full training (50 epochs)**: ~2-3 hours on CPU, ~20 minutes on GPU

## 1. Setup

**What this does**: Import libraries and configure PyTorch

**Key libraries**:
- `torch`: The deep learning framework we use
- `nn` (neural network): Building blocks for our model
- `Dataset/DataLoader`: Handle batching and feeding data to the model
- `sklearn.metrics`: Evaluate model performance (accuracy, AUROC, etc.)
- `tqdm`: Show progress bars during training

**What to look for**: After running, check if you're using GPU or CPU. GPU is ~10-20x faster!

In [16]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, balanced_accuracy_score, confusion_matrix
import seaborn as sns
from tqdm.auto import tqdm

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

Using device: cpu


## 2. Dataset

**What this section does**: Load and prepare STM images for training

**Dataset format**:
- Each image: [128 scanlines, 128 pixels per line]  
- Each scanline: One row of the STM scan
- Labels: Integer class labels (0, 1, 2, 3, 4, 5)

**The key trick**: We create multiple training samples per image by randomly selecting different numbers of scanlines (e.g., 10 lines, 50 lines, 100 lines). This teaches the model to work with partial scans.

**Why normalize each line?**: STM data has intensity variations due to instrumental drift. By normalizing each line independently (zero mean, unit variance), we help the model focus on patterns rather than absolute intensity values.

In [17]:
class LineByLineDataset(Dataset):
    """
    Dataset for line-by-line classification.
    
    For each image, we create multiple samples by varying the number of visible lines.
    This simulates the real-time scanning scenario.
    """
    
    def __init__(self, images, labels, min_lines=5, max_lines=128, samples_per_image=10):
        """
        Args:
            images: Array of shape [N, max_lines, line_width]
            labels: Array of shape [N] with class labels
            min_lines: Minimum number of lines to use
            max_lines: Maximum number of lines (full image)
            samples_per_image: How many random cutoffs to generate per image
        """
        self.images = images
        self.labels = labels
        self.min_lines = min_lines
        self.max_lines = max_lines
        self.samples_per_image = samples_per_image
        
        # Normalize each scanline independently
        self._normalize_scanlines()
        
    def _normalize_scanlines(self):
        """Normalize each scanline to zero mean and unit variance."""
        for i in range(len(self.images)):
            for j in range(self.images.shape[1]):
                line = self.images[i, j, :]
                mean = line.mean()
                std = line.std()
                if std > 1e-8:
                    self.images[i, j, :] = (line - mean) / std
                else:
                    self.images[i, j, :] = line - mean
    
    def __len__(self):
        return len(self.images) * self.samples_per_image
    
    def __getitem__(self, idx):
        # Determine which image and which random cutoff
        img_idx = idx // self.samples_per_image
        
        # Random number of lines between min and max
        num_lines = np.random.randint(self.min_lines, self.max_lines + 1)
        
        # Get partial image
        partial_image = self.images[img_idx, :num_lines, :]
        label = self.labels[img_idx]
        
        return {
            'image': torch.FloatTensor(partial_image),
            'label': torch.LongTensor([label])[0],
            'num_lines': num_lines
        }


def read_dataset():
    """
    Read dataset for testing
    """
    data = np.load('data/processed_data.npz')

    images = data[data.files[0]]
    labels = data[data.files[1]]
    return images, labels

In [18]:
# Load the real dataset
print("Loading dataset from processed_data.npz...")
data = np.load('data/processed_data.npz')
all_images = data[data.files[0]]
all_labels = data[data.files[1]]

print(f"Total dataset: {all_images.shape[0]} images")
print(f"Image size: {all_images.shape[1]} x {all_images.shape[2]}")
print(f"Number of classes: {len(np.unique(all_labels))}")
print(f"Class distribution:")
for cls in np.unique(all_labels):
    count = np.sum(all_labels == cls)
    print(f"  Class {cls}: {count} samples ({100*count/len(all_labels):.2f}%)")

# Split into train/val/test (70/15/15)
from sklearn.model_selection import train_test_split

# First split: separate test set
train_val_images, test_images, train_val_labels, test_labels = train_test_split(
    all_images, all_labels, test_size=0.15, random_state=42, stratify=all_labels
)

# Second split: separate validation set
train_images, val_images, train_labels, val_labels = train_test_split(
    train_val_images, train_val_labels, test_size=0.176, random_state=42, stratify=train_val_labels
)  # 0.176 * 0.85 ≈ 0.15 of total

print(f"\nTrain set: {train_images.shape[0]} images")
print(f"Val set: {val_images.shape[0]} images")
print(f"Test set: {test_images.shape[0]} images")

Loading dataset from processed_data.npz...
Total dataset: 18932 images
Image size: 128 x 128
Number of classes: 6
Class distribution:
  Class 0: 1664 samples (8.79%)
  Class 1: 1482 samples (7.83%)
  Class 2: 6830 samples (36.08%)
  Class 3: 3202 samples (16.91%)
  Class 4: 1332 samples (7.04%)
  Class 5: 4422 samples (23.36%)

Train set: 13259 images
Val set: 2833 images
Test set: 2840 images


## 3. Model Architecture

**What this section does**: Build the transformer model piece by piece

The transformer has 3 main components:
1. **Line Embedding**: Convert raw pixel values to learnable representations
2. **Positional Encoding**: Add position information (which line is which)
3. **Transformer Layers**: Process the sequence with self-attention

**Why transformers?** They excel at sequence data because they can:
- Look at ALL scanlines simultaneously (not just nearby ones)
- Learn which parts of the sequence are important via attention
- Handle variable-length inputs (10 lines or 100 lines - doesn't matter!)

### 3.1 Line Embedding

**What it does**: Projects each 128-pixel scanline to a 256-dimensional embedding

Think of it like compression: 128 raw numbers → 256 "feature" numbers that capture the essence of the line.

**Why 256 dimensions?** This is a hyperparameter. Larger = more capacity but slower training.
Common values: 128, 256, 512, 768

In [19]:
class PatchEmbedding(nn.Module):
    """
    Convert scanlines into embeddings for the transformer.
    
    This module treats each full scanline as a single token and projects it
    into a higher-dimensional embedding space. This is more natural for 1D
    sequential data than subdividing lines into patches.
    
    For example, with line_width=128:
    - Each full 128-pixel scanline becomes one token
    - Each line is then projected to embed_dim dimensions
    - This results in num_lines tokens (much more efficient than patch-based approach)
    
    Args:
        line_width: Number of pixels in each scanline (default: 128)
        embed_dim: Dimensionality of the embedding space (default: 256)
    """
    
    def __init__(self, line_width=128, embed_dim=256):
        super().__init__()
        self.line_width = line_width
        self.embed_dim = embed_dim
        
        # Linear projection from full line to embed_dim
        self.projection = nn.Linear(line_width, embed_dim)
        
    def forward(self, x):
        """
        Convert scanlines to embeddings.
        
        Args:
            x: [batch_size, num_lines, line_width]
        
        Returns:
            embeddings: [batch_size, num_lines, embed_dim]
            where each line is one token
        """
        batch_size, num_lines, line_width = x.shape
        
        # Project each line directly to embedding dimension
        # Shape: (batch_size, num_lines, line_width) -> (batch_size, num_lines, embed_dim)
        embeddings = self.projection(x)
        
        return embeddings

### 3.2 Positional Encoding

**What it does**: Tells the model "this is line 1, this is line 2, etc."

**Why needed?** Transformers process all tokens in parallel. Without positional encoding, the model can't tell if it's seeing lines in order [1,2,3] or scrambled [3,1,2].

**How it works**: We use sinusoidal functions (sin/cos) at different frequencies:
- Low frequencies: Capture "rough" position (early vs late in scan)
- High frequencies: Capture "precise" position (line 42 vs line 43)

**Benefits of sinusoidal encoding**:
- Works for any sequence length (even longer than training!)
- Model can learn to attend to relative positions ("5 lines ago")
- No additional parameters to learn

In [20]:
class PositionalEncoding(nn.Module):
    """
    Add positional information to line embeddings using sinusoidal encoding.
    
    Transformers have no inherent notion of sequence order, so we add positional
    encodings to give the model information about the position of each scanline.
    This is crucial for STM data where the temporal order of scanlines matters.
    
    We use the sinusoidal encoding from "Attention is All You Need" (Vaswani et al.):
    - PE(pos, 2i) = sin(pos / 10000^(2i/embed_dim))
    - PE(pos, 2i+1) = cos(pos / 10000^(2i/embed_dim))
    
    This encoding allows the model to learn relative positions and generalizes
    well to sequence lengths not seen during training.
    
    Args:
        embed_dim: Dimensionality of the embeddings
        max_lines: Maximum number of scanlines expected
        dropout: Dropout rate applied after adding positional encoding
    """
    
    def __init__(self, embed_dim, max_lines=128, dropout=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.dropout = nn.Dropout(p=dropout)
        
        # Maximum sequence length: max_lines + 1 (for CLS token)
        max_len = max_lines + 1
        
        # Initialize positional encoding matrix
        pe = torch.zeros(max_len, embed_dim)
        
        # Create position indices: [0, 1, 2, ..., max_len-1]
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        
        # Calculate the division term for sinusoidal functions
        # This creates different frequencies for different dimensions
        div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * (-np.log(10000.0) / embed_dim))
        
        # Apply sine to even dimensions
        pe[:, 0::2] = torch.sin(position * div_term)
        
        # Apply cosine to odd dimensions
        pe[:, 1::2] = torch.cos(position * div_term)
        
        # Register as buffer (not a parameter, but should be moved to GPU)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        """
        Add positional encoding to input embeddings.
        
        Args:
            x: [batch_size, seq_len, embed_dim]
        
        Returns:
            x with positional encoding added
        """
        x = x + self.pe[:x.size(1), :]
        return self.dropout(x)

### 3.3 Vision Transformer (The Complete Model)

**What it does**: Combines everything into a complete classification model

**Architecture flow**:
```
Input: [batch, num_lines, 128]
  ↓ Line Embedding
[batch, num_lines, 256]
  ↓ Add CLS token
[batch, 1+num_lines, 256]  ← CLS token prepended
  ↓ Add Positional Encoding
[batch, 1+num_lines, 256]
  ↓ Transformer Layers (6 layers)
[batch, 1+num_lines, 256]
  ↓ Extract CLS token
[batch, 256]
  ↓ Classification Head
[batch, 6]  ← Logits for 6 classes
```

**Key design choices**:
- **CLS token**: A learnable "summary" token that aggregates information from all lines
- **6 transformer layers**: Each layer refines the representations using self-attention
- **8 attention heads**: Each head can focus on different patterns
- **Pre-normalization**: Normalizes before (not after) each sub-layer for better training stability

**What happens during training**: The model learns to:
1. Embed lines in a meaningful way (similar lines → similar embeddings)
2. Attend to relevant parts of the sequence (which lines matter for classification?)
3. Aggregate information into the CLS token
4. Map CLS token to class probabilities

In [21]:
class VisionTransformer(nn.Module):
    """
    Vision Transformer adapted for line-by-line STM image classification.
    
    This model follows the ViT architecture but treats each full scanline as a token:
    1. Line Embedding: Convert full scanlines to embeddings (one token per line)
    2. CLS Token: Add a learnable classification token at the start of sequence
    3. Positional Encoding: Add position information to embeddings
    4. Transformer Encoder: Process the sequence with self-attention
    5. Classification Head: Use the CLS token output for classification
    
    The model can handle variable-length inputs (different numbers of scanlines)
    thanks to the CLS token approach and positional encodings.
    
    This architecture is more efficient than patch-based approaches:
    - num_lines tokens instead of num_lines × patches_per_line tokens
    - Quadratically reduced attention computation
    - More natural for 1D sequential data
    
    Args:
        line_width: Number of pixels per scanline (default: 128)
        num_classes: Number of output classes (default: 6)
        embed_dim: Embedding dimension (default: 256)
        depth: Number of transformer encoder layers (default: 6)
        num_heads: Number of attention heads (default: 8)
        mlp_ratio: Ratio of MLP hidden dim to embedding dim (default: 4)
        dropout: Dropout rate (default: 0.1)
        max_lines: Maximum number of scanlines (default: 128)
    """
    
    def __init__(
        self,
        line_width=128,
        num_classes=4,
        embed_dim=256,
        depth=6,
        num_heads=8,
        mlp_ratio=4,
        dropout=0.1,
        max_lines=128
    ):
        super().__init__()
        
        self.num_classes = num_classes
        self.embed_dim = embed_dim
        self.line_width = line_width
        
        # Convert scanlines to embeddings (one token per line)
        self.patch_embed = PatchEmbedding(line_width, embed_dim)
        
        # Learnable classification token (similar to BERT's [CLS] token)
        # This token's output will be used for classification
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        
        # Add positional information to embeddings
        self.pos_encoding = PositionalEncoding(
            embed_dim, 
            max_lines=max_lines, 
            dropout=dropout
        )
        
        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=embed_dim * mlp_ratio,
            dropout=dropout,
            activation='gelu',
            batch_first=True,
            norm_first=True  # Pre-norm architecture
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)
        
        # Classification head
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)
        
        # Initialize weights
        self._init_weights()
    
    def _init_weights(self):
        """Initialize weights following ViT paper."""
        # Initialize CLS token
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        
        # Initialize linear layers
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.trunc_normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)
    
    def forward(self, x):
        """
        Forward pass through the Vision Transformer.
        
        Args:
            x: [batch_size, num_lines, line_width]
        
        Returns:
            logits: [batch_size, num_classes]
        """
        batch_size = x.shape[0]
        
        # Convert scanlines to line embeddings
        # Shape: [batch_size, num_lines, embed_dim]
        x = self.patch_embed(x)
        
        # Prepend CLS token: [batch, num_lines + 1, embed_dim]
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)
        
        # Add positional encoding
        x = self.pos_encoding(x)
        
        # Transformer encoding
        x = self.transformer(x)
        
        # Take CLS token output
        cls_output = x[:, 0]
        
        # Normalize and classify
        cls_output = self.norm(cls_output)
        logits = self.head(cls_output)
        
        return logits

In [22]:
# Create model
model = VisionTransformer(
    line_width=128,
    num_classes=6,  # Updated to 6 classes based on the dataset
    embed_dim=256,
    depth=6,
    num_heads=8,
    mlp_ratio=4,
    dropout=0.1,
    max_lines=128
).to(device)

num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model created with {num_params:,} trainable parameters")

# Test forward pass
dummy_input = torch.randn(2, 50, 128).to(device)
output = model(dummy_input)
print(f"Input shape: {dummy_input.shape}")
print(f"Output shape: {output.shape}")

Model created with 4,773,894 trainable parameters
Input shape: torch.Size([2, 50, 128])
Output shape: torch.Size([2, 6])




## 4. Training Setup

**What this section does**: Configure how the model learns

This is where you define:
1. **How to measure error** (loss function)
2. **How to update weights** (optimizer)
3. **How to adjust learning rate** (scheduler)
4. **How to batch data** (dataloaders)

### Key Hyperparameters (feel free to experiment!)

**samples_per_image**: How many random line-count variations per image
- Current: 2 (fast prototyping)
- For real training: 5-10
- Higher = more data augmentation, longer training

**batch_size**: How many samples to process before updating weights
- Current: 8 (uses less memory)
- Typical: 16-32
- Larger = more stable gradients, faster training, more memory

**Learning rate (lr)**: How big are weight updates?
- Default: 1e-4 (0.0001)
- Too high: Training diverges
- Too low: Training is very slow
- This is THE most important hyperparameter!

**num_epochs**: How many times to see the full dataset
- Current: 2 (quick test)
- For real training: 50-100
- More epochs = better performance (until overfitting)

In [23]:
def collate_fn(batch):
    """
    Custom collate function to handle variable-length sequences.
    Pads to the maximum length in the batch.
    """
    max_lines = max(item['num_lines'] for item in batch)
    batch_size = len(batch)
    line_width = batch[0]['image'].shape[1]
    
    # Create padded tensors
    images = torch.zeros(batch_size, max_lines, line_width)
    labels = torch.zeros(batch_size, dtype=torch.long)
    
    for i, item in enumerate(batch):
        num_lines = item['num_lines']
        images[i, :num_lines] = item['image']
        labels[i] = item['label']
    
    return {'images': images, 'labels': labels}


# Create datasets
# For real training: samples_per_image=5
train_dataset = LineByLineDataset(train_images, train_labels, min_lines=10, samples_per_image=2)
val_dataset = LineByLineDataset(val_images, val_labels, min_lines=10, samples_per_image=2)
test_dataset = LineByLineDataset(test_images, test_labels, min_lines=10, samples_per_image=2)

# Create dataloaders
# For real training: batch_size=32
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, collate_fn=collate_fn)

print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

Training batches: 3315
Validation batches: 709
Test batches: 710


In [24]:
# Training configuration
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.05)

# Learning rate scheduler
# For real training: num_epochs=50
num_epochs = 2
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

print("Training configuration:")
print(f"- Loss function: CrossEntropyLoss (standard for classification)")
print(f"- Optimizer: AdamW (Adam with better weight decay)")
print(f"- Learning rate: 1e-4 (0.0001)")
print(f"- Weight decay: 0.05 (L2 regularization to prevent overfitting)")
print(f"- LR Scheduler: CosineAnnealing (gradually reduces LR)")
print(f"- Epochs: {num_epochs}")

Training configuration:
- Loss function: CrossEntropyLoss (standard for classification)
- Optimizer: AdamW (Adam with better weight decay)
- Learning rate: 1e-4 (0.0001)
- Weight decay: 0.05 (L2 regularization to prevent overfitting)
- LR Scheduler: CosineAnnealing (gradually reduces LR)
- Epochs: 2


## 5. Training Loop

**What this section does**: Actually train the model!

### The Training Process (what happens each epoch):

1. **Training phase**: 
   - Process all training batches
   - For each batch: forward pass → compute loss → backward pass → update weights
   - Track loss and accuracy
   
2. **Validation phase**:
   - Process all validation batches (no weight updates!)
   - Measure performance on unseen data
   - Track loss, accuracy, and AUROC

3. **Save best model**:
   - If validation AUROC improves, save checkpoint
   - This prevents overfitting (use model from best epoch, not last epoch)

### What to Watch For:

**Good signs**:
- Training loss decreases steadily
- Validation loss decreases (or stays stable)
- Validation accuracy improves
- Gap between train/val is small

**Bad signs**:
- Training loss decreases but validation loss increases = **overfitting**
- Both losses stay high = **underfitting** (need bigger model or more epochs)
- Loss becomes NaN = learning rate too high or numerical instability

**Metrics explained**:
- **Loss**: How wrong the predictions are (lower = better)
- **Accuracy**: % of correct predictions (higher = better)
- **Balanced Accuracy**: Like accuracy but accounts for class imbalance
- **AUROC**: Area under ROC curve (0.5 = random, 1.0 = perfect)

In [25]:
def train_epoch(model, loader, criterion, optimizer, device):
    """Train for one epoch."""
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    pbar = tqdm(loader, desc='Training')
    for batch in pbar:
        images = batch['images'].to(device)
        labels = batch['labels'].to(device)
        
        # Forward pass
        logits = model(images)
        loss = criterion(logits, labels)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Metrics
        total_loss += loss.item()
        predictions = logits.argmax(dim=1)
        correct += (predictions == labels).sum().item()
        total += labels.size(0)
        
        # Update progress bar
        pbar.set_postfix({'loss': loss.item(), 'acc': correct / total})
    
    return total_loss / len(loader), correct / total


def evaluate(model, loader, criterion, device):
    """Evaluate model."""
    model.eval()
    total_loss = 0
    all_labels = []
    all_preds = []
    all_probs = []
    
    with torch.no_grad():
        for batch in tqdm(loader, desc='Evaluating'):
            images = batch['images'].to(device)
            labels = batch['labels'].to(device)
            
            logits = model(images)
            loss = criterion(logits, labels)
            
            total_loss += loss.item()
            probs = F.softmax(logits, dim=1)
            preds = logits.argmax(dim=1)
            
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
    
    all_labels = np.array(all_labels)
    all_preds = np.array(all_preds)
    all_probs = np.array(all_probs)
    
    # Calculate metrics
    accuracy = balanced_accuracy_score(all_labels, all_preds)
    auroc = roc_auc_score(all_labels, all_probs, multi_class='ovr', average='macro')
    
    return total_loss / len(loader), accuracy, auroc

In [None]:
# Training loop
history = {
    'train_loss': [],
    'train_acc': [],
    'val_loss': [],
    'val_acc': [],
    'val_auroc': []
}

best_val_auroc = 0

print("Starting training...")

for epoch in range(num_epochs):
    print(f"\n{'='*60}")
    print(f"Epoch {epoch + 1}/{num_epochs}")
    print(f"{'='*60}")
    
    # Train
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    
    # Validate
    val_loss, val_acc, val_auroc = evaluate(model, val_loader, criterion, device)
    
    # Update scheduler
    scheduler.step()
    
    # Record history
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    history['val_auroc'].append(val_auroc)
    
    # Print summary
    print(f"\nEpoch {epoch + 1} Summary:")
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, Val AUROC: {val_auroc:.4f}")
    
    # Save best model
    if val_auroc > best_val_auroc:
        best_val_auroc = val_auroc
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_auroc': val_auroc,
        }, 'best_model.pt')
        print(f"Saved best model with AUROC: {val_auroc:.4f}")
    else:
        print(f" No improvement. Best AUROC: {best_val_auroc:.4f}")

print(f"\n{'='*60}")
print("Training complete!")
print(f"Best validation AUROC: {best_val_auroc:.4f}")
print(f"{'='*60}")

Starting training...
Watch the progress bars and metrics below.


Epoch 1/2


Training:   0%|          | 0/3315 [00:00<?, ?it/s]

## 6. Visualization

**What this does**: Plot training curves to visualize learning progress

**How to interpret the plots**:

1. **Loss plot** (left):
   - Should decrease over time for both train and val
   - If val loss increases while train decreases: **overfitting**
   - If both stay flat: **underfitting** (need bigger model or more epochs)

2. **Accuracy plot** (middle):
   - Should increase over time
   - Balanced accuracy accounts for class imbalance (better than regular accuracy)
   - Gap between train/val shows generalization quality

3. **AUROC plot** (right):
   - Most important metric for imbalanced classification
   - 0.5 = random guessing
   - 1.0 = perfect classification
   - For 6-class problem, 0.8+ is good, 0.9+ is excellent

**What good training looks like**:
- Smooth curves (not too noisy)
- Train and val curves close together
- Steady improvement over time
- Val metrics plateau towards the end (learning saturated)

In [None]:
# Plot training history
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Loss
axes[0].plot(history['train_loss'], label='Train')
axes[0].plot(history['val_loss'], label='Validation')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training and Validation Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Accuracy
axes[1].plot(history['train_acc'], label='Train')
axes[1].plot(history['val_acc'], label='Validation')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Balanced Accuracy')
axes[1].set_title('Training and Validation Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# AUROC
axes[2].plot(history['val_auroc'], label='Validation', color='green')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('AUROC')
axes[2].set_title('Validation AUROC')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('training_history.png', dpi=300, bbox_inches='tight')
plt.show()

## 7. Evaluation: Performance vs. Number of Lines

**This is the key analysis!** 

We want to answer: "How many scanlines do we need to classify accurately?"

**What this does**:
- Evaluates the model with 10, 20, 30, ..., 128 lines
- Shows how performance improves as we see more of the image
- Helps decide: "When can we stop scanning early?"

**What to expect**:
- Performance improves as we see more lines (more information)
- Early lines might give ~60-70% accuracy
- Full image (128 lines) might give ~90%+ accuracy
- The curve usually shows diminishing returns (plateaus after some point)

**Practical implications**:
- If 50 lines gives 85% accuracy and 128 lines gives 88%, maybe stop at 50?
- Saves 60% of scan time with minimal accuracy loss!
- This is the value proposition of early stopping

In [None]:
def evaluate_vs_num_lines(model, images, labels, device, line_steps=None):
    """
    Evaluate model performance as a function of number of scanlines.
    
    Args:
        model: Trained model
        images: Full images [N, max_lines, line_width]
        labels: Labels [N]
        device: Device to run on
        line_steps: List of line counts to evaluate at (default: [10, 20, 30, ..., 128])
    
    Returns:
        Dictionary with performance metrics at each line count
    """
    if line_steps is None:
        line_steps = list(range(10, 129, 10))  # Every 10 lines
    
    model.eval()
    results = {'num_lines': [], 'accuracy': [], 'auroc': []}
    
    for num_lines in tqdm(line_steps, desc='Evaluating vs. num lines'):
        all_labels = []
        all_probs = []
        all_preds = []
        
        with torch.no_grad():
            for i in range(0, len(images), 32):  # Batch size 32
                batch_images = images[i:i+32, :num_lines, :]
                batch_labels = labels[i:i+32]
                
                batch_images = torch.FloatTensor(batch_images).to(device)
                
                logits = model(batch_images)
                probs = F.softmax(logits, dim=1)
                preds = logits.argmax(dim=1)
                
                all_labels.extend(batch_labels)
                all_probs.extend(probs.cpu().numpy())
                all_preds.extend(preds.cpu().numpy())
        
        all_labels = np.array(all_labels)
        all_probs = np.array(all_probs)
        all_preds = np.array(all_preds)
        
        accuracy = balanced_accuracy_score(all_labels, all_preds)
        auroc = roc_auc_score(all_labels, all_probs, multi_class='ovr', average='macro')
        
        results['num_lines'].append(num_lines)
        results['accuracy'].append(accuracy)
        results['auroc'].append(auroc)
    
    return results

In [None]:
# Load best model
checkpoint = torch.load('best_model.pt')
model.load_state_dict(checkpoint['model_state_dict'])

# Evaluate on test set
results = evaluate_vs_num_lines(model, test_images, test_labels, device)

# Plot results
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# AUROC vs. number of lines
ax1.plot(results['num_lines'], results['auroc'], marker='o', linewidth=2, markersize=6)
ax1.set_xlabel('Number of Scanlines')
ax1.set_ylabel('AUROC')
ax1.set_title('AUROC vs. Number of Scanlines')
ax1.grid(True, alpha=0.3)
ax1.set_ylim([0, 1])

# Accuracy vs. number of lines
ax2.plot(results['num_lines'], results['accuracy'], marker='o', linewidth=2, markersize=6, color='green')
ax2.set_xlabel('Number of Scanlines')
ax2.set_ylabel('Balanced Accuracy')
ax2.set_title('Accuracy vs. Number of Scanlines')
ax2.grid(True, alpha=0.3)
ax2.set_ylim([0, 1])

plt.tight_layout()
plt.savefig('performance_vs_lines.png', dpi=300, bbox_inches='tight')
plt.show()

print("\nPerformance Summary:")
for i, num_lines in enumerate(results['num_lines']):
    print(f"Lines: {num_lines:3d} | AUROC: {results['auroc'][i]:.4f} | Accuracy: {results['accuracy'][i]:.4f}")

## 8. Confusion Matrix

**What this does**: Shows which classes the model confuses

**How to read it**:
- Rows = True labels (what the class actually is)
- Columns = Predicted labels (what the model thinks)
- Diagonal = Correct predictions
- Off-diagonal = Mistakes

**What to look for**:
- Strong diagonal = good performance
- Large off-diagonal numbers = systematic confusion
- Example: If row 2, column 3 is large → model often predicts class 3 when truth is class 2

**Why this matters**:
- Helps debug: "Why is class X performing poorly?"
- May reveal that classes are inherently similar
- Can guide data collection: "Need more examples of confused classes"

In [None]:
# Get predictions on full test set
model.eval()
all_labels = []
all_preds = []

with torch.no_grad():
    for batch in test_loader:
        images = batch['images'].to(device)
        labels = batch['labels']
        
        logits = model(images)
        preds = logits.argmax(dim=1)
        
        all_labels.extend(labels.numpy())
        all_preds.extend(preds.cpu().numpy())

# Compute confusion matrix
cm = confusion_matrix(all_labels, all_preds)

# Plot
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=[f'Class {i}' for i in range(6)],
            yticklabels=[f'Class {i}' for i in range(6)])
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.tight_layout()
plt.savefig('confusion_matrix.png', dpi=300, bbox_inches='tight')
plt.show()

## 9. Attention Visualization (Optional)

TODO: Extract and visualize attention maps to understand what the model focuses on.

In [None]:
# TODO: Modify model to return attention weights
# TODO: Visualize which patches/lines the model attends to
# This helps interpret model decisions

## 10. Model Export

Export model for deployment or further analysis.

In [None]:
# Save model architecture and weights
import json

model_config = {
    'line_width': 128,
    'num_classes': 6,  # Updated to 6 classes
    'embed_dim': 256,
    'depth': 6,
    'num_heads': 8,
    'mlp_ratio': 4,
    'dropout': 0.1,
    'max_lines': 128
}

with open('model_config.json', 'w') as f:
    json.dump(model_config, f, indent=2)

# Save final model
torch.save({
    'model_config': model_config,
    'model_state_dict': model.state_dict(),
    'best_val_auroc': best_val_auroc
}, 'final_model.pt')

print("Model saved successfully!")
print(f"Best validation AUROC: {best_val_auroc:.4f}")