# VQ-VAE v7 Training - U-Net Skip Connections + Volume Penalty

## Key Changes from v6-freq

| Change | v6-freq | v7 |
|--------|---------|-----|
| Skip connections | None | **Dual U-Net (16x16x16 + 32x32x32)** |
| Volume penalty | None | **MSE loss on vol_ratio** |
| Frequency cap | 10x | **5x** |
| Epochs | 25 | **40** |
| LR schedule | Constant | **Cosine annealing** |
| Early stopping | None | **On building_f1** |
| Metrics | Missing precision/F1 | **Complete set** |

## Goals

| Metric | v6-freq | v7 Target |
|--------|---------|----------|
| Building Accuracy | 49.2% | **55-62%** |
| Building Precision | ~58% (est) | **70-80%** |
| Building F1 | ~73% (est) | **75-85%** |
| Volume Ratio | 1.68x | **1.0-1.15** |
| Rare Block Recall | 94.9% | **85-92%** |

## Setup - Mount Google Drive

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Create output directory
import os
OUTPUT_DIR = '/content/drive/MyDrive/minecraft_ai/vqvae_v7'
os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"Output will be saved to: {OUTPUT_DIR}")

## Cell 1: Imports

In [None]:
import json
import random
import time
from pathlib import Path
from typing import Dict, List, Tuple, Any, Set, Optional
from collections import Counter

import h5py
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## Cell 2: Configuration

In [None]:
# === Data Paths (UPDATE THESE FOR YOUR DRIVE) ===
DRIVE_BASE = '/content/drive/MyDrive/minecraft_ai'

DATA_DIR = f"{DRIVE_BASE}/splits/train"
VAL_DIR = f"{DRIVE_BASE}/splits/val"
VOCAB_PATH = f"{DRIVE_BASE}/vocabulary/tok2block.json"
V3_EMBEDDINGS_PATH = f"{DRIVE_BASE}/embeddings/block_embeddings_v3.npy"

OUTPUT_DIR = f"{DRIVE_BASE}/vqvae_v7"

# === V7 Configuration ===
VERSION = 'v7'
HIDDEN_DIMS = [96, 192]
RFSQ_LEVELS_PER_STAGE = [5, 5, 5, 5]
NUM_STAGES = 2
DROPOUT = 0.1

# === FREQUENCY WEIGHTING (REDUCED cap from 10x to 5x) ===
USE_FREQUENCY_WEIGHTING = True
FREQUENCY_WEIGHT_CAP = 5.0  # Down from 10.0 in v6-freq

# === VOLUME PENALTY (NEW in v7) ===
VOLUME_PENALTY_WEIGHT = 0.1

# === Structure weights ===
STRUCTURE_WEIGHT = 50.0

# === TERRAIN SETTINGS ===
TERRAIN_WEIGHT = 0.2
BUILDING_WEIGHT = 1.0
AIR_WEIGHT = 0.1

# === Training (Extended to 40 epochs) ===
TOTAL_EPOCHS = 40
BATCH_SIZE = 4
BASE_LR = 3e-4
MIN_LR = 1e-5  # For cosine annealing
USE_AMP = True
GRAD_ACCUM_STEPS = 4

# === Early stopping (NEW in v7) ===
EARLY_STOPPING_PATIENCE = 10
EARLY_STOPPING_METRIC = 'building_f1'

SEED = 42
NUM_WORKERS = 2

CODES_PER_STAGE = int(np.prod(RFSQ_LEVELS_PER_STAGE))
TOTAL_IMPLICIT_CODES = CODES_PER_STAGE ** NUM_STAGES

print("VQ-VAE v7 Configuration:")
print(f"  RFSQ: {NUM_STAGES} stages × {CODES_PER_STAGE:,} codes")
print(f"  Total implicit codes: {TOTAL_IMPLICIT_CODES:,}")
print(f"\nKEY CHANGES:")
print(f"  Skip connections: Dual U-Net (16x16x16 + 32x32x32)")
print(f"  Volume penalty weight: {VOLUME_PENALTY_WEIGHT}")
print(f"  Frequency cap: {FREQUENCY_WEIGHT_CAP}x (down from 10x)")
print(f"  Epochs: {TOTAL_EPOCHS} (up from 25)")
print(f"  LR schedule: Cosine annealing ({BASE_LR} -> {MIN_LR})")
print(f"  Early stopping: patience={EARLY_STOPPING_PATIENCE} on {EARLY_STOPPING_METRIC}")

## Cell 3: Load Vocabulary and Embeddings

In [None]:
with open(VOCAB_PATH, 'r') as f:
    tok2block = {int(k): v for k, v in json.load(f).items()}

block2tok = {v: k for k, v in tok2block.items()}
VOCAB_SIZE = len(tok2block)
print(f"Vocabulary size: {VOCAB_SIZE}")

# Find air tokens
AIR_TOKENS: Set[int] = set()
for tok, block in tok2block.items():
    if 'air' in block.lower() and 'stair' not in block.lower():
        AIR_TOKENS.add(tok)
        print(f"  Air token: {tok} = {block}")

AIR_TOKENS_LIST = sorted(AIR_TOKENS)
AIR_TOKENS_TENSOR = torch.tensor(AIR_TOKENS_LIST, dtype=torch.long)

# Load V3 embeddings
v3_embeddings = np.load(V3_EMBEDDINGS_PATH).astype(np.float32)
EMBEDDING_DIM = v3_embeddings.shape[1]
print(f"V3 embeddings: {v3_embeddings.shape} (dim={EMBEDDING_DIM})")

## Cell 4: Compute Block Frequencies

In [None]:
print("Computing block frequencies from training data...")
print("(This may take a few minutes)")

block_counts = Counter()
train_files = sorted(Path(DATA_DIR).glob("*.h5"))

for h5_file in tqdm(train_files, desc="Scanning training data"):
    with h5py.File(h5_file, 'r') as f:
        key = list(f.keys())[0]
        structure = f[key][:].flatten()
        block_counts.update(structure.tolist())

total_blocks = sum(block_counts.values())
print(f"\nTotal blocks scanned: {total_blocks:,}")
print(f"Unique block types: {len(block_counts)}")

# Compute inverse frequency weights (capped at 5x for v7, down from 10x)
max_count = max(block_counts.values())
frequency_weights = {}
for tok in range(VOCAB_SIZE):
    count = block_counts.get(tok, 1)
    weight = min(FREQUENCY_WEIGHT_CAP, max_count / count)
    frequency_weights[tok] = weight

FREQUENCY_WEIGHT_TENSOR = torch.tensor(
    [frequency_weights[i] for i in range(VOCAB_SIZE)],
    dtype=torch.float32
)

# Show top 10 rarest blocks
print(f"\nTop 10 RAREST blocks (highest weights, capped at {FREQUENCY_WEIGHT_CAP}x):")
sorted_by_weight = sorted(frequency_weights.items(), key=lambda x: x[1], reverse=True)
for tok, weight in sorted_by_weight[:10]:
    block_name = tok2block.get(tok, f"UNKNOWN_{tok}")
    count = block_counts.get(tok, 0)
    print(f"  {block_name}: weight={weight:.1f}x (count={count:,})")

# Identify rare block categories
RARE_BLOCK_KEYWORDS = ['chest', 'door', 'fence', 'trapdoor', 'carpet', 'bed', 'button', 'lever']
RARE_BLOCK_TOKENS = set()
for tok, block in tok2block.items():
    for keyword in RARE_BLOCK_KEYWORDS:
        if keyword in block.lower():
            RARE_BLOCK_TOKENS.add(tok)
            break

RARE_BLOCK_TOKENS_TENSOR = torch.tensor(sorted(RARE_BLOCK_TOKENS), dtype=torch.long)
print(f"\nRare block tokens identified: {len(RARE_BLOCK_TOKENS)}")

