# Phase 0: VQ-VAE Embedding Validation

## Critical Question

**Do Block2Vec embeddings actually help VQ-VAE, or would random embeddings work just as well?**

Before spending more time optimizing Block2Vec, we need to answer this fundamental question.

## Experiment Design

We train three identical mini VQ-VAEs with different input embeddings:

| Condition | Embeddings | Purpose |
|-----------|------------|---------|
| **V1** | Block2Vec V1 (skip-gram) | Our baseline trained embeddings |
| **V2** | Block2Vec V2 (hybrid) | Our improved embeddings |
| **Random** | Random initialization | Null hypothesis - no learned structure |

## Success Criteria

- **V1/V2 >> Random (>20% better on STRUCTURE accuracy)**: Block2Vec is useful
- **V1/V2 ≈ Random**: Block2Vec doesn't matter for reconstruction
- **V1 > V2**: The hybrid approach was harmful, use V1

## BUG FIX (2024-12): Correct Non-Air Accuracy

**Previous bug**: Code used `block_ids != 0` to find non-air blocks, but token 0 is `UNKNOWN_BLOCK`, not air!

**Why this matters**: ~90% of voxels are air. If we include air in accuracy:
- Predicting air everywhere ≈ 90% accuracy
- The 88% accuracy we saw tells us almost nothing about structure reconstruction

**The fix**: Minecraft has 3 types of air:
- Token 19: `minecraft:air`
- Token 164: `minecraft:cave_air`  
- Token 932: `minecraft:void_air`

Now using `torch.isin(targets, AIR_TOKENS)` to correctly identify all air blocks.

## Mini VQ-VAE Config

To make this fast (~2 hours total for all 3 conditions):
- 10 epochs (instead of 25)
- Smaller hidden dims [32, 64, 128] (instead of [64, 128, 256])
- Same codebook size (512)

In [None]:
# ============================================================
# CELL 1: Imports and Setup
# ============================================================

import json
import random
import time
from pathlib import Path
from typing import Dict, List, Tuple, Any

import h5py
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# ============================================================
# CELL 2: Configuration
# ============================================================

# === Data Paths ===
DATA_DIR = "/kaggle/input/minecraft-schematics/minecraft_splits/splits/train"
VAL_DIR = "/kaggle/input/minecraft-schematics/minecraft_splits/splits/val"
VOCAB_PATH = "/kaggle/input/minecraft-schematics/tok2block.json"

# Embeddings paths
V1_EMBEDDINGS_PATH = "/kaggle/input/block2vec-embeddings/block_embeddings.npy"
V2_EMBEDDINGS_PATH = "/kaggle/input/block2vec-v2/block_embeddings_v2.npy"
V2_MAPPING_PATH = "/kaggle/input/block2vec-v2/original_to_collapsed.json"

OUTPUT_DIR = "/kaggle/working"

# === Mini Model Architecture (faster training) ===
BLOCK_EMBEDDING_DIM = 32
HIDDEN_DIMS = [32, 64, 128]  # Smaller than full model
LATENT_DIM = 128             # Smaller than full model (256)
NUM_CODEBOOK_ENTRIES = 512
COMMITMENT_COST = 0.25

# === Training (faster) ===
EPOCHS = 10
BATCH_SIZE = 4  # Reduced from 16 - cross_entropy on 32x32x32x3717 is memory-intensive
LEARNING_RATE = 3e-4
USE_AMP = True
GRAD_ACCUM_STEPS = 4  # Accumulate gradients to simulate larger effective batch

# === Other ===
SEED = 42
NUM_WORKERS = 2

print("Mini VQ-VAE Configuration:")
print(f"  Hidden dims: {HIDDEN_DIMS}")
print(f"  Latent dim: {LATENT_DIM}")
print(f"  Epochs: {EPOCHS}")
print(f"  Batch size: {BATCH_SIZE} (effective: {BATCH_SIZE * GRAD_ACCUM_STEPS})")

In [None]:
# ============================================================
# CELL 3: Load Vocabulary and Prepare Embeddings
# ============================================================

# Load vocabulary
with open(VOCAB_PATH, 'r') as f:
    tok2block = {int(k): v for k, v in json.load(f).items()}

