# VQ-VAE v8-B Training - Self-Contained Version

**âœ… This notebook is completely self-contained** - all model code is included directly!

No external file imports needed. Just upload your data to Google Drive and run!

## Changes from v6-freq

| Change | v6-freq | v8-B |
|--------|---------|------|
| Latent resolution | 8Ã—8Ã—8 (512 positions) | **16Ã—16Ã—16 (4,096 positions)** |
| Compression ratio | 64:1 | **8:1** |
| Downsampling stages | 2 (32â†’16â†’8) | **1 (32â†’16)** |
| ResBlocks at latent | 2 | **6 (more capacity)** |
| Volume penalty | No | **Yes (fixes 1.68x over-prediction)** |
| Perceptual loss | No | **Yes (spatial smoothness)** |

## Goals

| Metric | v6-freq | v8-B Target |
|--------|---------|-------------|
| Building Accuracy | 49.2% | **60-65%** |
| Volume Ratio | 1.68x | **1.1-1.2x** |

## 1. Setup - Mount Google Drive

In [None]:
# Mount Google Drive (uncomment for Colab)
from google.colab import drive
drive.mount('/content/drive')

import os
from pathlib import Path

# Set paths - CHANGE THIS to match your Google Drive structure
DRIVE_BASE = '/content/drive/MyDrive/minecraft_ai'
OUTPUT_DIR = '/content/drive/MyDrive/minecraft_ai/vqvae_v8b'

os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"Output directory: {OUTPUT_DIR}")
print(f"Data directory: {DRIVE_BASE}")

## 2. Imports

In [None]:
import json
import random
import time
from pathlib import Path
from typing import Dict, List, Tuple, Any, Set, Optional
from collections import Counter

import h5py
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 3. Model Code - FSQ and RFSQ Quantization

All model code is included inline - no external imports!

In [None]:
# ============================================================================
# FSQ (Finite Scalar Quantization)
# ============================================================================

class FSQ(nn.Module):
    """Finite Scalar Quantization."""

    def __init__(self, levels: List[int], eps: float = 1e-3):
        super().__init__()
        self.levels = levels
        self.dim = len(levels)
        self.eps = eps
        self.codebook_size = int(np.prod(levels))

        self.register_buffer('_levels', torch.tensor(levels, dtype=torch.float32))

        basis = []
        acc = 1
        for L in reversed(levels):
            basis.append(acc)
            acc *= L
        self.register_buffer('_basis', torch.tensor(list(reversed(basis)), dtype=torch.long))

        half_levels = [(L - 1) / 2 for L in levels]
        self.register_buffer('_half_levels', torch.tensor(half_levels, dtype=torch.float32))

    @property
    def num_codes(self) -> int:
        return self.codebook_size

    def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        z_bounded = torch.tanh(z)
        z_q = self._quantize(z_bounded)
        z_q = z_bounded + (z_q - z_bounded).detach()  # Straight-through
        indices = self._to_indices(z_q)
        return z_q, indices

    def _quantize(self, z: torch.Tensor) -> torch.Tensor:
        z_q_list = []
        for i in range(self.dim):
            L = self._levels[i]
            half_L = self._half_levels[i]
            z_i = z[..., i]
            z_i = z_i * half_L
            z_i = torch.round(z_i)
            z_i = torch.clamp(z_i, -half_L, half_L)
            z_i = z_i / half_L
            z_q_list.append(z_i)
        return torch.stack(z_q_list, dim=-1)

    def _to_indices(self, z_q: torch.Tensor) -> torch.Tensor:
        indices = torch.zeros(z_q.shape[:-1], dtype=torch.long, device=z_q.device)
        for i in range(self.dim):
            L = self._levels[i].long()
            half_L = self._half_levels[i]
            z_i = z_q[..., i]
            level_idx = ((z_i * half_L) + half_L).round().long()
            level_idx = torch.clamp(level_idx, 0, L - 1)
            indices = indices + level_idx * self._basis[i]
        return indices

    def get_codebook_usage(self, indices: torch.Tensor) -> Tuple[float, float]:
        flat_indices = indices.flatten()
        counts = torch.bincount(flat_indices, minlength=self.codebook_size).float()
        usage = (counts > 0).float().mean().item()
        probs = counts / counts.sum()
        probs = probs[probs > 0]
        entropy = -(probs * torch.log(probs)).sum()
        perplexity = torch.exp(entropy).item()
        return usage, perplexity


# ============================================================================
# RFSQ (Residual FSQ with LayerNorm)
# ============================================================================

