# Training Vision Captioner on Synthetic Scenes

This notebook trains a ConvNeXt-Tiny + GRU captioner with FSM-constrained decoding on synthetic scene data.

**Hardware**: Designed for Google Colab with A100 GPU

**Architecture**:
- Encoder: ConvNeXt-Tiny (8 blocks, 256-dim output)
- Decoder: GRU with Bahdanau attention (512-dim hidden)
- Training: AdamW + AMP + OneCycleLR + Scheduled Sampling
- Decoding: FSM-constrained to ensure grammar compliance

## Setup and Installation

First, let's mount Google Drive and set up the environment.

In [None]:
# Check GPU
!nvidia-smi

In [None]:
# Install dependencies
!pip install torch torchvision tqdm pillow scikit-learn matplotlib seaborn ipywidgets -q

In [None]:
# Mount Google Drive (optional, for saving checkpoints)
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Clone the repository from GitHub
import os
import sys

GITHUB_REPO = "https://github.com/jtooates/learning_to_see.git"

if not os.path.exists('/content/learning_to_see'):
    print("üì• Cloning repository...")
    !git clone {GITHUB_REPO}
    print("‚úÖ Repository cloned!")
else:
    print("‚úÖ Repository already exists")

# Change to project directory
%cd /content/learning_to_see

# Add to Python path
sys.path.insert(0, '/content/learning_to_see')

print(f"\nüìÇ Current directory: {os.getcwd()}")

## Generate Training Data

Generate synthetic scenes with images and captions.

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from tqdm.notebook import tqdm

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Generate training data
# This creates 6000 total samples, split 80/10/10 for train/val/test
!python -m data.gen \
    --out_dir ./data/scenes \
    --n 6000 \
    --split_strategy random \
    --seed 42

print("\n‚úÖ Data generation complete!")

## Visualize Sample Data

Let's look at some examples from the dataset.

In [None]:
from data.dataset import SceneDataset
from dsl.tokens import Vocab
import os

# Check if data exists
data_path = './data/scenes'
if not os.path.exists(data_path):
    print(f"‚ö†Ô∏è  ERROR: Data directory not found at {data_path}")
    print(f"   Please run the data generation cell above first!")
    raise FileNotFoundError(f"Data directory {data_path} does not exist. Run cell 8 to generate data.")

if not os.path.exists(f'{data_path}/splits.json'):
    print(f"‚ö†Ô∏è  ERROR: Data not fully generated")
    print(f"   The data generation command may have failed.")
    print(f"   Please check the output of cell 8 for errors.")
    raise FileNotFoundError(f"splits.json not found in {data_path}. Data generation incomplete.")

# Load vocab and dataset
vocab = Vocab()
train_dataset = SceneDataset(
    data_dir=data_path,
    split='train',
    vocab=vocab
)

print(f"‚úì Training samples: {len(train_dataset)}")
print(f"‚úì Vocabulary size: {vocab.vocab_size}")

In [None]:
# Visualize samples
def show_samples(dataset, num_samples=8, cols=4):
    rows = (num_samples + cols - 1) // cols
    fig, axes = plt.subplots(rows, cols, figsize=(cols * 3, rows * 3))
    axes = axes.flatten()
    
    indices = np.random.choice(len(dataset), num_samples, replace=False)
    
    for idx, ax in zip(indices, axes):
        sample = dataset[idx]
        image = sample['image'].permute(1, 2, 0).numpy()
        text = sample['text']
        
        ax.imshow(image)
        ax.set_title(text, fontsize=10, wrap=True)
        ax.axis('off')
    
    # Hide extra axes
    for ax in axes[num_samples:]:
        ax.axis('off')
    
    plt.tight_layout()
    plt.show()

show_samples(train_dataset, num_samples=8)

## Visualize Data Augmentations

Show how augmentations transform the images.

In [None]:
from captioner.augmentations import get_train_augmentation

# Get a sample image
sample = train_dataset[0]
image = sample['image']
text = sample['text']

# Apply augmentation multiple times
aug = get_train_augmentation(image_size=64, strong=True)

fig, axes = plt.subplots(2, 4, figsize=(12, 6))
axes = axes.flatten()

# Original
axes[0].imshow(image.permute(1, 2, 0).numpy())
axes[0].set_title('Original', fontweight='bold')
axes[0].axis('off')

