# Superconductor VAE — Generative Evaluation

**Goal**: Evaluate whether the trained VAE is a good *generator* of superconductor formulas.

This notebook runs two evaluations against the **45 held-out superconductors** (5 per family, never seen during training):

1. **Roundtrip Validation**: Encode each holdout → Z → decode all heads (formula, Tc, Magpie, SC class, family)
2. **Targeted Holdout Search**: Element-anchored Z-space exploration with PCA walks, perturbation, interpolation, and temperature sampling (~27K candidates per target)
3. **Self-Consistency Check**: Verify all model heads agree (SC↔Tc, SC↔Family, Tc↔Bucket)

**Families**: YBCO, LSCO, Hg-cuprate, Tl-cuprate, Bi-cuprate, Iron-based, MgB2, Conventional, Other

## 0. Setup

In [None]:
# Clone the repo and install dependencies
import os

REPO_DIR = '/content/superconductor-vae'
if not os.path.exists(REPO_DIR):
    !git clone https://github.com/jamesconde/superconductor-vae.git {REPO_DIR}
else:
    !cd {REPO_DIR} && git pull

os.chdir(REPO_DIR)
!pip install -q scipy matminer pymatgen scikit-learn

In [None]:
# Upload checkpoint_best.pt
# Option A: Upload directly
# Option B: Copy from Google Drive (uncomment below)

import os
CHECKPOINT_PATH = os.path.join(REPO_DIR, 'outputs', 'checkpoint_best.pt')

# --- Option A: Direct upload ---
if not os.path.exists(CHECKPOINT_PATH):
    from google.colab import files
    print("Upload checkpoint_best.pt:")
    uploaded = files.upload()
    for name, data in uploaded.items():
        os.makedirs(os.path.dirname(CHECKPOINT_PATH), exist_ok=True)
        with open(CHECKPOINT_PATH, 'wb') as f:
            f.write(data)
        print(f"Saved {name} → {CHECKPOINT_PATH} ({len(data)/1e6:.1f} MB)")

# --- Option B: Google Drive (uncomment) ---
# from google.colab import drive
# drive.mount('/content/drive')
# !cp /content/drive/MyDrive/path/to/checkpoint_best.pt {CHECKPOINT_PATH}

print(f"Checkpoint: {CHECKPOINT_PATH}")
print(f"Exists: {os.path.exists(CHECKPOINT_PATH)}")
if os.path.exists(CHECKPOINT_PATH):
    print(f"Size: {os.path.getsize(CHECKPOINT_PATH)/1e6:.1f} MB")

In [None]:
import sys
import json
import time
from pathlib import Path
from collections import defaultdict

import torch
import torch.nn.functional as F
import numpy as np

# Add source to path
sys.path.insert(0, os.path.join(REPO_DIR, 'src'))

from superconductor.models.attention_vae import FullMaterialsVAE
from superconductor.models.autoregressive_decoder import (
    EnhancedTransformerDecoder, IDX_TO_TOKEN, TOKEN_TO_IDX,
    PAD_IDX, START_IDX, END_IDX,
)
from superconductor.data.canonical_ordering import CanonicalOrderer, ElementWithFraction

PROJECT_ROOT = Path(REPO_DIR)
HOLDOUT_PATH = PROJECT_ROOT / 'data' / 'GENERATIVE_HOLDOUT_DO_NOT_TRAIN.json'
CACHE_DIR = PROJECT_ROOT / 'data' / 'processed' / 'cache'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"Device: {DEVICE}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Constants
TC_MEAN = 2.725219433789196
TC_STD = 1.3527019896187407

FAMILY_14_NAMES = [
    'NOT_SC', 'BCS_CONVENTIONAL', 'CUPRATE_YBCO', 'CUPRATE_LSCO',
    'CUPRATE_BSCCO', 'CUPRATE_TBCCO', 'CUPRATE_HBCCO', 'CUPRATE_OTHER',
    'IRON_PNICTIDE', 'IRON_CHALCOGENIDE', 'MGB2_TYPE',
    'HEAVY_FERMION', 'ORGANIC', 'OTHER_UNKNOWN',
]

TC_BUCKET_NAMES = ['non-SC (0K)', 'low (0-10K)', 'medium (10-50K)',
                   'high (50-100K)', 'very-high (100K+)']

HOLDOUT_FAMILY_TO_14 = {
    'YBCO': 2, 'LSCO': 3, 'Hg-cuprate': 6, 'Tl-cuprate': 5,
    'Bi-cuprate': 4, 'Iron-based': 8, 'MgB2': 10,
    'Conventional': 1, 'Other': 13,
}

