# Block2Vec V3: Compositional Embeddings

## What This Notebook Does

This notebook trains **version 3** of our Block2Vec model - a neural network that learns numerical representations ("embeddings") for Minecraft blocks. V3 uses a completely different approach from V1 and V2 to solve a fundamental problem we discovered.

## Why V3? What Went Wrong With V1 and V2?

### The Problem We're Solving

We want blocks that serve the **same function** to have similar embeddings. For example:
- `oak_planks` and `spruce_planks` should be similar (both are wooden planks)
- `oak_stairs` and `stone_stairs` should be similar (both are stairs)
- `white_wool` and `red_wool` should be similar (both are wool)

This is important because when our VQ-VAE (Phase 3) learns to generate structures, it should understand that swapping `oak_planks` for `spruce_planks` is a minor change, while swapping `oak_planks` for `lava` is a major change.

### What V1 Did (Skip-gram Only)

V1 used the same approach as Word2Vec: **"You shall know a block by its neighbors."**

The idea is that blocks appearing next to similar neighbors should have similar embeddings. This worked great for some blocks:

```
diamond_ore neighbors: [stone, stone, stone, deepslate, air, stone]
emerald_ore neighbors: [stone, stone, deepslate, stone, air, stone]
→ 95% overlap! So diamond_ore ≈ emerald_ore ✓
```

But it **completely failed** for wood planks:

```
oak_planks neighbors: [oak_planks, oak_stairs, oak_log, dark_oak_planks, air]
spruce_planks neighbors: [spruce_planks, spruce_stairs, spruce_log, air]
→ Almost NO overlap! oak_planks and spruce_planks never "meet"
```

**The result:** Ores clustered beautifully (~80% coherence), but planks were scattered (~20% coherence).

### What V2 Did (Hybrid Skip-gram + CBOW)

V2 tried to fix this by adding **CBOW** (Continuous Bag of Words), which learns in the opposite direction:
- Skip-gram: Given center block, predict neighbors
- CBOW: Given neighbors, predict center block

The theory was that CBOW would learn: "If neighbors are [planks, stairs, log, air], the center is probably some kind of plank." This should connect oak_planks and spruce_planks.

**But it failed even worse!**

| Metric | V1 | V2 |
|--------|----|----|  
| Overall Coherence | 20.4% | 16.1% |
| Planks Coherence | ~20% | **2.5%** |
| Ores Coherence | ~80% | 44% |

### Why V2 Failed: The Chicken-and-Egg Problem

Here's the subtle issue:

For CBOW to learn that `oak_planks ≈ spruce_planks`, it needs:
- `oak_stairs ≈ spruce_stairs` (so their contexts look similar)
- `oak_log ≈ spruce_log` (so their contexts look similar)

But `oak_stairs` and `spruce_stairs` face the same problem! Their neighbors don't overlap either.

**It's a chicken-and-egg problem.** The model can't learn planks are similar without stairs being similar, and can't learn stairs are similar without planks being similar.

### The Root Cause

The fundamental issue is that **Minecraft builders don't mix wood types**. A house made of oak uses oak planks, oak stairs, oak logs, oak doors. A house made of spruce uses all spruce. There's no "bridge" connecting oak and spruce in the training data.

This is a **data structure problem**, not an algorithm problem. No amount of hyperparameter tuning will fix it.

---

# Part 1: The V3 Solution - Compositional Embeddings

## The Key Insight

Instead of learning one embedding per block and hoping similar blocks become similar, we **force** the structure we want by decomposing blocks into components:

```
oak_planks = material(oak) + shape(planks) + property(solid)
spruce_planks = material(spruce) + shape(planks) + property(solid)
```

Since `oak_planks` and `spruce_planks` share the **same shape embedding**, they are **guaranteed** to be similar!

## How It Works

Each block embedding is composed of three parts:

### 1. Material Embedding (16 dimensions)
What the block is made of: oak, spruce, stone, iron, diamond, white, red, etc.

### 2. Shape Embedding (16 dimensions)
The geometric form: planks, stairs, slab, fence, door, block, ore, wool, etc.

### 3. Property Embedding (8 dimensions)
Functional characteristics: solid, transparent, light_emitting, interactable, etc.

## The Mathematical Guarantee

```
similarity(oak_planks, spruce_planks)
= similarity(oak + planks + solid, spruce + planks + solid)
= (oak · spruce) + (planks · planks) + (solid · solid)
= (small)       + (1.0)             + (1.0)
= HIGH SIMILARITY! ✓
```

The shape and property components contribute strongly to similarity, even if the materials are different.

## Expected Improvements

| Metric | V1 | V2 | V3 (Expected) |
|--------|----|----|---------------|
| Overall Coherence | 20.4% | 16.1% | **>50%** |
| Planks Coherence | ~20% | 2.5% | **>80%** |
| Stairs Coherence | ~60% | ~44% | **>80%** |
| Wool Coherence | ? | ? | **>80%** |

---

# Part 2: Setup and Configuration

Let's start by importing libraries and setting up our environment.

In [None]:
# ============================================================
# CELL 1: Imports and Setup
# ============================================================
# These are the libraries we need:
# - torch: PyTorch deep learning framework
# - numpy: Numerical operations on arrays
# - h5py: Reading HDF5 files (our training data)
# - json: Reading/writing configuration files
# - matplotlib: Creating visualizations
# - sklearn: For t-SNE dimensionality reduction and similarity metrics

import json
import os
import random
import re
import time
from collections import Counter, defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, Dict, List, Tuple, Iterator

import h5py
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.manifold import TSNE
from sklearn.metrics.pairwise import cosine_similarity
from torch.utils.data import Dataset, DataLoader, IterableDataset
from tqdm.notebook import tqdm

# Check if GPU is available - this is why we're using Kaggle!
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# ============================================================
# CELL 2: Configuration
# ============================================================
# These are our HYPERPARAMETERS - values we choose before training.
#
# KEY DIFFERENCE FROM V1/V2:
# Instead of one 32-dim embedding per block, we have:
# - 16 dims for MATERIAL (what it's made of)
# - 16 dims for SHAPE (geometric form)
# - 8 dims for PROPERTIES (functional characteristics)
# Total: 40 dimensions

# === Data Paths ===
DATA_DIR = "/kaggle/input/minecraft-schematics/minecraft_splits/splits/train"
VOCAB_PATH = "/kaggle/input/minecraft-schematics/tok2block.json"
OUTPUT_DIR = "/kaggle/working"

# === Compositional Embedding Dimensions ===
MATERIAL_DIM = 16   # Captures: oak, spruce, stone, iron, white, red, etc.
SHAPE_DIM = 16      # Captures: planks, stairs, slab, fence, door, etc.
PROPERTY_DIM = 8    # Captures: solid, transparent, light_emitting, etc.
TOTAL_EMBEDDING_DIM = MATERIAL_DIM + SHAPE_DIM + PROPERTY_DIM  # 40

# === Training Hyperparameters ===
# NOTE: Using fewer epochs than V1/V2 because:
# 1. V3 has far fewer parameters (~425 components vs 3717 blocks)
# 2. Compositional structure constrains the solution space
# 3. V1 peaked at epoch 10, V2 at epoch 25
# 4. This is a validation run - we can train longer if it works
EPOCHS = 15              # Reduced from 50 - enough to validate approach
BATCH_SIZE = 4096        # Examples per gradient update
LEARNING_RATE = 0.001    # Step size for optimization
NUM_NEGATIVE = 10        # Negative samples per positive pair
SUBSAMPLE_THRESHOLD = 0.001  # Subsample blocks more frequent than this

# === Other ===
SEED = 42  # Random seed for reproducibility

# Set random seeds
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

print("V3 Compositional Embedding Configuration:")
print(f"  Material dimensions: {MATERIAL_DIM}")
print(f"  Shape dimensions: {SHAPE_DIM}")
print(f"  Property dimensions: {PROPERTY_DIM}")
print(f"  Total embedding size: {TOTAL_EMBEDDING_DIM}")
print(f"\nTraining:")
print(f"  Epochs: {EPOCHS} (reduced for validation)")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Learning rate: {LEARNING_RATE}")

---

# Part 3: Block Decomposition

## The Core Innovation of V3

We need to decompose each Minecraft block name into its components:

```
"minecraft:oak_planks" → material=oak, shape=planks, properties=[solid]
"minecraft:white_wool" → material=white, shape=wool, properties=[solid]
"minecraft:glass"      → material=None, shape=glass, properties=[transparent]
"minecraft:torch"      → material=None, shape=torch, properties=[light_emitting, solid]
```

This decomposition is done automatically using pattern matching on block names.

## Material Categories

We define several categories of materials:

- **Wood types**: oak, spruce, birch, jungle, acacia, dark_oak, etc.
- **Stone types**: stone, cobblestone, granite, diorite, andesite, etc.
- **Colors**: white, orange, red, blue, green, etc. (for wool, concrete, terracotta)
- **Metals**: iron, gold, copper, netherite
- **Minerals**: diamond, emerald, lapis, redstone, coal, quartz

