# Visual Scene Generation on Google Colab

This notebook sets up and runs the autoregressive visual scene generation system on Google Colab.

## 1. Clone the Repository

In [None]:
# Clone the repository
!git clone https://github.com/jtooates/visual-scene-generation.git
%cd visual-scene-generation

## 2. Install Dependencies

In [None]:
# Install required packages (most are pre-installed in Colab)
!pip install -q tqdm matplotlib scikit-learn

## 3. Check GPU Availability

In [None]:
import torch
import subprocess

# Check if GPU is available
if torch.cuda.is_available():
    device = torch.cuda.get_device_name(0)
    print(f"GPU Available: {device}")
    !nvidia-smi
else:
    print("No GPU available. Using CPU.")
    print("Go to Runtime > Change runtime type > GPU for better performance")

## 4. (Optional) Mount Google Drive for Persistent Storage

**Recommended**: Mount Drive to save your checkpoints permanently. If you skip this, checkpoints will be lost when the session ends!

In [None]:
try:
    from google.colab import drive
    drive.mount('/content/drive')
    
    # Create directories on Google Drive first
    !mkdir -p /content/drive/MyDrive/visual-scene-generation/checkpoints
    !mkdir -p /content/drive/MyDrive/visual-scene-generation/logs
    
    # Remove any existing local directories
    !rm -rf /content/visual-scene-generation/checkpoints
    !rm -rf /content/visual-scene-generation/logs
    
    # Create symlinks FROM local TO Drive (so writes go to Drive)
    !ln -s /content/drive/MyDrive/visual-scene-generation/checkpoints /content/visual-scene-generation/checkpoints
    !ln -s /content/drive/MyDrive/visual-scene-generation/logs /content/visual-scene-generation/logs
    
    # Verify setup
    print("‚úÖ Google Drive mounted! Checkpoints will be saved persistently.")
    print("üìÅ Checkpoints: /content/drive/MyDrive/visual-scene-generation/checkpoints")
    print("üìÅ Logs: /content/drive/MyDrive/visual-scene-generation/logs")
    print("\nVerifying symlinks:")
    !ls -la /content/visual-scene-generation/ | grep -E "checkpoints|logs"
except Exception as e:
    print(f"‚ö†Ô∏è Google Drive setup failed: {e}")
    print("   Creating local directories - checkpoints will be temporary!")
    !mkdir -p /content/visual-scene-generation/checkpoints
    !mkdir -p /content/visual-scene-generation/logs

## 5. Quick Test Run (Small Dataset)

In [None]:
import os

print("Current working directory:", os.getcwd())
print("\nCheckpoint directory check:")

# Check if checkpoints dir exists and is a symlink
if os.path.exists('checkpoints'):
    if os.path.islink('checkpoints'):
        target = os.readlink('checkpoints')
        print(f"‚úÖ checkpoints/ is a symlink to: {target}")
    else:
        print(f"‚ö†Ô∏è checkpoints/ is a regular directory (not linked to Drive)")
    print(f"   Resolved path: {os.path.realpath('checkpoints')}")
else:
    print("‚ùå checkpoints/ does not exist yet")

print("\nLogs directory check:")
if os.path.exists('logs'):
    if os.path.islink('logs'):
        target = os.readlink('logs')
        print(f"‚úÖ logs/ is a symlink to: {target}")
    else:
        print(f"‚ö†Ô∏è logs/ is a regular directory (not linked to Drive)")
    print(f"   Resolved path: {os.path.realpath('logs')}")
else:
    print("‚ùå logs/ does not exist yet")
    
print("\n" + "="*60)
print("If directories are symlinked to Drive, checkpoints will persist!")
print("="*60)

## 6. Quick Test Run (Small Dataset)

In [ ]:
# Run a quick test with small dataset to verify everything works
!python train.py \
    --epochs 5 \
    --batch_size 16 \
    --num_samples 1000 \
    --use_vae \
    --log_interval 5 \
    --lr 0.0001 \
    --lambda_kl 0.0001

## 7. Full Training Run

## 8. Visualize Training Results

## 9. Interactive Scene Generation

## 8. Interactive Scene Generation

In [None]:
import torch
import os
from models import AutoregressiveLanguageModel, SceneDecoder, CaptionNetwork
from data_utils import SceneDescriptionDataset
import matplotlib.pyplot as plt

