# Text-to-Image Distillation Training (Colab)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jtooates/learning_to_see/blob/main/train_distill_colab.ipynb)

This notebook trains a neural network to generate 64×64 images from text captions by distilling knowledge from a procedural renderer.

**What this tests:**
- Text encoder: Converts DSL captions → 512-d embeddings
- Image decoder: Generates 64×64 RGB images from embeddings
- Training pipeline: Data generation, training loop, evaluation
- Visualization: Side-by-side comparison of renderer (teacher) vs model (student)

**Expected results:**
- Model learns to generate images matching the renderer
- PSNR ≥ 24 dB, SSIM ≥ 0.92 after training
- Visual quality: Clear colors, shapes, and spatial relationships

## Setup: Install Dependencies and Clone Repository

In [None]:
# Install required packages
!pip install -q torch torchvision tqdm scikit-learn matplotlib seaborn pillow

In [None]:
# Clone repository (or skip if already cloned)
import os
import sys

if not os.path.exists('learning_to_see'):
    !git clone https://github.com/jtooates/learning_to_see.git
    %cd learning_to_see
else:
    %cd learning_to_see
    !git pull  # Get latest changes
    
    # If modules are already imported, restart runtime to pick up changes
    if 'distill_c1.trainer' in sys.modules:
        print("\n⚠️  Code updated! Please restart the runtime to load the latest changes.")
        print("   Go to: Runtime → Restart runtime")
        print("   Then run all cells again from the beginning.")

## Mount Google Drive (Optional - for persistent storage)

This allows you to save checkpoints and data to your Google Drive so they persist across Colab sessions.

In [None]:
from google.colab import drive
import os

# Mount Google Drive
drive.mount('/content/drive')

# Create directories in Google Drive for persistent storage
DRIVE_DIR = '/content/drive/MyDrive/learning_to_see_colab'
os.makedirs(DRIVE_DIR, exist_ok=True)
os.makedirs(f'{DRIVE_DIR}/data', exist_ok=True)
os.makedirs(f'{DRIVE_DIR}/runs', exist_ok=True)

print(f"Data and checkpoints will be saved to: {DRIVE_DIR}")

## Configuration

In [None]:
# Training configuration
CONFIG = {
    # Data
    'n_samples': 3000,  # Number of training samples (use 1000 for quick test, 6000+ for full training)
    'data_seed': 42,
    
    # Training
    'steps': 20000,  # Training steps (use 5000 for quick test, 100000 for full training)
    'batch_size': 128,  # Reduce if OOM (try 64 or 32)
    'lr': 1e-3,
    
    # Loss weights (balanced for plausible images, not perfect reconstruction)
    'pixel_weight': 0.1,   # Low weight on pixel matching (L1+L2)
    'tv_weight': 1e-4,     # Smoothness regularization
    'perc_weight': 1.0,    # High weight on perceptual/structure matching
    'use_perc': True,      # Enable perceptual loss
    
    'use_amp': True,  # Automatic mixed precision
    'seed': 1337,
    
    # Evaluation
    'eval_every': 1000,  # Evaluate every N steps
    
    # Paths (will use Google Drive if mounted, otherwise local)
    'use_drive': os.path.exists('/content/drive/MyDrive'),
}

# Set paths based on whether Drive is mounted
if CONFIG['use_drive']:
    CONFIG['data_dir'] = f"{DRIVE_DIR}/data/scenes"
    CONFIG['save_dir'] = f"{DRIVE_DIR}/runs/distill_test"
else:
    CONFIG['data_dir'] = "/content/data/scenes"
    CONFIG['save_dir'] = "/content/runs/distill_test"

print("Configuration:")
for k, v in CONFIG.items():
    print(f"  {k}: {v}")

## Step 1: Generate Synthetic Data

Generate scene graphs, render images, and create text captions.

In [None]:
import sys
import torch
import random
import numpy as np
from pathlib import Path

# Set random seeds
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(CONFIG['data_seed'])

# Check if data already exists
data_path = Path(CONFIG['data_dir'])
if data_path.exists() and (data_path / 'train_shard_0.pt').exists():
    print(f"✓ Data already exists at {CONFIG['data_dir']}")
    print("  Skipping data generation. Delete the directory to regenerate.")
else:
    print(f"Generating {CONFIG['n_samples']} samples...")
    print(f"Output directory: {CONFIG['data_dir']}")
    
    # Run data generation
    !python -m data.gen \
        --out_dir {CONFIG['data_dir']} \
        --n {CONFIG['n_samples']} \
        --split_strategy random \
        --seed {CONFIG['data_seed']}
    
    print("\n✓ Data generation complete!")

## Step 2: Visualize Sample Data

