# VQ-VAE v6-freq Training - RFSQ + Frequency-Based Block Weighting

## Changes from v6 (RFSQ-only)

| Change | v6 | v6-freq |
|--------|-----|--------|
| Block weighting | Terrain/Building/Air only | **+ Inverse frequency weights** |
| Weight cap | N/A | **10x max for rare blocks** |
| Target | General accuracy | **Rare block reconstruction** |

## Why Frequency Weighting?

v5.1 showed that rare blocks (chests, doors, fences, trapdoors, carpet) **NEVER** reconstruct correctly.
These blocks are vastly outnumbered by common blocks (planks, stone, air).

Frequency weighting gives rare blocks up to 10x higher loss weight:
```python
weight = min(10.0, max_count / block_count)
```

## Goals

| Metric | v5.1 Result | v6-freq Target |
|--------|-------------|---------------|
| Building Accuracy | 45.6% | **>55%** |
| Rare Block Recall | ~0% | **>20%** |
| Chest Recall | 0% | **>10%** |
| Door Recall | 0% | **>10%** |
| Fence Recall | 0% | **>10%** |

## Setup - Mount Google Drive

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Create output directory - NOTE: v6_freq not v6!
import os
OUTPUT_DIR = '/content/drive/MyDrive/minecraft_ai/vqvae_v6_freq'
os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"Output will be saved to: {OUTPUT_DIR}")

## Cell 1: Imports

In [None]:
import json
import random
import time
from pathlib import Path
from typing import Dict, List, Tuple, Any, Set
from collections import Counter

import h5py
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## Cell 2: Configuration

In [None]:
# === Data Paths (UPDATE THESE FOR YOUR DRIVE) ===
DRIVE_BASE = '/content/drive/MyDrive/minecraft_ai'

DATA_DIR = f"{DRIVE_BASE}/splits/train"
VAL_DIR = f"{DRIVE_BASE}/splits/val"
VOCAB_PATH = f"{DRIVE_BASE}/tok2block.json"
V3_EMBEDDINGS_PATH = f"{DRIVE_BASE}/block_embeddings_v3.npy"

OUTPUT_DIR = f"{DRIVE_BASE}/vqvae_v6_freq"  # Different from v6!

# === V6-freq RFSQ Configuration ===
HIDDEN_DIMS = [96, 192]
RFSQ_LEVELS_PER_STAGE = [5, 5, 5, 5]
NUM_STAGES = 2
DROPOUT = 0.1

# === FREQUENCY WEIGHTING (NEW!) ===
USE_FREQUENCY_WEIGHTING = True
FREQUENCY_WEIGHT_CAP = 10.0  # Max weight for rare blocks

# === Structure weights ===
STRUCTURE_WEIGHT = 50.0

# === TERRAIN SETTINGS ===
TERRAIN_WEIGHT = 0.2
BUILDING_WEIGHT = 1.0
AIR_WEIGHT = 0.1

# === Training ===
TOTAL_EPOCHS = 25
BATCH_SIZE = 4
BASE_LR = 3e-4
USE_AMP = True
GRAD_ACCUM_STEPS = 4

SEED = 42
NUM_WORKERS = 2

CODES_PER_STAGE = int(np.prod(RFSQ_LEVELS_PER_STAGE))
TOTAL_IMPLICIT_CODES = CODES_PER_STAGE ** NUM_STAGES

print("VQ-VAE v6-freq (RFSQ + Frequency Weighting) Configuration:")
print(f"  RFSQ: {NUM_STAGES} stages × {CODES_PER_STAGE:,} codes")
print(f"  Total implicit codes: {TOTAL_IMPLICIT_CODES:,}")
print(f"\nFREQUENCY WEIGHTING:")
print(f"  Enabled: {USE_FREQUENCY_WEIGHTING}")
print(f"  Weight cap: {FREQUENCY_WEIGHT_CAP}x")

## Cell 3: Load Vocabulary and Embeddings

In [None]:
with open(VOCAB_PATH, 'r') as f:
    tok2block = {int(k): v for k, v in json.load(f).items()}

block2tok = {v: k for k, v in tok2block.items()}
VOCAB_SIZE = len(tok2block)
print(f"Vocabulary size: {VOCAB_SIZE}")

# Find air tokens
AIR_TOKENS: Set[int] = set()
for tok, block in tok2block.items():
    if 'air' in block.lower() and 'stair' not in block.lower():
        AIR_TOKENS.add(tok)
        print(f"  Air token: {tok} = {block}")

AIR_TOKENS_LIST = sorted(AIR_TOKENS)
AIR_TOKENS_TENSOR = torch.tensor(AIR_TOKENS_LIST, dtype=torch.long)

# Load V3 embeddings
v3_embeddings = np.load(V3_EMBEDDINGS_PATH).astype(np.float32)
EMBEDDING_DIM = v3_embeddings.shape[1]
print(f"V3 embeddings: {v3_embeddings.shape} (dim={EMBEDDING_DIM})")

## Cell 4: Compute Block Frequencies (NEW!)

This cell scans all training data to compute block frequencies.
Rare blocks get higher weights (up to 10x) in the loss function.

In [None]:
print("Computing block frequencies from training data...")
print("(This may take a few minutes)")

block_counts = Counter()
train_files = sorted(Path(DATA_DIR).glob("*.h5"))

for h5_file in tqdm(train_files, desc="Scanning training data"):
    with h5py.File(h5_file, 'r') as f:
        key = list(f.keys())[0]
        structure = f[key][:].flatten()
        block_counts.update(structure.tolist())