def load_models_from_checkpoint(checkpoint_path=None):
    """
    Load trained models from checkpoint.
    If no checkpoint specified, finds the latest one.
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Find checkpoint if not specified
    if checkpoint_path is None:
        # Check both local and Drive paths
        checkpoint_dirs = ['checkpoints', '/content/drive/MyDrive/visual-scene-generation/checkpoints']
        checkpoint_path = None
        
        for checkpoint_dir in checkpoint_dirs:
            if os.path.exists(checkpoint_dir):
                checkpoints = sorted([f for f in os.listdir(checkpoint_dir) if f.endswith('.pt')])
                if checkpoints:
                    checkpoint_path = os.path.join(checkpoint_dir, checkpoints[-1])
                    break
        
        if checkpoint_path is None:
            print("‚ùå No checkpoint found! Train the model first.")
            return None, None, None, None, device
    
    print(f"Loading checkpoint: {checkpoint_path}")
    
    # Load checkpoint to get vocab size
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    # We need to create a dataset to get vocab (or extract from checkpoint if saved)
    # For now, recreate dataset - it will have same vocab if same seed
    dataset = SceneDescriptionDataset(num_samples=1000, seed=42)
    vocab_size = dataset.vocab_size
    
    print(f"Vocabulary size: {vocab_size}")
    
    # Initialize models with same architecture as training
    ar_model = AutoregressiveLanguageModel(
        vocab_size=vocab_size,
        d_model=512,  # Use default or match your training config
        n_heads=8,
        n_layers=6
    ).to(device)
    
    scene_decoder = SceneDecoder(
        embedding_dim=512,
        hidden_dim=256,
        use_vae=True,
        z_dim=128
    ).to(device)
    
    caption_network = CaptionNetwork(
        vocab_size=vocab_size,
        embedding_dim=512,
        hidden_dim=256
    ).to(device)
    
    # Load state dicts
    try:
        ar_model.load_state_dict(checkpoint['models']['ar_model'])
        scene_decoder.load_state_dict(checkpoint['models']['scene_decoder'])
        caption_network.load_state_dict(checkpoint['models']['caption_network'])
        
        # Set to eval mode
        ar_model.eval()
        scene_decoder.eval()
        caption_network.eval()
        
        print(f"‚úÖ Models loaded successfully from epoch {checkpoint['epoch']}")
        print(f"   Training loss: {checkpoint['loss']:.4f}")
        
    except Exception as e:
        print(f"‚ùå Error loading models: {e}")
        return None, None, None, None, device
    
    return ar_model, scene_decoder, caption_network, dataset, device

# Load the models
ar_model, scene_decoder, caption_network, dataset, device = load_models_from_checkpoint()

if ar_model is not None:
    print("\n‚úÖ Ready for scene generation!")
else:
    print("\n‚ö†Ô∏è Please train the model first before running generation.")

In [ ]:
def generate_scene_from_text(text, ar_model, scene_decoder, caption_network, dataset, device):
    """Generate a scene from custom text input"""
    if ar_model is None:
        print("‚ùå Models not loaded. Run the previous cell first!")
        return
    
    ar_model.eval()
    scene_decoder.eval()
    caption_network.eval()
    
    with torch.no_grad():
        # Tokenize input text
        tokens = [dataset.vocab.get(word, dataset.vocab['<UNK>']) for word in text.lower().split()]
        tokens = [dataset.vocab['<SOS>']] + tokens + [dataset.vocab['<EOS>']]
        input_ids = torch.tensor([tokens], dtype=torch.long).to(device)
        
        # Generate embedding
        ar_outputs = ar_model(input_ids, return_embeddings=True)
        text_embedding = ar_outputs['embeddings']
        
        # Generate scene
        scene_outputs = scene_decoder(text_embedding)
        scene = scene_outputs['scene']
        
        # Generate caption from scene
        generated_caption, _ = caption_network.generate_caption(scene)
        reconstructed_text = dataset.decode_tokens(generated_caption[0])
        
        # Visualize
        scene_np = scene[0].cpu().permute(1, 2, 0).numpy()
        
        plt.figure(figsize=(12, 5))
        plt.subplot(1, 2, 1)
        plt.imshow(scene_np)
        plt.title(f"Original: {text}", fontsize=12, wrap=True)
        plt.axis('off')
        
        plt.subplot(1, 2, 2)
        plt.imshow(scene_np)
        plt.title(f"Reconstructed: {reconstructed_text}", fontsize=12, wrap=True)
        plt.axis('off')
        
        plt.tight_layout()
        plt.show()
        
        return scene_np, reconstructed_text

# Test with custom text (run after loading models)
if ar_model is not None:
    test_texts = [
        "a red ball in the center",
        "blue cube on the left",
        "yellow sphere floating",
        "large green triangle"
    ]
    
    print("Generating scenes for test inputs...\n")
    for text in test_texts:
        print(f"Input: {text}")
        generate_scene_from_text(text, ar_model, scene_decoder, caption_network, dataset, device)
        print("-" * 50)
else:
    print("‚ö†Ô∏è Skip this cell - models not loaded yet")

In [None]:
# Generate scene from your own custom text!
# Change this to whatever you want:
custom_text = "a tiny red sphere on the right"

if ar_model is not None:
    print(f"Generating scene for: '{custom_text}'\n")
    scene, caption = generate_scene_from_text(
        custom_text, 
        ar_model, 
        scene_decoder, 
        caption_network, 
        dataset, 
        device
    )
else:
    print("‚ö†Ô∏è Load models first (run cell 8)")

## 9. Resume Training from Checkpoint

## Tips for Colab

1. **Enable GPU**: Go to Runtime ‚Üí Change runtime type ‚Üí Hardware accelerator ‚Üí GPU
2. **Prevent Disconnection**: Keep the tab active or use Colab Pro for longer sessions
3. **Save Progress**: Regularly save checkpoints to Google Drive
4. **Monitor Memory**: Use smaller batch sizes if you encounter OOM errors
5. **Use Mixed Precision**: Add `--use_amp` flag for faster training with less memory