# VQ-VAE Embedding Validation: V3 Only

## Purpose
Validate Block2Vec V3 (compositional) embeddings for VQ-VAE structure reconstruction.

## Previous Results (from V2 validation)
- V1: **44.3%** structure accuracy
- V2: **43.0%** structure accuracy  
- Random: **39.0%** structure accuracy

**Key Question**: Can V3 compositional embeddings beat V1's 44.3%?


In [None]:
# ============================================================
# CELL 1: Imports and Setup
# ============================================================

import json
import random
import time
from pathlib import Path
from typing import Dict, List, Tuple, Any

import h5py
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# ============================================================
# CELL 2: Configuration (V3 ONLY)
# ============================================================

# === Data Paths ===
DATA_DIR = "/kaggle/input/minecraft-schematics/minecraft_splits/splits/train"
VAL_DIR = "/kaggle/input/minecraft-schematics/minecraft_splits/splits/val"
VOCAB_PATH = "/kaggle/input/minecraft-schematics/tok2block.json"

# V3 embeddings path
V3_EMBEDDINGS_PATH = "/kaggle/input/block2vec-v3/block_embeddings_v3.npy"

# Previous validation results (V1, V2, Random) for comparison plotting
PREVIOUS_RESULTS_PATH = "/kaggle/input/block2vec-v3/embedding_validation_full.json"

OUTPUT_DIR = "/kaggle/working"

# === Mini Model Architecture (same as V2 validation) ===
BLOCK_EMBEDDING_DIM = 40  # V3 compositional: 16+16+8=40
HIDDEN_DIMS = [32, 64, 128]
LATENT_DIM = 128
NUM_CODEBOOK_ENTRIES = 512
COMMITMENT_COST = 0.25

# === Training ===
EPOCHS = 10
BATCH_SIZE = 4
LEARNING_RATE = 3e-4
USE_AMP = True
GRAD_ACCUM_STEPS = 4

# === Other ===
SEED = 42
NUM_WORKERS = 2

print("V3 Validation Configuration:")
print(f"  Epochs: {EPOCHS}")
print(f"  Batch size: {BATCH_SIZE} (effective: {BATCH_SIZE * GRAD_ACCUM_STEPS})")


In [None]:
# ============================================================
# CELL 3: Load Vocabulary and V3 Embeddings
# ============================================================

# Load vocabulary
with open(VOCAB_PATH, 'r') as f:
    tok2block = {int(k): v for k, v in json.load(f).items()}

VOCAB_SIZE = len(tok2block)
print(f"Vocabulary size: {VOCAB_SIZE} block types")

# ============================================================
# Find ALL air tokens (not just token 0!)
# ============================================================
AIR_TOKENS = set()
for tok, block in tok2block.items():
    block_lower = block.lower()
    if 'air' in block_lower and 'stair' not in block_lower:
        AIR_TOKENS.add(tok)
        print(f"  Found air token: {tok} = {block}")

AIR_TOKENS_TENSOR = torch.tensor(sorted(AIR_TOKENS), dtype=torch.long)
print()
print(f"Air tokens: {AIR_TOKENS_TENSOR.tolist()}")

# Load V3 embeddings - keep as numpy (model will convert)
v3_embeddings = np.load(V3_EMBEDDINGS_PATH)
print()
print(f"V3 embeddings shape: {v3_embeddings.shape}")

# V3 ONLY - keep as numpy arrays
EMBEDDINGS = {
    "V3": v3_embeddings,
}

print()
print(f"Embeddings to test: {list(EMBEDDINGS.keys())}")
print("Note: Comparing against previous results - V1: 44.3%, V2: 43.0%, Random: 39.0%")

# Save air tokens info
air_info = {
    "air_tokens": sorted(AIR_TOKENS),
    "note": "These tokens were excluded from structure accuracy calculation"
}
with open(f"{OUTPUT_DIR}/air_tokens_used.json", 'w') as f:
    json.dump(air_info, f, indent=2)


In [None]:
# ============================================================
# CELL 4: Dataset
# ============================================================