## Shape Categories

Shapes are identified by suffixes in block names:

- `_planks` → shape is "planks"
- `_stairs` → shape is "stairs"
- `_slab` → shape is "slab"
- `_fence` → shape is "fence"
- `_door` → shape is "door"
- etc.

Some blocks ARE their own shape (like `air`, `water`, `dirt`, `torch`).

## Properties

Properties describe functional characteristics:

- **solid**: Most blocks (stone, planks, wool)
- **transparent**: Glass, ice, water, air, leaves
- **light_emitting**: Torch, glowstone, lantern, fire
- **interactable**: Doors, chests, buttons, furnaces

In [None]:
# ============================================================
# CELL 3: Block Decomposition Definitions
# ============================================================
# This cell defines how we parse Minecraft block names into
# their material, shape, and property components.

@dataclass
class BlockComponents:
    """Holds the decomposed components of a Minecraft block."""
    original_name: str      # e.g., "minecraft:oak_planks[axis=y]"
    base_name: str          # e.g., "oak_planks" (without namespace and state)
    material: Optional[str] = None   # e.g., "oak"
    shape: str = "block"    # e.g., "planks"
    properties: list = field(default_factory=list)  # e.g., ["solid"]


# === MATERIAL DEFINITIONS ===
# These are the "what it's made of" prefixes

WOOD_MATERIALS = {
    "oak", "spruce", "birch", "jungle", "acacia", "dark_oak",
    "mangrove", "cherry", "bamboo", "crimson", "warped"
}

STONE_MATERIALS = {
    "stone", "cobblestone", "mossy_cobblestone", "smooth_stone",
    "granite", "polished_granite", "diorite", "polished_diorite",
    "andesite", "polished_andesite", "deepslate", "cobbled_deepslate",
    "polished_deepslate", "calcite", "tuff", "dripstone", "blackstone",
    "polished_blackstone", "basalt", "smooth_basalt", "polished_basalt"
}

BRICK_MATERIALS = {
    "brick", "stone_brick", "mossy_stone_brick", "nether_brick",
    "red_nether_brick", "end_stone_brick", "prismarine_brick",
    "deepslate_brick", "polished_blackstone_brick", "mud_brick"
}

SANDSTONE_MATERIALS = {
    "sandstone", "red_sandstone", "smooth_sandstone", "smooth_red_sandstone",
    "cut_sandstone", "cut_red_sandstone", "chiseled_sandstone", "chiseled_red_sandstone"
}

METAL_MATERIALS = {
    "iron", "gold", "copper", "exposed_copper", "weathered_copper",
    "oxidized_copper", "waxed_copper", "netherite"
}

MINERAL_MATERIALS = {"diamond", "emerald", "lapis", "redstone", "coal", "quartz", "amethyst"}

# Colors for wool, concrete, terracotta, etc.
COLOR_MATERIALS = {
    "white", "orange", "magenta", "light_blue", "yellow", "lime",
    "pink", "gray", "light_gray", "cyan", "purple", "blue",
    "brown", "green", "red", "black"
}

NETHER_MATERIALS = {
    "nether", "soul", "shroomlight", "glowstone", "netherrack", "magma"
}

END_MATERIALS = {"end_stone", "purpur", "end"}

# Combine all materials
ALL_MATERIALS = (WOOD_MATERIALS | STONE_MATERIALS | BRICK_MATERIALS | 
                 SANDSTONE_MATERIALS | METAL_MATERIALS | MINERAL_MATERIALS | 
                 COLOR_MATERIALS | NETHER_MATERIALS | END_MATERIALS)

print(f"Defined {len(ALL_MATERIALS)} material types")
print(f"  Wood: {len(WOOD_MATERIALS)} types")
print(f"  Stone: {len(STONE_MATERIALS)} types")
print(f"  Brick: {len(BRICK_MATERIALS)} types")
print(f"  Sandstone: {len(SANDSTONE_MATERIALS)} types")
print(f"  Metal: {len(METAL_MATERIALS)} types")
print(f"  Colors: {len(COLOR_MATERIALS)} types")

In [None]:
# ============================================================
# CELL 4: Shape and Property Definitions
# ============================================================
# Shapes are identified by suffixes in block names.
# Properties describe functional characteristics.

# === SHAPE PATTERNS ===
# Maps block name suffixes to shape categories
# Longer patterns must come first to avoid partial matches
SHAPE_PATTERNS = {
    # Exact suffix matches (order matters - longer first)
    "_pressure_plate": "pressure_plate",
    "_wall_hanging_sign": "wall_hanging_sign",
    "_wall_sign": "wall_sign",
    "_wall_banner": "wall_banner",
    "_wall_head": "wall_head",
    "_wall_skull": "wall_skull",
    "_wall_torch": "wall_torch",
    "_wall_fan": "wall_fan",
    "_hanging_sign": "hanging_sign",
    "_fence_gate": "fence_gate",
    "_trap_door": "trapdoor",
    "_trapdoor": "trapdoor",
    "_coral_block": "coral_block",
    "_coral_fan": "coral_fan",
    "_coral": "coral",
    "_mushroom_block": "mushroom_block",
    "_mushroom": "mushroom",
    "_concrete_powder": "concrete_powder",
    "_glazed_terracotta": "glazed_terracotta",
    "_stained_glass_pane": "stained_glass_pane",
    "_stained_glass": "stained_glass",
    "_candle_cake": "candle_cake",
    "_shulker_box": "shulker_box",
    "_amethyst_bud": "amethyst_bud",
    "_froglight": "froglight",
    "_button": "button",
    "_stairs": "stairs",
    "_planks": "planks",
    "_slab": "slab",
    "_wall": "wall",
    "_fence": "fence",
    "_door": "door",
    "_sign": "sign",
    "_log": "log",
    "_wood": "wood",
    "_stem": "stem",
    "_hyphae": "hyphae",
    "_roots": "roots",
    "_leaves": "leaves",
    "_sapling": "sapling",
    "_carpet": "carpet",
    "_bed": "bed",
    "_banner": "banner",
    "_candle": "candle",
    "_head": "head",
    "_skull": "skull",
    "_pot": "pot",
    "_ore": "ore",
    "_block": "block",
    "_glass": "glass",
    "_pane": "pane",
    "_bars": "bars",
    "_chain": "chain",
    "_lantern": "lantern",
    "_torch": "torch",
    "_rail": "rail",
    "_chest": "chest",
    "_terracotta": "terracotta",
    "_concrete": "concrete",
    "_wool": "wool",
    "_tulip": "flower",
    "_vines": "vines",
    "_plant": "plant",
}

