# VQ-VAE Embedding Validation V2

## Improvements over V1 Validation

| Issue | V1 Validation | V2 Validation (this notebook) |
|-------|---------------|-------------------------------|
| Codebook Learning | Gradient-based (unstable) | **EMA updates (stable)** |
| Dead Code Reset | None | **Reset underutilized codes** |
| Structure Weighting | 1x (equal) | **10x weight on structure** |
| Commitment Cost | 0.25 | **0.5** |
| Epochs | 10 | **20** |
| Architecture | [32, 64, 128] | **[64, 128, 256]** |
| Codebook Monitoring | None | **Track usage per epoch** |

## Embeddings Tested

| Embedding | Dimensions | Description |
|-----------|------------|-------------|
| V1 | 32 | Skip-gram (co-occurrence) |
| V2 | 32 | Hybrid + state collapsing |
| V3 | 40 | Compositional (material + shape + properties) |
| Random | 32 | Random baseline |

## Key Metric: Structure Accuracy

~80% of voxels are air. Overall accuracy is meaningless. We track **structure accuracy** (non-air blocks only).

In [None]:
# ============================================================
# CELL 1: Imports and Setup
# ============================================================

import json
import random
import time
from pathlib import Path
from typing import Dict, List, Tuple, Any, Optional

import h5py
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# ============================================================
# CELL 2: Configuration
# ============================================================

# === Data Paths ===
DATA_DIR = "/kaggle/input/minecraft-schematics/minecraft_splits/splits/train"
VAL_DIR = "/kaggle/input/minecraft-schematics/minecraft_splits/splits/val"
VOCAB_PATH = "/kaggle/input/minecraft-schematics/tok2block.json"

# Embeddings paths
V1_EMBEDDINGS_PATH = "/kaggle/input/block2vec-embeddings/block_embeddings.npy"
V2_EMBEDDINGS_PATH = "/kaggle/input/block2vec-v2/block_embeddings_v2.npy"
V2_MAPPING_PATH = "/kaggle/input/block2vec-v2/original_to_collapsed.json"
V3_EMBEDDINGS_PATH = "/kaggle/input/block2vec-v3/block_embeddings_v3.npy"

OUTPUT_DIR = "/kaggle/working"

# === Model Architecture (IMPROVED) ===
HIDDEN_DIMS = [64, 128, 256]  # Larger than v1 validation
LATENT_DIM = 256
NUM_CODEBOOK_ENTRIES = 512

# === VQ-VAE Fixes (from vqvae_training_analysis.md) ===
COMMITMENT_COST = 0.5       # Was 0.25 - too low
EMA_DECAY = 0.99            # NEW: EMA codebook updates
DEAD_CODE_THRESHOLD = 2     # NEW: Reset codes used less than this
STRUCTURE_WEIGHT = 10.0     # NEW: Weight structure blocks 10x

# === Training (IMPROVED) ===
EPOCHS = 20                 # Was 10 - not enough
BATCH_SIZE = 4
LEARNING_RATE = 3e-4
USE_AMP = True
GRAD_ACCUM_STEPS = 4

# === Other ===
SEED = 42
NUM_WORKERS = 2

print("VQ-VAE Validation V2 Configuration:")
print(f"  Hidden dims: {HIDDEN_DIMS}")
print(f"  Latent dim: {LATENT_DIM}")
print(f"  Epochs: {EPOCHS}")
print(f"  Batch size: {BATCH_SIZE} (effective: {BATCH_SIZE * GRAD_ACCUM_STEPS})")
print(f"  Commitment cost: {COMMITMENT_COST}")
print(f"  EMA decay: {EMA_DECAY}")
print(f"  Structure weight: {STRUCTURE_WEIGHT}x")
print(f"  Dead code threshold: {DEAD_CODE_THRESHOLD}")

In [None]:
# ============================================================
# CELL 3: Load Vocabulary and Find Air Tokens
# ============================================================

# Load vocabulary
with open(VOCAB_PATH, 'r') as f:
    tok2block = {int(k): v for k, v in json.load(f).items()}

VOCAB_SIZE = len(tok2block)
print(f"Vocabulary size: {VOCAB_SIZE} block types")

# Find ALL air tokens (critical bug fix from v1)
# Token 0 is UNKNOWN_BLOCK, not air!
AIR_TOKENS = set()
for tok, block in tok2block.items():
    block_lower = block.lower()
    if 'air' in block_lower and 'stair' not in block_lower:
        AIR_TOKENS.add(tok)
        print(f"  Found air token: {tok} = '{block}'")

AIR_TOKENS_LIST = sorted(AIR_TOKENS)
AIR_TOKENS_TENSOR = torch.tensor(AIR_TOKENS_LIST, dtype=torch.long)
print(f"\nAir tokens: {AIR_TOKENS_LIST}")

# Save for reproducibility
air_info = {
    "air_tokens": AIR_TOKENS_LIST,
    "note": "These tokens are excluded from structure accuracy calculation"
}
with open(f"{OUTPUT_DIR}/air_tokens_used.json", 'w') as f:
    json.dump(air_info, f, indent=2)