total_blocks = sum(block_counts.values())
print(f"\nTotal blocks scanned: {total_blocks:,}")
print(f"Unique block types: {len(block_counts)}")

# Compute inverse frequency weights (capped at 10x)
max_count = max(block_counts.values())
frequency_weights = {}
for tok in range(VOCAB_SIZE):
    count = block_counts.get(tok, 1)  # Default to 1 if never seen
    weight = min(FREQUENCY_WEIGHT_CAP, max_count / count)
    frequency_weights[tok] = weight

# Create weight tensor
FREQUENCY_WEIGHT_TENSOR = torch.tensor(
    [frequency_weights[i] for i in range(VOCAB_SIZE)],
    dtype=torch.float32
)

# Show top 10 rarest blocks
print("\nTop 10 RAREST blocks (highest weights):")
sorted_by_weight = sorted(frequency_weights.items(), key=lambda x: x[1], reverse=True)
for tok, weight in sorted_by_weight[:10]:
    block_name = tok2block.get(tok, f"UNKNOWN_{tok}")
    count = block_counts.get(tok, 0)
    print(f"  {block_name}: weight={weight:.1f}x (count={count:,})")

# Show top 10 most common blocks
print("\nTop 10 MOST COMMON blocks (lowest weights):")
sorted_by_weight = sorted(frequency_weights.items(), key=lambda x: x[1])
for tok, weight in sorted_by_weight[:10]:
    block_name = tok2block.get(tok, f"UNKNOWN_{tok}")
    count = block_counts.get(tok, 0)
    print(f"  {block_name}: weight={weight:.2f}x (count={count:,})")

# Identify rare block categories for tracking
RARE_BLOCK_KEYWORDS = ['chest', 'door', 'fence', 'trapdoor', 'carpet', 'bed', 'button', 'lever']
RARE_BLOCK_TOKENS = set()
for tok, block in tok2block.items():
    for keyword in RARE_BLOCK_KEYWORDS:
        if keyword in block.lower():
            RARE_BLOCK_TOKENS.add(tok)
            break

RARE_BLOCK_TOKENS_TENSOR = torch.tensor(sorted(RARE_BLOCK_TOKENS), dtype=torch.long)
print(f"\nRare block tokens identified: {len(RARE_BLOCK_TOKENS)}")
print(f"Keywords: {RARE_BLOCK_KEYWORDS}")

## Cell 5: Terrain Detection

In [None]:
TERRAIN_BLOCKS: Set[str] = {
    'minecraft:dirt', 'minecraft:grass_block', 'minecraft:coarse_dirt',
    'minecraft:podzol', 'minecraft:mycelium', 'minecraft:rooted_dirt',
    'minecraft:dirt_path', 'minecraft:farmland', 'minecraft:mud',
    'minecraft:stone', 'minecraft:cobblestone', 'minecraft:mossy_cobblestone',
    'minecraft:bedrock', 'minecraft:deepslate', 'minecraft:tuff',
    'minecraft:granite', 'minecraft:diorite', 'minecraft:andesite',
    'minecraft:sand', 'minecraft:red_sand', 'minecraft:gravel', 'minecraft:clay',
    'minecraft:water', 'minecraft:lava',
    'minecraft:terracotta', 'minecraft:white_terracotta', 'minecraft:orange_terracotta',
    'minecraft:brown_terracotta', 'minecraft:red_terracotta',
    'minecraft:netherrack', 'minecraft:soul_sand', 'minecraft:soul_soil',
    'minecraft:end_stone',
    'minecraft:snow_block', 'minecraft:ice', 'minecraft:packed_ice',
}

TERRAIN_TOKENS: Set[int] = set()
for tok, block in tok2block.items():
    base_name = block.split('[')[0] if '[' in block else block
    if base_name in TERRAIN_BLOCKS:
        TERRAIN_TOKENS.add(tok)

TERRAIN_TOKENS_TENSOR = torch.tensor(sorted(TERRAIN_TOKENS), dtype=torch.long)
print(f"Terrain tokens: {len(TERRAIN_TOKENS)}")

## Cell 6: Dataset

In [None]:
class VQVAEDataset(Dataset):
    def __init__(self, data_dir: str):
        self.data_dir = Path(data_dir)
        self.h5_files = sorted(self.data_dir.glob("*.h5"))
        if not self.h5_files:
            raise ValueError(f"No H5 files in {data_dir}")
        print(f"Found {len(self.h5_files)} structures in {data_dir}")

    def __len__(self):
        return len(self.h5_files)

    def __getitem__(self, idx):
        with h5py.File(self.h5_files[idx], 'r') as f:
            key = list(f.keys())[0]
            structure = f[key][:].astype(np.int64)
        return torch.from_numpy(structure).long()

train_dataset = VQVAEDataset(DATA_DIR)
val_dataset = VQVAEDataset(VAL_DIR)
print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}")

## Cell 7: FSQ Base Module