# Blocks that ARE their own shape (no material prefix)
STANDALONE_SHAPES = {
    # Air and fluids
    "air", "cave_air", "void_air", "water", "lava", "fire", "soul_fire",
    "bubble_column", "powder_snow",

    # Terrain blocks
    "grass_block", "dirt", "coarse_dirt", "rooted_dirt", "podzol", "mycelium",
    "sand", "red_sand", "gravel", "clay", "mud", "packed_mud", "farmland", "dirt_path",
    "snow", "snow_block", "snow_layer", "ice", "packed_ice", "blue_ice", "frosted_ice",
    "glass", "tinted_glass", "bedrock", "obsidian", "crying_obsidian",

    # Storage and crafting
    "chest", "trapped_chest", "ender_chest", "barrel", "shulker_box",
    "crafting_table", "furnace", "blast_furnace", "smoker",
    "anvil", "chipped_anvil", "damaged_anvil",
    "grindstone", "stonecutter", "cartography_table",
    "fletching_table", "smithing_table", "loom", "lectern",
    "composter", "brewing_stand", "cauldron", "water_cauldron", "lava_cauldron", "powder_snow_cauldron",

    # Redstone
    "hopper", "dropper", "dispenser", "observer", "piston", "piston_head",
    "sticky_piston", "moving_piston", "slime_block", "honey_block",
    "tnt", "target", "repeater", "comparator", "daylight_detector",
    "note_block", "jukebox", "lever", "tripwire", "tripwire_hook",
    "lightning_rod", "redstone_wire",

    # Special blocks
    "respawn_anchor", "lodestone", "beacon", "conduit",
    "enchanting_table", "end_portal_frame", "end_portal", "end_gateway", "dragon_egg",
    "bell", "campfire", "soul_campfire",
    "spawner", "structure_block", "structure_void", "jigsaw", "barrier", "light",
    "command_block", "chain_command_block", "repeating_command_block",

    # Torches and lighting
    "torch", "soul_torch", "redstone_torch", "wall_torch", "soul_wall_torch", "redstone_wall_torch",
    "lantern", "soul_lantern", "chain", "end_rod", "sea_lantern",

    # Flowers (single blocks)
    "dandelion", "poppy", "blue_orchid", "allium", "azure_bluet",
    "oxeye_daisy", "cornflower", "lily_of_the_valley", "wither_rose",
    "sunflower", "lilac", "rose_bush", "peony", "torchflower",
    "pitcher_plant", "pitcher_crop", "spore_blossom",

    # Tall plants and grass
    "grass", "tall_grass", "fern", "large_fern", "dead_bush",
    "seagrass", "tall_seagrass", "kelp", "kelp_plant",
    "sugar_cane", "bamboo", "cactus",
    "vine", "glow_lichen", "sculk_vein",
    "hanging_roots", "azalea", "flowering_azalea",
    "big_dripleaf", "big_dripleaf_stem", "small_dripleaf",
    "lily_pad", "moss_carpet",

    # Crops
    "wheat", "carrots", "potatoes", "beetroots", "melon", "pumpkin",
    "carved_pumpkin", "jack_o_lantern", "melon_stem", "pumpkin_stem",
    "attached_melon_stem", "attached_pumpkin_stem",
    "sweet_berry_bush", "cocoa", "nether_wart", "torchflower_crop",

    # Cave/underground
    "pointed_dripstone", "sculk", "sculk_sensor", "calibrated_sculk_sensor",
    "sculk_catalyst", "sculk_shrieker", "moss_block",
    "amethyst_cluster", "budding_amethyst",
    "cave_vines", "cave_vines_plant", "glow_item_frame", "item_frame",
    "twisting_vines", "twisting_vines_plant", "weeping_vines", "weeping_vines_plant",

    # Corals (standalone base types)
    "brain_coral", "bubble_coral", "fire_coral", "horn_coral", "tube_coral",
    "dead_brain_coral", "dead_bubble_coral", "dead_fire_coral", "dead_horn_coral", "dead_tube_coral",

    # Nether
    "nether_sprouts", "crimson_fungus", "warped_fungus", "crimson_nylium", "warped_nylium",
    "shroomlight", "nether_wart_block", "warped_wart_block",

    # Misc blocks
    "sponge", "wet_sponge", "cobweb", "bookshelf", "chiseled_bookshelf",
    "hay_block", "bone_block", "honeycomb_block", "dried_kelp_block",
    "mushroom_stem", "chorus_flower", "chorus_plant",
    "decorated_pot", "flower_pot", "scaffolding", "ladder",
    "rail", "powered_rail", "detector_rail", "activator_rail",
    "frogspawn", "turtle_egg", "sniffer_egg",

    # Skulls and heads
    "skull", "creeper_head", "dragon_head", "piglin_head", "player_head", "zombie_head",
    "skeleton_skull", "wither_skeleton_skull",

    # Infested blocks
    "infested_stone", "infested_cobblestone", "infested_stone_bricks",
    "infested_mossy_stone_bricks", "infested_cracked_stone_bricks",
    "infested_chiseled_stone_bricks", "infested_deepslate",

    # Prismarine
    "prismarine", "prismarine_bricks", "dark_prismarine",

    # Generic/base blocks (when no color prefix)
    "terracotta", "concrete_powder", "banner", "wall_banner",
    "button", "carpet", "wool", "candle", "candle_cake",
    "stained_glass", "stained_glass_pane", "stained_hardened_clay",

    # Remaining unique blocks
    "ancient_debris", "azalea_leaves_flowers",
    "bee_hive", "bee_nest", "beehive",
    "bricks", "cake",
    "exposed_cut_copper", "oxidized_cut_copper", "weathered_cut_copper",
    "gilded_blackstone", "reinforced_deepslate",
    "mossy_stone_bricks", "mud_bricks",
    "sea_pickle", "suspicious_gravel", "suspicious_sand",
}

# === PROPERTY KEYWORDS ===
TRANSPARENT_KEYWORDS = {"glass", "pane", "ice", "leaves", "air", "water", "lava", "barrier"}
LIGHT_KEYWORDS = {"torch", "lantern", "glowstone", "shroomlight", "fire", "lava", "beacon", 
                  "sea_lantern", "end_rod", "froglight", "campfire", "magma"}
INTERACTABLE_KEYWORDS = {"door", "trapdoor", "fence_gate", "button", "chest", "furnace",
                         "lever", "bed", "barrel", "anvil", "enchanting_table"}

print(f"Defined {len(SHAPE_PATTERNS)} shape patterns")
print(f"Defined {len(STANDALONE_SHAPES)} standalone shapes")

In [None]:
# ============================================================
# CELL 5: Block Decomposition Functions
# ============================================================
# These functions parse a block name into its components.

def extract_base_name(block_name: str) -> str:
    """
    Remove namespace and block state from a block name.
    
    Examples:
        "minecraft:oak_planks" → "oak_planks"
        "minecraft:oak_stairs[facing=north]" → "oak_stairs"
    """
    name = block_name.replace("minecraft:", "")
    if "[" in name:
        name = name[:name.index("[")]
    return name


def identify_material(base_name: str) -> Optional[str]:
    """
    Identify the material component of a block name.
    
    Examples:
        "oak_planks" → "oak"
        "white_wool" → "white"
        "stone" → "stone"
        "glass" → None (no material prefix)
    """
    for material in ALL_MATERIALS:
        if base_name.startswith(material + "_") or base_name == material:
            return material
    return None


def identify_shape(base_name: str) -> str:
    """
    Identify the shape component of a block name.
    
    Examples:
        "oak_planks" → "planks"
        "stone_stairs" → "stairs"
        "glass" → "glass" (standalone shape)
        "potted_cactus" → "potted_plant"
        "stripped_oak_log" → "stripped_log"
    """
    # Check standalone shapes first
    if base_name in STANDALONE_SHAPES:
        return base_name
    
    # Special case: potted plants
    if base_name.startswith("potted_"):
        return "potted_plant"
    
    # Special case: stripped wood/logs
    if base_name.startswith("stripped_"):
        rest = base_name[9:]  # Remove "stripped_"
        if rest.endswith("_log"):
            return "stripped_log"
        elif rest.endswith("_wood"):
            return "stripped_wood"
        elif rest.endswith("_stem"):
            return "stripped_stem"
        elif rest.endswith("_hyphae"):
            return "stripped_hyphae"
        elif rest.endswith("_block"):
            return "stripped_block"
    
    # Special case: chiseled variants
    if base_name.startswith("chiseled_"):
        return "chiseled_block"
    
    # Special case: cracked variants
    if base_name.startswith("cracked_"):
        return "cracked_block"
    
    # Special case: cut variants (cut_copper, cut_sandstone)
    if base_name.startswith("cut_"):
        return "cut_block"
    
    # Special case: smooth variants
    if base_name.startswith("smooth_"):
        rest = base_name[7:]
        if any(rest.endswith(s) for s in ["_slab", "_stairs"]):
            pass  # Let it fall through to pattern matching
        else:
            return "smooth_block"
    
    # Special case: waxed copper variants
    if base_name.startswith("waxed_"):
        rest = base_name[6:]
        if rest.endswith("_slab"):
            return "slab"
        elif rest.endswith("_stairs"):
            return "stairs"
        else:
            return "waxed_block"
    
    # Special case: raw ore blocks
    if base_name.startswith("raw_") and base_name.endswith("_block"):
        return "raw_block"
    
    # Check shape patterns (longer patterns first)
    for pattern, shape in sorted(SHAPE_PATTERNS.items(), key=lambda x: -len(x[0])):
        if base_name.endswith(pattern):
            return shape
    
    # Default to "block" for full blocks
    return "block"


def identify_properties(base_name: str, shape: str) -> list:
    """
    Identify functional properties of a block.
    
    Properties:
        - solid: Block is solid (most blocks)
        - transparent: Light passes through (glass, water, air)
        - light_emitting: Block produces light (torch, glowstone)
        - interactable: Block can be interacted with (doors, chests)
    """
    properties = []
    
    # Check transparency
    for kw in TRANSPARENT_KEYWORDS:
        if kw in base_name or kw == shape:
            properties.append("transparent")
            break
    
    # Check light emission
    for kw in LIGHT_KEYWORDS:
        if kw in base_name or kw == shape:
            properties.append("light_emitting")
            break
    
    # Check interactability
    for kw in INTERACTABLE_KEYWORDS:
        if kw in base_name or kw == shape:
            properties.append("interactable")
            break
    
    # Default to solid if not transparent and not air/water/lava
    if "transparent" not in properties and base_name not in {"air", "water", "lava", "fire"}:
        properties.append("solid")
    
    return properties


def decompose_block(block_name: str) -> BlockComponents:
    """
    Decompose a full block name into its semantic components.
    
    Example:
        "minecraft:oak_stairs[facing=north]" →
        BlockComponents(material="oak", shape="stairs", properties=["solid"])
    """
    base_name = extract_base_name(block_name)
    material = identify_material(base_name)
    shape = identify_shape(base_name)
    properties = identify_properties(base_name, shape)
    
    return BlockComponents(
        original_name=block_name,
        base_name=base_name,
        material=material,
        shape=shape,
        properties=properties
    )