## Cell 5: Terrain Detection

In [None]:
TERRAIN_BLOCKS: Set[str] = {
    'minecraft:dirt', 'minecraft:grass_block', 'minecraft:coarse_dirt',
    'minecraft:podzol', 'minecraft:mycelium', 'minecraft:rooted_dirt',
    'minecraft:dirt_path', 'minecraft:farmland', 'minecraft:mud',
    'minecraft:stone', 'minecraft:cobblestone', 'minecraft:mossy_cobblestone',
    'minecraft:bedrock', 'minecraft:deepslate', 'minecraft:tuff',
    'minecraft:granite', 'minecraft:diorite', 'minecraft:andesite',
    'minecraft:sand', 'minecraft:red_sand', 'minecraft:gravel', 'minecraft:clay',
    'minecraft:water', 'minecraft:lava',
    'minecraft:terracotta', 'minecraft:white_terracotta', 'minecraft:orange_terracotta',
    'minecraft:brown_terracotta', 'minecraft:red_terracotta',
    'minecraft:netherrack', 'minecraft:soul_sand', 'minecraft:soul_soil',
    'minecraft:end_stone',
    'minecraft:snow_block', 'minecraft:ice', 'minecraft:packed_ice',
}

TERRAIN_TOKENS: Set[int] = set()
for tok, block in tok2block.items():
    base_name = block.split('[')[0] if '[' in block else block
    if base_name in TERRAIN_BLOCKS:
        TERRAIN_TOKENS.add(tok)

TERRAIN_TOKENS_TENSOR = torch.tensor(sorted(TERRAIN_TOKENS), dtype=torch.long)
print(f"Terrain tokens: {len(TERRAIN_TOKENS)}")

## Cell 6: Dataset

In [None]:
class VQVAEDataset(Dataset):
    def __init__(self, data_dir: str):
        self.data_dir = Path(data_dir)
        self.h5_files = sorted(self.data_dir.glob("*.h5"))
        if not self.h5_files:
            raise ValueError(f"No H5 files in {data_dir}")
        print(f"Found {len(self.h5_files)} structures in {data_dir}")

    def __len__(self):
        return len(self.h5_files)

    def __getitem__(self, idx):
        with h5py.File(self.h5_files[idx], 'r') as f:
            key = list(f.keys())[0]
            structure = f[key][:].astype(np.int64)
        return torch.from_numpy(structure).long()

train_dataset = VQVAEDataset(DATA_DIR)
val_dataset = VQVAEDataset(VAL_DIR)
print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}")

## Cell 7: FSQ Base Module

In [None]:
class FSQ(nn.Module):
    def __init__(self, levels: List[int], eps: float = 1e-3):
        super().__init__()
        self.levels = levels
        self.dim = len(levels)
        self.eps = eps
        self.codebook_size = int(np.prod(levels))

        self.register_buffer('_levels', torch.tensor(levels, dtype=torch.float32))
        basis = []
        acc = 1
        for L in reversed(levels):
            basis.append(acc)
            acc *= L
        self.register_buffer('_basis', torch.tensor(list(reversed(basis)), dtype=torch.long))
        half_levels = [(L - 1) / 2 for L in levels]
        self.register_buffer('_half_levels', torch.tensor(half_levels, dtype=torch.float32))
        self.register_buffer('_usage', torch.zeros(self.codebook_size))

    def reset_usage(self):
        self._usage.zero_()

    def get_usage_stats(self) -> Tuple[float, float]:
        usage = (self._usage > 0).float().mean().item()
        if self._usage.sum() == 0:
            return usage, 0.0
        probs = self._usage / self._usage.sum()
        probs = probs[probs > 0]
        entropy = -(probs * probs.log()).sum()
        perplexity = entropy.exp().item()
        return usage, perplexity

    def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        z_bounded = torch.tanh(z)
        z_q_list = []
        for i in range(self.dim):
            L = self._levels[i]
            half_L = self._half_levels[i]
            z_i = z_bounded[..., i]
            z_i = z_i * half_L
            z_i = torch.round(z_i)
            z_i = torch.clamp(z_i, -half_L, half_L)
            z_i = z_i / half_L
            z_q_list.append(z_i)
        z_q = torch.stack(z_q_list, dim=-1)
        z_q = z_bounded + (z_q - z_bounded).detach()
        indices = torch.zeros(z_q.shape[:-1], dtype=torch.long, device=z_q.device)
        for i in range(self.dim):
            L = self._levels[i].long()
            half_L = self._half_levels[i]
            z_i = z_q[..., i]
            level_idx = ((z_i * half_L) + half_L).round().long()
            level_idx = torch.clamp(level_idx, 0, L - 1)
            indices = indices + level_idx * self._basis[i]
        with torch.no_grad():
            for idx in indices.unique():
                if idx < self.codebook_size:
                    self._usage[idx] += (indices == idx).sum()
        return z_q, indices

## Cell 8: RFSQ Module

In [None]:
class InvertibleLayerNorm(nn.Module):
    def __init__(self, num_features: int, eps: float = 1e-5):
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(num_features))
        self.bias = nn.Parameter(torch.zeros(num_features))
        self.stored_mean = None
        self.stored_std = None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        self.stored_mean = x.mean(dim=(1, 2, 3), keepdim=True)
        self.stored_std = x.std(dim=(1, 2, 3), keepdim=True) + self.eps
        x_norm = (x - self.stored_mean) / self.stored_std
        return x_norm * self.weight + self.bias

    def inverse(self, x_norm: torch.Tensor) -> torch.Tensor:
        x = (x_norm - self.bias) / self.weight
        return x * self.stored_std + self.stored_mean


class RFSQStage(nn.Module):
    def __init__(self, levels: List[int]):
        super().__init__()
        self.fsq = FSQ(levels)
        self.layernorm = InvertibleLayerNorm(len(levels))

    @property
    def codebook_size(self) -> int:
        return self.fsq.codebook_size

    def reset_usage(self):
        self.fsq.reset_usage()

    def get_usage_stats(self):
        return self.fsq.get_usage_stats()

    def forward(self, residual):
        z_norm = self.layernorm(residual)
        z_q_norm, indices = self.fsq(z_norm)
        z_q = self.layernorm.inverse(z_q_norm)
        new_residual = residual - z_q
        return z_q, new_residual, indices


class RFSQ(nn.Module):
    def __init__(self, levels_per_stage: List[int], num_stages: int = 2):
        super().__init__()
        self.num_stages = num_stages
        self.dim = len(levels_per_stage)
        self.stages = nn.ModuleList([RFSQStage(levels_per_stage) for _ in range(num_stages)])
        codes_per_stage = int(np.prod(levels_per_stage))
        self.codebook_size = codes_per_stage ** num_stages
        self.codes_per_stage = codes_per_stage

    def reset_usage(self):
        for stage in self.stages:
            stage.reset_usage()

    def get_usage_stats(self):
        return {f'stage{i}': stage.get_usage_stats() for i, stage in enumerate(self.stages)}

    def forward(self, z):
        residual = z
        z_q_sum = torch.zeros_like(z)
        all_indices = []
        for stage in self.stages:
            z_q, residual, indices = stage(residual)
            z_q_sum = z_q_sum + z_q
            all_indices.append(indices)
        return z_q_sum, all_indices

    def forward_with_norms(self, z):
        residual = z
        z_q_sum = torch.zeros_like(z)
        all_indices = []
        residual_norms = []
        for stage in self.stages:
            residual_norms.append(residual.norm().item())
            z_q, residual, indices = stage(residual)
            z_q_sum = z_q_sum + z_q
            all_indices.append(indices)
        residual_norms.append(residual.norm().item())
        return z_q_sum, all_indices, residual_norms

print(f"RFSQ module defined")

## Cell 9: VQ-VAE v7 Architecture with U-Net Skip Connections