In [None]:
class FSQ(nn.Module):
    def __init__(self, levels: List[int], eps: float = 1e-3):
        super().__init__()
        self.levels = levels
        self.dim = len(levels)
        self.eps = eps
        self.codebook_size = int(np.prod(levels))

        self.register_buffer('_levels', torch.tensor(levels, dtype=torch.float32))
        basis = []
        acc = 1
        for L in reversed(levels):
            basis.append(acc)
            acc *= L
        self.register_buffer('_basis', torch.tensor(list(reversed(basis)), dtype=torch.long))
        half_levels = [(L - 1) / 2 for L in levels]
        self.register_buffer('_half_levels', torch.tensor(half_levels, dtype=torch.float32))
        self.register_buffer('_usage', torch.zeros(self.codebook_size))

    def reset_usage(self):
        self._usage.zero_()

    def get_usage_stats(self) -> Tuple[float, float]:
        usage = (self._usage > 0).float().mean().item()
        if self._usage.sum() == 0:
            return usage, 0.0
        probs = self._usage / self._usage.sum()
        probs = probs[probs > 0]
        entropy = -(probs * probs.log()).sum()
        perplexity = entropy.exp().item()
        return usage, perplexity

    def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        z_bounded = torch.tanh(z)
        z_q_list = []
        for i in range(self.dim):
            L = self._levels[i]
            half_L = self._half_levels[i]
            z_i = z_bounded[..., i]
            z_i = z_i * half_L
            z_i = torch.round(z_i)
            z_i = torch.clamp(z_i, -half_L, half_L)
            z_i = z_i / half_L
            z_q_list.append(z_i)
        z_q = torch.stack(z_q_list, dim=-1)
        z_q = z_bounded + (z_q - z_bounded).detach()
        indices = torch.zeros(z_q.shape[:-1], dtype=torch.long, device=z_q.device)
        for i in range(self.dim):
            L = self._levels[i].long()
            half_L = self._half_levels[i]
            z_i = z_q[..., i]
            level_idx = ((z_i * half_L) + half_L).round().long()
            level_idx = torch.clamp(level_idx, 0, L - 1)
            indices = indices + level_idx * self._basis[i]
        with torch.no_grad():
            for idx in indices.unique():
                if idx < self.codebook_size:
                    self._usage[idx] += (indices == idx).sum()
        return z_q, indices

## Cell 8: RFSQ Module

In [None]:
class InvertibleLayerNorm(nn.Module):
    def __init__(self, num_features: int, eps: float = 1e-5):
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(num_features))
        self.bias = nn.Parameter(torch.zeros(num_features))
        self.stored_mean = None
        self.stored_std = None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        self.stored_mean = x.mean(dim=(1, 2, 3), keepdim=True)
        self.stored_std = x.std(dim=(1, 2, 3), keepdim=True) + self.eps
        x_norm = (x - self.stored_mean) / self.stored_std
        return x_norm * self.weight + self.bias

    def inverse(self, x_norm: torch.Tensor) -> torch.Tensor:
        x = (x_norm - self.bias) / self.weight
        return x * self.stored_std + self.stored_mean


class RFSQStage(nn.Module):
    def __init__(self, levels: List[int]):
        super().__init__()
        self.fsq = FSQ(levels)
        self.layernorm = InvertibleLayerNorm(len(levels))

    @property
    def codebook_size(self) -> int:
        return self.fsq.codebook_size

    def reset_usage(self):
        self.fsq.reset_usage()

    def get_usage_stats(self):
        return self.fsq.get_usage_stats()

    def forward(self, residual):
        z_norm = self.layernorm(residual)
        z_q_norm, indices = self.fsq(z_norm)
        z_q = self.layernorm.inverse(z_q_norm)
        new_residual = residual - z_q
        return z_q, new_residual, indices


class RFSQ(nn.Module):
    def __init__(self, levels_per_stage: List[int], num_stages: int = 2):
        super().__init__()
        self.num_stages = num_stages
        self.dim = len(levels_per_stage)
        self.stages = nn.ModuleList([RFSQStage(levels_per_stage) for _ in range(num_stages)])
        codes_per_stage = int(np.prod(levels_per_stage))
        self.codebook_size = codes_per_stage ** num_stages
        self.codes_per_stage = codes_per_stage

    def reset_usage(self):
        for stage in self.stages:
            stage.reset_usage()

    def get_usage_stats(self):
        return {f'stage{i}': stage.get_usage_stats() for i, stage in enumerate(self.stages)}

    def forward(self, z):
        residual = z
        z_q_sum = torch.zeros_like(z)
        all_indices = []
        for stage in self.stages:
            z_q, residual, indices = stage(residual)
            z_q_sum = z_q_sum + z_q
            all_indices.append(indices)
        return z_q_sum, all_indices

    def forward_with_norms(self, z):
        residual = z
        z_q_sum = torch.zeros_like(z)
        all_indices = []
        residual_norms = []
        for stage in self.stages:
            residual_norms.append(residual.norm().item())
            z_q, residual, indices = stage(residual)
            z_q_sum = z_q_sum + z_q
            all_indices.append(indices)
        residual_norms.append(residual.norm().item())
        return z_q_sum, all_indices, residual_norms

print(f"RFSQ module defined")

## Cell 9: VQ-VAE v6-freq Architecture with Frequency-Weighted Loss

In [None]:
class ResidualBlock3D(nn.Module):
    def __init__(self, channels: int):
        super().__init__()
        self.conv1 = nn.Conv3d(channels, channels, 3, padding=1)
        self.conv2 = nn.Conv3d(channels, channels, 3, padding=1)
        self.bn1 = nn.BatchNorm3d(channels)
        self.bn2 = nn.BatchNorm3d(channels)

    def forward(self, x):
        residual = x
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        return F.relu(x + residual)


class EncoderV6(nn.Module):
    def __init__(self, in_channels: int, hidden_dims: List[int], rfsq_dim: int, dropout: float = 0.1):
        super().__init__()
        layers = []
        current = in_channels
        for h in hidden_dims:
            layers.extend([
                nn.Conv3d(current, h, 4, stride=2, padding=1),
                nn.BatchNorm3d(h), nn.ReLU(inplace=True),
                nn.Dropout3d(dropout), ResidualBlock3D(h),
            ])
            current = h
        layers.extend([ResidualBlock3D(current), ResidualBlock3D(current),
                       nn.Conv3d(current, rfsq_dim, 3, padding=1)])
        self.encoder = nn.Sequential(*layers)

    def forward(self, x):
        return self.encoder(x)