# Test the decomposition
print("Testing block decomposition:")
print("="*70)
test_blocks = [
    "minecraft:oak_planks",
    "minecraft:spruce_planks",
    "minecraft:oak_stairs[facing=north]",
    "minecraft:stone",
    "minecraft:glass",
    "minecraft:white_wool",
    "minecraft:torch",
    "minecraft:iron_door[facing=east]",
    "minecraft:diamond_ore",
    "minecraft:potted_cactus",
    "minecraft:stripped_oak_log",
    "minecraft:chiseled_stone_bricks",
]

for block in test_blocks:
    comp = decompose_block(block)
    print(f"{comp.base_name:<25} material={str(comp.material):<10} shape={comp.shape:<15} props={comp.properties}")

---

# Part 4: Loading and Processing the Vocabulary

Now we load the full vocabulary and decompose all 3,717 blocks into their components. We'll also create index mappings for each component type.

In [None]:
# ============================================================
# CELL 6: Load Vocabulary and Decompose All Blocks
# ============================================================
# We load the tok2block.json mapping and decompose every block.

# Load vocabulary
with open(VOCAB_PATH, 'r') as f:
    tok2block = {int(k): v for k, v in json.load(f).items()}

VOCAB_SIZE = len(tok2block)
print(f"Vocabulary size: {VOCAB_SIZE} unique block states")

# Decompose all blocks and collect unique components
block_components = {}  # token_id → BlockComponents
materials_set = set()
shapes_set = set()
properties_set = set()

print("\nDecomposing all blocks...")
for token_id, block_name in tqdm(tok2block.items(), desc="Decomposing"):
    comp = decompose_block(block_name)
    block_components[token_id] = comp
    
    if comp.material:
        materials_set.add(comp.material)
    shapes_set.add(comp.shape)
    properties_set.update(comp.properties)

# Create sorted lists for consistent indexing
# "_none_" represents blocks without a material (like "glass" or "air")
materials_list = ["_none_"] + sorted(materials_set)
shapes_list = sorted(shapes_set)
properties_list = sorted(properties_set)

# Create index mappings
material2idx = {m: i for i, m in enumerate(materials_list)}
shape2idx = {s: i for i, s in enumerate(shapes_list)}
property2idx = {p: i for i, p in enumerate(properties_list)}

NUM_MATERIALS = len(materials_list)
NUM_SHAPES = len(shapes_list)
NUM_PROPERTIES = len(properties_list)

print(f"\nComponent counts:")
print(f"  Materials: {NUM_MATERIALS} (including _none_)")
print(f"  Shapes: {NUM_SHAPES}")
print(f"  Properties: {NUM_PROPERTIES}")

print(f"\nShapes found: {shapes_list}")
print(f"\nProperties found: {properties_list}")

In [None]:
# ============================================================
# CELL 7: Create Block-to-Component Mapping Tensors
# ============================================================
# We create tensors that map each block token to its component indices.
# These will be used by the model to look up component embeddings.

# Create mapping tensors
# block_to_material[token_id] = material index
# block_to_shape[token_id] = shape index
# block_to_properties[token_id, property_idx] = 1 if block has property, else 0

block_to_material = torch.zeros(VOCAB_SIZE, dtype=torch.long)
block_to_shape = torch.zeros(VOCAB_SIZE, dtype=torch.long)
block_to_properties = torch.zeros(VOCAB_SIZE, NUM_PROPERTIES)

for token_id, comp in block_components.items():
    # Material index (0 = _none_ for blocks without material)
    if comp.material:
        block_to_material[token_id] = material2idx[comp.material]
    else:
        block_to_material[token_id] = 0  # _none_
    
    # Shape index
    block_to_shape[token_id] = shape2idx[comp.shape]
    
    # Properties (multi-hot encoding)
    for prop in comp.properties:
        block_to_properties[token_id, property2idx[prop]] = 1.0

print("Created block-to-component mapping tensors")
print(f"  block_to_material: {block_to_material.shape}")
print(f"  block_to_shape: {block_to_shape.shape}")
print(f"  block_to_properties: {block_to_properties.shape}")

# Show distribution of shapes
print("\nBlocks per shape (top 15):")
shape_counts = Counter(comp.shape for comp in block_components.values())
for shape, count in shape_counts.most_common(15):
    print(f"  {shape}: {count}")

---

# Part 5: The Compositional Block2Vec Model

## How the Model Works

The key difference from V1/V2 is how we compute block embeddings:

### V1/V2 (Direct Lookup)
```
embedding = embedding_matrix[block_token]
```
Each block has its own independent embedding. No guaranteed relationship between similar blocks.

### V3 (Compositional)
```
material_emb = material_matrix[block_to_material[block_token]]
shape_emb = shape_matrix[block_to_shape[block_token]]
property_emb = property_matrix @ block_to_properties[block_token]

embedding = concat(material_emb, shape_emb, property_emb)
```

Blocks sharing components (same material OR same shape) automatically share embedding dimensions.

## Training Objective

We still use **Skip-gram with negative sampling** - the same as V1. The difference is only in how embeddings are computed, not in how they're trained.

- **Positive pairs**: (center block, neighbor block) from actual builds
- **Negative pairs**: (center block, random block)
- **Goal**: Make positive pairs have high similarity, negative pairs have low similarity

In [None]:
# ============================================================
# CELL 8: Compositional Block2Vec Model
# ============================================================

class CompositionalBlock2Vec(nn.Module):
    """
    Block2Vec V3 with compositional embeddings.
    
    Instead of one embedding per block, we compose embeddings from:
    - Material embedding (16 dims)
    - Shape embedding (16 dims)
    - Property embedding (8 dims)
    
    This guarantees that blocks with the same shape (e.g., oak_planks
    and spruce_planks) will have similar embeddings.
    """
    
    def __init__(
        self,
        num_blocks: int,
        num_materials: int,
        num_shapes: int,
        num_properties: int,
        material_dim: int,
        shape_dim: int,
        property_dim: int,
        block_to_material: torch.Tensor,
        block_to_shape: torch.Tensor,
        block_to_properties: torch.Tensor,
    ):
        super().__init__()
        
        self.num_blocks = num_blocks
        self.embedding_dim = material_dim + shape_dim + property_dim
        
        # Component embedding matrices
        # Each component type has its own learned embedding table
        self.material_emb = nn.Embedding(num_materials, material_dim)
        self.shape_emb = nn.Embedding(num_shapes, shape_dim)
        self.property_emb = nn.Embedding(num_properties, property_dim)
        
        # Context embedding for Skip-gram (like V1)
        # This is a separate embedding used for context/neighbor blocks
        self.context_emb = nn.Embedding(num_blocks, self.embedding_dim)
        
        # Register mappings as buffers (saved with model but not trained)
        self.register_buffer('block_to_material', block_to_material)
        self.register_buffer('block_to_shape', block_to_shape)
        self.register_buffer('block_to_properties', block_to_properties.float())
        
        # Initialize with small random values
        self._init_weights()
    
    def _init_weights(self):
        """Initialize embeddings with small uniform random values."""
        nn.init.uniform_(self.material_emb.weight, -0.1, 0.1)
        nn.init.uniform_(self.shape_emb.weight, -0.1, 0.1)
        nn.init.uniform_(self.property_emb.weight, -0.1, 0.1)
        nn.init.uniform_(self.context_emb.weight, -0.1, 0.1)
    
    def get_block_embedding(self, block_ids: torch.Tensor) -> torch.Tensor:
        """
        Compute compositional embeddings for blocks.
        
        This is the KEY FUNCTION that makes V3 different from V1/V2.
        
        Args:
            block_ids: Tensor of block tokens [batch_size]
            
        Returns:
            Tensor of embeddings [batch_size, embedding_dim]
        """
        flat_ids = block_ids.view(-1)
        
        # Step 1: Look up component indices for each block
        material_ids = self.block_to_material[flat_ids]  # [N]
        shape_ids = self.block_to_shape[flat_ids]        # [N]
        property_mask = self.block_to_properties[flat_ids]  # [N, num_properties]
        
        # Step 2: Get component embeddings
        mat_emb = self.material_emb(material_ids)  # [N, material_dim]
        shp_emb = self.shape_emb(shape_ids)        # [N, shape_dim]
        
        # Step 3: Property embedding (average of all properties the block has)
        # property_emb.weight: [num_properties, property_dim]
        # property_mask: [N, num_properties]
        # Result: weighted sum of property embeddings
        prop_emb = torch.matmul(property_mask, self.property_emb.weight)  # [N, property_dim]
        num_props = property_mask.sum(dim=1, keepdim=True).clamp(min=1)   # Avoid div by 0
        prop_emb = prop_emb / num_props
        
        # Step 4: Concatenate all components
        combined = torch.cat([mat_emb, shp_emb, prop_emb], dim=-1)  # [N, total_dim]
        
        return combined.view(*block_ids.shape, self.embedding_dim)
    
    def forward(
        self,
        center_ids: torch.Tensor,
        context_ids: torch.Tensor,
        negative_ids: torch.Tensor,
    ) -> dict:
        """
        Skip-gram forward pass with negative sampling.
        
        Same as V1, but center embeddings are compositional.
        
        Args:
            center_ids: Center block tokens [batch_size]
            context_ids: Positive context tokens [batch_size]
            negative_ids: Negative sample tokens [batch_size, num_neg]
            
        Returns:
            Dictionary with loss and sub-losses
        """
        # Get center embeddings (COMPOSITIONAL - this is the V3 innovation)
        center_emb = self.get_block_embedding(center_ids)  # [B, D]
        
        # Get context embeddings (direct lookup, like V1)
        pos_ctx = self.context_emb(context_ids)       # [B, D]
        neg_ctx = self.context_emb(negative_ids)      # [B, num_neg, D]
        
        # Positive scores (dot product)
        pos_scores = (center_emb * pos_ctx).sum(dim=1)  # [B]
        
        # Negative scores
        neg_scores = torch.bmm(neg_ctx, center_emb.unsqueeze(2)).squeeze(2)  # [B, num_neg]
        
        # Loss: maximize positive, minimize negative
        pos_loss = F.logsigmoid(pos_scores).mean()
        neg_loss = F.logsigmoid(-neg_scores).mean()
        loss = -(pos_loss + neg_loss)
        
        return {
            'loss': loss,
            'pos_loss': -pos_loss.item(),
            'neg_loss': -neg_loss.item(),
        }
    
    def get_all_embeddings(self) -> np.ndarray:
        """Get embeddings for all blocks as numpy array."""
        with torch.no_grad():
            all_ids = torch.arange(self.num_blocks, device=self.material_emb.weight.device)
            return self.get_block_embedding(all_ids).cpu().numpy()