class InvertibleLayerNorm(nn.Module):
    """LayerNorm that stores statistics for exact inverse transformation."""

    def __init__(self, num_features: int, eps: float = 1e-5):
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(num_features))
        self.bias = nn.Parameter(torch.zeros(num_features))
        self.register_buffer('stored_mean', None, persistent=False)
        self.register_buffer('stored_std', None, persistent=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        channels_last = x.shape[-1] == self.num_features
        if channels_last:
            self.stored_mean = x.mean(dim=(1, 2, 3), keepdim=True)
            self.stored_std = x.std(dim=(1, 2, 3), keepdim=True) + self.eps
            x_norm = (x - self.stored_mean) / self.stored_std
            return x_norm * self.weight + self.bias
        else:
            self.stored_mean = x.mean(dim=(2, 3, 4), keepdim=True)
            self.stored_std = x.std(dim=(2, 3, 4), keepdim=True) + self.eps
            x_norm = (x - self.stored_mean) / self.stored_std
            return x_norm * self.weight.view(1, -1, 1, 1, 1) + self.bias.view(1, -1, 1, 1, 1)

    def inverse(self, x_norm: torch.Tensor) -> torch.Tensor:
        if self.stored_mean is None or self.stored_std is None:
            raise RuntimeError("Must call forward() before inverse()")
        channels_last = x_norm.shape[-1] == self.num_features
        if channels_last:
            x = (x_norm - self.bias) / self.weight
            return x * self.stored_std + self.stored_mean
        else:
            x = (x_norm - self.bias.view(1, -1, 1, 1, 1)) / self.weight.view(1, -1, 1, 1, 1)
            return x * self.stored_std + self.stored_mean


class RFSQStage(nn.Module):
    """Single stage of Residual FSQ with LayerNorm conditioning."""

    def __init__(self, levels: List[int]):
        super().__init__()
        self.levels = levels
        self.fsq = FSQ(levels)
        self.layernorm = InvertibleLayerNorm(len(levels))

    @property
    def codebook_size(self) -> int:
        return self.fsq.codebook_size

    def forward(self, residual: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        z_norm = self.layernorm(residual)
        z_q_norm, indices = self.fsq(z_norm)
        z_q = self.layernorm.inverse(z_q_norm)
        new_residual = residual - z_q
        return z_q, new_residual, indices


class RFSQ(nn.Module):
    """Robust Residual FSQ with multiple stages."""

    def __init__(self, levels_per_stage: List[int], num_stages: int = 2):
        super().__init__()
        self.levels_per_stage = levels_per_stage
        self.num_stages = num_stages
        self.dim = len(levels_per_stage)

        self.stages = nn.ModuleList([
            RFSQStage(levels_per_stage) for _ in range(num_stages)
        ])

        codes_per_stage = int(np.prod(levels_per_stage))
        self.codebook_size = codes_per_stage ** num_stages
        self.codes_per_stage = codes_per_stage
        self._usage_indices = []

    @property
    def num_codes(self) -> int:
        return self.codebook_size

    def reset_usage(self):
        self._usage_indices = []

    def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        residual = z
        z_q_sum = torch.zeros_like(z)
        all_indices = []

        for stage in self.stages:
            z_q, residual, indices = stage(residual)
            z_q_sum = z_q_sum + z_q
            all_indices.append(indices)

        self._usage_indices.append(all_indices)
        return z_q_sum, all_indices

    def get_usage_stats(self) -> Dict[str, Tuple[float, float]]:
        if not self._usage_indices:
            return {}
        stats = {}
        for stage_idx in range(self.num_stages):
            all_stage_indices = torch.cat([
                batch[stage_idx].flatten()
                for batch in self._usage_indices
            ])
            usage, perplexity = self.stages[stage_idx].fsq.get_codebook_usage(all_stage_indices)
            stats[f'stage{stage_idx}'] = (usage, perplexity)
        return stats

print("âœ“ FSQ and RFSQ modules defined")

## 4. Model Code - VQ-VAE v8-B Architecture

In [None]:
class ResidualBlock3D(nn.Module):
    """3D residual block with BatchNorm."""

    def __init__(self, in_channels: int, out_channels: int = None):
        super().__init__()
        if out_channels is None:
            out_channels = in_channels

        self.conv1 = nn.Conv3d(in_channels, out_channels, 3, padding=1)
        self.bn1 = nn.BatchNorm3d(out_channels)
        self.conv2 = nn.Conv3d(out_channels, out_channels, 3, padding=1)
        self.bn2 = nn.BatchNorm3d(out_channels)

        if in_channels != out_channels:
            self.shortcut = nn.Conv3d(in_channels, out_channels, 1)
        else:
            self.shortcut = nn.Identity()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        residual = self.shortcut(x)
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out = out + residual
        return F.relu(out)


class EncoderV8B(nn.Module):
    """Encoder: 32x32x32 â†’ 16x16x16 latent."""

    def __init__(
        self,
        in_channels: int = 40,
        hidden_dim: int = 192,
        rfsq_dim: int = 4,
        num_resblocks_per_stage: int = 2,
        num_resblocks_latent: int = 6,
        dropout: float = 0.1,
    ):
        super().__init__()

        self.initial = nn.Sequential(
            nn.Conv3d(in_channels, hidden_dim, 3, padding=1),
            nn.BatchNorm3d(hidden_dim),
            nn.ReLU(inplace=True)
        )

        self.downsample = nn.Sequential(
            nn.Conv3d(hidden_dim, hidden_dim, 4, stride=2, padding=1),
            nn.BatchNorm3d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout3d(dropout)
        )

        self.stage_blocks = nn.Sequential(*[
            ResidualBlock3D(hidden_dim)
            for _ in range(num_resblocks_per_stage)
        ])

        self.latent_blocks = nn.Sequential(*[
            ResidualBlock3D(hidden_dim)
            for _ in range(num_resblocks_latent)
        ])

        self.latent_proj = nn.Conv3d(hidden_dim, rfsq_dim, 3, padding=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.initial(x)
        x = self.downsample(x)
        x = self.stage_blocks(x)
        x = self.latent_blocks(x)
        z_e = self.latent_proj(x)
        return z_e


class DecoderV8B(nn.Module):
    """Decoder: 16x16x16 latent â†’ 32x32x32 output."""

    def __init__(
        self,
        rfsq_dim: int = 4,
        hidden_dim: int = 192,
        num_blocks: int = 3717,
        num_resblocks_per_stage: int = 2,
        num_resblocks_latent: int = 6,
        dropout: float = 0.1,
    ):
        super().__init__()

        self.initial = nn.Sequential(
            nn.Conv3d(rfsq_dim, hidden_dim, 3, padding=1),
            nn.BatchNorm3d(hidden_dim),
            nn.ReLU(inplace=True)
        )

        self.latent_blocks = nn.Sequential(*[
            ResidualBlock3D(hidden_dim)
            for _ in range(num_resblocks_latent)
        ])

        self.stage_blocks = nn.Sequential(*[
            ResidualBlock3D(hidden_dim)
            for _ in range(num_resblocks_per_stage)
        ])

        self.upsample = nn.Sequential(
            nn.ConvTranspose3d(hidden_dim, hidden_dim, 4, stride=2, padding=1),
            nn.BatchNorm3d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout3d(dropout)
        )

        self.final = nn.Conv3d(hidden_dim, num_blocks, 3, padding=1)

    def forward(self, z_q: torch.Tensor) -> torch.Tensor:
        x = self.initial(z_q)
        x = self.latent_blocks(x)
        x = self.stage_blocks(x)
        x = self.upsample(x)
        logits = self.final(x)
        return logits


class VQVAEv8B(nn.Module):
    """VQ-VAE v8-B with 16Ã—16Ã—16 latent resolution."""

    def __init__(
        self,
        vocab_size: int = 3717,
        emb_dim: int = 40,
        hidden_dim: int = 192,
        rfsq_levels: List[int] = None,
        num_stages: int = 2,
        dropout: float = 0.1,
        pretrained_embeddings: torch.Tensor = None,
    ):
        super().__init__()

        if rfsq_levels is None:
            rfsq_levels = [5, 5, 5, 5]

        self.rfsq_dim = len(rfsq_levels)

        self.block_emb = nn.Embedding(vocab_size, emb_dim)
        if pretrained_embeddings is not None:
            self.block_emb.weight.data.copy_(pretrained_embeddings)
            self.block_emb.weight.requires_grad = False

        self.encoder = EncoderV8B(
            in_channels=emb_dim,
            hidden_dim=hidden_dim,
            rfsq_dim=self.rfsq_dim,
            dropout=dropout
        )
        self.quantizer = RFSQ(
            levels_per_stage=rfsq_levels,
            num_stages=num_stages
        )
        self.decoder = DecoderV8B(
            rfsq_dim=self.rfsq_dim,
            hidden_dim=hidden_dim,
            num_blocks=vocab_size,
            dropout=dropout
        )

    def forward(self, block_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
        z_e = self.encode(block_ids)
        z_q, indices = self.quantize(z_e)
        logits = self.decode(z_q)
        return logits, z_q, indices

    def encode(self, block_ids: torch.Tensor) -> torch.Tensor:
        x = self.block_emb(block_ids)
        x = x.permute(0, 4, 1, 2, 3)
        z_e = self.encoder(x)
        z_e = z_e.permute(0, 2, 3, 4, 1)
        return z_e

    def quantize(self, z_e: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        z_q, indices = self.quantizer(z_e)
        return z_q, indices

    def decode(self, z_q: torch.Tensor) -> torch.Tensor:
        z_q = z_q.permute(0, 4, 1, 2, 3)
        logits = self.decoder(z_q)
        return logits

print("âœ“ VQ-VAE v8-B architecture defined")

## 5. Model Code - Loss Functions

In [None]:
class FrequencyWeightedLoss(nn.Module):
    """Frequency-weighted CE loss with BALANCED volume control.

    This loss function ensures:
    - volume_ratio approaches 1.0 (no over-prediction)
    - recall stays high (no structure erasure)

    The key: Volume loss and False Air loss work TOGETHER.
    - Volume loss: "don't predict too many blocks overall"
    - False air loss: "don't erase existing structure blocks"

    Both are DIFFERENTIABLE through softmax!
    """

    def __init__(
        self,
        frequency_weights: torch.Tensor,
        frequency_cap: float = 5.0,
        volume_penalty_weight: float = 10.0,      # Prevent over-prediction
        false_air_penalty_weight: float = 5.0,    # Protect recall
        perceptual_weight: float = 0.1,
        air_tokens: Optional[Set[int]] = None,
    ):
        super().__init__()
        clamped_weights = frequency_weights.clamp(max=frequency_cap)
        self.register_buffer('freq_weights', clamped_weights)
        self.volume_penalty_weight = volume_penalty_weight
        self.false_air_penalty_weight = false_air_penalty_weight
        self.perceptual_weight = perceptual_weight
        if air_tokens is None:
            air_tokens = {102, 576, 3352}
        self.air_tokens = list(air_tokens)

    def forward(
        self,
        logits: torch.Tensor,
        target: torch.Tensor,
        z_q: torch.Tensor,
    ) -> Dict[str, torch.Tensor]:
        vocab_size = logits.shape[1]

        # 1. Frequency-weighted CE loss
        logits_flat = logits.permute(0, 2, 3, 4, 1).reshape(-1, vocab_size)
        target_flat = target.reshape(-1)
        ce_loss = F.cross_entropy(logits_flat, target_flat, weight=self.freq_weights, reduction='mean')

        # 2. DIFFERENTIABLE Volume Loss
        # Get softmax probabilities (gradients flow through!)
        probs = F.softmax(logits, dim=1)  # [B, vocab_size, H, W, D]

        # Sum probability of predicting air tokens
        air_prob = torch.zeros_like(probs[:, 0, :, :, :])  # [B, H, W, D]
        for air_tok in self.air_tokens:
            if air_tok < probs.shape[1]:
                air_prob = air_prob + probs[:, air_tok, :, :, :]

        # Soft volume = sum of (1 - air_probability)
        non_air_prob = 1.0 - air_prob
        pred_volume_soft = non_air_prob.sum()

        # Ground truth volume
        air_tokens_tensor = torch.tensor(self.air_tokens, device=target.device, dtype=target.dtype)
        gt_is_air = torch.isin(target, air_tokens_tensor)
        gt_volume = (~gt_is_air).float().sum()

        # Volume ratio and loss (differentiable!)
        volume_ratio_soft = pred_volume_soft / (gt_volume + 1e-6)
        volume_loss = (volume_ratio_soft - 1.0) ** 2

        # 3. FALSE AIR PENALTY - Protects recall!
        # Penalize predicting air where ground truth has structure
        gt_is_structure = ~gt_is_air
        if gt_is_structure.any():
            # High air_prob at structure locations = bad = high penalty
            false_air_loss = air_prob[gt_is_structure].mean()
        else:
            false_air_loss = torch.tensor(0.0, device=logits.device)

        # 4. Perceptual loss (spatial smoothness)
        diff_h = (z_q[:, :, 1:, :, :] - z_q[:, :, :-1, :, :]).abs().mean()
        diff_w = (z_q[:, :, :, 1:, :] - z_q[:, :, :, :-1, :]).abs().mean()
        diff_d = (z_q[:, :, :, :, 1:] - z_q[:, :, :, :, :-1]).abs().mean()
        perceptual_loss = (diff_h + diff_w + diff_d) / 3.0

        # Combined loss
        total_loss = (
            ce_loss +
            self.volume_penalty_weight * volume_loss +
            self.false_air_penalty_weight * false_air_loss +
            self.perceptual_weight * perceptual_loss
        )

        # Compute metrics for logging (non-differentiable)
        with torch.no_grad():
            pred_hard = torch.argmax(logits, dim=1)
            pred_is_air_hard = torch.isin(pred_hard, air_tokens_tensor)
            pred_volume_hard = (~pred_is_air_hard).float().sum()
            volume_ratio_hard = pred_volume_hard / (gt_volume + 1e-6)

            # Recall = structure preserved / total structure
            structure_preserved = gt_is_structure & (~pred_is_air_hard)
            recall = structure_preserved.float().sum() / (gt_is_structure.float().sum() + 1e-6)

        return {
            'loss': total_loss,
            'ce_loss': ce_loss.detach(),
            'volume_loss': volume_loss.detach(),
            'false_air_loss': false_air_loss.detach() if torch.is_tensor(false_air_loss) else false_air_loss,
            'perceptual_loss': perceptual_loss.detach(),
            'volume_ratio': volume_ratio_hard,
            'recall': recall,
        }


def compute_frequency_weights(
    block_ids: torch.Tensor,
    vocab_size: int,
    smoothing: float = 0.5,
) -> torch.Tensor:
    """Compute frequency-based weights."""
    counts = torch.bincount(block_ids.flatten(), minlength=vocab_size).clamp(min=1)
    total = counts.sum()
    weights = (total.float() / counts.float()) ** smoothing
    return weights

print("[OK] BALANCED loss function defined!")
print("  - Volume penalty (differentiable): prevents over-prediction")
print("  - False air penalty (differentiable): protects recall")
print("="*70)
print("ALL MODEL CODE LOADED - Ready to train!")
print("="*70)

## 6. Configuration

In [None]:
# === Data Paths ===
DATA_DIR = f"{DRIVE_BASE}/splits/train"
VAL_DIR = f"{DRIVE_BASE}/splits/val"
VOCAB_PATH = f"{DRIVE_BASE}/vocabulary/tok2block.json"
V3_EMBEDDINGS_PATH = f"{DRIVE_BASE}/vocabulary/block_embeddings_v3.npy"

# Verify paths
print("Checking paths...")
for name, path in [('DATA_DIR', DATA_DIR), ('VAL_DIR', VAL_DIR), 
                    ('VOCAB_PATH', VOCAB_PATH), ('V3_EMBEDDINGS_PATH', V3_EMBEDDINGS_PATH)]:
    exists = Path(path).exists()
    print(f"  {name}: {'[OK]' if exists else '[NOT FOUND]'}")
    if not exists:
        print(f"    Path: {path}")

# === V8-B Architecture ===
HIDDEN_DIM = 192
RFSQ_LEVELS_PER_STAGE = [5, 5, 5, 5]
NUM_STAGES = 2
DROPOUT = 0.1
NUM_RESBLOCKS_LATENT = 6

# === Loss Weights (BALANCED) ===
FREQUENCY_WEIGHT_CAP = 5.0
VOLUME_PENALTY_WEIGHT = 10.0        # Prevent over-prediction
FALSE_AIR_PENALTY_WEIGHT = 5.0      # Protect recall (CRITICAL!)
PERCEPTUAL_WEIGHT = 0.1

# === Training ===
TOTAL_EPOCHS = 35
BATCH_SIZE = 2
BASE_LR = 2e-4
USE_AMP = True
GRAD_ACCUM_STEPS = 8
SEED = 42
NUM_WORKERS = 2

print(f"\nConfiguration:")
print(f"  Latent: 16x16x16 (4,096 positions)")
print(f"  Hidden dim: {HIDDEN_DIM}")
print(f"  Batch size: {BATCH_SIZE} (effective: {BATCH_SIZE * GRAD_ACCUM_STEPS})")
print(f"  Epochs: {TOTAL_EPOCHS}")
print(f"  Learning rate: {BASE_LR}")
print(f"\nLoss weights:")
print(f"  Volume penalty: {VOLUME_PENALTY_WEIGHT} (prevents over-prediction)")
print(f"  False air penalty: {FALSE_AIR_PENALTY_WEIGHT} (protects recall)")

## 7. Load Vocabulary and Embeddings

In [None]:
# Load vocabulary
with open(VOCAB_PATH, 'r') as f:
    tok2block = {int(k): v for k, v in json.load(f).items()}

VOCAB_SIZE = len(tok2block)
print(f"Vocabulary size: {VOCAB_SIZE}")

# Find air tokens
AIR_TOKENS: Set[int] = set()
for tok, block in tok2block.items():
    if 'air' in block.lower() and 'stair' not in block.lower():
        AIR_TOKENS.add(tok)
        print(f"  Air token: {tok} = {block}")

AIR_TOKENS_TENSOR = torch.tensor(sorted(AIR_TOKENS), dtype=torch.long)

# Load embeddings
v3_embeddings = np.load(V3_EMBEDDINGS_PATH).astype(np.float32)
EMBEDDING_DIM = v3_embeddings.shape[1]
print(f"\nV3 embeddings: {v3_embeddings.shape} (dim={EMBEDDING_DIM})")

## 8. Compute Frequency Weights

In [None]:
print("Computing block frequencies from training data...")
print("(This may take a few minutes)")

all_block_ids = []
train_files = sorted(Path(DATA_DIR).glob("*.h5"))

for h5_file in tqdm(train_files, desc="Scanning"):
    with h5py.File(h5_file, 'r') as f:
        key = list(f.keys())[0]
        structure = f[key][:].flatten()
        all_block_ids.append(torch.from_numpy(structure).long())

all_block_ids = torch.cat(all_block_ids)
print(f"\nTotal blocks scanned: {len(all_block_ids):,}")

FREQUENCY_WEIGHT_TENSOR = compute_frequency_weights(
    all_block_ids,
    vocab_size=VOCAB_SIZE,
    smoothing=0.5
).clamp(max=FREQUENCY_WEIGHT_CAP)

print(f"Frequency weights computed (cap={FREQUENCY_WEIGHT_CAP}x)")

## 9. Dataset

In [None]:
class VQVAEDataset(Dataset):
    def __init__(self, data_dir: str):
        self.data_dir = Path(data_dir)
        self.h5_files = sorted(self.data_dir.glob("*.h5"))
        if not self.h5_files:
            raise ValueError(f"No H5 files in {data_dir}")
        print(f"Found {len(self.h5_files)} structures in {data_dir}")

    def __len__(self):
        return len(self.h5_files)

    def __getitem__(self, idx):
        with h5py.File(self.h5_files[idx], 'r') as f:
            key = list(f.keys())[0]
            structure = f[key][:].astype(np.int64)
        return torch.from_numpy(structure).long()

train_dataset = VQVAEDataset(DATA_DIR)
val_dataset = VQVAEDataset(VAL_DIR)

## 10. Create Model

In [None]:
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

model = VQVAEv8B(
    vocab_size=VOCAB_SIZE,
    emb_dim=EMBEDDING_DIM,
    hidden_dim=HIDDEN_DIM,
    rfsq_levels=RFSQ_LEVELS_PER_STAGE,
    num_stages=NUM_STAGES,
    dropout=DROPOUT,
    pretrained_embeddings=torch.from_numpy(v3_embeddings),
).to(device)

total_params = sum(p.numel() for p in model.parameters())
print(f"Model: {total_params:,} parameters")

criterion = FrequencyWeightedLoss(
    frequency_weights=FREQUENCY_WEIGHT_TENSOR,
    frequency_cap=FREQUENCY_WEIGHT_CAP,
    volume_penalty_weight=VOLUME_PENALTY_WEIGHT,
    false_air_penalty_weight=FALSE_AIR_PENALTY_WEIGHT,
    perceptual_weight=PERCEPTUAL_WEIGHT,
    air_tokens=AIR_TOKENS,
).to(device)

optimizer = optim.AdamW(model.parameters(), lr=BASE_LR, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=TOTAL_EPOCHS, eta_min=1e-5)
scaler = torch.amp.GradScaler('cuda', enabled=USE_AMP)

print("Model, criterion, optimizer created")

## 11. Data Loaders

In [None]:
g = torch.Generator().manual_seed(SEED)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, 
                          num_workers=NUM_WORKERS, pin_memory=True, generator=g)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, 
                        num_workers=NUM_WORKERS, pin_memory=True)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")

## 12. Training Functions

In [None]:
def compute_metrics(logits, targets, air_tokens):
    """Compute key metrics."""
    device = logits.device
    B, C, X, Y, Z = logits.shape
    
    logits_flat = logits.permute(0, 2, 3, 4, 1).reshape(-1, C)
    targets_flat = targets.view(-1)
    
    air_dev = air_tokens.to(device)
    is_air = torch.isin(targets_flat, air_dev)
    is_building = ~is_air
    
    preds = logits_flat.argmax(dim=1)
    is_air_pred = torch.isin(preds, air_dev)
    correct = (preds == targets_flat).float()
    
    metrics = {}
    metrics['overall_acc'] = correct.mean()
    
    if is_building.any():
        metrics['building_acc'] = correct[is_building].mean()
        metrics['building_recall'] = (is_building & ~is_air_pred).sum().float() / is_building.sum()
    else:
        metrics['building_acc'] = torch.tensor(0.0, device=device)
        metrics['building_recall'] = torch.tensor(0.0, device=device)
    
    is_struct = ~is_air
    metrics['struct_recall'] = (is_struct & ~is_air_pred).sum().float() / is_struct.sum() if is_struct.any() else torch.tensor(0.0, device=device)
    
    pred_building = ~is_air_pred
    pred_vol = pred_building.sum().float()
    orig_vol = is_struct.sum().float()
    metrics['volume_ratio'] = pred_vol / orig_vol if orig_vol > 0 else torch.tensor(1.0, device=device)
    
    return metrics


def train_epoch(model, criterion, loader, optimizer, scaler, device, air_tokens):
    model.train()
    model.quantizer.reset_usage()
    
    # Include ALL metrics from loss function
    metrics_sum = {
        'loss': 0.0, 'ce_loss': 0.0, 'volume_loss': 0.0, 
        'false_air_loss': 0.0, 'perceptual_loss': 0.0,
        'overall_acc': 0.0, 'building_acc': 0.0, 'building_recall': 0.0,
        'struct_recall': 0.0, 'volume_ratio': 0.0, 'recall': 0.0
    }
    n = 0
    
    optimizer.zero_grad()
    
    for batch_idx, batch in enumerate(tqdm(loader, desc="Train", leave=False)):
        batch = batch.to(device)
        
        with torch.amp.autocast('cuda', enabled=USE_AMP):
            logits, z_q, indices = model(batch)
            loss_dict = criterion(logits, batch, z_q)
            loss = loss_dict['loss'] / GRAD_ACCUM_STEPS
        
        scaler.scale(loss).backward()
        
        if (batch_idx + 1) % GRAD_ACCUM_STEPS == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
        
        with torch.no_grad():
            # Collect loss metrics
            for k in ['loss', 'ce_loss', 'volume_loss', 'false_air_loss', 'perceptual_loss', 'volume_ratio', 'recall']:
                if k in loss_dict:
                    val = loss_dict[k]
                    metrics_sum[k] += val.item() if torch.is_tensor(val) else val
            
            batch_metrics = compute_metrics(logits, batch, air_tokens)
            for k, v in batch_metrics.items():
                metrics_sum[k] += v.item()
        
        n += 1
    
    metrics = {k: v / n for k, v in metrics_sum.items()}
    
    stage_stats = model.quantizer.get_usage_stats()
    for stage_name, (usage, perp) in stage_stats.items():
        metrics[f'{stage_name}_usage'] = usage
        metrics[f'{stage_name}_perplexity'] = perp
    
    return metrics


@torch.no_grad()
def validate(model, criterion, loader, device, air_tokens):
    model.eval()
    model.quantizer.reset_usage()
    
    # Include ALL metrics from loss function
    metrics_sum = {
        'loss': 0.0, 'ce_loss': 0.0, 'volume_loss': 0.0,
        'false_air_loss': 0.0, 'perceptual_loss': 0.0,
        'overall_acc': 0.0, 'building_acc': 0.0, 'building_recall': 0.0,
        'struct_recall': 0.0, 'volume_ratio': 0.0, 'recall': 0.0
    }
    n = 0
    
    for batch in tqdm(loader, desc="Val", leave=False):
        batch = batch.to(device)
        
        with torch.amp.autocast('cuda', enabled=USE_AMP):
            logits, z_q, indices = model(batch)
            loss_dict = criterion(logits, batch, z_q)
        
        # Collect loss metrics
        for k in ['loss', 'ce_loss', 'volume_loss', 'false_air_loss', 'perceptual_loss', 'volume_ratio', 'recall']:
            if k in loss_dict:
                val = loss_dict[k]
                metrics_sum[k] += val.item() if torch.is_tensor(val) else val
        
        batch_metrics = compute_metrics(logits, batch, air_tokens)
        for k, v in batch_metrics.items():
            metrics_sum[k] += v.item()
        
        n += 1
    
    metrics = {k: v / n for k, v in metrics_sum.items()}
    
    stage_stats = model.quantizer.get_usage_stats()
    for stage_name, (usage, perp) in stage_stats.items():
        metrics[f'{stage_name}_usage'] = usage
        metrics[f'{stage_name}_perplexity'] = perp
    
    return metrics

print("Training functions defined")
print("  - Tracking: volume_ratio, recall, false_air_loss")

## 13. Training Loop

In [None]:
print("="*70)
print("VQ-VAE V8-B TRAINING - BALANCED LOSS")
print("="*70)
print(f"Target: 60-65% building accuracy")
print(f"Target: volume_ratio ~1.0x (not 2x!)")
print(f"Target: recall >90% (preserve structures)")
print()

history = {
    'train_loss': [], 'train_building_acc': [], 'train_vol_ratio': [], 'train_recall': [],
    'val_loss': [], 'val_building_acc': [], 'val_vol_ratio': [], 'val_recall': [],
    'train_stage0_perplexity': [], 'val_stage0_perplexity': [],
    'learning_rate': [],
}

best_building_acc = 0
best_epoch = 0
start_time = time.time()

for epoch in range(TOTAL_EPOCHS):
    train_m = train_epoch(model, criterion, train_loader, optimizer, scaler, device, AIR_TOKENS_TENSOR)
    val_m = validate(model, criterion, val_loader, device, AIR_TOKENS_TENSOR)
    
    scheduler.step()
    current_lr = scheduler.get_last_lr()[0]
    
    # Record metrics
    history['train_loss'].append(train_m['loss'])
    history['train_building_acc'].append(train_m['building_acc'])
    history['train_vol_ratio'].append(train_m['volume_ratio'])
    history['train_recall'].append(train_m.get('recall', train_m.get('struct_recall', 0)))
    history['val_loss'].append(val_m['loss'])
    history['val_building_acc'].append(val_m['building_acc'])
    history['val_vol_ratio'].append(val_m['volume_ratio'])
    history['val_recall'].append(val_m.get('recall', val_m.get('struct_recall', 0)))
    history['train_stage0_perplexity'].append(train_m.get('stage0_perplexity', 0))
    history['val_stage0_perplexity'].append(val_m.get('stage0_perplexity', 0))
    history['learning_rate'].append(current_lr)
    
    # Save best model
    if val_m['building_acc'] > best_building_acc:
        best_building_acc = val_m['building_acc']
        best_epoch = epoch + 1
        torch.save(model.state_dict(), f"{OUTPUT_DIR}/vqvae_v8b_best.pt")
    
    # Save checkpoint every 5 epochs
    if (epoch + 1) % 5 == 0:
        torch.save(model.state_dict(), f"{OUTPUT_DIR}/vqvae_v8b_epoch{epoch+1}.pt")
    
    # Print with recall to monitor shape preservation
    val_recall = val_m.get('recall', val_m.get('struct_recall', 0))
    print(f"Epoch {epoch+1:2d} | "
          f"Build: {train_m['building_acc']:.1%}/{val_m['building_acc']:.1%} | "
          f"Vol: {val_m['volume_ratio']:.2f}x | "
          f"Recall: {val_recall:.1%} | "
          f"LR: {current_lr:.2e}")

train_time = time.time() - start_time
print(f"\nTraining complete in {train_time/60:.1f} minutes")
print(f"Best val building accuracy: {best_building_acc:.1%} at epoch {best_epoch}")
print(f"Final volume ratio: {history['val_vol_ratio'][-1]:.2f}x")
print(f"Final recall: {history['val_recall'][-1]:.1%}")

## 14. Plot Results

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
epochs = range(1, TOTAL_EPOCHS + 1)

# Building accuracy
ax = axes[0, 0]
ax.plot(epochs, history['train_building_acc'], 'b-', label='Train')
ax.plot(epochs, history['val_building_acc'], 'r--', label='Val')
ax.axhline(y=0.492, color='g', linestyle=':', alpha=0.5, label='v6-freq (49.2%)')
ax.axhline(y=0.60, color='orange', linestyle='--', alpha=0.5, label='Target (60%)')
ax.set_title('Building Accuracy', fontweight='bold')
ax.legend()
ax.grid(True, alpha=0.3)

# Volume ratio
ax = axes[0, 1]
ax.plot(epochs, history['train_vol_ratio'], 'b-', label='Train')
ax.plot(epochs, history['val_vol_ratio'], 'r--', label='Val')
ax.axhline(y=1.0, color='g', linestyle='--', alpha=0.5, label='Target (1.0x)')
ax.axhline(y=1.68, color='orange', linestyle=':', alpha=0.5, label='v6-freq (1.68x)')
ax.set_title('Volume Ratio', fontweight='bold')
ax.legend()
ax.grid(True, alpha=0.3)

# Loss
ax = axes[1, 0]
ax.plot(epochs, history['train_loss'], 'b-', label='Train')
ax.plot(epochs, history['val_loss'], 'r--', label='Val')
ax.set_title('Loss')
ax.legend()
ax.grid(True, alpha=0.3)

# RFSQ perplexity
ax = axes[1, 1]
ax.plot(epochs, history['train_stage0_perplexity'], 'b-', label='Train')
ax.plot(epochs, history['val_stage0_perplexity'], 'r--', label='Val')
ax.set_title('RFSQ Stage 0 Perplexity')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/vqvae_v8b_training.png", dpi=150, bbox_inches='tight')
plt.show()

print("\n" + "="*70)
print("RESULTS")
print("="*70)
print(f"Best building accuracy: {best_building_acc:.1%} (epoch {best_epoch})")
print(f"Final volume ratio: {history['val_vol_ratio'][-1]:.2f}x")
print(f"Training time: {train_time/60:.1f} minutes")
print()
if best_building_acc >= 0.60 and history['val_vol_ratio'][-1] <= 1.3:
    print("[SUCCESS] Stage 1 targets met! Ready for Stage 2 (v8-C)")
elif best_building_acc >= 0.55:
    print("[CLOSE] Consider Stage 2 with caution")
else:
    print("[BELOW TARGET] Analyze before Stage 2")

## 15. Save Results

In [None]:
results = {
    'config': {
        'version': 'v8-B',
        'latent_resolution': '16x16x16',
        'hidden_dim': HIDDEN_DIM,
        'total_epochs': TOTAL_EPOCHS,
        'batch_size': BATCH_SIZE,
        'base_lr': BASE_LR,
        'seed': SEED,
    },
    'results': {
        'best_building_acc': float(best_building_acc),
        'best_epoch': best_epoch,
        'final_building_acc': float(history['val_building_acc'][-1]),
        'final_volume_ratio': float(history['val_vol_ratio'][-1]),
        'training_time_min': float(train_time / 60),
        'target_60pct_met': bool(best_building_acc >= 0.60),
        'volume_target_met': bool(history['val_vol_ratio'][-1] <= 1.3),
    },
    'history': {k: [float(x) for x in v] for k, v in history.items()},
}

with open(f"{OUTPUT_DIR}/vqvae_v8b_results.json", 'w') as f:
    json.dump(results, f, indent=2)

torch.save(model.state_dict(), f"{OUTPUT_DIR}/vqvae_v8b_final.pt")

print("\nResults saved to:")
print(f"  - {OUTPUT_DIR}/vqvae_v8b_results.json")
print(f"  - {OUTPUT_DIR}/vqvae_v8b_best.pt")
print(f"  - {OUTPUT_DIR}/vqvae_v8b_final.pt")
print(f"  - {OUTPUT_DIR}/vqvae_v8b_training.png")
print("\nDone!")