Let's look at some examples from the generated dataset to verify quality.

In [None]:
import matplotlib.pyplot as plt
from data.dataset import SceneDataset
from dsl.tokens import Vocab

# Load dataset
vocab = Vocab()
dataset = SceneDataset(
    data_dir=CONFIG['data_dir'],
    split='train',
    vocab=vocab,
)

print(f"Dataset size: {len(dataset)} samples")
print(f"Vocabulary size: {len(vocab)} tokens")

# Visualize first 8 samples
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
axes = axes.flatten()

for i in range(8):
    sample = dataset[i]
    image = sample['image']  # Shape: (3, 64, 64), range [-1, 1]
    text = sample['text']
    
    # Convert to displayable format
    img_display = (image.permute(1, 2, 0).numpy() + 1) / 2  # [-1,1] -> [0,1]
    img_display = np.clip(img_display, 0, 1)
    
    axes[i].imshow(img_display)
    axes[i].set_title(text, fontsize=10)
    axes[i].axis('off')

plt.tight_layout()
plt.savefig('/content/sample_data.png', dpi=150, bbox_inches='tight')
plt.show()

print("\n✓ Sample visualization saved to /content/sample_data.png")

## Step 3: Initialize Models

Build the text encoder and image decoder.

In [None]:
from distill_c1.text_encoder import build_text_encoder
from distill_c1.decoder import build_decoder
from dsl.tokens import Vocab

# Initialize vocabulary
vocab = Vocab()
print(f"Vocabulary: {len(vocab)} tokens")
print(f"Special tokens: BOS={vocab.bos_id}, EOS={vocab.eos_id}, PAD={vocab.pad_id}")

# Build models
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"\nDevice: {device}")

text_encoder = build_text_encoder(vocab_size=len(vocab), pad_id=vocab.pad_id).to(device)
decoder = build_decoder().to(device)

# Count parameters
def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

encoder_params = count_params(text_encoder)
decoder_params = count_params(decoder)
total_params = encoder_params + decoder_params

print(f"\nModel parameters:")
print(f"  Text encoder: {encoder_params:,} ({encoder_params/1e6:.2f}M)")
print(f"  Image decoder: {decoder_params:,} ({decoder_params/1e6:.2f}M)")
print(f"  Total: {total_params:,} ({total_params/1e6:.2f}M)")

# Test forward pass
print("\nTesting forward pass...")
test_tokens = torch.randint(1, len(vocab), (2, 20)).to(device)
with torch.no_grad():
    embeddings = text_encoder(test_tokens, pad_id=vocab.pad_id)
    images = decoder(embeddings)

print(f"  Input shape: {test_tokens.shape}")
print(f"  Embedding shape: {embeddings.shape}")
print(f"  Output shape: {images.shape}")
print(f"  Output range: [{images.min():.3f}, {images.max():.3f}]")
print("\n✓ Models initialized successfully!")

## Step 4: Train the Model

Train the text-to-image model with distillation from the renderer.

In [None]:
from torch.utils.data import DataLoader
from distill_c1.trainer import DistillTrainer, DistillDataset, distill_collate_fn

# Set training seed
set_seed(CONFIG['seed'])

# Create datasets
train_dataset = DistillDataset(
    data_dir=CONFIG['data_dir'],
    split='train',
    vocab=vocab,
)

val_dataset = DistillDataset(
    data_dir=CONFIG['data_dir'],
    split='val',
    vocab=vocab,
)

print(f"Train samples: {len(train_dataset)}")
print(f"Val samples: {len(val_dataset)}")

# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=True,
    num_workers=2,
    pin_memory=True,
    collate_fn=distill_collate_fn,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    num_workers=2,
    pin_memory=True,
    collate_fn=distill_collate_fn,
)

# Create trainer
trainer = DistillTrainer(
    text_encoder=text_encoder,
    decoder=decoder,
    train_loader=train_loader,
    val_loader=val_loader,
    vocab=vocab,
    device=device,
    lr=CONFIG['lr'],
    save_dir=CONFIG['save_dir'],
    pixel_weight=CONFIG['pixel_weight'],
    tv_weight=CONFIG['tv_weight'],
    perc_weight=CONFIG['perc_weight'],
    use_perc=CONFIG['use_perc'],
    use_amp=CONFIG['use_amp'],
    eval_every=CONFIG['eval_every'],
)

print(f"\nStarting training for {CONFIG['steps']} steps...")
print(f"Checkpoints will be saved to: {CONFIG['save_dir']}")
print("\nTraining metrics will be displayed below:")
print("-" * 80)

# Train - pass total_steps to the train() method
trainer.train(total_steps=CONFIG['steps'])