# Augmented versions
for i in range(1, 8):
    augmented = aug(image)
    axes[i].imshow(augmented.permute(1, 2, 0).numpy())
    axes[i].set_title(f'Augmented {i}')
    axes[i].axis('off')

fig.suptitle(f'Caption: "{text}"', fontsize=12, y=1.02)
plt.tight_layout()
plt.show()

## Build Model

Create the captioner model and inspect its architecture.

In [None]:
from captioner import build_captioner

# Build model
model = build_captioner(
    vocab_size=vocab.vocab_size,
    embed_dim=256,
    hidden_dim=512,
    encoder_dim=256,
    attention_dim=256,
    dropout=0.3,
    drop_path_rate=0.1,
    label_smoothing=0.1
)

# Count parameters
def count_parameters(model):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total, trainable

total_params, trainable_params = count_parameters(model)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"\nModel size: ~{total_params * 4 / 1024 / 1024:.1f} MB (FP32)")

In [None]:
# Print model architecture
print("\n=== MODEL ARCHITECTURE ===")
print(model)

## Test Forward Pass

Verify the model works with a small batch.

In [None]:
# Test forward pass
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)

# Get a batch
from data.dataset import create_dataloaders

train_loader, val_loader, test_loader = create_dataloaders(
    data_dir='./data/scenes',
    vocab=vocab,
    batch_size=4,
    num_workers=2
)

batch = next(iter(train_loader))
images = batch['image'].to(device)
targets = batch['input_ids'].to(device)

print(f"Images shape: {images.shape}")
print(f"Targets shape: {targets.shape}")

# Forward pass
model.eval()
with torch.no_grad():
    logits, loss = model(images, targets, teacher_forcing_ratio=1.0)

print(f"\nLogits shape: {logits.shape}")
print(f"Loss: {loss.item():.4f}")
print("\n‚úì Forward pass successful!")

## Training Setup

Configure training parameters and create trainer.

In [None]:
from captioner.train import CaptionerTrainer

# Training configuration
config = {
    'lr': 3e-4,
    'weight_decay': 0.01,
    'max_epochs': 50,
    'warmup_epochs': 5,
    'batch_size': 128,  # A100 can handle this
    'use_amp': True,
    'checkpoint_dir': './checkpoints',
    'log_interval': 20
}

print("Training Configuration:")
for k, v in config.items():
    print(f"  {k}: {v}")

In [None]:
# Create dataloaders with training batch size
train_loader, val_loader, test_loader = create_dataloaders(
    data_dir='./data/scenes',
    vocab=vocab,
    batch_size=config['batch_size'],
    num_workers=4,
    pin_memory=True
)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")
print(f"Total training steps: {len(train_loader) * config['max_epochs']}")

In [None]:
# Rebuild model for training
model = build_captioner(
    vocab_size=vocab.vocab_size,
    embed_dim=256,
    hidden_dim=512,
    encoder_dim=256,
    attention_dim=256,
    dropout=0.3,
    drop_path_rate=0.1,
    label_smoothing=0.1
)

# Create trainer
trainer = CaptionerTrainer(
    model=model,
    vocab=vocab,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    lr=config['lr'],
    weight_decay=config['weight_decay'],
    max_epochs=config['max_epochs'],
    warmup_epochs=config['warmup_epochs'],
    use_amp=config['use_amp'],
    checkpoint_dir=config['checkpoint_dir'],
    log_interval=config['log_interval']
)

print("‚úì Trainer created successfully!")

## Training Loop

Train the model and track metrics.

In [None]:
# Train the model
# This will take ~30-60 minutes on A100 for 50 epochs
trainer.train()

## Visualize Training Progress

Plot training and validation metrics.

In [None]:
import json
from glob import glob

# Load metrics from checkpoints
checkpoint_dir = Path(config['checkpoint_dir'])
metric_files = sorted(checkpoint_dir.glob('checkpoint_epoch_*_metrics.json'))

epochs = []
val_losses = []
exact_matches = []
token_accuracies = []
color_f1s = []
shape_f1s = []

for f in metric_files:
    epoch = int(f.stem.split('_')[2])
    with open(f) as fp:
        metrics = json.load(fp)
    
    epochs.append(epoch)
    val_losses.append(metrics['val_loss'])
    exact_matches.append(metrics['exact_match'])
    token_accuracies.append(metrics['token_accuracy'])
    color_f1s.append(metrics['color_f1'])
    shape_f1s.append(metrics['shape_f1'])

