# Visual Scene Generation on Google Colab

This notebook sets up and runs the autoregressive visual scene generation system on Google Colab.

## 1. Clone the Repository

In [None]:
# Clone the repository
!git clone https://github.com/jtooates/visual-scene-generation.git
%cd visual-scene-generation

## 2. Install Dependencies

In [None]:
# Install required packages (most are pre-installed in Colab)
!pip install -q tqdm matplotlib scikit-learn

## 3. Check GPU Availability

In [None]:
import torch
import subprocess

# Check if GPU is available
if torch.cuda.is_available():
    device = torch.cuda.get_device_name(0)
    print(f"GPU Available: {device}")
    !nvidia-smi
else:
    print("No GPU available. Using CPU.")
    print("Go to Runtime > Change runtime type > GPU for better performance")

## 4. Quick Test Run (Small Dataset)

In [None]:
# Run a quick test with small dataset to verify everything works
!python train.py \
    --epochs 5 \
    --batch_size 16 \
    --num_samples 1000 \
    --use_vae \
    --log_interval 5

## 5. Full Training Run

In [None]:
# Full training with recommended settings
!python train.py \
    --epochs 50 \
    --batch_size 32 \
    --num_samples 10000 \
    --use_vae \
    --use_amp \
    --lr 0.001 \
    --d_model 512 \
    --hidden_dim 256 \
    --z_dim 128 \
    --lambda_consistency 1.0 \
    --lambda_spatial 0.1 \
    --lambda_kl 0.001

## 6. Visualize Training Results

In [None]:
import matplotlib.pyplot as plt
from IPython.display import Image, display
import os

# Display training curves
if os.path.exists('logs/training_curves.png'):
    display(Image('logs/training_curves.png'))
else:
    print("Training curves not yet generated")

# Display sample generations
if os.path.exists('sample_generations.png'):
    display(Image('sample_generations.png'))
else:
    print("Sample generations not yet created")

## 7. Interactive Scene Generation

In [None]:
import torch
from models import AutoregressiveLanguageModel, SceneDecoder, CaptionNetwork
from data_utils import SceneDescriptionDataset
from visualization import visualize_scenes
import matplotlib.pyplot as plt

# Load the trained models
def load_trained_models(checkpoint_path='checkpoints/checkpoint_epoch_49.pt'):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Create dataset for vocabulary
    dataset = SceneDescriptionDataset(num_samples=100)
    
    # Initialize models
    ar_model = AutoregressiveLanguageModel(
        vocab_size=dataset.vocab_size,
        d_model=512
    ).to(device)
    
    scene_decoder = SceneDecoder(
        embedding_dim=512,
        hidden_dim=256,
        use_vae=True
    ).to(device)
    
    caption_network = CaptionNetwork(
        vocab_size=dataset.vocab_size,
        embedding_dim=512
    ).to(device)
    
    # Load checkpoint if exists
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=device)
        ar_model.load_state_dict(checkpoint['models']['ar_model'])
        scene_decoder.load_state_dict(checkpoint['models']['scene_decoder'])
        caption_network.load_state_dict(checkpoint['models']['caption_network'])
        print(f"Loaded checkpoint from {checkpoint_path}")
    else:
        print("No checkpoint found, using random initialization")
    
    return ar_model, scene_decoder, caption_network, dataset, device

# Load models
ar_model, scene_decoder, caption_network, dataset, device = load_trained_models()

In [None]:
def generate_scene_from_text(text, ar_model, scene_decoder, caption_network, dataset, device):
    """Generate a scene from custom text input"""
    ar_model.eval()
    scene_decoder.eval()
    caption_network.eval()
    
    with torch.no_grad():
        # Tokenize input text
        tokens = [dataset.vocab.get(word, dataset.vocab['<UNK>']) for word in text.lower().split()]
        tokens = [dataset.vocab['<SOS>']] + tokens + [dataset.vocab['<EOS>']]
        input_ids = torch.tensor([tokens], dtype=torch.long).to(device)
        
        # Generate embedding
        ar_outputs = ar_model(input_ids, return_embeddings=True)
        text_embedding = ar_outputs['embeddings']
        
        # Generate scene
        scene_outputs = scene_decoder(text_embedding)
        scene = scene_outputs['scene']
        
        # Generate caption from scene
        generated_caption, _ = caption_network.generate_caption(scene)
        reconstructed_text = dataset.decode_tokens(generated_caption[0])
        
        # Visualize
        scene_np = scene[0].cpu().permute(1, 2, 0).numpy()
        
        plt.figure(figsize=(10, 5))
        plt.subplot(1, 2, 1)
        plt.imshow(scene_np)
        plt.title(f"Original: {text}")
        plt.axis('off')
        
        plt.subplot(1, 2, 2)
        plt.imshow(scene_np)
        plt.title(f"Reconstructed: {reconstructed_text}")
        plt.axis('off')
        
        plt.tight_layout()
        plt.show()

# Test with custom text
test_texts = [
    "a red ball in the center",
    "blue cube on the left",
    "yellow sphere floating",
    "large green triangle"
]

for text in test_texts:
    generate_scene_from_text(text, ar_model, scene_decoder, caption_network, dataset, device)

## 8. Save Models to Google Drive (Optional)

In [None]:
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

# Copy checkpoints to Drive
!cp -r checkpoints /content/drive/MyDrive/visual-scene-generation-checkpoints
!cp -r logs /content/drive/MyDrive/visual-scene-generation-logs
print("Models saved to Google Drive!")

## 9. Resume Training from Checkpoint

In [None]:
# Resume training from a saved checkpoint
!python train.py \
    --resume checkpoints/checkpoint_epoch_49.pt \
    --epochs 100 \
    --batch_size 32 \
    --num_samples 10000 \
    --use_vae \
    --use_amp

## Tips for Colab

1. **Enable GPU**: Go to Runtime → Change runtime type → Hardware accelerator → GPU
2. **Prevent Disconnection**: Keep the tab active or use Colab Pro for longer sessions
3. **Save Progress**: Regularly save checkpoints to Google Drive
4. **Monitor Memory**: Use smaller batch sizes if you encounter OOM errors
5. **Use Mixed Precision**: Add `--use_amp` flag for faster training with less memory