# RQ-VAE Colab Setup & Testing

This notebook sets up the RQ-VAE project in Google Colab for iterative testing.

## 1. Setup: Install Dependencies

In [None]:
# Clone repository (if not already cloned)
!git clone https://github.com/YOUR_USERNAME/rq_vae.git
%cd rq_vae

In [None]:
# Install dependencies
!pip install -q torch>=2.0.0
!pip install -q transformers datasets wandb einops omegaconf tqdm accelerate
!pip install -q "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"

print("✓ Dependencies installed!")

## 2. Import and Verify Setup

In [None]:
import torch
import sys
sys.path.insert(0, '.')

from src.model import TextEncoder, TextDecoder, SwiGLU, SwiGLUTransformerLayer

print(f"✓ Imports successful!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 3. Test Encoder

In [None]:
print("Creating encoder...")

encoder = TextEncoder(
    model_name="Qwen/Qwen3-0.6B",
    latent_dim=512,
    compression_factor=4,
    freeze_backbone=True,
    num_latent_layers=2,
)

# Move to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder = encoder.to(device)

print(f"✓ Encoder created on {device}")
print(f"  Model: {encoder.model_name}")
print(f"  Hidden size: {encoder.hidden_size}")
print(f"  Latent dim: {encoder.latent_dim}")
print(f"  Compression: {encoder.compression_factor}x")
print(f"  Latent layers: {encoder.num_latent_layers}")

In [None]:
# Test encoder forward pass
print("Testing encoder forward pass...")

batch_size = 2
seq_len = 128
input_ids = torch.randint(0, 1000, (batch_size, seq_len)).to(device)
attention_mask = torch.ones(batch_size, seq_len, dtype=torch.long).to(device)

with torch.no_grad():
    latents = encoder(input_ids, attention_mask)

print(f"✓ Encoder forward pass successful!")
print(f"  Input shape:  {input_ids.shape}")
print(f"  Output shape: {latents.shape}")
print(f"  Compression:  {seq_len} -> {latents.shape[1]} tokens ({seq_len // latents.shape[1]}x)")

## 4. Test Decoder

In [None]:
print("Creating decoder...")

decoder = TextDecoder(
    model_name="Qwen/Qwen3-0.6B",
    latent_dim=512,
    compression_factor=4,
    freeze_backbone=True,
    num_latent_layers=2,
)

decoder = decoder.to(device)

print(f"✓ Decoder created on {device}")
print(f"  Model: {decoder.model_name}")
print(f"  Hidden size: {decoder.hidden_size}")
print(f"  Latent dim: {decoder.latent_dim}")
print(f"  Vocab size: {decoder.vocab_size}")
print(f"  Compression: {decoder.compression_factor}x")
print(f"  Latent layers: {decoder.num_latent_layers}")

In [None]:
# Test decoder forward pass
print("Testing decoder forward pass...")

with torch.no_grad():
    logits = decoder(latents, target_len=seq_len)

print(f"✓ Decoder forward pass successful!")
print(f"  Input shape:  {latents.shape}")
print(f"  Output shape: {logits.shape}")
print(f"  Expansion:    {latents.shape[1]} -> {logits.shape[1]} tokens ({logits.shape[1] // latents.shape[1]}x)")

## 5. Test Full Pipeline (Encode -> Decode)

In [None]:
print("Testing full encode -> decode pipeline...")

# Create fresh input
batch_size = 2
seq_len = 256
input_ids = torch.randint(0, decoder.vocab_size, (batch_size, seq_len)).to(device)
attention_mask = torch.ones(batch_size, seq_len, dtype=torch.long).to(device)

with torch.no_grad():
    # Encode
    latents = encoder(input_ids, attention_mask)
    
    # Decode
    reconstructed_logits = decoder(latents, target_len=seq_len)
    
    # Get predicted tokens
    predicted_tokens = reconstructed_logits.argmax(dim=-1)

print(f"✓ Full pipeline successful!")
print(f"  Input tokens:    {input_ids.shape}")
print(f"  Compressed:      {latents.shape}")
print(f"  Reconstructed:   {reconstructed_logits.shape}")
print(f"  Predicted:       {predicted_tokens.shape}")

# Calculate reconstruction accuracy (just for fun, won't be good without training)
accuracy = (predicted_tokens == input_ids).float().mean().item()
print(f"  Token accuracy:  {accuracy:.2%} (untrained, just a sanity check)")

## 6. Memory Usage

In [None]:
if torch.cuda.is_available():
    print("GPU Memory Usage:")
    print(f"  Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
    print(f"  Reserved:  {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
    print(f"  Max allocated: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
else:
    print("Running on CPU")

## 7. Test SwiGLU Layers (Optional)

In [None]:
print("Testing SwiGLU components...")

# Test SwiGLU
swiglu = SwiGLU(dim=512).to(device)
x = torch.randn(2, 32, 512).to(device)
with torch.no_grad():
    out = swiglu(x)
print(f"✓ SwiGLU: {x.shape} -> {out.shape}")

# Test SwiGLUTransformerLayer
layer = SwiGLUTransformerLayer(d_model=512, nhead=8).to(device)
with torch.no_grad():
    out = layer(x)
print(f"✓ SwiGLUTransformerLayer: {x.shape} -> {out.shape}")

print("\n✓ All components working!")

## 8. Quick Training Test (Optional)

Test that gradients flow properly (without actual training).

In [None]:
print("Testing gradient flow...")

# Unfreeze models for gradient test
encoder.unfreeze_backbone()
decoder.unfreeze_backbone()

# Create optimizer
optimizer = torch.optim.Adam(
    list(encoder.parameters()) + list(decoder.parameters()),
    lr=1e-4
)

# Forward pass with gradients
input_ids = torch.randint(0, decoder.vocab_size, (2, 64)).to(device)
attention_mask = torch.ones(2, 64, dtype=torch.long).to(device)

latents = encoder(input_ids, attention_mask)
logits = decoder(latents, target_len=64)

# Dummy loss
loss = torch.nn.functional.cross_entropy(
    logits.reshape(-1, decoder.vocab_size),
    input_ids.reshape(-1)
)

# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()

print(f"✓ Gradients flow correctly!")
print(f"  Loss: {loss.item():.4f}")

# Re-freeze for future tests
encoder.freeze_backbone()
decoder.freeze_backbone()

## ✅ All Tests Complete!

Your RQ-VAE models are working correctly in Colab. You can now:
- Train the models
- Test with real data
- Experiment with hyperparameters