# Superconductor VAE — Generative Evaluation & Novel Discovery

Evaluate generative capability and discover novel superconductors.

**Core idea**: Phase 2 self-supervised training generates thousands of formulas per sub-epoch.
While training the model, we simultaneously check every generated formula against:
1. The **45 held-out superconductors** (holdout recovery)
2. The training set (known vs novel classification)

This means Phase 2 training IS the search — it improves the model AND discovers materials at the same time.

**Workflow**:
1. **Baseline**: Quick roundtrip validation (encode holdout → decode, ~1 min)
2. **Phase 2 Loop** (main event): Train N sub-epochs, each generating + filtering + checking formulas
   - Holdout recoveries flagged in real-time
   - Novel candidates logged to `phase2_discoveries.jsonl`
3. **Post-training search**: Run targeted holdout search on the **improved** model
4. **Broad exploration**: Generate novel candidates from the improved Z-space

**GPU**: Any GPU works. Larger GPUs = more samples per sub-epoch (auto-scaled).

**Families**: YBCO, LSCO, Hg-cuprate, Tl-cuprate, Bi-cuprate, Iron-based, MgB2, Conventional, Other

**Prerequisites** (one-time setup):
- A checkpoint in `outputs/` (from `train_colab.ipynb` or uploaded manually)
- Data cache is built automatically on first run (~5 min)
- Z-cache is built automatically on first run (~2 min on A100)


## Cell 1: Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## Cell 1b: Sync Repo with GitHub

Pull the latest code from GitHub so Colab matches your most recent push.

In [None]:
# Sync repo with GitHub (clone or pull latest changes)
_REPO = "/content/drive/My Drive/Colab Notebooks/SuperconductorVAE/superconductor-vae"
_GITHUB_URL = "https://github.com/jamesconde/superconductor-vae.git"

import subprocess, os

def _run_git(*args, **kwargs):
    """Run a git command in _REPO, return (stdout, stderr, returncode)."""
    r = subprocess.run(list(args), cwd=_REPO, capture_output=True, text=True, timeout=120, **kwargs)
    return r.stdout.strip(), r.stderr.strip(), r.returncode

if os.path.isdir(os.path.join(_REPO, '.git')):
    # Repo exists — ensure 'origin' remote is configured
    out, err, rc = _run_git("git", "remote", "get-url", "origin")
    if rc != 0:
        print("Git repo detected but no 'origin' remote — adding it...")
        _run_git("git", "remote", "add", "origin", _GITHUB_URL)
        print(f"  Added origin → {_GITHUB_URL}")

    print("Pulling latest from GitHub...")
    out, err, rc = _run_git("git", "pull", "--ff-only", "origin", "main")
    if rc != 0:
        # Try without specifying branch (maybe default branch differs)
        out, err, rc = _run_git("git", "pull", "--ff-only")
    print(out)
    if rc != 0:
        print(f"git pull failed (exit {rc}):")
        print(err)
        print("\nTo fix: delete and re-clone:")
        print(f"  !rm -rf '{_REPO}'")
        print("  Then re-run this cell.")
    else:
        out, _, _ = _run_git("git", "log", "--oneline", "-1")
        print(f"Current commit: {out}")

elif os.path.isdir(os.path.dirname(_REPO)):
    # Parent dir exists but repo not cloned yet — clone it
    print(f"Cloning repo from GitHub (first time)...")
    r = subprocess.run(
        ["git", "clone", _GITHUB_URL, _REPO],
        capture_output=True, text=True, timeout=300,
    )
    print(r.stdout.strip())
    if r.returncode != 0:
        print(f"Clone failed: {r.stderr.strip()}")
    else:
        out, _, _ = _run_git("git", "log", "--oneline", "-1")
        print(f"Cloned successfully. Current commit: {out}")
else:
    print(f"Parent directory not found: {os.path.dirname(_REPO)}")
    print("Make sure Google Drive is mounted and the path is correct.")
    print(f"Expected: {_REPO}")

## Cell 2: Configuration

In [None]:
# Path to the superconductor-vae repo on your Google Drive
REPO_PATH = "/content/drive/My Drive/Colab Notebooks/SuperconductorVAE/superconductor-vae"

# Checkpoint to evaluate (uses Drive-based outputs/)
CHECKPOINT = 'auto'  # 'auto' finds checkpoint_best.pt, or set explicit path

# --- Phase 2 training + discovery (the main event) ---
# Phase 2 generates formulas as part of self-supervised training.
# Every formula is checked against holdout set + training set.
# This IS the search — training and discovery happen simultaneously.
PHASE2_EPOCHS = 20          # Number of Phase 2 sub-epochs (set 0 for inference-only)
PHASE2_LR = 1e-5            # Learning rate for Phase 2 fine-tuning
PHASE2_HOLDOUT_INTERVAL = 5 # Run mini holdout search every N sub-epochs

# --- Search budget (for post-Phase-2 targeted holdout search) ---
# 'auto' scales with GPU VRAM. Override with 'small', 'medium', 'large', 'xlarge'.
SEARCH_BUDGET = 'auto'

# --- Novel discovery (for post-Phase-2 broad Z-space exploration) ---
# 'auto' scales with VRAM: T4=10K, A100-40=50K, A100-80=100K
NOVEL_N_SAMPLES = 'auto'

# Minimum predicted Tc (K) to report as novel candidate
NOVEL_MIN_TC = 10.0

## Cell 3: Install Dependencies

In [None]:
!pip install -q scipy matminer pymatgen scikit-learn

## Cell 4: Setup Paths and Verify Environment

In [None]:
import sys
import os
import json
import time
from pathlib import Path
from collections import defaultdict

import torch
import torch.nn.functional as F
import numpy as np

repo = Path(REPO_PATH)

# Add src/ to Python path
src_path = str(repo / 'src')
if src_path not in sys.path:
    sys.path.insert(0, src_path)

# Verify key files
required_files = {
    'Training data': repo / 'data/processed/supercon_fractions_contrastive.csv',
    'Holdout set': repo / 'data/GENERATIVE_HOLDOUT_DO_NOT_TRAIN.json',
    'VAE model': repo / 'src/superconductor/models/attention_vae.py',
    'Decoder model': repo / 'src/superconductor/models/autoregressive_decoder.py',
}
optional_files = {
    'Best checkpoint': repo / 'outputs/checkpoint_best.pt',
    'Z-cache': repo / 'outputs/latent_cache.pt',
    'Phase 2 discoveries': repo / 'outputs/phase2_discoveries.jsonl',
}

all_found = True
for name, path in required_files.items():
    exists = path.exists()
    status = 'OK' if exists else 'MISSING'
    print(f'  [{status}] {name}: {path.name}')
    if not exists:
        all_found = False

print()
for name, path in optional_files.items():
    exists = path.exists()
    status = 'OK' if exists else '---'
    print(f'  [{status}] {name}: {path.name}')

if not all_found:
    raise FileNotFoundError(f'Missing required files. Check REPO_PATH: {REPO_PATH}')

# Resolve checkpoint path
if CHECKPOINT == 'auto':
    CHECKPOINT_PATH = str(repo / 'outputs/checkpoint_best.pt')
    if not os.path.exists(CHECKPOINT_PATH):
        epoch_files = sorted((repo / 'outputs').glob('checkpoint_epoch_*.pt'))
        if epoch_files:
            CHECKPOINT_PATH = str(epoch_files[-1])
        else:
            raise FileNotFoundError('No checkpoint found in outputs/')
else:
    CHECKPOINT_PATH = str(repo / CHECKPOINT)

print(f'\nCheckpoint: {Path(CHECKPOINT_PATH).name}')
print(f'Size: {os.path.getsize(CHECKPOINT_PATH)/1e6:.1f} MB')

# GPU info + VRAM-aware scaling
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
VRAM_GB = 0.0
if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    VRAM_GB = torch.cuda.get_device_properties(0).total_memory / (1024**3)
    print(f'\nGPU: {gpu_name}')
    print(f'VRAM: {VRAM_GB:.1f} GB')
else:
    print('\nWARNING: No GPU detected. Running on CPU (slow).')

# --- Auto-scale search parameters based on VRAM ---
if SEARCH_BUDGET == 'auto':
    if VRAM_GB >= 70:
        SEARCH_BUDGET = 'xlarge'
    elif VRAM_GB >= 38:
        SEARCH_BUDGET = 'large'
    elif VRAM_GB >= 14:
        SEARCH_BUDGET = 'medium'
    else:
        SEARCH_BUDGET = 'small'

BUDGET_CONFIGS = {
    'small':  {'encode_batch': 128, 'decode_batch': 64,  'n_perturbations': 50,
               'noise_scales': [0.05, 0.1, 0.2, 0.3], 'n_seeds': 20,
               'n_interp_steps': 10, 'max_pairs': 50,
               'n_temp_samples': 15, 'temps': [0.01, 0.1, 0.3, 0.5],
               'n_pca_comps': 10, 'pca_steps': 10, 'n_neighbors': 60},
    'medium': {'encode_batch': 256, 'decode_batch': 128, 'n_perturbations': 100,
               'noise_scales': [0.02, 0.05, 0.08, 0.1, 0.15, 0.2, 0.3, 0.5], 'n_seeds': 30,
               'n_interp_steps': 15, 'max_pairs': 100,
               'n_temp_samples': 30, 'temps': [0.01, 0.05, 0.1, 0.2, 0.3, 0.5, 0.7, 1.0],
               'n_pca_comps': 20, 'pca_steps': 20, 'n_neighbors': 100},
    'large':  {'encode_batch': 512, 'decode_batch': 256, 'n_perturbations': 200,
               'noise_scales': [0.02, 0.05, 0.08, 0.1, 0.12, 0.15, 0.2, 0.3, 0.4, 0.5],
               'n_seeds': 50, 'n_interp_steps': 20, 'max_pairs': 150,
               'n_temp_samples': 50, 'temps': [0.01, 0.03, 0.05, 0.1, 0.15, 0.2, 0.3, 0.5, 0.7, 1.0],
               'n_pca_comps': 30, 'pca_steps': 30, 'n_neighbors': 150},
    'xlarge': {'encode_batch': 1024, 'decode_batch': 512, 'n_perturbations': 300,
               'noise_scales': [0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.12, 0.15, 0.2, 0.25, 0.3, 0.4, 0.5],
               'n_seeds': 80, 'n_interp_steps': 25, 'max_pairs': 200,
               'n_temp_samples': 80, 'temps': [0.01, 0.02, 0.05, 0.08, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.7, 1.0],
               'n_pca_comps': 40, 'pca_steps': 40, 'n_neighbors': 200},
}
B = BUDGET_CONFIGS[SEARCH_BUDGET]