class DecoderV6(nn.Module):
    def __init__(self, rfsq_dim: int, hidden_dims: List[int], num_blocks: int, dropout: float = 0.1):
        super().__init__()
        layers = [
            nn.Conv3d(rfsq_dim, hidden_dims[0], 3, padding=1),
            nn.BatchNorm3d(hidden_dims[0]), nn.ReLU(inplace=True),
            ResidualBlock3D(hidden_dims[0]), ResidualBlock3D(hidden_dims[0]),
        ]
        current = hidden_dims[0]
        for h in hidden_dims[1:]:
            layers.extend([
                ResidualBlock3D(current),
                nn.ConvTranspose3d(current, h, 4, stride=2, padding=1),
                nn.BatchNorm3d(h), nn.ReLU(inplace=True), nn.Dropout3d(dropout),
            ])
            current = h
        layers.extend([
            ResidualBlock3D(current),
            nn.ConvTranspose3d(current, current, 4, stride=2, padding=1),
            nn.BatchNorm3d(current), nn.ReLU(inplace=True),
            nn.Conv3d(current, num_blocks, 3, padding=1),
        ])
        self.decoder = nn.Sequential(*layers)

    def forward(self, z_q):
        return self.decoder(z_q)


class FrequencyWeightedLoss(nn.Module):
    """Cross-entropy with terrain weighting AND per-block frequency weighting."""

    def __init__(self, frequency_weights: torch.Tensor,
                 terrain_weight: float = 0.2, building_weight: float = 1.0, air_weight: float = 0.1):
        super().__init__()
        self.register_buffer('frequency_weights', frequency_weights)
        self.terrain_weight = terrain_weight
        self.building_weight = building_weight
        self.air_weight = air_weight

    def forward(self, logits: torch.Tensor, targets: torch.Tensor,
                terrain_mask: torch.Tensor, air_mask: torch.Tensor) -> torch.Tensor:
        # Per-voxel cross-entropy
        ce_loss = F.cross_entropy(logits, targets, reduction='none')

        # Base weights: terrain/building/air
        base_weights = torch.full_like(ce_loss, self.building_weight)
        base_weights[terrain_mask] = self.terrain_weight
        base_weights[air_mask] = self.air_weight

        # Frequency weights: look up per-block weight based on TARGET
        freq_weights = self.frequency_weights[targets]

        # Combined weight = base_weight * freq_weight
        combined_weights = base_weights * freq_weights

        return (ce_loss * combined_weights).sum() / combined_weights.sum()


class VQVAEv6Freq(nn.Module):
    """VQ-VAE v6-freq with RFSQ + Frequency-Weighted Loss."""

    def __init__(self, vocab_size: int, emb_dim: int, hidden_dims: List[int],
                 rfsq_levels: List[int], num_stages: int, pretrained_emb: np.ndarray,
                 frequency_weights: torch.Tensor,
                 terrain_weight: float = 0.2, building_weight: float = 1.0,
                 air_weight: float = 0.1, dropout: float = 0.1):
        super().__init__()

        self.vocab_size = vocab_size
        self.rfsq_dim = len(rfsq_levels)
        self.num_stages = num_stages

        self.block_emb = nn.Embedding(vocab_size, emb_dim)
        self.block_emb.weight.data.copy_(torch.from_numpy(pretrained_emb))
        self.block_emb.weight.requires_grad = False

        self.encoder = EncoderV6(emb_dim, hidden_dims, self.rfsq_dim, dropout)
        self.rfsq = RFSQ(rfsq_levels, num_stages)
        self.decoder = DecoderV6(self.rfsq_dim, list(reversed(hidden_dims)), vocab_size, dropout)

        # Frequency-weighted loss (KEY DIFFERENCE from v6)
        self.loss_fn = FrequencyWeightedLoss(frequency_weights, terrain_weight, building_weight, air_weight)

    def forward(self, block_ids, return_norms=False):
        x = self.block_emb(block_ids)
        x = x.permute(0, 4, 1, 2, 3).contiguous()
        z_e = self.encoder(x)
        z_e = z_e.permute(0, 2, 3, 4, 1).contiguous()
        if return_norms:
            z_q, all_indices, residual_norms = self.rfsq.forward_with_norms(z_e)
        else:
            z_q, all_indices = self.rfsq(z_e)
            residual_norms = None
        z_q = z_q.permute(0, 4, 1, 2, 3).contiguous()
        logits = self.decoder(z_q)
        result = {'logits': logits, 'all_indices': all_indices, 'z_e': z_e, 'z_q': z_q}
        if residual_norms:
            result['residual_norms'] = residual_norms
        return result

    def compute_loss(self, block_ids, air_tokens, terrain_tokens, rare_tokens,
                     structure_weight=50.0, return_norms=False):
        out = self(block_ids, return_norms=return_norms)
        logits = out['logits']

        B, C, X, Y, Z = logits.shape
        logits_flat = logits.permute(0, 2, 3, 4, 1).reshape(-1, C)
        targets_flat = block_ids.view(-1)

        device = targets_flat.device
        air_dev = air_tokens.to(device)
        terrain_dev = terrain_tokens.to(device)
        rare_dev = rare_tokens.to(device)

        is_air = torch.isin(targets_flat, air_dev)
        is_terrain = torch.isin(targets_flat, terrain_dev) & ~is_air
        is_building = ~is_air & ~is_terrain
        is_rare = torch.isin(targets_flat, rare_dev)

        # Frequency-weighted loss
        loss = self.loss_fn(logits_flat, targets_flat, is_terrain, is_air)

        with torch.no_grad():
            preds = logits_flat.argmax(dim=1)
            is_air_pred = torch.isin(preds, air_dev)
            correct = (preds == targets_flat).float()

            overall_acc = correct.mean()
            terrain_acc = correct[is_terrain].mean() if is_terrain.any() else torch.tensor(0.0, device=device)

            if is_building.any():
                building_acc = correct[is_building].mean()
                building_recall = (is_building & ~is_air_pred).sum().float() / is_building.sum()
                building_false_air = (is_building & is_air_pred).sum().float() / is_building.sum()
            else:
                building_acc = building_recall = building_false_air = torch.tensor(0.0, device=device)

            # RARE BLOCK METRICS (KEY FOR v6-freq)
            if is_rare.any():
                rare_acc = correct[is_rare].mean()
                rare_recall = (is_rare & ~is_air_pred).sum().float() / is_rare.sum()
            else:
                rare_acc = rare_recall = torch.tensor(0.0, device=device)

            is_struct = ~is_air
            struct_recall = (is_struct & ~is_air_pred).sum().float() / is_struct.sum() if is_struct.any() else torch.tensor(0.0, device=device)

            orig_vol = is_struct.sum().float()
            pred_vol = (~is_air_pred).sum().float()
            vol_ratio = pred_vol / orig_vol if orig_vol > 0 else torch.tensor(1.0, device=device)

            wrong_building = is_building & (preds != targets_flat)
            if wrong_building.any():
                error_similarity = F.cosine_similarity(
                    self.block_emb.weight[preds[wrong_building]],
                    self.block_emb.weight[targets_flat[wrong_building]], dim=-1).mean()
            else:
                error_similarity = torch.tensor(0.0, device=device)

        result = {
            'loss': loss, 'overall_acc': overall_acc, 'terrain_acc': terrain_acc,
            'building_acc': building_acc, 'building_recall': building_recall,
            'building_false_air': building_false_air, 'struct_recall': struct_recall,
            'vol_ratio': vol_ratio, 'error_similarity': error_similarity,
            'rare_acc': rare_acc, 'rare_recall': rare_recall,  # NEW!
        }
        if 'residual_norms' in out:
            result['residual_norms'] = out['residual_norms']
        return result