class VQVAEDataset(Dataset):
    def __init__(self, data_dir: str, seed: int = 42):
        self.data_dir = Path(data_dir)
        self.h5_files = sorted(self.data_dir.glob("*.h5"))
        if not self.h5_files:
            raise ValueError(f"No H5 files found in {data_dir}")
        print(f"Found {len(self.h5_files)} structures in {data_dir}")
    
    def __len__(self):
        return len(self.h5_files)
    
    def __getitem__(self, idx):
        with h5py.File(self.h5_files[idx], 'r') as f:
            key = list(f.keys())[0]
            structure = f[key][:].astype(np.int64)
        return torch.from_numpy(structure).long()


train_dataset = VQVAEDataset(DATA_DIR, seed=SEED)
val_dataset = VQVAEDataset(VAL_DIR, seed=SEED)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, 
                          num_workers=NUM_WORKERS, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
                        num_workers=NUM_WORKERS, pin_memory=True)

print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")

In [None]:
# ============================================================
# CELL 5: Mini VQ-VAE Model
# ============================================================

class ResidualBlock3D(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv3d(channels, channels, 3, padding=1)
        self.conv2 = nn.Conv3d(channels, channels, 3, padding=1)
        self.bn1 = nn.BatchNorm3d(channels)
        self.bn2 = nn.BatchNorm3d(channels)
    
    def forward(self, x):
        residual = x
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        return F.relu(x + residual)


class MiniVQVAE(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dims, latent_dim, 
                 num_codes, commitment_cost, pretrained_embeddings):
        super().__init__()
        self.vocab_size = vocab_size
        self.latent_dim = latent_dim
        self.num_codes = num_codes
        self.commitment_cost = commitment_cost
        
        # Block embeddings (frozen)
        self.block_emb = nn.Embedding(vocab_size, embedding_dim)
        self.block_emb.weight.data.copy_(torch.from_numpy(pretrained_embeddings))
        self.block_emb.weight.requires_grad = False
        
        # Encoder
        enc_layers = []
        in_ch = embedding_dim
        for h_dim in hidden_dims:
            enc_layers.extend([
                nn.Conv3d(in_ch, h_dim, 4, stride=2, padding=1),
                nn.BatchNorm3d(h_dim),
                nn.ReLU(inplace=True),
                ResidualBlock3D(h_dim),
            ])
            in_ch = h_dim
        enc_layers.append(nn.Conv3d(in_ch, latent_dim, 3, padding=1))
        self.encoder = nn.Sequential(*enc_layers)
        
        # Codebook
        self.codebook = nn.Embedding(num_codes, latent_dim)
        self.codebook.weight.data.uniform_(-1/num_codes, 1/num_codes)
        
        # Decoder
        dec_layers = []
        in_ch = latent_dim
        for h_dim in reversed(hidden_dims):
            dec_layers.extend([
                ResidualBlock3D(in_ch),
                nn.ConvTranspose3d(in_ch, h_dim, 4, stride=2, padding=1),
                nn.BatchNorm3d(h_dim),
                nn.ReLU(inplace=True),
            ])
            in_ch = h_dim
        dec_layers.append(nn.Conv3d(in_ch, vocab_size, 3, padding=1))
        self.decoder = nn.Sequential(*dec_layers)
    
    def quantize(self, z_e):
        # z_e: [B, C, D, H, W]
        z_e_perm = z_e.permute(0, 2, 3, 4, 1).contiguous()  # [B, D, H, W, C]
        flat = z_e_perm.view(-1, self.latent_dim)
        
        # Distances
        d = (flat**2).sum(1, keepdim=True) + \
            (self.codebook.weight**2).sum(1) - \
            2 * flat @ self.codebook.weight.t()
        
        indices = d.argmin(1)
        z_q_flat = self.codebook(indices)
        z_q_perm = z_q_flat.view(z_e_perm.shape)
        
        # Losses
        codebook_loss = F.mse_loss(z_q_perm, z_e_perm.detach())
        commit_loss = F.mse_loss(z_e_perm, z_q_perm.detach())
        vq_loss = codebook_loss + self.commitment_cost * commit_loss
        
        # Straight-through
        z_q_st = z_e_perm + (z_q_perm - z_e_perm).detach()
        z_q = z_q_st.permute(0, 4, 1, 2, 3).contiguous()
        
        return z_q, vq_loss, indices.view(z_e_perm.shape[:-1])
    
    def forward(self, block_ids):
        # Embed
        x = self.block_emb(block_ids)  # [B, 32, 32, 32, emb]
        x = x.permute(0, 4, 1, 2, 3).contiguous()  # [B, emb, 32, 32, 32]
        
        # Encode
        z_e = self.encoder(x)
        
        # Quantize
        z_q, vq_loss, indices = self.quantize(z_e)
        
        # Decode
        logits = self.decoder(z_q)
        
        return {'logits': logits, 'vq_loss': vq_loss, 'indices': indices}
    
    def compute_loss(self, block_ids, air_tokens_tensor):
        """
        Compute loss and metrics.
        
        BUG FIX: Now uses air_tokens_tensor to correctly identify air blocks.
        Previous version used `block_ids != 0`, but token 0 is UNKNOWN_BLOCK!
        Air tokens are: 19 (air), 164 (cave_air), 932 (void_air)
        """
        out = self(block_ids)
        
        # Reconstruction loss
        logits = out['logits'].permute(0, 2, 3, 4, 1).contiguous()
        recon_loss = F.cross_entropy(logits.view(-1, self.vocab_size), block_ids.view(-1))
        
        total_loss = recon_loss + out['vq_loss']
        
        # Accuracy metrics
        with torch.no_grad():
            preds = logits.argmax(-1)
            targets_flat = block_ids.view(-1)
            preds_flat = preds.view(-1)
            
            # Overall accuracy
            correct = (preds_flat == targets_flat).float()
            acc = correct.mean()
            
            # Move air tokens to same device
            air_tokens_device = air_tokens_tensor.to(targets_flat.device)
            
            # Find air and non-air blocks using torch.isin
            is_air = torch.isin(targets_flat, air_tokens_device)
            is_structure = ~is_air
            
            # Air accuracy
            if is_air.sum() > 0:
                air_acc = correct[is_air].mean()
            else:
                air_acc = torch.tensor(0.0, device=block_ids.device)
            
            # Structure accuracy (non-air) - THE KEY METRIC!
            if is_structure.sum() > 0:
                struct_acc = correct[is_structure].mean()
            else:
                struct_acc = torch.tensor(0.0, device=block_ids.device)
            
            # Track air percentage for sanity check
            air_pct = is_air.float().mean()
        
        return {
            'loss': total_loss,
            'recon_loss': recon_loss,
            'vq_loss': out['vq_loss'],
            'accuracy': acc,
            'air_accuracy': air_acc,
            'struct_accuracy': struct_acc,
            'air_percentage': air_pct,
            'indices': out['indices'],
        }


print("MiniVQVAE defined!")
print("BUG FIX: compute_loss now correctly identifies all air tokens")

In [None]:
# ============================================================
# CELL 6: Training Functions
# ============================================================

def train_epoch(model, loader, optimizer, scaler, device, air_tokens_tensor):
    """Train for one epoch."""
    model.train()
    metrics = {'loss': 0, 'recon': 0, 'vq': 0, 'acc': 0, 'air_acc': 0, 'struct_acc': 0, 'air_pct': 0}
    n = 0
    
    optimizer.zero_grad()
    
    for batch_idx, batch in enumerate(tqdm(loader, desc="Train", leave=False)):
        batch = batch.to(device)
        
        with torch.amp.autocast('cuda', enabled=USE_AMP):
            out = model.compute_loss(batch, air_tokens_tensor)
            # Scale loss for gradient accumulation
            loss = out['loss'] / GRAD_ACCUM_STEPS
        
        scaler.scale(loss).backward()
        
        # Step optimizer every GRAD_ACCUM_STEPS batches
        if (batch_idx + 1) % GRAD_ACCUM_STEPS == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
        
        metrics['loss'] += out['loss'].item()
        metrics['recon'] += out['recon_loss'].item()
        metrics['vq'] += out['vq_loss'].item()
        metrics['acc'] += out['accuracy'].item()
        metrics['air_acc'] += out['air_accuracy'].item()
        metrics['struct_acc'] += out['struct_accuracy'].item()
        metrics['air_pct'] += out['air_percentage'].item()
        n += 1
    
    # Handle remaining gradients if loader length not divisible by GRAD_ACCUM_STEPS
    if len(loader) % GRAD_ACCUM_STEPS != 0:
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()
    
    return {k: v/n for k, v in metrics.items()}


@torch.no_grad()
def validate(model, loader, device, air_tokens_tensor):
    """Validate model."""
    model.eval()
    metrics = {'loss': 0, 'recon': 0, 'acc': 0, 'air_acc': 0, 'struct_acc': 0, 'air_pct': 0}
    n = 0
    
    for batch in tqdm(loader, desc="Val", leave=False):
        batch = batch.to(device)
        
        with torch.amp.autocast('cuda', enabled=USE_AMP):
            out = model.compute_loss(batch, air_tokens_tensor)
        
        metrics['loss'] += out['loss'].item()
        metrics['recon'] += out['recon_loss'].item()
        metrics['acc'] += out['accuracy'].item()
        metrics['air_acc'] += out['air_accuracy'].item()
        metrics['struct_acc'] += out['struct_accuracy'].item()
        metrics['air_pct'] += out['air_percentage'].item()
        n += 1
    
    return {k: v/n for k, v in metrics.items()}


print("Training functions defined!")

In [None]:
# ============================================================
# CELL 7: Run Experiment for One Embedding Type
# ============================================================

def run_experiment(name, embeddings, air_tokens_tensor):
    """Train and evaluate VQ-VAE with given embeddings."""
    print(f"\n{'='*60}")
    print(f"Training with {name} embeddings")
    print(f"{'='*60}")
    print(f"Air tokens: {air_tokens_tensor.tolist()}")
    
    # Clear GPU memory from previous experiment
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
    
    # Set seeds for reproducibility
    torch.manual_seed(SEED)
    np.random.seed(SEED)
    random.seed(SEED)
    
    # Create model
    model = MiniVQVAE(
        vocab_size=VOCAB_SIZE,
        embedding_dim=BLOCK_EMBEDDING_DIM,
        hidden_dims=HIDDEN_DIMS,
        latent_dim=LATENT_DIM,
        num_codes=NUM_CODEBOOK_ENTRIES,
        commitment_cost=COMMITMENT_COST,
        pretrained_embeddings=embeddings,
    ).to(device)
    
    # Count params
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Trainable parameters: {trainable:,}")
    
    # Print GPU memory status
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1e9
        reserved = torch.cuda.memory_reserved() / 1e9
        print(f"GPU memory: {allocated:.2f} GB allocated, {reserved:.2f} GB reserved")
    
    # Optimizer
    optimizer = optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=LEARNING_RATE,
    )
    scaler = torch.amp.GradScaler('cuda', enabled=USE_AMP)
    
    # Training loop - now tracking air_acc and air_pct
    history = {
        'train_loss': [], 'train_acc': [], 'train_air_acc': [], 'train_struct_acc': [], 'train_air_pct': [],
        'val_loss': [], 'val_acc': [], 'val_air_acc': [], 'val_struct_acc': [], 'val_air_pct': [],
    }
    
    start_time = time.time()
    
    for epoch in range(EPOCHS):
        train_metrics = train_epoch(model, train_loader, optimizer, scaler, device, air_tokens_tensor)
        val_metrics = validate(model, val_loader, device, air_tokens_tensor)
        
        history['train_loss'].append(train_metrics['loss'])
        history['train_acc'].append(train_metrics['acc'])
        history['train_air_acc'].append(train_metrics['air_acc'])
        history['train_struct_acc'].append(train_metrics['struct_acc'])
        history['train_air_pct'].append(train_metrics['air_pct'])
        history['val_loss'].append(val_metrics['loss'])
        history['val_acc'].append(val_metrics['acc'])
        history['val_air_acc'].append(val_metrics['air_acc'])
        history['val_struct_acc'].append(val_metrics['struct_acc'])
        history['val_air_pct'].append(val_metrics['air_pct'])
        
        # Show all metrics including structure accuracy
        print(f"Epoch {epoch+1:2d}/{EPOCHS} | "
              f"Loss: {train_metrics['loss']:.3f} | "
              f"Acc: {train_metrics['acc']:.1%} | "
              f"Struct: {train_metrics['struct_acc']:.1%} | "
              f"Air: {train_metrics['air_acc']:.1%} | "
              f"Val Struct: {val_metrics['struct_acc']:.1%} | "
              f"Air%: {val_metrics['air_pct']:.1%}")
    
    train_time = time.time() - start_time
    print(f"Training time: {train_time/60:.1f} minutes")
    
    # Sanity check: struct_acc should differ from overall acc
    if abs(history['val_acc'][-1] - history['val_struct_acc'][-1]) < 0.001:
        print("⚠️  WARNING: Overall and Structure accuracy are nearly identical!")
        print("    Air detection may still be broken.")
    else:
        print(f"✓ Structure accuracy ({history['val_struct_acc'][-1]:.1%}) differs from overall ({history['val_acc'][-1]:.1%})")
    
    # Final metrics
    final_metrics = {
        'name': name,
        'final_train_loss': history['train_loss'][-1],
        'final_val_loss': history['val_loss'][-1],
        'final_train_acc': history['train_acc'][-1],
        'final_val_acc': history['val_acc'][-1],
        'final_train_struct_acc': history['train_struct_acc'][-1],
        'final_val_struct_acc': history['val_struct_acc'][-1],
        'final_train_air_acc': history['train_air_acc'][-1],
        'final_val_air_acc': history['val_air_acc'][-1],
        'avg_air_pct': np.mean(history['val_air_pct']),
        'best_val_loss': min(history['val_loss']),
        'best_val_acc': max(history['val_acc']),
        'best_val_struct_acc': max(history['val_struct_acc']),
        'training_time': train_time,
        'history': history,
    }
    
    # Cleanup
    del model, optimizer, scaler
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    return final_metrics