ELEMENT_TO_Z = {
    'H': 1, 'He': 2, 'Li': 3, 'Be': 4, 'B': 5, 'C': 6, 'N': 7, 'O': 8,
    'F': 9, 'Ne': 10, 'Na': 11, 'Mg': 12, 'Al': 13, 'Si': 14, 'P': 15,
    'S': 16, 'Cl': 17, 'Ar': 18, 'K': 19, 'Ca': 20, 'Sc': 21, 'Ti': 22,
    'V': 23, 'Cr': 24, 'Mn': 25, 'Fe': 26, 'Co': 27, 'Ni': 28, 'Cu': 29,
    'Zn': 30, 'Ga': 31, 'Ge': 32, 'As': 33, 'Se': 34, 'Br': 35, 'Kr': 36,
    'Rb': 37, 'Sr': 38, 'Y': 39, 'Zr': 40, 'Nb': 41, 'Mo': 42, 'Tc': 43,
    'Ru': 44, 'Rh': 45, 'Pd': 46, 'Ag': 47, 'Cd': 48, 'In': 49, 'Sn': 50,
    'Sb': 51, 'Te': 52, 'I': 53, 'Xe': 54, 'Cs': 55, 'Ba': 56, 'La': 57,
    'Ce': 58, 'Pr': 59, 'Nd': 60, 'Pm': 61, 'Sm': 62, 'Eu': 63, 'Gd': 64,
    'Tb': 65, 'Dy': 66, 'Ho': 67, 'Er': 68, 'Tm': 69, 'Yb': 70, 'Lu': 71,
    'Hf': 72, 'Ta': 73, 'W': 74, 'Re': 75, 'Os': 76, 'Ir': 77, 'Pt': 78,
    'Au': 79, 'Hg': 80, 'Tl': 81, 'Pb': 82, 'Bi': 83, 'Po': 84, 'At': 85,
    'Rn': 86, 'Fr': 87, 'Ra': 88, 'Ac': 89, 'Th': 90, 'Pa': 91, 'U': 92,
}

_CANONICALIZER = CanonicalOrderer()

## 1. Load Model & Data

In [None]:
def tokens_to_formula(token_ids):
    """Convert token IDs to formula string."""
    tokens = []
    for tid in token_ids:
        tid = int(tid)
        if tid == PAD_IDX or tid == START_IDX:
            continue
        if tid == END_IDX:
            break
        token = IDX_TO_TOKEN.get(tid, '?')
        tokens.append(token)
    return ''.join(tokens)


def denormalize_tc(tc_norm):
    """Convert normalized Tc prediction back to Kelvin."""
    tc_log = tc_norm * TC_STD + TC_MEAN
    return max(0.0, np.expm1(tc_log))


def parse_formula_elements(formula):
    """Extract {element: fraction_value} from formula string."""
    try:
        elements = _CANONICALIZER.parse_formula(formula)
        if not elements:
            return {}
        result = {}
        for ef in elements:
            val = ef.fraction_value
            result[ef.element] = result.get(ef.element, 0) + val
        return result
    except Exception:
        return {}


def element_similarity(formula_a, formula_b):
    """Compositional similarity: Jaccard on elements + fraction overlap."""
    parsed_a = parse_formula_elements(formula_a)
    parsed_b = parse_formula_elements(formula_b)
    if not parsed_a or not parsed_b:
        return 0.0
    all_elements = set(parsed_a.keys()) | set(parsed_b.keys())
    shared = set(parsed_a.keys()) & set(parsed_b.keys())
    if not all_elements:
        return 0.0
    jaccard = len(shared) / len(all_elements)
    if shared:
        total_a = sum(parsed_a.values())
        total_b = sum(parsed_b.values())
        frac_overlap = 0.0
        for elem in shared:
            fa = parsed_a[elem] / max(total_a, 1e-8)
            fb = parsed_b[elem] / max(total_b, 1e-8)
            frac_overlap += min(fa, fb)
        frac_sim = frac_overlap
    else:
        frac_sim = 0.0
    return 0.5 * jaccard + 0.5 * frac_sim


def element_overlap_score(target_elements, cache_elem_idx):
    """Score a training sample by how many target elements it contains."""
    candidate_elements = set(int(z) for z in cache_elem_idx if z > 0)
    shared_elems = target_elements & candidate_elements
    all_elem = target_elements | candidate_elements
    if not all_elem:
        return (0, 0.0)
    return (len(shared_elems), len(shared_elems) / len(all_elem))


def slerp(z1, z2, t):
    """Spherical linear interpolation."""
    z1_norm = F.normalize(z1, dim=-1)
    z2_norm = F.normalize(z2, dim=-1)
    omega = torch.acos(torch.clamp(
        (z1_norm * z2_norm).sum(dim=-1, keepdim=True), -1.0, 1.0
    ))
    omega = omega.clamp(min=1e-6)
    sin_omega = torch.sin(omega)
    if sin_omega.abs().min() < 1e-6:
        return (1 - t) * z1 + t * z2
    s1 = torch.sin((1 - t) * omega) / sin_omega
    s2 = torch.sin(t * omega) / sin_omega
    mag1 = z1.norm(dim=-1, keepdim=True)
    mag2 = z2.norm(dim=-1, keepdim=True)
    mag = (1 - t) * mag1 + t * mag2
    return (s1 * z1_norm + s2 * z2_norm) * mag