if NOVEL_N_SAMPLES == 'auto':
    NOVEL_N_SAMPLES = int(max(5000, min(100000, VRAM_GB * 1250)))

print(f'\nSearch budget: {SEARCH_BUDGET}')
print(f'  encode_batch={B["encode_batch"]}, decode_batch={B["decode_batch"]}')
print(f'  seeds={B["n_seeds"]}, perturbations={B["n_perturbations"]}, noise_scales={len(B["noise_scales"])}')
print(f'  temps={len(B["temps"])}, temp_samples={B["n_temp_samples"]}')
print(f'Novel discovery: {NOVEL_N_SAMPLES:,} Z-space samples')
print(f'Phase 2 epochs: {PHASE2_EPOCHS}')

## Cell 5: Constants & Helpers

In [None]:
from superconductor.models.attention_vae import FullMaterialsVAE
from superconductor.models.autoregressive_decoder import (
    EnhancedTransformerDecoder, IDX_TO_TOKEN, TOKEN_TO_IDX,
    PAD_IDX, START_IDX, END_IDX,
)
from superconductor.data.canonical_ordering import CanonicalOrderer
from superconductor.tokenizer.fraction_tokenizer import FractionAwareTokenizer

# V13+ tokenizer for semantic fraction + isotope tokens (vocab 4752)
_frac_vocab = str(repo / 'data' / 'fraction_vocab.json')
_iso_vocab = str(repo / 'data' / 'isotope_vocab.json')
V13_TOKENIZER = FractionAwareTokenizer(_frac_vocab, isotope_vocab_path=_iso_vocab)
print(f'V13 tokenizer: vocab_size={V13_TOKENIZER.vocab_size}')

PROJECT_ROOT = repo
HOLDOUT_PATH = repo / 'data' / 'GENERATIVE_HOLDOUT_DO_NOT_TRAIN.json'
CACHE_DIR = repo / 'data' / 'processed' / 'cache'
OUTPUT_DIR = repo / 'outputs'

# TC normalization constants — loaded from cache_meta.json after data loading.
# These are placeholder defaults; overwritten in Cell 6 (Load Model & Data).
TC_MEAN = 2.725219433789196
TC_STD = 1.3527019896187407

FAMILY_14_NAMES = [
    'NOT_SC', 'BCS_CONVENTIONAL', 'CUPRATE_YBCO', 'CUPRATE_LSCO',
    'CUPRATE_BSCCO', 'CUPRATE_TBCCO', 'CUPRATE_HBCCO', 'CUPRATE_OTHER',
    'IRON_PNICTIDE', 'IRON_CHALCOGENIDE', 'MGB2_TYPE',
    'HEAVY_FERMION', 'ORGANIC', 'OTHER_UNKNOWN',
]

HOLDOUT_FAMILY_TO_14 = {
    'YBCO': 2, 'LSCO': 3, 'Hg-cuprate': 6, 'Tl-cuprate': 5,
    'Bi-cuprate': 4, 'Iron-based': 8, 'MgB2': 10,
    'Conventional': 1, 'Other': 13,
}

ELEMENT_TO_Z = {
    'H':1,'He':2,'Li':3,'Be':4,'B':5,'C':6,'N':7,'O':8,'F':9,'Ne':10,
    'Na':11,'Mg':12,'Al':13,'Si':14,'P':15,'S':16,'Cl':17,'Ar':18,'K':19,'Ca':20,
    'Sc':21,'Ti':22,'V':23,'Cr':24,'Mn':25,'Fe':26,'Co':27,'Ni':28,'Cu':29,'Zn':30,
    'Ga':31,'Ge':32,'As':33,'Se':34,'Br':35,'Kr':36,'Rb':37,'Sr':38,'Y':39,'Zr':40,
    'Nb':41,'Mo':42,'Tc':43,'Ru':44,'Rh':45,'Pd':46,'Ag':47,'Cd':48,'In':49,'Sn':50,
    'Sb':51,'Te':52,'I':53,'Xe':54,'Cs':55,'Ba':56,'La':57,'Ce':58,'Pr':59,'Nd':60,
    'Pm':61,'Sm':62,'Eu':63,'Gd':64,'Tb':65,'Dy':66,'Ho':67,'Er':68,'Tm':69,'Yb':70,
    'Lu':71,'Hf':72,'Ta':73,'W':74,'Re':75,'Os':76,'Ir':77,'Pt':78,'Au':79,'Hg':80,
    'Tl':81,'Pb':82,'Bi':83,'Po':84,'At':85,'Rn':86,'Fr':87,'Ra':88,'Ac':89,'Th':90,
    'Pa':91,'U':92,
}
Z_TO_ELEMENT = {v: k for k, v in ELEMENT_TO_Z.items()}
_CANONICALIZER = CanonicalOrderer()


def tokens_to_formula(token_ids):
    """Convert token IDs to formula string using V13+ tokenizer."""
    ids = [int(t) for t in token_ids]
    return V13_TOKENIZER.decode(ids, strip_special=True)


def denormalize_tc(tc_norm):
    """Convert normalized Tc prediction back to Kelvin."""
    tc_log = tc_norm * TC_STD + TC_MEAN
    return max(0.0, float(np.expm1(tc_log)))


def parse_formula_elements(formula):
    """Extract {element: fraction_value} from formula string."""
    try:
        elements = _CANONICALIZER.parse_formula(formula)
        if not elements:
            return {}
        result = {}
        for ef in elements:
            result[ef.element] = result.get(ef.element, 0) + ef.fraction_value
        return result
    except Exception:
        return {}


def element_similarity(formula_a, formula_b):
    """Compositional similarity: Jaccard on elements + fraction overlap."""
    parsed_a = parse_formula_elements(formula_a)
    parsed_b = parse_formula_elements(formula_b)
    if not parsed_a or not parsed_b:
        return 0.0
    all_elements = set(parsed_a.keys()) | set(parsed_b.keys())
    shared = set(parsed_a.keys()) & set(parsed_b.keys())
    if not all_elements:
        return 0.0
    jaccard = len(shared) / len(all_elements)
    frac_sim = 0.0
    if shared:
        total_a = sum(parsed_a.values())
        total_b = sum(parsed_b.values())
        for elem in shared:
            fa = parsed_a[elem] / max(total_a, 1e-8)
            fb = parsed_b[elem] / max(total_b, 1e-8)
            frac_sim += min(fa, fb)
    return 0.5 * jaccard + 0.5 * frac_sim


def element_overlap_score(target_elements, cache_elem_idx):
    """Score a training sample by how many target elements it contains."""
    candidate_elements = set(int(z) for z in cache_elem_idx if z > 0)
    shared_elems = target_elements & candidate_elements
    all_elem = target_elements | candidate_elements
    if not all_elem:
        return (0, 0.0)
    return (len(shared_elems), len(shared_elems) / len(all_elem))


def slerp(z1, z2, t):
    """Spherical linear interpolation."""
    z1_norm = F.normalize(z1, dim=-1)
    z2_norm = F.normalize(z2, dim=-1)
    omega = torch.acos(torch.clamp(
        (z1_norm * z2_norm).sum(dim=-1, keepdim=True), -1.0, 1.0
    ))
    omega = omega.clamp(min=1e-6)
    sin_omega = torch.sin(omega)
    if sin_omega.abs().min() < 1e-6:
        return (1 - t) * z1 + t * z2
    s1 = torch.sin((1 - t) * omega) / sin_omega
    s2 = torch.sin(t * omega) / sin_omega
    mag1 = z1.norm(dim=-1, keepdim=True)
    mag2 = z2.norm(dim=-1, keepdim=True)
    mag = (1 - t) * mag1 + t * mag2
    return (s1 * z1_norm + s2 * z2_norm) * mag

## Cell 6: Load Model & Data