print("VQ-VAE v6-freq architecture defined!")
print("KEY: FrequencyWeightedLoss with per-block weights")

## Cell 10: Training Functions

In [None]:
def train_epoch(model, loader, optimizer, scaler, device,
                air_tokens, terrain_tokens, rare_tokens, structure_weight):
    model.train()
    model.rfsq.reset_usage()

    metrics = {k: 0.0 for k in [
        'loss', 'overall_acc', 'terrain_acc', 'building_acc',
        'building_recall', 'building_false_air', 'struct_recall',
        'vol_ratio', 'error_similarity', 'rare_acc', 'rare_recall',
    ]}
    grad_norms = []
    all_residual_norms = []
    n = 0
    optimizer.zero_grad()

    for batch_idx, batch in enumerate(tqdm(loader, desc="Train", leave=False)):
        batch = batch.to(device)

        with torch.amp.autocast('cuda', enabled=USE_AMP):
            return_norms = (batch_idx % 100 == 0)
            out = model.compute_loss(batch, air_tokens, terrain_tokens, rare_tokens,
                                    structure_weight, return_norms=return_norms)
            loss = out['loss'] / GRAD_ACCUM_STEPS

        if return_norms and 'residual_norms' in out:
            all_residual_norms.append(out['residual_norms'])

        scaler.scale(loss).backward()

        if (batch_idx + 1) % GRAD_ACCUM_STEPS == 0:
            scaler.unscale_(optimizer)
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            grad_norms.append(grad_norm.item())
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

        for k in metrics:
            if k in out:
                metrics[k] += out[k].item() if torch.is_tensor(out[k]) else out[k]
        n += 1

    stage_stats = model.rfsq.get_usage_stats()
    for stage_name, (usage, perp) in stage_stats.items():
        metrics[f'{stage_name}_usage'] = usage
        metrics[f'{stage_name}_perplexity'] = perp

    metrics['grad_norm'] = sum(grad_norms) / len(grad_norms) if grad_norms else 0.0

    if all_residual_norms:
        avg_norms = np.mean(all_residual_norms, axis=0)
        metrics['residual_decay'] = avg_norms[-1] / avg_norms[0] if avg_norms[0] > 0 else 1.0
    else:
        metrics['residual_decay'] = 1.0

    skip_avg = ['stage0_usage', 'stage0_perplexity', 'stage1_usage', 'stage1_perplexity', 'grad_norm', 'residual_decay']
    return {k: v/n if k not in skip_avg else v for k, v in metrics.items()}