VOCAB_SIZE = len(tok2block)
print(f"Vocabulary size: {VOCAB_SIZE} block types")

# ============================================================
# BUG FIX: Find ALL air tokens (not just token 0!)
# Token 0 is UNKNOWN_BLOCK, not air!
# Air tokens are: 19 (air), 164 (cave_air), 932 (void_air)
# ============================================================
AIR_TOKENS = set()
for tok, block in tok2block.items():
    block_lower = block.lower()
    if 'air' in block_lower and 'stair' not in block_lower:
        AIR_TOKENS.add(tok)
        print(f"  Found air token: {tok} = '{block}'")

AIR_TOKENS_TENSOR = torch.tensor(sorted(AIR_TOKENS), dtype=torch.long)
print(f"\nAir tokens: {AIR_TOKENS_TENSOR.tolist()}")
print(f"Total air types: {len(AIR_TOKENS)}")

# Load V1 embeddings (3717, 32)
v1_embeddings = np.load(V1_EMBEDDINGS_PATH)
print(f"\nV1 embeddings: {v1_embeddings.shape}")

# Load V2 embeddings (1007, 32) and mapping
v2_collapsed = np.load(V2_EMBEDDINGS_PATH)
with open(V2_MAPPING_PATH, 'r') as f:
    original_to_collapsed = {int(k): int(v) for k, v in json.load(f).items()}
print(f"V2 collapsed embeddings: {v2_collapsed.shape}")

# Expand V2 to full vocabulary using mapping
# Each original token maps to a collapsed token
v2_embeddings = np.zeros((VOCAB_SIZE, BLOCK_EMBEDDING_DIM), dtype=np.float32)
for orig_tok in range(VOCAB_SIZE):
    if orig_tok in original_to_collapsed:
        collapsed_tok = original_to_collapsed[orig_tok]
        if collapsed_tok < len(v2_collapsed):
            v2_embeddings[orig_tok] = v2_collapsed[collapsed_tok]

print(f"V2 expanded embeddings: {v2_embeddings.shape}")

# Create random embeddings (same scale as V1)
np.random.seed(SEED)
v1_std = v1_embeddings.std()
random_embeddings = np.random.randn(VOCAB_SIZE, BLOCK_EMBEDDING_DIM).astype(np.float32) * v1_std
print(f"Random embeddings: {random_embeddings.shape} (std={v1_std:.4f})")

# Store all embedding variants
EMBEDDING_VARIANTS = {
    'V1': v1_embeddings,
    'V2': v2_embeddings,
    'Random': random_embeddings,
}

print("\nAll embedding variants prepared!")

In [None]:
# ============================================================
# CELL 4: Dataset
# ============================================================

class VQVAEDataset(Dataset):
    def __init__(self, data_dir: str, seed: int = 42):
        self.data_dir = Path(data_dir)
        self.h5_files = sorted(self.data_dir.glob("*.h5"))
        if not self.h5_files:
            raise ValueError(f"No H5 files found in {data_dir}")
        print(f"Found {len(self.h5_files)} structures in {data_dir}")
    
    def __len__(self):
        return len(self.h5_files)
    
    def __getitem__(self, idx):
        with h5py.File(self.h5_files[idx], 'r') as f:
            key = list(f.keys())[0]
            structure = f[key][:].astype(np.int64)
        return torch.from_numpy(structure).long()


train_dataset = VQVAEDataset(DATA_DIR, seed=SEED)
val_dataset = VQVAEDataset(VAL_DIR, seed=SEED)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, 
                          num_workers=NUM_WORKERS, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
                        num_workers=NUM_WORKERS, pin_memory=True)

print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")

In [None]:
# ============================================================
# CELL 5: Mini VQ-VAE Model
# ============================================================

