# VQ-VAE Embedding Validation - Google Colab Version

## Setup Instructions

1. Upload the following to your Google Drive under `minecraft_ai/`:
   - `splits/train/` - Training H5 files
   - `splits/val/` - Validation H5 files  
   - `tok2block.json` - Vocabulary file
   - `block_embeddings.npy` - V1 embeddings
   - `block_embeddings_v3.npy` - V3 embeddings

2. Run all cells in order

3. **Keep the browser tab open** (but you can minimize it)

## Idle Disconnect Prevention

Cell 0 includes a JavaScript keep-alive script that pings every 60 seconds to prevent Colab from disconnecting due to inactivity. You still need to keep the tab open, but you don't need to interact with it.

## Configuration

- **Epochs**: 12
- **Seeds**: 1 (single run per embedding, ~5 hours total)
- **Structure weight**: 10x (to counter 80% air imbalance)
- **Embeddings**: V1, V3, Random

## Runtime Limits

Google doesn't publish exact limits. Commonly reported:
- Free tier: ~12 hours max, ~90 min idle timeout
- Check official FAQ: https://research.google.com/colaboratory/faq.html

In [None]:
# ============================================================
# CELL 0: Mount Google Drive & Prevent Idle Disconnect
# ============================================================

from google.colab import drive
drive.mount('/content/drive')

# Verify mount
import os
print("Drive mounted. Checking for minecraft_ai folder...")
drive_path = "/content/drive/MyDrive/minecraft_ai"
if os.path.exists(drive_path):
    print(f"Found: {drive_path}")
    print(f"Contents: {os.listdir(drive_path)}")
else:
    print(f"WARNING: {drive_path} not found!")
    print("Please create this folder and upload your data.")

# --- Prevent Colab from disconnecting due to idle timeout ---
# This injects JavaScript that clicks the connect button every 60 seconds
from IPython.display import display, Javascript

keep_alive_js = Javascript('''
function KeepAlive() {
    console.log("Keep-alive ping at " + new Date().toLocaleTimeString());
    // Try multiple selectors for different Colab versions
    var buttons = document.querySelectorAll("colab-connect-button, colab-toolbar-button#connect");
    buttons.forEach(function(btn) {
        if (btn) btn.click();
    });
}
// Run every 60 seconds
setInterval(KeepAlive, 60000);
console.log("Keep-alive script activated - will ping every 60 seconds");
''')

display(keep_alive_js)
print("\nKeep-alive script activated to prevent idle disconnect.")

In [None]:
# ============================================================
# CELL 1: Imports
# ============================================================

import json
import random
import time
from pathlib import Path
from typing import Dict, List, Tuple, Any, Optional
from collections import defaultdict, Counter

import h5py
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm
from scipy import stats

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# ============================================================
# CELL 2: Configuration
# ============================================================

# === Google Drive Paths ===
DRIVE_BASE = "/content/drive/MyDrive/minecraft_ai"

DATA_DIR = f"{DRIVE_BASE}/splits/train"
VAL_DIR = f"{DRIVE_BASE}/splits/val"
VOCAB_PATH = f"{DRIVE_BASE}/tok2block.json"

# Embeddings
V1_EMBEDDINGS_PATH = f"{DRIVE_BASE}/block_embeddings.npy"
V3_EMBEDDINGS_PATH = f"{DRIVE_BASE}/block_embeddings_v3.npy"

# Output to Drive so results persist
OUTPUT_DIR = f"{DRIVE_BASE}/output/vqvae_validation"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# === Model Architecture ===
HIDDEN_DIMS = [64, 128, 256]
LATENT_DIM = 256
NUM_CODEBOOK_ENTRIES = 512

# === VQ-VAE Settings ===
COMMITMENT_COST = 0.5
EMA_DECAY = 0.99
DEAD_CODE_THRESHOLD = 2
STRUCTURE_WEIGHT = 10.0

# === Training ===
EPOCHS = 12
BATCH_SIZE = 4
LEARNING_RATE = 3e-4
USE_AMP = True
GRAD_ACCUM_STEPS = 4

# === Statistical Rigor ===
NUM_SEEDS = 1
BASE_SEEDS = [42]

# === Diagnostics ===
NUM_RECONSTRUCTION_SAMPLES = 5
TOP_K_BLOCKS = 20

NUM_WORKERS = 2

print("Colab Validation Configuration:")
print(f"  Drive base: {DRIVE_BASE}")
print(f"  Embeddings: V1, V3, Random")
print(f"  Seeds: {BASE_SEEDS} ({NUM_SEEDS} runs per embedding)")
print(f"  Epochs: {EPOCHS}")
print(f"  Output: {OUTPUT_DIR}")

# Verify paths exist
for path, name in [(DATA_DIR, "train"), (VAL_DIR, "val"), (VOCAB_PATH, "vocab"), 
                   (V1_EMBEDDINGS_PATH, "V1 emb"), (V3_EMBEDDINGS_PATH, "V3 emb")]:
    exists = os.path.exists(path)
    status = "OK" if exists else "MISSING"
    print(f"  {name}: {status}")