In [None]:
def load_models(checkpoint_path):
    """Load encoder and decoder from checkpoint."""
    print(f"Loading checkpoint: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=DEVICE, weights_only=False)

    enc_state_raw = checkpoint.get('encoder_state_dict', {})
    magpie_dim = 145
    for k, v in enc_state_raw.items():
        if 'magpie_encoder' in k and k.endswith('.weight') and v.dim() == 2:
            magpie_dim = v.shape[1]
            break

    enc_state = {k.replace('_orig_mod.', ''): v for k, v in enc_state_raw.items()}

    # Detect numden_head architecture
    has_numden_head = any('numden_head.' in k for k in enc_state)
    numden_first_key = 'numden_head.0.weight'
    old_numden_arch = False
    if numden_first_key in enc_state:
        if enc_state[numden_first_key].shape[0] == 128:
            old_numden_arch = True
            print(f"  Detected OLD numden_head architecture (128-dim)")

    encoder = FullMaterialsVAE(
        n_elements=118, element_embed_dim=128, n_attention_heads=8,
        magpie_dim=magpie_dim, fusion_dim=256, encoder_hidden=[512, 256],
        latent_dim=2048, decoder_hidden=[256, 512], dropout=0.1
    ).to(DEVICE)

    if old_numden_arch:
        import torch.nn as nn
        encoder.numden_head = nn.Sequential(
            nn.Linear(2048, 128), nn.ReLU(), nn.Linear(128, encoder.max_elements * 2),
        ).to(DEVICE)

    encoder.load_state_dict(enc_state, strict=False)

    dec_state_raw = checkpoint.get('decoder_state_dict', {})
    dec_state = {k.replace('_orig_mod.', ''): v for k, v in dec_state_raw.items()}

    stoich_weight_key = 'stoich_to_memory.0.weight'
    stoich_dim = dec_state[stoich_weight_key].shape[1] if stoich_weight_key in dec_state else 37

    dec_vocab_size = checkpoint.get('tokenizer_vocab_size', None)
    if dec_vocab_size is None and 'token_embedding.weight' in dec_state:
        dec_vocab_size = dec_state['token_embedding.weight'].shape[0]

    _d_model = checkpoint.get('d_model', None)
    if _d_model is None and 'token_embedding.weight' in dec_state:
        _d_model = dec_state['token_embedding.weight'].shape[1]
    _d_model = _d_model or 512
    _dim_ff = checkpoint.get('dim_feedforward', None)
    if _dim_ff is None and 'transformer_decoder.layers.0.linear1.weight' in dec_state:
        _dim_ff = dec_state['transformer_decoder.layers.0.linear1.weight'].shape[0]
    _dim_ff = _dim_ff or 2048
    _nhead = checkpoint.get('nhead', 8)
    _num_layers = checkpoint.get('num_layers', 12)
    _max_len = checkpoint.get('max_formula_len', 60)

    decoder = EnhancedTransformerDecoder(
        latent_dim=2048, d_model=_d_model, nhead=_nhead, num_layers=_num_layers,
        dim_feedforward=_dim_ff, dropout=0.1, max_len=_max_len,
        n_memory_tokens=16, encoder_skip_dim=256,
        use_skip_connection=False, use_stoich_conditioning=True,
        max_elements=12, n_stoich_tokens=4,
        vocab_size=dec_vocab_size, stoich_input_dim=stoich_dim,
    ).to(DEVICE)
    decoder.load_state_dict(dec_state, strict=False)

    encoder.eval()
    decoder.eval()

    epoch = checkpoint.get('epoch', '?')
    print(f"  Loaded epoch {epoch}, magpie_dim={magpie_dim}, d_model={_d_model}, "
          f"dim_ff={_dim_ff}, vocab={dec_vocab_size}, numden={has_numden_head}")
    return encoder, decoder, magpie_dim, has_numden_head, epoch


def load_data(magpie_dim):
    """Load cached tensors."""
    data = {
        'elem_idx': torch.load(CACHE_DIR / 'element_indices.pt', map_location='cpu', weights_only=True),
        'elem_frac': torch.load(CACHE_DIR / 'element_fractions.pt', map_location='cpu', weights_only=True),
        'elem_mask': torch.load(CACHE_DIR / 'element_mask.pt', map_location='cpu', weights_only=True),
        'tc': torch.load(CACHE_DIR / 'tc_tensor.pt', map_location='cpu', weights_only=True),
        'magpie': torch.load(CACHE_DIR / 'magpie_tensor.pt', map_location='cpu', weights_only=True),
        'is_sc': torch.load(CACHE_DIR / 'is_sc_tensor.pt', map_location='cpu', weights_only=True),
        'tokens': torch.load(CACHE_DIR / 'formula_tokens.pt', map_location='cpu', weights_only=True),
    }
    if data['magpie'].shape[1] > magpie_dim:
        data['magpie'] = data['magpie'][:, :magpie_dim]
    meta = json.load(open(CACHE_DIR / 'cache_meta.json'))
    data['train_indices'] = meta.get('train_indices', list(range(len(data['elem_idx']))))
    print(f"  {len(data['elem_idx'])} total samples, {len(data['train_indices'])} train")
    return data

In [None]:
# Load everything
encoder, decoder, magpie_dim, has_numden_head, model_epoch = load_models(CHECKPOINT_PATH)
data = load_data(magpie_dim)

with open(HOLDOUT_PATH) as f:
    holdout = json.load(f)

holdout_samples = holdout['holdout_samples']
print(f"\n{len(holdout_samples)} holdout materials loaded")
print(f"Families: {sorted(set(s['family'] for s in holdout_samples))}")

## 2. Part 1: Roundtrip Validation

For each of the 45 holdout superconductors, encode it → Z → decode ALL heads:
- Formula (autoregressive generation)
- Tc (regression, Kelvin)
- SC classification (binary)
- Family classification (14-class)
- Magpie features (145-dim)
- High-pressure prediction

In [None]:
@torch.no_grad()
def full_forward(encoder, decoder, elem_idx, elem_frac, elem_mask, magpie, tc, temperature=0.01):
    """Run full encoder forward (all heads) + formula decoder."""
    enc_out = encoder(
        elem_idx.to(DEVICE), elem_frac.to(DEVICE), elem_mask.to(DEVICE),
        magpie.to(DEVICE), tc.to(DEVICE),
    )
    z = enc_out['z']
    tc_pred_norm = enc_out['tc_pred'][0].item()
    magpie_pred = enc_out['magpie_pred'][0].cpu()

    # Stoich conditioning for decoder
    fraction_pred = enc_out['fraction_pred']
    element_count_pred = enc_out['element_count_pred']
    numden_pred = enc_out.get('numden_pred')
    if numden_pred is not None:
        stoich_pred = torch.cat([fraction_pred, numden_pred, element_count_pred.unsqueeze(-1)], dim=-1)
    else:
        stoich_pred = torch.cat([fraction_pred, element_count_pred.unsqueeze(-1)], dim=-1)

    generated, log_probs, entropy = decoder.generate_with_kv_cache(
        z=z, stoich_pred=stoich_pred, temperature=temperature,
    )
    formula = tokens_to_formula(generated[0])
    tc_kelvin = denormalize_tc(tc_pred_norm)

    result = {
        'formula': formula, 'tc_pred_kelvin': tc_kelvin, 'tc_pred_norm': tc_pred_norm,
        'magpie_pred': magpie_pred,
    }

    # Tc classification
    tc_class_logits = enc_out.get('tc_class_logits')
    if tc_class_logits is not None:
        probs = torch.softmax(tc_class_logits[0].cpu(), dim=-1)
        result['tc_class'] = probs.argmax().item()

    # SC classification
    sc_pred = enc_out.get('sc_pred')
    if sc_pred is not None:
        result['sc_prob'] = torch.sigmoid(sc_pred[0]).item()
        result['sc_pred'] = result['sc_prob'] > 0.5

    # Family classification
    family_composed = enc_out.get('family_composed_14')
    if family_composed is not None:
        result['family_pred_14'] = family_composed[0].cpu().argmax().item()

    # High-pressure
    hp_pred = enc_out.get('hp_pred')
    if hp_pred is not None:
        result['hp_prob'] = torch.sigmoid(hp_pred[0]).item()

    return result

In [None]:
# Run roundtrip validation
print("=" * 80)
print(f"ROUNDTRIP VALIDATION — Epoch {model_epoch}")
print("=" * 80)
print(f"{'Family':<14s} {'True Tc':>8s} {'Pred Tc':>9s} {'Err':>7s} {'Sim':>5s} {'SC?':>5s} {'FamPred':<16s} | Formula")
print("-" * 120)

roundtrip_results = []
tc_errors = []
magpie_mses = []
formula_sims = []
sc_correct = 0
family_correct = 0

for sample in holdout_samples:
    formula = sample['formula']
    true_tc = sample['Tc']
    family = sample['family']
    orig_idx = sample.get('original_index')
    if orig_idx is None:
        continue

    idx_t = torch.tensor([orig_idx], dtype=torch.long)
    decoded = full_forward(
        encoder, decoder,
        data['elem_idx'][idx_t], data['elem_frac'][idx_t],
        data['elem_mask'][idx_t], data['magpie'][idx_t], data['tc'][idx_t],
    )

    pred_tc = decoded['tc_pred_kelvin']
    gen_formula = decoded['formula']
    tc_err = abs(pred_tc - true_tc)
    tc_errors.append(tc_err)

    mag_mse = F.mse_loss(decoded['magpie_pred'], data['magpie'][orig_idx]).item()
    magpie_mses.append(mag_mse)

    sim = element_similarity(gen_formula, formula)
    formula_sims.append(sim)
    exact = gen_formula.strip() == formula.strip()

    sc_str = f"{decoded.get('sc_prob', 0):.2f}"
    if decoded.get('sc_pred', False):
        sc_correct += 1

    family_pred_str = ''
    if 'family_pred_14' in decoded:
        pred_idx = decoded['family_pred_14']
        pred_name = FAMILY_14_NAMES[pred_idx]
        true_idx = HOLDOUT_FAMILY_TO_14.get(family, -1)
        match = pred_idx == true_idx
        if match:
            family_correct += 1
        family_pred_str = f"{pred_name} {'OK' if match else 'X'}"

    exact_str = ' [EXACT]' if exact else ''
    print(f"  [{family:<12s}] {true_tc:8.1f} {pred_tc:9.1f} {tc_err:+7.1f} {sim:5.3f} {sc_str:>5s} {family_pred_str:<16s} | {formula}")
    if not exact:
        print(f"     -> {gen_formula}")
    else:
        print(f"     -> {gen_formula}{exact_str}")

    roundtrip_results.append({
        'formula': formula, 'generated': gen_formula, 'exact': exact,
        'similarity': sim, 'true_tc': true_tc, 'pred_tc': pred_tc,
        'tc_error': tc_err, 'family': family,
        'sc_prob': decoded.get('sc_prob'), 'family_pred': family_pred_str,
    })

In [None]:
# Roundtrip Summary
tc_arr = np.array(tc_errors)
sim_arr = np.array(formula_sims)
mag_arr = np.array(magpie_mses)

print("\n" + "=" * 80)
print("ROUNDTRIP SUMMARY")
print("=" * 80)

print(f"\nTc Prediction:")
print(f"  MAE: {tc_arr.mean():.2f} K (median: {np.median(tc_arr):.2f} K)")
print(f"  Within 1K: {(tc_arr < 1).sum()}/45 | Within 5K: {(tc_arr < 5).sum()}/45")

print(f"\nFormula Roundtrip:")
print(f"  Mean similarity: {sim_arr.mean():.3f}")
print(f"  Exact matches: {(sim_arr >= 0.999).sum()}/45")
print(f"  > 0.95 similarity: {(sim_arr > 0.95).sum()}/45")

print(f"\nSC Classification: {sc_correct}/45 ({sc_correct/45*100:.1f}%)")
print(f"Family Classification: {family_correct}/45 ({family_correct/45*100:.1f}%)")
print(f"Magpie MSE: {mag_arr.mean():.6f}")

## 3. Part 2: Targeted Holdout Search

For each holdout formula, find training samples sharing the same elements, encode those as Z seeds, and explore the Z neighborhood using:

1. **Fine-grained perturbation** (8 noise scales x 100 samples x 30 seeds)
2. **Pairwise interpolation** (linear + SLERP, 15 steps, up to 100 pairs)
3. **Centroid + random walks** (5 scales x 30 directions)
4. **PCA-directed walks** (top 20 principal components of neighbor Z distribution)
5. **Temperature sampling** (8 temperatures x 30 samples x 15 seeds)

Total: ~27,500 Z-space candidates per target formula.

In [None]:
# Search parameters
N_PERTURBATIONS = 100
NOISE_SCALES = [0.02, 0.05, 0.08, 0.1, 0.15, 0.2, 0.3, 0.5]
N_INTERPOLATION_STEPS = 15
N_TEMPERATURE_SAMPLES = 30
TEMPERATURES = [0.01, 0.05, 0.1, 0.2, 0.3, 0.5, 0.7, 1.0]


@torch.no_grad()
def encode_indices(encoder, data, indices):
    """Encode specific dataset indices -> Z vectors."""
    batch_size = 128
    all_z = []
    for start in range(0, len(indices), batch_size):
        batch_idx = indices[start:start + batch_size]
        idx_t = torch.tensor(batch_idx, dtype=torch.long)
        result = encoder.encode(
            data['elem_idx'][idx_t].to(DEVICE),
            data['elem_frac'][idx_t].to(DEVICE),
            data['elem_mask'][idx_t].to(DEVICE),
            data['magpie'][idx_t].to(DEVICE),
            data['tc'][idx_t].to(DEVICE),
        )
        all_z.append(result['z'].cpu())
    return torch.cat(all_z, dim=0)


@torch.no_grad()
def decode_z_batch(encoder, decoder, z_batch, has_numden_head=False, temperature=0.01):
    """Decode Z vectors -> formula strings."""
    batch_size = 64
    all_formulas = []
    for start in range(0, len(z_batch), batch_size):
        z = z_batch[start:start + batch_size].to(DEVICE)
        fraction_output = encoder.fraction_head(z)
        fraction_pred = fraction_output[:, :encoder.max_elements]
        element_count_pred = fraction_output[:, -1]
        if has_numden_head and hasattr(encoder, 'numden_head'):
            numden_pred = encoder.numden_head(z)
            stoich_pred = torch.cat([fraction_pred, numden_pred, element_count_pred.unsqueeze(-1)], dim=-1)
        else:
            stoich_pred = torch.cat([fraction_pred, element_count_pred.unsqueeze(-1)], dim=-1)
        generated, _, _ = decoder.generate_with_kv_cache(
            z=z, stoich_pred=stoich_pred, temperature=temperature,
        )
        for i in range(len(z)):
            all_formulas.append(tokens_to_formula(generated[i]))
    return all_formulas


def find_element_neighbors(target_formula, data, top_k=100):
    """Find training samples sharing elements with target."""
    parsed = parse_formula_elements(target_formula)
    if not parsed:
        return []
    target_atomic_nums = set()
    for elem in parsed.keys():
        z = ELEMENT_TO_Z.get(elem)
        if z:
            target_atomic_nums.add(z)

    scores = []
    for i in data['train_indices']:
        n_shared, jaccard = element_overlap_score(target_atomic_nums, data['elem_idx'][i])
        if n_shared > 0:
            scores.append((i, n_shared, jaccard))
    scores.sort(key=lambda x: (-x[1], -x[2]))
    return [s[0] for s in scores[:top_k]]

In [None]:
def search_single_target(encoder, decoder, data, target_formula, target_tc, target_family, has_numden_head=False):
    """Targeted search for a single holdout formula."""
    print(f"\n  TARGET: {target_formula} (Tc={target_tc}K, {target_family})")

    # Step 1: Find element-matched neighbors
    neighbor_indices = find_element_neighbors(target_formula, data, top_k=100)
    if len(neighbor_indices) < 3:
        print(f"    Only {len(neighbor_indices)} neighbors — skipping")
        return {'target': target_formula, 'target_tc': target_tc, 'target_family': target_family,
                'best_sim': 0.0, 'best_gen': '', 'n_unique': 0, 'n_total': 0,
                'exact': False, 'top_matches': [], 'top_frequent': []}

    # Step 2: Encode neighbors
    z_neighbors = encode_indices(encoder, data, neighbor_indices)
    z_seeds = z_neighbors[:min(30, len(z_neighbors))]
    print(f"    {len(neighbor_indices)} neighbors, {len(z_seeds)} seeds, Z norm: {z_neighbors.norm(dim=-1).mean():.2f}")

    # Step 3: Generate candidate Z vectors
    all_candidates_z = []

    # Strategy 1: Fine-grained perturbation
    for z in z_seeds:
        z_exp = z.unsqueeze(0)
        for scale in NOISE_SCALES:
            noise = torch.randn(N_PERTURBATIONS, z.shape[0]) * scale
            all_candidates_z.append(z_exp + noise)

    # Strategy 2: Pairwise interpolation (linear + SLERP)
    n_seeds = len(z_seeds)
    max_pairs = min(n_seeds * (n_seeds - 1) // 2, 100)
    pair_count = 0
    for i in range(n_seeds):
        for j in range(i + 1, n_seeds):
            if pair_count >= max_pairs:
                break
            z1 = z_seeds[i].unsqueeze(0)
            z2 = z_seeds[j].unsqueeze(0)
            for t in np.linspace(0.05, 0.95, N_INTERPOLATION_STEPS):
                all_candidates_z.append((1 - t) * z1 + t * z2)
                all_candidates_z.append(slerp(z1, z2, float(t)))
            pair_count += 1
        if pair_count >= max_pairs:
            break

    # Strategy 3: Centroid + random direction walks
    centroid = z_seeds.mean(dim=0, keepdim=True)
    std = z_seeds.std(dim=0, keepdim=True).clamp(min=1e-6)
    for scale in [0.3, 0.5, 1.0, 1.5, 2.0]:
        directions = torch.randn(30, z_seeds.shape[1])
        directions = F.normalize(directions, dim=-1)
        all_candidates_z.append(centroid + scale * std * directions)

    # Strategy 4: PCA-directed walks
    if len(z_neighbors) >= 10:
        z_np = z_neighbors.numpy()
        mean = z_np.mean(axis=0)
        centered = z_np - mean
        U, S, Vt = np.linalg.svd(centered, full_matrices=False)
        n_comp = min(20, len(S))
        for c in range(n_comp):
            direction = torch.from_numpy(Vt[c]).float()
            std_along = S[c] / np.sqrt(len(z_neighbors) - 1)
            for alpha in np.linspace(-3.0, 3.0, 20):
                z_walked = torch.from_numpy(mean).float().unsqueeze(0) + alpha * std_along * direction.unsqueeze(0)
                all_candidates_z.append(z_walked)

    all_z = torch.cat(all_candidates_z, dim=0)
    print(f"    {len(all_z)} Z candidates (greedy decode)...")

    # Step 4: Greedy decode
    t0 = time.time()
    greedy_formulas = decode_z_batch(encoder, decoder, all_z, has_numden_head=has_numden_head, temperature=0.01)
    print(f"    Greedy decoded in {time.time()-t0:.1f}s")

    # Step 5: Temperature sampling
    temp_formulas = []
    for temp in TEMPERATURES:
        for z_idx in range(min(len(z_seeds), 15)):
            z_repeated = z_seeds[z_idx].unsqueeze(0).repeat(N_TEMPERATURE_SAMPLES, 1)
            temp_formulas.extend(decode_z_batch(encoder, decoder, z_repeated, has_numden_head=has_numden_head, temperature=temp))

    all_generated = greedy_formulas + temp_formulas
    unique_formulas = set(f for f in all_generated if f and len(f) > 1)

    # Step 6: Score candidates
    formula_counts = defaultdict(int)
    for f in all_generated:
        if f and len(f) > 1:
            formula_counts[f] += 1

    candidates = []
    best_sim = 0.0
    best_gen = ''
    top_matches = []

    for formula_gen, count in formula_counts.items():
        sim = element_similarity(formula_gen, target_formula)
        candidates.append({'formula': formula_gen, 'count': count, 'similarity': sim})
        if sim > best_sim:
            best_sim = sim
            best_gen = formula_gen
        if sim >= 0.8:
            top_matches.append((formula_gen, sim, count))

    candidates.sort(key=lambda x: -x['similarity'])
    top_matches.sort(key=lambda x: -x[1])

    target_norm = target_formula.strip()
    exact = any(f.strip() == target_norm for f in unique_formulas)

    # Top by frequency
    top_freq = sorted([(c['formula'], c['count']) for c in candidates], key=lambda x: -x[1])[:5]

    print(f"    Total: {len(all_generated)}, unique: {len(unique_formulas)}")
    print(f"    RESULT: best_sim={best_sim:.4f}, exact={'YES' if exact else 'no'}")
    if top_matches[:3]:
        for f, s, c in top_matches[:3]:
            print(f"      sim={s:.4f} ({c}x): {f}")

    return {
        'target': target_formula, 'target_tc': target_tc, 'target_family': target_family,
        'exact': exact, 'best_sim': float(best_sim), 'best_gen': best_gen,
        'n_unique': len(unique_formulas), 'n_total': len(all_generated),
        'n_neighbors': len(neighbor_indices),
        'top_matches': [(f, float(s)) for f, s, c in top_matches[:10]],
        'top_frequent': top_freq,
        'all_candidates': candidates[:50],  # Top 50 by similarity
    }

In [None]:
# Run targeted search on all 45 holdout targets
print("=" * 80)
print(f"TARGETED HOLDOUT SEARCH — Epoch {model_epoch}")
print("Element-Anchored Z-Space Exploration")
print("=" * 80)

search_results = []
t_start = time.time()

for i, sample in enumerate(holdout_samples):
    print(f"\n--- [{i+1}/45] ---")
    result = search_single_target(
        encoder, decoder, data,
        sample['formula'], sample['Tc'], sample['family'],
        has_numden_head=has_numden_head,
    )
    search_results.append(result)

total_time = time.time() - t_start
print(f"\nTotal search time: {total_time/60:.1f} minutes")

In [None]:
# Search Results Summary
print("\n" + "=" * 80)
print(f"TARGETED SEARCH SUMMARY — Epoch {model_epoch}")
print("=" * 80)

for r in search_results:
    marker = '***' if r.get('exact') else ' + ' if r['best_sim'] >= 0.95 else '   '
    print(f"{marker} [{r['target_family']:12s}] sim={r['best_sim']:.3f} | {r['target']}")
    if r['best_gen'] and not r.get('exact'):
        print(f"     Best: {r['best_gen']}")

n_exact = sum(1 for r in search_results if r.get('exact'))
print(f"\nExact matches: {n_exact}/{len(search_results)}")

print(f"\nRESULTS BY THRESHOLD:")
for thresh in [1.0, 0.99, 0.98, 0.95, 0.90, 0.85, 0.80]:
    found = sum(1 for r in search_results if r['best_sim'] >= thresh)
    pct = found / len(search_results) * 100
    bar = '#' * int(pct / 2)
    print(f"  >= {thresh:.2f}: {found:2d}/45 ({pct:5.1f}%) {bar}")

# Per-family breakdown
print(f"\nPER-FAMILY BREAKDOWN:")
families = sorted(set(r['target_family'] for r in search_results))
for fam in families:
    fam_results = [r for r in search_results if r['target_family'] == fam]
    n_exact_fam = sum(1 for r in fam_results if r.get('exact'))
    n_095 = sum(1 for r in fam_results if r['best_sim'] >= 0.95)
    avg_sim = np.mean([r['best_sim'] for r in fam_results])
    print(f"  {fam:14s}: exact={n_exact_fam}/5, >=0.95={n_095}/5, avg_sim={avg_sim:.3f}")

## 4. Visualization

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Plot 1: Similarity distribution
sims = [r['best_sim'] for r in search_results]
colors = ['#2ecc71' if s >= 0.95 else '#f39c12' if s >= 0.80 else '#e74c3c' for s in sims]
axes[0].barh(range(len(sims)), sims, color=colors)
axes[0].set_xlabel('Best Similarity')
axes[0].set_ylabel('Holdout Target Index')
axes[0].set_title('Best Match Similarity per Holdout Target')
axes[0].axvline(x=0.95, color='green', linestyle='--', alpha=0.5, label='0.95')
axes[0].axvline(x=0.80, color='orange', linestyle='--', alpha=0.5, label='0.80')
axes[0].legend()
axes[0].set_xlim(0, 1.05)

# Plot 2: Tc prediction error (roundtrip)
fam_colors = {
    'YBCO': '#e74c3c', 'LSCO': '#3498db', 'Hg-cuprate': '#2ecc71',
    'Tl-cuprate': '#9b59b6', 'Bi-cuprate': '#f39c12', 'Iron-based': '#1abc9c',
    'MgB2': '#e67e22', 'Conventional': '#95a5a6', 'Other': '#34495e',
}
for sample, err in zip(holdout_samples, tc_errors):
    axes[1].scatter(sample['Tc'], err, color=fam_colors.get(sample['family'], 'gray'),
                   s=50, alpha=0.7)
axes[1].set_xlabel('True Tc (K)')
axes[1].set_ylabel('Absolute Error (K)')
axes[1].set_title('Tc Prediction Error vs True Tc')
axes[1].axhline(y=1.0, color='gray', linestyle='--', alpha=0.5)

# Plot 3: Per-family average similarity
fam_avg = {}
for r in search_results:
    fam = r['target_family']
    fam_avg.setdefault(fam, []).append(r['best_sim'])
fam_names = sorted(fam_avg.keys())
fam_means = [np.mean(fam_avg[f]) for f in fam_names]
bars = axes[2].barh(fam_names, fam_means, color=[fam_colors.get(f, 'gray') for f in fam_names])
axes[2].set_xlabel('Average Best Similarity')
axes[2].set_title('Average Similarity by Family')
axes[2].set_xlim(0, 1.05)
axes[2].axvline(x=0.95, color='green', linestyle='--', alpha=0.5)

plt.tight_layout()
plt.savefig(os.path.join(REPO_DIR, 'outputs', 'generative_evaluation.png'), dpi=150, bbox_inches='tight')
plt.show()
print("Saved: outputs/generative_evaluation.png")

In [None]:
# Grade card
print("\n" + "=" * 80)
print(f"GENERATIVE EVALUATION GRADE CARD — Epoch {model_epoch}")
print("=" * 80)

tc_mae = np.mean(tc_errors)
tc_within_1k = (np.array(tc_errors) < 1).sum()
n_exact_search = sum(1 for r in search_results if r.get('exact'))
n_095 = sum(1 for r in search_results if r['best_sim'] >= 0.95)
n_099 = sum(1 for r in search_results if r['best_sim'] >= 0.99)
avg_best_sim = np.mean([r['best_sim'] for r in search_results])

def grade(val, thresholds):
    """A+, A, B+, B, C, D, F grading."""
    grades = ['A+', 'A', 'B+', 'B', 'C', 'D', 'F']
    for i, t in enumerate(thresholds):
        if val >= t:
            return grades[i]
    return grades[-1]

metrics = [
    ('SC Classification', f'{sc_correct}/45 ({sc_correct/45*100:.0f}%)',
     grade(sc_correct/45, [0.98, 0.95, 0.90, 0.80, 0.70, 0.60])),
    ('Tc Prediction MAE', f'{tc_mae:.2f}K',
     grade(1/(1+tc_mae), [0.67, 0.50, 0.33, 0.20, 0.10, 0.05])),
    ('Tc within 1K', f'{tc_within_1k}/45 ({tc_within_1k/45*100:.0f}%)',
     grade(tc_within_1k/45, [0.90, 0.80, 0.70, 0.60, 0.50, 0.30])),
    ('Family Classification', f'{family_correct}/45 ({family_correct/45*100:.0f}%)',
     grade(family_correct/45, [0.95, 0.90, 0.80, 0.70, 0.60, 0.50])),
    ('Formula Exact (search)', f'{n_exact_search}/45 ({n_exact_search/45*100:.0f}%)',
     grade(n_exact_search/45, [0.80, 0.60, 0.40, 0.25, 0.15, 0.05])),
    ('Formula >= 0.99 sim', f'{n_099}/45 ({n_099/45*100:.0f}%)',
     grade(n_099/45, [0.90, 0.80, 0.60, 0.40, 0.25, 0.10])),
    ('Formula >= 0.95 sim', f'{n_095}/45 ({n_095/45*100:.0f}%)',
     grade(n_095/45, [0.95, 0.85, 0.70, 0.55, 0.40, 0.25])),
    ('Average Best Similarity', f'{avg_best_sim:.3f}',
     grade(avg_best_sim, [0.95, 0.90, 0.85, 0.80, 0.70, 0.60])),
    ('Magpie MSE', f'{np.mean(magpie_mses):.4f}',
     grade(1/(1+np.mean(magpie_mses)*10), [0.90, 0.80, 0.60, 0.40, 0.25, 0.10])),
]

print(f"\n{'Metric':<28s} {'Result':<28s} {'Grade':>5s}")
print("-" * 65)
for name, result, g in metrics:
    print(f"  {name:<26s} {result:<28s} {g:>5s}")

print(f"\n{'='*65}")
overall = (
    f"The model {'CAN' if n_exact_search >= 10 else 'CANNOT yet'} reliably generate "
    f"unseen superconductors.\n"
    f"  - {n_exact_search}/45 exact matches from Z-space exploration\n"
    f"  - {n_095}/45 with >= 0.95 compositional similarity\n"
    f"  - Tc prediction: {tc_mae:.2f}K MAE on held-out materials\n"
)
print(overall)

In [None]:
# Save results to JSON
output = {
    'timestamp': time.strftime('%Y-%m-%dT%H:%M:%S'),
    'checkpoint': str(CHECKPOINT_PATH),
    'epoch': model_epoch,
    'roundtrip': {
        'tc_mae_K': float(tc_mae),
        'tc_within_1K': int(tc_within_1k),
        'sc_accuracy': sc_correct / 45,
        'family_accuracy': family_correct / 45,
        'magpie_mse': float(np.mean(magpie_mses)),
        'formula_mean_sim': float(np.mean(formula_sims)),
        'details': roundtrip_results,
    },
    'search': {
        'n_exact': n_exact_search,
        'n_095': n_095,
        'n_099': n_099,
        'avg_best_sim': float(avg_best_sim),
        'details': [{k: v for k, v in r.items() if k != 'all_candidates'} for r in search_results],
    },
}
output_path = os.path.join(REPO_DIR, 'outputs', f'generative_evaluation_epoch_{model_epoch}.json')
with open(output_path, 'w') as f:
    json.dump(output, f, indent=2, default=str)
print(f"Results saved to: {output_path}")