# Visual Scene Generation on Google Colab

This notebook sets up and runs the autoregressive visual scene generation system on Google Colab.

## 1. Setup Environment

**IMPORTANT**: If you're re-running the notebook or get nested directories, run this first:

In [None]:
# Clean up any existing setup (run this if you get nested directories)
import os

# Go to content root
%cd /content

# Remove old directory if it exists
!rm -rf visual-scene-generation

print("✅ Environment reset. Proceed to next cell.")

## 2. Clone Repository

## 2. Install Dependencies

## 3. Install Dependencies

## 3. Check GPU Availability

## 4. Check GPU Availability

## 4. (Optional) Mount Google Drive for Persistent Storage

**Recommended**: Mount Drive to save your checkpoints permanently. If you skip this, checkpoints will be lost when the session ends!

## 5. Mount Google Drive for Persistent Storage

**CRITICAL**: Run this to save checkpoints permanently!

## 5. Quick Test Run (Small Dataset)

## 6. Verify Checkpoint Setup

Run this to check where checkpoints will be saved:

## 6. Quick Test Run (Small Dataset)

## 7. Quick Test Run (Small Dataset)

## 7. Full Training Run

## 8. Full Training Run

## 9. Interactive Scene Generation

## 9. Visualize Training Results

In [None]:
import torch
import os
from models import AutoregressiveLanguageModel, SceneDecoder, CaptionNetwork
from data_utils import SceneDescriptionDataset
import matplotlib.pyplot as plt

def load_models_from_checkpoint(checkpoint_path=None):
    """
    Load trained models from checkpoint.
    If no checkpoint specified, finds the latest one.
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Find checkpoint if not specified
    if checkpoint_path is None:
        # Check both local and Drive paths
        checkpoint_dirs = ['checkpoints', '/content/drive/MyDrive/visual-scene-generation/checkpoints']
        checkpoint_path = None
        
        for checkpoint_dir in checkpoint_dirs:
            if os.path.exists(checkpoint_dir):
                checkpoints = sorted([f for f in os.listdir(checkpoint_dir) if f.endswith('.pt')])
                if checkpoints:
                    checkpoint_path = os.path.join(checkpoint_dir, checkpoints[-1])
                    break
        
        if checkpoint_path is None:
            print("❌ No checkpoint found! Train the model first.")
            return None, None, None, None, device
    
    print(f"Loading checkpoint: {checkpoint_path}")
    
    # Load checkpoint to get vocab size
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    # We need to create a dataset to get vocab (or extract from checkpoint if saved)
    # For now, recreate dataset - it will have same vocab if same seed
    dataset = SceneDescriptionDataset(num_samples=1000, seed=42)
    vocab_size = dataset.vocab_size
    
    print(f"Vocabulary size: {vocab_size}")
    
    # Initialize models with same architecture as training
    ar_model = AutoregressiveLanguageModel(
        vocab_size=vocab_size,
        d_model=512,  # Use default or match your training config
        n_heads=8,
        n_layers=6
    ).to(device)
    
    scene_decoder = SceneDecoder(
        embedding_dim=512,
        hidden_dim=256,
        use_vae=True,
        z_dim=128
    ).to(device)
    
    caption_network = CaptionNetwork(
        vocab_size=vocab_size,
        embedding_dim=512,
        hidden_dim=256
    ).to(device)
    
    # Load state dicts
    try:
        ar_model.load_state_dict(checkpoint['models']['ar_model'])
        scene_decoder.load_state_dict(checkpoint['models']['scene_decoder'])
        caption_network.load_state_dict(checkpoint['models']['caption_network'])
        
        # Set to eval mode
        ar_model.eval()
        scene_decoder.eval()
        caption_network.eval()
        
        print(f"✅ Models loaded successfully from epoch {checkpoint['epoch']}")
        print(f"   Training loss: {checkpoint['loss']:.4f}")
        
    except Exception as e:
        print(f"❌ Error loading models: {e}")
        return None, None, None, None, device
    
    return ar_model, scene_decoder, caption_network, dataset, device

# Load the models
ar_model, scene_decoder, caption_network, dataset, device = load_models_from_checkpoint()

if ar_model is not None:
    print("\n✅ Ready for scene generation!")
else:
    print("\n⚠️ Please train the model first before running generation.")

## 10. Interactive Scene Generation

In [None]:
# Generate scene from your own custom text!
# Change this to whatever you want:
custom_text = "a tiny red sphere on the right"

if ar_model is not None:
    print(f"Generating scene for: '{custom_text}'\n")
    scene, caption = generate_scene_from_text(
        custom_text, 
        ar_model, 
        scene_decoder, 
        caption_network, 
        dataset, 
        device
    )
else:
    print("⚠️ Load models first (run cell 8)")

## 9. Resume Training from Checkpoint

## Tips for Colab

1. **Enable GPU**: Go to Runtime → Change runtime type → Hardware accelerator → GPU
2. **Prevent Disconnection**: Keep the tab active or use Colab Pro for longer sessions
3. **Save Progress**: Regularly save checkpoints to Google Drive
4. **Monitor Memory**: Use smaller batch sizes if you encounter OOM errors
5. **Use Mixed Precision**: Add `--use_amp` flag for faster training with less memory