In [None]:
# ============================================================
# CELL 3: Load Vocabulary and Categorize Blocks
# ============================================================

with open(VOCAB_PATH, 'r') as f:
    tok2block = {int(k): v for k, v in json.load(f).items()}

block2tok = {v: k for k, v in tok2block.items()}
VOCAB_SIZE = len(tok2block)
print(f"Vocabulary size: {VOCAB_SIZE}")

# Find air tokens
AIR_TOKENS = set()
for tok, block in tok2block.items():
    if 'air' in block.lower() and 'stair' not in block.lower():
        AIR_TOKENS.add(tok)
        print(f"  Air token: {tok} = {block}")

AIR_TOKENS_LIST = sorted(AIR_TOKENS)
AIR_TOKENS_TENSOR = torch.tensor(AIR_TOKENS_LIST, dtype=torch.long)

# Categorize blocks for analysis
BLOCK_CATEGORIES = {
    'wood': [], 'stone': [], 'metal': [], 'glass': [],
    'wool': [], 'concrete': [], 'terracotta': [],
    'stairs': [], 'slabs': [], 'walls': [], 'fences': [],
    'doors': [], 'plants': [], 'redstone': [], 'light': [],
    'other': []
}

for tok, block in tok2block.items():
    if tok in AIR_TOKENS:
        continue
    block_lower = block.lower()
    
    categorized = False
    if any(w in block_lower for w in ['oak', 'spruce', 'birch', 'jungle', 'acacia', 'dark_oak', 'mangrove', 'cherry', 'bamboo', 'crimson', 'warped']):
        if 'stair' in block_lower:
            BLOCK_CATEGORIES['stairs'].append(tok)
        elif 'slab' in block_lower:
            BLOCK_CATEGORIES['slabs'].append(tok)
        elif 'fence' in block_lower:
            BLOCK_CATEGORIES['fences'].append(tok)
        elif 'door' in block_lower:
            BLOCK_CATEGORIES['doors'].append(tok)
        else:
            BLOCK_CATEGORIES['wood'].append(tok)
        categorized = True
    elif any(w in block_lower for w in ['stone', 'cobble', 'brick', 'granite', 'diorite', 'andesite', 'deepslate', 'tuff']):
        if 'stair' in block_lower:
            BLOCK_CATEGORIES['stairs'].append(tok)
        elif 'slab' in block_lower:
            BLOCK_CATEGORIES['slabs'].append(tok)
        elif 'wall' in block_lower:
            BLOCK_CATEGORIES['walls'].append(tok)
        else:
            BLOCK_CATEGORIES['stone'].append(tok)
        categorized = True
    elif any(w in block_lower for w in ['iron', 'gold', 'copper', 'netherite']):
        BLOCK_CATEGORIES['metal'].append(tok)
        categorized = True
    elif 'glass' in block_lower:
        BLOCK_CATEGORIES['glass'].append(tok)
        categorized = True
    elif 'wool' in block_lower:
        BLOCK_CATEGORIES['wool'].append(tok)
        categorized = True
    elif 'concrete' in block_lower:
        BLOCK_CATEGORIES['concrete'].append(tok)
        categorized = True
    elif 'terracotta' in block_lower:
        BLOCK_CATEGORIES['terracotta'].append(tok)
        categorized = True
    elif any(w in block_lower for w in ['redstone', 'piston', 'observer', 'comparator', 'repeater', 'lever', 'button']):
        BLOCK_CATEGORIES['redstone'].append(tok)
        categorized = True
    elif any(w in block_lower for w in ['torch', 'lantern', 'lamp', 'glowstone', 'sea_lantern', 'shroomlight']):
        BLOCK_CATEGORIES['light'].append(tok)
        categorized = True
    elif any(w in block_lower for w in ['flower', 'grass', 'fern', 'leaves', 'sapling', 'vine', 'moss']):
        BLOCK_CATEGORIES['plants'].append(tok)
        categorized = True
    
    if not categorized:
        BLOCK_CATEGORIES['other'].append(tok)

print("\nBlock categories:")
for cat, toks in BLOCK_CATEGORIES.items():
    print(f"  {cat}: {len(toks)} blocks")

In [None]:
# ============================================================
# CELL 4: Load Embeddings
# ============================================================

# V1: Skip-gram embeddings
v1_embeddings = np.load(V1_EMBEDDINGS_PATH).astype(np.float32)
print(f"V1: {v1_embeddings.shape}")

# V3: Compositional embeddings
v3_embeddings = np.load(V3_EMBEDDINGS_PATH).astype(np.float32)
print(f"V3: {v3_embeddings.shape}")

# Random (will create fresh per seed)
v1_std = v1_embeddings.std()
print(f"V1 std (for random init): {v1_std:.4f}")