@torch.no_grad()
def validate(model, loader, device, air_tokens, terrain_tokens, rare_tokens, structure_weight):
    model.eval()
    model.rfsq.reset_usage()

    metrics = {k: 0.0 for k in [
        'loss', 'overall_acc', 'terrain_acc', 'building_acc',
        'building_recall', 'building_false_air', 'struct_recall',
        'vol_ratio', 'error_similarity', 'rare_acc', 'rare_recall',
    ]}
    all_residual_norms = []
    n = 0

    for batch_idx, batch in enumerate(tqdm(loader, desc="Val", leave=False)):
        batch = batch.to(device)

        with torch.amp.autocast('cuda', enabled=USE_AMP):
            return_norms = (batch_idx % 50 == 0)
            out = model.compute_loss(batch, air_tokens, terrain_tokens, rare_tokens,
                                    structure_weight, return_norms=return_norms)

        if return_norms and 'residual_norms' in out:
            all_residual_norms.append(out['residual_norms'])

        for k in metrics:
            if k in out:
                metrics[k] += out[k].item() if torch.is_tensor(out[k]) else out[k]
        n += 1

    stage_stats = model.rfsq.get_usage_stats()
    for stage_name, (usage, perp) in stage_stats.items():
        metrics[f'{stage_name}_usage'] = usage
        metrics[f'{stage_name}_perplexity'] = perp

    if all_residual_norms:
        avg_norms = np.mean(all_residual_norms, axis=0)
        metrics['residual_decay'] = avg_norms[-1] / avg_norms[0] if avg_norms[0] > 0 else 1.0
    else:
        metrics['residual_decay'] = 1.0

    skip_avg = ['stage0_usage', 'stage0_perplexity', 'stage1_usage', 'stage1_perplexity', 'residual_decay']
    return {k: v/n if k not in skip_avg else v for k, v in metrics.items()}


print("Training functions defined with RARE BLOCK metrics!")

## Cell 11: Create Model and Optimizer

In [None]:
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.empty_cache()

g = torch.Generator().manual_seed(SEED)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, pin_memory=True, generator=g)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
                        num_workers=NUM_WORKERS, pin_memory=True)

model = VQVAEv6Freq(
    vocab_size=VOCAB_SIZE,
    emb_dim=EMBEDDING_DIM,
    hidden_dims=HIDDEN_DIMS,
    rfsq_levels=RFSQ_LEVELS_PER_STAGE,
    num_stages=NUM_STAGES,
    pretrained_emb=v3_embeddings,
    frequency_weights=FREQUENCY_WEIGHT_TENSOR,
    terrain_weight=TERRAIN_WEIGHT,
    building_weight=BUILDING_WEIGHT,
    air_weight=AIR_WEIGHT,
    dropout=DROPOUT,
).to(device)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total params: {total_params:,}")
print(f"Trainable params: {trainable_params:,}")
print(f"RFSQ total codes: {model.rfsq.codebook_size:,}")

optimizer = optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=BASE_LR, weight_decay=1e-5)
scaler = torch.amp.GradScaler('cuda', enabled=USE_AMP)
print(f"\nOptimizer: AdamW, LR={BASE_LR}")

## Cell 12: Training Loop

In [None]:
print("="*70)
print("VQ-VAE V6-FREQ TRAINING - RFSQ + FREQUENCY WEIGHTING")
print("="*70)
print(f"Key features:")
print(f"  - RFSQ: {NUM_STAGES}-stage residual quantization")
print(f"  - Frequency weighting: up to {FREQUENCY_WEIGHT_CAP}x for rare blocks")
print(f"  - Tracking: rare_acc, rare_recall")
print()

history = {
    'train_loss': [], 'train_building_acc': [], 'train_building_recall': [],
    'train_terrain_acc': [], 'train_struct_recall': [],
    'val_loss': [], 'val_building_acc': [], 'val_building_recall': [],
    'val_terrain_acc': [], 'val_struct_recall': [],
    'train_stage0_usage': [], 'train_stage0_perplexity': [],
    'train_stage1_usage': [], 'train_stage1_perplexity': [],
    'val_stage0_usage': [], 'val_stage0_perplexity': [],
    'val_stage1_usage': [], 'val_stage1_perplexity': [],
    'train_building_false_air': [], 'val_building_false_air': [],
    'train_vol_ratio': [], 'val_vol_ratio': [],
    'train_error_similarity': [], 'val_error_similarity': [],
    'train_grad_norm': [],
    'train_residual_decay': [], 'val_residual_decay': [],
    # NEW: Rare block metrics
    'train_rare_acc': [], 'val_rare_acc': [],
    'train_rare_recall': [], 'val_rare_recall': [],
}

best_building_acc = 0
best_epoch = 0
start_time = time.time()

for epoch in range(TOTAL_EPOCHS):
    train_m = train_epoch(model, train_loader, optimizer, scaler, device,
                          AIR_TOKENS_TENSOR, TERRAIN_TOKENS_TENSOR, RARE_BLOCK_TOKENS_TENSOR, STRUCTURE_WEIGHT)
    val_m = validate(model, val_loader, device,
                     AIR_TOKENS_TENSOR, TERRAIN_TOKENS_TENSOR, RARE_BLOCK_TOKENS_TENSOR, STRUCTURE_WEIGHT)

    # Record all metrics
    for key in history:
        if key.startswith('train_'):
            metric_name = key[6:]  # Remove 'train_'
            history[key].append(train_m.get(metric_name, 0))
        elif key.startswith('val_'):
            metric_name = key[4:]  # Remove 'val_'
            history[key].append(val_m.get(metric_name, 0))

    if val_m['building_acc'] > best_building_acc:
        best_building_acc = val_m['building_acc']
        best_epoch = epoch + 1
        torch.save(model.state_dict(), f"{OUTPUT_DIR}/vqvae_v6_freq_best.pt")

    print(f"Epoch {epoch+1:2d} | "
          f"Build: {train_m['building_acc']:.1%}/{val_m['building_acc']:.1%} | "
          f"Rare: {train_m['rare_acc']:.1%}/{val_m['rare_acc']:.1%} | "
          f"RareRecall: {val_m['rare_recall']:.1%} | "
          f"S0: {val_m.get('stage0_perplexity', 0):.0f}")

