# Text-to-Image Distillation Training (C1)

Train a neural network to generate 64×64 images from text captions.

**Architecture:** 4-layer Transformer encoder + FiLM-conditioned CNN decoder  
**Parameters:** ~11M total (2.4M encoder + 8.5M decoder)  
**Training time:** ~6-8 hours for 100k steps on GPU

## 1. Setup

In [None]:
# Check GPU
import torch
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Clone repo
import os
if not os.path.exists('learning_to_see'):
    !git clone https://github.com/jtooates/learning_to_see.git
%cd learning_to_see

In [None]:
# Install dependencies
!pip install -q pillow numpy regex tqdm pyyaml matplotlib ipywidgets pytest
import sys
sys.path.insert(0, '/content/learning_to_see')
print("✓ Dependencies installed")

## 2. Generate Data

In [None]:
# Generate scenes (adjust --n for dataset size)
!python -m data.gen --out_dir data/scenes --n 2000 --split_strategy random --seed 42

In [None]:
# Visualize samples
!python visualize_samples.py --data_dir data/scenes --n 16 --save_path samples.png
import os
from IPython.display import Image, display
if os.path.exists('samples.png'):
    display(Image('samples.png'))

## 3. Run Tests

In [None]:
!python -m pytest distill_c1/tests_distill.py -v

## 4. Training Configuration

In [None]:
# Configure training
config = {
    'data_dir': 'data/scenes',
    'save_dir': 'runs/distill_c1',
    'steps': 30000,      # Reduce for faster (full: 100000)
    'batch': 96,         # Adjust for GPU (T4: 64, A100: 192)
    'eval_every': 2000,
    'lr': 3e-4,
    'seed': 1337,
}

for k, v in config.items():
    print(f"{k:15s}: {v}")

## 5. Quick Test (Optional)

In [None]:
# Quick sanity check (~5 min)
!python -m distill_c1.train_distill   --data_dir data/scenes   --save_dir runs/test   --steps 500   --batch 64   --eval_every 250   --seed 42

## 6. Full Training

In [None]:
# Train model
!python -m distill_c1.train_distill   --data_dir {config['data_dir']}   --save_dir {config['save_dir']}   --steps {config['steps']}   --batch {config['batch']}   --eval_every {config['eval_every']}   --lr {config['lr']}   --seed {config['seed']}   --use_amp

## 7. Visualize Training

In [None]:
# Plot metrics
import json, matplotlib.pyplot as plt
from pathlib import Path

with open(Path(config['save_dir']) / 'log.json') as f:
    log = json.load(f)

steps = [e['step'] for e in log]
psnr = [e['metrics']['psnr'] for e in log]
ssim = [e['metrics']['ssim'] for e in log]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
ax1.plot(steps, psnr)
ax1.axhline(24, color='r', linestyle='--', label='Target')
ax1.set_xlabel('Step')
ax1.set_ylabel('PSNR (dB)')
ax1.legend()
ax1.grid(True, alpha=0.3)

ax2.plot(steps, ssim, color='orange')
ax2.axhline(0.92, color='r', linestyle='--', label='Target')
ax2.set_xlabel('Step')
ax2.set_ylabel('SSIM')
ax2.legend()
ax2.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print(f"Final - PSNR: {psnr[-1]:.2f} dB, SSIM: {ssim[-1]:.4f}")

## 8. Evaluation

In [None]:
# Evaluate model
!python -m distill_c1.eval_distill   --data_dir {config['data_dir']}   --ckpt {config['save_dir']}/ema_best.pt   --report {config['save_dir']}/report.json   --save_images {config['save_dir']}/eval_images   --counterfactual

In [None]:
# Show report
with open(Path(config['save_dir']) / 'report.json') as f:
    report = json.load(f)

print(f"PSNR: {report['metrics']['psnr']:.2f} dB")
print(f"SSIM: {report['metrics']['ssim']:.4f}")
print(f"PSNR >= 24: {'✓' if report['metrics']['psnr'] >= 24 else '✗'}")
print(f"SSIM >= 0.92: {'✓' if report['metrics']['ssim'] >= 0.92 else '✗'}")

In [None]:
# Show eval images
from IPython.display import Image, display
eval_dir = Path(config['save_dir']) / 'eval_images'
display(Image(str(eval_dir / 'grid.png'), width=800))

## 9. Interactive Generation

In [None]:
# Load model
import torch, numpy as np
from distill_c1.text_encoder import build_text_encoder
from distill_c1.decoder import build_decoder
from dsl.tokens import Vocab

device = 'cuda' if torch.cuda.is_available() else 'cpu'
vocab = Vocab()

text_encoder = build_text_encoder(vocab_size=len(vocab), pad_id=vocab.pad_id)
decoder = build_decoder()

ckpt = torch.load(Path(config['save_dir']) / 'ema_best.pt', map_location=device)
text_encoder.load_state_dict(ckpt['text_encoder'])
decoder.load_state_dict(ckpt['decoder'])

text_encoder.to(device).eval()
decoder.to(device).eval()

print(f"✓ Model loaded on {device}")

In [None]:
# Generate function
def generate(text):
    tokens = vocab.encode(text, add_special_tokens=True)
    token_ids = torch.tensor([tokens], device=device)
    with torch.no_grad():
        e = text_encoder(token_ids, vocab.pad_id)
        img = decoder(e)
    img = ((img[0].cpu() + 1) / 2).permute(1, 2, 0).numpy()
    return np.clip(img, 0, 1)

def show(text):
    plt.figure(figsize=(5, 5))
    plt.imshow(generate(text))
    plt.title(text, fontsize=12)
    plt.axis('off')
    plt.show()

print("✓ Ready to generate")

In [None]:
# Example generations
prompts = [
    "There is one red ball.",
    "There are two green cubes.",
    "The blue ball is left of the red cube.",
    "The green block is on the yellow ball.",
]

for p in prompts:
    show(p)

## 10. Download Results

In [None]:
# Create zip
import shutil
shutil.make_archive('results', 'zip', config['save_dir'])

try:
    from google.colab import files
    files.download('results.zip')
    print("✓ Download started")
except:
    print("✓ results.zip created")