# VQ-VAE v2: Fixed Training with EMA Codebook

## Key Fixes from v1:

1. **EMA Codebook Updates**: Instead of gradient descent on codebook, use exponential moving average
2. **Dead Code Reset**: Reinitialize codebook entries that are never used
3. **Weighted Loss**: Non-air blocks weighted 10x higher (combats class imbalance)
4. **Higher Commitment Cost**: beta=0.5 instead of 0.25
5. **Structure Accuracy Tracking**: Separate metrics for air vs non-air blocks

## Why v1 Failed (Codebook Collapse)

In v1, only 29 out of 512 codes were used because:
- Air blocks dominate (~90% of voxels), so model optimized for predicting air
- Unused codebook entries drifted away and were never recovered
- Low commitment cost let encoder outputs "hover" between codes

In [None]:
# ============================================================
# CELL 1: Imports and Setup
# ============================================================

import json
import random
import time
from pathlib import Path
from typing import Dict, List, Tuple, Any, Optional

import h5py
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm

# Check if GPU is available
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# ============================================================
# CELL 2: Configuration (UPDATED)
# ============================================================

# === Data Paths ===
DATA_DIR = "/kaggle/input/minecraft-schematics/minecraft_splits/splits/train"
VAL_DIR = "/kaggle/input/minecraft-schematics/minecraft_splits/splits/val"
VOCAB_PATH = "/kaggle/input/minecraft-schematics/tok2block.json"
EMBEDDINGS_PATH = "/kaggle/input/block2vec-embeddings/block_embeddings.npy"
OUTPUT_DIR = "/kaggle/working"

# === Model Architecture ===
BLOCK_EMBEDDING_DIM = 32
HIDDEN_DIMS = [64, 128, 256]
LATENT_DIM = 256
NUM_CODEBOOK_ENTRIES = 512  # Keep 512 - EMA + reset will use more

# === KEY CHANGES ===
COMMITMENT_COST = 0.5       # Increased from 0.25 - stronger commitment
EMA_DECAY = 0.99            # EMA decay rate for codebook updates
STRUCTURE_WEIGHT = 10.0     # Weight non-air blocks 10x higher
USE_FOCAL_LOSS = False      # Optional: focal loss for hard examples

# === Training Hyperparameters ===
EPOCHS = 30                 # Slightly more epochs for EMA to stabilize
BATCH_SIZE = 8
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-5
USE_AMP = True

# === Other ===
SEED = 42
NUM_WORKERS = 2
AIR_INDEX = 0               # Block ID for air

print("Configuration loaded (v2 with fixes)!")
print(f"  Key changes:")
print(f"    - EMA codebook updates (decay={EMA_DECAY})")
print(f"    - Dead code reset")
print(f"    - Structure weight: {STRUCTURE_WEIGHT}x for non-air blocks")
print(f"    - Commitment cost: {COMMITMENT_COST}")

In [None]:
# ============================================================
# CELL 3: Load Vocabulary and Pre-trained Embeddings
# ============================================================

with open(VOCAB_PATH, 'r') as f:
    tok2block = {int(k): v for k, v in json.load(f).items()}

VOCAB_SIZE = len(tok2block)
print(f"Vocabulary size: {VOCAB_SIZE} block types")

# Find air block index
for tok, block in tok2block.items():
    if 'air' in block.lower() and 'stair' not in block.lower():
        print(f"Air block: token {tok} = '{block}'")
        AIR_INDEX = tok
        break

pretrained_embeddings = np.load(EMBEDDINGS_PATH)
print(f"Loaded embeddings: {pretrained_embeddings.shape}")

In [None]:
# ============================================================
# CELL 4: Dataset Class (unchanged)
# ============================================================