In [None]:
def load_models(checkpoint_path):
    """Load encoder and decoder from checkpoint."""
    print(f'Loading checkpoint: {Path(checkpoint_path).name}')
    checkpoint = torch.load(checkpoint_path, map_location=DEVICE, weights_only=False)

    enc_state_raw = checkpoint.get('encoder_state_dict', {})
    magpie_dim = 145
    for k, v in enc_state_raw.items():
        if 'magpie_encoder' in k and k.endswith('.weight') and v.dim() == 2:
            magpie_dim = v.shape[1]
            break

    enc_state = {k.replace('_orig_mod.', ''): v for k, v in enc_state_raw.items()}

    has_numden_head = any('numden_head.' in k for k in enc_state)
    old_numden_arch = False
    if 'numden_head.0.weight' in enc_state:
        if enc_state['numden_head.0.weight'].shape[0] == 128:
            old_numden_arch = True

    encoder = FullMaterialsVAE(
        n_elements=118, element_embed_dim=128, n_attention_heads=8,
        magpie_dim=magpie_dim, fusion_dim=256, encoder_hidden=[512, 256],
        latent_dim=2048, decoder_hidden=[256, 512], dropout=0.1
    ).to(DEVICE)

    if old_numden_arch:
        import torch.nn as nn
        encoder.numden_head = nn.Sequential(
            nn.Linear(2048, 128), nn.ReLU(), nn.Linear(128, encoder.max_elements * 2),
        ).to(DEVICE)

    encoder.load_state_dict(enc_state, strict=False)

    dec_state_raw = checkpoint.get('decoder_state_dict', {})
    dec_state = {k.replace('_orig_mod.', ''): v for k, v in dec_state_raw.items()}

    stoich_weight_key = 'stoich_to_memory.0.weight'
    stoich_dim = dec_state[stoich_weight_key].shape[1] if stoich_weight_key in dec_state else 37

    dec_vocab_size = checkpoint.get('tokenizer_vocab_size', None)
    if dec_vocab_size is None and 'token_embedding.weight' in dec_state:
        dec_vocab_size = dec_state['token_embedding.weight'].shape[0]

    _d_model = checkpoint.get('d_model', None)
    if _d_model is None and 'token_embedding.weight' in dec_state:
        _d_model = dec_state['token_embedding.weight'].shape[1]
    _d_model = _d_model or 512
    _dim_ff = checkpoint.get('dim_feedforward', None)
    if _dim_ff is None and 'transformer_decoder.layers.0.linear1.weight' in dec_state:
        _dim_ff = dec_state['transformer_decoder.layers.0.linear1.weight'].shape[0]
    _dim_ff = _dim_ff or 2048
    _nhead = checkpoint.get('nhead', 8)
    _num_layers = checkpoint.get('num_layers', 12)
    # Auto-detect max_len from PE buffer shape (most reliable), then checkpoint key, then default
    _max_len = checkpoint.get('max_formula_len', None)
    if _max_len is None and 'pos_encoding.pe' in dec_state:
        _max_len = dec_state['pos_encoding.pe'].shape[1]
    _max_len = _max_len or 90  # Safe fallback: covers all formula lengths

    decoder = EnhancedTransformerDecoder(
        latent_dim=2048, d_model=_d_model, nhead=_nhead, num_layers=_num_layers,
        dim_feedforward=_dim_ff, dropout=0.1, max_len=_max_len,
        n_memory_tokens=16, encoder_skip_dim=256,
        use_skip_connection=False, use_stoich_conditioning=True,
        max_elements=12, n_stoich_tokens=4,
        vocab_size=dec_vocab_size, stoich_input_dim=stoich_dim,
    ).to(DEVICE)
    decoder.load_state_dict(dec_state, strict=False)

    encoder.eval()
    decoder.eval()

    epoch = checkpoint.get('epoch', '?')
    best_exact = checkpoint.get('best_exact', 0)
    print(f'  Epoch {epoch}, best_exact={best_exact:.4f}, magpie_dim={magpie_dim}, '
          f'd_model={_d_model}, dim_ff={_dim_ff}, max_len={_max_len}, vocab={dec_vocab_size}')
    return encoder, decoder, magpie_dim, has_numden_head, epoch


def _build_cache_if_needed():
    """Build data cache from CSV if it doesn't exist (one-time, ~5 min)."""
    if (CACHE_DIR / 'cache_meta.json').exists():
        # Validate cache was built with V13 tokenization (not stale V12)
        _meta = json.load(open(CACHE_DIR / 'cache_meta.json'))
        if _meta.get('use_semantic_fractions', False):
            return  # Cache is V13, good
        print('WARNING: Existing cache uses V12 tokenization. Rebuilding for V13...')
    print('=' * 60)
    print('DATA CACHE NOT FOUND \u2014 building from CSV (one-time, ~5 min)')
    print('=' * 60)
    _scripts_dir = str(repo)
    if _scripts_dir not in sys.path:
        sys.path.insert(0, _scripts_dir)
    import scripts.train_v12_clean as _train
    _train.TRAIN_CONFIG['use_semantic_fractions'] = True
    _train.TRAIN_CONFIG['use_isotope_tokens'] = True
    _train.TRAIN_CONFIG['max_formula_len'] = 30
    _train.TRAIN_CONFIG['contrastive_mode'] = True
    _train.TRAIN_CONFIG['num_epochs'] = 0
    _result = _train.load_and_prepare_data()
    del _result
    print('Cache built successfully.')


def load_data(magpie_dim):
    """Load cached tensors. Builds cache from CSV if needed (one-time)."""
    _build_cache_if_needed()
    data = {
        'elem_idx': torch.load(CACHE_DIR / 'element_indices.pt', map_location='cpu', weights_only=True),
        'elem_frac': torch.load(CACHE_DIR / 'element_fractions.pt', map_location='cpu', weights_only=True),
        'elem_mask': torch.load(CACHE_DIR / 'element_mask.pt', map_location='cpu', weights_only=True),
        'tc': torch.load(CACHE_DIR / 'tc_tensor.pt', map_location='cpu', weights_only=True),
        'magpie': torch.load(CACHE_DIR / 'magpie_tensor.pt', map_location='cpu', weights_only=True),
        'is_sc': torch.load(CACHE_DIR / 'is_sc_tensor.pt', map_location='cpu', weights_only=True),
        'tokens': torch.load(CACHE_DIR / 'formula_tokens.pt', map_location='cpu', weights_only=True),
    }
    if data['magpie'].shape[1] > magpie_dim:
        data['magpie'] = data['magpie'][:, :magpie_dim]
    meta = json.load(open(CACHE_DIR / 'cache_meta.json'))
    data['train_indices'] = meta.get('train_indices', list(range(len(data['elem_idx']))))
    print(f'  {len(data["elem_idx"])} total samples, {len(data["train_indices"])} train')
    return data, meta


# Load everything
encoder, decoder, magpie_dim, has_numden_head, model_epoch = load_models(CHECKPOINT_PATH)
data, cache_meta = load_data(magpie_dim)

# Override TC normalization constants from cache metadata (more accurate than hardcoded)
TC_MEAN = cache_meta.get('tc_mean', TC_MEAN)
TC_STD = cache_meta.get('tc_std', TC_STD)
print(f'  TC normalization: mean={TC_MEAN:.6f}, std={TC_STD:.6f}')

with open(HOLDOUT_PATH) as f:
    holdout = json.load(f)
holdout_samples = holdout['holdout_samples']
print(f'\n{len(holdout_samples)} holdout materials loaded')
print(f'Families: {sorted(set(s["family"] for s in holdout_samples))}')

# Build training formula set for novelty checking
print('Building training formula index...')
training_formulas = set()
for i in data['train_indices']:
    f = tokens_to_formula(data['tokens'][i])
    if f:
        training_formulas.add(f.strip())
holdout_formulas = set(s['formula'].strip() for s in holdout_samples)
print(f'  {len(training_formulas)} unique training formulas, {len(holdout_formulas)} holdout')

## Cell 7: Core Encode/Decode Functions (VRAM-scaled)

In [None]:
@torch.no_grad()
def encode_indices(encoder, data, indices):
    """Encode dataset indices -> Z vectors. Batch size scales with VRAM."""
    batch_size = B['encode_batch']
    all_z = []
    for start in range(0, len(indices), batch_size):
        batch_idx = indices[start:start + batch_size]
        idx_t = torch.tensor(batch_idx, dtype=torch.long)
        result = encoder.encode(
            data['elem_idx'][idx_t].to(DEVICE),
            data['elem_frac'][idx_t].to(DEVICE),
            data['elem_mask'][idx_t].to(DEVICE),
            data['magpie'][idx_t].to(DEVICE),
            data['tc'][idx_t].to(DEVICE),
        )
        all_z.append(result['z'].cpu())
    return torch.cat(all_z, dim=0)


@torch.no_grad()
def decode_z_batch(encoder, decoder, z_batch, temperature=0.01):
    """Decode Z vectors -> formula strings. Batch size scales with VRAM."""
    batch_size = B['decode_batch']
    all_formulas = []
    for start in range(0, len(z_batch), batch_size):
        z = z_batch[start:start + batch_size].to(DEVICE)
        fraction_output = encoder.fraction_head(z)
        fraction_pred = fraction_output[:, :encoder.max_elements]
        element_count_pred = fraction_output[:, -1]
        if has_numden_head and hasattr(encoder, 'numden_head'):
            numden_pred = encoder.numden_head(z)
            stoich_pred = torch.cat([fraction_pred, numden_pred, element_count_pred.unsqueeze(-1)], dim=-1)
        else:
            stoich_pred = torch.cat([fraction_pred, element_count_pred.unsqueeze(-1)], dim=-1)
        generated, _, _ = decoder.generate_with_kv_cache(
            z=z, stoich_pred=stoich_pred, temperature=temperature,
        )
        for i in range(len(z)):
            all_formulas.append(tokens_to_formula(generated[i]))
    return all_formulas