class ResidualBlock3D(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv3d(channels, channels, 3, padding=1)
        self.conv2 = nn.Conv3d(channels, channels, 3, padding=1)
        self.bn1 = nn.BatchNorm3d(channels)
        self.bn2 = nn.BatchNorm3d(channels)
    
    def forward(self, x):
        residual = x
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        return F.relu(x + residual)


class MiniVQVAE(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dims, latent_dim, 
                 num_codes, commitment_cost, pretrained_embeddings):
        super().__init__()
        self.vocab_size = vocab_size
        self.latent_dim = latent_dim
        self.num_codes = num_codes
        self.commitment_cost = commitment_cost
        
        # Block embeddings (frozen)
        self.block_emb = nn.Embedding(vocab_size, embedding_dim)
        self.block_emb.weight.data.copy_(torch.from_numpy(pretrained_embeddings))
        self.block_emb.weight.requires_grad = False
        
        # Encoder
        enc_layers = []
        in_ch = embedding_dim
        for h_dim in hidden_dims:
            enc_layers.extend([
                nn.Conv3d(in_ch, h_dim, 4, stride=2, padding=1),
                nn.BatchNorm3d(h_dim),
                nn.ReLU(inplace=True),
                ResidualBlock3D(h_dim),
            ])
            in_ch = h_dim
        enc_layers.append(nn.Conv3d(in_ch, latent_dim, 3, padding=1))
        self.encoder = nn.Sequential(*enc_layers)
        
        # Codebook
        self.codebook = nn.Embedding(num_codes, latent_dim)
        self.codebook.weight.data.uniform_(-1/num_codes, 1/num_codes)
        
        # Decoder
        dec_layers = []
        in_ch = latent_dim
        for h_dim in reversed(hidden_dims):
            dec_layers.extend([
                ResidualBlock3D(in_ch),
                nn.ConvTranspose3d(in_ch, h_dim, 4, stride=2, padding=1),
                nn.BatchNorm3d(h_dim),
                nn.ReLU(inplace=True),
            ])
            in_ch = h_dim
        dec_layers.append(nn.Conv3d(in_ch, vocab_size, 3, padding=1))
        self.decoder = nn.Sequential(*dec_layers)
    
    def quantize(self, z_e):
        # z_e: [B, C, D, H, W]
        z_e_perm = z_e.permute(0, 2, 3, 4, 1).contiguous()  # [B, D, H, W, C]
        flat = z_e_perm.view(-1, self.latent_dim)
        
        # Distances
        d = (flat**2).sum(1, keepdim=True) + \
            (self.codebook.weight**2).sum(1) - \
            2 * flat @ self.codebook.weight.t()
        
        indices = d.argmin(1)
        z_q_flat = self.codebook(indices)
        z_q_perm = z_q_flat.view(z_e_perm.shape)
        
        # Losses
        codebook_loss = F.mse_loss(z_q_perm, z_e_perm.detach())
        commit_loss = F.mse_loss(z_e_perm, z_q_perm.detach())
        vq_loss = codebook_loss + self.commitment_cost * commit_loss
        
        # Straight-through
        z_q_st = z_e_perm + (z_q_perm - z_e_perm).detach()
        z_q = z_q_st.permute(0, 4, 1, 2, 3).contiguous()
        
        return z_q, vq_loss, indices.view(z_e_perm.shape[:-1])
    
    def forward(self, block_ids):
        # Embed
        x = self.block_emb(block_ids)  # [B, 32, 32, 32, emb]
        x = x.permute(0, 4, 1, 2, 3).contiguous()  # [B, emb, 32, 32, 32]
        
        # Encode
        z_e = self.encoder(x)
        
        # Quantize
        z_q, vq_loss, indices = self.quantize(z_e)
        
        # Decode
        logits = self.decoder(z_q)
        
        return {'logits': logits, 'vq_loss': vq_loss, 'indices': indices}
    
    def compute_loss(self, block_ids, air_tokens_tensor):
        """
        Compute loss and metrics.
        
        BUG FIX: Now uses air_tokens_tensor to correctly identify air blocks.
        Previous version used `block_ids != 0`, but token 0 is UNKNOWN_BLOCK!
        Air tokens are: 19 (air), 164 (cave_air), 932 (void_air)
        """
        out = self(block_ids)
        
        # Reconstruction loss
        logits = out['logits'].permute(0, 2, 3, 4, 1).contiguous()
        recon_loss = F.cross_entropy(logits.view(-1, self.vocab_size), block_ids.view(-1))
        
        total_loss = recon_loss + out['vq_loss']
        
        # Accuracy metrics
        with torch.no_grad():
            preds = logits.argmax(-1)
            targets_flat = block_ids.view(-1)
            preds_flat = preds.view(-1)
            
            # Overall accuracy
            correct = (preds_flat == targets_flat).float()
            acc = correct.mean()
            
            # Move air tokens to same device
            air_tokens_device = air_tokens_tensor.to(targets_flat.device)
            
            # Find air and non-air blocks using torch.isin
            is_air = torch.isin(targets_flat, air_tokens_device)
            is_structure = ~is_air
            
            # Air accuracy
            if is_air.sum() > 0:
                air_acc = correct[is_air].mean()
            else:
                air_acc = torch.tensor(0.0, device=block_ids.device)
            
            # Structure accuracy (non-air) - THE KEY METRIC!
            if is_structure.sum() > 0:
                struct_acc = correct[is_structure].mean()
            else:
                struct_acc = torch.tensor(0.0, device=block_ids.device)
            
            # Track air percentage for sanity check
            air_pct = is_air.float().mean()
        
        return {
            'loss': total_loss,
            'recon_loss': recon_loss,
            'vq_loss': out['vq_loss'],
            'accuracy': acc,
            'air_accuracy': air_acc,
            'struct_accuracy': struct_acc,
            'air_percentage': air_pct,
            'indices': out['indices'],
        }


print("MiniVQVAE defined!")
print("BUG FIX: compute_loss now correctly identifies all air tokens")

In [None]:
# ============================================================
# CELL 6: Training Functions
# ============================================================

def train_epoch(model, loader, optimizer, scaler, device, air_tokens_tensor):
    """Train for one epoch."""
    model.train()
    metrics = {'loss': 0, 'recon': 0, 'vq': 0, 'acc': 0, 'air_acc': 0, 'struct_acc': 0, 'air_pct': 0}
    n = 0
    
    optimizer.zero_grad()
    
    for batch_idx, batch in enumerate(tqdm(loader, desc="Train", leave=False)):
        batch = batch.to(device)
        
        with torch.amp.autocast('cuda', enabled=USE_AMP):
            out = model.compute_loss(batch, air_tokens_tensor)
            # Scale loss for gradient accumulation
            loss = out['loss'] / GRAD_ACCUM_STEPS
        
        scaler.scale(loss).backward()
        
        # Step optimizer every GRAD_ACCUM_STEPS batches
        if (batch_idx + 1) % GRAD_ACCUM_STEPS == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
        
        metrics['loss'] += out['loss'].item()
        metrics['recon'] += out['recon_loss'].item()
        metrics['vq'] += out['vq_loss'].item()
        metrics['acc'] += out['accuracy'].item()
        metrics['air_acc'] += out['air_accuracy'].item()
        metrics['struct_acc'] += out['struct_accuracy'].item()
        metrics['air_pct'] += out['air_percentage'].item()
        n += 1
    
    # Handle remaining gradients if loader length not divisible by GRAD_ACCUM_STEPS
    if len(loader) % GRAD_ACCUM_STEPS != 0:
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()
    
    return {k: v/n for k, v in metrics.items()}


@torch.no_grad()
def validate(model, loader, device, air_tokens_tensor):
    """Validate model."""
    model.eval()
    metrics = {'loss': 0, 'recon': 0, 'acc': 0, 'air_acc': 0, 'struct_acc': 0, 'air_pct': 0}
    n = 0
    
    for batch in tqdm(loader, desc="Val", leave=False):
        batch = batch.to(device)
        
        with torch.amp.autocast('cuda', enabled=USE_AMP):
            out = model.compute_loss(batch, air_tokens_tensor)
        
        metrics['loss'] += out['loss'].item()
        metrics['recon'] += out['recon_loss'].item()
        metrics['acc'] += out['accuracy'].item()
        metrics['air_acc'] += out['air_accuracy'].item()
        metrics['struct_acc'] += out['struct_accuracy'].item()
        metrics['air_pct'] += out['air_percentage'].item()
        n += 1
    
    return {k: v/n for k, v in metrics.items()}


print("Training functions defined!")

In [None]:
# ============================================================
# CELL 7: Run Experiment for One Embedding Type
# ============================================================

def run_experiment(name, embeddings, air_tokens_tensor):
    """Train and evaluate VQ-VAE with given embeddings."""
    print(f"\n{'='*60}")
    print(f"Training with {name} embeddings")
    print(f"{'='*60}")
    print(f"Air tokens: {air_tokens_tensor.tolist()}")
    
    # Clear GPU memory from previous experiment
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
    
    # Set seeds for reproducibility
    torch.manual_seed(SEED)
    np.random.seed(SEED)
    random.seed(SEED)
    
    # Create model
    model = MiniVQVAE(
        vocab_size=VOCAB_SIZE,
        embedding_dim=BLOCK_EMBEDDING_DIM,
        hidden_dims=HIDDEN_DIMS,
        latent_dim=LATENT_DIM,
        num_codes=NUM_CODEBOOK_ENTRIES,
        commitment_cost=COMMITMENT_COST,
        pretrained_embeddings=embeddings,
    ).to(device)
    
    # Count params
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Trainable parameters: {trainable:,}")
    
    # Print GPU memory status
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1e9
        reserved = torch.cuda.memory_reserved() / 1e9
        print(f"GPU memory: {allocated:.2f} GB allocated, {reserved:.2f} GB reserved")
    
    # Optimizer
    optimizer = optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=LEARNING_RATE,
    )
    scaler = torch.amp.GradScaler('cuda', enabled=USE_AMP)
    
    # Training loop - now tracking air_acc and air_pct
    history = {
        'train_loss': [], 'train_acc': [], 'train_air_acc': [], 'train_struct_acc': [], 'train_air_pct': [],
        'val_loss': [], 'val_acc': [], 'val_air_acc': [], 'val_struct_acc': [], 'val_air_pct': [],
    }
    
    start_time = time.time()
    
    for epoch in range(EPOCHS):
        train_metrics = train_epoch(model, train_loader, optimizer, scaler, device, air_tokens_tensor)
        val_metrics = validate(model, val_loader, device, air_tokens_tensor)
        
        history['train_loss'].append(train_metrics['loss'])
        history['train_acc'].append(train_metrics['acc'])
        history['train_air_acc'].append(train_metrics['air_acc'])
        history['train_struct_acc'].append(train_metrics['struct_acc'])
        history['train_air_pct'].append(train_metrics['air_pct'])
        history['val_loss'].append(val_metrics['loss'])
        history['val_acc'].append(val_metrics['acc'])
        history['val_air_acc'].append(val_metrics['air_acc'])
        history['val_struct_acc'].append(val_metrics['struct_acc'])
        history['val_air_pct'].append(val_metrics['air_pct'])
        
        # Show all metrics including structure accuracy
        print(f"Epoch {epoch+1:2d}/{EPOCHS} | "
              f"Loss: {train_metrics['loss']:.3f} | "
              f"Acc: {train_metrics['acc']:.1%} | "
              f"Struct: {train_metrics['struct_acc']:.1%} | "
              f"Air: {train_metrics['air_acc']:.1%} | "
              f"Val Struct: {val_metrics['struct_acc']:.1%} | "
              f"Air%: {val_metrics['air_pct']:.1%}")
    
    train_time = time.time() - start_time
    print(f"Training time: {train_time/60:.1f} minutes")
    
    # Sanity check: struct_acc should differ from overall acc
    if abs(history['val_acc'][-1] - history['val_struct_acc'][-1]) < 0.001:
        print("⚠️  WARNING: Overall and Structure accuracy are nearly identical!")
        print("    Air detection may still be broken.")
    else:
        print(f"✓ Structure accuracy ({history['val_struct_acc'][-1]:.1%}) differs from overall ({history['val_acc'][-1]:.1%})")
    
    # Final metrics
    final_metrics = {
        'name': name,
        'final_train_loss': history['train_loss'][-1],
        'final_val_loss': history['val_loss'][-1],
        'final_train_acc': history['train_acc'][-1],
        'final_val_acc': history['val_acc'][-1],
        'final_train_struct_acc': history['train_struct_acc'][-1],
        'final_val_struct_acc': history['val_struct_acc'][-1],
        'final_train_air_acc': history['train_air_acc'][-1],
        'final_val_air_acc': history['val_air_acc'][-1],
        'avg_air_pct': np.mean(history['val_air_pct']),
        'best_val_loss': min(history['val_loss']),
        'best_val_acc': max(history['val_acc']),
        'best_val_struct_acc': max(history['val_struct_acc']),
        'training_time': train_time,
        'history': history,
    }
    
    # Cleanup
    del model, optimizer, scaler
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    return final_metrics


