# Text-to-Image Distillation Training (C1: Renderer Distillation)

This notebook trains a neural network to generate 64×64 images from text captions by distilling knowledge from a procedural renderer.

**Architecture:**
- Text Encoder: 4-layer Transformer (2.4M params)
- Image Decoder: FiLM-conditioned CNN with ResBlocks (8.5M params)
- Training: Pixel losses + TV + random perceptual loss

**Expected Results:**
- PSNR: 26-28 dB
- SSIM: 0.93-0.95
- Training time: ~6-8 hours for 100k steps

## 1. Setup and Installation

In [None]:
# Check GPU availability
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

In [None]:
# Clone repository (if not already cloned)
import os
if not os.path.exists('learning_to_see'):
    !git clone https://github.com/YOUR_USERNAME/learning_to_see.git
    %cd learning_to_see
else:
    %cd learning_to_see
    !git pull

In [None]:
# Install dependencies
!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install -q pillow numpy regex tqdm pyyaml matplotlib ipywidgets

print("✓ Dependencies installed")

## 2. Generate Training Data

Generate synthetic scene images with captions using the DSL and procedural renderer.

In [None]:
# Generate data (adjust --n based on desired dataset size)
# For quick testing: --n 1000
# For full training: --n 6000

!python -m data.gen \
  --out_dir data/scenes \
  --n 3000 \
  --split_strategy random \
  --seed 42

print("\n✓ Data generation complete")

In [None]:
# Visualize sample data
!python visualize_samples.py --data_dir data/scenes --n 16 --save_path samples_preview.png

from IPython.display import Image, display
display(Image('samples_preview.png'))

## 3. Run Tests

Verify that all components work correctly.

In [None]:
# Run unit tests
!python -m pytest distill_c1/tests_distill.py -v --tb=short

print("\n✓ All tests passed!")

## 4. Training Configuration

Configure training parameters. Adjust based on your needs and available compute.

In [None]:
# Training configuration
config = {
    # Data
    'data_dir': 'data/scenes',
    'save_dir': 'runs/distill_c1_colab',
    
    # Training schedule
    'steps': 50000,      # Reduce for faster training (full: 100000)
    'batch': 128,        # Adjust based on GPU memory (192 for A100, 64 for T4)
    'eval_every': 2000,
    
    # Optimizer
    'lr': 3e-4,
    'wd': 0.05,
    'warmup': 1000,
    'grad_clip': 1.0,
    
    # Loss weights
    'tv': 1e-5,
    'perc': 1e-3,
    'use_perc': True,
    
    # Model
    'emb_dim': 512,
    'base_ch': 256,
    'attn_heads': 4,
    
    # Training
    'use_amp': True,
    'ema': 0.999,
    'num_workers': 2,
    'seed': 1337,
}

# Display config
print("Training Configuration:")
print("=" * 50)
for key, value in config.items():
    print(f"{key:20s}: {value}")
print("=" * 50)

## 5. Quick Sanity Check (Optional)

Run a very short training run to verify everything works before the full training.

In [None]:
# Quick sanity check (500 steps, ~5 minutes)
!python -m distill_c1.train_distill \
  --data_dir data/scenes \
  --save_dir runs/sanity_check \
  --steps 500 \
  --batch 64 \
  --eval_every 250 \
  --seed 42

print("\n✓ Sanity check complete")

## 6. Full Training

Train the full model. This will take several hours.

In [None]:
# Build training command
train_cmd = f"""
python -m distill_c1.train_distill \
  --data_dir {config['data_dir']} \
  --save_dir {config['save_dir']} \
  --steps {config['steps']} \
  --batch {config['batch']} \
  --lr {config['lr']} \
  --wd {config['wd']} \
  --warmup {config['warmup']} \
  --grad_clip {config['grad_clip']} \
  --tv {config['tv']} \
  --perc {config['perc']} \
  --eval_every {config['eval_every']} \
  --emb_dim {config['emb_dim']} \
  --base_ch {config['base_ch']} \
  --attn_heads {config['attn_heads']} \
  --ema {config['ema']} \
  --num_workers {config['num_workers']} \
  --seed {config['seed']}
"""

if config['use_amp']:
    train_cmd += " --use_amp"
if not config['use_perc']:
    train_cmd += " --no_perc"

print("Starting training...")
print(f"Expected time: ~{config['steps'] / 12000:.1f} hours\n")

!{train_cmd}

print("\n✓ Training complete!")

## 7. Visualize Training Progress

Plot training metrics and view generated samples.

In [None]:
import json
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path

# Load training log
log_path = Path(config['save_dir']) / 'log.json'
with open(log_path, 'r') as f:
    log = json.load(f)

# Extract metrics
steps = [entry['step'] for entry in log]
psnr = [entry['metrics']['psnr'] for entry in log]
ssim = [entry['metrics']['ssim'] for entry in log]
loss = [entry['metrics']['total'] for entry in log]

# Plot
fig, axes = plt.subplots(1, 3, figsize=(18, 4))