print("Experiment function defined!")

In [None]:
# ============================================================
# CELL 8: Run All Experiments
# ============================================================

print("="*60)
print("PHASE 0: VQ-VAE EMBEDDING VALIDATION")
print("="*60)
print(f"\nAir tokens being used: {AIR_TOKENS_TENSOR.tolist()}")
print("These will be EXCLUDED from structure accuracy calculation.\n")

all_results = {}

for name, embeddings in EMBEDDINGS.items():
    results = run_experiment(name, embeddings, AIR_TOKENS_TENSOR)
    all_results[name] = results

print("\n" + "="*60)
print("ALL EXPERIMENTS COMPLETE")
print("="*60)

In [None]:
# ============================================================
# CELL 9: Load Previous Results and Compare
# ============================================================

print("\n" + "="*70)
print("LOADING PREVIOUS RESULTS")
print("="*70)

# Load previous results (V1, V2, Random) from JSON file
with open(PREVIOUS_RESULTS_PATH, 'r') as f:
    previous_results = json.load(f)

print(f"Loaded previous results: {list(previous_results.keys())}")

# Combine all results
all_results_combined = {**previous_results, 'V3': all_results['V3']}

# Extract best metrics from previous results
for name in ['V1', 'V2', 'Random']:
    hist = previous_results[name]['history']
    previous_results[name]['best_val_loss'] = min(hist['val_loss'])
    previous_results[name]['best_val_acc'] = max(hist['val_acc'])
    previous_results[name]['best_val_struct_acc'] = max(hist['val_struct_acc'])
    print(f"  {name}: {previous_results[name]['best_val_struct_acc']:.1%} structure accuracy")

