# Text-to-Image Distillation Training (C1)

Train a neural network to generate 64×64 images from text captions.

**Architecture:** 4-layer Transformer encoder + FiLM-conditioned CNN decoder  
**Parameters:** ~11M total (2.4M encoder + 8.5M decoder)  
**Training time:** ~6-8 hours for 100k steps on GPU

**Note:** This notebook uses Google Drive for persistent storage.

## 1. Setup and Mount Drive

In [None]:
# Mount Google Drive for persistent storage
from google.colab import drive
drive.mount('/content/drive')
print("✓ Drive mounted")

In [None]:
# Check GPU
import torch
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Clone repo to Drive (persistent) or local (temporary)
import os

# Use Drive for persistence (recommended)
WORK_DIR = '/content/drive/MyDrive/learning_to_see'

# Or use local for faster I/O (but data lost on disconnect)
# WORK_DIR = '/content/learning_to_see'

if not os.path.exists(WORK_DIR):
    !git clone https://github.com/jtooates/learning_to_see.git {WORK_DIR}
    print(f"✓ Cloned to {WORK_DIR}")
else:
    print(f"✓ Found existing repo at {WORK_DIR}")

%cd {WORK_DIR}
!git pull  # Get latest changes

In [None]:
# Install dependencies
!pip install -q pillow numpy regex tqdm pyyaml matplotlib ipywidgets pytest

# Add to Python path
import sys
sys.path.insert(0, WORK_DIR)
print("✓ Dependencies installed")

## 2. Generate Data

**Note:** Data will be saved to your Google Drive (persistent) or local (temporary based on WORK_DIR choice).  
First run takes ~10-20 minutes for 2000 scenes.

In [None]:
# Check if data already exists
import os
data_path = 'data/scenes'

if os.path.exists(data_path) and os.path.exists(f'{data_path}/manifest.json'):
    print(f"✓ Found existing data at {data_path}")
    print("Skipping generation. Delete the folder to regenerate.")
else:
    print("Generating new dataset...")
    !python -m data.gen --out_dir {data_path} --n 2000 --split_strategy random --seed 42
    print("✓ Data generation complete")

In [None]:
# Visualize samples
!python visualize_samples.py --data_dir data/scenes --n 16 --save_path samples.png

import os
from IPython.display import Image, display
if os.path.exists('samples.png'):
    display(Image('samples.png', width=800))
else:
    print("Note: Preview not available")

## 3. Run Tests (Optional but Recommended)

In [None]:
# Verify all components work
!python -m pytest distill_c1/tests_distill.py -v --tb=short

