# Shape Scene Generator - Training in Google Colab

This notebook trains the Conditional VAE model for caption-to-image generation.

## Setup: Clone Repository and Install Dependencies

In [None]:
# Clone the repository
!git clone https://github.com/jtooates/learning_to_see.git
%cd learning_to_see

In [None]:
# Install dependencies
!pip install -q torch torchvision Pillow numpy matplotlib tqdm

In [None]:
# Check if GPU is available
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## Test Data Generation

In [None]:
from src.data.dataset import ShapeSceneDataset
import matplotlib.pyplot as plt

# Create a small test dataset
dataset = ShapeSceneDataset(size=10, canvas_size=256, seed=42)

# Visualize a few samples
fig, axes = plt.subplots(2, 3, figsize=(12, 8))
axes = axes.flatten()

for i in range(6):
    image_array, caption = dataset.get_raw_sample(i)
    axes[i].imshow(image_array)
    axes[i].set_title(caption, fontsize=8, wrap=True)
    axes[i].axis('off')

plt.tight_layout()
plt.show()

print("✓ Data generation working!")

## Train the Model

In [None]:
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from src.data.dataset import ShapeSceneDataset, collate_fn
from src.generation.models.cvae import ConditionalVAE
from src.generation.training.losses import VAELoss
from src.generation.training.trainer import Trainer
from src.generation.utils.tokenizer import CaptionTokenizer

# Hyperparameters
BATCH_SIZE = 32
NUM_EPOCHS = 30  # Reduced for Colab
LEARNING_RATE = 1e-4
DATASET_SIZE = 5000  # Smaller for faster training in Colab
IMAGE_SIZE = 256
LATENT_DIM = 128
CAPTION_DIM = 256

print("Creating dataset...")
full_dataset = ShapeSceneDataset(size=DATASET_SIZE, canvas_size=IMAGE_SIZE, seed=42)

# Split into train/val
val_size = int(0.1 * len(full_dataset))
train_size = len(full_dataset) - val_size
train_dataset, val_dataset = random_split(
    full_dataset, [train_size, val_size],
    generator=torch.Generator().manual_seed(42)
)

print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}")

In [None]:
# Build tokenizer
print("Building tokenizer...")
tokenizer = CaptionTokenizer(max_length=32)
sample_captions = [full_dataset[i][1] for i in range(min(1000, len(full_dataset)))]
tokenizer.fit(sample_captions)
print(f"Vocabulary size: {tokenizer.get_vocab_size()}")

In [None]:
# Create data loaders
train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True, 
    num_workers=2, collate_fn=collate_fn
)
val_loader = DataLoader(
    val_dataset, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=2, collate_fn=collate_fn
)

In [None]:
# Create model
print("Creating model...")
model = ConditionalVAE(
    vocab_size=tokenizer.get_vocab_size(),
    image_size=IMAGE_SIZE,
    latent_dim=LATENT_DIM,
    caption_dim=CAPTION_DIM,
)

print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# Setup training
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.5)
criterion = VAELoss(
    reconstruction_loss="mse",
    kl_weight=0.001,
    kl_annealing=True,
    kl_annealing_epochs=10,
)

trainer = Trainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    criterion=criterion,
    device=device,
    checkpoint_dir="checkpoints",
    scheduler=scheduler,
    tokenize_fn=tokenizer.encode,
)

In [None]:
# Train!
print("Starting training...")
trainer.train(num_epochs=NUM_EPOCHS, save_every=10)

## Visualize Results

In [None]:
from src.generation.utils.visualization import plot_training_curves

# Plot training curves
plot_training_curves(
    trainer.train_losses,
    trainer.val_losses,
    save_path="training_curves.png"
)

In [None]:
from src.generation.utils.visualization import visualize_reconstruction

# Visualize reconstructions
model.eval()
val_images, val_captions = next(iter(val_loader))
val_images = val_images[:8].to(device)
val_captions = val_captions[:8]

val_tokens = torch.stack([
    torch.tensor(tokenizer.encode(caption), dtype=torch.long)
    for caption in val_captions
]).to(device)

with torch.no_grad():
    reconstructed = model.reconstruct(val_images, val_tokens)

visualize_reconstruction(
    val_images, reconstructed, val_captions,
    save_path="reconstruction.png", num_samples=8
)

In [None]:
from src.generation.utils.visualization import visualize_generation

# Generate from text
test_captions = [
    "a large blue square above 3 small red circles",
    "2 medium green triangles",
    "a small yellow rectangle left of a large purple square",
    "4 orange circles",
    "a large pink triangle below 2 small cyan squares",
    "3 medium brown rectangles",
]

test_tokens = torch.stack([
    torch.tensor(tokenizer.encode(caption), dtype=torch.long)
    for caption in test_captions
]).to(device)

with torch.no_grad():
    generated = model.generate(test_tokens, num_samples=1)

visualize_generation(
    generated, test_captions,
    save_path="generated_samples.png"
)

## Save Model to Google Drive (Optional)

In [None]:
# Uncomment to save to Google Drive
# from google.colab import drive
# drive.mount('/content/drive')

# !cp -r checkpoints /content/drive/MyDrive/learning_to_see_checkpoints
# !cp *.png /content/drive/MyDrive/learning_to_see_checkpoints/
# print("Saved to Google Drive!")