print("Experiment function defined!")

In [None]:
# ============================================================
# CELL 8: Run All Experiments
# ============================================================

print("="*60)
print("PHASE 0: VQ-VAE EMBEDDING VALIDATION")
print("="*60)
print(f"\nAir tokens being used: {AIR_TOKENS_TENSOR.tolist()}")
print("These will be EXCLUDED from structure accuracy calculation.\n")

all_results = {}

for name, embeddings in EMBEDDING_VARIANTS.items():
    results = run_experiment(name, embeddings, AIR_TOKENS_TENSOR)
    all_results[name] = results

print("\n" + "="*60)
print("ALL EXPERIMENTS COMPLETE")
print("="*60)

In [None]:
# ============================================================
# CELL 9: Compare Results
# ============================================================

print("\n" + "="*70)
print("RESULTS COMPARISON")
print("="*70)

# Show air percentage first
avg_air_pct = np.mean([all_results[name]['avg_air_pct'] for name in ['V1', 'V2', 'Random']])
print(f"\nAverage air percentage in data: {avg_air_pct:.1%}")
print("This is why STRUCTURE accuracy (non-air) is the key metric!\n")

# Create comparison table - emphasize STRUCTURE accuracy
print("{:<10} {:>12} {:>12} {:>12} {:>12} {:>12}".format(
    "Embeddings", "Val Loss", "Overall Acc", "STRUCT Acc", "Air Acc", "Time"))