train_time = time.time() - start_time
print(f"\nTraining complete in {train_time/60:.1f} minutes")
print(f"Best val building accuracy: {best_building_acc:.1%} at epoch {best_epoch}")

## Cell 13: Plot Training Curves

In [None]:
fig, axes = plt.subplots(4, 4, figsize=(20, 16))
epochs = range(1, TOTAL_EPOCHS + 1)

# Row 1: Core metrics
ax = axes[0, 0]
ax.plot(epochs, history['train_building_acc'], 'b-', label='Train')
ax.plot(epochs, history['val_building_acc'], 'r--', label='Val')
ax.set_title('Building Accuracy', fontweight='bold')
ax.legend(); ax.grid(True, alpha=0.3)

ax = axes[0, 1]
ax.plot(epochs, history['train_rare_acc'], 'b-', label='Train')
ax.plot(epochs, history['val_rare_acc'], 'r--', label='Val')
ax.set_title('RARE Block Accuracy (KEY!)', fontweight='bold', color='green')
ax.legend(); ax.grid(True, alpha=0.3)

ax = axes[0, 2]
ax.plot(epochs, history['train_rare_recall'], 'b-', label='Train')
ax.plot(epochs, history['val_rare_recall'], 'r--', label='Val')
ax.axhline(y=0.2, color='g', linestyle='--', alpha=0.5, label='Target 20%')
ax.set_title('RARE Block Recall (KEY!)', fontweight='bold', color='green')
ax.legend(); ax.grid(True, alpha=0.3)

ax = axes[0, 3]
ax.plot(epochs, history['train_building_recall'], 'b-', label='Train')
ax.plot(epochs, history['val_building_recall'], 'r--', label='Val')
ax.set_title('Building Recall')
ax.legend(); ax.grid(True, alpha=0.3)

# Row 2: RFSQ metrics
ax = axes[1, 0]
ax.plot(epochs, history['train_stage0_perplexity'], 'b-', label='Train S0')
ax.plot(epochs, history['val_stage0_perplexity'], 'r--', label='Val S0')
ax.set_title('Stage 0 Perplexity')
ax.legend(); ax.grid(True, alpha=0.3)

ax = axes[1, 1]
ax.plot(epochs, history['train_stage1_perplexity'], 'b-', label='Train S1')
ax.plot(epochs, history['val_stage1_perplexity'], 'r--', label='Val S1')
ax.set_title('Stage 1 Perplexity')
ax.legend(); ax.grid(True, alpha=0.3)

ax = axes[1, 2]
ax.plot(epochs, history['train_residual_decay'], 'b-', label='Train')
ax.plot(epochs, history['val_residual_decay'], 'r--', label='Val')
ax.axhline(y=0.5, color='orange', linestyle='--', alpha=0.5)
ax.set_title('Residual Decay')
ax.legend(); ax.grid(True, alpha=0.3)

ax = axes[1, 3]
ax.plot(epochs, history['train_terrain_acc'], 'b-', label='Train')
ax.plot(epochs, history['val_terrain_acc'], 'r--', label='Val')
ax.set_title('Terrain Accuracy')
ax.legend(); ax.grid(True, alpha=0.3)

# Row 3: Loss and diagnostics
ax = axes[2, 0]
ax.plot(epochs, history['train_loss'], 'b-', label='Train')
ax.plot(epochs, history['val_loss'], 'r--', label='Val')
ax.set_title('Loss')
ax.legend(); ax.grid(True, alpha=0.3)

ax = axes[2, 1]
ax.plot(epochs, history['train_building_false_air'], 'b-', label='Train')
ax.plot(epochs, history['val_building_false_air'], 'r--', label='Val')
ax.set_title('Building False Air')
ax.legend(); ax.grid(True, alpha=0.3)

ax = axes[2, 2]
ax.plot(epochs, history['train_struct_recall'], 'b-', label='Train')
ax.plot(epochs, history['val_struct_recall'], 'r--', label='Val')
ax.set_title('Structure Recall')
ax.legend(); ax.grid(True, alpha=0.3)

ax = axes[2, 3]
ax.plot(epochs, history['train_grad_norm'], 'g-')
ax.axhline(y=1.0, color='r', linestyle='--', alpha=0.5)
ax.set_title('Gradient Norm')
ax.grid(True, alpha=0.3)

# Row 4: Comparisons and summary
ax = axes[3, 0]
ax.plot(epochs, history['val_error_similarity'], 'b-')
ax.axhline(y=0.5, color='orange', linestyle='--', alpha=0.5)
ax.set_title('Error Similarity (Val)')
ax.grid(True, alpha=0.3)

# v5.1 vs v6-freq comparison
ax = axes[3, 1]
v51_baseline = {'Build\nAcc': 0.456, 'Build\nRecall': 0.847, 'Rare\nAcc': 0.0, 'Rare\nRecall': 0.0}
v6_freq = {
    'Build\nAcc': history['val_building_acc'][-1],
    'Build\nRecall': history['val_building_recall'][-1],
    'Rare\nAcc': history['val_rare_acc'][-1],
    'Rare\nRecall': history['val_rare_recall'][-1],
}
x = np.arange(len(v51_baseline))
width = 0.35
ax.bar(x - width/2, list(v51_baseline.values()), width, label='v5.1', color='gray')
ax.bar(x + width/2, list(v6_freq.values()), width, label='v6-freq', color='green')
ax.set_xticks(x)
ax.set_xticklabels(v51_baseline.keys())
ax.set_title('v5.1 vs v6-freq')
ax.legend(); ax.grid(True, alpha=0.3)