class VQVAEDataset(Dataset):
    def __init__(self, data_dir: str, augment: bool = False, seed: int = 42):
        self.data_dir = Path(data_dir)
        self.augment = augment
        self.rng = random.Random(seed)
        self.h5_files = sorted(self.data_dir.glob("*.h5"))
        if not self.h5_files:
            raise ValueError(f"No H5 files found in {data_dir}")
        print(f"Found {len(self.h5_files)} structures in {data_dir}")
    
    def __len__(self) -> int:
        return len(self.h5_files)
    
    def __getitem__(self, idx: int) -> torch.Tensor:
        h5_path = self.h5_files[idx]
        with h5py.File(h5_path, 'r') as f:
            key = list(f.keys())[0]
            structure = f[key][:].astype(np.int64)
        
        if self.augment:
            k = self.rng.randint(0, 3)
            if k > 0:
                structure = np.rot90(structure, k=k, axes=(0, 2))
            if self.rng.random() > 0.5:
                structure = np.flip(structure, axis=2)
            structure = np.ascontiguousarray(structure)
        
        return torch.from_numpy(structure).long()

train_dataset = VQVAEDataset(DATA_DIR, augment=True, seed=SEED)
val_dataset = VQVAEDataset(VAL_DIR, augment=False, seed=SEED)
print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}")

In [None]:
# ============================================================
# CELL 5: Create DataLoaders
# ============================================================

train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True,
    num_workers=NUM_WORKERS, pin_memory=(device == "cuda"),
)
val_loader = DataLoader(
    val_dataset, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=(device == "cuda"),
)
print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")

---
# Part 1: EMA Vector Quantizer (NEW!)

The key fix: instead of learning the codebook via gradient descent, we update it using **Exponential Moving Average**:

```
For each codebook entry i:
   N_i = decay * N_i + (1-decay) * (count of assignments to i)
   m_i = decay * m_i + (1-decay) * (sum of encoder outputs assigned to i)
   codebook[i] = m_i / N_i
```

This is more stable and prevents codebook collapse.

We also **reset dead codes** - if a code isn't used, reinitialize it to a random encoder output.

In [None]:
# ============================================================
# CELL 6: EMA Vector Quantizer (KEY FIX!)
# ============================================================