# PSNR
axes[0].plot(steps, psnr, linewidth=2)
axes[0].axhline(y=24, color='r', linestyle='--', label='Target (24 dB)')
axes[0].set_xlabel('Step')
axes[0].set_ylabel('PSNR (dB)')
axes[0].set_title('Peak Signal-to-Noise Ratio')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# SSIM
axes[1].plot(steps, ssim, linewidth=2, color='orange')
axes[1].axhline(y=0.92, color='r', linestyle='--', label='Target (0.92)')
axes[1].set_xlabel('Step')
axes[1].set_ylabel('SSIM')
axes[1].set_title('Structural Similarity Index')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# Loss
axes[2].plot(steps, loss, linewidth=2, color='green')
axes[2].set_xlabel('Step')
axes[2].set_ylabel('Loss')
axes[2].set_title('Training Loss')
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('training_metrics.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\nFinal Metrics:")
print(f"  PSNR: {psnr[-1]:.2f} dB")
print(f"  SSIM: {ssim[-1]:.4f}")
print(f"  Loss: {loss[-1]:.6f}")

In [None]:
# View generated samples from different training steps
from IPython.display import Image, display
import glob

sample_dir = Path(config['save_dir']) / 'samples'
sample_files = sorted(glob.glob(str(sample_dir / 'step_*.png')))

print("Generated Samples During Training:")
print("=" * 50)

# Show samples at different checkpoints
indices = [0, len(sample_files)//4, len(sample_files)//2, 3*len(sample_files)//4, -1]
for idx in indices:
    if 0 <= idx < len(sample_files) or idx == -1:
        file = sample_files[idx]
        step = Path(file).stem.split('_')[1]
        print(f"\nStep {step}:")
        display(Image(file, width=800))

## 8. Evaluation

Evaluate the trained model on the validation set with comprehensive metrics.

In [None]:
# Run evaluation
eval_cmd = f"""
python -m distill_c1.eval_distill \
  --data_dir {config['data_dir']} \
  --ckpt {config['save_dir']}/ema_best.pt \
  --report {config['save_dir']}/report.json \
  --save_images {config['save_dir']}/eval_images \
  --split val \
  --counterfactual
"""

!{eval_cmd}

print("\n✓ Evaluation complete")

In [None]:
# Display evaluation report
report_path = Path(config['save_dir']) / 'report.json'
with open(report_path, 'r') as f:
    report = json.load(f)

print("Evaluation Report")
print("=" * 50)
print(f"Split: {report['split']}")
print(f"Samples: {report['num_samples']}")
print(f"\nMetrics:")
print(f"  PSNR: {report['metrics']['psnr']:.2f} dB")
print(f"  SSIM: {report['metrics']['ssim']:.4f}")

# Acceptance criteria
psnr_pass = report['metrics']['psnr'] >= 24.0
ssim_pass = report['metrics']['ssim'] >= 0.92

print(f"\nAcceptance Criteria:")
print(f"  PSNR ≥ 24 dB: {'✓ PASS' if psnr_pass else '✗ FAIL'}")
print(f"  SSIM ≥ 0.92:  {'✓ PASS' if ssim_pass else '✗ FAIL'}")

if 'counterfactual' in report and report['counterfactual']:
    print(f"\nCounterfactual Sensitivity:")
    for edit_type, results in report['counterfactual'].items():
        print(f"  {edit_type}: ΔL2 = {results['avg_delta_l2']:.6f}")

print("=" * 50)

In [None]:
# Display evaluation visualizations
eval_img_dir = Path(config['save_dir']) / 'eval_images'

print("Teacher vs Student (3 rows: Teacher, Student, Difference):")
display(Image(str(eval_img_dir / 'grid.png'), width=900))

print("\nSide-by-Side Comparison:")
display(Image(str(eval_img_dir / 'comparison.png'), width=600))

## 9. Interactive Generation

Generate images from custom text prompts.

In [None]:
# Load trained model
import torch
from distill_c1.text_encoder import build_text_encoder
from distill_c1.decoder import build_decoder
from dsl.tokens import Vocab
import matplotlib.pyplot as plt
import numpy as np

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load vocabulary
vocab = Vocab()

# Build models
text_encoder = build_text_encoder(vocab_size=len(vocab), pad_id=vocab.pad_id)
decoder = build_decoder()

# Load checkpoint
ckpt_path = Path(config['save_dir']) / 'ema_best.pt'
checkpoint = torch.load(ckpt_path, map_location=device)
text_encoder.load_state_dict(checkpoint['text_encoder'])
decoder.load_state_dict(checkpoint['decoder'])

text_encoder = text_encoder.to(device)
decoder = decoder.to(device)

text_encoder.eval()
decoder.eval()

print("✓ Model loaded successfully")
print(f"Device: {device}")

In [None]:
def generate_image(text: str):
    """Generate image from text prompt."""
    # Tokenize
    tokens = vocab.encode(text, add_special_tokens=True)
    token_ids = torch.tensor([tokens], dtype=torch.long, device=device)
    
    # Generate
    with torch.no_grad():
        e = text_encoder(token_ids, pad_id=vocab.pad_id)
        img = decoder(e)
    
    # Convert to numpy
    img = img.squeeze(0).cpu()
    img = (img + 1.0) / 2.0  # [-1, 1] -> [0, 1]
    img = img.permute(1, 2, 0).numpy()
    img = np.clip(img, 0, 1)
    
    return img

def visualize_generation(text: str):
    """Generate and display image."""
    img = generate_image(text)
    
    plt.figure(figsize=(6, 6))
    plt.imshow(img)
    plt.title(text, fontsize=14, pad=10)
    plt.axis('off')
    plt.tight_layout()
    plt.show()

print("✓ Generation functions ready")

In [None]:
# Example generations
prompts = [
    "There is one red ball.",
    "There are two green cubes.",
    "There are three yellow blocks.",
    "The blue ball is left of the red cube.",
    "The green block is on the yellow ball.",
    "The red cube is in front of the blue block.",
    "There are five red balls.",
    "The yellow cube is right of the green ball.",
]

print("Generating images from prompts...\n")

for prompt in prompts:
    print(f"Prompt: {prompt}")
    visualize_generation(prompt)
    print()

In [None]:
# Interactive: Generate from custom prompt
from ipywidgets import interact, Text

def generate_interactive(prompt):
    if prompt.strip():
        visualize_generation(prompt)
    else:
        print("Please enter a prompt in DSL format.")
        print("Examples:")
        print('  "There is one red ball."')
        print('  "There are two green cubes."')
        print('  "The blue ball is left of the red cube."')

interact(generate_interactive, prompt=Text(value="There is one red ball.", description="Prompt:"))

## 10. Counterfactual Analysis

Visualize how small changes in text affect the generated images.

In [None]:
def compare_prompts(prompt1: str, prompt2: str):
    """Compare images from two prompts."""
    img1 = generate_image(prompt1)
    img2 = generate_image(prompt2)
    diff = np.abs(img1 - img2)
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    axes[0].imshow(img1)
    axes[0].set_title(prompt1, fontsize=10)
    axes[0].axis('off')
    
    axes[1].imshow(img2)
    axes[1].set_title(prompt2, fontsize=10)
    axes[1].axis('off')
    
    axes[2].imshow(diff, cmap='hot')
    axes[2].set_title('Difference (hotter = more change)', fontsize=10)
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Compute metrics
    l2_diff = np.mean(diff ** 2)
    print(f"L2 Difference: {l2_diff:.6f}")

# Color changes
print("Color Change:")
compare_prompts(
    "There is one red ball.",
    "There is one blue ball."
)

# Shape changes
print("\nShape Change:")
compare_prompts(
    "There is one red ball.",
    "There is one red cube."
)

# Number changes
print("\nNumber Change:")
compare_prompts(
    "There are two green cubes.",
    "There are three green cubes."
)

# Relation changes
print("\nRelation Change:")
compare_prompts(
    "The blue ball is left of the red cube.",
    "The blue ball is right of the red cube."
)

## 11. Download Results

Download trained models and results for later use.

In [None]:
# Create zip file with results
import shutil

# Files to include
save_dir = Path(config['save_dir'])
zip_name = 'distill_c1_results'

# Create archive
shutil.make_archive(zip_name, 'zip', save_dir)

print(f"✓ Created {zip_name}.zip")
print(f"\nContents:")
print(f"  - best.pt (best training checkpoint)")
print(f"  - ema_best.pt (best EMA checkpoint - recommended)")
print(f"  - last.pt (latest checkpoint)")
print(f"  - log.json (training metrics)")
print(f"  - report.json (evaluation results)")
print(f"  - samples/ (generated samples during training)")
print(f"  - eval_images/ (evaluation visualizations)")

# Download in Colab
try:
    from google.colab import files
    files.download(f'{zip_name}.zip')
    print(f"\n✓ Download started")
except:
    print(f"\n✓ Zip file ready at: {zip_name}.zip")

## Summary

This notebook demonstrated:

1. ✅ Data generation from DSL and procedural renderer
2. ✅ Model training with FiLM-conditioned decoder
3. ✅ Training visualization and metrics tracking
4. ✅ Model evaluation with PSNR/SSIM
5. ✅ Interactive image generation from text
6. ✅ Counterfactual sensitivity analysis

### Next Steps

- **Phase C2**: Fine-tune with caption loss only (strict regime)
- **Phase C3**: Add adversarial losses for realism
- **Experiment**: Try different DSL compositions
- **Scale up**: Train on larger datasets or longer

### Key Results

- Total parameters: ~11M (2.4M encoder + 8.5M decoder)
- Expected PSNR: 26-28 dB (target: ≥24 dB)
- Expected SSIM: 0.93-0.95 (target: ≥0.92)
- Training time: ~6-8 hours for 100k steps on GPU