print(f"Saved air tokens to {OUTPUT_DIR}/air_tokens_used.json")

In [None]:
# ============================================================
# CELL 4: Load and Prepare All Embeddings
# ============================================================

# V1: Skip-gram embeddings (32-dim)
v1_embeddings = np.load(V1_EMBEDDINGS_PATH).astype(np.float32)
print(f"V1 embeddings: {v1_embeddings.shape}")

# V2: Hybrid embeddings with mapping (32-dim, collapsed vocab)
v2_collapsed = np.load(V2_EMBEDDINGS_PATH).astype(np.float32)
with open(V2_MAPPING_PATH, 'r') as f:
    original_to_collapsed = {int(k): int(v) for k, v in json.load(f).items()}
print(f"V2 collapsed embeddings: {v2_collapsed.shape}")

# Expand V2 to full vocabulary
v2_embeddings = np.zeros((VOCAB_SIZE, 32), dtype=np.float32)
for orig_tok in range(VOCAB_SIZE):
    if orig_tok in original_to_collapsed:
        collapsed_tok = original_to_collapsed[orig_tok]
        if collapsed_tok < len(v2_collapsed):
            v2_embeddings[orig_tok] = v2_collapsed[collapsed_tok]
print(f"V2 expanded embeddings: {v2_embeddings.shape}")

# V3: Compositional embeddings (40-dim)
v3_embeddings = np.load(V3_EMBEDDINGS_PATH).astype(np.float32)
print(f"V3 embeddings: {v3_embeddings.shape}")

# Random: Same scale as V1
np.random.seed(SEED)
v1_std = v1_embeddings.std()
random_embeddings = np.random.randn(VOCAB_SIZE, 32).astype(np.float32) * v1_std
print(f"Random embeddings: {random_embeddings.shape} (std={v1_std:.4f})")

# Store all variants with their dimensions
EMBEDDINGS = {
    'V1': {'embeddings': v1_embeddings, 'dim': 32},
    'V2': {'embeddings': v2_embeddings, 'dim': 32},
    'V3': {'embeddings': v3_embeddings, 'dim': 40},
    'Random': {'embeddings': random_embeddings, 'dim': 32},
}

print("\nAll embeddings loaded:")
for name, data in EMBEDDINGS.items():
    print(f"  {name}: {data['embeddings'].shape} ({data['dim']}-dim)")

In [None]:
# ============================================================
# CELL 5: Dataset
# ============================================================

class VQVAEDataset(Dataset):
    def __init__(self, data_dir: str, seed: int = 42):
        self.data_dir = Path(data_dir)
        self.h5_files = sorted(self.data_dir.glob("*.h5"))
        if not self.h5_files:
            raise ValueError(f"No H5 files found in {data_dir}")
        print(f"Found {len(self.h5_files)} structures in {data_dir}")
    
    def __len__(self):
        return len(self.h5_files)
    
    def __getitem__(self, idx):
        with h5py.File(self.h5_files[idx], 'r') as f:
            key = list(f.keys())[0]
            structure = f[key][:].astype(np.int64)
        return torch.from_numpy(structure).long()


train_dataset = VQVAEDataset(DATA_DIR, seed=SEED)
val_dataset = VQVAEDataset(VAL_DIR, seed=SEED)

train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True,
    num_workers=NUM_WORKERS, pin_memory=True
)
val_loader = DataLoader(
    val_dataset, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=True
)

print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")

In [None]:
# ============================================================
# CELL 6: Improved VQ-VAE Model with EMA and Dead Code Reset
# ============================================================

class ResidualBlock3D(nn.Module):
    """3D Residual block with batch normalization."""
    def __init__(self, channels: int):
        super().__init__()
        self.conv1 = nn.Conv3d(channels, channels, 3, padding=1)
        self.conv2 = nn.Conv3d(channels, channels, 3, padding=1)
        self.bn1 = nn.BatchNorm3d(channels)
        self.bn2 = nn.BatchNorm3d(channels)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        residual = x
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        return F.relu(x + residual)