@torch.no_grad()
def decode_z_with_tc(encoder, decoder, z_batch, temperature=0.01):
    """Decode Z vectors -> (formula, tc_kelvin, sc_prob) tuples."""
    batch_size = B['decode_batch']
    results = []
    for start in range(0, len(z_batch), batch_size):
        z = z_batch[start:start + batch_size].to(DEVICE)
        # Get Tc + SC predictions from encoder heads
        tc_pred_norm = encoder.tc_head(z).squeeze(-1)
        sc_logits = encoder.sc_head(z).squeeze(-1) if hasattr(encoder, 'sc_head') else None
        # Get family prediction
        family_pred = None
        if hasattr(encoder, 'hierarchical_family_head'):
            fam_logits = encoder.hierarchical_family_head(z)
            family_pred = fam_logits.argmax(dim=-1)
        # Decode formula
        fraction_output = encoder.fraction_head(z)
        fraction_pred = fraction_output[:, :encoder.max_elements]
        element_count_pred = fraction_output[:, -1]
        if has_numden_head and hasattr(encoder, 'numden_head'):
            numden_pred = encoder.numden_head(z)
            stoich_pred = torch.cat([fraction_pred, numden_pred, element_count_pred.unsqueeze(-1)], dim=-1)
        else:
            stoich_pred = torch.cat([fraction_pred, element_count_pred.unsqueeze(-1)], dim=-1)
        generated, _, _ = decoder.generate_with_kv_cache(
            z=z, stoich_pred=stoich_pred, temperature=temperature,
        )
        for i in range(len(z)):
            formula = tokens_to_formula(generated[i])
            tc_k = denormalize_tc(tc_pred_norm[i].item())
            sc_p = torch.sigmoid(sc_logits[i]).item() if sc_logits is not None else 0.5
            fam = int(family_pred[i].item()) if family_pred is not None else -1
            results.append({'formula': formula, 'tc_kelvin': tc_k, 'sc_prob': sc_p, 'family': fam})
    return results


def find_element_neighbors(target_formula, data, top_k=100):
    """Find training samples sharing elements with target."""
    parsed = parse_formula_elements(target_formula)
    if not parsed:
        return []
    target_atomic_nums = set()
    for elem in parsed.keys():
        z = ELEMENT_TO_Z.get(elem)
        if z:
            target_atomic_nums.add(z)
    scores = []
    for i in data['train_indices']:
        n_shared, jaccard = element_overlap_score(target_atomic_nums, data['elem_idx'][i])
        if n_shared > 0:
            scores.append((i, n_shared, jaccard))
    scores.sort(key=lambda x: (-x[1], -x[2]))
    return [s[0] for s in scores[:top_k]]

---
# PART 1: Baseline Roundtrip Validation (Pre-Phase-2)

Encode each holdout target → decode → check roundtrip fidelity.
This establishes the baseline before Phase 2 training improves the model.
---

## Cell 8: Roundtrip Validation

In [None]:
@torch.no_grad()
def full_forward(encoder, decoder, elem_idx, elem_frac, elem_mask, magpie, tc, temperature=0.01):
    """Run full encoder forward (all heads) + formula decoder."""
    enc_out = encoder(
        elem_idx.to(DEVICE), elem_frac.to(DEVICE), elem_mask.to(DEVICE),
        magpie.to(DEVICE), tc.to(DEVICE),
    )
    z = enc_out['z']
    tc_pred_norm = enc_out['tc_pred'][0].item()
    magpie_pred = enc_out['magpie_pred'][0].cpu()
    fraction_pred = enc_out['fraction_pred']
    element_count_pred = enc_out['element_count_pred']
    numden_pred = enc_out.get('numden_pred')
    if numden_pred is not None:
        stoich_pred = torch.cat([fraction_pred, numden_pred, element_count_pred.unsqueeze(-1)], dim=-1)
    else:
        stoich_pred = torch.cat([fraction_pred, element_count_pred.unsqueeze(-1)], dim=-1)
    generated, log_probs, entropy = decoder.generate_with_kv_cache(
        z=z, stoich_pred=stoich_pred, temperature=temperature,
    )
    result = {
        'formula': tokens_to_formula(generated[0]),
        'tc_pred_kelvin': denormalize_tc(tc_pred_norm),
        'magpie_pred': magpie_pred,
    }
    tc_class_logits = enc_out.get('tc_class_logits')
    if tc_class_logits is not None:
        result['tc_class'] = torch.softmax(tc_class_logits[0].cpu(), dim=-1).argmax().item()
    sc_pred = enc_out.get('sc_pred')
    if sc_pred is not None:
        result['sc_prob'] = torch.sigmoid(sc_pred[0]).item()
        result['sc_pred'] = result['sc_prob'] > 0.5
    family_composed = enc_out.get('family_composed_14')
    if family_composed is not None:
        result['family_pred_14'] = family_composed[0].cpu().argmax().item()
    hp_pred = enc_out.get('hp_pred')
    if hp_pred is not None:
        result['hp_prob'] = torch.sigmoid(hp_pred[0]).item()
    return result


# Run roundtrip validation
print('=' * 80)
print(f'ROUNDTRIP VALIDATION — Epoch {model_epoch}')
print('=' * 80)
print(f'{"Family":<14s} {"True Tc":>8s} {"Pred Tc":>9s} {"Err":>7s} {"Sim":>5s} {"SC?":>5s} {"FamPred":<16s} | Formula')
print('-' * 120)

roundtrip_results = []
tc_errors = []
magpie_mses = []
formula_sims = []
sc_correct = 0
family_correct = 0

for sample in holdout_samples:
    formula = sample['formula']
    true_tc = sample['Tc']
    family = sample['family']
    orig_idx = sample.get('original_index')
    if orig_idx is None:
        continue
    idx_t = torch.tensor([orig_idx], dtype=torch.long)
    decoded = full_forward(
        encoder, decoder,
        data['elem_idx'][idx_t], data['elem_frac'][idx_t],
        data['elem_mask'][idx_t], data['magpie'][idx_t], data['tc'][idx_t],
    )
    pred_tc = decoded['tc_pred_kelvin']
    gen_formula = decoded['formula']
    tc_err = abs(pred_tc - true_tc)
    tc_errors.append(tc_err)
    mag_mse = F.mse_loss(decoded['magpie_pred'], data['magpie'][orig_idx]).item()
    magpie_mses.append(mag_mse)
    sim = element_similarity(gen_formula, formula)
    formula_sims.append(sim)
    exact = gen_formula.strip() == formula.strip()
    sc_str = f"{decoded.get('sc_prob', 0):.2f}"
    if decoded.get('sc_pred', False):
        sc_correct += 1
    family_pred_str = ''
    if 'family_pred_14' in decoded:
        pred_idx = decoded['family_pred_14']
        pred_name = FAMILY_14_NAMES[pred_idx]
        true_idx = HOLDOUT_FAMILY_TO_14.get(family, -1)
        match = pred_idx == true_idx
        if match:
            family_correct += 1
        family_pred_str = f"{pred_name} {'OK' if match else 'X'}"
    exact_str = ' [EXACT]' if exact else ''
    print(f'  [{family:<12s}] {true_tc:8.1f} {pred_tc:9.1f} {tc_err:+7.1f} {sim:5.3f} {sc_str:>5s} {family_pred_str:<16s} | {formula}')
    if not exact:
        print(f'     -> {gen_formula}')
    else:
        print(f'     -> {gen_formula}{exact_str}')
    roundtrip_results.append({
        'formula': formula, 'generated': gen_formula, 'exact': exact,
        'similarity': sim, 'true_tc': true_tc, 'pred_tc': pred_tc,
        'tc_error': tc_err, 'family': family,
        'sc_prob': decoded.get('sc_prob'), 'family_pred': family_pred_str,
    })

In [None]:
# Roundtrip Summary
tc_arr = np.array(tc_errors)
sim_arr = np.array(formula_sims)
mag_arr = np.array(magpie_mses)

print('\n' + '=' * 80)
print('ROUNDTRIP SUMMARY')
print('=' * 80)
print(f'\nTc Prediction:')
print(f'  MAE: {tc_arr.mean():.2f} K (median: {np.median(tc_arr):.2f} K)')
print(f'  Within 1K: {(tc_arr < 1).sum()}/45 | Within 5K: {(tc_arr < 5).sum()}/45')
print(f'\nFormula Roundtrip:')
print(f'  Mean similarity: {sim_arr.mean():.3f}')
print(f'  Exact matches: {(sim_arr >= 0.999).sum()}/45')
print(f'  > 0.95 similarity: {(sim_arr > 0.95).sum()}/45')
print(f'\nSC Classification: {sc_correct}/45 ({sc_correct/45*100:.1f}%)')
print(f'Family Classification: {family_correct}/45 ({family_correct/45*100:.1f}%)')
print(f'Magpie MSE: {mag_arr.mean():.6f}')

---
# PHASE 2: Training + Discovery (Main Event)

Phase 2 generates formulas as part of self-supervised training. Each sub-epoch:
1. Samples Z-vectors (perturbation, SLERP, PCA, element-anchored)
2. Generates formulas via the decoder
3. Filters for chemical validity (parse, plausibility, physics, constraints)
4. **Checks every valid formula against holdout set + training set**
5. Computes round-trip consistency losses and updates the model
6. Logs novel discoveries to `outputs/phase2_discoveries.jsonl`

Every `PHASE2_HOLDOUT_INTERVAL` sub-epochs, runs a mini targeted holdout search.
---