print("
✓ All tests passed!")

## 4. Training Configuration

**Adjust batch size based on your GPU:**
- T4 (free Colab): batch=64 or batch=96
- A100 (Colab Pro): batch=192
- V100: batch=128

**Training time estimates:**
- 30k steps: ~2.5 hours
- 50k steps: ~4 hours  
- 100k steps: ~8 hours

In [None]:
# Configure training
config = {
    'data_dir': 'data/scenes',
    'save_dir': 'runs/distill_c1',  # Saved in your working directory
    'steps': 30000,      # Reduce for faster (full: 100000)
    'batch': 96,         # Adjust for GPU (T4: 64-96, A100: 192)
    'eval_every': 2000,
    'lr': 3e-4,
    'seed': 1337,
}

print("Training Configuration:")
print("=" * 40)
for k, v in config.items():
    print(f"{k:15s}: {v}")
print("=" * 40)
print(f"
Estimated time: ~{config['steps'] / 12000:.1f} hours")

## 5. Quick Sanity Check (Optional)

Run a very short training (500 steps, ~5 minutes) to verify everything works.

In [None]:
# Quick sanity check
!python -m distill_c1.train_distill   --data_dir data/scenes   --save_dir runs/sanity_check   --steps 500   --batch 64   --eval_every 250   --seed 42

print("
✓ Sanity check complete!")

## 6. Full Training

**Important:** This will take several hours. The notebook must stay connected.  
Progress is saved every eval_every steps, so you can resume if disconnected.

In [None]:
# Check for existing checkpoint to resume
import os
resume_ckpt = f"{config['save_dir']}/last.pt"
resume_flag = f"--resume {resume_ckpt}" if os.path.exists(resume_ckpt) else ""

if resume_flag:
    print(f"✓ Found checkpoint, will resume from {resume_ckpt}")
else:
    print("Starting fresh training")

# Train model
!python -m distill_c1.train_distill   --data_dir {config['data_dir']}   --save_dir {config['save_dir']}   --steps {config['steps']}   --batch {config['batch']}   --eval_every {config['eval_every']}   --lr {config['lr']}   --seed {config['seed']}   --use_amp   {resume_flag}

print("
✓ Training complete!")

## 7. Visualize Training Progress

In [None]:
# Plot training metrics
import json
import matplotlib.pyplot as plt
from pathlib import Path

log_path = Path(config['save_dir']) / 'log.json'

if not log_path.exists():
    print("No training log found. Train the model first (Section 6).")
else:
    with open(log_path) as f:
        log = json.load(f)
    
    steps = [e['step'] for e in log]
    psnr = [e['metrics']['psnr'] for e in log]
    ssim = [e['metrics']['ssim'] for e in log]
    loss = [e['metrics']['total'] for e in log]
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    # PSNR
    axes[0].plot(steps, psnr, linewidth=2)
    axes[0].axhline(24, color='r', linestyle='--', label='Target (24 dB)', alpha=0.7)
    axes[0].set_xlabel('Step', fontsize=12)
    axes[0].set_ylabel('PSNR (dB)', fontsize=12)
    axes[0].set_title('Peak Signal-to-Noise Ratio', fontsize=13)
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # SSIM
    axes[1].plot(steps, ssim, linewidth=2, color='orange')
    axes[1].axhline(0.92, color='r', linestyle='--', label='Target (0.92)', alpha=0.7)
    axes[1].set_xlabel('Step', fontsize=12)
    axes[1].set_ylabel('SSIM', fontsize=12)
    axes[1].set_title('Structural Similarity', fontsize=13)
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    # Loss
    axes[2].plot(steps, loss, linewidth=2, color='green')
    axes[2].set_xlabel('Step', fontsize=12)
    axes[2].set_ylabel('Loss', fontsize=12)
    axes[2].set_title('Training Loss', fontsize=13)
    axes[2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print(f"
Final Metrics (Step {steps[-1]}):")
    print(f"  PSNR: {psnr[-1]:.2f} dB")
    print(f"  SSIM: {ssim[-1]:.4f}")
    print(f"  Loss: {loss[-1]:.6f}")

In [None]:
# View generated samples during training
from IPython.display import Image, display
from pathlib import Path
import glob

sample_dir = Path(config['save_dir']) / 'samples'
sample_files = sorted(glob.glob(str(sample_dir / 'step_*.png')))

if not sample_files:
    print("No sample images found yet.")
else:
    print(f"Found {len(sample_files)} sample grids
")
    print("Showing samples at different stages:
")
    
    # Show first, middle, and last
    indices = [0, len(sample_files)//2, -1]
    for idx in indices:
        file = sample_files[idx]
        step = Path(file).stem.split('_')[1]
        print(f"Step {step}:")
        display(Image(file, width=800))
        print()

## 8. Evaluation

In [None]:
# Run comprehensive evaluation
!python -m distill_c1.eval_distill   --data_dir {config['data_dir']}   --ckpt {config['save_dir']}/ema_best.pt   --report {config['save_dir']}/report.json   --save_images {config['save_dir']}/eval_images   --counterfactual

print("
✓ Evaluation complete")

In [None]:
# Display evaluation report
from pathlib import Path
import json

report_path = Path(config['save_dir']) / 'report.json'

if report_path.exists():
    with open(report_path) as f:
        report = json.load(f)
    
    print("Evaluation Report")
    print("=" * 50)
    print(f"Split: {report['split']}")
    print(f"Samples: {report['num_samples']}")
    print(f"
Metrics:")
    print(f"  PSNR: {report['metrics']['psnr']:.2f} dB")
    print(f"  SSIM: {report['metrics']['ssim']:.4f}")
    
    psnr_pass = report['metrics']['psnr'] >= 24.0
    ssim_pass = report['metrics']['ssim'] >= 0.92
    
    print(f"
Acceptance Criteria:")
    print(f"  PSNR >= 24 dB: {'✓ PASS' if psnr_pass else '✗ FAIL'}")
    print(f"  SSIM >= 0.92:  {'✓ PASS' if ssim_pass else '✗ FAIL'}")
    
    if 'counterfactual' in report and report['counterfactual']:
        print(f"
Counterfactual Sensitivity:")
        for edit_type, results in report['counterfactual'].items():
            print(f"  {edit_type}: ΔL2 = {results['avg_delta_l2']:.6f}")
    
    print("=" * 50)
else:
    print("No report found. Run evaluation first (Section 8).")

In [None]:
# Show evaluation visualizations
from IPython.display import Image, display
from pathlib import Path

eval_dir = Path(config['save_dir']) / 'eval_images'

if (eval_dir / 'grid.png').exists():
    print("Teacher vs Student vs Difference:
")
    display(Image(str(eval_dir / 'grid.png'), width=900))
    
    print("

Side-by-Side Comparison:
")
    display(Image(str(eval_dir / 'comparison.png'), width=600))
else:
    print("No evaluation images found. Run evaluation first (Section 8).")

## 9. Interactive Generation

In [None]:
# Load trained model
import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from distill_c1.text_encoder import build_text_encoder
from distill_c1.decoder import build_decoder
from dsl.tokens import Vocab

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load vocabulary
vocab = Vocab()

# Build models
text_encoder = build_text_encoder(vocab_size=len(vocab), pad_id=vocab.pad_id)
decoder = build_decoder()

# Load checkpoint
ckpt_path = Path(config['save_dir']) / 'ema_best.pt'
if not ckpt_path.exists():
    print(f"Checkpoint not found: {ckpt_path}")
    print("Train the model first (Section 6).")
else:
    checkpoint = torch.load(ckpt_path, map_location=device)
    text_encoder.load_state_dict(checkpoint['text_encoder'])
    decoder.load_state_dict(checkpoint['decoder'])
    
    text_encoder.to(device).eval()
    decoder.to(device).eval()
    
    print(f"✓ Model loaded successfully")
    print(f"Device: {device}")

In [None]:
# Generation functions
def generate(text):
    """Generate image from text."""
    tokens = vocab.encode(text, add_special_tokens=True)
    token_ids = torch.tensor([tokens], dtype=torch.long, device=device)
    
    with torch.no_grad():
        e = text_encoder(token_ids, vocab.pad_id)
        img = decoder(e)
    
    # Convert to numpy
    img = img.squeeze(0).cpu()
    img = (img + 1.0) / 2.0  # [-1, 1] -> [0, 1]
    img = img.permute(1, 2, 0).numpy()
    return np.clip(img, 0, 1)

def show(text):
    """Generate and display."""
    plt.figure(figsize=(5, 5))
    plt.imshow(generate(text))
    plt.title(text, fontsize=12, pad=10)
    plt.axis('off')
    plt.tight_layout()
    plt.show()

print("✓ Generation functions ready")

In [None]:
# Example generations
prompts = [
    "There is one red ball.",
    "There are two green cubes.",
    "There are three yellow blocks.",
    "The blue ball is left of the red cube.",
    "The green block is on the yellow ball.",
    "The red cube is in front of the blue block.",
]

print("Generating images from prompts...
")

for prompt in prompts:
    print(f"Prompt: {prompt}")
    show(prompt)
    print()

In [None]:
# Interactive generation with custom prompts
from ipywidgets import interact, Text

def generate_interactive(prompt):
    if prompt.strip():
        show(prompt)
    else:
        print("Enter a prompt in DSL format:")
        print('  "There is one red ball."')
        print('  "There are two green cubes."')
        print('  "The blue ball is left of the red cube."')

interact(generate_interactive, 
         prompt=Text(value="There is one red ball.", 
                    description="Prompt:",
                    style={'description_width': 'initial'}))

## 10. Counterfactual Analysis

In [None]:
# Compare how text changes affect images
def compare(prompt1, prompt2):
    """Compare two prompts."""
    img1 = generate(prompt1)
    img2 = generate(prompt2)
    diff = np.abs(img1 - img2)
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    axes[0].imshow(img1)
    axes[0].set_title(prompt1, fontsize=10)
    axes[0].axis('off')
    
    axes[1].imshow(img2)
    axes[1].set_title(prompt2, fontsize=10)
    axes[1].axis('off')
    
    axes[2].imshow(diff, cmap='hot')
    axes[2].set_title('Difference (hotter = more change)', fontsize=10)
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    l2_diff = np.mean(diff ** 2)
    print(f"L2 Difference: {l2_diff:.6f}
")

# Color change
print("Color Change:")
compare("There is one red ball.", "There is one blue ball.")

# Shape change
print("Shape Change:")
compare("There is one red ball.", "There is one red cube.")

# Number change
print("Number Change:")
compare("There are two green cubes.", "There are three green cubes.")

# Relation change
print("Relation Change:")
compare("The blue ball is left of the red cube.", 
        "The blue ball is right of the red cube.")

## Summary

✅ **Completed:**
1. Data generation with DSL and renderer
2. Model training with FiLM-conditioned decoder
3. Evaluation with PSNR/SSIM metrics
4. Interactive image generation
5. Counterfactual sensitivity analysis

**Key Results:**
- Total parameters: ~11M (2.4M encoder + 8.5M decoder)
- Target metrics: PSNR ≥ 24 dB, SSIM ≥ 0.92
- Training time: ~6-8 hours for 100k steps

**Next Steps:**
- Phase C2: Fine-tune with caption loss only
- Phase C3: Add adversarial losses
- Experiment with different compositions

**Data Location:**
All outputs are saved in your working directory (Drive or local).