print("\n" + "="*70)
print("RESULTS COMPARISON")
print("="*70)

# Show air percentage
avg_air_pct = all_results['V3']['avg_air_pct']
print(f"\nAverage air percentage in data: {avg_air_pct:.1%}")
print("This is why STRUCTURE accuracy (non-air) is the key metric!\n")

# Create comparison table
print("{:<10} {:>12} {:>12} {:>12} {:>12}".format(
    "Embeddings", "Val Loss", "Overall Acc", "STRUCT Acc", "Time"))
print("-"*60)

# Print all results
for name in ['V3', 'V1', 'V2', 'Random']:
    if name == 'V3':
        r = all_results['V3']
        time_str = f"{r['training_time']/60:.1f}m"
    else:
        r = previous_results[name]
        time_str = f"{r['training_time']/60:.1f}m"
    
    print("{:<10} {:>12.4f} {:>12.1%} {:>12.1%} {:>12}".format(
        name + (" (NEW)" if name == 'V3' else ""),
        r['best_val_loss'],
        r['best_val_acc'],
        r['best_val_struct_acc'],
        time_str
    ))

# Calculate V3 improvement
print("\n" + "="*70)
print("V3 IMPROVEMENT OVER BASELINES")
print("="*70)

v3_struct = all_results['V3']['best_val_struct_acc']
v1_struct = previous_results['V1']['best_val_struct_acc']
v2_struct = previous_results['V2']['best_val_struct_acc']
random_struct = previous_results['Random']['best_val_struct_acc']