class VectorQuantizerEMA(nn.Module):
    """
    Vector Quantization with EMA codebook updates.
    
    Key improvements over gradient-based VQ:
    1. EMA updates are more stable
    2. Dead code reset prevents codebook collapse
    3. No gradient on codebook - only commitment loss on encoder
    """
    
    def __init__(
        self,
        num_embeddings: int = 512,
        embedding_dim: int = 256,
        commitment_cost: float = 0.5,
        decay: float = 0.99,
        epsilon: float = 1e-5,
        dead_code_threshold: float = 0.01,
    ):
        super().__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.commitment_cost = commitment_cost
        self.decay = decay
        self.epsilon = epsilon
        self.dead_code_threshold = dead_code_threshold
        
        # Codebook (not a learnable parameter - updated via EMA)
        self.register_buffer("codebook", torch.randn(num_embeddings, embedding_dim))
        self.register_buffer("ema_cluster_size", torch.zeros(num_embeddings))
        self.register_buffer("ema_embed_sum", torch.randn(num_embeddings, embedding_dim))
        self.register_buffer("initialized", torch.tensor(False))
        self.register_buffer("usage_count", torch.zeros(num_embeddings))
    
    def _init_codebook(self, flat_z_e: torch.Tensor):
        """Initialize codebook from first batch of encoder outputs."""
        n_samples = flat_z_e.shape[0]
        if n_samples >= self.num_embeddings:
            indices = torch.randperm(n_samples, device=flat_z_e.device)[:self.num_embeddings]
            self.codebook.data.copy_(flat_z_e[indices])
        else:
            self.codebook.data[:n_samples].copy_(flat_z_e)
        
        self.ema_cluster_size.fill_(1.0)
        self.ema_embed_sum.data.copy_(self.codebook.data)
        self.initialized.fill_(True)
        print("Codebook initialized from encoder outputs!")
    
    def _reset_dead_codes(self, flat_z_e: torch.Tensor, encoding_indices: torch.Tensor):
        """Reset codebook entries that are rarely used."""
        batch_usage = torch.bincount(
            encoding_indices.view(-1), minlength=self.num_embeddings
        ).float()
        
        self.usage_count.data.mul_(self.decay).add_(batch_usage, alpha=1 - self.decay)
        
        avg_usage = self.usage_count.sum() / self.num_embeddings
        dead_mask = self.usage_count < (avg_usage * self.dead_code_threshold)
        n_dead = dead_mask.sum().item()
        
        if n_dead > 0 and flat_z_e.shape[0] > 0:
            n_samples = min(int(n_dead), flat_z_e.shape[0])
            indices = torch.randperm(flat_z_e.shape[0], device=flat_z_e.device)[:n_samples]
            samples = flat_z_e[indices]
            dead_indices = torch.where(dead_mask)[0][:n_samples]
            
            self.codebook.data[dead_indices] = samples
            self.ema_cluster_size.data[dead_indices] = 1.0
            self.ema_embed_sum.data[dead_indices] = samples
            self.usage_count.data[dead_indices] = avg_usage
        
        return n_dead
    
    def forward(self, z_e: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        # Reshape to [N, C]
        z_e_permuted = z_e.permute(0, 2, 3, 4, 1).contiguous()
        flat_z_e = z_e_permuted.view(-1, self.embedding_dim)
        
        # Initialize codebook from first batch
        if not self.initialized and self.training:
            self._init_codebook(flat_z_e)
        
        # Find nearest codebook entry
        z_e_sq = (flat_z_e ** 2).sum(dim=1, keepdim=True)
        codebook_sq = (self.codebook ** 2).sum(dim=1, keepdim=True).t()
        dot_product = torch.mm(flat_z_e, self.codebook.t())
        distances = z_e_sq + codebook_sq - 2 * dot_product
        encoding_indices = distances.argmin(dim=1)
        
        z_q_flat = F.embedding(encoding_indices, self.codebook)
        
        # EMA update (only during training)
        if self.training:
            encodings = F.one_hot(encoding_indices, self.num_embeddings).float()
            
            batch_cluster_size = encodings.sum(0)
            self.ema_cluster_size.data.mul_(self.decay).add_(
                batch_cluster_size, alpha=1 - self.decay
            )
            
            batch_embed_sum = encodings.t() @ flat_z_e
            self.ema_embed_sum.data.mul_(self.decay).add_(
                batch_embed_sum, alpha=1 - self.decay
            )
            
            n = self.ema_cluster_size.sum()
            smoothed_cluster_size = (
                (self.ema_cluster_size + self.epsilon) /
                (n + self.num_embeddings * self.epsilon) * n
            )
            
            self.codebook.data.copy_(
                self.ema_embed_sum / smoothed_cluster_size.unsqueeze(1)
            )
            
            self._reset_dead_codes(flat_z_e, encoding_indices)
        
        z_q_permuted = z_q_flat.view(z_e_permuted.shape)
        
        # Only commitment loss (codebook updated via EMA)
        commitment_loss = F.mse_loss(z_e_permuted, z_q_permuted.detach())
        vq_loss = self.commitment_cost * commitment_loss
        
        # Straight-through
        z_q_st = z_e_permuted + (z_q_permuted - z_e_permuted).detach()
        z_q = z_q_st.permute(0, 4, 1, 2, 3).contiguous()
        
        encoding_indices = encoding_indices.view(z_e_permuted.shape[:-1])
        
        return z_q, vq_loss, encoding_indices
    
    def get_usage_stats(self) -> Tuple[int, float]:
        """Get codebook utilization."""
        avg_usage = self.usage_count.sum() / self.num_embeddings
        used_mask = self.usage_count > (avg_usage * self.dead_code_threshold)
        num_used = used_mask.sum().item()
        return int(num_used), num_used / self.num_embeddings

print("VectorQuantizerEMA defined!")

In [None]:
# ============================================================
# CELL 7: Residual Block (unchanged)
# ============================================================

class ResidualBlock3D(nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm3d(in_channels)
        self.bn2 = nn.BatchNorm3d(out_channels)
        self.skip = nn.Conv3d(in_channels, out_channels, kernel_size=1) \
            if in_channels != out_channels else nn.Identity()
    
    def forward(self, x):
        identity = self.skip(x)
        out = F.relu(self.bn1(x))
        out = self.conv1(out)
        out = F.relu(self.bn2(out))
        out = self.conv2(out)
        return out + identity

In [None]:
# ============================================================
# CELL 8: Encoder (unchanged)
# ============================================================

class Encoder(nn.Module):
    def __init__(self, in_channels=32, hidden_dims=None, latent_dim=256):
        super().__init__()
        if hidden_dims is None:
            hidden_dims = [64, 128, 256]
        
        layers = []
        current_channels = in_channels
        for hidden_dim in hidden_dims:
            layers.extend([
                nn.Conv3d(current_channels, hidden_dim, kernel_size=4, stride=2, padding=1),
                nn.BatchNorm3d(hidden_dim),
                nn.ReLU(inplace=True),
                ResidualBlock3D(hidden_dim, hidden_dim),
            ])
            current_channels = hidden_dim
        layers.append(nn.Conv3d(current_channels, latent_dim, kernel_size=3, padding=1))
        self.encoder = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.encoder(x)

In [None]:
# ============================================================
# CELL 9: Decoder (unchanged)
# ============================================================

class Decoder(nn.Module):
    def __init__(self, latent_dim=256, hidden_dims=None, num_blocks=3717):
        super().__init__()
        if hidden_dims is None:
            hidden_dims = [256, 128, 64]
        
        layers = []
        current_channels = latent_dim
        for hidden_dim in hidden_dims:
            layers.extend([
                ResidualBlock3D(current_channels, current_channels),
                nn.ConvTranspose3d(current_channels, hidden_dim, kernel_size=4, stride=2, padding=1),
                nn.BatchNorm3d(hidden_dim),
                nn.ReLU(inplace=True),
            ])
            current_channels = hidden_dim
        layers.append(nn.Conv3d(current_channels, num_blocks, kernel_size=3, padding=1))
        self.decoder = nn.Sequential(*layers)
    
    def forward(self, z_q):
        return self.decoder(z_q)

In [None]:
# ============================================================
# CELL 10: Full VQ-VAE Model (UPDATED with weighted loss)
# ============================================================

class VQVAE(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        block_embedding_dim: int,
        hidden_dims: List[int],
        latent_dim: int,
        num_codebook_entries: int,
        commitment_cost: float,
        ema_decay: float,
        pretrained_embeddings: np.ndarray,
    ):
        super().__init__()
        
        self.vocab_size = vocab_size
        self.latent_dim = latent_dim
        self.num_codebook_entries = num_codebook_entries
        
        # Block embeddings (frozen)
        self.block_embeddings = nn.Embedding(vocab_size, block_embedding_dim)
        self.block_embeddings.weight.data.copy_(torch.from_numpy(pretrained_embeddings))
        self.block_embeddings.weight.requires_grad = False
        
        self.encoder = Encoder(in_channels=block_embedding_dim, hidden_dims=hidden_dims, latent_dim=latent_dim)
        
        # Use EMA quantizer!
        self.quantizer = VectorQuantizerEMA(
            num_embeddings=num_codebook_entries,
            embedding_dim=latent_dim,
            commitment_cost=commitment_cost,
            decay=ema_decay,
        )
        
        self.decoder = Decoder(latent_dim=latent_dim, hidden_dims=list(reversed(hidden_dims)), num_blocks=vocab_size)
    
    def forward(self, block_ids: torch.Tensor) -> Dict[str, Any]:
        embedded = self.block_embeddings(block_ids)
        embedded = embedded.permute(0, 4, 1, 2, 3).contiguous()
        z_e = self.encoder(embedded)
        z_q, vq_loss, indices = self.quantizer(z_e)
        logits = self.decoder(z_q)
        return {"logits": logits, "vq_loss": vq_loss, "indices": indices}
    
    def compute_loss(
        self,
        block_ids: torch.Tensor,
        air_index: int = 0,
        structure_weight: float = 10.0,
    ) -> Dict[str, torch.Tensor]:
        """
        Compute loss with class imbalance handling.
        Non-air blocks are weighted higher to combat air dominance.
        """
        outputs = self(block_ids)
        
        logits = outputs["logits"].permute(0, 2, 3, 4, 1).contiguous()
        logits_flat = logits.view(-1, self.vocab_size)
        targets_flat = block_ids.view(-1)
        
        # Weight non-air blocks higher
        is_structure = (targets_flat != air_index).float()
        weights = 1.0 + is_structure * (structure_weight - 1)  # air=1, structure=10
        
        # Weighted cross-entropy
        ce_loss = F.cross_entropy(logits_flat, targets_flat, reduction='none')
        reconstruction_loss = (weights * ce_loss).mean()
        
        total_loss = reconstruction_loss + outputs["vq_loss"]
        
        # Detailed accuracy metrics
        with torch.no_grad():
            predictions = logits_flat.argmax(dim=1)
            correct = (predictions == targets_flat).float()
            accuracy = correct.mean()
            
            air_mask = targets_flat == air_index
            air_accuracy = correct[air_mask].mean() if air_mask.sum() > 0 else torch.tensor(0.0)
            
            structure_mask = ~air_mask
            structure_accuracy = correct[structure_mask].mean() if structure_mask.sum() > 0 else torch.tensor(0.0)
        
        return {
            "loss": total_loss,
            "reconstruction_loss": reconstruction_loss,
            "vq_loss": outputs["vq_loss"],
            "accuracy": accuracy,
            "air_accuracy": air_accuracy,
            "structure_accuracy": structure_accuracy,
            "indices": outputs["indices"],
        }

print("VQVAE model defined (with EMA and weighted loss)!")

In [None]:
# ============================================================
# CELL 11: Create Model
# ============================================================

torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

model = VQVAE(
    vocab_size=VOCAB_SIZE,
    block_embedding_dim=BLOCK_EMBEDDING_DIM,
    hidden_dims=HIDDEN_DIMS,
    latent_dim=LATENT_DIM,
    num_codebook_entries=NUM_CODEBOOK_ENTRIES,
    commitment_cost=COMMITMENT_COST,
    ema_decay=EMA_DECAY,
    pretrained_embeddings=pretrained_embeddings,
)

model = model.to(device)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total params: {total_params:,}, Trainable: {trainable_params:,}")

In [None]:
# ============================================================
# CELL 12: Test Forward Pass
# ============================================================

print("Testing forward pass...")
with torch.no_grad():
    test_batch = torch.randint(0, VOCAB_SIZE, (2, 32, 32, 32)).to(device)
    outputs = model.compute_loss(test_batch, air_index=AIR_INDEX, structure_weight=STRUCTURE_WEIGHT)
    print(f"  Loss: {outputs['loss'].item():.4f}")
    print(f"  Structure accuracy: {outputs['structure_accuracy'].item():.4f}")
print("Forward pass successful!")

In [None]:
# ============================================================
# CELL 13: Create Optimizer and Scaler
# ============================================================

optimizer = optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
scaler = torch.amp.GradScaler('cuda', enabled=USE_AMP)

print(f"Optimizer: AdamW, Scheduler: ReduceLROnPlateau")

In [None]:
# ============================================================
# CELL 14: Training Functions (UPDATED)
# ============================================================

def train_epoch(model, loader, optimizer, scaler, device, air_index, structure_weight, use_amp=True):
    model.train()
    metrics = {"loss": 0, "recon": 0, "vq": 0, "acc": 0, "air_acc": 0, "struct_acc": 0}
    all_indices = []
    
    for batch in tqdm(loader, desc="Training", leave=False):
        batch = batch.to(device)
        
        with torch.amp.autocast('cuda', enabled=use_amp):
            outputs = model.compute_loss(batch, air_index=air_index, structure_weight=structure_weight)
            loss = outputs["loss"]
        
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()
        
        metrics["loss"] += loss.item()
        metrics["recon"] += outputs["reconstruction_loss"].item()
        metrics["vq"] += outputs["vq_loss"].item()
        metrics["acc"] += outputs["accuracy"].item()
        metrics["air_acc"] += outputs["air_accuracy"].item()
        metrics["struct_acc"] += outputs["structure_accuracy"].item()
        all_indices.append(outputs["indices"].cpu())
    
    n = len(loader)
    for k in metrics:
        metrics[k] /= n
    
    # Codebook usage
    all_indices = torch.cat([idx.view(-1) for idx in all_indices])
    unique_codes = len(torch.unique(all_indices))
    metrics["codebook_usage"] = unique_codes / NUM_CODEBOOK_ENTRIES
    
    return metrics


@torch.no_grad()
def validate(model, loader, device, air_index, structure_weight, use_amp=True):
    model.eval()
    metrics = {"loss": 0, "recon": 0, "acc": 0, "air_acc": 0, "struct_acc": 0}
    
    for batch in tqdm(loader, desc="Validating", leave=False):
        batch = batch.to(device)
        
        with torch.amp.autocast('cuda', enabled=use_amp):
            outputs = model.compute_loss(batch, air_index=air_index, structure_weight=structure_weight)
        
        metrics["loss"] += outputs["loss"].item()
        metrics["recon"] += outputs["reconstruction_loss"].item()
        metrics["acc"] += outputs["accuracy"].item()
        metrics["air_acc"] += outputs["air_accuracy"].item()
        metrics["struct_acc"] += outputs["structure_accuracy"].item()
    
    n = len(loader)
    for k in metrics:
        metrics[k] /= n
    
    return metrics

print("Training functions defined!")

In [None]:
# ============================================================
# CELL 15: Main Training Loop (UPDATED)
# ============================================================

print("=" * 60)
print("Starting Training (v2 with EMA + weighted loss)")
print("=" * 60)

history = {
    "train_loss": [], "train_recon": [], "train_vq": [],
    "train_acc": [], "train_air_acc": [], "train_struct_acc": [],
    "val_loss": [], "val_recon": [],
    "val_acc": [], "val_air_acc": [], "val_struct_acc": [],
    "codebook_usage": [], "lr": [],
}

best_val_loss = float("inf")
start_time = time.time()

for epoch in range(EPOCHS):
    epoch_start = time.time()
    
    train_metrics = train_epoch(
        model, train_loader, optimizer, scaler, device,
        air_index=AIR_INDEX, structure_weight=STRUCTURE_WEIGHT, use_amp=USE_AMP
    )
    
    val_metrics = validate(
        model, val_loader, device,
        air_index=AIR_INDEX, structure_weight=STRUCTURE_WEIGHT, use_amp=USE_AMP
    )
    
    scheduler.step(val_metrics["loss"])
    current_lr = optimizer.param_groups[0]["lr"]
    
    # Track history
    history["train_loss"].append(train_metrics["loss"])
    history["train_recon"].append(train_metrics["recon"])
    history["train_vq"].append(train_metrics["vq"])
    history["train_acc"].append(train_metrics["acc"])
    history["train_air_acc"].append(train_metrics["air_acc"])
    history["train_struct_acc"].append(train_metrics["struct_acc"])
    history["val_loss"].append(val_metrics["loss"])
    history["val_recon"].append(val_metrics["recon"])
    history["val_acc"].append(val_metrics["acc"])
    history["val_air_acc"].append(val_metrics["air_acc"])
    history["val_struct_acc"].append(val_metrics["struct_acc"])
    history["codebook_usage"].append(train_metrics["codebook_usage"])
    history["lr"].append(current_lr)
    
    if val_metrics["loss"] < best_val_loss:
        best_val_loss = val_metrics["loss"]
        torch.save(model.state_dict(), f"{OUTPUT_DIR}/vqvae_best.pt")
    
    epoch_time = time.time() - epoch_start
    print(
        f"Epoch {epoch+1:3d}/{EPOCHS} | "
        f"Loss: {train_metrics['loss']:.3f} | "
        f"Struct: {train_metrics['struct_acc']:.1%} | "
        f"Val: {val_metrics['struct_acc']:.1%} | "
        f"CB: {train_metrics['codebook_usage']:.1%} | "
        f"{epoch_time:.0f}s"
    )

total_time = time.time() - start_time
print(f"\nTraining complete in {total_time/60:.1f} minutes")

In [None]:
# ============================================================
# CELL 16: Save Results
# ============================================================

torch.save(model.state_dict(), f"{OUTPUT_DIR}/vqvae_final.pt")
codebook = model.quantizer.codebook.cpu().numpy()
np.save(f"{OUTPUT_DIR}/codebook.npy", codebook)

with open(f"{OUTPUT_DIR}/training_history.json", "w") as f:
    json.dump(history, f, indent=2)

print("Results saved!")

In [None]:
# ============================================================
# CELL 17: Plot Training Curves (UPDATED)
# ============================================================

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Loss
ax = axes[0, 0]
ax.plot(history["train_loss"], label="Train")
ax.plot(history["val_loss"], label="Val")
ax.set_xlabel("Epoch")
ax.set_ylabel("Total Loss")
ax.set_title("Training and Validation Loss")
ax.legend()
ax.grid(True, alpha=0.3)

# Structure Accuracy (the key metric!)
ax = axes[0, 1]
ax.plot(history["train_struct_acc"], label="Train Structure")
ax.plot(history["val_struct_acc"], label="Val Structure")
ax.plot(history["train_air_acc"], label="Train Air", linestyle="--", alpha=0.5)
ax.set_xlabel("Epoch")
ax.set_ylabel("Accuracy")
ax.set_title("Block Prediction Accuracy (Structure vs Air)")
ax.legend()
ax.grid(True, alpha=0.3)

# Loss Components
ax = axes[1, 0]
ax.plot(history["train_recon"], label="Reconstruction")
ax.plot(history["train_vq"], label="VQ (commitment)")
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
ax.set_title("Loss Components")
ax.legend()
ax.grid(True, alpha=0.3)

# Codebook Usage
ax = axes[1, 1]
ax.plot(history["codebook_usage"], color="green")
ax.axhline(y=1.0, color="red", linestyle="--", label="100% usage")
ax.set_xlabel("Epoch")
ax.set_ylabel("Fraction Used")
ax.set_title("Codebook Utilization")
ax.set_ylim(0, 1.1)
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/training_curves.png", dpi=150)
plt.show()

print(f"\nFinal Structure Accuracy: Train {history['train_struct_acc'][-1]:.1%}, Val {history['val_struct_acc'][-1]:.1%}")
print(f"Codebook Usage: {history['codebook_usage'][-1]:.1%}")

In [None]:
# ============================================================
# CELL 18: Visualize Reconstructions
# ============================================================

def visualize_reconstruction(model, dataset, device, idx=0):
    model.eval()
    original = dataset[idx].unsqueeze(0).to(device)
    
    with torch.no_grad():
        outputs = model(original)
        reconstructed = outputs["logits"].argmax(dim=1)
    
    original = original.cpu().numpy()[0]
    reconstructed = reconstructed.cpu().numpy()[0]
    
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    slice_idx = 16
    
    for i, (ax_row, data, label) in enumerate([
        (axes[0], original, "Original"),
        (axes[1], reconstructed, "Reconstructed")
    ]):
        ax_row[0].imshow(data[slice_idx, :, :], cmap='tab20')
        ax_row[0].set_title(f'{label} (X slice {slice_idx})')
        ax_row[0].axis('off')
        
        ax_row[1].imshow(data[:, slice_idx, :], cmap='tab20')
        ax_row[1].set_title(f'{label} (Y slice {slice_idx})')
        ax_row[1].axis('off')
        
        ax_row[2].imshow(data[:, :, slice_idx], cmap='tab20')
        ax_row[2].set_title(f'{label} (Z slice {slice_idx})')
        ax_row[2].axis('off')
    
    accuracy = (original == reconstructed).mean()
    plt.suptitle(f'Reconstruction Accuracy: {accuracy:.1%}', fontsize=14)
    plt.tight_layout()
    plt.savefig(f"{OUTPUT_DIR}/reconstruction_{idx}.png", dpi=150)
    plt.show()
    return accuracy

for i in range(3):
    acc = visualize_reconstruction(model, val_dataset, device, idx=i)
    print(f"Sample {i}: {acc:.1%}")

In [None]:
# ============================================================
# CELL 19: Analyze Codebook
# ============================================================

@torch.no_grad()
def analyze_codebook(model, loader, device):
    model.eval()
    all_indices = []
    
    for batch in tqdm(loader, desc="Analyzing"):
        batch = batch.to(device)
        outputs = model(batch)
        all_indices.append(outputs["indices"].cpu().view(-1))
    
    all_indices = torch.cat(all_indices)
    usage = torch.bincount(all_indices, minlength=NUM_CODEBOOK_ENTRIES)
    return (usage.float() / usage.sum()).numpy()

codebook_usage = analyze_codebook(model, val_loader, device)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

ax = axes[0]
ax.bar(range(NUM_CODEBOOK_ENTRIES), sorted(codebook_usage, reverse=True))
ax.set_xlabel("Codebook Entry (sorted)")
ax.set_ylabel("Usage Frequency")
ax.set_title("Codebook Usage Distribution")
ax.set_yscale("log")

ax = axes[1]
used_codes = (codebook_usage > 0).sum()
top10 = sum(sorted(codebook_usage, reverse=True)[:10])
stats = f"""Codebook Statistics:

Total codes: {NUM_CODEBOOK_ENTRIES}
Used codes: {used_codes} ({used_codes/NUM_CODEBOOK_ENTRIES:.1%})
Dead codes: {NUM_CODEBOOK_ENTRIES - used_codes}

Top 10 codes: {top10:.1%}
Max usage: {max(codebook_usage):.3%}
"""
ax.text(0.1, 0.5, stats, transform=ax.transAxes, fontsize=12, verticalalignment='center', fontfamily='monospace')
ax.axis('off')

plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/codebook_analysis.png", dpi=150)
plt.show()

In [None]:
# ============================================================
# CELL 20: Summary
# ============================================================

print("=" * 60)
print("VQ-VAE v2 TRAINING COMPLETE!")
print("=" * 60)
print(f"\nKey improvements:")
print(f"  - EMA codebook updates (decay={EMA_DECAY})")
print(f"  - Dead code reset")
print(f"  - Weighted loss (structure={STRUCTURE_WEIGHT}x)")
print(f"\nResults:")
print(f"  - Structure accuracy: {history['val_struct_acc'][-1]:.1%}")
print(f"  - Codebook usage: {history['codebook_usage'][-1]:.1%}")
print(f"\nOutput files: vqvae_best.pt, vqvae_final.pt, codebook.npy")