EMBEDDINGS_BASE = {
    'V1': {'embeddings': v1_embeddings, 'dim': 32},
    'V3': {'embeddings': v3_embeddings, 'dim': 40},
}

print(f"\nEmbeddings to test: {list(EMBEDDINGS_BASE.keys())} + Random")

In [None]:
# ============================================================
# CELL 5: Dataset
# ============================================================

class VQVAEDataset(Dataset):
    def __init__(self, data_dir: str):
        self.data_dir = Path(data_dir)
        self.h5_files = sorted(self.data_dir.glob("*.h5"))
        if not self.h5_files:
            raise ValueError(f"No H5 files in {data_dir}")
        print(f"Found {len(self.h5_files)} structures in {data_dir}")
    
    def __len__(self):
        return len(self.h5_files)
    
    def __getitem__(self, idx):
        with h5py.File(self.h5_files[idx], 'r') as f:
            key = list(f.keys())[0]
            structure = f[key][:].astype(np.int64)
        return torch.from_numpy(structure).long(), idx

train_dataset = VQVAEDataset(DATA_DIR)
val_dataset = VQVAEDataset(VAL_DIR)

print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}")

In [None]:
# ============================================================
# CELL 6: Enhanced VQ-VAE with Diagnostics
# ============================================================

class ResidualBlock3D(nn.Module):
    def __init__(self, channels: int):
        super().__init__()
        self.conv1 = nn.Conv3d(channels, channels, 3, padding=1)
        self.conv2 = nn.Conv3d(channels, channels, 3, padding=1)
        self.bn1 = nn.BatchNorm3d(channels)
        self.bn2 = nn.BatchNorm3d(channels)
    
    def forward(self, x):
        residual = x
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        return F.relu(x + residual)


class VectorQuantizerEMA(nn.Module):
    def __init__(self, num_codes, latent_dim, commitment_cost=0.5, ema_decay=0.99, dead_threshold=2):
        super().__init__()
        self.num_codes = num_codes
        self.latent_dim = latent_dim
        self.commitment_cost = commitment_cost
        self.ema_decay = ema_decay
        self.dead_threshold = dead_threshold
        
        self.register_buffer('codebook', torch.randn(num_codes, latent_dim))
        self.codebook.data.uniform_(-1/num_codes, 1/num_codes)
        self.register_buffer('ema_cluster_size', torch.zeros(num_codes))
        self.register_buffer('ema_embed_sum', torch.zeros(num_codes, latent_dim))
        self.register_buffer('code_usage', torch.zeros(num_codes))
        self.register_buffer('code_usage_total', torch.zeros(num_codes))
    
    def reset_epoch_stats(self):
        self.code_usage.zero_()
    
    def get_usage_fraction(self):
        return (self.code_usage > 0).float().mean().item()
    
    def get_perplexity(self):
        if self.code_usage.sum() == 0:
            return 0.0
        probs = self.code_usage / self.code_usage.sum()
        probs = probs[probs > 0]
        entropy = -(probs * probs.log()).sum()
        return entropy.exp().item()
    
    def get_entropy(self):
        if self.code_usage.sum() == 0:
            return 0.0
        probs = self.code_usage / self.code_usage.sum()
        probs = probs[probs > 0]
        return -(probs * probs.log()).sum().item()
    
    def forward(self, z_e):
        z_e_perm = z_e.permute(0, 2, 3, 4, 1).contiguous()
        flat = z_e_perm.view(-1, self.latent_dim)
        
        # Cast to float32 for codebook operations (AMP fix)
        flat_f32 = flat.float()
        codebook_f32 = self.codebook.float()
        
        d = (flat_f32.pow(2).sum(1, keepdim=True) 
             + codebook_f32.pow(2).sum(1) 
             - 2 * flat_f32 @ codebook_f32.t())
        
        indices = d.argmin(dim=1)
        
        with torch.no_grad():
            for idx in indices.unique():
                count = (indices == idx).sum()
                self.code_usage[idx] += count
                self.code_usage_total[idx] += count
        
        z_q_flat = self.codebook[indices]
        z_q_perm = z_q_flat.view(z_e_perm.shape)
        
        if self.training:
            with torch.no_grad():
                encodings = F.one_hot(indices, self.num_codes).float()
                batch_size = encodings.sum(0)
                
                self.ema_cluster_size = self.ema_decay * self.ema_cluster_size + (1 - self.ema_decay) * batch_size
                batch_sum = encodings.t() @ flat_f32
                self.ema_embed_sum = self.ema_decay * self.ema_embed_sum + (1 - self.ema_decay) * batch_sum
                
                n = self.ema_cluster_size.sum()
                smoothed = (self.ema_cluster_size + 1e-5) / (n + self.num_codes * 1e-5) * n
                self.codebook.data = self.ema_embed_sum / smoothed.unsqueeze(1)
                
                # Dead code reset - use float32
                dead = batch_size < self.dead_threshold
                if dead.any() and flat_f32.size(0) > 0:
                    n_dead = dead.sum().item()
                    rand_idx = torch.randint(0, flat_f32.size(0), (n_dead,), device=flat_f32.device)
                    self.codebook.data[dead] = flat_f32[rand_idx]
                    self.ema_cluster_size[dead] = 1
                    self.ema_embed_sum[dead] = flat_f32[rand_idx]
        
        commitment_loss = F.mse_loss(z_e_perm, z_q_perm.detach())
        vq_loss = self.commitment_cost * commitment_loss
        
        z_q_st = z_e_perm + (z_q_perm - z_e_perm).detach()
        z_q = z_q_st.permute(0, 4, 1, 2, 3).contiguous()
        
        return z_q, vq_loss, indices.view(z_e_perm.shape[:-1])