v3_vs_random = (v3_struct - random_struct) / random_struct * 100
v3_vs_v1 = (v3_struct - v1_struct) / v1_struct * 100
v3_vs_v2 = (v3_struct - v2_struct) / v2_struct * 100

print(f"\nV3 Structure Accuracy: {v3_struct:.1%}")
print(f"  vs Random ({random_struct:.1%}): {v3_vs_random:+.1f}%")
print(f"  vs V1 ({v1_struct:.1%}):     {v3_vs_v1:+.1f}%")
print(f"  vs V2 ({v2_struct:.1%}):     {v3_vs_v2:+.1f}%")

# Conclusion
print("\n" + "="*70)
print("CONCLUSION")
print("="*70)

if v3_struct > v1_struct:
    print(f"\n✓ V3 ({v3_struct:.1%}) BEATS V1 ({v1_struct:.1%})!")
    print("  Compositional embeddings are the best approach.")
    print("  Use V3 embeddings for VQ-VAE training.")
elif v3_struct > v2_struct:
    print(f"\n~ V3 ({v3_struct:.1%}) beats V2 ({v2_struct:.1%}) but not V1 ({v1_struct:.1%}).")
    print("  V1 skip-gram is still the best. Consider hybrid approach.")
else:
    print(f"\n✗ V3 ({v3_struct:.1%}) is worse than V1 and V2.")
    print("  Compositional approach didn't help. Stick with V1.")

