# Visual Scene Generation on Google Colab

This notebook sets up and runs the autoregressive visual scene generation system.

## 1. (Optional) Reset Environment

Run this if you're re-running the notebook or encounter nested directories:

In [None]:
# Clean up any existing setup
import os
%cd /content
!rm -rf visual-scene-generation
print("‚úÖ Environment reset")

## 2. Clone Repository

In [None]:
# Clone repository to specific location
import os

if not os.path.exists('/content/visual-scene-generation'):
    print("Cloning repository...")
    !git clone https://github.com/jtooates/visual-scene-generation.git /content/visual-scene-generation
else:
    print("Repository exists, pulling latest...")
    !cd /content/visual-scene-generation && git pull

%cd /content/visual-scene-generation
print(f"\n‚úÖ Working directory: {os.getcwd()}")
!ls -1

## 3. Install Dependencies

In [None]:
!pip install -q tqdm matplotlib scikit-learn

## 4. Check GPU

In [None]:
import torch

if torch.cuda.is_available():
    print(f"‚úÖ GPU: {torch.cuda.get_device_name(0)}")
    !nvidia-smi
else:
    print("‚ö†Ô∏è No GPU! Go to Runtime ‚Üí Change runtime type ‚Üí GPU")

## 5. Mount Google Drive (CRITICAL for saving checkpoints!)

In [None]:
try:
    from google.colab import drive
    drive.mount('/content/drive')
    
    # Create directories on Google Drive
    !mkdir -p /content/drive/MyDrive/visual-scene-generation/checkpoints
    !mkdir -p /content/drive/MyDrive/visual-scene-generation/logs
    
    # Remove any local directories
    !rm -rf /content/visual-scene-generation/checkpoints
    !rm -rf /content/visual-scene-generation/logs
    
    # Create symlinks: local -> Drive
    !ln -s /content/drive/MyDrive/visual-scene-generation/checkpoints /content/visual-scene-generation/checkpoints
    !ln -s /content/drive/MyDrive/visual-scene-generation/logs /content/visual-scene-generation/logs
    
    print("‚úÖ Google Drive mounted and symlinks created!")
    print("üìÅ Checkpoints ‚Üí /content/drive/MyDrive/visual-scene-generation/checkpoints")
    print("üìÅ Logs ‚Üí /content/drive/MyDrive/visual-scene-generation/logs")
    
except Exception as e:
    print(f"‚ö†Ô∏è Drive mount failed: {e}")
    print("Creating local directories (will be lost when session ends)")
    !mkdir -p checkpoints logs

## 6. Verify Setup

In [None]:
import os

print(f"Current directory: {os.getcwd()}\n")

# Check checkpoints
if os.path.islink('checkpoints'):
    target = os.readlink('checkpoints')
    print(f"‚úÖ checkpoints/ ‚Üí {target}")
elif os.path.exists('checkpoints'):
    print(f"‚ö†Ô∏è checkpoints/ is a regular directory (not linked to Drive!)")
else:
    print("‚ùå checkpoints/ doesn't exist (will be created by training script)")

# Check logs
if os.path.islink('logs'):
    target = os.readlink('logs')
    print(f"‚úÖ logs/ ‚Üí {target}")
elif os.path.exists('logs'):
    print(f"‚ö†Ô∏è logs/ is a regular directory (not linked to Drive!)")
else:
    print("‚ùå logs/ doesn't exist (will be created by training script)")

print("\n" + "="*60)
print("If symlinks are shown, checkpoints WILL persist to Drive!")
print("="*60)

## 7. Quick Training (5 epochs, ~5-10 min)

In [None]:
!python train.py \
    --epochs 5 \
    --batch_size 16 \
    --num_samples 1000 \
    --use_vae \
    --lr 0.0001 \
    --lambda_kl 0.001 \
    --log_interval 5

## 8. Full Training (50 epochs, ~30-60 min)

In [None]:
!python train.py \
    --epochs 50 \
    --batch_size 16 \
    --num_samples 10000 \
    --use_vae \
    --lr 0.00005 \
    --d_model 512 \
    --hidden_dim 256 \
    --z_dim 128 \
    --lambda_consistency 1.0 \
    --lambda_spatial 0.1 \
    --lambda_kl 0.001