print("-"*70)

for name in ['V1', 'V2', 'Random']:
    r = all_results[name]
    print("{:<10} {:>12.4f} {:>12.1%} {:>12.1%} {:>12.1%} {:>10.1f}m".format(
        name,
        r['best_val_loss'],
        r['best_val_acc'],
        r['best_val_struct_acc'],
        r['final_val_air_acc'],
        r['training_time']/60
    ))

# Calculate improvements over random - STRUCTURE accuracy is key!
print("\n" + "="*70)
print("IMPROVEMENT OVER RANDOM BASELINE (Structure Accuracy = Key Metric)")
print("="*70)

random_loss = all_results['Random']['best_val_loss']
random_acc = all_results['Random']['best_val_acc']
random_struct = all_results['Random']['best_val_struct_acc']

for name in ['V1', 'V2']:
    r = all_results[name]
    loss_improvement = (random_loss - r['best_val_loss']) / random_loss * 100
    acc_improvement = (r['best_val_acc'] - random_acc) / random_acc * 100
    struct_improvement = (r['best_val_struct_acc'] - random_struct) / random_struct * 100
    
    print(f"\n{name} vs Random:")
    print(f"  Loss:            {loss_improvement:+.1f}% {'(better)' if loss_improvement > 0 else '(worse)'}")
    print(f"  Overall Acc:     {acc_improvement:+.1f}% {'(better)' if acc_improvement > 0 else '(worse)'}")
    print(f"  ★ STRUCT Acc:    {struct_improvement:+.1f}% {'(better)' if struct_improvement > 0 else '(worse)'} ← KEY METRIC")