In [None]:
# ============================================================
# CELL 10: Plot All Results (V1, V2, V3, Random)
# ============================================================

# Plot comparison of all embedding types
fig, axes = plt.subplots(2, 3, figsize=(18, 10))

colors = {'V1': 'blue', 'V2': 'green', 'V3': 'purple', 'Random': 'red'}
names = ['V1', 'V2', 'V3', 'Random']

# Validation Loss
ax = axes[0, 0]
for name in names:
    if name == 'V3':
        hist = all_results['V3']['history']
    else:
        hist = previous_results[name]['history']
    ax.plot(hist['val_loss'], label=name, color=colors[name], linewidth=2)
ax.set_title('Validation Loss by Embedding Type', fontsize=12)
ax.set_xlabel('Epoch')
ax.set_ylabel('Validation Loss')
ax.legend()
ax.grid(True, alpha=0.3)

# Overall Accuracy  
ax = axes[0, 1]
for name in names:
    if name == 'V3':
        hist = all_results['V3']['history']
    else:
        hist = previous_results[name]['history']
    ax.plot(hist['val_acc'], label=name, color=colors[name], linewidth=2)
ax.set_title('Overall Accuracy (includes ~80% air)', fontsize=12)
ax.set_xlabel('Epoch')
ax.set_ylabel('Validation Accuracy')
ax.legend()
ax.grid(True, alpha=0.3)