In [None]:
def search_single_target(encoder, decoder, data, target_formula, target_tc, target_family):
    """Targeted search for a single holdout formula. Budget scales with VRAM."""
    print(f'\n  TARGET: {target_formula} (Tc={target_tc}K, {target_family})')

    neighbor_indices = find_element_neighbors(target_formula, data, top_k=B['n_neighbors'])
    if len(neighbor_indices) < 3:
        print(f'    Only {len(neighbor_indices)} neighbors — skipping')
        return {'target': target_formula, 'target_tc': target_tc, 'target_family': target_family,
                'best_sim': 0.0, 'best_gen': '', 'n_unique': 0, 'n_total': 0,
                'exact': False, 'top_matches': [], 'top_frequent': []}

    z_neighbors = encode_indices(encoder, data, neighbor_indices)
    z_seeds = z_neighbors[:min(B['n_seeds'], len(z_neighbors))]
    print(f'    {len(neighbor_indices)} neighbors, {len(z_seeds)} seeds, Z norm: {z_neighbors.norm(dim=-1).mean():.2f}')

    all_candidates_z = []

    # Strategy 1: Fine-grained perturbation
    for z in z_seeds:
        z_exp = z.unsqueeze(0)
        for scale in B['noise_scales']:
            noise = torch.randn(B['n_perturbations'], z.shape[0]) * scale
            all_candidates_z.append(z_exp + noise)

    # Strategy 2: Pairwise interpolation (linear + SLERP)
    n_seeds = len(z_seeds)
    max_pairs = min(n_seeds * (n_seeds - 1) // 2, B['max_pairs'])
    pair_count = 0
    for i in range(n_seeds):
        for j in range(i + 1, n_seeds):
            if pair_count >= max_pairs:
                break
            z1 = z_seeds[i].unsqueeze(0)
            z2 = z_seeds[j].unsqueeze(0)
            for t in np.linspace(0.05, 0.95, B['n_interp_steps']):
                all_candidates_z.append((1 - t) * z1 + t * z2)
                all_candidates_z.append(slerp(z1, z2, float(t)))
            pair_count += 1
        if pair_count >= max_pairs:
            break

    # Strategy 3: Centroid + random direction walks
    centroid = z_seeds.mean(dim=0, keepdim=True)
    std = z_seeds.std(dim=0, keepdim=True).clamp(min=1e-6)
    for scale in [0.3, 0.5, 1.0, 1.5, 2.0]:
        n_dirs = max(30, B['n_seeds'])
        directions = torch.randn(n_dirs, z_seeds.shape[1])
        directions = F.normalize(directions, dim=-1)
        all_candidates_z.append(centroid + scale * std * directions)

    # Strategy 4: PCA-directed walks
    if len(z_neighbors) >= 10:
        z_np = z_neighbors.numpy()
        mean = z_np.mean(axis=0)
        centered = z_np - mean
        U, S, Vt = np.linalg.svd(centered, full_matrices=False)
        n_comp = min(B['n_pca_comps'], len(S))
        for c in range(n_comp):
            direction = torch.from_numpy(Vt[c]).float()
            std_along = S[c] / np.sqrt(len(z_neighbors) - 1)
            for alpha in np.linspace(-3.0, 3.0, B['pca_steps']):
                z_walked = torch.from_numpy(mean).float().unsqueeze(0) + alpha * std_along * direction.unsqueeze(0)
                all_candidates_z.append(z_walked)

    all_z = torch.cat(all_candidates_z, dim=0)
    print(f'    {len(all_z)} Z candidates (greedy decode)...')

    # Greedy decode
    t0 = time.time()
    greedy_formulas = decode_z_batch(encoder, decoder, all_z, temperature=0.01)
    print(f'    Greedy decoded in {time.time()-t0:.1f}s')

    # Temperature sampling
    temp_formulas = []
    for temp in B['temps']:
        for z_idx in range(min(len(z_seeds), B['n_seeds'] // 2)):
            z_repeated = z_seeds[z_idx].unsqueeze(0).repeat(B['n_temp_samples'], 1)
            temp_formulas.extend(decode_z_batch(encoder, decoder, z_repeated, temperature=temp))

    all_generated = greedy_formulas + temp_formulas
    unique_formulas = set(f for f in all_generated if f and len(f) > 1)

    # Score candidates
    formula_counts = defaultdict(int)
    for f in all_generated:
        if f and len(f) > 1:
            formula_counts[f] += 1

    best_sim = 0.0
    best_gen = ''
    top_matches = []
    candidates = []

    for formula_gen, count in formula_counts.items():
        sim = element_similarity(formula_gen, target_formula)
        candidates.append({'formula': formula_gen, 'count': count, 'similarity': sim})
        if sim > best_sim:
            best_sim = sim
            best_gen = formula_gen
        if sim >= 0.8:
            top_matches.append((formula_gen, sim, count))

    candidates.sort(key=lambda x: -x['similarity'])
    top_matches.sort(key=lambda x: -x[1])
    target_norm = target_formula.strip()
    exact = any(f.strip() == target_norm for f in unique_formulas)
    top_freq = sorted([(c['formula'], c['count']) for c in candidates], key=lambda x: -x[1])[:5]

    print(f'    Total: {len(all_generated)}, unique: {len(unique_formulas)}')
    print(f'    RESULT: best_sim={best_sim:.4f}, exact={"YES" if exact else "no"}')
    if top_matches[:3]:
        for f, s, c in top_matches[:3]:
            print(f'      sim={s:.4f} ({c}x): {f}')

    return {
        'target': target_formula, 'target_tc': target_tc, 'target_family': target_family,
        'exact': exact, 'best_sim': float(best_sim), 'best_gen': best_gen,
        'n_unique': len(unique_formulas), 'n_total': len(all_generated),
        'n_neighbors': len(neighbor_indices),
        'top_matches': [(f, float(s)) for f, s, c in top_matches[:10]],
        'top_frequent': top_freq,
    }

## PHASE 2: Training + Discovery (Main Event)

Self-supervised training generates formulas, checks them against holdouts, and discovers novel materials.
The `search_single_target` helper is defined above for mini holdout searches during Phase 2.

In [None]:
from superconductor.training.self_supervised import (
    SelfSupervisedConfig, SelfSupervisedEpoch,
)

# --- Initialize Phase 2 engine ---
phase2_config = SelfSupervisedConfig(
    enabled=True,
    start='0',             # Activate immediately (we're past training)
    max_weight=0.1,
    n_samples=0,            # Auto-scale with VRAM
    warmup_epochs=5,        # Short warmup since model is already trained
    interval=1,             # Run every sub-epoch
    min_resume_epochs=0,    # No delay
)
n_samples = phase2_config.resolve_n_samples(DEVICE)

print('=' * 80)
print(f'PHASE 2: TRAINING + DISCOVERY — {PHASE2_EPOCHS} sub-epochs')
print(f'  Samples per sub-epoch: {n_samples} (VRAM-scaled)')
print(f'  Holdout search every {PHASE2_HOLDOUT_INTERVAL} sub-epochs')
print('=' * 80)

# Load or build z-cache for Phase 2 sampling (one-time if missing)
z_cache_path = OUTPUT_DIR / 'latent_cache.pt'
if not z_cache_path.exists():
    print('Z-cache not found. Encoding all training data (one-time, ~2 min on A100)...')
    encoder.eval()
    _z_idx = data['train_indices']
    _all_z = encode_indices(encoder, data, _z_idx)
    _z_cache = {
        'z_vectors': _all_z,
        'tc_values': data['tc'][_z_idx],
        'is_sc': data['is_sc'][_z_idx],
    }
    if 'elem_idx' in data:
        _z_cache['element_indices'] = data['elem_idx'][_z_idx]
        _z_cache['element_mask'] = data['elem_mask'][_z_idx]
    torch.save(_z_cache, z_cache_path)
    print(f'  Saved z-cache: {_all_z.shape[0]} vectors to {z_cache_path.name}')
    del _all_z, _z_cache
z_cache_available = z_cache_path.exists()

if PHASE2_EPOCHS > 0:
    # Create optimizers
    encoder.train()
    decoder.train()
    enc_opt = torch.optim.AdamW(encoder.parameters(), lr=PHASE2_LR, weight_decay=1e-4)
    dec_opt = torch.optim.AdamW(decoder.parameters(), lr=PHASE2_LR, weight_decay=1e-4)

    phase2_engine = SelfSupervisedEpoch(
        config=phase2_config,
        encoder=encoder,
        decoder=decoder,
        device=DEVICE,
        v13_tokenizer=V13_TOKENIZER,
        known_formulas=training_formulas,
        holdout_formulas=holdout_formulas,
        discovery_output_path=str(OUTPUT_DIR / 'phase2_discoveries.jsonl'),
    )

    if z_cache_available:
        phase2_engine.load_z_cache(str(z_cache_path))
    else:
        print('  WARNING: No z-cache found. Phase 2 sampling will be limited.')
        print('  Run training first to generate outputs/latent_cache.pt')

    # Force activation
    phase2_engine._phase2_activation_epoch = 0
    phase2_engine._activation_exact = 0.9

    # --- Main Phase 2 loop ---
    phase2_history = []
    cumulative_holdout = 0
    cumulative_novel = 0
    holdout_search_snapshots = []

    for sub_epoch in range(PHASE2_EPOCHS):
        print(f'\n{"="*60}')
        print(f'Phase 2 sub-epoch {sub_epoch + 1}/{PHASE2_EPOCHS}')
        print(f'{"="*60}')

        metrics = phase2_engine.run(
            epoch=sub_epoch,
            current_exact=0.9,
            enc_opt=enc_opt,
            dec_opt=dec_opt,
            main_lr=PHASE2_LR,
            use_amp=torch.cuda.is_available(),
            amp_dtype=torch.bfloat16 if torch.cuda.is_available() and
                      torch.cuda.get_device_capability(0)[0] >= 8 else torch.float16,
        )

        if metrics:
            n_valid = metrics.get('phase2_n_valid', 0)
            n_novel = metrics.get('phase2_n_novel', 0)
            n_holdout = metrics.get('phase2_n_holdout_recovered', 0)
            cumulative_holdout += n_holdout
            cumulative_novel += n_novel

            print(f'  Loss: {metrics.get("phase2_total_loss", 0):.4f} '
                  f'(RT={metrics.get("phase2_loss1_rt", 0):.4f}, '
                  f'Consist={metrics.get("phase2_loss2_consist", 0):.4f}, '
                  f'Phys={metrics.get("phase2_loss3_physics", 0):.4f})')
            print(f'  Valid: {n_valid}/{metrics.get("phase2_n_sampled", 0)} '
                  f'({metrics.get("phase2_valid_rate", 0):.1%}), '
                  f'Unique: {metrics.get("phase2_n_unique_formulas", 0)}, '
                  f'Z-MSE: {metrics.get("phase2_z_mse", 0):.4f}')
            print(f'  Discoveries this epoch: novel={n_novel}, holdout={n_holdout}')
            print(f'  Cumulative: novel={cumulative_novel}, holdout={cumulative_holdout}/45')

            phase2_history.append(metrics)

        # --- Periodic mini holdout search ---
        if (sub_epoch + 1) % PHASE2_HOLDOUT_INTERVAL == 0 and sub_epoch < PHASE2_EPOCHS - 1:
            print(f'\n  --- Mini holdout search (sub-epoch {sub_epoch + 1}) ---')
            encoder.eval()
            decoder.eval()

            # Quick search on a subset (9 targets, 1 per family)
            mini_targets = []
            seen_families = set()
            for s in holdout_samples:
                if s['family'] not in seen_families:
                    seen_families.add(s['family'])
                    mini_targets.append(s)

            mini_exact = 0
            mini_best_sims = []
            for s in mini_targets:
                # Use small budget for speed
                old_budget = B.copy()
                try:
                    for k in ['n_perturbations', 'n_seeds', 'n_temp_samples', 'max_pairs']:
                        B[k] = max(5, B[k] // 5)
                    B['noise_scales'] = [0.05, 0.1, 0.2]
                    B['temps'] = [0.1, 0.3]
                    B['n_pca_comps'] = 5
                    B['pca_steps'] = 5
                    B['n_interp_steps'] = 5
                    B['n_neighbors'] = 30

                    r = search_single_target(encoder, decoder, data,
                                             s['formula'], s['Tc'], s['family'])
                    if r.get('exact'):
                        mini_exact += 1
                    mini_best_sims.append(r['best_sim'])
                finally:
                    for k, v in old_budget.items():
                        B[k] = v

            avg_sim = np.mean(mini_best_sims) if mini_best_sims else 0
            holdout_search_snapshots.append({
                'sub_epoch': sub_epoch + 1,
                'mini_exact': mini_exact,
                'avg_sim': avg_sim,
                'n_targets': len(mini_targets),
            })
            print(f'  Mini search: {mini_exact}/{len(mini_targets)} exact, avg_sim={avg_sim:.3f}')

            encoder.train()
            decoder.train()

    encoder.eval()
    decoder.eval()

    # --- Phase 2 summary ---
    print(f'\n{"="*80}')
    print(f'PHASE 2 COMPLETE — {PHASE2_EPOCHS} sub-epochs')
    print(f'{"="*80}')
    print(f'  Total holdout recoveries: {cumulative_holdout}/45')
    print(f'  Total novel discoveries: {cumulative_novel}')
    if phase2_history:
        avg_valid_rate = np.mean([m.get('phase2_valid_rate', 0) for m in phase2_history])
        avg_z_mse = np.mean([m.get('phase2_z_mse', 0) for m in phase2_history])
        print(f'  Avg valid rate: {avg_valid_rate:.1%}')
        print(f'  Avg Z-MSE: {avg_z_mse:.4f}')
    if holdout_search_snapshots:
        print(f'\n  Holdout search progress:')
        for snap in holdout_search_snapshots:
            print(f'    Sub-epoch {snap["sub_epoch"]}: '
                  f'{snap["mini_exact"]}/{snap["n_targets"]} exact, avg_sim={snap["avg_sim"]:.3f}')

    # Check phase2_discoveries.jsonl for full results
    disc_path = OUTPUT_DIR / 'phase2_discoveries.jsonl'
    if disc_path.exists():
        import json as _json
        discoveries = []
        with open(disc_path) as f:
            for line in f:
                if line.strip():
                    discoveries.append(_json.loads(line))
        n_novel_disc = sum(1 for d in discoveries if d.get('category') == 'novel')
        n_holdout_disc = sum(1 for d in discoveries if d.get('category') == 'holdout_recovery')
        print(f'\n  Discoveries file: {len(discoveries)} total entries')
        print(f'    Novel: {n_novel_disc}, Holdout recoveries: {n_holdout_disc}')
        if n_holdout_disc > 0:
            print(f'    Recovered holdout formulas:')
            for d in discoveries:
                if d.get('category') == 'holdout_recovery':
                    print(f'      {d["formula"]} (Tc={d.get("tc_kelvin", "?"):.1f}K)')

    print(f'\nModel has been updated in-memory. Proceeding with improved model...')
else:
    phase2_history = []
    holdout_search_snapshots = []
    cumulative_holdout = 0
    cumulative_novel = 0
    print('Phase 2 skipped (PHASE2_EPOCHS=0). Running inference-only evaluation.')

## Cell 9: Targeted Holdout Search (Post-Phase-2)

Full targeted search on all 45 holdout targets using the **improved** model (after Phase 2 training).
Uses element-anchored Z-space exploration with VRAM-scaled budgets.

In [None]:
# Run targeted search on all 45 holdout targets
print('=' * 80)
print(f'TARGETED HOLDOUT SEARCH — Epoch {model_epoch} (budget={SEARCH_BUDGET})')
print('Element-Anchored Z-Space Exploration')
print('=' * 80)

search_results = []
t_start = time.time()

for i, sample in enumerate(holdout_samples):
    print(f'\n--- [{i+1}/45] ---')
    result = search_single_target(
        encoder, decoder, data,
        sample['formula'], sample['Tc'], sample['family'],
    )
    search_results.append(result)

total_time = time.time() - t_start
print(f'\nTotal search time: {total_time/60:.1f} minutes')

In [None]:
# Search Results Summary
print('\n' + '=' * 80)
print(f'TARGETED SEARCH SUMMARY — Epoch {model_epoch}')
print('=' * 80)

for r in search_results:
    marker = '***' if r.get('exact') else ' + ' if r['best_sim'] >= 0.95 else '   '
    print(f"{marker} [{r['target_family']:12s}] sim={r['best_sim']:.3f} | {r['target']}")
    if r['best_gen'] and not r.get('exact'):
        print(f"     Best: {r['best_gen']}")

n_exact = sum(1 for r in search_results if r.get('exact'))
print(f'\nExact matches: {n_exact}/{len(search_results)}')

print(f'\nRESULTS BY THRESHOLD:')
for thresh in [1.0, 0.99, 0.98, 0.95, 0.90, 0.85, 0.80]:
    found = sum(1 for r in search_results if r['best_sim'] >= thresh)
    pct = found / len(search_results) * 100
    bar = '#' * int(pct / 2)
    print(f'  >= {thresh:.2f}: {found:2d}/45 ({pct:5.1f}%) {bar}')

print(f'\nPER-FAMILY BREAKDOWN:')
families = sorted(set(r['target_family'] for r in search_results))
for fam in families:
    fam_results = [r for r in search_results if r['target_family'] == fam]
    n_exact_fam = sum(1 for r in fam_results if r.get('exact'))
    n_095 = sum(1 for r in fam_results if r['best_sim'] >= 0.95)
    avg_sim = np.mean([r['best_sim'] for r in fam_results])
    print(f'  {fam:14s}: exact={n_exact_fam}/5, >=0.95={n_095}/5, avg_sim={avg_sim:.3f}')

---
# PART 2: Novel Superconductor Discovery (Post-Phase-2)

Broad Z-space exploration using the **improved** model (after Phase 2 training).
---

## Cell 10: Encode Training Data -> Z Cache

Build a complete Z-space representation of all training data for broad exploration.

In [None]:
# Check for pre-computed z-cache first
z_cache_path = OUTPUT_DIR / 'latent_cache.pt'
if z_cache_path.exists():
    print(f'Loading pre-computed Z-cache from {z_cache_path.name}...')
    z_cache = torch.load(z_cache_path, map_location='cpu', weights_only=True)
    all_z = z_cache['z_vectors']
    all_tc = z_cache.get('tc_values', data['tc'])
    all_is_sc = z_cache.get('is_sc', data['is_sc'])
    print(f'  Z-cache: {all_z.shape[0]} vectors, {all_z.shape[1]} dims, '
          f'norm={all_z.norm(dim=-1).mean():.2f}')
else:
    print('No Z-cache found. Encoding all training data (takes ~2 min on A100)...')
    all_z = encode_indices(encoder, data, data['train_indices'])
    all_tc = data['tc'][data['train_indices']]
    all_is_sc = data['is_sc'][data['train_indices']]
    print(f'  Encoded {all_z.shape[0]} samples, Z norm={all_z.norm(dim=-1).mean():.2f}')

# SC-only z-vectors for high-Tc exploration
sc_mask = all_is_sc.bool()
z_sc = all_z[sc_mask]
tc_sc = all_tc[sc_mask]
print(f'  SC samples: {z_sc.shape[0]}, mean Tc_norm={tc_sc.mean():.3f}')

# Identify high-Tc region (top 10%)
tc_threshold = torch.quantile(tc_sc, 0.9).item()
high_tc_mask = tc_sc >= tc_threshold
z_high_tc = z_sc[high_tc_mask]
print(f'  High-Tc (top 10%): {z_high_tc.shape[0]} samples, '
      f'Tc_norm >= {tc_threshold:.3f} ({denormalize_tc(tc_threshold):.1f} K)')

## Cell 11: Broad Z-Space Exploration for Novel Materials

In [None]:
print('=' * 80)
print(f'NOVEL SUPERCONDUCTOR DISCOVERY — {NOVEL_N_SAMPLES:,} Z-space samples')
print('=' * 80)

z_norm_mean = z_sc.norm(dim=-1).mean().item()
z_norm_std = z_sc.norm(dim=-1).std().item()
z_mean = z_sc.mean(dim=0)
z_std = z_sc.std(dim=0).clamp(min=1e-6)

novel_z_candidates = []

# Allocate budget across strategies
n_perturb = int(NOVEL_N_SAMPLES * 0.35)   # 35% perturbation around SC samples
n_high_tc = int(NOVEL_N_SAMPLES * 0.25)    # 25% around high-Tc region
n_slerp = int(NOVEL_N_SAMPLES * 0.20)      # 20% SLERP between SC pairs
n_pca = int(NOVEL_N_SAMPLES * 0.10)         # 10% PCA-directed walks
n_gradient = int(NOVEL_N_SAMPLES * 0.10)    # 10% gradient-based Tc optimization

print(f'\nStrategy allocation:')
print(f'  Perturbation (SC):  {n_perturb:,}')
print(f'  High-Tc region:     {n_high_tc:,}')
print(f'  SLERP interpolation: {n_slerp:,}')
print(f'  PCA walks:          {n_pca:,}')
print(f'  Gradient Tc optim:  {n_gradient:,}')

# Strategy 1: Perturbation around random SC training samples
print('\n[1/5] Perturbation around SC samples...')
seed_indices = torch.randperm(len(z_sc))[:min(500, len(z_sc))]
samples_per_seed = max(1, n_perturb // len(seed_indices))
for idx in seed_indices:
    z_seed = z_sc[idx].unsqueeze(0)
    noise_scale = np.random.choice([0.05, 0.1, 0.15, 0.2, 0.3])
    noise = torch.randn(samples_per_seed, z_seed.shape[1]) * noise_scale
    novel_z_candidates.append(z_seed + noise)
print(f'  Generated {sum(c.shape[0] for c in novel_z_candidates)} perturbation candidates')

# Strategy 2: Dense exploration around high-Tc region
print('[2/5] High-Tc region exploration...')
high_tc_centroid = z_high_tc.mean(dim=0, keepdim=True)
high_tc_std = z_high_tc.std(dim=0, keepdim=True).clamp(min=1e-6)
for scale in [0.3, 0.5, 0.8, 1.0, 1.5, 2.0]:
    n_this = n_high_tc // 6
    noise = torch.randn(n_this, z_sc.shape[1]) * scale
    novel_z_candidates.append(high_tc_centroid + high_tc_std * noise)
print(f'  Cumulative: {sum(c.shape[0] for c in novel_z_candidates)} candidates')

# Strategy 3: SLERP between random SC pairs
print('[3/5] SLERP interpolation between SC pairs...')
n_pairs = n_slerp // 10  # 10 interpolation points per pair
pair_idx = torch.randperm(len(z_sc))[:2 * n_pairs].reshape(n_pairs, 2)
for i in range(n_pairs):
    z1 = z_sc[pair_idx[i, 0]].unsqueeze(0)
    z2 = z_sc[pair_idx[i, 1]].unsqueeze(0)
    for t in np.linspace(0.1, 0.9, 10):
        novel_z_candidates.append(slerp(z1, z2, float(t)))
print(f'  Cumulative: {sum(c.shape[0] for c in novel_z_candidates)} candidates')

# Strategy 4: PCA-directed walks
print('[4/5] PCA-directed walks...')
z_sc_np = z_sc.numpy()
sc_mean = z_sc_np.mean(axis=0)
sc_centered = z_sc_np - sc_mean
U, S, Vt = np.linalg.svd(sc_centered, full_matrices=False)
n_comps = min(30, len(S))
steps_per_comp = max(1, n_pca // (n_comps * 10))
for c in range(n_comps):
    direction = torch.from_numpy(Vt[c]).float()
    std_along = S[c] / np.sqrt(len(z_sc) - 1)
    for alpha in np.linspace(-3.0, 3.0, steps_per_comp):
        z_walked = torch.from_numpy(sc_mean).float().unsqueeze(0) + alpha * std_along * direction.unsqueeze(0)
        novel_z_candidates.append(z_walked)
print(f'  Cumulative: {sum(c.shape[0] for c in novel_z_candidates)} candidates')

# Strategy 5: Gradient-based Tc optimization
print('[5/5] Gradient-based Tc optimization...')
n_grad_starts = min(n_gradient, 500)
grad_start_idx = torch.randperm(len(z_high_tc))[:n_grad_starts]
grad_z_list = []
encoder.eval()  # Keep encoder in eval mode
for idx in grad_start_idx:
    z_opt = z_high_tc[idx].clone().unsqueeze(0).to(DEVICE).requires_grad_(True)
    optimizer = torch.optim.Adam([z_opt], lr=0.1)
    for step in range(20):  # 20 gradient steps per starting point
        optimizer.zero_grad()
        tc_pred = encoder.tc_head(z_opt).squeeze()
        # Maximize Tc while staying near training distribution
        z_reg = ((z_opt.norm(dim=-1) - z_norm_mean) ** 2).mean() * 0.01
        loss = -tc_pred + z_reg  # Negative because we want to maximize Tc
        loss.backward()
        optimizer.step()
    grad_z_list.append(z_opt.detach().cpu())
if grad_z_list:
    novel_z_candidates.append(torch.cat(grad_z_list, dim=0))
print(f'  Cumulative: {sum(c.shape[0] for c in novel_z_candidates)} candidates')

all_novel_z = torch.cat(novel_z_candidates, dim=0)
print(f'\nTotal novel Z candidates: {len(all_novel_z):,}')

## Cell 12: Decode & Validate Novel Candidates

In [None]:
print('Decoding novel candidates...')
t0 = time.time()
novel_results = decode_z_with_tc(encoder, decoder, all_novel_z, temperature=0.01)
print(f'  Decoded {len(novel_results):,} candidates in {time.time()-t0:.1f}s')

# Deduplicate and classify
formula_best = {}  # formula -> best result (highest Tc)
for r in novel_results:
    f = r['formula'].strip()
    if not f or len(f) < 2:
        continue
    if f not in formula_best or r['tc_kelvin'] > formula_best[f]['tc_kelvin']:
        formula_best[f] = r

print(f'  Unique formulas: {len(formula_best):,}')

# Classify: known (training), holdout recovery, or novel
known_count = 0
holdout_recovered = []
novel_candidates = []

for formula, result in formula_best.items():
    if formula in training_formulas:
        known_count += 1
    elif formula in holdout_formulas:
        holdout_recovered.append(result)
    else:
        # Check if it parses as a valid chemical formula
        parsed = parse_formula_elements(formula)
        if parsed and len(parsed) >= 2:  # At least 2 elements
            result['parsed_elements'] = parsed
            result['n_elements'] = len(parsed)
            novel_candidates.append(result)

print(f'\nClassification:')
print(f'  Known (training): {known_count}')
print(f'  Holdout recovered: {len(holdout_recovered)}')
print(f'  Novel (valid parse, 2+ elements): {len(novel_candidates)}')
print(f'  Invalid/single-element: {len(formula_best) - known_count - len(holdout_recovered) - len(novel_candidates)}')

if holdout_recovered:
    print(f'\nHoldout Recoveries (found during broad exploration!):')
    for r in holdout_recovered:
        print(f'  {r["formula"]} (pred Tc={r["tc_kelvin"]:.1f}K, SC prob={r["sc_prob"]:.3f})')

In [None]:
# Filter and rank novel candidates
# Keep those predicted as superconductors with Tc above threshold
sc_novels = [r for r in novel_candidates if r['sc_prob'] > 0.5 and r['tc_kelvin'] >= NOVEL_MIN_TC]
sc_novels.sort(key=lambda x: -x['tc_kelvin'])

print('=' * 80)
print(f'TOP NOVEL SUPERCONDUCTOR CANDIDATES (SC prob > 0.5, Tc >= {NOVEL_MIN_TC}K)')
print('=' * 80)
print(f'{"Rank":>4s} {"Formula":<35s} {"Tc (K)":>8s} {"SC prob":>8s} {"Family":>16s} {"Elements":>10s}')
print('-' * 90)

for i, r in enumerate(sc_novels[:50]):
    fam_name = FAMILY_14_NAMES[r['family']] if 0 <= r['family'] < len(FAMILY_14_NAMES) else '?'
    elems = ', '.join(sorted(r.get('parsed_elements', {}).keys()))
    print(f'{i+1:4d}  {r["formula"]:<35s} {r["tc_kelvin"]:8.1f} {r["sc_prob"]:8.3f} {fam_name:>16s} {elems}')

print(f'\nTotal novel SC candidates: {len(sc_novels)}')

# Tc distribution of novel candidates
if sc_novels:
    tc_vals = [r['tc_kelvin'] for r in sc_novels]
    print(f'\nNovel Tc distribution:')
    for lo, hi, label in [(0,10,'0-10K'), (10,30,'10-30K'), (30,77,'30-77K'), (77,120,'77-120K'), (120,200,'120-200K'), (200,9999,'>200K')]:
        n = sum(1 for t in tc_vals if lo <= t < hi)
        if n > 0:
            print(f'  {label:>10s}: {n:5d}')

---
# Results & Visualization
---

## Cell 14: Visualization

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 2, figsize=(16, 12))

fam_colors = {
    'YBCO': '#e74c3c', 'LSCO': '#3498db', 'Hg-cuprate': '#2ecc71',
    'Tl-cuprate': '#9b59b6', 'Bi-cuprate': '#f39c12', 'Iron-based': '#1abc9c',
    'MgB2': '#e67e22', 'Conventional': '#95a5a6', 'Other': '#34495e',
}

# Plot 1: Similarity distribution (holdout search)
sims = [r['best_sim'] for r in search_results]
colors = ['#2ecc71' if s >= 0.95 else '#f39c12' if s >= 0.80 else '#e74c3c' for s in sims]
axes[0,0].barh(range(len(sims)), sims, color=colors)
axes[0,0].set_xlabel('Best Similarity')
axes[0,0].set_ylabel('Holdout Target Index')
axes[0,0].set_title('Part 1: Holdout Recovery — Best Match per Target')
axes[0,0].axvline(x=0.95, color='green', linestyle='--', alpha=0.5, label='0.95')
axes[0,0].axvline(x=0.80, color='orange', linestyle='--', alpha=0.5, label='0.80')
axes[0,0].legend()
axes[0,0].set_xlim(0, 1.05)

# Plot 2: Tc prediction error (roundtrip)
for sample, err in zip(holdout_samples, tc_errors):
    axes[0,1].scatter(sample['Tc'], err, color=fam_colors.get(sample['family'], 'gray'),
                     s=50, alpha=0.7, label=sample['family'])
handles, labels = axes[0,1].get_legend_handles_labels()
by_label = dict(zip(labels, handles))
axes[0,1].legend(by_label.values(), by_label.keys(), fontsize=7, loc='upper left')
axes[0,1].set_xlabel('True Tc (K)')
axes[0,1].set_ylabel('Absolute Error (K)')
axes[0,1].set_title('Tc Prediction Error vs True Tc')
axes[0,1].axhline(y=1.0, color='gray', linestyle='--', alpha=0.5)

# Plot 3: Per-family average similarity
fam_avg = {}
for r in search_results:
    fam = r['target_family']
    fam_avg.setdefault(fam, []).append(r['best_sim'])
fam_names = sorted(fam_avg.keys())
fam_means = [np.mean(fam_avg[f]) for f in fam_names]
axes[1,0].barh(fam_names, fam_means, color=[fam_colors.get(f, 'gray') for f in fam_names])
axes[1,0].set_xlabel('Average Best Similarity')
axes[1,0].set_title('Average Similarity by Family')
axes[1,0].set_xlim(0, 1.05)
axes[1,0].axvline(x=0.95, color='green', linestyle='--', alpha=0.5)

# Plot 4: Novel candidate Tc distribution
if sc_novels:
    novel_tcs = [r['tc_kelvin'] for r in sc_novels]
    axes[1,1].hist(novel_tcs, bins=30, color='#3498db', edgecolor='white', alpha=0.8)
    axes[1,1].axvline(x=77, color='red', linestyle='--', alpha=0.7, label='LN2 (77K)')
    axes[1,1].axvline(x=120, color='orange', linestyle='--', alpha=0.7, label='120K')
    axes[1,1].set_xlabel('Predicted Tc (K)')
    axes[1,1].set_ylabel('Count')
    axes[1,1].set_title(f'Part 2: Novel Candidates — Tc Distribution (n={len(sc_novels)})')
    axes[1,1].legend()
else:
    axes[1,1].text(0.5, 0.5, 'No novel SC candidates found', ha='center', va='center', fontsize=14)
    axes[1,1].set_title('Part 2: Novel Candidates')

plt.tight_layout()
fig_path = str(OUTPUT_DIR / 'generative_evaluation.png')
plt.savefig(fig_path, dpi=150, bbox_inches='tight')
plt.show()
print(f'Saved: {fig_path}')

## Cell 15: Grade Card

In [None]:
tc_mae = np.mean(tc_errors)
tc_within_1k = (np.array(tc_errors) < 1).sum()
n_exact_search = sum(1 for r in search_results if r.get('exact'))
n_095 = sum(1 for r in search_results if r['best_sim'] >= 0.95)
n_099 = sum(1 for r in search_results if r['best_sim'] >= 0.99)
avg_best_sim = np.mean([r['best_sim'] for r in search_results])

def grade(val, thresholds):
    grades = ['A+', 'A', 'B+', 'B', 'C', 'D', 'F']
    for i, t in enumerate(thresholds):
        if val >= t:
            return grades[i]
    return grades[-1]

print('\n' + '=' * 80)
print(f'GENERATIVE EVALUATION GRADE CARD — Epoch {model_epoch}')
print('=' * 80)

metrics_list = [
    ('--- Part 1: Holdout Recovery ---', '', ''),
    ('SC Classification', f'{sc_correct}/45 ({sc_correct/45*100:.0f}%)',
     grade(sc_correct/45, [0.98, 0.95, 0.90, 0.80, 0.70, 0.60])),
    ('Tc Prediction MAE', f'{tc_mae:.2f}K',
     grade(1/(1+tc_mae), [0.67, 0.50, 0.33, 0.20, 0.10, 0.05])),
    ('Tc within 1K', f'{tc_within_1k}/45 ({tc_within_1k/45*100:.0f}%)',
     grade(tc_within_1k/45, [0.90, 0.80, 0.70, 0.60, 0.50, 0.30])),
    ('Family Classification', f'{family_correct}/45 ({family_correct/45*100:.0f}%)',
     grade(family_correct/45, [0.95, 0.90, 0.80, 0.70, 0.60, 0.50])),
    ('Formula Exact (search)', f'{n_exact_search}/45 ({n_exact_search/45*100:.0f}%)',
     grade(n_exact_search/45, [0.80, 0.60, 0.40, 0.25, 0.15, 0.05])),
    ('Formula >= 0.95 sim', f'{n_095}/45 ({n_095/45*100:.0f}%)',
     grade(n_095/45, [0.95, 0.85, 0.70, 0.55, 0.40, 0.25])),
    ('Average Best Similarity', f'{avg_best_sim:.3f}',
     grade(avg_best_sim, [0.95, 0.90, 0.85, 0.80, 0.70, 0.60])),
    ('Magpie MSE', f'{np.mean(magpie_mses):.4f}',
     grade(1/(1+np.mean(magpie_mses)*10), [0.90, 0.80, 0.60, 0.40, 0.25, 0.10])),
    ('--- Part 2: Novel Discovery ---', '', ''),
    ('Novel SC candidates', f'{len(sc_novels):,}', ''),
    ('Holdout recovered (broad)', f'{len(holdout_recovered)}/45', ''),
    ('Novel Tc > 77K (LN2)', f'{sum(1 for r in sc_novels if r["tc_kelvin"] > 77):,}', ''),
    ('Novel Tc > 120K', f'{sum(1 for r in sc_novels if r["tc_kelvin"] > 120):,}', ''),
]

print(f'\n{"Metric":<28s} {"Result":<28s} {"Grade":>5s}')
print('-' * 65)
for name, result, g in metrics_list:
    if name.startswith('---'):
        print(f'\n{name}')
    else:
        print(f'  {name:<26s} {result:<28s} {g:>5s}')

print(f'\n{"="*65}')
print(
    f'The model {"CAN" if n_exact_search >= 10 else "CANNOT yet"} reliably generate '
    f'unseen superconductors.\n'
    f'  - {n_exact_search}/45 exact matches from Z-space exploration\n'
    f'  - {n_095}/45 with >= 0.95 compositional similarity\n'
    f'  - Tc prediction: {tc_mae:.2f}K MAE on held-out materials\n'
    f'  - {len(sc_novels):,} novel SC candidates generated ({sum(1 for r in sc_novels if r["tc_kelvin"] > 77)} above LN2 temp)\n'
)

## Cell 16: Save Results

In [None]:
output = {
    'timestamp': time.strftime('%Y-%m-%dT%H:%M:%S'),
    'checkpoint': Path(CHECKPOINT_PATH).name,
    'epoch': model_epoch,
    'search_budget': SEARCH_BUDGET,
    'vram_gb': round(VRAM_GB, 1),
    'part1_roundtrip': {
        'tc_mae_K': float(tc_mae),
        'tc_within_1K': int(tc_within_1k),
        'sc_accuracy': sc_correct / 45,
        'family_accuracy': family_correct / 45,
        'magpie_mse': float(np.mean(magpie_mses)),
        'formula_mean_sim': float(np.mean(formula_sims)),
        'details': roundtrip_results,
    },
    'part1_search': {
        'n_exact': n_exact_search,
        'n_095': n_095,
        'n_099': n_099,
        'avg_best_sim': float(avg_best_sim),
        'details': [{k: v for k, v in r.items() if k != 'all_candidates'} for r in search_results],
    },
    'part2_novel': {
        'n_total_sampled': len(all_novel_z),
        'n_unique_formulas': len(formula_best),
        'n_known': known_count,
        'n_holdout_recovered': len(holdout_recovered),
        'n_novel_sc': len(sc_novels),
        'n_novel_above_77K': sum(1 for r in sc_novels if r['tc_kelvin'] > 77),
        'n_novel_above_120K': sum(1 for r in sc_novels if r['tc_kelvin'] > 120),
        'top_50': [{'formula': r['formula'], 'tc_kelvin': round(r['tc_kelvin'], 2),
                     'sc_prob': round(r['sc_prob'], 4),
                     'family': FAMILY_14_NAMES[r['family']] if 0 <= r['family'] < len(FAMILY_14_NAMES) else '?',
                     'elements': sorted(r.get('parsed_elements', {}).keys())}
                    for r in sc_novels[:50]],
        'holdout_recovered': [{'formula': r['formula'], 'tc_kelvin': round(r['tc_kelvin'], 2)}
                              for r in holdout_recovered],
    },
}

output_path = str(OUTPUT_DIR / f'generative_evaluation_epoch_{model_epoch}.json')
with open(output_path, 'w') as f:
    json.dump(output, f, indent=2, default=str)
print(f'Results saved to: {output_path}')

# Also save full novel candidates list
if sc_novels:
    novel_path = str(OUTPUT_DIR / f'novel_candidates_epoch_{model_epoch}.json')
    novel_output = [{'formula': r['formula'], 'tc_kelvin': round(r['tc_kelvin'], 2),
                      'sc_prob': round(r['sc_prob'], 4),
                      'family': FAMILY_14_NAMES[r['family']] if 0 <= r['family'] < len(FAMILY_14_NAMES) else '?',
                      'n_elements': r.get('n_elements', 0),
                      'elements': sorted(r.get('parsed_elements', {}).keys())}
                     for r in sc_novels]
    with open(novel_path, 'w') as f:
        json.dump(novel_output, f, indent=2)
    print(f'Novel candidates saved to: {novel_path} ({len(sc_novels)} candidates)')