# Visual Scene Generation on Google Colab

This notebook sets up and runs the autoregressive visual scene generation system on Google Colab.

## 1. Clone the Repository

In [None]:
# Clone the repository
!git clone https://github.com/jtooates/visual-scene-generation.git
%cd visual-scene-generation

## 2. Install Dependencies

In [None]:
# Install required packages (most are pre-installed in Colab)
!pip install -q tqdm matplotlib scikit-learn

## 3. Check GPU Availability

In [None]:
import torch
import subprocess

# Check if GPU is available
if torch.cuda.is_available():
    device = torch.cuda.get_device_name(0)
    print(f"GPU Available: {device}")
    !nvidia-smi
else:
    print("No GPU available. Using CPU.")
    print("Go to Runtime > Change runtime type > GPU for better performance")

## 4. (Optional) Mount Google Drive for Persistent Storage

**Recommended**: Mount Drive to save your checkpoints permanently. If you skip this, checkpoints will be lost when the session ends!

In [None]:
try:
    from google.colab import drive
    drive.mount('/content/drive')
    
    # Create directories for saving
    !mkdir -p /content/drive/MyDrive/visual-scene-generation
    
    # Create symlinks so checkpoints save directly to Drive
    !ln -sf /content/drive/MyDrive/visual-scene-generation/checkpoints /content/visual-scene-generation/checkpoints
    !ln -sf /content/drive/MyDrive/visual-scene-generation/logs /content/visual-scene-generation/logs
    
    print("‚úÖ Google Drive mounted! Checkpoints will be saved persistently.")
    print("üìÅ Checkpoints: /content/drive/MyDrive/visual-scene-generation/checkpoints")
    print("üìÅ Logs: /content/drive/MyDrive/visual-scene-generation/logs")
except:
    print("‚ö†Ô∏è Google Drive not mounted. Checkpoints will be temporary!")
    print("   They will be lost when the session ends.")
    !mkdir -p checkpoints logs

## 5. Quick Test Run (Small Dataset)

In [ ]:
# Run a quick test with small dataset to verify everything works
!python train.py \
    --epochs 5 \
    --batch_size 16 \
    --num_samples 1000 \
    --use_vae \
    --log_interval 5 \
    --lr 0.0001 \
    --lambda_kl 0.0001

## 6. Full Training Run

In [ ]:
# Full training with recommended settings
!python train.py \
    --epochs 50 \
    --batch_size 32 \
    --num_samples 10000 \
    --use_vae \
    --use_amp \
    --lr 0.0001 \
    --d_model 512 \
    --hidden_dim 256 \
    --z_dim 128 \
    --lambda_consistency 1.0 \
    --lambda_spatial 0.1 \
    --lambda_kl 0.0001

## 7. Visualize Training Results

## 7. Interactive Scene Generation

## 8. Interactive Scene Generation

In [None]:
def generate_scene_from_text(text, ar_model, scene_decoder, caption_network, dataset, device):
    """Generate a scene from custom text input"""
    ar_model.eval()
    scene_decoder.eval()
    caption_network.eval()
    
    with torch.no_grad():
        # Tokenize input text
        tokens = [dataset.vocab.get(word, dataset.vocab['<UNK>']) for word in text.lower().split()]
        tokens = [dataset.vocab['<SOS>']] + tokens + [dataset.vocab['<EOS>']]
        input_ids = torch.tensor([tokens], dtype=torch.long).to(device)
        
        # Generate embedding
        ar_outputs = ar_model(input_ids, return_embeddings=True)
        text_embedding = ar_outputs['embeddings']
        
        # Generate scene
        scene_outputs = scene_decoder(text_embedding)
        scene = scene_outputs['scene']
        
        # Generate caption from scene
        generated_caption, _ = caption_network.generate_caption(scene)
        reconstructed_text = dataset.decode_tokens(generated_caption[0])
        
        # Visualize
        scene_np = scene[0].cpu().permute(1, 2, 0).numpy()
        
        plt.figure(figsize=(10, 5))
        plt.subplot(1, 2, 1)
        plt.imshow(scene_np)
        plt.title(f"Original: {text}")
        plt.axis('off')
        
        plt.subplot(1, 2, 2)
        plt.imshow(scene_np)
        plt.title(f"Reconstructed: {reconstructed_text}")
        plt.axis('off')
        
        plt.tight_layout()
        plt.show()

# Test with custom text
test_texts = [
    "a red ball in the center",
    "blue cube on the left",
    "yellow sphere floating",
    "large green triangle"
]

for text in test_texts:
    generate_scene_from_text(text, ar_model, scene_decoder, caption_network, dataset, device)

## 8. Save Models to Google Drive (Optional)

## 9. Resume Training from Checkpoint

## Tips for Colab

1. **Enable GPU**: Go to Runtime ‚Üí Change runtime type ‚Üí Hardware accelerator ‚Üí GPU
2. **Prevent Disconnection**: Keep the tab active or use Colab Pro for longer sessions
3. **Save Progress**: Regularly save checkpoints to Google Drive
4. **Monitor Memory**: Use smaller batch sizes if you encounter OOM errors
5. **Use Mixed Precision**: Add `--use_amp` flag for faster training with less memory