class DiagnosticVQVAE(nn.Module):
    def __init__(self, vocab_size, emb_dim, hidden_dims, latent_dim, num_codes,
                 commitment_cost, ema_decay, dead_threshold, pretrained_emb):
        super().__init__()
        self.vocab_size = vocab_size
        self.emb_dim = emb_dim
        self.latent_dim = latent_dim
        
        self.block_emb = nn.Embedding(vocab_size, emb_dim)
        self.block_emb.weight.data.copy_(torch.from_numpy(pretrained_emb))
        self.block_emb.weight.requires_grad = False
        
        enc = []
        in_ch = emb_dim
        for h in hidden_dims:
            enc.extend([
                nn.Conv3d(in_ch, h, 4, stride=2, padding=1),
                nn.BatchNorm3d(h),
                nn.ReLU(inplace=True),
                ResidualBlock3D(h),
            ])
            in_ch = h
        enc.append(nn.Conv3d(in_ch, latent_dim, 3, padding=1))
        self.encoder = nn.Sequential(*enc)
        
        self.quantizer = VectorQuantizerEMA(num_codes, latent_dim, commitment_cost, ema_decay, dead_threshold)
        
        dec = []
        in_ch = latent_dim
        for h in reversed(hidden_dims):
            dec.extend([
                ResidualBlock3D(in_ch),
                nn.ConvTranspose3d(in_ch, h, 4, stride=2, padding=1),
                nn.BatchNorm3d(h),
                nn.ReLU(inplace=True),
            ])
            in_ch = h
        dec.append(nn.Conv3d(in_ch, vocab_size, 3, padding=1))
        self.decoder = nn.Sequential(*dec)
    
    def forward(self, block_ids):
        x = self.block_emb(block_ids).permute(0, 4, 1, 2, 3).contiguous()
        z_e = self.encoder(x)
        z_q, vq_loss, indices = self.quantizer(z_e)
        logits = self.decoder(z_q)
        return {'logits': logits, 'vq_loss': vq_loss, 'indices': indices, 'z_e': z_e}
    
    def compute_loss_with_diagnostics(self, block_ids, air_tokens, structure_weight):
        out = self(block_ids)
        
        logits = out['logits'].permute(0, 2, 3, 4, 1).contiguous()
        logits_flat = logits.view(-1, self.vocab_size)
        targets_flat = block_ids.view(-1)
        
        ce = F.cross_entropy(logits_flat, targets_flat, reduction='none')
        
        air_dev = air_tokens.to(targets_flat.device)
        is_air = torch.isin(targets_flat, air_dev)
        is_struct = ~is_air
        
        weights = torch.ones_like(ce)
        weights[is_struct] = structure_weight
        recon_loss = (weights * ce).sum() / weights.sum()
        
        total_loss = recon_loss + out['vq_loss']
        
        with torch.no_grad():
            preds_flat = logits_flat.argmax(-1)
            correct = (preds_flat == targets_flat).float()
            
            acc = correct.mean()
            air_acc = correct[is_air].mean() if is_air.any() else torch.tensor(0.0)
            struct_acc = correct[is_struct].mean() if is_struct.any() else torch.tensor(0.0)
            air_pct = is_air.float().mean()
        
        return {
            'loss': total_loss,
            'recon_loss': recon_loss,
            'vq_loss': out['vq_loss'],
            'accuracy': acc,
            'air_accuracy': air_acc,
            'struct_accuracy': struct_acc,
            'air_pct': air_pct,
            'predictions': preds_flat,
            'targets': targets_flat,
            'is_structure': is_struct,
            'indices': out['indices'],
        }

print("DiagnosticVQVAE defined!")

In [None]:
# ============================================================
# CELL 7: Diagnostic Tracking Classes
# ============================================================

