# VQ-VAE: Learning to Compress Minecraft Structures

## What This Notebook Does

This notebook trains a **Vector Quantized Variational AutoEncoder (VQ-VAE)** to learn a compressed representation of Minecraft structures. After training, the model can:

1. **Encode** any 32×32×32 structure into a small grid of discrete codes
2. **Decode** those codes back into a full structure

This compressed representation will later be used for text-to-structure generation.

## The Pipeline So Far

1. **Phase 1 (Done)**: Prepare training data - 4,462 Minecraft builds as 3D arrays
2. **Phase 2 (Done)**: Train Block2Vec - learned 32-dim embeddings for each block type
3. **Phase 3 (This Notebook)**: Train VQ-VAE - learn to compress/decompress structures
4. **Phase 4**: Connect text descriptions to the VQ-VAE
5. **Phase 5**: Generate new structures from text!

---

# Part 1: Understanding AutoEncoders

## What is an AutoEncoder?

An **autoencoder** is a neural network that learns to compress and reconstruct data:

```
Input → [Encoder] → Latent Code → [Decoder] → Reconstructed Input
  ↑                                                    ↓
  └──────────── Should be as similar as possible ──────┘
```

For Minecraft structures:
- **Input**: 32×32×32 blocks (32,768 positions)
- **Latent Code**: 4×4×4 = 64 vectors (much smaller!)
- **Output**: Reconstructed 32×32×32 structure

## Why Compress?

1. **Efficient Generation**: Generating 64 codes is much easier than 32,768 blocks
2. **Learning Patterns**: The encoder learns to recognize building patterns
3. **Denoising**: Reconstruction smooths out noise and inconsistencies

## The Bottleneck Forces Learning

The latent code is MUCH smaller than the input. This forces the encoder to:
- Identify what's important
- Discard irrelevant details
- Learn efficient representations

If we just copied the input, we wouldn't learn anything useful!

---

# Part 2: The "VQ" in VQ-VAE - Vector Quantization

## Continuous vs Discrete

A regular autoencoder has **continuous** latent codes - any real number is allowed.

VQ-VAE uses **discrete** codes - each position must pick from a fixed set of options.

Think of it like this:
- **Continuous**: Any shade of color (infinite options)
- **Discrete**: Pick from a palette of 512 specific colors

## The Codebook

VQ-VAE maintains a **codebook** - a learned lookup table of vectors:

```
Codebook (512 vectors, each 256-dimensional):
┌─────────────────────────────────────┐
│ Code 0:   [0.2, -0.1, 0.5, ...]     │ → Maybe represents "stone walls"
│ Code 1:   [0.8,  0.3, -0.2, ...]    │ → Maybe represents "wooden floors"
│ Code 2:   [-0.4, 0.7, 0.1, ...]     │ → Maybe represents "glass windows"
│ ...                                  │
│ Code 511: [0.1, 0.0, 0.3, ...]      │ → Maybe represents "empty air"
└─────────────────────────────────────┘
```

## How Quantization Works

1. **Encoder outputs** a continuous vector at each position
2. **Find nearest** codebook entry for each position
3. **Replace** encoder output with the codebook entry
4. **Decoder** receives the codebook vectors

```
Encoder output: [0.21, -0.08, 0.52, ...]  (continuous)
                        ↓
              Find nearest in codebook
                        ↓
Quantized:      [0.2, -0.1, 0.5, ...]     (Code 0)
                        ↓
                    Decoder
```

## Why Discrete is Better for Generation

1. **Finite possibilities**: With 512 codes at 64 positions, there are 512^64 possible structures (still huge, but manageable)
2. **No "blurry" outputs**: Discrete codes produce sharp, distinct outputs
3. **Easy to sample**: Just pick codes 0-511 at each position
4. **Works with language models**: Text models naturally output discrete tokens

---

# Part 3: The Straight-Through Estimator

## The Problem: Argmin is Not Differentiable

Finding the nearest codebook entry uses `argmin` (which index has minimum distance?)

But `argmin` has **zero gradient** everywhere! Neural networks learn through gradients.

```
Forward:  Encoder output → argmin → Codebook entry → Decoder
Backward: How should encoder change? No gradient flows through argmin!
```

## The Solution: Copy the Gradients

We use a clever trick called the **straight-through estimator**:

- **Forward pass**: Use the quantized codebook vector
- **Backward pass**: Pretend quantization didn't happen, copy gradients directly

In code:
```python
# Forward: z_q is used (quantized)
# Backward: gradients go to z_e (as if z_q = z_e)
z_q = z_e + (z_q - z_e).detach()
```

The `.detach()` stops gradients through `(z_q - z_e)`, so:
- Forward: `z_e + (z_q - z_e) = z_q` ✓
- Backward: Only `z_e` receives gradients ✓

## Why This Works

The encoder learns to output vectors CLOSE to codebook entries.
The decoder provides feedback: "If you'd output this slightly different vector, the reconstruction would be better."
The encoder adjusts, and over time, encoder outputs cluster around good codebook entries.

---

# Part 4: The Three VQ-VAE Losses

## 1. Reconstruction Loss

**Goal**: The output should match the input.