# Create the model
model = CompositionalBlock2Vec(
    num_blocks=VOCAB_SIZE,
    num_materials=NUM_MATERIALS,
    num_shapes=NUM_SHAPES,
    num_properties=NUM_PROPERTIES,
    material_dim=MATERIAL_DIM,
    shape_dim=SHAPE_DIM,
    property_dim=PROPERTY_DIM,
    block_to_material=block_to_material,
    block_to_shape=block_to_shape,
    block_to_properties=block_to_properties,
).to(device)

# Count parameters
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model created!")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Embedding dimension: {model.embedding_dim}")
print(f"\nParameter breakdown:")
print(f"  Material embeddings: {NUM_MATERIALS} × {MATERIAL_DIM} = {NUM_MATERIALS * MATERIAL_DIM:,}")
print(f"  Shape embeddings: {NUM_SHAPES} × {SHAPE_DIM} = {NUM_SHAPES * SHAPE_DIM:,}")
print(f"  Property embeddings: {NUM_PROPERTIES} × {PROPERTY_DIM} = {NUM_PROPERTIES * PROPERTY_DIM:,}")
print(f"  Context embeddings: {VOCAB_SIZE} × {model.embedding_dim} = {VOCAB_SIZE * model.embedding_dim:,}")

---

# Part 6: The Dataset

We use the same training data approach as V1: iterate through Minecraft builds, extract (center, neighbor) pairs, and sample negative examples.

The dataset is identical to V1 - only the model has changed.

In [None]:
# ============================================================
# CELL 9: Dataset Class
# ============================================================
# Same dataset as V1 - we only changed the model, not the data.

class Block2VecDataset(IterableDataset):
    """
    Dataset that yields (center, context, negatives) tuples.
    
    Streams through H5 files, extracting training pairs on-the-fly.
    Uses subsampling to reduce dominance of frequent blocks (like air).
    """
    
    NEIGHBORS_6 = [
        (-1, 0, 0), (1, 0, 0),  # left, right (x-axis)
        (0, -1, 0), (0, 1, 0),  # down, up (y-axis)
        (0, 0, -1), (0, 0, 1),  # back, front (z-axis)
    ]
    
    def __init__(
        self,
        data_dir: str,
        vocab_size: int,
        num_negative: int = 10,
        subsample_threshold: float = 0.001,
        seed: int = 42,
    ):
        self.data_dir = Path(data_dir)
        self.vocab_size = vocab_size
        self.num_negative = num_negative
        self.subsample_threshold = subsample_threshold
        self.seed = seed
        
        self.h5_files = sorted(self.data_dir.glob("*.h5"))
        print(f"Found {len(self.h5_files)} training files")
        
        self._negative_table = None
        self._subsample_probs = None
        self._build_tables()
    
    def _build_tables(self):
        """Build frequency table for negative sampling and subsampling."""
        print("Building frequency tables...")
        freqs = np.zeros(self.vocab_size, dtype=np.float64)
        
        for h5_path in tqdm(self.h5_files[:100], desc="Counting blocks"):
            with h5py.File(h5_path, 'r') as f:
                build = f[list(f.keys())[0]][:]
                unique, counts = np.unique(build, return_counts=True)
                for tok, count in zip(unique, counts):
                    if tok < self.vocab_size:
                        freqs[tok] += count
        
        freqs /= freqs.sum()
        
        # Negative sampling table (frequency^0.75)
        weighted = np.power(freqs + 1e-10, 0.75)
        weighted /= weighted.sum()
        self._negative_table = np.random.choice(
            self.vocab_size, size=10_000_000, p=weighted
        )
        
        # Subsampling probabilities
        self._subsample_probs = np.ones(self.vocab_size, dtype=np.float32)
        for i, freq in enumerate(freqs):
            if freq > self.subsample_threshold:
                self._subsample_probs[i] = np.sqrt(self.subsample_threshold / freq)
        
        print(f"Tables built!")
    
    def __iter__(self) -> Iterator:
        rng = random.Random(self.seed)
        neg_idx = 0
        files = list(self.h5_files)
        rng.shuffle(files)
        
        for h5_path in files:
            with h5py.File(h5_path, 'r') as f:
                build = f[list(f.keys())[0]][:]
            
            h, w, d = build.shape
            
            for y in range(1, h-1):
                for x in range(1, w-1):
                    for z in range(1, d-1):
                        center = int(build[y, x, z])
                        
                        # Subsampling check
                        if rng.random() >= self._subsample_probs[center]:
                            continue
                        
                        # Check each neighbor
                        for dy, dx, dz in self.NEIGHBORS_6:
                            context = int(build[y+dy, x+dx, z+dz])
                            
                            # Get negatives
                            negatives = self._negative_table[neg_idx:neg_idx + self.num_negative]
                            neg_idx = (neg_idx + self.num_negative) % len(self._negative_table)
                            
                            yield center, context, negatives


def collate_fn(batch):
    """Convert batch to tensors."""
    centers = torch.tensor([b[0] for b in batch], dtype=torch.long)
    contexts = torch.tensor([b[1] for b in batch], dtype=torch.long)
    negatives = torch.tensor(np.array([b[2] for b in batch]), dtype=torch.long)
    return centers, contexts, negatives


# Create dataset and dataloader
dataset = Block2VecDataset(
    data_dir=DATA_DIR,
    vocab_size=VOCAB_SIZE,
    num_negative=NUM_NEGATIVE,
    subsample_threshold=SUBSAMPLE_THRESHOLD,
    seed=SEED,
)

dataloader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    collate_fn=collate_fn,
    pin_memory=(device == "cuda"),
    num_workers=2,
)

print(f"\nDataLoader ready with batch size {BATCH_SIZE}")

---

# Part 7: Training

Now we train the model. This is the same training loop as V1 - only the model is different.

In [None]:
# ============================================================
# CELL 10: Create Optimizer
# ============================================================

optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

print(f"Optimizer: AdamW (lr={LEARNING_RATE})")
print(f"Scheduler: CosineAnnealingLR over {EPOCHS} epochs")

In [None]:
# ============================================================
# CELL 11: Training Loop
# ============================================================

# Estimate batches per epoch (rough, since dataset is iterable)
BATCHES_PER_EPOCH = len(dataset.h5_files) * 500  # ~500 useful samples per file after subsampling

print("="*60)
print("STARTING BLOCK2VEC V3 TRAINING")
print("="*60)
print(f"Epochs: {EPOCHS}")
print(f"Estimated batches per epoch: ~{BATCHES_PER_EPOCH}")
print()

history = {'loss': [], 'pos_loss': [], 'neg_loss': []}
start_time = time.time()