class PerBlockTracker:
    def __init__(self, vocab_size):
        self.correct = np.zeros(vocab_size)
        self.total = np.zeros(vocab_size)
    
    def update(self, preds, targets):
        preds_np = preds.cpu().numpy()
        targets_np = targets.cpu().numpy()
        for t in np.unique(targets_np):
            mask = targets_np == t
            self.total[t] += mask.sum()
            self.correct[t] += (preds_np[mask] == t).sum()
    
    def get_accuracy(self, min_count=10):
        accs = {}
        for t in range(len(self.total)):
            if self.total[t] >= min_count:
                accs[t] = self.correct[t] / self.total[t]
        return accs
    
    def get_top_bottom(self, tok2block, k=20, min_count=10):
        accs = self.get_accuracy(min_count)
        sorted_toks = sorted(accs.keys(), key=lambda t: accs[t], reverse=True)
        
        top = [(tok2block.get(t, f'tok_{t}'), accs[t], int(self.total[t])) for t in sorted_toks[:k]]
        bottom = [(tok2block.get(t, f'tok_{t}'), accs[t], int(self.total[t])) for t in sorted_toks[-k:]]
        return top, bottom


class ConfusionTracker:
    def __init__(self, max_track=1000):
        self.confusions = Counter()
        self.max_track = max_track
    
    def update(self, preds, targets, is_structure):
        preds_np = preds.cpu().numpy()
        targets_np = targets.cpu().numpy()
        is_struct_np = is_structure.cpu().numpy()
        
        wrong = (preds_np != targets_np) & is_struct_np
        for t, p in zip(targets_np[wrong], preds_np[wrong]):
            self.confusions[(int(t), int(p))] += 1
    
    def get_top_confusions(self, tok2block, k=20):
        top = self.confusions.most_common(k)
        result = []
        for (t, p), count in top:
            true_name = tok2block.get(t, f'tok_{t}')
            pred_name = tok2block.get(p, f'tok_{p}')
            result.append((true_name, pred_name, count))
        return result


class GradientTracker:
    def __init__(self):
        self.encoder_norms = []
        self.decoder_norms = []
    
    def compute_norms(self, model):
        enc_norm = 0.0
        dec_norm = 0.0
        
        for name, param in model.named_parameters():
            if param.grad is not None:
                norm = param.grad.norm().item()
                if 'encoder' in name:
                    enc_norm += norm ** 2
                elif 'decoder' in name:
                    dec_norm += norm ** 2
        
        self.encoder_norms.append(enc_norm ** 0.5)
        self.decoder_norms.append(dec_norm ** 0.5)

print("Diagnostic trackers defined!")

In [None]:
# ============================================================
# CELL 8: Training Functions
# ============================================================

def train_epoch_diagnostic(model, loader, optimizer, scaler, device, air_tokens, 
                           structure_weight, grad_tracker=None):
    model.train()
    model.quantizer.reset_epoch_stats()
    
    metrics = {'loss': 0, 'recon': 0, 'vq': 0, 'acc': 0, 'air_acc': 0, 'struct_acc': 0, 'air_pct': 0}
    n = 0
    
    optimizer.zero_grad()
    
    for batch_idx, (batch, _) in enumerate(tqdm(loader, desc="Train", leave=False)):
        batch = batch.to(device)
        
        with torch.amp.autocast('cuda', enabled=USE_AMP):
            out = model.compute_loss_with_diagnostics(batch, air_tokens, structure_weight)
            loss = out['loss'] / GRAD_ACCUM_STEPS
        
        scaler.scale(loss).backward()
        
        if (batch_idx + 1) % GRAD_ACCUM_STEPS == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            
            if grad_tracker:
                grad_tracker.compute_norms(model)
            
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
        
        metrics['loss'] += out['loss'].item()
        metrics['recon'] += out['recon_loss'].item()
        metrics['vq'] += out['vq_loss'].item()
        metrics['acc'] += out['accuracy'].item()
        metrics['air_acc'] += out['air_accuracy'].item()
        metrics['struct_acc'] += out['struct_accuracy'].item()
        metrics['air_pct'] += out['air_pct'].item()
        n += 1
    
    if len(loader) % GRAD_ACCUM_STEPS != 0:
        scaler.unscale_(optimizer)
        scaler.step(optimizer)
        scaler.update()
    
    metrics['codebook_usage'] = model.quantizer.get_usage_fraction()
    metrics['perplexity'] = model.quantizer.get_perplexity()
    metrics['entropy'] = model.quantizer.get_entropy()
    
    return {k: v/n if k not in ['codebook_usage', 'perplexity', 'entropy'] else v for k, v in metrics.items()}


@torch.no_grad()
def validate_diagnostic(model, loader, device, air_tokens, structure_weight,
                        block_tracker=None, confusion_tracker=None):
    model.eval()
    model.quantizer.reset_epoch_stats()
    
    metrics = {'loss': 0, 'recon': 0, 'acc': 0, 'air_acc': 0, 'struct_acc': 0, 'air_pct': 0}
    n = 0
    
    for batch, _ in tqdm(loader, desc="Val", leave=False):
        batch = batch.to(device)
        
        with torch.amp.autocast('cuda', enabled=USE_AMP):
            out = model.compute_loss_with_diagnostics(batch, air_tokens, structure_weight)
        
        metrics['loss'] += out['loss'].item()
        metrics['recon'] += out['recon_loss'].item()
        metrics['acc'] += out['accuracy'].item()
        metrics['air_acc'] += out['air_accuracy'].item()
        metrics['struct_acc'] += out['struct_accuracy'].item()
        metrics['air_pct'] += out['air_pct'].item()
        n += 1
        
        if block_tracker:
            block_tracker.update(out['predictions'], out['targets'])
        if confusion_tracker:
            confusion_tracker.update(out['predictions'], out['targets'], out['is_structure'])
    
    metrics['codebook_usage'] = model.quantizer.get_usage_fraction()
    metrics['perplexity'] = model.quantizer.get_perplexity()
    metrics['entropy'] = model.quantizer.get_entropy()
    
    return {k: v/n if k not in ['codebook_usage', 'perplexity', 'entropy'] else v for k, v in metrics.items()}