# Decision based on STRUCTURE accuracy
print("\n" + "="*70)
print("CONCLUSION (Based on STRUCTURE Accuracy)")
print("="*70)

v1_struct_improvement = (all_results['V1']['best_val_struct_acc'] - random_struct) / random_struct * 100
v2_struct_improvement = (all_results['V2']['best_val_struct_acc'] - random_struct) / random_struct * 100

print(f"\nStructure accuracy improvements over random:")
print(f"  V1: {v1_struct_improvement:+.1f}%")
print(f"  V2: {v2_struct_improvement:+.1f}%")

if v1_struct_improvement > 20 or v2_struct_improvement > 20:
    print("\n✓ Block2Vec embeddings ARE useful for VQ-VAE structure reconstruction!")
    print("  Proceed with Block2Vec optimization.")
    if all_results['V1']['best_val_struct_acc'] > all_results['V2']['best_val_struct_acc']:
        print(f"  V1 is better for structure - the hybrid approach may have been harmful.")
    else:
        print(f"  V2 is better for structure - continue with hybrid approach.")
elif v1_struct_improvement > 5 or v2_struct_improvement > 5:
    print("\n~ Block2Vec provides MODEST improvement for structure reconstruction.")
    print("  Consider whether optimization effort is worth it.")
else:
    print("\n✗ Block2Vec embeddings do NOT significantly help structure reconstruction.")
    print("  Either:")
    print("  1. Use simple one-hot encoding")
    print("  2. Let VQ-VAE learn its own embeddings")
    print("  3. Focus optimization elsewhere (VQ-VAE architecture, training, etc.)")