## 9. View Results

In [None]:
from IPython.display import Image, display
import os

if os.path.exists('logs/training_curves.png'):
    print("Training Curves:")
    display(Image('logs/training_curves.png'))

if os.path.exists('sample_generations.png'):
    print("\nSample Generations:")
    display(Image('sample_generations.png'))

print("\nüìÅ Checkpoints:")
!ls -lh checkpoints/ 2>/dev/null || echo "No checkpoints yet"

print("\nüìÅ Logs:")
!ls -lh logs/ 2>/dev/null || echo "No logs yet"

## 10. Load Models and Generate Scenes

In [None]:
import torch
import os
from models import AutoregressiveLanguageModel, SceneDecoder, CaptionNetwork
from data_utils import SceneDescriptionDataset
import matplotlib.pyplot as plt

def load_checkpoint():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Find latest checkpoint
    checkpoint_dirs = ['checkpoints', '/content/drive/MyDrive/visual-scene-generation/checkpoints']
    checkpoint_path = None
    
    for cp_dir in checkpoint_dirs:
        if os.path.exists(cp_dir):
            files = sorted([f for f in os.listdir(cp_dir) if f.endswith('.pt')])
            if files:
                checkpoint_path = os.path.join(cp_dir, files[-1])
                break
    
    if not checkpoint_path:
        print("‚ùå No checkpoint found! Train first.")
        return None, None, None, None, device
    
    print(f"Loading: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    # Recreate dataset for vocab
    dataset = SceneDescriptionDataset(num_samples=1000, seed=42)
    
    # Initialize models
    ar_model = AutoregressiveLanguageModel(
        vocab_size=dataset.vocab_size, d_model=512
    ).to(device)
    
    scene_decoder = SceneDecoder(
        embedding_dim=512, hidden_dim=256, use_vae=True
    ).to(device)
    
    caption_network = CaptionNetwork(
        vocab_size=dataset.vocab_size, embedding_dim=512
    ).to(device)
    
    # Load weights
    ar_model.load_state_dict(checkpoint['models']['ar_model'])
    scene_decoder.load_state_dict(checkpoint['models']['scene_decoder'])
    caption_network.load_state_dict(checkpoint['models']['caption_network'])
    
    ar_model.eval()
    scene_decoder.eval()
    caption_network.eval()
    
    print(f"‚úÖ Loaded from epoch {checkpoint['epoch']}")
    return ar_model, scene_decoder, caption_network, dataset, device

ar_model, scene_decoder, caption_network, dataset, device = load_checkpoint()

## 11. Generate Custom Scenes

In [None]:
def generate_scene(text):
    if ar_model is None:
        print("‚ùå Load models first!")
        return
    
    with torch.no_grad():
        # Tokenize
        tokens = [dataset.vocab.get(w, dataset.vocab['<UNK>']) for w in text.lower().split()]
        tokens = [dataset.vocab['<SOS>']] + tokens + [dataset.vocab['<EOS>']]
        input_ids = torch.tensor([tokens], dtype=torch.long).to(device)
        
        # Generate
        ar_out = ar_model(input_ids, return_embeddings=True)
        scene_out = scene_decoder(ar_out['embeddings'])
        caption_tokens, _ = caption_network.generate_caption(scene_out['scene'])
        reconstructed = dataset.decode_tokens(caption_tokens[0])
        
        # Visualize
        scene_np = scene_out['scene'][0].cpu().permute(1, 2, 0).numpy()
        
        plt.figure(figsize=(12, 5))
        plt.subplot(1, 2, 1)
        plt.imshow(scene_np)
        plt.title(f"Input: {text}")
        plt.axis('off')
        
        plt.subplot(1, 2, 2)
        plt.imshow(scene_np)
        plt.title(f"Reconstructed: {reconstructed}")
        plt.axis('off')
        plt.show()

# Try it!
if ar_model is not None:
    for text in ["red ball in center", "blue cube on left", "yellow sphere floating"]:
        generate_scene(text)