print("Training functions defined!")

In [None]:
# ============================================================
# CELL 9: Visualization Functions
# ============================================================

@torch.no_grad()
def get_reconstruction_samples(model, dataset, device, air_tokens, n_samples=5):
    model.eval()
    samples = []
    
    indices = np.random.choice(len(dataset), n_samples, replace=False)
    
    for idx in indices:
        original, _ = dataset[idx]
        original = original.unsqueeze(0).to(device)
        
        out = model(original)
        recon = out['logits'].argmax(dim=1).squeeze(0)
        
        orig_flat = original.view(-1)
        recon_flat = recon.view(-1)
        
        air_dev = air_tokens.to(device)
        is_struct = ~torch.isin(orig_flat, air_dev)
        
        overall_acc = (orig_flat == recon_flat).float().mean().item()
        struct_acc = (orig_flat[is_struct] == recon_flat[is_struct]).float().mean().item() if is_struct.any() else 0
        
        samples.append({
            'original': original.squeeze(0).cpu().numpy(),
            'reconstructed': recon.cpu().numpy(),
            'overall_acc': overall_acc,
            'struct_acc': struct_acc,
            'idx': idx,
        })
    
    return samples


def visualize_slice(original, reconstructed, z_slice, tok2block, air_tokens, ax_orig, ax_recon):
    orig_slice = original[:, :, z_slice]
    recon_slice = reconstructed[:, :, z_slice]
    
    def get_color(tok):
        if tok in air_tokens:
            return [1, 1, 1]
        np.random.seed(tok)
        return np.random.rand(3)
    
    orig_rgb = np.zeros((*orig_slice.shape, 3))
    recon_rgb = np.zeros((*recon_slice.shape, 3))
    
    for i in range(orig_slice.shape[0]):
        for j in range(orig_slice.shape[1]):
            orig_rgb[i, j] = get_color(orig_slice[i, j])
            recon_rgb[i, j] = get_color(recon_slice[i, j])
    
    ax_orig.imshow(orig_rgb)
    ax_orig.set_title('Original')
    ax_orig.axis('off')
    
    ax_recon.imshow(recon_rgb)
    ax_recon.set_title('Reconstructed')
    ax_recon.axis('off')


def compute_category_accuracy(block_tracker, categories, air_tokens):
    per_block = block_tracker.get_accuracy(min_count=5)
    
    cat_acc = {}
    for cat, toks in categories.items():
        valid_toks = [t for t in toks if t in per_block and t not in air_tokens]
        if valid_toks:
            cat_acc[cat] = np.mean([per_block[t] for t in valid_toks])
    
    return cat_acc

print("Visualization functions defined!")

In [None]:
# ============================================================
# CELL 10: Main Experiment Runner
# ============================================================