# Structure Accuracy (KEY METRIC)
ax = axes[0, 2]
for name in names:
    if name == 'V3':
        hist = all_results['V3']['history']
    else:
        hist = previous_results[name]['history']
    ax.plot(hist['val_struct_acc'], label=name, color=colors[name], linewidth=2)
ax.set_title('STRUCTURE Accuracy (KEY METRIC)', fontsize=12, fontweight='bold')
ax.set_xlabel('Epoch')
ax.set_ylabel('Structure Accuracy (non-air)')
ax.legend()
ax.grid(True, alpha=0.3)

# Air Accuracy
ax = axes[1, 0]
for name in names:
    if name == 'V3':
        hist = all_results['V3']['history']
    else:
        hist = previous_results[name]['history']
    ax.plot(hist['val_air_acc'], label=name, color=colors[name], linewidth=2)
ax.set_title('Air Block Accuracy', fontsize=12)
ax.set_xlabel('Epoch')
ax.set_ylabel('Air Accuracy')
ax.legend()
ax.grid(True, alpha=0.3)

# Air Percentage (sanity check)
ax = axes[1, 1]
for name in names:
    if name == 'V3':
        hist = all_results['V3']['history']
    else:
        hist = previous_results[name]['history']
    ax.plot(hist['val_air_pct'], label=name, color=colors[name], linewidth=2)
ax.set_title('Air Block % in Data (~should be constant)', fontsize=12)
ax.set_xlabel('Epoch')
ax.set_ylabel('Air Percentage')
ax.set_ylim(0.78, 0.82)
ax.legend()
ax.grid(True, alpha=0.3)

# Bar chart comparison
ax = axes[1, 2]
struct_accs = []
for name in names:
    if name == 'V3':
        struct_accs.append(all_results['V3']['best_val_struct_acc'])
    else:
        struct_accs.append(previous_results[name]['best_val_struct_acc'])
bar_colors = [colors[n] for n in names]
bars = ax.bar(names, struct_accs, color=bar_colors, edgecolor='black', linewidth=1.5)
ax.set_title('Best Structure Accuracy Comparison', fontsize=12, fontweight='bold')
ax.set_ylabel('Accuracy')
ax.set_ylim(0, 0.55)

# Add value labels on bars
for bar, acc, name in zip(bars, struct_accs, names):
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2, height + 0.01, 
            f'{acc:.1%}', ha='center', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/embedding_comparison_all.png", dpi=150, bbox_inches='tight')
plt.show()

print(f"Saved plot to {OUTPUT_DIR}/embedding_comparison_all.png")


In [None]:
# ============================================================
# CELL 11: Save Final Results
# ============================================================

# Get V3 results from all_results
v3_result = all_results['V3']

# Save V3 summary
v3_summary = {
    "V3": {
        "final_val_loss": v3_result["history"]["val_loss"][-1],
        "final_val_acc": v3_result["history"]["val_acc"][-1],
        "final_val_struct_acc": v3_result["history"]["val_struct_acc"][-1],
        "final_val_air_acc": v3_result["history"]["val_air_acc"][-1],
        "best_val_struct_acc": v3_result["best_val_struct_acc"],
        "training_time_minutes": v3_result["training_time"] / 60,
    }
}

with open(f"{OUTPUT_DIR}/v3_validation_summary.json", 'w') as f:
    json.dump(v3_summary, f, indent=2)

# Save full history (convert numpy to lists for JSON)
v3_full = {
    'V3': {
        'history': {k: [float(x) for x in v] for k, v in v3_result['history'].items()},
        'training_time': v3_result['training_time'],
        'best_val_loss': float(v3_result['best_val_loss']),
        'best_val_acc': float(v3_result['best_val_acc']),
        'best_val_struct_acc': float(v3_result['best_val_struct_acc']),
    }
}
with open(f"{OUTPUT_DIR}/v3_validation_full.json", 'w') as f:
    json.dump(v3_full, f, indent=2)

