<a href="https://colab.research.google.com/github/jtooates/latent/blob/main/colab_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Latent Canvas Painter - GPU Training on Colab

This notebook sets up and trains the DRAW-style canvas painter model on Google Colab with GPU acceleration.

## Features:
- Automatic GPU detection and usage
- Google Drive integration for checkpoint persistence
- Optimized batch size for GPU training
- Resume training from saved checkpoints

## 1. Setup Environment

In [None]:
# Check GPU availability
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("WARNING: No GPU detected. Training will be slow.")

In [None]:
# Mount Google Drive for checkpoint persistence
from google.colab import drive
drive.mount('/content/drive')

# Create directory for this project
import os
project_dir = '/content/drive/MyDrive/latent_training'
os.makedirs(project_dir, exist_ok=True)
print(f"Project directory: {project_dir}")

In [None]:
# Clone the repository
!git clone https://github.com/jtooates/latent.git
%cd latent

In [None]:
# Install dependencies (Pillow for image visualization)
!pip install -q Pillow
print("Dependencies installed!")

## 2. Configure Training Parameters

In [None]:
# Training configuration
EPOCHS = 100
BATCH_SIZE = 128  # Larger batch size for GPU
LEARNING_RATE = 0.001
LATENT_SIZE = 48  # Larger canvas
PATCH_SIZE = 7  # Larger patches
NUM_STEPS = 0  # Adaptive (use num tokens)

# Checkpoint settings
CHECKPOINT_DIR = f"{project_dir}/checkpoints"
RESUME_CHECKPOINT = None  # Set to checkpoint path to resume

# Display settings
print("Training Configuration:")
print(f"  Epochs: {EPOCHS}")
print(f"  Batch Size: {BATCH_SIZE}")
print(f"  Learning Rate: {LEARNING_RATE}")
print(f"  Canvas Size: {LATENT_SIZE}x{LATENT_SIZE}")
print(f"  Patch Size: {PATCH_SIZE}x{PATCH_SIZE}")
print(f"  Painting Steps: {'adaptive' if NUM_STEPS == 0 else NUM_STEPS}")
print(f"  Checkpoint Dir: {CHECKPOINT_DIR}")

## 3. Train Model

This cell will start training. On a GPU, training should be significantly faster than CPU.

**Expected speeds:**
- CPU: ~2-5 minutes per epoch
- GPU (T4): ~10-30 seconds per epoch
- GPU (V100/A100): ~5-15 seconds per epoch

In [None]:
# Build training command
cmd = f"""python train.py \
    --data training_data.json \
    --config training_config.json \
    --use-canvas-painter \
    --epochs {EPOCHS} \
    --batch-size {BATCH_SIZE} \
    --lr {LEARNING_RATE} \
    --latent-size {LATENT_SIZE} \
    --painter-patch-size {PATCH_SIZE} \
    --painter-num-steps {NUM_STEPS} \
    --output-dir {CHECKPOINT_DIR}"""

# Add resume flag if checkpoint specified
if RESUME_CHECKPOINT:
    cmd += f" --resume {RESUME_CHECKPOINT}"

print("Starting training...\n")
!{cmd}

## 4. Resume Training (Optional)

If training was interrupted, you can resume from the last checkpoint.

In [None]:
# List available checkpoints
import os
checkpoint_path = f"{CHECKPOINT_DIR}/best_model.pt"

if os.path.exists(checkpoint_path):
    print(f"Found checkpoint: {checkpoint_path}")
    print("\nTo resume training, run the cell below or update RESUME_CHECKPOINT above and rerun training cell.")
else:
    print("No checkpoint found yet. Train first!")

In [None]:
# Quick resume command (runs 50 more epochs)
!python train.py \
    --data training_data.json \
    --config training_config.json \
    --resume {CHECKPOINT_DIR}/best_model.pt \
    --epochs 150 \
    --batch-size {BATCH_SIZE} \
    --output-dir {CHECKPOINT_DIR}

## 5. Visualize Results

In [None]:
# Display some latent images
from IPython.display import Image, display
import glob

# Get latest epoch images
image_dir = f"{CHECKPOINT_DIR}/latent_images"
if os.path.exists(image_dir):
    images = sorted(glob.glob(f"{image_dir}/*.png"))[-8:]  # Last 8 images
    print(f"Displaying {len(images)} latent images:\n")
    for img_path in images:
        print(os.path.basename(img_path))
        display(Image(img_path, width=256))
else:
    print("No images generated yet. Set --save-images-every in training.")

In [None]:
# Generate saliency visualization for a test sentence
test_sentence = "red circle on blue square"
output_path = f"{project_dir}/saliency_test.png"

!python visualize_saliency.py \
    --checkpoint {CHECKPOINT_DIR}/best_model.pt \
    --data training_data.json \
    --sentence "{test_sentence}" \
    --output {output_path}

# Display result
if os.path.exists(output_path):
    print(f"\nSaliency map for: '{test_sentence}'")
    display(Image(output_path, width=512))

## 6. Download Trained Model

In [None]:
# Download the best model to your local machine
from google.colab import files

checkpoint_file = f"{CHECKPOINT_DIR}/best_model.pt"
if os.path.exists(checkpoint_file):
    print(f"Downloading checkpoint: {checkpoint_file}")
    files.download(checkpoint_file)
else:
    print("No checkpoint found!")

## 7. Tips & Troubleshooting

### GPU Memory Issues
If you run out of GPU memory, try:
- Reduce `BATCH_SIZE` (try 64 or 32)
- Reduce `LATENT_SIZE` (try 32 instead of 48)

### Speed Optimization
- Use larger batch sizes on GPU (128-256)
- Enable mixed precision training (add `--amp` flag)
- Use persistent workers for data loading

### Monitoring Training
- Check `{CHECKPOINT_DIR}/training.log` for detailed logs
- Latent images are saved every 10 epochs by default
- Best model is saved automatically based on validation loss

### Checkpoints in Google Drive
All checkpoints are saved to Google Drive, so:
- They persist across Colab sessions
- You can resume training anytime
- You can download them to your local machine

### Colab Runtime Limits
- Free tier: ~12 hour session limit
- Pro tier: ~24 hour session limit
- If disconnected, just rerun and resume from checkpoint!