# Training a Segmentation Model

This notebook demonstrates how to train a segmentation model using the VLM evaluation framework.

## Setup

In [None]:
import sys
from pathlib import Path
sys.path.insert(0, str(Path.cwd().parent))

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

from vlm_eval.core import EncoderRegistry, HeadRegistry, DatasetRegistry

## Create Model and Dataset

In [None]:
# Create model
encoder = EncoderRegistry.get("simple_cnn", variant="small", pretrained=False)
model = HeadRegistry.get(
    "linear_probe",
    encoder=encoder,
    num_classes=21,
    freeze_encoder=False  # Train the whole model
)

# Create datasets
train_dataset = DatasetRegistry.get("dummy", num_samples=200, num_classes=21)
val_dataset = DatasetRegistry.get("dummy", num_samples=50, num_classes=21)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

print(f"Model parameters: {model.get_num_parameters():,}")
print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

## Training Setup

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss(ignore_index=255)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

print(f"Device: {device}")
print(f"Optimizer: {optimizer.__class__.__name__}")
print(f"Learning rate: {optimizer.param_groups[0]['lr']}")

## Training Loop

In [None]:
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    
    for batch in tqdm(loader, desc="Training"):
        images = batch["image"].to(device)
        masks = batch["mask"].to(device)
        
        # Forward pass
        features = model.encoder(images)
        logits = model(features)
        
        # Compute loss
        loss = criterion(logits, masks)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(loader)

def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch in tqdm(loader, desc="Validation"):
            images = batch["image"].to(device)
            masks = batch["mask"].to(device)
            
            # Forward pass
            features = model.encoder(images)
            logits = model(features)
            
            # Compute loss
            loss = criterion(logits, masks)
            total_loss += loss.item()
            
            # Compute accuracy
            preds = logits.argmax(dim=1)
            correct += (preds == masks).sum().item()
            total += masks.numel()
    
    avg_loss = total_loss / len(loader)
    accuracy = correct / total
    
    return avg_loss, accuracy

## Train the Model

In [None]:
num_epochs = 5
train_losses = []
val_losses = []
val_accuracies = []

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    
    # Train
    train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
    train_losses.append(train_loss)
    
    # Validate
    val_loss, val_acc = validate(model, val_loader, criterion, device)
    val_losses.append(val_loss)
    val_accuracies.append(val_acc)
    
    print(f"Train Loss: {train_loss:.4f}")
    print(f"Val Loss: {val_loss:.4f}")
    print(f"Val Accuracy: {val_acc:.4f}")

print("\n✓ Training complete!")

## Plot Training Curves

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Loss
axes[0].plot(train_losses, label='Train Loss', marker='o')
axes[0].plot(val_losses, label='Val Loss', marker='s')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training and Validation Loss')
axes[0].legend()
axes[0].grid(True)

# Accuracy
axes[1].plot(val_accuracies, label='Val Accuracy', marker='o', color='green')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Validation Accuracy')
axes[1].legend()
axes[1].grid(True)

plt.tight_layout()
plt.show()

## Visualize Predictions

In [None]:
model.eval()
batch = next(iter(val_loader))
images = batch["image"].to(device)
masks = batch["mask"].to(device)

with torch.no_grad():
    features = model.encoder(images)
    logits = model(features)
    predictions = logits.argmax(dim=1)

# Show first 4 samples
fig, axes = plt.subplots(4, 3, figsize=(12, 16))

for i in range(4):
    # Image
    axes[i, 0].imshow(images[i].cpu().permute(1, 2, 0).numpy())
    axes[i, 0].set_title("Input Image")
    axes[i, 0].axis('off')
    
    # Ground truth
    axes[i, 1].imshow(masks[i].cpu().numpy(), cmap='tab20')
    axes[i, 1].set_title("Ground Truth")
    axes[i, 1].axis('off')
    
    # Prediction
    axes[i, 2].imshow(predictions[i].cpu().numpy(), cmap='tab20')
    axes[i, 2].set_title("Prediction")
    axes[i, 2].axis('off')

plt.tight_layout()
plt.show()

## Save Model

In [None]:
# Save checkpoint
checkpoint = {
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'encoder_config': model.encoder.get_config(),
    'head_config': model.get_config(),
    'epoch': num_epochs,
    'val_accuracy': val_accuracies[-1],
}

torch.save(checkpoint, 'model_checkpoint.pth')
print("✓ Model saved to model_checkpoint.pth")

## Summary

You've successfully:
1. ✅ Created a model with encoder and head
2. ✅ Set up training and validation datasets
3. ✅ Implemented training and validation loops
4. ✅ Trained the model for multiple epochs
5. ✅ Visualized training curves and predictions
6. ✅ Saved the trained model

This demonstrates the complete training workflow using the VLM evaluation framework!