def run_full_experiment(name, embeddings, emb_dim, seed, air_tokens, train_ds, val_ds):
    print(f"\n{'='*70}")
    print(f"{name} (seed={seed}, dim={emb_dim})")
    print(f"{'='*70}")
    
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.empty_cache()
    
    g = torch.Generator()
    g.manual_seed(seed)
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                              num_workers=NUM_WORKERS, pin_memory=True, generator=g)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
                            num_workers=NUM_WORKERS, pin_memory=True)
    
    model = DiagnosticVQVAE(
        vocab_size=VOCAB_SIZE,
        emb_dim=emb_dim,
        hidden_dims=HIDDEN_DIMS,
        latent_dim=LATENT_DIM,
        num_codes=NUM_CODEBOOK_ENTRIES,
        commitment_cost=COMMITMENT_COST,
        ema_decay=EMA_DECAY,
        dead_threshold=DEAD_CODE_THRESHOLD,
        pretrained_emb=embeddings,
    ).to(device)
    
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Trainable params: {trainable:,}")
    
    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=LEARNING_RATE)
    scaler = torch.amp.GradScaler('cuda', enabled=USE_AMP)
    
    grad_tracker = GradientTracker()
    
    history = {
        'train_loss': [], 'train_recon': [], 'train_vq': [],
        'train_acc': [], 'train_air_acc': [], 'train_struct_acc': [],
        'train_codebook_usage': [], 'train_perplexity': [], 'train_entropy': [],
        'val_loss': [], 'val_acc': [], 'val_air_acc': [], 'val_struct_acc': [],
        'val_codebook_usage': [], 'val_perplexity': [], 'val_entropy': [],
        'encoder_grad_norm': [], 'decoder_grad_norm': [],
    }
    
    best_struct_acc = 0
    start = time.time()
    
    for epoch in range(EPOCHS):
        train_m = train_epoch_diagnostic(model, train_loader, optimizer, scaler, device,
                                         air_tokens, STRUCTURE_WEIGHT, grad_tracker)
        val_m = validate_diagnostic(model, val_loader, device, air_tokens, STRUCTURE_WEIGHT)
        
        history['train_loss'].append(train_m['loss'])
        history['train_recon'].append(train_m['recon'])
        history['train_vq'].append(train_m['vq'])
        history['train_acc'].append(train_m['acc'])
        history['train_air_acc'].append(train_m['air_acc'])
        history['train_struct_acc'].append(train_m['struct_acc'])
        history['train_codebook_usage'].append(train_m['codebook_usage'])
        history['train_perplexity'].append(train_m['perplexity'])
        history['train_entropy'].append(train_m['entropy'])
        history['val_loss'].append(val_m['loss'])
        history['val_acc'].append(val_m['acc'])
        history['val_air_acc'].append(val_m['air_acc'])
        history['val_struct_acc'].append(val_m['struct_acc'])
        history['val_codebook_usage'].append(val_m['codebook_usage'])
        history['val_perplexity'].append(val_m['perplexity'])
        history['val_entropy'].append(val_m['entropy'])
        
        if grad_tracker.encoder_norms:
            history['encoder_grad_norm'].append(np.mean(grad_tracker.encoder_norms[-100:]))
            history['decoder_grad_norm'].append(np.mean(grad_tracker.decoder_norms[-100:]))
        
        if val_m['struct_acc'] > best_struct_acc:
            best_struct_acc = val_m['struct_acc']
        
        print(f"Epoch {epoch+1:2d} | Struct: {train_m['struct_acc']:.1%} | "
              f"Val: {val_m['struct_acc']:.1%} | Perp: {train_m['perplexity']:.0f} | "
              f"CB: {train_m['codebook_usage']:.1%}")
    
    train_time = time.time() - start
    
    print("\nFinal detailed validation...")
    block_tracker = PerBlockTracker(VOCAB_SIZE)
    confusion_tracker = ConfusionTracker()
    
    final_val = validate_diagnostic(model, val_loader, device, air_tokens, STRUCTURE_WEIGHT,
                                    block_tracker, confusion_tracker)
    
    top_blocks, bottom_blocks = block_tracker.get_top_bottom(tok2block, k=TOP_K_BLOCKS)
    top_confusions = confusion_tracker.get_top_confusions(tok2block, k=20)
    category_acc = compute_category_accuracy(block_tracker, BLOCK_CATEGORIES, AIR_TOKENS)
    
    samples = get_reconstruction_samples(model, val_ds, device, air_tokens, NUM_RECONSTRUCTION_SAMPLES)
    
    results = {
        'name': name,
        'seed': seed,
        'emb_dim': emb_dim,
        'best_struct_acc': best_struct_acc,
        'final_struct_acc': final_val['struct_acc'],
        'final_val_loss': final_val['loss'],
        'final_perplexity': final_val['perplexity'],
        'final_codebook_usage': final_val['codebook_usage'],
        'training_time': train_time,
        'history': history,
        'top_blocks': top_blocks,
        'bottom_blocks': bottom_blocks,
        'top_confusions': top_confusions,
        'category_accuracy': category_acc,
        'reconstruction_samples': samples,
    }
    
    del model, optimizer, scaler
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    return results

print("Experiment runner defined!")

In [None]:
# ============================================================
# CELL 11: Run All Experiments
# ============================================================

print("="*70)
print("VQ-VAE EMBEDDING VALIDATION - COLAB")
print("="*70)
print(f"Seeds: {BASE_SEEDS}")
print(f"Epochs: {EPOCHS}")
print(f"Embeddings: V1, V3, Random")
print()

all_results = {}

for emb_name, emb_data in EMBEDDINGS_BASE.items():
    all_results[emb_name] = []
    for seed in BASE_SEEDS:
        result = run_full_experiment(
            name=emb_name,
            embeddings=emb_data['embeddings'],
            emb_dim=emb_data['dim'],
            seed=seed,
            air_tokens=AIR_TOKENS_TENSOR,
            train_ds=train_dataset,
            val_ds=val_dataset,
        )
        all_results[emb_name].append(result)

all_results['Random'] = []
for seed in BASE_SEEDS:
    np.random.seed(seed)
    rand_emb = np.random.randn(VOCAB_SIZE, 32).astype(np.float32) * v1_std
    
    result = run_full_experiment(
        name='Random',
        embeddings=rand_emb,
        emb_dim=32,
        seed=seed,
        air_tokens=AIR_TOKENS_TENSOR,
        train_ds=train_dataset,
        val_ds=val_dataset,
    )
    all_results['Random'].append(result)