Key architectural changes:
1. **EncoderV7**: Returns z_e + skip_16 + skip_32
2. **DecoderV7**: Accepts and concatenates skip connections
3. **VolumeRatioPenalty**: Penalizes over-prediction

In [None]:
class ResidualBlock3D(nn.Module):
    def __init__(self, channels: int):
        super().__init__()
        self.conv1 = nn.Conv3d(channels, channels, 3, padding=1)
        self.conv2 = nn.Conv3d(channels, channels, 3, padding=1)
        self.bn1 = nn.BatchNorm3d(channels)
        self.bn2 = nn.BatchNorm3d(channels)

    def forward(self, x):
        residual = x
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        return F.relu(x + residual)


class EncoderV7(nn.Module):
    """Encoder with dual skip connection outputs for U-Net architecture.
    
    Returns:
        z_e: Final latent [B, rfsq_dim, 8, 8, 8]
        skip_16: Features at 16x16x16 [B, hidden_dims[0], 16, 16, 16]
        skip_32: Input embeddings [B, in_channels, 32, 32, 32]
    """
    
    def __init__(self, in_channels: int, hidden_dims: List[int], rfsq_dim: int, dropout: float = 0.1):
        super().__init__()
        
        # Stage 1: 32 -> 16 (output used for skip_16)
        self.stage1 = nn.Sequential(
            nn.Conv3d(in_channels, hidden_dims[0], 4, stride=2, padding=1),
            nn.BatchNorm3d(hidden_dims[0]),
            nn.ReLU(inplace=True),
            nn.Dropout3d(dropout),
            ResidualBlock3D(hidden_dims[0]),
        )
        
        # Stage 2: 16 -> 8
        self.stage2 = nn.Sequential(
            nn.Conv3d(hidden_dims[0], hidden_dims[1], 4, stride=2, padding=1),
            nn.BatchNorm3d(hidden_dims[1]),
            nn.ReLU(inplace=True),
            nn.Dropout3d(dropout),
            ResidualBlock3D(hidden_dims[1]),
        )
        
        # Final processing at 8x8x8
        self.final = nn.Sequential(
            ResidualBlock3D(hidden_dims[1]),
            ResidualBlock3D(hidden_dims[1]),
            nn.Conv3d(hidden_dims[1], rfsq_dim, 3, padding=1),
        )
    
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        # Store input for skip_32 connection
        skip_32 = x  # [B, 32, 32, 32, 32]
        
        # Stage 1: capture 16x16x16 features for skip_16
        skip_16 = self.stage1(x)  # [B, 96, 16, 16, 16]
        
        # Stage 2: compress to 8x8x8
        x = self.stage2(skip_16)  # [B, 192, 8, 8, 8]
        
        # Final projection to RFSQ dimension
        z_e = self.final(x)  # [B, 4, 8, 8, 8]
        
        return z_e, skip_16, skip_32


class DecoderV7(nn.Module):
    """Decoder with dual U-Net skip connections.
    
    Skip connections:
        - skip_16: Concatenated after first upsample (8->16)
        - skip_32: Concatenated after second upsample (16->32)
    """
    
    def __init__(self, rfsq_dim: int, hidden_dims: List[int], num_blocks: int, 
                 emb_dim: int, dropout: float = 0.1):
        super().__init__()
        # hidden_dims is [192, 96] (reversed from encoder)
        
        # Initial projection from RFSQ dim
        self.initial = nn.Sequential(
            nn.Conv3d(rfsq_dim, hidden_dims[0], 3, padding=1),
            nn.BatchNorm3d(hidden_dims[0]),
            nn.ReLU(inplace=True),
            ResidualBlock3D(hidden_dims[0]),
            ResidualBlock3D(hidden_dims[0]),
        )
        
        # Upsample 8 -> 16
        self.up1 = nn.Sequential(
            ResidualBlock3D(hidden_dims[0]),
            nn.ConvTranspose3d(hidden_dims[0], hidden_dims[1], 4, stride=2, padding=1),
            nn.BatchNorm3d(hidden_dims[1]),
            nn.ReLU(inplace=True),
            nn.Dropout3d(dropout),
        )
        
        # Skip connection 1 projection: 96 (up1) + 96 (skip_16) = 192 -> 96
        self.skip1_proj = nn.Sequential(
            nn.Conv3d(hidden_dims[1] * 2, hidden_dims[1], 1),
            nn.BatchNorm3d(hidden_dims[1]),
            nn.ReLU(inplace=True),
        )
        
        # Process after skip 1
        self.post_skip1 = nn.Sequential(
            ResidualBlock3D(hidden_dims[1]),
            ResidualBlock3D(hidden_dims[1]),
        )
        
        # Upsample 16 -> 32
        self.up2 = nn.Sequential(
            ResidualBlock3D(hidden_dims[1]),
            nn.ConvTranspose3d(hidden_dims[1], hidden_dims[1], 4, stride=2, padding=1),
            nn.BatchNorm3d(hidden_dims[1]),
            nn.ReLU(inplace=True),
        )
        
        # Skip connection 2 projection: 96 (up2) + 32 (skip_32/emb) = 128 -> 96
        self.skip2_proj = nn.Sequential(
            nn.Conv3d(hidden_dims[1] + emb_dim, hidden_dims[1], 1),
            nn.BatchNorm3d(hidden_dims[1]),
            nn.ReLU(inplace=True),
        )
        
        # Process after skip 2
        self.post_skip2 = ResidualBlock3D(hidden_dims[1])
        
        # Final prediction
        self.final = nn.Conv3d(hidden_dims[1], num_blocks, 3, padding=1)
    
    def forward(self, z_q: torch.Tensor, skip_16: torch.Tensor, skip_32: torch.Tensor) -> torch.Tensor:
        # Initial processing at 8x8x8
        x = self.initial(z_q)  # [B, 192, 8, 8, 8]
        
        # Upsample to 16x16x16
        x = self.up1(x)  # [B, 96, 16, 16, 16]
        
        # Skip connection 1: concat with encoder's 16x16x16 features
        x = torch.cat([x, skip_16], dim=1)  # [B, 192, 16, 16, 16]
        x = self.skip1_proj(x)  # [B, 96, 16, 16, 16]
        x = self.post_skip1(x)  # [B, 96, 16, 16, 16]
        
        # Upsample to 32x32x32
        x = self.up2(x)  # [B, 96, 32, 32, 32]
        
        # Skip connection 2: concat with input embeddings
        x = torch.cat([x, skip_32], dim=1)  # [B, 128, 32, 32, 32] (96 + 32)
        x = self.skip2_proj(x)  # [B, 96, 32, 32, 32]
        x = self.post_skip2(x)  # [B, 96, 32, 32, 32]
        
        # Predict block types
        logits = self.final(x)  # [B, vocab_size, 32, 32, 32]
        
        return logits


class FrequencyWeightedLoss(nn.Module):
    """Cross-entropy with terrain weighting AND per-block frequency weighting."""

    def __init__(self, frequency_weights: torch.Tensor,
                 terrain_weight: float = 0.2, building_weight: float = 1.0, air_weight: float = 0.1):
        super().__init__()
        self.register_buffer('frequency_weights', frequency_weights)
        self.terrain_weight = terrain_weight
        self.building_weight = building_weight
        self.air_weight = air_weight

    def forward(self, logits: torch.Tensor, targets: torch.Tensor,
                terrain_mask: torch.Tensor, air_mask: torch.Tensor) -> torch.Tensor:
        ce_loss = F.cross_entropy(logits, targets, reduction='none')
        base_weights = torch.full_like(ce_loss, self.building_weight)
        base_weights[terrain_mask] = self.terrain_weight
        base_weights[air_mask] = self.air_weight
        freq_weights = self.frequency_weights[targets]
        combined_weights = base_weights * freq_weights
        return (ce_loss * combined_weights).sum() / combined_weights.sum()