for epoch in range(EPOCHS):
    model.train()
    epoch_loss = 0
    epoch_pos = 0
    epoch_neg = 0
    n_batches = 0
    
    pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{EPOCHS}", leave=False)
    
    for centers, contexts, negatives in pbar:
        # Move to GPU
        centers = centers.to(device)
        contexts = contexts.to(device)
        negatives = negatives.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        out = model(centers, contexts, negatives)
        
        # Backward pass
        out['loss'].backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        # Track metrics
        epoch_loss += out['loss'].item()
        epoch_pos += out['pos_loss']
        epoch_neg += out['neg_loss']
        n_batches += 1
        
        pbar.set_postfix({'loss': f"{out['loss'].item():.4f}"})
        
        # Limit batches per epoch for reasonable training time
        if n_batches >= BATCHES_PER_EPOCH:
            break
    
    scheduler.step()
    
    # Record epoch metrics
    avg_loss = epoch_loss / n_batches
    avg_pos = epoch_pos / n_batches
    avg_neg = epoch_neg / n_batches
    
    history['loss'].append(avg_loss)
    history['pos_loss'].append(avg_pos)
    history['neg_loss'].append(avg_neg)
    
    elapsed = time.time() - start_time
    lr = scheduler.get_last_lr()[0]
    
    print(f"Epoch {epoch+1:2d}/{EPOCHS} | Loss: {avg_loss:.4f} | "
          f"Pos: {avg_pos:.4f} | Neg: {avg_neg:.4f} | "
          f"LR: {lr:.6f} | Time: {elapsed/60:.1f}m")

total_time = time.time() - start_time
print(f"\nTraining complete in {total_time/60:.1f} minutes")

---

# Part 8: Save Results

Let's save the trained embeddings and model for use in Phase 3 (VQ-VAE).

In [None]:
# ============================================================
# CELL 12: Extract and Save Embeddings
# ============================================================

# Get final embeddings
embeddings = model.get_all_embeddings()
print(f"Final embeddings shape: {embeddings.shape}")

# Save embeddings
np.save(f"{OUTPUT_DIR}/block_embeddings_v3.npy", embeddings)
print(f"Saved embeddings to {OUTPUT_DIR}/block_embeddings_v3.npy")

# Save component embeddings separately (for analysis)
component_embs = {
    'material': model.material_emb.weight.detach().cpu().numpy(),
    'shape': model.shape_emb.weight.detach().cpu().numpy(),
    'property': model.property_emb.weight.detach().cpu().numpy(),
}
np.savez(f"{OUTPUT_DIR}/component_embeddings_v3.npz", **component_embs)
print(f"Saved component embeddings")

# Save training history
with open(f"{OUTPUT_DIR}/training_history_v3.json", 'w') as f:
    json.dump(history, f, indent=2)
print(f"Saved training history")

# Save vocabulary info
vocab_info = {
    'materials': materials_list,
    'shapes': shapes_list,
    'properties': properties_list,
    'material2idx': material2idx,
    'shape2idx': shape2idx,
    'property2idx': property2idx,
}
with open(f"{OUTPUT_DIR}/vocab_info_v3.json", 'w') as f:
    json.dump(vocab_info, f, indent=2)
print(f"Saved vocabulary info")

---

# Part 9: Evaluation - Did V3 Fix the Problem?

The key question: **Do blocks with the same shape now have high similarity?**

We compute "coherence" for each shape category - the average cosine similarity between all blocks of that shape.

In [None]:
# ============================================================
# CELL 13: Compute Category Coherence
# ============================================================

def compute_category_coherence(embeddings, block_components):
    """
    Compute coherence (average similarity) for each shape category.
    
    High coherence = blocks of this shape are similar to each other.
    """
    # Group blocks by shape
    shape_groups = defaultdict(list)
    for token_id, comp in block_components.items():
        shape_groups[comp.shape].append(token_id)
    
    results = {}
    
    for shape, token_ids in shape_groups.items():
        if len(token_ids) < 2:
            continue
        
        # Get embeddings for this shape
        shape_embs = embeddings[token_ids]
        
        # Compute pairwise cosine similarities
        sims = cosine_similarity(shape_embs)
        
        # Average similarity (excluding self-similarity on diagonal)
        n = len(token_ids)
        avg_sim = (sims.sum() - n) / (n * (n - 1)) if n > 1 else 0
        
        results[shape] = {
            'count': n,
            'avg_similarity': avg_sim,
        }
    
    return results


coherence = compute_category_coherence(embeddings, block_components)

print("\n" + "="*60)
print("CATEGORY COHERENCE BY SHAPE")
print("="*60)
print("\nThis is the KEY METRIC - higher is better!")
print("V1 had ~20% overall, V2 had ~16% overall.")
print("V3 should be >50% for most shapes.\n")

# Sort by count
sorted_shapes = sorted(coherence.items(), key=lambda x: -x[1]['count'])

print(f"{'Shape':<25} {'Count':>8} {'Coherence':>12}")
print("-"*50)

total_coherence = 0
total_count = 0

for shape, data in sorted_shapes[:20]:
    coh = data['avg_similarity']
    count = data['count']
    indicator = "***" if coh > 0.8 else "**" if coh > 0.5 else "*" if coh > 0.3 else ""
    print(f"{shape:<25} {count:>8} {coh:>11.1%} {indicator}")
    total_coherence += coh * count
    total_count += count

overall = total_coherence / total_count if total_count > 0 else 0
print(f"\n{'OVERALL (weighted)':<25} {total_count:>8} {overall:>11.1%}")

In [None]:
# ============================================================
# CELL 14: Compare Specific Block Pairs
# ============================================================
# Let's check the exact similarity between blocks we care about.

def get_similarity(block1_name, block2_name, embeddings, tok2block):
    """Get cosine similarity between two blocks."""
    block2tok = {v.replace('minecraft:', '').split('[')[0]: k for k, v in tok2block.items()}
    
    tok1 = block2tok.get(block1_name)
    tok2 = block2tok.get(block2_name)
    
    if tok1 is None or tok2 is None:
        return None
    
    emb1 = embeddings[tok1]
    emb2 = embeddings[tok2]
    return np.dot(emb1, emb2) / (np.linalg.norm(emb1) * np.linalg.norm(emb2))


print("\n" + "="*60)
print("SPECIFIC BLOCK PAIR SIMILARITIES")
print("="*60)
print("\nThese pairs SHOULD be highly similar in V3:")

# Block pairs that should be similar
should_be_similar = [
    ("oak_planks", "spruce_planks"),
    ("oak_planks", "dark_oak_planks"),
    ("oak_stairs", "spruce_stairs"),
    ("oak_stairs", "stone_stairs"),
    ("white_wool", "red_wool"),
    ("white_concrete", "black_concrete"),
]

print(f"\n{'Block 1':<20} {'Block 2':<20} {'Similarity':>12}  {'Status'}")
print("-"*65)

for b1, b2 in should_be_similar:
    sim = get_similarity(b1, b2, embeddings, tok2block)
    if sim is not None:
        status = "GREAT" if sim > 0.8 else "Good" if sim > 0.5 else "Poor"
        print(f"{b1:<20} {b2:<20} {sim:>11.3f}   {status}")

print("\nThese pairs should have LOWER similarity (different shapes):")

should_be_different = [
    ("oak_planks", "oak_stairs"),
    ("oak_planks", "stone"),
    ("glass", "dirt"),
]

print(f"\n{'Block 1':<20} {'Block 2':<20} {'Similarity':>12}")
print("-"*55)

for b1, b2 in should_be_different:
    sim = get_similarity(b1, b2, embeddings, tok2block)
    if sim is not None:
        print(f"{b1:<20} {b2:<20} {sim:>11.3f}")

---

# Part 10: Visualizations

Let's create visualizations to understand what the model learned.

In [None]:
# ============================================================
# CELL 15: Create Visualizations
# ============================================================

fig, axes = plt.subplots(2, 2, figsize=(14, 12))

# 1. Training Loss
ax = axes[0, 0]
ax.plot(history['loss'], label='Total Loss', linewidth=2)
ax.plot(history['pos_loss'], label='Positive Loss', alpha=0.7)
ax.plot(history['neg_loss'], label='Negative Loss', alpha=0.7)
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('Training Loss')
ax.legend()
ax.grid(True, alpha=0.3)

# 2. t-SNE colored by shape
ax = axes[0, 1]
print("Running t-SNE...")

# Sample blocks
np.random.seed(42)
sample_ids = np.random.choice(VOCAB_SIZE, min(500, VOCAB_SIZE), replace=False)
sample_embs = embeddings[sample_ids]

tsne = TSNE(n_components=2, random_state=42, perplexity=30)
coords = tsne.fit_transform(sample_embs)

# Color by shape category
shape_colors = {}
color_idx = 0
cmap = plt.cm.get_cmap('tab20')

for sid in sample_ids:
    shape = block_components[sid].shape
    if shape not in shape_colors:
        shape_colors[shape] = cmap(color_idx % 20)
        color_idx += 1