In [None]:
# ============================================================
# CELL 10: Plot Comparison
# ============================================================

fig, axes = plt.subplots(2, 3, figsize=(18, 10))

colors = {'V1': 'blue', 'V2': 'green', 'Random': 'red'}

# Validation Loss
ax = axes[0, 0]
for name in ['V1', 'V2', 'Random']:
    ax.plot(all_results[name]['history']['val_loss'], label=name, 
            color=colors[name], linewidth=2)
ax.set_xlabel('Epoch')
ax.set_ylabel('Validation Loss')
ax.set_title('Validation Loss by Embedding Type')
ax.legend()
ax.grid(True, alpha=0.3)

# Overall Accuracy (less important - dominated by air)
ax = axes[0, 1]
for name in ['V1', 'V2', 'Random']:
    ax.plot(all_results[name]['history']['val_acc'], label=name,
            color=colors[name], linewidth=2)
ax.set_xlabel('Epoch')
ax.set_ylabel('Validation Accuracy')
ax.set_title('Overall Accuracy (includes ~90% air)')
ax.legend()
ax.grid(True, alpha=0.3)

# STRUCTURE Accuracy - THE KEY METRIC!
ax = axes[0, 2]
for name in ['V1', 'V2', 'Random']:
    ax.plot(all_results[name]['history']['val_struct_acc'], label=name,
            color=colors[name], linewidth=2)
ax.set_xlabel('Epoch')
ax.set_ylabel('Structure Accuracy (non-air)')
ax.set_title('★ STRUCTURE Accuracy (KEY METRIC) ★')
ax.legend()
ax.grid(True, alpha=0.3)

# Air Accuracy
ax = axes[1, 0]
for name in ['V1', 'V2', 'Random']:
    ax.plot(all_results[name]['history']['val_air_acc'], label=name,
            color=colors[name], linewidth=2)
ax.set_xlabel('Epoch')
ax.set_ylabel('Air Accuracy')
ax.set_title('Air Block Accuracy')
ax.legend()
ax.grid(True, alpha=0.3)

# Air Percentage (sanity check)
ax = axes[1, 1]
for name in ['V1', 'V2', 'Random']:
    ax.plot(all_results[name]['history']['val_air_pct'], label=name,
            color=colors[name], linewidth=2, alpha=0.7)
ax.set_xlabel('Epoch')
ax.set_ylabel('Air Percentage')
ax.set_title(f'Air Block % in Data (~should be constant)')
ax.legend()
ax.grid(True, alpha=0.3)

# Bar chart comparison - emphasize structure accuracy
ax = axes[1, 2]
x = np.arange(3)
width = 0.35
struct_accs = [all_results[name]['best_val_struct_acc'] for name in ['V1', 'V2', 'Random']]
overall_accs = [all_results[name]['best_val_acc'] for name in ['V1', 'V2', 'Random']]

bars1 = ax.bar(x - width/2, struct_accs, width, label='Structure Acc (KEY)', color=['blue', 'green', 'red'], alpha=0.9)
bars2 = ax.bar(x + width/2, overall_accs, width, label='Overall Acc', color=['blue', 'green', 'red'], alpha=0.4)

ax.set_ylabel('Accuracy')
ax.set_title('Best Accuracy Comparison\n(Structure Acc is the key metric!)')
ax.set_xticks(x)
ax.set_xticklabels(['V1', 'V2', 'Random'])
ax.legend()
ax.grid(True, alpha=0.3, axis='y')

# Add value labels on bars
for bar, val in zip(bars1, struct_accs):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, 
            f'{val:.1%}', ha='center', va='bottom', fontsize=9, fontweight='bold')

plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/embedding_comparison.png", dpi=150)
plt.show()