# Print comparison using loaded previous results (not hardcoded)
print("="*70)
print("FINAL COMPARISON - Structure Accuracy (non-air blocks)")
print("="*70)
print(f"  V1:     {previous_results['V1']['best_val_struct_acc']:.1%}  (previous)")
print(f"  V2:     {previous_results['V2']['best_val_struct_acc']:.1%}  (previous)")
print(f"  V3:     {v3_result['best_val_struct_acc']:.1%}  ← NEW (best)")
print(f"  Random: {previous_results['Random']['best_val_struct_acc']:.1%}  (previous)")
print()

# Use BEST accuracy, not last epoch accuracy
v3_best = v3_result['best_val_struct_acc']
v1_best = previous_results['V1']['best_val_struct_acc']
v2_best = previous_results['V2']['best_val_struct_acc']

if v3_best > v1_best:
    print("✓ V3 BEATS V1! Use V3 embeddings for VQ-VAE.")
elif v3_best > v2_best:
    print("~ V3 is between V1 and V2. Could use either V1 or V3.")
else:
    print("✗ V3 is worse than V1. Stick with V1 embeddings.")


In [None]:
# ============================================================
# CELL 12: Final Summary
# ============================================================

print("\n" + "="*70)
print("PHASE 0 COMPLETE: VQ-VAE EMBEDDING VALIDATION")
print("="*70)

print("\nQuestion: Do Block2Vec embeddings help VQ-VAE reconstruct STRUCTURES?")
print(f"\nAir tokens excluded: {sorted(AIR_TOKENS)}")
print(f"Average air percentage: {avg_air_pct:.1%}")
print("\nNOTE: Overall accuracy is ~{:.0%} just from predicting air correctly.".format(avg_air_pct))
print("      STRUCTURE accuracy is the true measure of reconstruction quality!")

print("\nResults:")
print("\n{:<12} {:>12} {:>15} {:>15} {:>12}".format(
    "", "Val Loss", "Overall Acc", "★STRUCT Acc★", "Air Acc"))
print("-"*70)

# All results
for name in ['V3', 'V1', 'V2', 'Random']:
    if name == 'V3':
        r = all_results['V3']
        label = "V3 (NEW)"
    else:
        r = previous_results[name]
        label = name
    print("{:<12} {:>12.4f} {:>15.1%} {:>15.1%} {:>12.1%}".format(
        label,
        r['best_val_loss'],
        r['best_val_acc'],
        r['best_val_struct_acc'],
        r['history']['val_air_acc'][-1]
    ))

# Final verdict
print("\n" + "="*70)
print("CONCLUSION")
print("="*70)

v3_struct = all_results['V3']['best_val_struct_acc']
v1_struct = previous_results['V1']['best_val_struct_acc']
v2_struct = previous_results['V2']['best_val_struct_acc']
random_struct = previous_results['Random']['best_val_struct_acc']

v3_vs_random = (v3_struct - random_struct) / random_struct * 100
v3_vs_v1 = (v3_struct - v1_struct) / v1_struct * 100

print(f"\nV3 Structure Accuracy: {v3_struct:.1%}")
print(f"  Improvement over Random ({random_struct:.1%}): {v3_vs_random:+.1f}%")
print(f"  Comparison to V1 ({v1_struct:.1%}):        {v3_vs_v1:+.1f}%")

if v3_struct > v1_struct:
    print("\n✓ V3 BEATS V1! Compositional embeddings are the best!")
elif v3_struct > v2_struct:
    print("\n~ V3 is between V1 and V2. V1 still slightly better.")
else:
    print("\n✗ V3 underperforms. Stick with V1 embeddings.")

print("\n" + "="*70)
print("Files saved:")
print(f"  - {OUTPUT_DIR}/embedding_comparison_all.png")
print(f"  - {OUTPUT_DIR}/v3_validation_summary.json")
print(f"  - {OUTPUT_DIR}/v3_validation_full.json")
print(f"  - {OUTPUT_DIR}/air_tokens_used.json")
print("="*70)