colors = [shape_colors[block_components[sid].shape] for sid in sample_ids]
ax.scatter(coords[:, 0], coords[:, 1], c=colors, s=20, alpha=0.6)
ax.set_title('t-SNE (colored by shape)')
ax.set_xticks([])
ax.set_yticks([])

# 3. Shape Coherence Bar Chart
ax = axes[1, 0]
top_shapes = sorted_shapes[:15]
shape_names = [s[0] for s in top_shapes]
coherences = [s[1]['avg_similarity'] for s in top_shapes]
colors_bar = ['green' if c > 0.5 else 'orange' if c > 0.3 else 'red' for c in coherences]
ax.barh(shape_names, coherences, color=colors_bar)
ax.set_xlabel('Average Cosine Similarity')
ax.set_title('Shape Coherence (Top 15 by count)')
ax.set_xlim(0, 1)
ax.axvline(x=0.5, color='green', linestyle='--', alpha=0.5, label='Good threshold')

# 4. Material Embedding Heatmap
ax = axes[1, 1]
mat_embs = model.material_emb.weight.detach().cpu().numpy()
mat_sims = cosine_similarity(mat_embs[:20], mat_embs[:20])
im = ax.imshow(mat_sims, cmap='RdYlGn', vmin=-1, vmax=1)
ax.set_xticks(range(20))
ax.set_yticks(range(20))
ax.set_xticklabels(materials_list[:20], rotation=45, ha='right', fontsize=8)
ax.set_yticklabels(materials_list[:20], fontsize=8)
ax.set_title('Material Embedding Similarities')
plt.colorbar(im, ax=ax)

plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/block2vec_v3_results.png", dpi=150)
plt.show()

print(f"\nSaved visualization to {OUTPUT_DIR}/block2vec_v3_results.png")

---

# Part 11: Final Summary

Let's summarize what V3 achieved compared to V1 and V2.

In [None]:
# ============================================================
# CELL 16: Final Summary
# ============================================================

print("\n" + "="*70)
print("BLOCK2VEC V3 TRAINING COMPLETE")
print("="*70)

print(f"\n## Model Architecture")
print(f"  - Material embedding: {NUM_MATERIALS} materials × {MATERIAL_DIM} dims")
print(f"  - Shape embedding: {NUM_SHAPES} shapes × {SHAPE_DIM} dims")
print(f"  - Property embedding: {NUM_PROPERTIES} properties × {PROPERTY_DIM} dims")
print(f"  - Total embedding dimension: {model.embedding_dim}")

print(f"\n## Training")
print(f"  - Epochs: {EPOCHS}")
print(f"  - Batch size: {BATCH_SIZE}")
print(f"  - Final loss: {history['loss'][-1]:.4f}")
print(f"  - Training time: {total_time/60:.1f} minutes")

print(f"\n## Key Coherence Scores (V3 vs V1 vs V2)")
print(f"  {'Shape':<20} {'V3':>10} {'V1':>10} {'V2':>10}")
print(f"  {'-'*50}")
key_shapes = ['planks', 'stairs', 'slab', 'wool', 'concrete', 'door']
v1_approx = {'planks': 0.20, 'stairs': 0.60, 'slab': 0.40, 'wool': 0.30, 'concrete': 0.25, 'door': 0.50}
v2_approx = {'planks': 0.025, 'stairs': 0.44, 'slab': 0.30, 'wool': 0.20, 'concrete': 0.15, 'door': 0.40}

for shape in key_shapes:
    if shape in coherence:
        v3_val = coherence[shape]['avg_similarity']
        v1_val = v1_approx.get(shape, 0)
        v2_val = v2_approx.get(shape, 0)
        improvement = "↑↑↑" if v3_val > v1_val * 2 else "↑↑" if v3_val > v1_val * 1.5 else "↑" if v3_val > v1_val else ""
        print(f"  {shape:<20} {v3_val:>9.1%} {v1_val:>9.1%} {v2_val:>9.1%}  {improvement}")

print(f"\n## Output Files")
print(f"  - {OUTPUT_DIR}/block_embeddings_v3.npy")
print(f"  - {OUTPUT_DIR}/component_embeddings_v3.npz")
print(f"  - {OUTPUT_DIR}/training_history_v3.json")
print(f"  - {OUTPUT_DIR}/vocab_info_v3.json")
print(f"  - {OUTPUT_DIR}/block2vec_v3_results.png")

print("\n" + "="*70)
print("V3 uses COMPOSITIONAL embeddings to GUARANTEE shape similarity.")
print("This solves the fundamental problem that V1 and V2 couldn't fix.")
print("="*70)

---

# Part 12: Diagnostics and Failure Detection

These diagnostics help us understand if V3 is working correctly and identify potential issues. If V3 fails, these metrics will help us understand why.

## Key Diagnostics

1. **Similarity Distribution**: Within-shape similarity should be HIGH, across-shape should be LOWER
2. **Embedding Collapse Detection**: Check if embeddings are degenerating
3. **Component Usage Statistics**: Identify underrepresented components
4. **Nearest Neighbor Quality**: Is each block's nearest neighbor the same shape?
5. **Embedding Norm Distribution**: Are norms uniform across blocks?

In [None]:
# ============================================================
# CELL 17: Similarity Distribution Analysis
# ============================================================
# This is a KEY DIAGNOSTIC: Within-shape similarity should be HIGH,
# across-shape similarity should be LOWER. The gap between them
# tells us if V3's compositional structure is working.

from scipy.spatial.distance import pdist

print("="*60)
print("SIMILARITY DISTRIBUTION ANALYSIS")
print("="*60)

# Compute within-shape similarities
within_shape_sims = []
across_shape_sims = []

# Group blocks by shape
shape_groups = defaultdict(list)
for token_id, comp in block_components.items():
    shape_groups[comp.shape].append(token_id)

# Sample within-shape pairs
print("\nComputing within-shape similarities...")
for shape, token_ids in shape_groups.items():
    if len(token_ids) >= 2:
        shape_embs = embeddings[token_ids]
        sims = cosine_similarity(shape_embs)
        # Get upper triangle (excluding diagonal)
        n = len(token_ids)
        for i in range(n):
            for j in range(i+1, n):
                within_shape_sims.append(sims[i, j])

# Sample across-shape pairs (sample to avoid O(n^2) computation)
print("Computing across-shape similarities (sampling)...")
np.random.seed(42)
shapes_with_blocks = [s for s, ids in shape_groups.items() if len(ids) >= 2]
n_samples = min(10000, len(within_shape_sims))

for _ in range(n_samples):
    # Pick two different shapes
    s1, s2 = np.random.choice(shapes_with_blocks, 2, replace=False)
    # Pick one block from each
    t1 = np.random.choice(shape_groups[s1])
    t2 = np.random.choice(shape_groups[s2])
    # Compute similarity
    sim = np.dot(embeddings[t1], embeddings[t2]) / (
        np.linalg.norm(embeddings[t1]) * np.linalg.norm(embeddings[t2])
    )
    across_shape_sims.append(sim)

within_shape_sims = np.array(within_shape_sims)
across_shape_sims = np.array(across_shape_sims)

# Statistics
print(f"\nWithin-Shape Similarity:")
print(f"  Mean: {within_shape_sims.mean():.3f}")
print(f"  Std:  {within_shape_sims.std():.3f}")
print(f"  Min:  {within_shape_sims.min():.3f}")
print(f"  Max:  {within_shape_sims.max():.3f}")

print(f"\nAcross-Shape Similarity:")
print(f"  Mean: {across_shape_sims.mean():.3f}")
print(f"  Std:  {across_shape_sims.std():.3f}")
print(f"  Min:  {across_shape_sims.min():.3f}")
print(f"  Max:  {across_shape_sims.max():.3f}")

separation = within_shape_sims.mean() - across_shape_sims.mean()
print(f"\n*** SEPARATION (within - across): {separation:.3f} ***")
print("  > 0.3 = GOOD (clear distinction)")
print("  > 0.5 = GREAT (strong separation)")
print("  < 0.1 = POOR (shapes not distinct)")

# Save for visualization
diagnostics = {
    'within_shape_mean': float(within_shape_sims.mean()),
    'within_shape_std': float(within_shape_sims.std()),
    'across_shape_mean': float(across_shape_sims.mean()),
    'across_shape_std': float(across_shape_sims.std()),
    'separation': float(separation),
}

In [None]:
# ============================================================
# CELL 18: Embedding Collapse Detection
# ============================================================
# Check if embeddings are degenerating (becoming too similar or uniform)

print("="*60)
print("EMBEDDING COLLAPSE DETECTION")
print("="*60)

# Check embedding standard deviation per dimension
emb_std_per_dim = embeddings.std(axis=0)
print(f"\nPer-dimension std dev:")
print(f"  Mean: {emb_std_per_dim.mean():.4f}")
print(f"  Min:  {emb_std_per_dim.min():.4f}")
print(f"  Max:  {emb_std_per_dim.max():.4f}")