class VectorQuantizerEMA(nn.Module):
    """
    Vector Quantizer with EMA updates and dead code reset.
    
    Key improvements over gradient-based VQ:
    1. EMA updates are more stable than gradient descent
    2. Dead code reset prevents codebook collapse
    3. Tracks codebook usage for monitoring
    """
    def __init__(
        self,
        num_codes: int,
        latent_dim: int,
        commitment_cost: float = 0.5,
        ema_decay: float = 0.99,
        dead_code_threshold: int = 2,
    ):
        super().__init__()
        self.num_codes = num_codes
        self.latent_dim = latent_dim
        self.commitment_cost = commitment_cost
        self.ema_decay = ema_decay
        self.dead_code_threshold = dead_code_threshold
        
        # Codebook
        self.register_buffer('codebook', torch.randn(num_codes, latent_dim))
        self.codebook.data.uniform_(-1/num_codes, 1/num_codes)
        
        # EMA tracking
        self.register_buffer('ema_cluster_size', torch.zeros(num_codes))
        self.register_buffer('ema_embed_sum', torch.zeros(num_codes, latent_dim))
        
        # Usage tracking (for monitoring)
        self.register_buffer('code_usage', torch.zeros(num_codes))
    
    def reset_usage_stats(self):
        """Reset usage stats at start of each epoch."""
        self.code_usage.zero_()
    
    def get_codebook_usage(self) -> float:
        """Return fraction of codes used this epoch."""
        return (self.code_usage > 0).float().mean().item()
    
    def forward(self, z_e: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Args:
            z_e: Encoder output [B, C, D, H, W]
        
        Returns:
            z_q: Quantized output [B, C, D, H, W]
            vq_loss: Commitment loss (codebook updated via EMA)
            indices: Code indices [B, D, H, W]
        """
        # Reshape: [B, C, D, H, W] -> [B*D*H*W, C]
        z_e_perm = z_e.permute(0, 2, 3, 4, 1).contiguous()
        flat = z_e_perm.view(-1, self.latent_dim)
        
        # Compute distances to codebook entries
        d = (
            flat.pow(2).sum(1, keepdim=True)
            + self.codebook.pow(2).sum(1)
            - 2 * flat @ self.codebook.t()
        )
        
        # Find nearest codes
        indices = d.argmin(dim=1)
        
        # Update usage tracking
        with torch.no_grad():
            unique_indices = indices.unique()
            self.code_usage[unique_indices] += 1
        
        # Get quantized vectors
        z_q_flat = self.codebook[indices]
        z_q_perm = z_q_flat.view(z_e_perm.shape)
        
        # EMA codebook update (only during training)
        if self.training:
            with torch.no_grad():
                # One-hot encoding of assignments
                encodings = F.one_hot(indices, self.num_codes).float()
                
                # Update cluster sizes
                batch_cluster_size = encodings.sum(0)
                self.ema_cluster_size = (
                    self.ema_decay * self.ema_cluster_size
                    + (1 - self.ema_decay) * batch_cluster_size
                )
                
                # Update embedding sums
                batch_embed_sum = encodings.t() @ flat
                self.ema_embed_sum = (
                    self.ema_decay * self.ema_embed_sum
                    + (1 - self.ema_decay) * batch_embed_sum
                )
                
                # Laplace smoothing to avoid division by zero
                n = self.ema_cluster_size.sum()
                smoothed_cluster_size = (
                    (self.ema_cluster_size + 1e-5)
                    / (n + self.num_codes * 1e-5)
                    * n
                )
                
                # Update codebook
                self.codebook = self.ema_embed_sum / smoothed_cluster_size.unsqueeze(1)
                
                # Dead code reset: reinitialize underutilized codes
                dead_codes = batch_cluster_size < self.dead_code_threshold
                if dead_codes.any() and flat.size(0) > 0:
                    # Sample random encoder outputs to reinitialize dead codes
                    n_dead = dead_codes.sum().item()
                    random_indices = torch.randint(0, flat.size(0), (n_dead,), device=flat.device)
                    self.codebook[dead_codes] = flat[random_indices]
                    # Reset EMA stats for these codes
                    self.ema_cluster_size[dead_codes] = 1
                    self.ema_embed_sum[dead_codes] = flat[random_indices]
        
        # Commitment loss (encoder should commit to codes)
        commitment_loss = F.mse_loss(z_e_perm, z_q_perm.detach())
        vq_loss = self.commitment_cost * commitment_loss
        
        # Straight-through estimator
        z_q_st = z_e_perm + (z_q_perm - z_e_perm).detach()
        z_q = z_q_st.permute(0, 4, 1, 2, 3).contiguous()
        
        return z_q, vq_loss, indices.view(z_e_perm.shape[:-1])


class ImprovedVQVAE(nn.Module):
    """
    VQ-VAE with all improvements from vqvae_training_analysis.md:
    - EMA codebook updates
    - Dead code reset
    - Structure-weighted loss
    - Flexible embedding dimensions
    """
    def __init__(
        self,
        vocab_size: int,
        embedding_dim: int,
        hidden_dims: List[int],
        latent_dim: int,
        num_codes: int,
        commitment_cost: float,
        ema_decay: float,
        dead_code_threshold: int,
        pretrained_embeddings: np.ndarray,
    ):
        super().__init__()
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.latent_dim = latent_dim
        
        # Block embeddings (frozen)
        self.block_emb = nn.Embedding(vocab_size, embedding_dim)
        self.block_emb.weight.data.copy_(torch.from_numpy(pretrained_embeddings))
        self.block_emb.weight.requires_grad = False
        
        # Encoder - note: first layer adapts to embedding_dim
        enc_layers = []
        in_ch = embedding_dim
        for h_dim in hidden_dims:
            enc_layers.extend([
                nn.Conv3d(in_ch, h_dim, 4, stride=2, padding=1),
                nn.BatchNorm3d(h_dim),
                nn.ReLU(inplace=True),
                ResidualBlock3D(h_dim),
            ])
            in_ch = h_dim
        enc_layers.append(nn.Conv3d(in_ch, latent_dim, 3, padding=1))
        self.encoder = nn.Sequential(*enc_layers)
        
        # Vector Quantizer with EMA
        self.quantizer = VectorQuantizerEMA(
            num_codes=num_codes,
            latent_dim=latent_dim,
            commitment_cost=commitment_cost,
            ema_decay=ema_decay,
            dead_code_threshold=dead_code_threshold,
        )
        
        # Decoder
        dec_layers = []
        in_ch = latent_dim
        for h_dim in reversed(hidden_dims):
            dec_layers.extend([
                ResidualBlock3D(in_ch),
                nn.ConvTranspose3d(in_ch, h_dim, 4, stride=2, padding=1),
                nn.BatchNorm3d(h_dim),
                nn.ReLU(inplace=True),
            ])
            in_ch = h_dim
        dec_layers.append(nn.Conv3d(in_ch, vocab_size, 3, padding=1))
        self.decoder = nn.Sequential(*dec_layers)
    
    def forward(self, block_ids: torch.Tensor) -> Dict[str, torch.Tensor]:
        # Embed blocks
        x = self.block_emb(block_ids)  # [B, 32, 32, 32, emb_dim]
        x = x.permute(0, 4, 1, 2, 3).contiguous()  # [B, emb_dim, 32, 32, 32]
        
        # Encode
        z_e = self.encoder(x)
        
        # Quantize
        z_q, vq_loss, indices = self.quantizer(z_e)
        
        # Decode
        logits = self.decoder(z_q)
        
        return {'logits': logits, 'vq_loss': vq_loss, 'indices': indices}
    
    def compute_loss(
        self,
        block_ids: torch.Tensor,
        air_tokens: torch.Tensor,
        structure_weight: float = 10.0,
    ) -> Dict[str, torch.Tensor]:
        """
        Compute loss with structure weighting.
        
        Structure blocks are weighted 10x to counter the 80% air imbalance.
        """
        out = self(block_ids)
        
        # Reshape logits for loss computation
        logits = out['logits'].permute(0, 2, 3, 4, 1).contiguous()  # [B, 32, 32, 32, vocab]
        logits_flat = logits.view(-1, self.vocab_size)
        targets_flat = block_ids.view(-1)
        
        # Compute per-element cross entropy
        ce_loss = F.cross_entropy(logits_flat, targets_flat, reduction='none')
        
        # Apply structure weighting
        air_tokens_device = air_tokens.to(targets_flat.device)
        is_air = torch.isin(targets_flat, air_tokens_device)
        is_structure = ~is_air
        
        # Weight: air=1, structure=structure_weight
        weights = torch.ones_like(ce_loss)
        weights[is_structure] = structure_weight
        
        # Weighted mean
        recon_loss = (weights * ce_loss).sum() / weights.sum()
        
        total_loss = recon_loss + out['vq_loss']
        
        # Compute accuracy metrics
        with torch.no_grad():
            preds_flat = logits_flat.argmax(dim=-1)
            correct = (preds_flat == targets_flat).float()
            
            # Overall accuracy
            acc = correct.mean()
            
            # Air accuracy
            air_acc = correct[is_air].mean() if is_air.any() else torch.tensor(0.0)
            
            # Structure accuracy (KEY METRIC)
            struct_acc = correct[is_structure].mean() if is_structure.any() else torch.tensor(0.0)
            
            # Air percentage
            air_pct = is_air.float().mean()
        
        return {
            'loss': total_loss,
            'recon_loss': recon_loss,
            'vq_loss': out['vq_loss'],
            'accuracy': acc,
            'air_accuracy': air_acc,
            'struct_accuracy': struct_acc,
            'air_percentage': air_pct,
            'indices': out['indices'],
        }


print("ImprovedVQVAE with EMA codebook and structure weighting defined!")

In [None]:
# ============================================================
# CELL 7: Training Functions
# ============================================================

def train_epoch(
    model: ImprovedVQVAE,
    loader: DataLoader,
    optimizer: optim.Optimizer,
    scaler: torch.cuda.amp.GradScaler,
    device: str,
    air_tokens: torch.Tensor,
    structure_weight: float,
) -> Dict[str, float]:
    """Train for one epoch."""
    model.train()
    model.quantizer.reset_usage_stats()
    
    metrics = {
        'loss': 0, 'recon': 0, 'vq': 0,
        'acc': 0, 'air_acc': 0, 'struct_acc': 0, 'air_pct': 0
    }
    n = 0
    
    optimizer.zero_grad()
    
    for batch_idx, batch in enumerate(tqdm(loader, desc="Train", leave=False)):
        batch = batch.to(device)
        
        with torch.amp.autocast('cuda', enabled=USE_AMP):
            out = model.compute_loss(batch, air_tokens, structure_weight)
            loss = out['loss'] / GRAD_ACCUM_STEPS
        
        scaler.scale(loss).backward()
        
        if (batch_idx + 1) % GRAD_ACCUM_STEPS == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
        
        metrics['loss'] += out['loss'].item()
        metrics['recon'] += out['recon_loss'].item()
        metrics['vq'] += out['vq_loss'].item()
        metrics['acc'] += out['accuracy'].item()
        metrics['air_acc'] += out['air_accuracy'].item()
        metrics['struct_acc'] += out['struct_accuracy'].item()
        metrics['air_pct'] += out['air_percentage'].item()
        n += 1
    
    # Handle remaining gradients
    if len(loader) % GRAD_ACCUM_STEPS != 0:
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()
    
    # Add codebook usage
    metrics['codebook_usage'] = model.quantizer.get_codebook_usage()
    
    return {k: v/n if k != 'codebook_usage' else v for k, v in metrics.items()}


@torch.no_grad()
def validate(
    model: ImprovedVQVAE,
    loader: DataLoader,
    device: str,
    air_tokens: torch.Tensor,
    structure_weight: float,
) -> Dict[str, float]:
    """Validate model."""
    model.eval()
    model.quantizer.reset_usage_stats()
    
    metrics = {
        'loss': 0, 'recon': 0,
        'acc': 0, 'air_acc': 0, 'struct_acc': 0, 'air_pct': 0
    }
    n = 0
    
    for batch in tqdm(loader, desc="Val", leave=False):
        batch = batch.to(device)
        
        with torch.amp.autocast('cuda', enabled=USE_AMP):
            out = model.compute_loss(batch, air_tokens, structure_weight)
        
        metrics['loss'] += out['loss'].item()
        metrics['recon'] += out['recon_loss'].item()
        metrics['acc'] += out['accuracy'].item()
        metrics['air_acc'] += out['air_accuracy'].item()
        metrics['struct_acc'] += out['struct_accuracy'].item()
        metrics['air_pct'] += out['air_percentage'].item()
        n += 1
    
    metrics['codebook_usage'] = model.quantizer.get_codebook_usage()
    
    return {k: v/n if k != 'codebook_usage' else v for k, v in metrics.items()}


print("Training functions defined!")

In [None]:
# ============================================================
# CELL 8: Run Experiment for One Embedding Type
# ============================================================

def run_experiment(
    name: str,
    embeddings: np.ndarray,
    embedding_dim: int,
    air_tokens: torch.Tensor,
) -> Dict[str, Any]:
    """Train and evaluate VQ-VAE with given embeddings."""
    print(f"\n{'='*70}")
    print(f"Training with {name} embeddings ({embedding_dim}-dim)")
    print(f"{'='*70}")
    
    # Clear GPU memory
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
    
    # Set seeds
    torch.manual_seed(SEED)
    np.random.seed(SEED)
    random.seed(SEED)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(SEED)
    
    # Create model
    model = ImprovedVQVAE(
        vocab_size=VOCAB_SIZE,
        embedding_dim=embedding_dim,
        hidden_dims=HIDDEN_DIMS,
        latent_dim=LATENT_DIM,
        num_codes=NUM_CODEBOOK_ENTRIES,
        commitment_cost=COMMITMENT_COST,
        ema_decay=EMA_DECAY,
        dead_code_threshold=DEAD_CODE_THRESHOLD,
        pretrained_embeddings=embeddings,
    ).to(device)
    
    # Count params
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    
    if torch.cuda.is_available():
        print(f"GPU memory: {torch.cuda.memory_allocated()/1e9:.2f} GB allocated")
    
    # Optimizer
    optimizer = optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=LEARNING_RATE,
    )
    scaler = torch.amp.GradScaler('cuda', enabled=USE_AMP)
    
    # History tracking
    history = {
        'train_loss': [], 'train_acc': [], 'train_air_acc': [],
        'train_struct_acc': [], 'train_air_pct': [], 'train_codebook_usage': [],
        'val_loss': [], 'val_acc': [], 'val_air_acc': [],
        'val_struct_acc': [], 'val_air_pct': [], 'val_codebook_usage': [],
    }
    
    best_val_struct_acc = 0.0
    start_time = time.time()
    
    for epoch in range(EPOCHS):
        train_metrics = train_epoch(
            model, train_loader, optimizer, scaler, device,
            air_tokens, STRUCTURE_WEIGHT
        )
        val_metrics = validate(
            model, val_loader, device, air_tokens, STRUCTURE_WEIGHT
        )
        
        # Record history
        history['train_loss'].append(train_metrics['loss'])
        history['train_acc'].append(train_metrics['acc'])
        history['train_air_acc'].append(train_metrics['air_acc'])
        history['train_struct_acc'].append(train_metrics['struct_acc'])
        history['train_air_pct'].append(train_metrics['air_pct'])
        history['train_codebook_usage'].append(train_metrics['codebook_usage'])
        history['val_loss'].append(val_metrics['loss'])
        history['val_acc'].append(val_metrics['acc'])
        history['val_air_acc'].append(val_metrics['air_acc'])
        history['val_struct_acc'].append(val_metrics['struct_acc'])
        history['val_air_pct'].append(val_metrics['air_pct'])
        history['val_codebook_usage'].append(val_metrics['codebook_usage'])
        
        # Track best
        if val_metrics['struct_acc'] > best_val_struct_acc:
            best_val_struct_acc = val_metrics['struct_acc']
        
        # Print progress
        print(
            f"Epoch {epoch+1:2d}/{EPOCHS} | "
            f"Loss: {train_metrics['loss']:.3f} | "
            f"Struct: {train_metrics['struct_acc']:.1%} | "
            f"Val Struct: {val_metrics['struct_acc']:.1%} | "
            f"CB Usage: {train_metrics['codebook_usage']:.1%}"
        )
    
    train_time = time.time() - start_time
    print(f"Training time: {train_time/60:.1f} minutes")
    
    # Sanity checks
    final_struct = history['val_struct_acc'][-1]
    final_overall = history['val_acc'][-1]
    if abs(final_struct - final_overall) < 0.01:
        print("WARNING: Structure and overall accuracy too similar - air detection may be broken!")
    else:
        print(f"OK: Structure ({final_struct:.1%}) differs from overall ({final_overall:.1%})")
    
    final_codebook_usage = history['train_codebook_usage'][-1]
    if final_codebook_usage < 0.2:
        print(f"WARNING: Low codebook usage ({final_codebook_usage:.1%}) - possible collapse!")
    else:
        print(f"OK: Codebook usage ({final_codebook_usage:.1%}) is healthy")
    
    # Compile results
    results = {
        'name': name,
        'embedding_dim': embedding_dim,
        'final_val_loss': history['val_loss'][-1],
        'final_val_acc': history['val_acc'][-1],
        'final_val_struct_acc': history['val_struct_acc'][-1],
        'final_val_air_acc': history['val_air_acc'][-1],
        'best_val_struct_acc': best_val_struct_acc,
        'best_val_loss': min(history['val_loss']),
        'best_val_acc': max(history['val_acc']),
        'final_codebook_usage': final_codebook_usage,
        'avg_air_pct': np.mean(history['val_air_pct']),
        'training_time': train_time,
        'history': history,
    }
    
    # Cleanup
    del model, optimizer, scaler
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    return results


print("Experiment function defined!")

In [None]:
# ============================================================
# CELL 9: Run All Experiments
# ============================================================

print("="*70)
print("VQ-VAE EMBEDDING VALIDATION V2")
print("="*70)
print(f"\nImprovements over V1 validation:")
print(f"  - EMA codebook (decay={EMA_DECAY})")
print(f"  - Dead code reset (threshold={DEAD_CODE_THRESHOLD})")
print(f"  - Structure weighting ({STRUCTURE_WEIGHT}x)")
print(f"  - Commitment cost: {COMMITMENT_COST}")
print(f"  - Epochs: {EPOCHS}")
print(f"  - Architecture: {HIDDEN_DIMS}")
print(f"\nAir tokens: {AIR_TOKENS_LIST}")
print("")

all_results = {}

for name, data in EMBEDDINGS.items():
    results = run_experiment(
        name=name,
        embeddings=data['embeddings'],
        embedding_dim=data['dim'],
        air_tokens=AIR_TOKENS_TENSOR,
    )
    all_results[name] = results

print("\n" + "="*70)
print("ALL EXPERIMENTS COMPLETE")
print("="*70)

In [None]:
# ============================================================
# CELL 10: Compare Results
# ============================================================

print("\n" + "="*70)
print("RESULTS COMPARISON")
print("="*70)

# Get average air percentage
avg_air_pct = np.mean([r['avg_air_pct'] for r in all_results.values()])
print(f"\nAverage air percentage: {avg_air_pct:.1%}")
print("This is why STRUCTURE accuracy is the key metric!\n")

# Results table
print("{:<10} {:>8} {:>10} {:>12} {:>12} {:>10} {:>10}".format(
    "Embedding", "Dim", "Val Loss", "Overall Acc", "STRUCT Acc", "CB Usage", "Time"
))
print("-"*80)

# Sort by structure accuracy
sorted_names = sorted(
    all_results.keys(),
    key=lambda x: all_results[x]['best_val_struct_acc'],
    reverse=True
)

for name in sorted_names:
    r = all_results[name]
    print("{:<10} {:>8} {:>10.4f} {:>12.1%} {:>12.1%} {:>10.1%} {:>8.1f}m".format(
        name,
        r['embedding_dim'],
        r['best_val_loss'],
        r['best_val_acc'],
        r['best_val_struct_acc'],
        r['final_codebook_usage'],
        r['training_time']/60,
    ))

# Improvement over random
print("\n" + "="*70)
print("IMPROVEMENT OVER RANDOM BASELINE")
print("="*70)

random_struct = all_results['Random']['best_val_struct_acc']

for name in ['V1', 'V2', 'V3']:
    struct = all_results[name]['best_val_struct_acc']
    improvement = (struct - random_struct) / random_struct * 100
    print(f"\n{name} vs Random:")
    print(f"  Structure Acc: {struct:.1%} vs {random_struct:.1%}")
    print(f"  Improvement: {improvement:+.1f}%")

In [None]:
# ============================================================
# CELL 11: Plot Results
# ============================================================

fig, axes = plt.subplots(2, 4, figsize=(20, 10))

colors = {'V1': 'blue', 'V2': 'green', 'V3': 'purple', 'Random': 'red'}
names = ['V1', 'V2', 'V3', 'Random']

# Row 1: Training metrics

# Validation Loss
ax = axes[0, 0]
for name in names:
    ax.plot(all_results[name]['history']['val_loss'], label=name, color=colors[name], linewidth=2)
ax.set_title('Validation Loss', fontsize=12)
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.legend()
ax.grid(True, alpha=0.3)

# Overall Accuracy
ax = axes[0, 1]
for name in names:
    ax.plot(all_results[name]['history']['val_acc'], label=name, color=colors[name], linewidth=2)
ax.set_title('Overall Accuracy (includes ~80% air)', fontsize=12)
ax.set_xlabel('Epoch')
ax.set_ylabel('Accuracy')
ax.legend()
ax.grid(True, alpha=0.3)

# STRUCTURE Accuracy (KEY)
ax = axes[0, 2]
for name in names:
    ax.plot(all_results[name]['history']['val_struct_acc'], label=name, color=colors[name], linewidth=2)
ax.set_title('STRUCTURE Accuracy (KEY METRIC)', fontsize=12, fontweight='bold')
ax.set_xlabel('Epoch')
ax.set_ylabel('Structure Accuracy')
ax.legend()
ax.grid(True, alpha=0.3)

# Codebook Usage
ax = axes[0, 3]
for name in names:
    ax.plot(all_results[name]['history']['train_codebook_usage'], label=name, color=colors[name], linewidth=2)
ax.set_title('Codebook Usage (should be >30%)', fontsize=12)
ax.set_xlabel('Epoch')
ax.set_ylabel('Usage %')
ax.axhline(y=0.3, color='gray', linestyle='--', alpha=0.5, label='Min healthy')
ax.legend()
ax.grid(True, alpha=0.3)

# Row 2: More metrics

# Air Accuracy
ax = axes[1, 0]
for name in names:
    ax.plot(all_results[name]['history']['val_air_acc'], label=name, color=colors[name], linewidth=2)
ax.set_title('Air Block Accuracy', fontsize=12)
ax.set_xlabel('Epoch')
ax.set_ylabel('Air Accuracy')
ax.legend()
ax.grid(True, alpha=0.3)

# Training Loss
ax = axes[1, 1]
for name in names:
    ax.plot(all_results[name]['history']['train_loss'], label=name, color=colors[name], linewidth=2)
ax.set_title('Training Loss', fontsize=12)
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.legend()
ax.grid(True, alpha=0.3)

# Train vs Val Structure Accuracy (check overfitting)
ax = axes[1, 2]
for name in names:
    ax.plot(all_results[name]['history']['train_struct_acc'], 
            label=f'{name} train', color=colors[name], linewidth=2, linestyle='--')
    ax.plot(all_results[name]['history']['val_struct_acc'],
            label=f'{name} val', color=colors[name], linewidth=2)
ax.set_title('Train vs Val Structure Accuracy', fontsize=12)
ax.set_xlabel('Epoch')
ax.set_ylabel('Structure Accuracy')
ax.legend(fontsize=8)
ax.grid(True, alpha=0.3)

# Bar chart comparison
ax = axes[1, 3]
struct_accs = [all_results[name]['best_val_struct_acc'] for name in names]
bar_colors = [colors[name] for name in names]
bars = ax.bar(names, struct_accs, color=bar_colors, edgecolor='black', linewidth=1.5)
ax.set_title('Best Structure Accuracy', fontsize=12, fontweight='bold')
ax.set_ylabel('Accuracy')
ax.set_ylim(0, max(struct_accs) * 1.2)

# Add value labels
for bar, acc in zip(bars, struct_accs):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
            f'{acc:.1%}', ha='center', fontsize=12, fontweight='bold')

plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/embedding_comparison_v2.png", dpi=150, bbox_inches='tight')
plt.show()

print(f"\nSaved plot to {OUTPUT_DIR}/embedding_comparison_v2.png")

In [None]:
# ============================================================
# CELL 12: Save Results
# ============================================================

# Save summary
summary = {}
for name, r in all_results.items():
    summary[name] = {
        'embedding_dim': r['embedding_dim'],
        'final_val_loss': float(r['final_val_loss']),
        'final_val_acc': float(r['final_val_acc']),
        'final_val_struct_acc': float(r['final_val_struct_acc']),
        'final_val_air_acc': float(r['final_val_air_acc']),
        'best_val_loss': float(r['best_val_loss']),
        'best_val_acc': float(r['best_val_acc']),
        'best_val_struct_acc': float(r['best_val_struct_acc']),
        'final_codebook_usage': float(r['final_codebook_usage']),
        'avg_air_percentage': float(r['avg_air_pct']),
        'training_time_minutes': float(r['training_time'] / 60),
    }

with open(f"{OUTPUT_DIR}/embedding_validation_v2_results.json", 'w') as f:
    json.dump(summary, f, indent=2)
print(f"Summary saved to {OUTPUT_DIR}/embedding_validation_v2_results.json")

# Save full history
full_results = {}
for name, r in all_results.items():
    full_results[name] = {
        'embedding_dim': r['embedding_dim'],
        'history': {k: [float(x) for x in v] for k, v in r['history'].items()},
        'training_time': float(r['training_time']),
        'best_val_struct_acc': float(r['best_val_struct_acc']),
        'best_val_loss': float(r['best_val_loss']),
    }

with open(f"{OUTPUT_DIR}/embedding_validation_v2_full.json", 'w') as f:
    json.dump(full_results, f, indent=2)
print(f"Full history saved to {OUTPUT_DIR}/embedding_validation_v2_full.json")

# Save config for reproducibility
config = {
    'epochs': EPOCHS,
    'batch_size': BATCH_SIZE,
    'learning_rate': LEARNING_RATE,
    'hidden_dims': HIDDEN_DIMS,
    'latent_dim': LATENT_DIM,
    'num_codebook_entries': NUM_CODEBOOK_ENTRIES,
    'commitment_cost': COMMITMENT_COST,
    'ema_decay': EMA_DECAY,
    'dead_code_threshold': DEAD_CODE_THRESHOLD,
    'structure_weight': STRUCTURE_WEIGHT,
    'seed': SEED,
    'air_tokens': AIR_TOKENS_LIST,
}

with open(f"{OUTPUT_DIR}/validation_config_v2.json", 'w') as f:
    json.dump(config, f, indent=2)
print(f"Config saved to {OUTPUT_DIR}/validation_config_v2.json")

In [None]:
# ============================================================
# CELL 13: Final Summary and Conclusion
# ============================================================

print("\n" + "="*70)
print("VQ-VAE EMBEDDING VALIDATION V2 - FINAL SUMMARY")
print("="*70)

print("\nConfiguration:")
print(f"  Epochs: {EPOCHS}")
print(f"  Architecture: {HIDDEN_DIMS}")
print(f"  EMA decay: {EMA_DECAY}")
print(f"  Structure weight: {STRUCTURE_WEIGHT}x")
print(f"  Commitment cost: {COMMITMENT_COST}")

print(f"\nAir tokens excluded: {AIR_TOKENS_LIST}")
print(f"Average air percentage: {avg_air_pct:.1%}")

print("\n" + "="*70)
print("RESULTS (sorted by structure accuracy)")
print("="*70)

print("\n{:<12} {:>12} {:>15} {:>15} {:>12}".format(
    "", "Val Loss", "Overall Acc", "STRUCT Acc", "CB Usage"
))
print("-"*70)

for name in sorted_names:
    r = all_results[name]
    print("{:<12} {:>12.4f} {:>15.1%} {:>15.1%} {:>12.1%}".format(
        name,
        r['best_val_loss'],
        r['best_val_acc'],
        r['best_val_struct_acc'],
        r['final_codebook_usage'],
    ))

# Winner
winner = sorted_names[0]
winner_struct = all_results[winner]['best_val_struct_acc']
random_struct = all_results['Random']['best_val_struct_acc']
improvement = (winner_struct - random_struct) / random_struct * 100

print("\n" + "="*70)
print("CONCLUSION")
print("="*70)

print(f"\nWINNER: {winner} with {winner_struct:.1%} structure accuracy")
print(f"Improvement over random: {improvement:+.1f}%")

# Compare V1 vs V3
v1_struct = all_results['V1']['best_val_struct_acc']
v3_struct = all_results['V3']['best_val_struct_acc']
v1_vs_v3 = (v1_struct - v3_struct) / v3_struct * 100

print(f"\nV1 vs V3: {v1_vs_v3:+.1f}%")

if abs(v1_vs_v3) < 5:
    print("  -> V1 and V3 are comparable. Either could work.")
elif v1_vs_v3 > 0:
    print("  -> V1 skip-gram is better. Use V1 for VQ-VAE.")
else:
    print("  -> V3 compositional is better! Use V3 for VQ-VAE.")

# Codebook health check
print("\nCodebook Health:")
for name in names:
    usage = all_results[name]['final_codebook_usage']
    status = "OK" if usage > 0.3 else "WARNING - possible collapse"
    print(f"  {name}: {usage:.1%} ({status})")

print("\n" + "="*70)
print("Files saved:")
print(f"  - {OUTPUT_DIR}/embedding_comparison_v2.png")
print(f"  - {OUTPUT_DIR}/embedding_validation_v2_results.json")
print(f"  - {OUTPUT_DIR}/embedding_validation_v2_full.json")
print(f"  - {OUTPUT_DIR}/validation_config_v2.json")
print(f"  - {OUTPUT_DIR}/air_tokens_used.json")
print("="*70)