# Plot metrics
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Loss
axes[0, 0].plot(epochs, val_losses, marker='o', linewidth=2, markersize=4)
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Validation Loss')
axes[0, 0].set_title('Validation Loss')
axes[0, 0].grid(True, alpha=0.3)

# Exact Match
axes[0, 1].plot(epochs, exact_matches, marker='o', linewidth=2, markersize=4, color='green')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Exact Match Accuracy')
axes[0, 1].set_title('Exact Match Accuracy')
axes[0, 1].set_ylim([0, 1])
axes[0, 1].grid(True, alpha=0.3)

# Token Accuracy
axes[1, 0].plot(epochs, token_accuracies, marker='o', linewidth=2, markersize=4, color='blue')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Token Accuracy')
axes[1, 0].set_title('Token Accuracy')
axes[1, 0].set_ylim([0, 1])
axes[1, 0].grid(True, alpha=0.3)

# Per-Attribute F1
axes[1, 1].plot(epochs, color_f1s, marker='o', linewidth=2, markersize=4, label='Color F1')
axes[1, 1].plot(epochs, shape_f1s, marker='s', linewidth=2, markersize=4, label='Shape F1')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('F1 Score')
axes[1, 1].set_title('Per-Attribute F1 Scores')
axes[1, 1].set_ylim([0, 1])
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Print final metrics
if epochs:
    print(f"\nFinal Metrics (Epoch {epochs[-1]}):")
    print(f"  Val Loss: {val_losses[-1]:.4f}")
    print(f"  Exact Match: {exact_matches[-1]:.4f}")
    print(f"  Token Accuracy: {token_accuracies[-1]:.4f}")
    print(f"  Color F1: {color_f1s[-1]:.4f}")
    print(f"  Shape F1: {shape_f1s[-1]:.4f}")

## Test Model Predictions

Load the best model and visualize predictions.

In [None]:
# Load best model
best_checkpoint = checkpoint_dir / 'best_model.pt'