# Count collapsed dimensions (std < 0.01)
collapsed_dims = (emb_std_per_dim < 0.01).sum()
print(f"\n  Collapsed dimensions (<0.01 std): {collapsed_dims}/{len(emb_std_per_dim)}")
if collapsed_dims > len(emb_std_per_dim) // 4:
    print("  ⚠️ WARNING: Many dimensions have collapsed!")

# Check pairwise distances
print("\nPairwise distance statistics (sampling 1000 pairs)...")
np.random.seed(42)
sample_indices = np.random.choice(len(embeddings), min(1000, len(embeddings)), replace=False)
sample_embs = embeddings[sample_indices]
pairwise_dists = pdist(sample_embs, metric='cosine')

print(f"  Min distance: {pairwise_dists.min():.4f}")
print(f"  Mean distance: {pairwise_dists.mean():.4f}")
print(f"  Max distance: {pairwise_dists.max():.4f}")

if pairwise_dists.min() < 0.001:
    print("  ⚠️ WARNING: Some embeddings are nearly identical!")

# Check embedding norms
norms = np.linalg.norm(embeddings, axis=1)
print(f"\nEmbedding norms:")
print(f"  Mean: {norms.mean():.4f}")
print(f"  Std:  {norms.std():.4f}")
print(f"  Min:  {norms.min():.4f}")
print(f"  Max:  {norms.max():.4f}")

if norms.std() / norms.mean() > 0.5:
    print("  ⚠️ WARNING: Norm variance is high (some blocks have very different scales)")

diagnostics['collapsed_dims'] = int(collapsed_dims)
diagnostics['norm_mean'] = float(norms.mean())
diagnostics['norm_std'] = float(norms.std())

In [None]:
# ============================================================
# CELL 19: Nearest Neighbor Quality
# ============================================================
# For each block, is its nearest neighbor the same shape?
# This is a strong test of embedding quality.

print("="*60)
print("NEAREST NEIGHBOR QUALITY")
print("="*60)

# Compute nearest neighbors
print("\nFinding nearest neighbors for all blocks...")
all_sims = cosine_similarity(embeddings)

# For each block, find nearest neighbor (excluding self)
nn_same_shape = []
nn_same_material = []

for i in range(len(embeddings)):
    # Set self-similarity to -inf
    sims = all_sims[i].copy()
    sims[i] = -float('inf')
    
    # Find nearest neighbor
    nearest = sims.argmax()
    
    # Check if same shape
    my_shape = block_components[i].shape
    nn_shape = block_components[nearest].shape
    nn_same_shape.append(my_shape == nn_shape)
    
    # Check if same material
    my_mat = block_components[i].material
    nn_mat = block_components[nearest].material
    nn_same_material.append(my_mat == nn_mat)

nn_same_shape_pct = np.mean(nn_same_shape)
nn_same_material_pct = np.mean(nn_same_material)

print(f"\nNearest Neighbor Results:")
print(f"  Same Shape:    {nn_same_shape_pct:>6.1%}")
print(f"  Same Material: {nn_same_material_pct:>6.1%}")

print("\nInterpretation:")
print("  V3 expects HIGHER same-shape NN than V1/V2")
print("  > 50% same shape = GOOD")
print("  > 70% same shape = GREAT")
print("  Random baseline ~5% (if 20 shapes)")

diagnostics['nn_same_shape'] = float(nn_same_shape_pct)
diagnostics['nn_same_material'] = float(nn_same_material_pct)

In [None]:
# ============================================================
# CELL 20: Diagnostic Visualizations
# ============================================================

fig, axes = plt.subplots(2, 2, figsize=(14, 12))

# 1. Similarity Distribution Histogram
ax = axes[0, 0]
ax.hist(within_shape_sims, bins=50, alpha=0.7, label=f'Same Shape (μ={within_shape_sims.mean():.2f})', color='green', density=True)
ax.hist(across_shape_sims, bins=50, alpha=0.7, label=f'Different Shape (μ={across_shape_sims.mean():.2f})', color='red', density=True)
ax.axvline(x=0.5, color='black', linestyle='--', alpha=0.5, label='Threshold')
ax.set_xlabel('Cosine Similarity')
ax.set_ylabel('Density')
ax.set_title(f'Similarity Distribution (Separation: {separation:.3f})')
ax.legend()
ax.grid(True, alpha=0.3)

# 2. Embedding Norm Distribution
ax = axes[0, 1]
ax.hist(norms, bins=50, color='blue', alpha=0.7)
ax.axvline(norms.mean(), color='red', linestyle='-', linewidth=2, label=f'Mean: {norms.mean():.2f}')
ax.axvline(norms.mean() - norms.std(), color='red', linestyle='--', alpha=0.5)
ax.axvline(norms.mean() + norms.std(), color='red', linestyle='--', alpha=0.5)
ax.set_xlabel('Embedding L2 Norm')
ax.set_ylabel('Count')
ax.set_title('Embedding Norm Distribution')
ax.legend()
ax.grid(True, alpha=0.3)

# 3. Component Embedding t-SNE: Materials
ax = axes[1, 0]
mat_embs = model.material_emb.weight.detach().cpu().numpy()
if len(mat_embs) > 3:
    tsne_mat = TSNE(n_components=2, random_state=42, perplexity=min(5, len(mat_embs)-1))
    mat_coords = tsne_mat.fit_transform(mat_embs)
    ax.scatter(mat_coords[:, 0], mat_coords[:, 1], c='blue', s=50, alpha=0.7)
    # Label some materials
    for i, mat in enumerate(materials_list[:15]):
        ax.annotate(mat, (mat_coords[i, 0], mat_coords[i, 1]), fontsize=8, alpha=0.8)
ax.set_title('Material Embedding Space (t-SNE)')
ax.set_xticks([])
ax.set_yticks([])

# 4. Shape Embedding t-SNE
ax = axes[1, 1]
shp_embs = model.shape_emb.weight.detach().cpu().numpy()
if len(shp_embs) > 3:
    # Sample if too many shapes
    n_shapes = min(50, len(shp_embs))
    shape_indices = np.random.choice(len(shp_embs), n_shapes, replace=False)
    shp_sample = shp_embs[shape_indices]
    shape_names_sample = [shapes_list[i] for i in shape_indices]
    
    tsne_shp = TSNE(n_components=2, random_state=42, perplexity=min(10, n_shapes-1))
    shp_coords = tsne_shp.fit_transform(shp_sample)
    ax.scatter(shp_coords[:, 0], shp_coords[:, 1], c='green', s=50, alpha=0.7)
    # Label shapes
    for i, shp in enumerate(shape_names_sample):
        ax.annotate(shp, (shp_coords[i, 0], shp_coords[i, 1]), fontsize=7, alpha=0.8)
ax.set_title('Shape Embedding Space (t-SNE)')
ax.set_xticks([])
ax.set_yticks([])

plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/block2vec_v3_diagnostics.png", dpi=150)
plt.show()

print(f"\nSaved diagnostic visualization to {OUTPUT_DIR}/block2vec_v3_diagnostics.png")

In [None]:
# ============================================================
# CELL 21: Save All Diagnostics
# ============================================================

# Add coherence scores to diagnostics
diagnostics['overall_coherence'] = float(overall)
diagnostics['key_shape_coherence'] = {
    shape: float(coherence[shape]['avg_similarity']) 
    for shape in key_shapes if shape in coherence
}

# Save diagnostics
with open(f"{OUTPUT_DIR}/diagnostics_v3.json", 'w') as f:
    json.dump(diagnostics, f, indent=2)

print("="*60)
print("DIAGNOSTICS SUMMARY")
print("="*60)
print(f"\n1. Similarity Separation: {diagnostics['separation']:.3f}")
print(f"   (Higher = better shape distinction)")
print(f"\n2. Nearest Neighbor Same Shape: {diagnostics['nn_same_shape']:.1%}")
print(f"   (Higher = better shape clustering)")
print(f"\n3. Collapsed Dimensions: {diagnostics['collapsed_dims']}/{TOTAL_EMBEDDING_DIM}")
print(f"   (Lower = healthier embeddings)")
print(f"\n4. Embedding Norm Std: {diagnostics['norm_std']:.4f}")
print(f"   (Moderate = consistent embedding scales)")
print(f"\n5. Overall Coherence: {diagnostics['overall_coherence']:.1%}")
print(f"   (V1 was ~20%, V2 was ~16%)")

print(f"\nDiagnostics saved to {OUTPUT_DIR}/diagnostics_v3.json")
print("\n" + "="*60)
print("If V3 fails, check these metrics to understand why:")
print("  - Low separation → Components not distinct enough")
print("  - Low NN same shape → Shape embedding not dominating")
print("  - Many collapsed dims → Optimization issue")
print("  - High norm variance → Imbalanced training")
print("="*60)