print("\n" + "="*70)
print("ALL EXPERIMENTS COMPLETE")
print("="*70)

In [None]:
# ============================================================
# CELL 12: Results Summary
# ============================================================

print("\n" + "="*70)
print("RESULTS SUMMARY")
print("="*70)

summary_stats = {}
for name, runs in all_results.items():
    struct_accs = [r['best_struct_acc'] for r in runs]
    perplexities = [r['final_perplexity'] for r in runs]
    cb_usages = [r['final_codebook_usage'] for r in runs]
    
    summary_stats[name] = {
        'struct_acc_mean': np.mean(struct_accs),
        'struct_acc_std': np.std(struct_accs),
        'struct_acc_all': struct_accs,
        'perplexity_mean': np.mean(perplexities),
        'cb_usage_mean': np.mean(cb_usages),
    }

sorted_names = sorted(summary_stats.keys(), key=lambda x: summary_stats[x]['struct_acc_mean'], reverse=True)

print("\n{:<10} {:>15} {:>15} {:>15}".format("Embedding", "Struct Acc", "Perplexity", "CB Usage"))
print("-"*60)

for name in sorted_names:
    s = summary_stats[name]
    print("{:<10} {:>15.1%} {:>15.0f} {:>15.1%}".format(
        name, s['struct_acc_mean'], s['perplexity_mean'], s['cb_usage_mean']
    ))

winner = sorted_names[0]
print(f"\nWINNER: {winner} with {summary_stats[winner]['struct_acc_mean']:.1%} structure accuracy")

In [None]:
# ============================================================
# CELL 13: Save Results
# ============================================================

# Save summary
summary_output = {}
for emb_name, emb_stats in summary_stats.items():
    summary_output[emb_name] = {
        'struct_acc_mean': float(emb_stats['struct_acc_mean']),
        'struct_acc_std': float(emb_stats['struct_acc_std']),
        'perplexity_mean': float(emb_stats['perplexity_mean']),
        'cb_usage_mean': float(emb_stats['cb_usage_mean']),
    }

with open(f"{OUTPUT_DIR}/validation_summary.json", 'w') as f:
    json.dump(summary_output, f, indent=2)

# Save detailed results
detailed_output = {}
for emb_name, runs in all_results.items():
    detailed_output[emb_name] = []
    for run in runs:
        run_data = {
            'seed': run['seed'],
            'best_struct_acc': float(run['best_struct_acc']),
            'final_struct_acc': float(run['final_struct_acc']),
            'training_time': float(run['training_time']),
            'history': {k: [float(x) for x in v] for k, v in run['history'].items()},
            'category_accuracy': {k: float(v) for k, v in run['category_accuracy'].items()},
        }
        detailed_output[emb_name].append(run_data)

with open(f"{OUTPUT_DIR}/validation_detailed.json", 'w') as f:
    json.dump(detailed_output, f, indent=2)

print(f"Results saved to {OUTPUT_DIR}/")
print(f"  - validation_summary.json")
print(f"  - validation_detailed.json")

In [None]:
# ============================================================
# CELL 14: Training Curves Plot
# ============================================================

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

colors = {'V1': 'blue', 'V3': 'purple', 'Random': 'red'}
names = ['V1', 'V3', 'Random']

# Structure accuracy
ax = axes[0, 0]
for name in names:
    ax.plot(all_results[name][0]['history']['val_struct_acc'], label=name, color=colors[name], linewidth=2)
ax.set_title('Validation Structure Accuracy', fontweight='bold')
ax.set_xlabel('Epoch')
ax.set_ylabel('Accuracy')
ax.legend()
ax.grid(True, alpha=0.3)

# Loss
ax = axes[0, 1]
for name in names:
    ax.plot(all_results[name][0]['history']['train_loss'], label=name, color=colors[name])
ax.set_title('Training Loss')
ax.set_xlabel('Epoch')
ax.legend()
ax.grid(True, alpha=0.3)

# Codebook usage
ax = axes[1, 0]
for name in names:
    ax.plot(all_results[name][0]['history']['train_codebook_usage'], label=name, color=colors[name])
ax.set_title('Codebook Usage')
ax.set_xlabel('Epoch')
ax.legend()
ax.grid(True, alpha=0.3)

# Final comparison bar chart
ax = axes[1, 1]
means = [summary_stats[name]['struct_acc_mean'] for name in names]
bar_colors = [colors[name] for name in names]
bars = ax.bar(names, means, color=bar_colors, edgecolor='black')
ax.set_title('Final Structure Accuracy')
ax.set_ylabel('Accuracy')
for bar, mean in zip(bars, means):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
            f'{mean:.1%}', ha='center', fontsize=11, fontweight='bold')

plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/training_curves.png", dpi=150, bbox_inches='tight')
plt.show()

print(f"Plot saved to {OUTPUT_DIR}/training_curves.png")