# Sanity check output
print("\n" + "="*60)
print("SANITY CHECK")
print("="*60)
for name in ['V1', 'V2', 'Random']:
    r = all_results[name]
    struct = r['best_val_struct_acc']
    overall = r['best_val_acc']
    diff = abs(struct - overall)
    status = "✓" if diff > 0.01 else "⚠️"
    print(f"{status} {name}: Overall={overall:.3f}, Structure={struct:.3f}, Diff={diff:.3f}")

In [None]:
# ============================================================
# CELL 11: Save Results
# ============================================================

# Save summary (without full history to keep file small)
summary = {}
for name, results in all_results.items():
    summary[name] = {
        'final_val_loss': results['final_val_loss'],
        'final_val_acc': results['final_val_acc'],
        'final_val_struct_acc': results['final_val_struct_acc'],
        'final_val_air_acc': results['final_val_air_acc'],
        'best_val_loss': results['best_val_loss'],
        'best_val_acc': results['best_val_acc'],
        'best_val_struct_acc': results['best_val_struct_acc'],
        'avg_air_percentage': results['avg_air_pct'],
        'training_time_minutes': results['training_time'] / 60,
    }

with open(f"{OUTPUT_DIR}/embedding_validation_results.json", 'w') as f:
    json.dump(summary, f, indent=2)

print("Results saved to embedding_validation_results.json")

# Save full history for detailed analysis
full_results = {}
for name, results in all_results.items():
    full_results[name] = {
        'history': results['history'],
        'training_time': results['training_time'],
    }

with open(f"{OUTPUT_DIR}/embedding_validation_full.json", 'w') as f:
    json.dump(full_results, f, indent=2)

print("Full history saved to embedding_validation_full.json")

# Also save air tokens used for reproducibility
with open(f"{OUTPUT_DIR}/air_tokens_used.json", 'w') as f:
    json.dump({
        'air_tokens': sorted(AIR_TOKENS),
        'note': 'These tokens were excluded from structure accuracy calculation'
    }, f, indent=2)

print("Air tokens saved to air_tokens_used.json")

In [None]:
# ============================================================
# CELL 12: Final Summary
# ============================================================

print("\n" + "="*70)
print("PHASE 0 COMPLETE: VQ-VAE EMBEDDING VALIDATION")
print("="*70)

print("\nQuestion: Do Block2Vec embeddings help VQ-VAE reconstruct STRUCTURES?")
print(f"\nAir tokens excluded: {sorted(AIR_TOKENS)}")
print(f"Average air percentage: {avg_air_pct:.1%}")
print("\nNOTE: Overall accuracy is ~{:.0%} just from predicting air correctly.".format(avg_air_pct))
print("      STRUCTURE accuracy is the true measure of reconstruction quality!")

print("\nResults:")
print("\n{:<12} {:>12} {:>15} {:>15} {:>12}".format(
    "", "Val Loss", "Overall Acc", "★STRUCT Acc★", "Air Acc"))
print("-"*70)
for name in ['V1', 'V2', 'Random']:
    r = all_results[name]
    print("{:<12} {:>12.4f} {:>15.1%} {:>15.1%} {:>12.1%}".format(
        name,
        r['best_val_loss'],
        r['best_val_acc'],
        r['best_val_struct_acc'],
        r['final_val_air_acc']
    ))

# Final verdict
print("\n" + "="*70)
v1_improvement = (all_results['V1']['best_val_struct_acc'] - all_results['Random']['best_val_struct_acc']) / all_results['Random']['best_val_struct_acc'] * 100
v2_improvement = (all_results['V2']['best_val_struct_acc'] - all_results['Random']['best_val_struct_acc']) / all_results['Random']['best_val_struct_acc'] * 100

print(f"Structure accuracy improvement over random:")
print(f"  V1: {v1_improvement:+.1f}%")
print(f"  V2: {v2_improvement:+.1f}%")

print("\n" + "="*70)
print("Files saved:")
print(f"  - {OUTPUT_DIR}/embedding_comparison.png")
print(f"  - {OUTPUT_DIR}/embedding_validation_results.json")
print(f"  - {OUTPUT_DIR}/embedding_validation_full.json")
print(f"  - {OUTPUT_DIR}/air_tokens_used.json")
print("="*70)