# Final metrics
ax = axes[3, 2]
final = {
    'Build\nAcc': history['val_building_acc'][-1],
    'Rare\nAcc': history['val_rare_acc'][-1],
    'Rare\nRecall': history['val_rare_recall'][-1],
    'S0\nUsage': history['val_stage0_usage'][-1] if history['val_stage0_usage'] else 0,
}
colors = ['green', 'blue', 'purple', 'orange']
bars = ax.bar(final.keys(), final.values(), color=colors)
ax.set_title('Final Val Metrics')
ax.set_ylim(0, 1)
for bar, val in zip(bars, final.values()):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02,
            f'{val:.2f}', ha='center', fontsize=9)

# Summary
ax = axes[3, 3]
ax.axis('off')
summary = f"""VQ-VAE v6-freq Results
─────────────────────────
Best Building Acc: {best_building_acc:.1%} (epoch {best_epoch})

RARE BLOCK METRICS:
  Rare Accuracy: {history['val_rare_acc'][-1]:.1%}
  Rare Recall: {history['val_rare_recall'][-1]:.1%}
  (Target: >20%)

Frequency Weighting:
  Cap: {FREQUENCY_WEIGHT_CAP}x

Training Time: {train_time/60:.1f} min"""
ax.text(0.1, 0.9, summary, transform=ax.transAxes, fontsize=11,
        verticalalignment='top', fontfamily='monospace')

plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/vqvae_v6_freq_training.png", dpi=150, bbox_inches='tight')
plt.show()

print("\n" + "="*70)
print("RARE BLOCK ANALYSIS")
print("="*70)
print(f"Rare block accuracy: {history['val_rare_acc'][-1]:.1%}")
print(f"Rare block recall: {history['val_rare_recall'][-1]:.1%}")
if history['val_rare_recall'][-1] > 0.2:
    print("  SUCCESS: Rare blocks are being reconstructed!")
elif history['val_rare_recall'][-1] > 0.1:
    print("  PARTIAL: Some improvement over v5.1's 0%")
else:
    print("  NEEDS MORE: Rare blocks still struggling")

## Cell 14: Save Results

In [None]:
results = {
    'config': {
        'version': 'v6-freq',
        'changes_from_v6': ['Frequency-based block weighting', f'Weight cap: {FREQUENCY_WEIGHT_CAP}x'],
        'hidden_dims': HIDDEN_DIMS,
        'rfsq_levels_per_stage': RFSQ_LEVELS_PER_STAGE,
        'num_stages': NUM_STAGES,
        'frequency_weight_cap': FREQUENCY_WEIGHT_CAP,
        'total_epochs': TOTAL_EPOCHS,
        'batch_size': BATCH_SIZE,
        'seed': SEED,
    },
    'results': {
        'best_building_acc': float(best_building_acc),
        'best_epoch': best_epoch,
        'final_building_acc': float(history['val_building_acc'][-1]),
        'final_building_recall': float(history['val_building_recall'][-1]),
        'final_rare_acc': float(history['val_rare_acc'][-1]),
        'final_rare_recall': float(history['val_rare_recall'][-1]),
        'final_stage0_perplexity': float(history['val_stage0_perplexity'][-1]) if history['val_stage0_perplexity'] else 0,
        'final_stage1_perplexity': float(history['val_stage1_perplexity'][-1]) if history['val_stage1_perplexity'] else 0,
        'training_time_min': float(train_time / 60),
    },
    'history': {k: [float(x) for x in v] for k, v in history.items()},
}

with open(f"{OUTPUT_DIR}/vqvae_v6_freq_results.json", 'w') as f:
    json.dump(results, f, indent=2)

checkpoint = {
    'model_state_dict': model.state_dict(),
    'config': {
        'version': 'v6-freq',
        'vocab_size': VOCAB_SIZE,
        'emb_dim': EMBEDDING_DIM,
        'hidden_dims': HIDDEN_DIMS,
        'rfsq_levels': RFSQ_LEVELS_PER_STAGE,
        'num_stages': NUM_STAGES,
        'frequency_weight_cap': FREQUENCY_WEIGHT_CAP,
    },
    'air_tokens': AIR_TOKENS_LIST,
    'terrain_tokens': sorted(TERRAIN_TOKENS),
    'rare_tokens': sorted(RARE_BLOCK_TOKENS),
    'best_building_acc': float(best_building_acc),
    'best_epoch': best_epoch,
}

torch.save(checkpoint, f"{OUTPUT_DIR}/vqvae_v6_freq_best_checkpoint.pt")
torch.save(model.state_dict(), f"{OUTPUT_DIR}/vqvae_v6_freq_final.pt")

print("\nResults saved:")
print(f"  - {OUTPUT_DIR}/vqvae_v6_freq_results.json")
print(f"  - {OUTPUT_DIR}/vqvae_v6_freq_best_checkpoint.pt")
print(f"  - {OUTPUT_DIR}/vqvae_v6_freq_final.pt")
print(f"  - {OUTPUT_DIR}/vqvae_v6_freq_training.png")

print("\n" + "="*70)
print("FINAL RESULTS - VQ-VAE v6-freq (RFSQ + Frequency Weighting)")
print("="*70)
print(f"Best building accuracy: {best_building_acc:.1%} at epoch {best_epoch}")
print(f"\nRARE BLOCK PERFORMANCE:")
print(f"  Rare accuracy: {history['val_rare_acc'][-1]:.1%}")
print(f"  Rare recall: {history['val_rare_recall'][-1]:.1%}")
print(f"  (v5.1 baseline: ~0%)")
print(f"\nTraining time: {train_time/60:.1f} minutes")