For each position, we predict which block should be there (out of 3,717 options).
This is **cross-entropy loss** - the standard loss for classification.

```
True block: oak_planks (token 153)
Model predicts: [0.01, 0.02, ..., 0.85 (token 153), ..., 0.01]
Loss = -log(0.85) = 0.16  (low loss = good!)

If model predicted 0.01 for token 153:
Loss = -log(0.01) = 4.6  (high loss = bad!)
```

## 2. Codebook Loss

**Goal**: Move codebook vectors toward encoder outputs.

```python
codebook_loss = MSE(codebook_vectors, encoder_output.detach())
```

The `.detach()` means: "Update the codebook, but don't change the encoder."

This makes codebook vectors migrate toward where the encoder is pointing.

## 3. Commitment Loss

**Goal**: Keep encoder outputs close to codebook vectors.

```python
commitment_loss = MSE(encoder_output, codebook_vectors.detach())
```

This prevents the encoder from wandering too far from any codebook entry.

## Total Loss

```
total_loss = reconstruction_loss + codebook_loss + β × commitment_loss
```

Where β (beta) is typically 0.25. The commitment loss is weighted lower because we want the codebook to "follow" the encoder more than vice versa.

---

# Part 5: Architecture for Minecraft

## The Full Pipeline

```
Input: 32×32×32 block IDs
        ↓
Block Embedding: Convert each ID to 32-dim vector (using Block2Vec!)
        ↓
Shape: 32×32×32×32 (spatial × embedding_dim)
        ↓
Encoder: 3D Convolutions that downsample
  - 32→16→8→4 in each spatial dimension
  - 32→64→128→256 channels
        ↓
Shape: 4×4×4×256
        ↓
Vector Quantizer: Each of 64 positions picks from 512 codebook entries
        ↓
Shape: 4×4×4×256 (same, but now discrete)
        ↓
Decoder: 3D Transposed Convolutions that upsample
  - 4→8→16→32 in each spatial dimension
  - 256→128→64→3717 channels
        ↓
Output: 32×32×32×3717 (logits for each block type at each position)
```

## Compression Ratio

- **Input**: 32×32×32 = 32,768 positions, each needs log2(3717) ≈ 12 bits
- **Latent**: 4×4×4 = 64 positions, each needs log2(512) = 9 bits
- **Compression**: 32,768 × 12 / (64 × 9) ≈ **680:1**

That's a massive compression! The model must learn efficient patterns.

---

# Part 6: Let's Start Coding!

Now that you understand the concepts, let's implement it.

In [None]:
# ============================================================
# CELL 1: Imports and Setup
# ============================================================

import json
import random
import time
from pathlib import Path
from typing import Dict, List, Tuple, Any, Optional

import h5py
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm

# Check if GPU is available
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# ============================================================
# CELL 2: Configuration
# ============================================================

# === Data Paths ===
DATA_DIR = "/kaggle/input/minecraft-schematics/minecraft_splits/splits/train"
VAL_DIR = "/kaggle/input/minecraft-schematics/minecraft_splits/splits/val"
VOCAB_PATH = "/kaggle/input/minecraft-schematics/tok2block.json"  # At root level!
EMBEDDINGS_PATH = "/kaggle/input/block2vec-embeddings/block_embeddings.npy"  # From Phase 2!
OUTPUT_DIR = "/kaggle/working"

# === Model Architecture ===
BLOCK_EMBEDDING_DIM = 32    # Must match Block2Vec embeddings
HIDDEN_DIMS = [64, 128, 256]  # Encoder channel progression
LATENT_DIM = 256            # Codebook vector dimension
NUM_CODEBOOK_ENTRIES = 512  # Number of codes in codebook
COMMITMENT_COST = 0.25      # Beta for commitment loss

# === Training Hyperparameters ===
EPOCHS = 25
BATCH_SIZE = 8              # Reduced for memory (using mixed precision)
LEARNING_RATE = 1e-4        # Lower LR for stable training
WEIGHT_DECAY = 1e-5
USE_AMP = True              # Automatic Mixed Precision (fp16) - saves memory!

# === Other ===
SEED = 42
NUM_WORKERS = 2             # Parallel data loading

print("Configuration loaded!")
print(f"  Latent grid: 4×4×4 = 64 positions")
print(f"  Codebook: {NUM_CODEBOOK_ENTRIES} entries × {LATENT_DIM} dims")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Epochs: {EPOCHS}")
print(f"  Mixed Precision (AMP): {USE_AMP}")

In [None]:
# ============================================================
# CELL 3: Load Vocabulary and Pre-trained Embeddings
# ============================================================

# Load vocabulary
with open(VOCAB_PATH, 'r') as f:
    tok2block = {int(k): v for k, v in json.load(f).items()}

VOCAB_SIZE = len(tok2block)
print(f"Vocabulary size: {VOCAB_SIZE} block types")

# Load pre-trained Block2Vec embeddings
pretrained_embeddings = np.load(EMBEDDINGS_PATH)
print(f"Loaded embeddings: {pretrained_embeddings.shape}")

# Verify dimensions match
assert pretrained_embeddings.shape == (VOCAB_SIZE, BLOCK_EMBEDDING_DIM), \
    f"Embedding shape mismatch! Expected ({VOCAB_SIZE}, {BLOCK_EMBEDDING_DIM})"