if best_checkpoint.exists():
    checkpoint = torch.load(best_checkpoint, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    print(f"‚úì Loaded best model from epoch {checkpoint['epoch']}")
    print(f"  Best exact match: {checkpoint['best_exact_match']:.4f}")
else:
    print("‚ö†Ô∏è  Best model checkpoint not found")

In [None]:
from captioner.decode import greedy_decode

def visualize_predictions(model, dataset, num_samples=8, use_constraints=True):
    """Visualize model predictions."""
    model.eval()
    
    # Get random samples
    indices = np.random.choice(len(dataset), num_samples, replace=False)
    
    images = []
    ground_truths = []
    
    for idx in indices:
        sample = dataset[idx]
        images.append(sample['image'])
        ground_truths.append(sample['text'])
    
    # Stack images
    images_tensor = torch.stack(images).to(device)
    
    # Generate predictions
    with torch.no_grad():
        _, predictions = greedy_decode(
            model=model,
            images=images_tensor,
            vocab=vocab,
            max_length=32,
            use_constraints=use_constraints
        )
    
    # Plot
    cols = 4
    rows = (num_samples + cols - 1) // cols
    fig, axes = plt.subplots(rows, cols, figsize=(cols * 4, rows * 4))
    axes = axes.flatten()
    
    for i, (img, gt, pred) in enumerate(zip(images, ground_truths, predictions)):
        ax = axes[i]
        
        # Show image
        ax.imshow(img.permute(1, 2, 0).cpu().numpy())
        
        # Check if prediction matches
        match = gt == pred
        color = 'green' if match else 'red'
        
        # Title with ground truth and prediction
        title = f"GT: {gt}\nPred: {pred}"
        ax.set_title(title, fontsize=9, color=color, fontweight='bold')
        ax.axis('off')
    
    # Hide extra axes
    for ax in axes[num_samples:]:
        ax.axis('off')
    
    constraint_str = "WITH" if use_constraints else "WITHOUT"
    fig.suptitle(f"Model Predictions ({constraint_str} FSM Constraints)", 
                 fontsize=14, fontweight='bold', y=1.00)
    plt.tight_layout()
    plt.show()
    
    # Calculate accuracy
    correct = sum(1 for gt, pred in zip(ground_truths, predictions) if gt == pred)
    accuracy = correct / len(ground_truths)
    print(f"\nAccuracy on these samples: {correct}/{len(ground_truths)} = {accuracy:.2%}")

# Visualize with constraints
print("=== Predictions WITH FSM Constraints ===")
visualize_predictions(model, train_dataset, num_samples=8, use_constraints=True)

In [None]:
# Compare with and without constraints
print("\n=== Predictions WITHOUT FSM Constraints ===")
visualize_predictions(model, train_dataset, num_samples=8, use_constraints=False)

## Test on Validation Set

Evaluate on the full validation set.

In [None]:
from captioner.metrics import evaluate_model

# Evaluate on validation set
print("Evaluating on validation set...")
metrics = evaluate_model(
    model=model,
    dataloader=val_loader,
    vocab=vocab,
    device=device,
    use_constraints=True,
    max_length=32
)

# Print results
metrics.print_summary()

## Visualize Attention Weights

Show what the model is attending to during generation.

In [None]:
def visualize_attention(model, image, vocab, device):
    """Visualize attention weights during decoding."""
    model.eval()
    
    with torch.no_grad():
        # Encode
        image_tensor = image.unsqueeze(0).to(device)
        grid_tokens, pooled = model.encode(image_tensor)
        
        # Initialize decoder
        hidden = model.init_decoder_state(pooled)
        
        # Generate tokens and collect attention
        input_token = torch.tensor([vocab.bos_id], device=device)
        tokens = [vocab.bos_id]
        attention_weights = []
        
        for _ in range(32):
            logits, hidden, attn = model.decode_step(
                input_token=input_token,
                hidden=hidden,
                encoder_out=grid_tokens
            )
            
            next_token = logits.argmax(dim=1).item()
            tokens.append(next_token)
            attention_weights.append(attn.squeeze(0).cpu().numpy())
            
            if next_token == vocab.eos_id:
                break
            
            input_token = torch.tensor([next_token], device=device)
    
    # Decode caption
    caption = vocab.decode(tokens)
    token_strs = [vocab.id_to_token.get(t, f'<{t}>') for t in tokens[1:-1]]  # Skip BOS/EOS
    
    # Plot attention heatmap
    attention_matrix = np.array(attention_weights[:-1])  # Skip EOS
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Show image
    axes[0].imshow(image.permute(1, 2, 0).cpu().numpy())
    axes[0].set_title('Input Image', fontsize=12, fontweight='bold')
    axes[0].axis('off')
    
    # Show attention heatmap
    im = axes[1].imshow(attention_matrix, cmap='viridis', aspect='auto')
    axes[1].set_yticks(range(len(token_strs)))
    axes[1].set_yticklabels(token_strs, fontsize=10)
    axes[1].set_xlabel('Spatial Position (4√ó4 grid, flattened)', fontsize=10)
    axes[1].set_ylabel('Generated Token', fontsize=10)
    axes[1].set_title('Attention Weights', fontsize=12, fontweight='bold')
    plt.colorbar(im, ax=axes[1], label='Attention Weight')
    
    fig.suptitle(f'Caption: "{caption}"', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()

# Visualize attention for a few samples
for i in range(3):
    idx = np.random.randint(len(train_dataset))
    sample = train_dataset[idx]
    visualize_attention(model, sample['image'], vocab, device)

## Error Analysis

Find and visualize common failure cases.

In [None]:
def analyze_errors(model, dataset, num_samples=100):
    """Analyze prediction errors."""
    model.eval()
    
    errors = []
    
    # Sample predictions
    indices = np.random.choice(len(dataset), num_samples, replace=False)
    
    for idx in tqdm(indices, desc="Analyzing errors"):
        sample = dataset[idx]
        image = sample['image'].unsqueeze(0).to(device)
        gt = sample['text']
        
        with torch.no_grad():
            _, preds = greedy_decode(model, image, vocab, use_constraints=True)
        
        pred = preds[0]
        
        if pred != gt:
            errors.append({
                'image': sample['image'],
                'gt': gt,
                'pred': pred,
                'idx': idx
            })
    
    print(f"\nFound {len(errors)} errors out of {num_samples} samples")
    print(f"Accuracy: {(num_samples - len(errors)) / num_samples:.2%}")
    
    # Show some errors
    if errors:
        num_show = min(8, len(errors))
        fig, axes = plt.subplots(2, 4, figsize=(16, 8))
        axes = axes.flatten()
        
        for i, error in enumerate(errors[:num_show]):
            ax = axes[i]
            ax.imshow(error['image'].permute(1, 2, 0).cpu().numpy())
            title = f"GT: {error['gt']}\nPred: {error['pred']}"
            ax.set_title(title, fontsize=9, color='red')
            ax.axis('off')
        
        plt.tight_layout()
        plt.suptitle('Error Cases', fontsize=14, fontweight='bold', y=1.00)
        plt.show()
    
    return errors

errors = analyze_errors(model, train_dataset, num_samples=100)

## Test Compositional Generalization

Test on holdout splits to evaluate compositional understanding.

In [None]:
# Generate holdout test sets
print("Generating compositional holdout test sets...")

# Color-shape holdout
!python -m data.gen \
    --out_dir ./data/scenes_color_holdout \
    --n 500 \
    --split_strategy color_shape \
    --seed 42

# Relation holdout
!python -m data.gen \
    --out_dir ./data/scenes_relation_holdout \
    --n 500 \
    --split_strategy relation \
    --seed 42

print("\n‚úÖ Holdout datasets generated!")

In [None]:
# Evaluate on holdout splits
from data.dataset import SceneDataset
from torch.utils.data import DataLoader

def evaluate_on_split(model, data_dir, split_name):
    """Evaluate on a specific data split."""
    dataset = SceneDataset(data_dir, 'test', vocab)
    loader = DataLoader(dataset, batch_size=64, num_workers=2)
    
    metrics = evaluate_model(model, loader, vocab, device, use_constraints=True)
    results = metrics.compute()
    
    print(f"\n=== {split_name} ===")
    print(f"  Exact Match: {results['exact_match']:.4f}")
    print(f"  Token Accuracy: {results['token_accuracy']:.4f}")
    print(f"  Color F1: {results['color_f1']:.4f}")
    print(f"  Shape F1: {results['shape_f1']:.4f}")
    
    return results

# Evaluate on different splits
results_iid = evaluate_on_split(model, './data/scenes', 'IID (Random Split)')
results_color = evaluate_on_split(model, './data/scenes_color_holdout', 'Color-Shape Holdout')
results_relation = evaluate_on_split(model, './data/scenes_relation_holdout', 'Relation Holdout')

# Compare results
fig, ax = plt.subplots(1, 1, figsize=(10, 6))

splits = ['IID', 'Color Holdout', 'Relation Holdout']
exact_matches = [
    results_iid['exact_match'],
    results_color['exact_match'],
    results_relation['exact_match']
]
token_accs = [
    results_iid['token_accuracy'],
    results_color['token_accuracy'],
    results_relation['token_accuracy']
]

x = np.arange(len(splits))
width = 0.35

ax.bar(x - width/2, exact_matches, width, label='Exact Match', alpha=0.8)
ax.bar(x + width/2, token_accs, width, label='Token Accuracy', alpha=0.8)

ax.set_ylabel('Accuracy')
ax.set_title('Compositional Generalization: Performance on Different Splits')
ax.set_xticks(x)
ax.set_xticklabels(splits)
ax.legend()
ax.set_ylim([0, 1])
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

## Export Model

Save the trained model for later use.

In [None]:
# Save to Google Drive
import shutil

# Copy best model to Drive
drive_checkpoint_dir = '/content/drive/MyDrive/learning_to_see_checkpoints'
Path(drive_checkpoint_dir).mkdir(parents=True, exist_ok=True)

best_model_path = checkpoint_dir / 'best_model.pt'
if best_model_path.exists():
    shutil.copy(best_model_path, drive_checkpoint_dir)
    print(f"‚úì Saved best model to: {drive_checkpoint_dir}/best_model.pt")

# Save final metrics
final_metrics = {
    'iid': results_iid,
    'color_holdout': results_color,
    'relation_holdout': results_relation
}

with open(Path(drive_checkpoint_dir) / 'final_metrics.json', 'w') as f:
    json.dump(final_metrics, f, indent=2)

print("‚úì Saved final metrics to Drive")
print("\nüéâ Training complete!")