class VolumeRatioPenalty(nn.Module):
    """Penalize when predicted volume deviates from ground truth.
    
    Loss = weight * (vol_ratio - 1.0)^2
    
    This addresses v6-freq's 1.68x over-prediction problem.
    """
    
    def __init__(self, weight: float = 0.1):
        super().__init__()
        self.weight = weight
    
    def forward(self, preds: torch.Tensor, targets: torch.Tensor, 
                air_tokens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute volume penalty from hard predictions."""
        gt_air = torch.isin(targets, air_tokens)
        pred_air = torch.isin(preds, air_tokens)
        
        gt_volume = (~gt_air).sum().float()
        pred_volume = (~pred_air).sum().float()
        
        vol_ratio = pred_volume / (gt_volume + 1e-6)
        volume_penalty = self.weight * (vol_ratio - 1.0) ** 2
        
        return volume_penalty, vol_ratio.detach()


print("V7 architecture components defined!")
print("  - EncoderV7: Returns z_e, skip_16, skip_32")
print("  - DecoderV7: Accepts dual skip connections")
print("  - VolumeRatioPenalty: Penalizes over-prediction")

## Cell 10: Complete Metrics Computation

All metrics from CLAUDE.md reference including NEW: precision, F1, air_acc

In [None]:
def compute_all_metrics(preds: torch.Tensor, targets: torch.Tensor,
                        air_tokens: torch.Tensor, terrain_tokens: torch.Tensor,
                        rare_tokens: torch.Tensor, 
                        block_emb: nn.Embedding) -> Dict[str, torch.Tensor]:
    """
    Compute ALL required metrics for VQ-VAE v7.
    
    NEW metrics in v7:
        - building_precision: % of predicted building that is actually building
        - building_f1: Balanced precision/recall score
        - air_acc: % of GT air predicted as air
        - air_precision: % of predicted air that is actually air
        - false_block_rate: air incorrectly predicted as building
        - rare_precision: precision on rare blocks
    """
    device = preds.device
    
    # Masks
    gt_air = torch.isin(targets, air_tokens)
    pred_air = torch.isin(preds, air_tokens)
    gt_terrain = torch.isin(targets, terrain_tokens) & ~gt_air
    gt_building = ~gt_air & ~gt_terrain
    gt_rare = torch.isin(targets, rare_tokens)
    
    correct = (preds == targets)
    
    metrics = {}
    
    # === CORE METRICS ===
    metrics['overall_acc'] = correct.float().mean()
    
    # Building metrics
    if gt_building.any():
        metrics['building_acc'] = correct[gt_building].float().mean()
        metrics['building_recall'] = (~pred_air[gt_building]).float().mean()
        metrics['building_false_air'] = pred_air[gt_building].float().mean()
    else:
        metrics['building_acc'] = torch.tensor(0.0, device=device)
        metrics['building_recall'] = torch.tensor(0.0, device=device)
        metrics['building_false_air'] = torch.tensor(0.0, device=device)
    
    # Terrain metrics
    if gt_terrain.any():
        metrics['terrain_acc'] = correct[gt_terrain].float().mean()
    else:
        metrics['terrain_acc'] = torch.tensor(0.0, device=device)
    
    # === NEW: BUILDING PRECISION ===
    pred_building = ~pred_air
    if pred_building.any():
        # Of all predicted non-air, how many are correct AND were actually non-air?
        correct_building = correct & pred_building & ~gt_air
        metrics['building_precision'] = correct_building.sum().float() / pred_building.sum().float()
    else:
        metrics['building_precision'] = torch.tensor(0.0, device=device)
    
    # === NEW: BUILDING F1 ===
    prec = metrics['building_precision']
    rec = metrics['building_recall']
    if (prec + rec) > 0:
        metrics['building_f1'] = 2 * (prec * rec) / (prec + rec)
    else:
        metrics['building_f1'] = torch.tensor(0.0, device=device)
    
    # === NEW: AIR METRICS ===
    if gt_air.any():
        metrics['air_acc'] = correct[gt_air].float().mean()
        metrics['false_block_rate'] = (~pred_air[gt_air]).float().mean()
    else:
        metrics['air_acc'] = torch.tensor(1.0, device=device)
        metrics['false_block_rate'] = torch.tensor(0.0, device=device)
    
    if pred_air.any():
        correct_air = correct & pred_air & gt_air
        metrics['air_precision'] = correct_air.sum().float() / pred_air.sum().float()
    else:
        metrics['air_precision'] = torch.tensor(0.0, device=device)
    
    # === VOLUME METRICS ===
    gt_struct = ~gt_air
    pred_struct = ~pred_air
    
    gt_volume = gt_struct.sum().float()
    pred_volume = pred_struct.sum().float()
    metrics['vol_ratio'] = pred_volume / (gt_volume + 1e-6)
    
    if gt_struct.any():
        metrics['struct_recall'] = pred_struct[gt_struct].float().mean()
    else:
        metrics['struct_recall'] = torch.tensor(1.0, device=device)
    
    # === RARE BLOCK METRICS ===
    if gt_rare.any():
        metrics['rare_acc'] = correct[gt_rare].float().mean()
        metrics['rare_recall'] = (~pred_air[gt_rare]).float().mean()
        
        # NEW: Rare precision
        pred_rare = torch.isin(preds, rare_tokens)
        if pred_rare.any():
            correct_rare = correct & pred_rare & gt_rare
            metrics['rare_precision'] = correct_rare.sum().float() / pred_rare.sum().float()
        else:
            metrics['rare_precision'] = torch.tensor(0.0, device=device)
    else:
        metrics['rare_acc'] = torch.tensor(0.0, device=device)
        metrics['rare_recall'] = torch.tensor(0.0, device=device)
        metrics['rare_precision'] = torch.tensor(0.0, device=device)
    
    # === ERROR SIMILARITY ===
    wrong_building = gt_building & (preds != targets)
    if wrong_building.any():
        metrics['error_similarity'] = F.cosine_similarity(
            block_emb.weight[preds[wrong_building]],
            block_emb.weight[targets[wrong_building]], 
            dim=-1
        ).mean()
    else:
        metrics['error_similarity'] = torch.tensor(0.0, device=device)
    
    return metrics


print("Complete metrics function defined!")
print("NEW metrics: building_precision, building_f1, air_acc, air_precision, false_block_rate, rare_precision")

## Cell 11: VQ-VAE v7 Full Model

In [None]:
class VQVAEv7(nn.Module):
    """VQ-VAE v7 with U-Net skip connections + Volume Penalty.
    
    Key improvements over v6-freq:
    1. Dual skip connections (16x16x16 + 32x32x32) preserve fine details
    2. Volume ratio penalty prevents over-prediction
    3. Reduced frequency weight cap (10x -> 5x)
    4. Complete metrics including precision/F1
    """
    
    def __init__(self, vocab_size: int, emb_dim: int, hidden_dims: List[int],
                 rfsq_levels: List[int], num_stages: int, pretrained_emb: np.ndarray,
                 frequency_weights: torch.Tensor,
                 terrain_weight: float = 0.2, building_weight: float = 1.0,
                 air_weight: float = 0.1, volume_penalty_weight: float = 0.1,
                 dropout: float = 0.1):
        super().__init__()
        
        self.vocab_size = vocab_size
        self.emb_dim = emb_dim
        self.rfsq_dim = len(rfsq_levels)
        self.num_stages = num_stages
        
        # Block embeddings (frozen)
        self.block_emb = nn.Embedding(vocab_size, emb_dim)
        self.block_emb.weight.data.copy_(torch.from_numpy(pretrained_emb))
        self.block_emb.weight.requires_grad = False
        
        # Encoder with dual skip outputs
        self.encoder = EncoderV7(emb_dim, hidden_dims, self.rfsq_dim, dropout)
        
        # RFSQ quantization (unchanged from v6)
        self.rfsq = RFSQ(rfsq_levels, num_stages)
        
        # Decoder with dual skip inputs
        self.decoder = DecoderV7(
            self.rfsq_dim, 
            list(reversed(hidden_dims)), 
            vocab_size, 
            emb_dim,  # For skip_32 projection
            dropout
        )
        
        # Loss functions
        self.loss_fn = FrequencyWeightedLoss(frequency_weights, terrain_weight, building_weight, air_weight)
        self.volume_penalty = VolumeRatioPenalty(weight=volume_penalty_weight)
    
    def forward(self, block_ids: torch.Tensor, return_norms: bool = False) -> Dict[str, Any]:
        # Embed blocks
        x = self.block_emb(block_ids)  # [B, 32, 32, 32, emb_dim]
        x = x.permute(0, 4, 1, 2, 3).contiguous()  # [B, emb_dim, 32, 32, 32]
        
        # Encode with dual skip connections
        z_e, skip_16, skip_32 = self.encoder(x)
        # z_e: [B, 4, 8, 8, 8], skip_16: [B, 96, 16, 16, 16], skip_32: [B, 32, 32, 32, 32]
        
        # Permute for RFSQ (expects channels last)
        z_e = z_e.permute(0, 2, 3, 4, 1).contiguous()  # [B, 8, 8, 8, 4]
        
        # Quantize
        if return_norms:
            z_q, all_indices, residual_norms = self.rfsq.forward_with_norms(z_e)
        else:
            z_q, all_indices = self.rfsq(z_e)
            residual_norms = None
        
        # Permute back for decoder
        z_q = z_q.permute(0, 4, 1, 2, 3).contiguous()  # [B, 4, 8, 8, 8]
        
        # Decode with dual skip connections
        logits = self.decoder(z_q, skip_16, skip_32)  # [B, vocab, 32, 32, 32]
        
        result = {
            'logits': logits,
            'all_indices': all_indices,
            'z_e': z_e,
            'z_q': z_q,
            'skip_16': skip_16,
            'skip_32': skip_32,
        }
        
        if residual_norms is not None:
            result['residual_norms'] = residual_norms
        
        return result
    
    def compute_loss(self, block_ids: torch.Tensor, air_tokens: torch.Tensor,
                     terrain_tokens: torch.Tensor, rare_tokens: torch.Tensor,
                     structure_weight: float = 50.0, return_norms: bool = False) -> Dict[str, torch.Tensor]:
        out = self(block_ids, return_norms=return_norms)
        logits = out['logits']
        
        B, C, X, Y, Z = logits.shape
        logits_flat = logits.permute(0, 2, 3, 4, 1).reshape(-1, C)
        targets_flat = block_ids.view(-1)
        
        device = targets_flat.device
        air_dev = air_tokens.to(device)
        terrain_dev = terrain_tokens.to(device)
        rare_dev = rare_tokens.to(device)
        
        is_air = torch.isin(targets_flat, air_dev)
        is_terrain = torch.isin(targets_flat, terrain_dev) & ~is_air
        
        # Frequency-weighted CE loss
        ce_loss = self.loss_fn(logits_flat, targets_flat, is_terrain, is_air)
        
        # Compute predictions for volume penalty and metrics
        with torch.no_grad():
            preds = logits_flat.argmax(dim=1)
        
        # Volume penalty (NEW in v7)
        vol_penalty, vol_ratio = self.volume_penalty(preds, targets_flat, air_dev)
        
        # Total loss
        loss = ce_loss + vol_penalty
        
        # Compute all metrics
        with torch.no_grad():
            metrics = compute_all_metrics(
                preds, targets_flat,
                air_dev, terrain_dev, rare_dev,
                self.block_emb
            )
        
        result = {
            'loss': loss,
            'ce_loss': ce_loss,
            'vol_penalty': vol_penalty,
            **metrics,
        }
        
        if 'residual_norms' in out:
            result['residual_norms'] = out['residual_norms']
        
        return result


print("VQ-VAE v7 model defined!")
print("  - U-Net dual skip connections")
print("  - Volume ratio penalty")
print("  - Complete metrics (precision, F1, air metrics)")

## Cell 12: Training Functions with Early Stopping

In [None]:
# All metric keys for v7
METRIC_KEYS = [
    'loss', 'ce_loss', 'vol_penalty',
    'overall_acc', 'terrain_acc', 
    'building_acc', 'building_recall', 'building_precision', 'building_f1', 'building_false_air',
    'air_acc', 'air_precision', 'false_block_rate',
    'struct_recall', 'vol_ratio', 'error_similarity',
    'rare_acc', 'rare_recall', 'rare_precision',
]


def train_epoch(model, loader, optimizer, scaler, device,
                air_tokens, terrain_tokens, rare_tokens, structure_weight):
    model.train()
    model.rfsq.reset_usage()
    
    metrics = {k: 0.0 for k in METRIC_KEYS}
    grad_norms = []
    all_residual_norms = []
    n = 0
    optimizer.zero_grad()
    
    for batch_idx, batch in enumerate(tqdm(loader, desc="Train", leave=False)):
        batch = batch.to(device)
        
        with torch.amp.autocast('cuda', enabled=USE_AMP):
            return_norms = (batch_idx % 100 == 0)
            out = model.compute_loss(batch, air_tokens, terrain_tokens, rare_tokens,
                                    structure_weight, return_norms=return_norms)
            loss = out['loss'] / GRAD_ACCUM_STEPS
        
        if return_norms and 'residual_norms' in out:
            all_residual_norms.append(out['residual_norms'])
        
        scaler.scale(loss).backward()
        
        if (batch_idx + 1) % GRAD_ACCUM_STEPS == 0:
            scaler.unscale_(optimizer)
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            grad_norms.append(grad_norm.item())
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
        
        for k in metrics:
            if k in out:
                metrics[k] += out[k].item() if torch.is_tensor(out[k]) else out[k]
        n += 1
    
    # RFSQ stage stats
    stage_stats = model.rfsq.get_usage_stats()
    for stage_name, (usage, perp) in stage_stats.items():
        metrics[f'{stage_name}_usage'] = usage
        metrics[f'{stage_name}_perplexity'] = perp
    
    metrics['grad_norm'] = sum(grad_norms) / len(grad_norms) if grad_norms else 0.0
    
    if all_residual_norms:
        avg_norms = np.mean(all_residual_norms, axis=0)
        metrics['residual_decay'] = avg_norms[-1] / avg_norms[0] if avg_norms[0] > 0 else 1.0
    else:
        metrics['residual_decay'] = 1.0
    
    skip_avg = ['stage0_usage', 'stage0_perplexity', 'stage1_usage', 'stage1_perplexity', 'grad_norm', 'residual_decay']
    return {k: v/n if k not in skip_avg else v for k, v in metrics.items()}


@torch.no_grad()
def validate(model, loader, device, air_tokens, terrain_tokens, rare_tokens, structure_weight):
    model.eval()
    model.rfsq.reset_usage()
    
    metrics = {k: 0.0 for k in METRIC_KEYS}
    all_residual_norms = []
    n = 0
    
    for batch_idx, batch in enumerate(tqdm(loader, desc="Val", leave=False)):
        batch = batch.to(device)
        
        with torch.amp.autocast('cuda', enabled=USE_AMP):
            return_norms = (batch_idx % 50 == 0)
            out = model.compute_loss(batch, air_tokens, terrain_tokens, rare_tokens,
                                    structure_weight, return_norms=return_norms)
        
        if return_norms and 'residual_norms' in out:
            all_residual_norms.append(out['residual_norms'])
        
        for k in metrics:
            if k in out:
                metrics[k] += out[k].item() if torch.is_tensor(out[k]) else out[k]
        n += 1
    
    stage_stats = model.rfsq.get_usage_stats()
    for stage_name, (usage, perp) in stage_stats.items():
        metrics[f'{stage_name}_usage'] = usage
        metrics[f'{stage_name}_perplexity'] = perp
    
    if all_residual_norms:
        avg_norms = np.mean(all_residual_norms, axis=0)
        metrics['residual_decay'] = avg_norms[-1] / avg_norms[0] if avg_norms[0] > 0 else 1.0
    else:
        metrics['residual_decay'] = 1.0
    
    skip_avg = ['stage0_usage', 'stage0_perplexity', 'stage1_usage', 'stage1_perplexity', 'residual_decay']
    return {k: v/n if k not in skip_avg else v for k, v in metrics.items()}


class EarlyStopping:
    """Early stopping based on validation metric."""
    def __init__(self, patience: int = 10, metric: str = 'building_f1', mode: str = 'max'):
        self.patience = patience
        self.metric = metric
        self.mode = mode
        self.best_value = float('-inf') if mode == 'max' else float('inf')
        self.counter = 0
        self.best_epoch = 0
    
    def __call__(self, value: float, epoch: int) -> bool:
        is_better = (value > self.best_value) if self.mode == 'max' else (value < self.best_value)
        if is_better:
            self.best_value = value
            self.best_epoch = epoch
            self.counter = 0
            return False  # Don't stop
        else:
            self.counter += 1
            return self.counter >= self.patience  # Stop if patience exceeded


print("Training functions defined with early stopping!")

## Cell 13: Create Model and Optimizer with Cosine Annealing

In [None]:
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.empty_cache()

g = torch.Generator().manual_seed(SEED)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, pin_memory=True, generator=g)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
                        num_workers=NUM_WORKERS, pin_memory=True)

model = VQVAEv7(
    vocab_size=VOCAB_SIZE,
    emb_dim=EMBEDDING_DIM,
    hidden_dims=HIDDEN_DIMS,
    rfsq_levels=RFSQ_LEVELS_PER_STAGE,
    num_stages=NUM_STAGES,
    pretrained_emb=v3_embeddings,
    frequency_weights=FREQUENCY_WEIGHT_TENSOR,
    terrain_weight=TERRAIN_WEIGHT,
    building_weight=BUILDING_WEIGHT,
    air_weight=AIR_WEIGHT,
    volume_penalty_weight=VOLUME_PENALTY_WEIGHT,
    dropout=DROPOUT,
).to(device)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total params: {total_params:,}")
print(f"Trainable params: {trainable_params:,}")
print(f"RFSQ total codes: {model.rfsq.codebook_size:,}")

optimizer = optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=BASE_LR, weight_decay=1e-5)

# Cosine annealing LR scheduler (NEW in v7)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=TOTAL_EPOCHS, eta_min=MIN_LR)

scaler = torch.amp.GradScaler('cuda', enabled=USE_AMP)

# Early stopping (NEW in v7)
early_stopper = EarlyStopping(patience=EARLY_STOPPING_PATIENCE, metric=EARLY_STOPPING_METRIC, mode='max')

print(f"\nOptimizer: AdamW, LR={BASE_LR}")
print(f"LR Schedule: Cosine annealing ({BASE_LR} -> {MIN_LR})")
print(f"Early stopping: patience={EARLY_STOPPING_PATIENCE} on {EARLY_STOPPING_METRIC}")

## Cell 14: Training Loop

In [None]:
print("="*70)
print("VQ-VAE V7 TRAINING - U-NET + VOLUME PENALTY")
print("="*70)
print(f"Key features:")
print(f"  - U-Net dual skip connections (16x16x16 + 32x32x32)")
print(f"  - Volume ratio penalty (weight={VOLUME_PENALTY_WEIGHT})")
print(f"  - Frequency cap: {FREQUENCY_WEIGHT_CAP}x (down from 10x)")
print(f"  - Cosine annealing LR: {BASE_LR} -> {MIN_LR}")
print(f"  - Early stopping: {EARLY_STOPPING_PATIENCE} epochs on {EARLY_STOPPING_METRIC}")
print(f"  - NEW metrics: precision, F1, air_acc")
print()

# History with all v7 metrics
history = {
    # Core metrics
    'train_loss': [], 'train_ce_loss': [], 'train_vol_penalty': [],
    'train_building_acc': [], 'train_building_recall': [], 'train_building_precision': [], 'train_building_f1': [],
    'train_terrain_acc': [], 'train_struct_recall': [], 'train_building_false_air': [],
    'val_loss': [], 'val_ce_loss': [], 'val_vol_penalty': [],
    'val_building_acc': [], 'val_building_recall': [], 'val_building_precision': [], 'val_building_f1': [],
    'val_terrain_acc': [], 'val_struct_recall': [], 'val_building_false_air': [],
    # Air metrics (NEW)
    'train_air_acc': [], 'train_air_precision': [], 'train_false_block_rate': [],
    'val_air_acc': [], 'val_air_precision': [], 'val_false_block_rate': [],
    # Volume and error
    'train_vol_ratio': [], 'val_vol_ratio': [],
    'train_error_similarity': [], 'val_error_similarity': [],
    # RFSQ metrics
    'train_stage0_usage': [], 'train_stage0_perplexity': [],
    'train_stage1_usage': [], 'train_stage1_perplexity': [],
    'val_stage0_usage': [], 'val_stage0_perplexity': [],
    'val_stage1_usage': [], 'val_stage1_perplexity': [],
    'train_residual_decay': [], 'val_residual_decay': [],
    # Rare block metrics
    'train_rare_acc': [], 'val_rare_acc': [],
    'train_rare_recall': [], 'val_rare_recall': [],
    'train_rare_precision': [], 'val_rare_precision': [],
    # Training diagnostics
    'train_grad_norm': [], 'learning_rate': [],
}

best_building_f1 = 0
best_building_acc = 0
best_epoch = 0
start_time = time.time()

for epoch in range(TOTAL_EPOCHS):
    current_lr = optimizer.param_groups[0]['lr']
    history['learning_rate'].append(current_lr)
    
    train_m = train_epoch(model, train_loader, optimizer, scaler, device,
                          AIR_TOKENS_TENSOR, TERRAIN_TOKENS_TENSOR, RARE_BLOCK_TOKENS_TENSOR, STRUCTURE_WEIGHT)
    val_m = validate(model, val_loader, device,
                     AIR_TOKENS_TENSOR, TERRAIN_TOKENS_TENSOR, RARE_BLOCK_TOKENS_TENSOR, STRUCTURE_WEIGHT)
    
    # Update LR scheduler
    scheduler.step()
    
    # Record all metrics
    for key in history:
        if key == 'learning_rate':
            continue
        if key.startswith('train_'):
            metric_name = key[6:]
            history[key].append(train_m.get(metric_name, 0))
        elif key.startswith('val_'):
            metric_name = key[4:]
            history[key].append(val_m.get(metric_name, 0))
    
    # Track best building_f1 (for early stopping)
    if val_m['building_f1'] > best_building_f1:
        best_building_f1 = val_m['building_f1']
        best_building_acc = val_m['building_acc']
        best_epoch = epoch + 1
        torch.save(model.state_dict(), f"{OUTPUT_DIR}/vqvae_v7_best.pt")
    
    # Print progress with NEW metrics
    print(f"Epoch {epoch+1:2d} | LR: {current_lr:.2e} | "
          f"Build: {train_m['building_acc']:.1%}/{val_m['building_acc']:.1%} | "
          f"Prec: {val_m['building_precision']:.1%} | F1: {val_m['building_f1']:.1%} | "
          f"Vol: {val_m['vol_ratio']:.2f}")
    
    # Early stopping check
    if early_stopper(val_m['building_f1'], epoch + 1):
        print(f"\nEarly stopping triggered! No improvement for {EARLY_STOPPING_PATIENCE} epochs.")
        print(f"Best {EARLY_STOPPING_METRIC}: {early_stopper.best_value:.1%} at epoch {early_stopper.best_epoch}")
        break

train_time = time.time() - start_time
print(f"\nTraining complete in {train_time/60:.1f} minutes")
print(f"Best val building F1: {best_building_f1:.1%} at epoch {best_epoch}")
print(f"Best val building accuracy: {best_building_acc:.1%}")

## Cell 15: Plot Training Curves

In [None]:
actual_epochs = len(history['train_loss'])
epochs = range(1, actual_epochs + 1)

fig, axes = plt.subplots(5, 4, figsize=(20, 20))

# Row 1: Core building metrics
ax = axes[0, 0]
ax.plot(epochs, history['train_building_acc'], 'b-', label='Train')
ax.plot(epochs, history['val_building_acc'], 'r--', label='Val')
ax.set_title('Building Accuracy (KEY)', fontweight='bold')
ax.legend(); ax.grid(True, alpha=0.3)

ax = axes[0, 1]
ax.plot(epochs, history['train_building_precision'], 'b-', label='Train')
ax.plot(epochs, history['val_building_precision'], 'r--', label='Val')
ax.set_title('Building PRECISION (NEW)', fontweight='bold', color='green')
ax.legend(); ax.grid(True, alpha=0.3)

ax = axes[0, 2]
ax.plot(epochs, history['train_building_f1'], 'b-', label='Train')
ax.plot(epochs, history['val_building_f1'], 'r--', label='Val')
ax.set_title('Building F1 (NEW)', fontweight='bold', color='green')
ax.legend(); ax.grid(True, alpha=0.3)

ax = axes[0, 3]
ax.plot(epochs, history['train_building_recall'], 'b-', label='Train')
ax.plot(epochs, history['val_building_recall'], 'r--', label='Val')
ax.set_title('Building Recall')
ax.legend(); ax.grid(True, alpha=0.3)

# Row 2: Air metrics (NEW)
ax = axes[1, 0]
ax.plot(epochs, history['train_air_acc'], 'b-', label='Train')
ax.plot(epochs, history['val_air_acc'], 'r--', label='Val')
ax.set_title('Air Accuracy (NEW)', fontweight='bold', color='green')
ax.legend(); ax.grid(True, alpha=0.3)

ax = axes[1, 1]
ax.plot(epochs, history['train_false_block_rate'], 'b-', label='Train')
ax.plot(epochs, history['val_false_block_rate'], 'r--', label='Val')
ax.set_title('False Block Rate (NEW)', fontweight='bold', color='orange')
ax.legend(); ax.grid(True, alpha=0.3)

ax = axes[1, 2]
ax.plot(epochs, history['train_vol_ratio'], 'b-', label='Train')
ax.plot(epochs, history['val_vol_ratio'], 'r--', label='Val')
ax.axhline(y=1.0, color='g', linestyle='--', alpha=0.7, label='Target 1.0')
ax.set_title('Volume Ratio (TARGET: ~1.0)', fontweight='bold')
ax.legend(); ax.grid(True, alpha=0.3)

ax = axes[1, 3]
ax.plot(epochs, history['train_vol_penalty'], 'b-', label='Train')
ax.plot(epochs, history['val_vol_penalty'], 'r--', label='Val')
ax.set_title('Volume Penalty Loss (NEW)')
ax.legend(); ax.grid(True, alpha=0.3)

# Row 3: RFSQ metrics
ax = axes[2, 0]
ax.plot(epochs, history['train_stage0_perplexity'], 'b-', label='Train S0')
ax.plot(epochs, history['val_stage0_perplexity'], 'r--', label='Val S0')
ax.set_title('Stage 0 Perplexity')
ax.legend(); ax.grid(True, alpha=0.3)

ax = axes[2, 1]
ax.plot(epochs, history['train_stage1_perplexity'], 'b-', label='Train S1')
ax.plot(epochs, history['val_stage1_perplexity'], 'r--', label='Val S1')
ax.set_title('Stage 1 Perplexity')
ax.legend(); ax.grid(True, alpha=0.3)

ax = axes[2, 2]
ax.plot(epochs, history['train_residual_decay'], 'b-', label='Train')
ax.plot(epochs, history['val_residual_decay'], 'r--', label='Val')
ax.axhline(y=0.5, color='orange', linestyle='--', alpha=0.5, label='<50% target')
ax.set_title('Residual Decay')
ax.legend(); ax.grid(True, alpha=0.3)

ax = axes[2, 3]
ax.plot(epochs, history['train_terrain_acc'], 'b-', label='Train')
ax.plot(epochs, history['val_terrain_acc'], 'r--', label='Val')
ax.set_title('Terrain Accuracy')
ax.legend(); ax.grid(True, alpha=0.3)

# Row 4: Rare blocks and diagnostics
ax = axes[3, 0]
ax.plot(epochs, history['train_rare_acc'], 'b-', label='Train')
ax.plot(epochs, history['val_rare_acc'], 'r--', label='Val')
ax.set_title('Rare Block Accuracy')
ax.legend(); ax.grid(True, alpha=0.3)

ax = axes[3, 1]
ax.plot(epochs, history['train_rare_recall'], 'b-', label='Train')
ax.plot(epochs, history['val_rare_recall'], 'r--', label='Val')
ax.set_title('Rare Block Recall')
ax.legend(); ax.grid(True, alpha=0.3)

ax = axes[3, 2]
ax.plot(epochs, history['train_loss'], 'b-', label='Train')
ax.plot(epochs, history['val_loss'], 'r--', label='Val')
ax.set_title('Total Loss')
ax.legend(); ax.grid(True, alpha=0.3)

ax = axes[3, 3]
ax.plot(epochs, history['learning_rate'], 'g-')
ax.set_title('Learning Rate (Cosine Annealing)')
ax.grid(True, alpha=0.3)

# Row 5: Comparisons and summary
ax = axes[4, 0]
ax.plot(epochs, history['val_error_similarity'], 'b-')
ax.axhline(y=0.5, color='orange', linestyle='--', alpha=0.5)
ax.set_title('Error Similarity (Val)')
ax.grid(True, alpha=0.3)

# v6-freq vs v7 comparison
ax = axes[4, 1]
v6_freq = {'Build\nAcc': 0.492, 'Build\nPrec': 0.58, 'Build\nF1': 0.73, 'Vol\nRatio': 1.68}
v7_final = {
    'Build\nAcc': history['val_building_acc'][-1],
    'Build\nPrec': history['val_building_precision'][-1],
    'Build\nF1': history['val_building_f1'][-1],
    'Vol\nRatio': history['val_vol_ratio'][-1],
}
x = np.arange(len(v6_freq))
width = 0.35
ax.bar(x - width/2, list(v6_freq.values()), width, label='v6-freq', color='gray')
ax.bar(x + width/2, list(v7_final.values()), width, label='v7', color='green')
ax.set_xticks(x)
ax.set_xticklabels(v6_freq.keys())
ax.set_title('v6-freq vs v7')
ax.legend(); ax.grid(True, alpha=0.3)

# Final metrics
ax = axes[4, 2]
final = {
    'Build\nAcc': history['val_building_acc'][-1],
    'Build\nPrec': history['val_building_precision'][-1],
    'Build\nF1': history['val_building_f1'][-1],
    'Vol\nRatio': history['val_vol_ratio'][-1],
    'Air\nAcc': history['val_air_acc'][-1],
}
colors = ['green', 'blue', 'purple', 'orange', 'cyan']
bars = ax.bar(final.keys(), final.values(), color=colors)
ax.set_title('Final Val Metrics')
ax.set_ylim(0, max(1.5, max(final.values()) * 1.1))
for bar, val in zip(bars, final.values()):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02,
            f'{val:.2f}', ha='center', fontsize=9)

# Summary
ax = axes[4, 3]
ax.axis('off')
summary = f"""VQ-VAE v7 Results
──────────────────────────────
Best Building F1: {best_building_f1:.1%} (epoch {best_epoch})
Best Building Acc: {best_building_acc:.1%}

Final Metrics:
  Building Accuracy: {history['val_building_acc'][-1]:.1%}
  Building Precision: {history['val_building_precision'][-1]:.1%}
  Building F1: {history['val_building_f1'][-1]:.1%}
  Volume Ratio: {history['val_vol_ratio'][-1]:.2f} (target ~1.0)
  Air Accuracy: {history['val_air_acc'][-1]:.1%}
  False Block Rate: {history['val_false_block_rate'][-1]:.1%}

Changes from v6-freq:
  - Dual U-Net skip connections
  - Volume penalty: {VOLUME_PENALTY_WEIGHT}
  - Freq cap: {FREQUENCY_WEIGHT_CAP}x
  - Cosine annealing LR

Training Time: {train_time/60:.1f} min"""
ax.text(0.05, 0.95, summary, transform=ax.transAxes, fontsize=10,
        verticalalignment='top', fontfamily='monospace')

plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/vqvae_v7_training.png", dpi=150, bbox_inches='tight')
plt.show()

# Fixed Code Block
print(f"{'Metric':<25} {'v6-freq':<15} {'v7':<15} {'Change':<15}")
print("-" * 70)

# 1. Building Accuracy (Fixed order: <15.1%)
print(f"{'Building Accuracy':<25} {'49.2%':<15} {history['val_building_acc'][-1]:<15.1%} {(history['val_building_acc'][-1]-0.492)*100:+.1f}%")

# 2. Building Precision (Fixed order: <15.1%)
print(f"{'Building Precision':<25} {'~58%':<15} {history['val_building_precision'][-1]:<15.1%} {'(NEW)':<15}")

# 3. Building F1 (Fixed order: <15.1%)
print(f"{'Building F1':<25} {'~73%':<15} {history['val_building_f1'][-1]:<15.1%} {'(NEW)':<15}")

# 4. Volume Ratio (Fixed order: <15.2f)
print(f"{'Volume Ratio':<25} {'1.68':<15} {history['val_vol_ratio'][-1]:<15.2f} {history['val_vol_ratio'][-1]-1.68:+.2f}")

# 5. Air Accuracy (Fixed order: <15.1%)
print(f"{'Air Accuracy':<25} {'N/A':<15} {history['val_air_acc'][-1]:<15.1%} {'(NEW)':<15}")

# 6. False Block Rate (Fixed order: <15.1%)
print(f"{'False Block Rate':<25} {'N/A':<15} {history['val_false_block_rate'][-1]:<15.1%} {'(NEW)':<15}")

## Cell 16: Save Results

In [None]:
results = {
    'config': {
        'version': 'v7',
        'changes_from_v6_freq': [
            'Dual U-Net skip connections (16x16x16 + 32x32x32)',
            f'Volume ratio penalty (weight={VOLUME_PENALTY_WEIGHT})',
            f'Reduced frequency cap ({FREQUENCY_WEIGHT_CAP}x from 10x)',
            'Cosine annealing LR schedule',
            f'Early stopping (patience={EARLY_STOPPING_PATIENCE} on {EARLY_STOPPING_METRIC})',
            'New metrics: precision, F1, air_acc, false_block_rate',
        ],
        'hidden_dims': HIDDEN_DIMS,
        'rfsq_levels_per_stage': RFSQ_LEVELS_PER_STAGE,
        'num_stages': NUM_STAGES,
        'frequency_weight_cap': FREQUENCY_WEIGHT_CAP,
        'volume_penalty_weight': VOLUME_PENALTY_WEIGHT,
        'total_epochs': TOTAL_EPOCHS,
        'actual_epochs': actual_epochs,
        'batch_size': BATCH_SIZE,
        'base_lr': BASE_LR,
        'min_lr': MIN_LR,
        'early_stopping_patience': EARLY_STOPPING_PATIENCE,
        'early_stopping_metric': EARLY_STOPPING_METRIC,
        'seed': SEED,
    },
    'results': {
        'best_building_f1': float(best_building_f1),
        'best_building_acc': float(best_building_acc),
        'best_epoch': best_epoch,
        'final_building_acc': float(history['val_building_acc'][-1]),
        'final_building_precision': float(history['val_building_precision'][-1]),
        'final_building_f1': float(history['val_building_f1'][-1]),
        'final_building_recall': float(history['val_building_recall'][-1]),
        'final_vol_ratio': float(history['val_vol_ratio'][-1]),
        'final_air_acc': float(history['val_air_acc'][-1]),
        'final_air_precision': float(history['val_air_precision'][-1]),
        'final_false_block_rate': float(history['val_false_block_rate'][-1]),
        'final_rare_acc': float(history['val_rare_acc'][-1]),
        'final_rare_recall': float(history['val_rare_recall'][-1]),
        'final_stage0_perplexity': float(history['val_stage0_perplexity'][-1]) if history['val_stage0_perplexity'] else 0,
        'final_stage1_perplexity': float(history['val_stage1_perplexity'][-1]) if history['val_stage1_perplexity'] else 0,
        'training_time_min': float(train_time / 60),
    },
    'history': {k: [float(x) for x in v] for k, v in history.items()},
}

with open(f"{OUTPUT_DIR}/vqvae_v7_results.json", 'w') as f:
    json.dump(results, f, indent=2)

checkpoint = {
    'model_state_dict': model.state_dict(),
    'config': {
        'version': 'v7',
        'vocab_size': VOCAB_SIZE,
        'emb_dim': EMBEDDING_DIM,
        'hidden_dims': HIDDEN_DIMS,
        'rfsq_levels': RFSQ_LEVELS_PER_STAGE,
        'num_stages': NUM_STAGES,
        'frequency_weight_cap': FREQUENCY_WEIGHT_CAP,
        'volume_penalty_weight': VOLUME_PENALTY_WEIGHT,
    },
    'air_tokens': AIR_TOKENS_LIST,
    'terrain_tokens': sorted(TERRAIN_TOKENS),
    'rare_tokens': sorted(RARE_BLOCK_TOKENS),
    'best_building_f1': float(best_building_f1),
    'best_building_acc': float(best_building_acc),
    'best_epoch': best_epoch,
}

torch.save(checkpoint, f"{OUTPUT_DIR}/vqvae_v7_best_checkpoint.pt")
torch.save(model.state_dict(), f"{OUTPUT_DIR}/vqvae_v7_final.pt")

print("\nResults saved:")
print(f"  - {OUTPUT_DIR}/vqvae_v7_results.json")
print(f"  - {OUTPUT_DIR}/vqvae_v7_best_checkpoint.pt")
print(f"  - {OUTPUT_DIR}/vqvae_v7_final.pt")
print(f"  - {OUTPUT_DIR}/vqvae_v7_training.png")

print("\n" + "="*70)
print("FINAL RESULTS - VQ-VAE v7 (U-Net + Volume Penalty)")
print("="*70)
print(f"Best building F1: {best_building_f1:.1%} at epoch {best_epoch}")
print(f"Best building accuracy: {best_building_acc:.1%}")
print(f"\nKEY IMPROVEMENTS:")
print(f"  Volume ratio: {history['val_vol_ratio'][-1]:.2f} (v6-freq was 1.68)")
print(f"  Building precision: {history['val_building_precision'][-1]:.1%}")
print(f"  False block rate: {history['val_false_block_rate'][-1]:.1%}")
print(f"\nTraining time: {train_time/60:.1f} minutes")