print("\nEmbeddings will be used to convert block IDs → 32-dim vectors")

---

# Part 7: The Dataset

For VQ-VAE, we simply load complete 32×32×32 structures.

In [None]:
# ============================================================
# CELL 4: Dataset Class
# ============================================================

class VQVAEDataset(Dataset):
    """
    Dataset that loads complete 32×32×32 Minecraft structures.
    
    Unlike Block2VecDataset (which yielded individual block pairs),
    this returns whole structures for reconstruction learning.
    """
    
    def __init__(
        self,
        data_dir: str,
        augment: bool = False,
        seed: int = 42,
    ):
        self.data_dir = Path(data_dir)
        self.augment = augment
        self.rng = random.Random(seed)
        
        # Find all H5 files
        self.h5_files = sorted(self.data_dir.glob("*.h5"))
        if not self.h5_files:
            raise ValueError(f"No H5 files found in {data_dir}")
        
        print(f"Found {len(self.h5_files)} structures in {data_dir}")
    
    def __len__(self) -> int:
        return len(self.h5_files)
    
    def __getitem__(self, idx: int) -> torch.Tensor:
        """Load and return a single structure."""
        h5_path = self.h5_files[idx]
        
        with h5py.File(h5_path, 'r') as f:
            key = list(f.keys())[0]
            structure = f[key][:]
        
        structure = structure.astype(np.int64)
        
        # Optional augmentation: random rotations around Y axis
        if self.augment:
            k = self.rng.randint(0, 3)  # 0, 90, 180, or 270 degrees
            if k > 0:
                structure = np.rot90(structure, k=k, axes=(0, 2))
            
            # Random horizontal flip
            if self.rng.random() > 0.5:
                structure = np.flip(structure, axis=2)
            
            structure = np.ascontiguousarray(structure)
        
        return torch.from_numpy(structure).long()


# Create datasets
train_dataset = VQVAEDataset(DATA_DIR, augment=True, seed=SEED)
val_dataset = VQVAEDataset(VAL_DIR, augment=False, seed=SEED)

print(f"\nTrain: {len(train_dataset)} structures")
print(f"Val: {len(val_dataset)} structures")

In [None]:
# ============================================================
# CELL 5: Create DataLoaders
# ============================================================

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=(device == "cuda"),
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=(device == "cuda"),
)

print(f"Train batches per epoch: {len(train_loader)}")
print(f"Val batches per epoch: {len(val_loader)}")

---

# Part 8: The VQ-VAE Model

Now let's define the neural network. This is more complex than Block2Vec!

In [None]:
# ============================================================
# CELL 6: Vector Quantizer
# ============================================================