print("\n" + "=" * 80)
print("✓ Training complete!")
print(f"  Best checkpoint: {CONFIG['save_dir']}/ema_best.pt")
print(f"  Latest checkpoint: {CONFIG['save_dir']}/last.pt")
print(f"  Training log: {CONFIG['save_dir']}/log.json")

## Step 5: Visualize Training Results

Load and display the generated images during training.

In [None]:
from PIL import Image
from pathlib import Path
import json

save_dir = Path(CONFIG['save_dir'])
samples_dir = save_dir / 'samples'

# Display training curve
log_path = save_dir / 'log.json'
if log_path.exists():
    with open(log_path, 'r') as f:
        log_data = json.load(f)  # Load as JSON array, not line-by-line
    
    # Extract metrics
    steps = [entry['step'] for entry in log_data]
    psnr_vals = [entry['metrics']['psnr'] for entry in log_data]
    ssim_vals = [entry['metrics']['ssim'] for entry in log_data]
    loss_vals = [entry['metrics']['total'] for entry in log_data]
    
    # Plot
    fig, axes = plt.subplots(1, 3, figsize=(18, 4))
    
    axes[0].plot(steps, loss_vals, marker='o')
    axes[0].set_xlabel('Step')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Validation Loss')
    axes[0].grid(True, alpha=0.3)
    
    axes[1].plot(steps, psnr_vals, marker='o')
    axes[1].axhline(y=24, color='r', linestyle='--', alpha=0.5, label='Target: 24 dB')
    axes[1].set_xlabel('Step')
    axes[1].set_ylabel('PSNR (dB)')
    axes[1].set_title('Validation PSNR')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    axes[2].plot(steps, ssim_vals, marker='o')
    axes[2].axhline(y=0.92, color='r', linestyle='--', alpha=0.5, label='Target: 0.92')
    axes[2].set_xlabel('Step')
    axes[2].set_ylabel('SSIM')
    axes[2].set_title('Validation SSIM')
    axes[2].legend()
    axes[2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('/content/training_curves.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    # Print final metrics
    if psnr_vals:
        print(f"\nFinal validation metrics:")
        print(f"  PSNR: {psnr_vals[-1]:.2f} dB (target: ≥24 dB)")
        print(f"  SSIM: {ssim_vals[-1]:.4f} (target: ≥0.92)")
        
        if psnr_vals[-1] >= 24 and ssim_vals[-1] >= 0.92:
            print("\n✓ Model meets acceptance criteria!")
        else:
            print("\n⚠ Model needs more training to meet acceptance criteria.")
            print("  Consider increasing 'steps' in CONFIG and re-running training.")

# Display latest generated images
if samples_dir.exists():
    sample_files = sorted(samples_dir.glob('step_*.png'))
    if sample_files:
        latest_sample = sample_files[-1]
        print(f"\nLatest generated samples ({latest_sample.name}):")
        img = Image.open(latest_sample)
        plt.figure(figsize=(14, 10))
        plt.imshow(img)
        plt.axis('off')
        plt.title('Generated Images (Top: Teacher/Renderer, Bottom: Student/Model)', fontsize=14)
        plt.tight_layout()
        plt.savefig('/content/latest_samples.png', dpi=150, bbox_inches='tight')
        plt.show()
        print("\nNote: Top row = Teacher (renderer), Bottom row = Student (model)")
        print("The model should learn to match the renderer's output.")

## Step 6: Test the Trained Model

Generate images from custom text captions to verify the model works correctly.

In [None]:
# Load best model
checkpoint_path = Path(CONFIG['save_dir']) / 'ema_best.pt'

if checkpoint_path.exists():
    print(f"Loading checkpoint: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    text_encoder.load_state_dict(checkpoint['text_encoder'])
    decoder.load_state_dict(checkpoint['decoder'])
    text_encoder.eval()
    decoder.eval()
    print("✓ Model loaded successfully")
else:
    print(f"⚠ Checkpoint not found: {checkpoint_path}")
    print("Using current model state (may not be optimal)")

# Test captions
test_captions = [
    "There is one red ball.",
    "There are three blue cubes.",
    "There are five yellow blocks.",
    "The red ball is left of the blue cube.",
    "The green block is on the yellow ball.",
    "The blue cube is in front of the red block.",
    "There are two green balls.",
    "There are four red blocks.",
]

print(f"\nGenerating images for {len(test_captions)} test captions...")

# Generate images
generated_images = []
with torch.no_grad():
    for caption in test_captions:
        # Tokenize
        token_ids = vocab.encode(caption, add_special=True)
        token_ids = torch.tensor(token_ids).unsqueeze(0).to(device)  # (1, seq_len)
        
        # Generate image
        embedding = text_encoder(token_ids, pad_id=vocab.pad_id)
        image = decoder(embedding)
        
        # Convert to displayable format
        image = image.cpu().squeeze(0)  # (3, 64, 64)
        image = (image.permute(1, 2, 0).numpy() + 1) / 2  # [-1,1] -> [0,1]
        image = np.clip(image, 0, 1)
        
        generated_images.append(image)

# Visualize results
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
axes = axes.flatten()

for i, (caption, image) in enumerate(zip(test_captions, generated_images)):
    axes[i].imshow(image)
    axes[i].set_title(caption, fontsize=10)
    axes[i].axis('off')

plt.tight_layout()
plt.savefig('/content/test_generations.png', dpi=150, bbox_inches='tight')
plt.show()

print("\n✓ Test generation complete!")
print("\nWhat to look for:")
print("  ✓ Correct colors (red, blue, green, yellow)")
print("  ✓ Correct shapes (ball=circle, cube/block=square)")
print("  ✓ Correct counts (one, two, three, four, five)")
print("  ✓ Correct spatial relations (left of, on, in front of)")
print("  ✓ No checkerboard artifacts or blurriness")

## Step 7: Compare with Ground Truth (Renderer)

Generate side-by-side comparisons with the renderer to evaluate quality.

In [None]:
from dsl.parser import SceneParser
from render.renderer import SceneRenderer

# Initialize renderer
parser = SceneParser()
renderer = SceneRenderer(seed=42)

print("Generating comparisons with ground truth renderer...\n")

# Use subset of test captions
comparison_captions = test_captions[:4]

fig, axes = plt.subplots(len(comparison_captions), 3, figsize=(12, 4 * len(comparison_captions)))
if len(comparison_captions) == 1:
    axes = axes.reshape(1, -1)

with torch.no_grad():
    for i, caption in enumerate(comparison_captions):
        # Generate with renderer (teacher)
        scene_graph = parser.parse(caption)
        teacher_image, _ = renderer.render(scene_graph)
        teacher_array = np.array(teacher_image) / 255.0  # [0, 1]
        
        # Generate with model (student)
        token_ids = vocab.encode(caption, add_special=True)
        token_ids = torch.tensor(token_ids).unsqueeze(0).to(device)
        embedding = text_encoder(token_ids, pad_id=vocab.pad_id)
        student_image = decoder(embedding)
        student_array = student_image.cpu().squeeze(0).permute(1, 2, 0).numpy()
        student_array = (student_array + 1) / 2  # [-1,1] -> [0,1]
        student_array = np.clip(student_array, 0, 1)
        
        # Compute difference
        diff = np.abs(teacher_array - student_array)
        
        # Plot
        axes[i, 0].imshow(teacher_array)
        axes[i, 0].set_title(f'Teacher (Renderer)\n{caption}', fontsize=10)
        axes[i, 0].axis('off')
        
        axes[i, 1].imshow(student_array)
        axes[i, 1].set_title(f'Student (Model)\n{caption}', fontsize=10)
        axes[i, 1].axis('off')
        
        axes[i, 2].imshow(diff, cmap='hot', vmin=0, vmax=1)
        axes[i, 2].set_title(f'Absolute Difference\nMean: {diff.mean():.4f}', fontsize=10)
        axes[i, 2].axis('off')

plt.tight_layout()
plt.savefig('/content/teacher_student_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

print("\n✓ Comparison complete!")
print("\nInterpretation:")
print("  • Left column: Ground truth from procedural renderer")
print("  • Middle column: Generated by trained model")
print("  • Right column: Pixel-wise difference (red = high error)")
print("\nThe model should closely match the renderer with minimal difference.")

## Summary

This notebook tested the text-to-image distillation pipeline:

1. ✓ Generated synthetic data (scenes + captions + images)
2. ✓ Trained text encoder + image decoder
3. ✓ Visualized training progress
4. ✓ Generated images from test captions
5. ✓ Compared model output with ground truth renderer

### Next Steps

**If results look good:**
- Train on more data (increase `n_samples` to 6000+)
- Train longer (increase `steps` to 100,000)
- Evaluate on compositional generalization (holdout splits)

**If results need improvement:**
- Check data quality (visualizations in Step 2)
- Verify metrics (PSNR ≥ 24 dB, SSIM ≥ 0.92)
- Adjust hyperparameters (learning rate, loss weights)

**Files saved:**
- `/content/sample_data.png` - Training data samples
- `/content/training_curves.png` - Training metrics over time
- `/content/latest_samples.png` - Generated samples during training
- `/content/test_generations.png` - Model outputs on test captions
- `/content/teacher_student_comparison.png` - Side-by-side comparison

All checkpoints and logs are saved to your specified save directory (Google Drive if mounted).