class VectorQuantizer(nn.Module):
    """
    Vector Quantization layer with straight-through gradient estimator.
    
    This is the heart of VQ-VAE - it discretizes the latent space.
    
    Args:
        num_embeddings: Number of codebook entries (K=512)
        embedding_dim: Dimension of each codebook vector (D=256)
        commitment_cost: Weight for commitment loss (beta=0.25)
    """
    
    def __init__(
        self,
        num_embeddings: int = 512,
        embedding_dim: int = 256,
        commitment_cost: float = 0.25,
    ):
        super().__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.commitment_cost = commitment_cost
        
        # The codebook: K vectors of dimension D
        self.codebook = nn.Embedding(num_embeddings, embedding_dim)
        
        # Initialize with small uniform values
        self.codebook.weight.data.uniform_(-1.0 / num_embeddings, 1.0 / num_embeddings)
    
    def forward(self, z_e: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Quantize encoder outputs to nearest codebook entries.
        
        Args:
            z_e: Encoder output [batch, channels, depth, height, width]
                 channels should equal embedding_dim
        
        Returns:
            z_q: Quantized output (same shape as z_e)
            vq_loss: Codebook + commitment loss
            indices: Which codebook entry was selected [batch, D, H, W]
        """
        # z_e shape: [B, C, D, H, W] where C = embedding_dim
        
        # Reshape to [B*D*H*W, C] for distance calculation
        z_e_permuted = z_e.permute(0, 2, 3, 4, 1).contiguous()  # [B, D, H, W, C]
        flat_z_e = z_e_permuted.view(-1, self.embedding_dim)    # [B*D*H*W, C]
        
        # Compute distances using the identity:
        # ||z - c||² = ||z||² + ||c||² - 2*z·c
        z_e_sq = (flat_z_e ** 2).sum(dim=1, keepdim=True)          # [N, 1]
        codebook_sq = (self.codebook.weight ** 2).sum(dim=1).unsqueeze(0)  # [1, K]
        dot_product = torch.mm(flat_z_e, self.codebook.weight.t())  # [N, K]
        distances = z_e_sq + codebook_sq - 2 * dot_product          # [N, K]
        
        # Find nearest codebook entry
        indices = distances.argmin(dim=1)  # [N]
        
        # Look up codebook vectors
        z_q_flat = self.codebook(indices)  # [N, C]
        z_q_permuted = z_q_flat.view(z_e_permuted.shape)  # [B, D, H, W, C]
        
        # === Compute VQ Loss ===
        # Codebook loss: move codebook toward encoder (stop gradient on z_e)
        codebook_loss = F.mse_loss(z_q_permuted, z_e_permuted.detach())
        
        # Commitment loss: keep encoder close to codebook (stop gradient on z_q)
        commitment_loss = F.mse_loss(z_e_permuted, z_q_permuted.detach())
        
        vq_loss = codebook_loss + self.commitment_cost * commitment_loss
        
        # === Straight-Through Estimator ===
        # Forward: use z_q | Backward: gradients flow to z_e
        z_q_st = z_e_permuted + (z_q_permuted - z_e_permuted).detach()
        
        # Permute back to [B, C, D, H, W]
        z_q = z_q_st.permute(0, 4, 1, 2, 3).contiguous()
        
        # Reshape indices to spatial grid
        indices = indices.view(z_e_permuted.shape[:-1])  # [B, D, H, W]
        
        return z_q, vq_loss, indices


print("VectorQuantizer defined!")

In [None]:
# ============================================================
# CELL 7: Residual Block
# ============================================================

class ResidualBlock3D(nn.Module):
    """
    3D Residual block for encoder/decoder.
    
    Residual connections help gradients flow through deep networks:
    output = conv(conv(input)) + input
    
    If the convolutions learn nothing, the output = input (identity).
    This makes training easier for deep networks.
    """
    
    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm3d(in_channels)
        self.bn2 = nn.BatchNorm3d(out_channels)
        
        # If dimensions change, need a 1x1 conv for the skip connection
        if in_channels != out_channels:
            self.skip = nn.Conv3d(in_channels, out_channels, kernel_size=1)
        else:
            self.skip = nn.Identity()
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = self.skip(x)
        
        out = self.bn1(x)
        out = F.relu(out)
        out = self.conv1(out)
        
        out = self.bn2(out)
        out = F.relu(out)
        out = self.conv2(out)
        
        return out + identity  # The residual connection!


print("ResidualBlock3D defined!")

In [None]:
# ============================================================
# CELL 8: Encoder
# ============================================================

class Encoder(nn.Module):
    """
    3D CNN Encoder that compresses 32×32×32 to 4×4×4.
    
    Uses strided convolutions to downsample by 2× at each layer.
    
    Architecture:
        32×32×32 (32ch) → 16×16×16 (64ch) → 8×8×8 (128ch) → 4×4×4 (256ch)
    """
    
    def __init__(
        self,
        in_channels: int = 32,
        hidden_dims: List[int] = None,
        latent_dim: int = 256,
    ):
        super().__init__()
        
        if hidden_dims is None:
            hidden_dims = [64, 128, 256]
        
        layers = []
        current_channels = in_channels
        
        for hidden_dim in hidden_dims:
            layers.extend([
                # Strided conv to downsample 2×
                nn.Conv3d(current_channels, hidden_dim, kernel_size=4, stride=2, padding=1),
                nn.BatchNorm3d(hidden_dim),
                nn.ReLU(inplace=True),
                # Residual block for more capacity
                ResidualBlock3D(hidden_dim, hidden_dim),
            ])
            current_channels = hidden_dim
        
        # Final projection to latent dimension
        layers.append(nn.Conv3d(current_channels, latent_dim, kernel_size=3, padding=1))
        
        self.encoder = nn.Sequential(*layers)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.encoder(x)


print("Encoder defined!")

In [None]:
# ============================================================
# CELL 9: Decoder
# ============================================================

class Decoder(nn.Module):
    """
    3D CNN Decoder that expands 4×4×4 back to 32×32×32.
    
    Uses transposed convolutions to upsample by 2× at each layer.
    
    Architecture:
        4×4×4 (256ch) → 8×8×8 (128ch) → 16×16×16 (64ch) → 32×32×32 (vocab_size)
    """
    
    def __init__(
        self,
        latent_dim: int = 256,
        hidden_dims: List[int] = None,
        num_blocks: int = 3717,
    ):
        super().__init__()
        
        if hidden_dims is None:
            hidden_dims = [256, 128, 64]
        
        layers = []
        current_channels = latent_dim
        
        for hidden_dim in hidden_dims:
            layers.extend([
                ResidualBlock3D(current_channels, current_channels),
                # Transposed conv to upsample 2×
                nn.ConvTranspose3d(current_channels, hidden_dim, kernel_size=4, stride=2, padding=1),
                nn.BatchNorm3d(hidden_dim),
                nn.ReLU(inplace=True),
            ])
            current_channels = hidden_dim
        
        # Final prediction layer: output logits for each block type
        layers.append(nn.Conv3d(current_channels, num_blocks, kernel_size=3, padding=1))
        
        self.decoder = nn.Sequential(*layers)
    
    def forward(self, z_q: torch.Tensor) -> torch.Tensor:
        return self.decoder(z_q)


print("Decoder defined!")

In [None]:
# ============================================================
# CELL 10: Full VQ-VAE Model
# ============================================================

class VQVAE(nn.Module):
    """
    Vector Quantized Variational AutoEncoder for Minecraft structures.
    
    Full pipeline:
        1. Embed input blocks using pre-trained Block2Vec
        2. Encode to compressed latent grid
        3. Quantize each position to nearest codebook entry
        4. Decode to predict block at each position
    """
    
    def __init__(
        self,
        vocab_size: int,
        block_embedding_dim: int,
        hidden_dims: List[int],
        latent_dim: int,
        num_codebook_entries: int,
        commitment_cost: float,
        pretrained_embeddings: np.ndarray,
    ):
        super().__init__()
        
        self.vocab_size = vocab_size
        self.latent_dim = latent_dim
        self.num_codebook_entries = num_codebook_entries
        
        # Block embedding layer (using pre-trained Block2Vec!)
        self.block_embeddings = nn.Embedding(vocab_size, block_embedding_dim)
        self.block_embeddings.weight.data.copy_(torch.from_numpy(pretrained_embeddings))
        self.block_embeddings.weight.requires_grad = False  # Freeze - already trained!
        
        # Encoder
        self.encoder = Encoder(
            in_channels=block_embedding_dim,
            hidden_dims=hidden_dims,
            latent_dim=latent_dim,
        )
        
        # Vector Quantizer
        self.quantizer = VectorQuantizer(
            num_embeddings=num_codebook_entries,
            embedding_dim=latent_dim,
            commitment_cost=commitment_cost,
        )
        
        # Decoder
        self.decoder = Decoder(
            latent_dim=latent_dim,
            hidden_dims=list(reversed(hidden_dims)),
            num_blocks=vocab_size,
        )
    
    def forward(self, block_ids: torch.Tensor) -> Dict[str, Any]:
        """
        Full forward pass.
        
        Args:
            block_ids: Block token IDs [batch, 32, 32, 32]
        
        Returns:
            Dictionary with logits, vq_loss, indices, etc.
        """
        # Step 1: Embed blocks
        # [B, 32, 32, 32] → [B, 32, 32, 32, 32] (last dim is embedding)
        embedded = self.block_embeddings(block_ids)
        
        # Permute to channel-first: [B, 32, 32, 32, 32] → [B, 32, 32, 32, 32]
        # Wait, that looks the same! Let me clarify:
        # Input:  [B, D, H, W, C] where D=H=W=32, C=32
        # Output: [B, C, D, H, W]
        embedded = embedded.permute(0, 4, 1, 2, 3).contiguous()
        
        # Step 2: Encode
        z_e = self.encoder(embedded)  # [B, 256, 4, 4, 4]
        
        # Step 3: Quantize
        z_q, vq_loss, indices = self.quantizer(z_e)
        
        # Step 4: Decode
        logits = self.decoder(z_q)  # [B, vocab_size, 32, 32, 32]
        
        return {
            "logits": logits,
            "vq_loss": vq_loss,
            "indices": indices,
            "z_e": z_e,
            "z_q": z_q,
        }
    
    def compute_loss(
        self,
        block_ids: torch.Tensor,
        ignore_index: int = -100,
    ) -> Dict[str, torch.Tensor]:
        """
        Compute training loss.
        
        Returns:
            Dictionary with loss, reconstruction_loss, vq_loss, accuracy
        """
        outputs = self(block_ids)
        
        # Reconstruction loss (cross-entropy)
        # logits: [B, vocab_size, 32, 32, 32] → [B, 32, 32, 32, vocab_size]
        logits = outputs["logits"].permute(0, 2, 3, 4, 1).contiguous()
        
        # Flatten for cross-entropy
        logits_flat = logits.view(-1, self.vocab_size)
        targets_flat = block_ids.view(-1)
        
        reconstruction_loss = F.cross_entropy(
            logits_flat,
            targets_flat,
            ignore_index=ignore_index,
        )
        
        # Total loss
        total_loss = reconstruction_loss + outputs["vq_loss"]
        
        # Compute accuracy
        with torch.no_grad():
            predictions = logits_flat.argmax(dim=1)
            if ignore_index >= 0:
                mask = targets_flat != ignore_index
                correct = (predictions[mask] == targets_flat[mask]).float().sum()
                total = mask.sum()
            else:
                correct = (predictions == targets_flat).float().sum()
                total = targets_flat.numel()
            accuracy = correct / total if total > 0 else torch.tensor(0.0)
        
        return {
            "loss": total_loss,
            "reconstruction_loss": reconstruction_loss,
            "vq_loss": outputs["vq_loss"],
            "accuracy": accuracy,
            "indices": outputs["indices"],
        }


print("VQVAE model defined!")

In [None]:
# ============================================================
# CELL 11: Create Model
# ============================================================

# Set random seeds
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

# Create model
model = VQVAE(
    vocab_size=VOCAB_SIZE,
    block_embedding_dim=BLOCK_EMBEDDING_DIM,
    hidden_dims=HIDDEN_DIMS,
    latent_dim=LATENT_DIM,
    num_codebook_entries=NUM_CODEBOOK_ENTRIES,
    commitment_cost=COMMITMENT_COST,
    pretrained_embeddings=pretrained_embeddings,
)

model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Model on {device}")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Frozen (Block2Vec): {total_params - trainable_params:,}")

In [None]:
# ============================================================
# CELL 12: Test Forward Pass
# ============================================================

# Quick test to make sure everything works
print("Testing forward pass...")

with torch.no_grad():
    # Create a dummy batch
    test_batch = torch.randint(0, VOCAB_SIZE, (2, 32, 32, 32)).to(device)
    
    # Forward pass
    outputs = model.compute_loss(test_batch)
    
    print(f"  Input shape: {test_batch.shape}")
    print(f"  Loss: {outputs['loss'].item():.4f}")
    print(f"  Reconstruction loss: {outputs['reconstruction_loss'].item():.4f}")
    print(f"  VQ loss: {outputs['vq_loss'].item():.4f}")
    print(f"  Accuracy: {outputs['accuracy'].item():.4f}")
    print(f"  Indices shape: {outputs['indices'].shape}")

print("\nForward pass successful!")

In [None]:
# ============================================================
# CELL 13: Create Optimizer, Scheduler, and Scaler
# ============================================================

# AdamW optimizer
optimizer = optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),  # Only trainable params
    lr=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
)

# Learning rate scheduler: reduce LR when validation loss plateaus
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,      # Multiply LR by 0.5 when triggered
    patience=5,       # Wait 5 epochs before reducing
)

# GradScaler for mixed precision training
# This scales gradients to prevent underflow in fp16
scaler = torch.amp.GradScaler('cuda', enabled=USE_AMP)

print(f"Optimizer: AdamW (lr={LEARNING_RATE})")
print(f"Scheduler: ReduceLROnPlateau (factor=0.5, patience=5)")
print(f"Mixed Precision: {'Enabled' if USE_AMP else 'Disabled'}")

---

# Part 9: Training Loop

In [None]:
# ============================================================
# CELL 14: Training Loop (with Mixed Precision)
# ============================================================

def train_epoch(model, loader, optimizer, scaler, device, use_amp=True):
    """Train for one epoch with optional mixed precision."""
    model.train()
    total_loss = 0
    total_recon = 0
    total_vq = 0
    total_acc = 0
    num_batches = 0
    
    # Track codebook usage
    all_indices = []
    
    for batch in tqdm(loader, desc="Training", leave=False):
        batch = batch.to(device)
        
        # Forward pass with automatic mixed precision
        with torch.amp.autocast('cuda', enabled=use_amp):
            outputs = model.compute_loss(batch)
            loss = outputs["loss"]
        
        # Backward pass with gradient scaling
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        
        # Unscale gradients for clipping
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        # Optimizer step with scaler
        scaler.step(optimizer)
        scaler.update()
        
        # Track metrics (these are already fp32)
        total_loss += loss.item()
        total_recon += outputs["reconstruction_loss"].item()
        total_vq += outputs["vq_loss"].item()
        total_acc += outputs["accuracy"].item()
        num_batches += 1
        
        # Track codebook usage
        all_indices.append(outputs["indices"].cpu())
    
    # Compute codebook utilization
    all_indices = torch.cat([idx.view(-1) for idx in all_indices])
    unique_codes = len(torch.unique(all_indices))
    
    return {
        "loss": total_loss / num_batches,
        "recon_loss": total_recon / num_batches,
        "vq_loss": total_vq / num_batches,
        "accuracy": total_acc / num_batches,
        "codebook_usage": unique_codes / NUM_CODEBOOK_ENTRIES,
    }


@torch.no_grad()
def validate(model, loader, device, use_amp=True):
    """Validate the model with optional mixed precision."""
    model.eval()
    total_loss = 0
    total_recon = 0
    total_acc = 0
    num_batches = 0
    
    for batch in tqdm(loader, desc="Validating", leave=False):
        batch = batch.to(device)
        
        with torch.amp.autocast('cuda', enabled=use_amp):
            outputs = model.compute_loss(batch)
        
        total_loss += outputs["loss"].item()
        total_recon += outputs["reconstruction_loss"].item()
        total_acc += outputs["accuracy"].item()
        num_batches += 1
    
    return {
        "loss": total_loss / num_batches,
        "recon_loss": total_recon / num_batches,
        "accuracy": total_acc / num_batches,
    }


print("Training functions defined (with mixed precision support)!")

In [None]:
# ============================================================
# CELL 15: Main Training Loop
# ============================================================

print("=" * 60)
print("Starting Training")
print("=" * 60)

# Track metrics
history = {
    "train_loss": [],
    "train_recon": [],
    "train_vq": [],
    "train_acc": [],
    "val_loss": [],
    "val_recon": [],
    "val_acc": [],
    "codebook_usage": [],
    "lr": [],
}

best_val_loss = float("inf")
start_time = time.time()

for epoch in range(EPOCHS):
    epoch_start = time.time()
    
    # Train (with mixed precision)
    train_metrics = train_epoch(model, train_loader, optimizer, scaler, device, use_amp=USE_AMP)
    
    # Validate (with mixed precision)
    val_metrics = validate(model, val_loader, device, use_amp=USE_AMP)
    
    # Update scheduler
    scheduler.step(val_metrics["loss"])
    current_lr = optimizer.param_groups[0]["lr"]
    
    # Track history
    history["train_loss"].append(train_metrics["loss"])
    history["train_recon"].append(train_metrics["recon_loss"])
    history["train_vq"].append(train_metrics["vq_loss"])
    history["train_acc"].append(train_metrics["accuracy"])
    history["val_loss"].append(val_metrics["loss"])
    history["val_recon"].append(val_metrics["recon_loss"])
    history["val_acc"].append(val_metrics["accuracy"])
    history["codebook_usage"].append(train_metrics["codebook_usage"])
    history["lr"].append(current_lr)
    
    # Save best model
    if val_metrics["loss"] < best_val_loss:
        best_val_loss = val_metrics["loss"]
        torch.save(model.state_dict(), f"{OUTPUT_DIR}/vqvae_best.pt")
    
    # Print progress
    epoch_time = time.time() - epoch_start
    print(
        f"Epoch {epoch+1:3d}/{EPOCHS} | "
        f"Train: {train_metrics['loss']:.4f} (acc {train_metrics['accuracy']:.3f}) | "
        f"Val: {val_metrics['loss']:.4f} (acc {val_metrics['accuracy']:.3f}) | "
        f"CB: {train_metrics['codebook_usage']:.1%} | "
        f"LR: {current_lr:.2e} | "
        f"{epoch_time:.0f}s"
    )

total_time = time.time() - start_time
print("\n" + "=" * 60)
print(f"Training complete in {total_time/60:.1f} minutes")
print(f"Best validation loss: {best_val_loss:.4f}")
print("=" * 60)

In [None]:
# ============================================================
# CELL 16: Save Results
# ============================================================

# Save final model
torch.save(model.state_dict(), f"{OUTPUT_DIR}/vqvae_final.pt")
print(f"Final model saved to {OUTPUT_DIR}/vqvae_final.pt")

# Save codebook
codebook = model.quantizer.codebook.weight.data.cpu().numpy()
np.save(f"{OUTPUT_DIR}/codebook.npy", codebook)
print(f"Codebook saved: {codebook.shape}")

# Save training history
with open(f"{OUTPUT_DIR}/training_history.json", "w") as f:
    json.dump(history, f, indent=2)
print("Training history saved")

---

# Part 10: Visualizing Results

In [None]:
# ============================================================
# CELL 17: Plot Training Curves
# ============================================================

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Loss
ax = axes[0, 0]
ax.plot(history["train_loss"], label="Train", linewidth=2)
ax.plot(history["val_loss"], label="Val", linewidth=2)
ax.set_xlabel("Epoch")
ax.set_ylabel("Total Loss")
ax.set_title("Training and Validation Loss")
ax.legend()
ax.grid(True, alpha=0.3)

# Accuracy
ax = axes[0, 1]
ax.plot(history["train_acc"], label="Train", linewidth=2)
ax.plot(history["val_acc"], label="Val", linewidth=2)
ax.set_xlabel("Epoch")
ax.set_ylabel("Accuracy")
ax.set_title("Block Prediction Accuracy")
ax.legend()
ax.grid(True, alpha=0.3)

# Reconstruction vs VQ Loss
ax = axes[1, 0]
ax.plot(history["train_recon"], label="Reconstruction", linewidth=2)
ax.plot(history["train_vq"], label="VQ", linewidth=2)
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
ax.set_title("Loss Components (Training)")
ax.legend()
ax.grid(True, alpha=0.3)

# Codebook Usage
ax = axes[1, 1]
ax.plot(history["codebook_usage"], linewidth=2, color="green")
ax.axhline(y=1.0, color="red", linestyle="--", label="100% usage")
ax.set_xlabel("Epoch")
ax.set_ylabel("Fraction Used")
ax.set_title("Codebook Utilization")
ax.set_ylim(0, 1.1)
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/training_curves.png", dpi=150)
plt.show()

print(f"\nFinal metrics:")
print(f"  Train accuracy: {history['train_acc'][-1]:.3f}")
print(f"  Val accuracy: {history['val_acc'][-1]:.3f}")
print(f"  Codebook usage: {history['codebook_usage'][-1]:.1%}")

In [None]:
# ============================================================
# CELL 18: Visualize Reconstructions
# ============================================================

def visualize_reconstruction(model, dataset, device, idx=0):
    """Visualize original vs reconstructed structure."""
    model.eval()
    
    # Get a sample
    original = dataset[idx].unsqueeze(0).to(device)
    
    with torch.no_grad():
        outputs = model(original)
        reconstructed = outputs["logits"].argmax(dim=1)
    
    original = original.cpu().numpy()[0]
    reconstructed = reconstructed.cpu().numpy()[0]
    
    # Compare center slices
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    
    # X-slice (center)
    slice_idx = 16
    
    axes[0, 0].imshow(original[slice_idx, :, :], cmap='tab20')
    axes[0, 0].set_title(f'Original (X slice {slice_idx})')
    axes[0, 0].axis('off')
    
    axes[1, 0].imshow(reconstructed[slice_idx, :, :], cmap='tab20')
    axes[1, 0].set_title(f'Reconstructed (X slice {slice_idx})')
    axes[1, 0].axis('off')
    
    # Y-slice (center)
    axes[0, 1].imshow(original[:, slice_idx, :], cmap='tab20')
    axes[0, 1].set_title(f'Original (Y slice {slice_idx})')
    axes[0, 1].axis('off')
    
    axes[1, 1].imshow(reconstructed[:, slice_idx, :], cmap='tab20')
    axes[1, 1].set_title(f'Reconstructed (Y slice {slice_idx})')
    axes[1, 1].axis('off')
    
    # Z-slice (center)
    axes[0, 2].imshow(original[:, :, slice_idx], cmap='tab20')
    axes[0, 2].set_title(f'Original (Z slice {slice_idx})')
    axes[0, 2].axis('off')
    
    axes[1, 2].imshow(reconstructed[:, :, slice_idx], cmap='tab20')
    axes[1, 2].set_title(f'Reconstructed (Z slice {slice_idx})')
    axes[1, 2].axis('off')
    
    # Compute accuracy for this sample
    accuracy = (original == reconstructed).mean()
    plt.suptitle(f'Reconstruction Accuracy: {accuracy:.1%}', fontsize=14)
    
    plt.tight_layout()
    plt.savefig(f"{OUTPUT_DIR}/reconstruction_{idx}.png", dpi=150)
    plt.show()
    
    return accuracy


# Visualize a few samples
print("Visualizing reconstructions...")
for i in range(3):
    acc = visualize_reconstruction(model, val_dataset, device, idx=i)
    print(f"Sample {i}: {acc:.1%} accuracy")

In [None]:
# ============================================================
# CELL 19: Analyze Codebook Usage
# ============================================================

@torch.no_grad()
def analyze_codebook(model, loader, device):
    """Analyze how the codebook is being used."""
    model.eval()
    
    all_indices = []
    
    for batch in tqdm(loader, desc="Analyzing codebook"):
        batch = batch.to(device)
        outputs = model(batch)
        all_indices.append(outputs["indices"].cpu().view(-1))
    
    all_indices = torch.cat(all_indices)
    
    # Count usage of each code
    usage = torch.bincount(all_indices, minlength=NUM_CODEBOOK_ENTRIES)
    usage = usage.float() / usage.sum()
    
    return usage.numpy()


# Analyze
codebook_usage = analyze_codebook(model, val_loader, device)

# Plot
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Histogram
ax = axes[0]
ax.bar(range(NUM_CODEBOOK_ENTRIES), sorted(codebook_usage, reverse=True))
ax.set_xlabel("Codebook Entry (sorted by usage)")
ax.set_ylabel("Usage Frequency")
ax.set_title("Codebook Usage Distribution")
ax.set_yscale("log")

# Stats
ax = axes[1]
used_codes = (codebook_usage > 0).sum()
top10_usage = sorted(codebook_usage, reverse=True)[:10]

stats_text = f"""
Codebook Statistics:

Total codes: {NUM_CODEBOOK_ENTRIES}
Used codes: {used_codes} ({used_codes/NUM_CODEBOOK_ENTRIES:.1%})
Dead codes: {NUM_CODEBOOK_ENTRIES - used_codes}

Top 10 codes account for: {sum(top10_usage):.1%}

Max usage: {max(codebook_usage):.3%}
Min usage (non-zero): {min(u for u in codebook_usage if u > 0):.6%}
"""

ax.text(0.1, 0.5, stats_text, transform=ax.transAxes, fontsize=12,
        verticalalignment='center', fontfamily='monospace')
ax.axis('off')

plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/codebook_analysis.png", dpi=150)
plt.show()

---

# Part 11: Summary and Next Steps

## What We Learned

1. **AutoEncoders** compress and reconstruct data through a bottleneck
2. **Vector Quantization** forces the latent space to be discrete (finite options)
3. **The Codebook** learns patterns that appear in the training data
4. **Straight-Through Estimator** allows gradients to flow despite discrete quantization

## What's Next?

With the VQ-VAE trained, we can now:

1. **Encode** any structure into 64 discrete codes (4×4×4 grid)
2. **Decode** those codes back to a full structure

In Phase 4, we'll train a model to predict these codes from text descriptions!

In [None]:
# ============================================================
# CELL 20: Final Summary
# ============================================================

print("=" * 60)
print("VQ-VAE TRAINING COMPLETE!")
print("=" * 60)

print(f"\nModel Architecture:")
print(f"  Input: 32×32×32 block structure")
print(f"  Latent: 4×4×4 = 64 discrete codes")
print(f"  Codebook: {NUM_CODEBOOK_ENTRIES} entries × {LATENT_DIM} dims")
print(f"  Compression: ~680:1")

print(f"\nTraining:")
print(f"  Epochs: {EPOCHS}")
print(f"  Final train accuracy: {history['train_acc'][-1]:.3f}")
print(f"  Final val accuracy: {history['val_acc'][-1]:.3f}")
print(f"  Codebook usage: {history['codebook_usage'][-1]:.1%}")

print(f"\nOutput files in {OUTPUT_DIR}:")
print(f"  - vqvae_best.pt (best validation loss)")
print(f"  - vqvae_final.pt (final epoch)")
print(f"  - codebook.npy")
print(f"  - training_history.json")
print(f"  - training_curves.png")
print(f"  - reconstruction_*.png")
print(f"  - codebook_analysis.png")

print("\n" + "=" * 60)