In [None]:
# ===================================================================
# ENCODER TRAINING v3 — FINAL (Phase 1 only + slide-scale reg)
# ===================================================================
# Decision: Trunk = Phase 1 checkpoint (VICReg + Spatial InfoNCE only)
#           Domain alignment = post-hoc adapter (trained separately, trunk frozen)
#
# Changes from v2:
#   - two_phase_training=False  (no Phase 2 — alignment broke locality)
#   - n_epochs=1200  (Phase 1 saturates ~900-1000, extra buffer for best checkpoint)
#   - slide_scale_weight=5.0  (equalize per-slide RMS norms during training)
#   - adv_slide_weight=0.0  (GRL off, disc training skipped to save compute)
#   - All CORAL/MMD/alignment weights=0  (no trunk alignment)
#   - Everything else same as v2 (NCE=5.0, inv warmup, etc.)
# ===================================================================

print("\n" + "=" * 70)
print("TRAINING ENCODER v3 — FINAL (Phase 1 only)")
print("=" * 70)

set_seed(SEED)
encoder_v3 = SharedEncoder(n_genes=n_genes, n_embedding=[512, 256, 128], dropout=0.1)

outdir_v3 = '/home/ehtesamul/sc_st/model/gems_liver_crossslide_v3'
os.makedirs(outdir_v3, exist_ok=True)

encoder_v3, projector_v3, discriminator_v3, hist_v3 = train_encoder(
    inference_dropout_prob=0.0,       # No inference dropout (no SC alignment phase)
    model=encoder_v3,
    st_gene_expr=st_expr,
    st_coords=st_coords,
    sc_gene_expr=inf_expr,
    slide_ids=slide_ids,
    sc_slide_ids=inf_slide_ids,
    sc_patient_ids=inf_patient_ids,
    n_epochs=1200,                    # Phase 1 only — saturates by ~1000
    batch_size=256,
    lr=1e-4,
    device=device,
    outf=outdir_v3,
    st_source_ids=st_source_ids,
    sc_source_ids=inf_source_ids,
    use_source_adversary=False,
    source_coral_weight=0.0,          # OFF — no trunk alignment
    stageA_obj='vicreg_adv',
    # ---- VICReg: same warmup as v2 ----
    vicreg_lambda_inv=25.0,
    vicreg_lambda_var=50.0,
    vicreg_lambda_cov=1.0,
    vicreg_inv_warmup_frac=0.3,
    vicreg_inv_start=5.0,
    vicreg_gamma=1.0,
    vicreg_eps=1e-4,
    vicreg_project_dim=256,
    vicreg_use_projector=False,
    vicreg_float32_stats=True,
    vicreg_ddp_gather=False,
    aug_gene_dropout=0.25,
    aug_gauss_std=0.01,
    aug_scale_jitter=0.1,
    # ---- ALL alignment OFF (domain compat via post-hoc adapter) ----
    adv_slide_weight=0.0,             # GRL off + disc training skipped
    patient_coral_weight=0.0,
    mmd_weight=0.0,                   # OFF — was 30.0 in v2
    mmd_use_l2norm=True,
    mmd_ramp=True,
    adv_warmup_epochs=50,
    adv_ramp_epochs=200,
    grl_alpha_max=1.0,
    disc_hidden=512,
    disc_dropout=0.1,
    stageA_balanced_slides=True,
    adv_representation_mode='clean',
    adv_use_layernorm=False,
    adv_log_diagnostics=False,        # No disc = no diagnostics needed
    adv_log_grad_norms=False,
    use_local_align=False,            # OFF
    return_aux=True,
    local_align_bidirectional=True,
    local_align_weight=0.0,
    local_align_tau_z=0.07,
    seed=SEED,
    use_best_checkpoint=True,
    coral_raw_weight=0.0,             # OFF
    knn_weight=0.0,
    # ---- Spatial NCE: same as v2 ----
    spatial_nce_weight=5.0,
    spatial_nce_k_phys=20,
    spatial_nce_far_mult=4.0,
    spatial_nce_n_hard=20,
    spatial_nce_tau=0.1,
    spatial_nce_n_rand_neg=128,
    spatial_nce_n_anchors=64,
    # ---- NO Phase 2 ----
    two_phase_training=False,
    # ---- Slide-scale regularizer (equalize per-slide norms) ----
    slide_scale_weight=5.0,           # NEW: L_ss = Σ_s (log RMS(z_s) - log RMS(z_all))^2
    # ---- Per-slide normalization (for diagnostics) ----
    per_slide_norm=True,
    per_slide_norm_target=1.0,
)

print("\n" + "=" * 70)
print("v3 TRAINING COMPLETE")
print("=" * 70)
torch.save(encoder_v3.state_dict(), f'{outdir_v3}/encoder_v3_final.pt')
print(f"Saved: {outdir_v3}/encoder_v3_final.pt")

# Post-training: per-slide normalization report
from core_models_et_p1 import normalize_embeddings_per_slide
encoder_v3.eval()
with torch.no_grad():
    z_parts = []
    for i in range(0, st_expr.shape[0], 512):
        z_parts.append(encoder_v3(st_expr[i:i+512]))
    z_all = torch.cat(z_parts, dim=0)
    z_normed = normalize_embeddings_per_slide(z_all, slide_ids, target_rms=1.0)

    print("\nPer-slide norm report (before → after normalization):")
    for sid in torch.unique(slide_ids):
        mask = (slide_ids == sid)
        nb = z_all[mask].norm(dim=1).mean().item()
        na = z_normed[mask].norm(dim=1).mean().item()
        print(f"  Slide {sid.item()}: {nb:.2f} → {na:.2f}")

# Save normalized embeddings for adapter training
torch.save({
    'z_raw': z_all.cpu(),
    'z_normed': z_normed.cpu(),
    'slide_ids': slide_ids.cpu(),
}, f'{outdir_v3}/st_embeddings.pt')
print(f"\nSaved ST embeddings: {outdir_v3}/st_embeddings.pt")
print("Next step: Train affine adapter with trunk frozen (see adapter_topology_eval.py)")


In [None]:
# ===================================================================
# ENCODER TRAINING v2 — with diagnostic improvements
# ===================================================================
# Changes from v1 (Cell 5):
#   - adv_slide_weight=0.0  (GRL kept in code, just zeroed — preserves code for future use)
#   - vicreg_inv_warmup_frac=0.3, vicreg_inv_start=5.0  (ramp λ_inv over first 30% of Phase 1)
#   - spatial_nce_weight=5.0  (strengthened from 3.0)
#   - per_slide_norm=True  (track per-slide norms in diagnostics)
#   - Now logs: overlap@20, hit@20, per-slide norms every 100 epochs
# ===================================================================

print("\n" + "=" * 70)
print("TRAINING ENCODER v2 — DIAGNOSTIC IMPROVEMENTS")
print("=" * 70)

set_seed(SEED)
encoder_v2 = SharedEncoder(n_genes=n_genes, n_embedding=[512, 256, 128], dropout=0.1)

outdir_v2 = '/home/ehtesamul/sc_st/model/gems_liver_crossslide_v2'
os.makedirs(outdir_v2, exist_ok=True)

encoder_v2, projector_v2, discriminator_v2, hist_v2 = train_encoder(
    inference_dropout_prob=0.5,
    model=encoder_v2,
    st_gene_expr=st_expr,
    st_coords=st_coords,
    sc_gene_expr=inf_expr,
    slide_ids=slide_ids,
    sc_slide_ids=inf_slide_ids,
    sc_patient_ids=inf_patient_ids,
    n_epochs=2000,
    batch_size=256,
    lr=1e-4,
    device=device,
    outf=outdir_v2,
    st_source_ids=st_source_ids,
    sc_source_ids=inf_source_ids,
    use_source_adversary=False,
    source_coral_weight=75.0,
    stageA_obj='vicreg_adv',
    # ---- VICReg: ramp λ_inv from 5 → 25 over first 30% of Phase 1 ----
    vicreg_lambda_inv=25.0,
    vicreg_lambda_var=50.0,
    vicreg_lambda_cov=1.0,
    vicreg_inv_warmup_frac=0.3,   # NEW: ramp λ_inv over first 30% of Phase 1
    vicreg_inv_start=5.0,         # NEW: start λ_inv low
    vicreg_gamma=1.0,
    vicreg_eps=1e-4,
    vicreg_project_dim=256,
    vicreg_use_projector=False,
    vicreg_float32_stats=True,
    vicreg_ddp_gather=False,
    aug_gene_dropout=0.25,
    aug_gauss_std=0.01,
    aug_scale_jitter=0.1,
    # ---- GRL: zeroed weight (code preserved for future use) ----
    adv_slide_weight=0.0,         # CHANGED: was 75.0, now zeroed
    patient_coral_weight=0.0,
    mmd_weight=30.0,
    mmd_use_l2norm=True,
    mmd_ramp=True,
    adv_warmup_epochs=50,
    adv_ramp_epochs=200,
    grl_alpha_max=1.0,
    disc_hidden=512,
    disc_dropout=0.1,
    stageA_balanced_slides=True,
    adv_representation_mode='clean',
    adv_use_layernorm=False,
    adv_log_diagnostics=True,
    adv_log_grad_norms=False,
    use_local_align=True,
    return_aux=True,
    local_align_bidirectional=True,
    local_align_weight=0.0,
    local_align_tau_z=0.07,
    seed=SEED,
    use_best_checkpoint=True,
    coral_raw_weight=2.0,
    knn_weight=0.0,
    # ---- Spatial NCE: strengthened ----
    spatial_nce_weight=5.0,       # CHANGED: was 3.0, now 5.0
    spatial_nce_k_phys=20,
    spatial_nce_far_mult=4.0,
    spatial_nce_n_hard=20,
    spatial_nce_tau=0.1,
    spatial_nce_n_rand_neg=128,
    spatial_nce_n_anchors=64,
    # ---- Two-phase training ----
    two_phase_training=True,
    phase1_epochs=1000,
    phase2_lr_factor=0.1,
    # ---- Per-slide normalization (diagnostic tracking) ----
    per_slide_norm=True,           # NEW: track per-slide norms
    per_slide_norm_target=1.0,
)

print("\n✓ Training v2 complete!")
torch.save(encoder_v2.state_dict(), f'{outdir_v2}/encoder_v2_final.pt')
print(f"✓ Saved to: {outdir_v2}/encoder_v2_final.pt")

# Post-training: apply per-slide normalization and save
from core_models_et_p1 import normalize_embeddings_per_slide
encoder_v2.eval()
with torch.no_grad():
    z_parts = []
    for i in range(0, st_expr.shape[0], 512):
        z_parts.append(encoder_v2(st_expr[i:i+512]))
    z_all = torch.cat(z_parts, dim=0)
    z_normed = normalize_embeddings_per_slide(z_all, slide_ids, target_rms=1.0)

    # Report norm stats before/after
    for sid in torch.unique(slide_ids):
        mask = (slide_ids == sid)
        norm_before = z_all[mask].norm(dim=1).mean().item()
        norm_after = z_normed[mask].norm(dim=1).mean().item()
        print(f"  Slide {sid.item()}: norm {norm_before:.2f} → {norm_after:.2f}")


In [None]:
# ===================================================================
# SECTION D (cont): VISUALIZE ADAPTER RESULTS
# ===================================================================

import matplotlib.pyplot as plt
import numpy as np

hist = adapter_results['history']
epochs_h = hist['epoch']

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Panel 1: CORAL + MMD loss
ax = axes[0, 0]
ax.plot(epochs_h, hist['loss_coral'], label='CORAL', color='blue', linewidth=2)
ax.plot(epochs_h, hist['loss_mmd'], label='MMD', color='red', linewidth=2)
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('Domain Alignment Losses')
ax.legend()
ax.grid(alpha=0.3)

# Panel 2: Centroid distance
ax = axes[0, 1]
ax.plot(epochs_h, hist['domain_centroid_dist'], color='purple', linewidth=2)
ax.set_xlabel('Epoch')
ax.set_ylabel('L2 Distance')
ax.set_title('ST-SC Centroid Distance')
ax.grid(alpha=0.3)

# Panel 3: ST overlap@20 (should be flat)
ax = axes[1, 0]
ax.plot(epochs_h, hist['st_overlap_at_20'], color='green', linewidth=2, marker='o', markersize=3)
ax.set_xlabel('Epoch')
ax.set_ylabel('overlap@20')
ax.set_title('ST Spatial Locality (should be STABLE)')
ax.grid(alpha=0.3)
ax.set_ylim([0, max(0.7, max(hist['st_overlap_at_20']) * 1.1)])

# Panel 4: PCA of embeddings (before vs after adapter)
ax = axes[1, 1]
z_st = adapter_results['z_st_frozen'].cpu().numpy()
z_sc_raw = encoder_for_adapter(inf_expr).detach().cpu().numpy() if False else None  # skip if not needed
z_sc_adapted = adapter_results['z_sc_adapted'].cpu().numpy()

from sklearn.decomposition import PCA
Z_all_vis = np.vstack([z_st, z_sc_adapted])
pca = PCA(n_components=2).fit(Z_all_vis)
st_pca = pca.transform(z_st)
sc_pca = pca.transform(z_sc_adapted)

ax.scatter(st_pca[:, 0], st_pca[:, 1], c='steelblue', s=8, alpha=0.3, label='ST (frozen)')
ax.scatter(sc_pca[:, 0], sc_pca[:, 1], c='orange', s=8, alpha=0.3, label='SC (adapted)')
ax.set_xlabel('PC1')
ax.set_ylabel('PC2')
ax.set_title('PCA: ST (frozen) vs SC (adapted)')
ax.legend()
ax.grid(alpha=0.3)

plt.suptitle('SC Adapter Training Results', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

# Summary table
res = adapter_results['results']
print("\n" + "=" * 60)
print("SC ADAPTER EXPERIMENT SUMMARY")
print("=" * 60)
print(f"  ST overlap@20:       {res['st_overlap_at_20']:.4f}  (should be ~0.686 for v3)")
print(f"  Domain acc BEFORE:   {res['domain_acc_before']:.4f}  (closer to 0.5 = better mixing)")
print(f"  Domain acc AFTER:    {res['domain_acc_after']:.4f}")
print(f"  CORAL (final):       {res['final_coral']:.4f}")
print(f"  MMD (final):         {res['final_mmd']:.4f}")
print(f"  Centroid dist:       {res['final_centroid_dist']:.4f}")
print("=" * 60)

In [None]:
# ===================================================================
# SECTION D: SC ADAPTER TRAINING (ST-Anchored Alignment)  — v3
# ===================================================================
# v3 encoder is frozen. A per-dim affine adapter (z' = a*z + b) is
# trained with CORAL + MMD to align SC embeddings to frozen ST space.
# Affine adapter CANNOT scramble neighborhoods (topology-safe).
# ST overlap@20 stays unchanged by design.
# ===================================================================

from sc_adapter import SCAdapter, train_sc_adapter
from ssl_utils import precompute_spatial_nce_structures

# Load the v3 encoder (Phase 1 only, slide-scale regularized)
encoder_for_adapter = SharedEncoder(n_genes=n_genes, n_embedding=[512, 256, 128], dropout=0.1)
encoder_for_adapter.load_state_dict(torch.load(
    '/home/ehtesamul/sc_st/model/gems_liver_crossslide_v3/encoder_v3_final.pt',
    map_location=device
))
encoder_for_adapter.to(device)

# Precompute physical kNN for overlap monitoring
spatial_nce_data_full = precompute_spatial_nce_structures(
    st_coords=st_coords, st_gene_expr=st_expr, slide_ids=slide_ids,
    k_phys=20, far_mult=4.0, n_hard=20, device=device,
)

adapter_outdir = '/home/ehtesamul/sc_st/model/gems_liver_crossslide_v3/sc_adapter'

adapter_results = train_sc_adapter(
    encoder=encoder_for_adapter,
    st_gene_expr=st_expr,
    sc_gene_expr=inf_expr,
    st_coords=st_coords,
    slide_ids=slide_ids,
    # Adapter config — affine (per-dim scale+shift, topology-safe)
    adapter_mode='affine',
    adapter_dropout=0.0,          # not used for affine, kept for API compat
    # Training config
    n_epochs=500,
    batch_size=256,
    lr=1e-3,
    weight_decay=1e-4,
    # Loss weights
    coral_weight=10.0,
    mmd_weight=10.0,
    # Diagnostics
    log_every=25,
    device=device,
    seed=SEED,
    outf=adapter_outdir,
    phys_knn_idx=spatial_nce_data_full['pos_idx'],
)

In [None]:
import torch
import torch.nn.functional as F
import scanpy as sc
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import balanced_accuracy_score
from sklearn.model_selection import cross_val_score, StratifiedKFold
import sys
import os

sys.path.insert(0, '/home/ehtesamul/sc_st/model')
from core_models_et_p1 import SharedEncoder, train_encoder
from ssl_utils import set_seed
import utils_et as uet

device = 'cuda' if torch.cuda.is_available() else 'cpu'
SEED = 42
set_seed(SEED)

print("=" * 70)
print("CROSS-SLIDE ALIGNMENT: Mouse Liver")
print("ST training = ST1, ST2, ST3 | Inference target = ST4")
print("Same patient (mouse) — cross-slide gap only")
print("=" * 70)

# Load data
ST_PATHS = {
    'liver_ST1': '/home/ehtesamul/sc_st/data/liver/stadata1.h5ad',
    'liver_ST2': '/home/ehtesamul/sc_st/data/liver/stadata2.h5ad',
    'liver_ST3': '/home/ehtesamul/sc_st/data/liver/stadata3.h5ad',
}
INF_PATHS = {
    'liver_ST4': '/home/ehtesamul/sc_st/data/liver/stadata4.h5ad',
}

st_data = {}
for name, path in ST_PATHS.items():
    st_data[name] = sc.read_h5ad(path)
    print(f"  {name}: {st_data[name].n_obs} spots")

inf_data = {}
for name, path in INF_PATHS.items():
    inf_data[name] = sc.read_h5ad(path)
    print(f"  {name}: {inf_data[name].n_obs} spots")

# Normalize
all_data = list(st_data.values()) + list(inf_data.values())
for adata in all_data:
    sc.pp.normalize_total(adata)
    sc.pp.log1p(adata)

# Common genes
common_genes = sorted(set.intersection(*[set(a.var_names) for a in all_data]))
n_genes = len(common_genes)
print(f"\n✓ Common genes: {n_genes}")

# Extract expression
def extract_expr(adata, genes):
    X = adata[:, genes].X
    return X.toarray() if hasattr(X, "toarray") else X

X_st = {name: extract_expr(st_data[name], common_genes) for name in ST_PATHS}
X_inf = {name: extract_expr(inf_data[name], common_genes) for name in INF_PATHS}

for name, X in {**X_st, **X_inf}.items():
    print(f"  {name}: {X.shape}")

labels_str = list(ST_PATHS.keys()) + list(INF_PATHS.keys())

colors_map = {
    'liver_ST1': '#e74c3c',
    'liver_ST2': '#3498db',
    'liver_ST3': '#2ecc71',
    'liver_ST4': '#f39c12',
}

In [None]:
print("\n" + "=" * 70)
print("PREPARING TRAINING DATA")
print("=" * 70)

# ST domain (training slides with coords)
st_expr = torch.tensor(
    np.vstack([X_st[n] for n in ST_PATHS]),
    dtype=torch.float32, device=device
)

st_coords_list = [st_data[n].obsm['spatial'] for n in ST_PATHS]
st_coords_raw = torch.tensor(
    np.vstack(st_coords_list), dtype=torch.float32, device=device
)

ns = [X_st[n].shape[0] for n in ST_PATHS]
slide_ids = torch.tensor(
    np.concatenate([np.full(n, i, dtype=int) for i, n in enumerate(ns)]),
    dtype=torch.long, device=device
)

st_coords, st_mu, st_scale = uet.canonicalize_st_coords_per_slide(st_coords_raw, slide_ids)

print(f"✓ ST expr: {st_expr.shape}")
for i, (name, n) in enumerate(zip(ST_PATHS, ns)):
    print(f"  - {name}: {n} spots (slide {i})")
print(f"✓ ST coords: {st_coords.shape} (canonicalized per-slide)")

# Inference domain (ST4 treated as "SC")
inf_expr = torch.tensor(X_inf['liver_ST4'], dtype=torch.float32, device=device)
n_inf = X_inf['liver_ST4'].shape[0]
inf_slide_ids = torch.zeros(n_inf, dtype=torch.long, device=device)
inf_patient_ids = torch.zeros(n_inf, dtype=torch.long, device=device)  # same patient

print(f"\n✓ Inference expr: {inf_expr.shape}")
print(f"  - liver_ST4: {n_inf} spots (inference domain)")

# Source IDs for adversary
st_source_ids = torch.tensor(
    np.concatenate([np.full(n, i, dtype=int) for i, n in enumerate(ns)]),
    dtype=torch.long
)
inf_source_ids = torch.full((n_inf,), len(ns), dtype=torch.long)

print(f"\n✓ ST source IDs: {st_source_ids.unique().tolist()}, counts: {ns}")
print(f"✓ Inf source IDs: {inf_source_ids.unique().tolist()}, count: {n_inf}")
print(f"\nTotal for Stage A: {st_expr.shape[0] + n_inf}")

In [None]:
# ===================================================================
# SPATIAL InfoNCE DIAGNOSTICS (SUPPORT-SET VERSION)
# ===================================================================
# CHECK A: Per-loss-component gradient norms (is spatial NCE "live"?)
# CHECK B: Logit gap over 200 steps (sim_pos vs sim_hard vs sim_rand)
# CHECK C: Index set sanity (pos/neg physical distances + overlap)
# ===================================================================

from ssl_utils import diagnose_spatial_infonce, set_seed
from core_models_et_p1 import SharedEncoder

print("=" * 70)
print("SPATIAL InfoNCE DIAGNOSTIC (support-set)")
print("=" * 70)

set_seed(SEED)
diag_encoder = SharedEncoder(n_genes=n_genes, n_embedding=[512, 256, 128], dropout=0.1)

diag_results = diagnose_spatial_infonce(
    model=diag_encoder,
    st_gene_expr=st_expr,
    st_coords=st_coords,
    sc_gene_expr=inf_expr,
    slide_ids=slide_ids,
    sc_slide_ids=inf_slide_ids,
    spatial_nce_weight=3.0,
    spatial_nce_k_phys=20,
    spatial_nce_far_mult=4.0,
    spatial_nce_n_hard=20,
    spatial_nce_tau=0.1,
    spatial_nce_n_rand_neg=128,
    spatial_nce_n_anchors=64,
    vicreg_lambda_inv=25.0,
    vicreg_lambda_var=50.0,
    vicreg_lambda_cov=1.0,
    vicreg_gamma=1.0,
    vicreg_eps=1e-4,
    aug_gene_dropout=0.25,
    aug_gauss_std=0.01,
    aug_scale_jitter=0.1,
    local_align_weight=0.0,
    local_align_tau_z=0.07,
    local_align_bidirectional=True,
    batch_size=256,
    n_diagnostic_steps=200,
    lr=1e-4,
    device=device,
    seed=SEED,
)

# ===================================================================
# PLOTS
# ===================================================================
import matplotlib.pyplot as plt
import numpy as np

logs = diag_results['step_logs']
steps = np.array(logs['step'])
cc = diag_results['check_c']

fig, axes = plt.subplots(2, 3, figsize=(18, 10))

# --- Panel 1: Similarity traces (Check B) ---
ax = axes[0, 0]
ax.plot(steps, logs['sim_pos_mean'], label='sim(anchor, pos)', color='green', linewidth=2)
ax.plot(steps, logs['sim_hard_mean'], label='sim(anchor, hard_neg)', color='red', linewidth=2)
ax.plot(steps, logs['sim_rand_mean'], label='sim(anchor, rand_neg)', color='blue', linewidth=2)
ax.fill_between(steps,
    np.array(logs['sim_pos_mean']) - np.array(logs['sim_pos_std']),
    np.array(logs['sim_pos_mean']) + np.array(logs['sim_pos_std']),
    alpha=0.15, color='green')
ax.fill_between(steps,
    np.array(logs['sim_hard_mean']) - np.array(logs['sim_hard_std']),
    np.array(logs['sim_hard_mean']) + np.array(logs['sim_hard_std']),
    alpha=0.15, color='red')
ax.set_xlabel('Step')
ax.set_ylabel('Cosine Similarity')
ax.set_title('CHECK B: Similarity Gap (support-set)')
ax.legend(fontsize=8)
ax.grid(alpha=0.3)

# --- Panel 2: NCE loss over steps (Check B) ---
ax = axes[0, 1]
ax.plot(steps, logs['loss_nce'], label='L_spatialNCE', color='purple', linewidth=2)
ax.plot(steps, logs['loss_vicreg'], label='L_VICReg', color='orange', linewidth=2, alpha=0.7)
ax.set_xlabel('Step')
ax.set_ylabel('Loss')
ax.set_title('CHECK B: Loss Components')
ax.legend(fontsize=8)
ax.grid(alpha=0.3)

# --- Panel 3: Gradient norms bar chart (Check A) ---
ax = axes[0, 2]
gn = diag_results['grad_norms']
names_gn = list(gn.keys())
vals_gn = [gn[k] for k in names_gn]
bars = ax.bar(names_gn, vals_gn, color=['purple', 'orange', 'green', 'gray'])
ax.set_ylabel('||grad_theta||')
ax.set_title('CHECK A: Gradient Norms')
for bar, val in zip(bars, vals_gn):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.001,
            f'{val:.4f}', ha='center', va='bottom', fontsize=8)
ax.grid(alpha=0.3, axis='y')
plt.setp(ax.get_xticklabels(), rotation=15, ha='right', fontsize=8)

# --- Panel 4: Physical distance distributions (Check C) ---
ax = axes[1, 0]
if len(cc['pos_phys_dists']) > 0:
    ax.hist(cc['pos_phys_dists'], bins=60, alpha=0.6, label='Positives', color='green', density=True)
if len(cc['hard_phys_dists']) > 0:
    ax.hist(cc['hard_phys_dists'], bins=60, alpha=0.6, label='Hard negatives', color='red', density=True)
ax.axvline(cc['r_far_threshold'], color='black', linestyle='--', linewidth=2,
           label=f'r_far = {cc["r_far_threshold"]:.3f}')
ax.set_xlabel('Physical distance to anchor')
ax.set_ylabel('Density')
ax.set_title('CHECK C: Pos vs Neg Physical Distances')
ax.legend(fontsize=8)
ax.grid(alpha=0.3)

# --- Panel 5: Per-anchor counts (Check C) ---
ax = axes[1, 1]
x_pos = np.arange(3)
means = [cc['pos_counts'].mean(), cc['hard_counts'].mean(), np.mean(cc['far_counts'])]
mins = [cc['pos_counts'].min(), cc['hard_counts'].min(), np.min(cc['far_counts'])]
labels_c = ['Positives\n(phys neighbors)', 'Hard negatives\n(expr-sim + far)', 'Far mask\n(total far spots)']
bar_colors = ['green', 'red', 'blue']
bars = ax.bar(x_pos, means, color=bar_colors, alpha=0.7)
for i, (bar, m, mn) in enumerate(zip(bars, means, mins)):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5,
            f'mean={m:.1f}\nmin={mn}', ha='center', va='bottom', fontsize=8)
ax.set_xticks(x_pos)
ax.set_xticklabels(labels_c, fontsize=8)
ax.set_ylabel('Count per anchor')
ax.set_title('CHECK C: Index Set Sizes')
ax.grid(alpha=0.3, axis='y')

# --- Panel 6: Batch coverage (Check B) ---
ax = axes[1, 2]
ax.plot(steps, logs['n_active_anchors'], label='Active anchors', color='teal', linewidth=2)
ax2 = ax.twinx()
ax2.plot(steps, logs['n_pos_per_anchor'], label='pos/anchor', color='green', linestyle='--')
ax2.plot(steps, logs['n_hard_per_anchor'], label='hard/anchor', color='red', linestyle='--')
ax2.plot(steps, logs['n_rand_per_anchor'], label='rand/anchor', color='blue', linestyle='--')
ax.set_xlabel('Step')
ax.set_ylabel('Active Anchors')
ax2.set_ylabel('Neighbors per Anchor')
ax.set_title('CHECK B: Support-Set Coverage')
lines1, labels1 = ax.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax.legend(lines1 + lines2, labels1 + labels2, fontsize=7, loc='center right')
ax.grid(alpha=0.3)

plt.suptitle('Spatial InfoNCE Diagnostics — Support-Set (Checks A + B + C)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

# ===================================================================
# VERDICT
# ===================================================================
print("\n" + "=" * 70)
print("DIAGNOSTIC VERDICT")
print("=" * 70)

# Check A verdict
gnce = gn.get('L_spatialNCE', 0)
gvic = gn.get('L_VICReg', 0)
print("\n[CHECK A] Gradient norms:")
if gnce < 1e-6:
    print("  FAIL: Spatial NCE gradient is ZERO -- loss is disconnected")
elif gvic > 0 and gnce / gvic < 1e-3:
    print(f"  WARN: Spatial NCE gradient negligible vs VICReg (ratio={gnce/gvic:.6f})")
else:
    print(f"  PASS: Spatial NCE is live (ratio NCE/VICReg = {gnce/gvic:.4f})")

# Check B verdict
mean_active = np.mean(logs['n_active_anchors'])
mean_pos = np.mean(logs['sim_pos_mean'])
mean_hard = np.mean(logs['sim_hard_mean'])
mean_rand = np.mean(logs['sim_rand_mean'])
mean_n_pos = np.mean(logs['n_pos_per_anchor'])
mean_n_hard = np.mean(logs['n_hard_per_anchor'])
print(f"\n[CHECK B] Loss scale (support-set):")
print(f"  Active anchors/step: {mean_active:.1f}")
print(f"  pos/anchor: {mean_n_pos:.1f}, hard/anchor: {mean_n_hard:.1f}")
print(f"  sim(pos)={mean_pos:.4f}, sim(hard)={mean_hard:.4f}, sim(rand)={mean_rand:.4f}")
print(f"  Gap pos-hard: {mean_pos - mean_hard:.4f}, Gap pos-rand: {mean_pos - mean_rand:.4f}")
if mean_active < 2:
    print("  FAIL: Almost no anchors active -- loss is starved")
elif mean_pos > mean_hard + 0.1:
    print("  INFO: Positives already more similar than hard negatives")
elif mean_pos < mean_hard:
    print("  INFO: Positives less similar than hard negatives -- loss is working hard")

# Check C verdict
print(f"\n[CHECK C] Index set sanity:")
max_overlap_ph = cc['overlap_pos_hard'].max()
max_overlap_any = cc['overlap_pos_any_neg'].max()
n_leak_ph = sum(1 for x in cc['overlap_pos_hard'] if x > 0)
n_leak_any = sum(1 for x in cc['overlap_pos_any_neg'] if x > 0)
n_sampled = len(cc['overlap_pos_hard'])
print(f"  |pos ∩ hard_neg|: max={max_overlap_ph}, leaking anchors={n_leak_ph}/{n_sampled}")
print(f"  |pos ∩ (hard ∪ far)|: max={max_overlap_any}, leaking anchors={n_leak_any}/{n_sampled}")
if max_overlap_ph > 0:
    print("  FAIL: LEAKAGE -- some spots are both positive AND hard negative!")
elif max_overlap_any > 0:
    print("  WARN: Some positives appear in the far set")
else:
    print("  PASS: Zero overlap between pos and neg index sets")

if len(cc['pos_phys_dists']) > 0:
    pct_pos_ok = (cc['pos_phys_dists'] < cc['r_far_threshold']).mean() * 100
    print(f"  Pos below r_far: {pct_pos_ok:.1f}% (should be ~100%)")
if len(cc['hard_phys_dists']) > 0:
    pct_hard_ok = (cc['hard_phys_dists'] >= cc['r_far_threshold']).mean() * 100
    print(f"  Hard neg above r_far: {pct_hard_ok:.1f}% (should be ~100%)")
print("=" * 70)

In [None]:
print("\n" + "=" * 70)
print("TRAINING ENCODER — CROSS-SLIDE (SAME PATIENT)")
print("=" * 70)

set_seed(SEED)
encoder = SharedEncoder(n_genes=n_genes, n_embedding=[512, 256, 128], dropout=0.1)

outdir = '/home/ehtesamul/sc_st/model/gems_liver_crossslide'
os.makedirs(outdir, exist_ok=True)

encoder, projector, discriminator, hist = train_encoder(
    inference_dropout_prob=0.5,
    model=encoder,
    st_gene_expr=st_expr,
    st_coords=st_coords,
    sc_gene_expr=inf_expr,
    slide_ids=slide_ids,
    sc_slide_ids=inf_slide_ids,
    sc_patient_ids=inf_patient_ids,
    n_epochs=2000,
    batch_size=256,
    lr=1e-4,
    device=device,
    outf=outdir,
    st_source_ids=st_source_ids,
    sc_source_ids=inf_source_ids,
    use_source_adversary=False,
    source_coral_weight=75.0,
    stageA_obj='vicreg_adv',
    vicreg_lambda_inv=25.0,
    vicreg_lambda_var=50.0,
    vicreg_lambda_cov=1.0,
    vicreg_gamma=1.0,
    vicreg_eps=1e-4,
    vicreg_project_dim=256,
    vicreg_use_projector=False,
    vicreg_float32_stats=True,
    vicreg_ddp_gather=False,
    aug_gene_dropout=0.25,
    aug_gauss_std=0.01,
    aug_scale_jitter=0.1,
    adv_slide_weight=75.0,
    patient_coral_weight=0.0,
    mmd_weight=30.0,
    mmd_use_l2norm=True,
    mmd_ramp=True,
    adv_warmup_epochs=50,
    adv_ramp_epochs=200,
    grl_alpha_max=1.0,
    disc_hidden=512,
    disc_dropout=0.1,
    stageA_balanced_slides=True,
    adv_representation_mode='clean',
    adv_use_layernorm=False,
    adv_log_diagnostics=True,
    adv_log_grad_norms=False,
    use_local_align=True,
    return_aux=True,
    local_align_bidirectional=True,
    local_align_weight=0.0,
    local_align_tau_z=0.07,
    seed=SEED,
    use_best_checkpoint=True,
    coral_raw_weight=2.0,
    knn_weight=0.0,              # OFF — this preserves expression neighborhoods (bad for liver)
    spatial_nce_weight=3.0,       # ON — enforces physical neighborhoods (support-set)
    spatial_nce_k_phys=20,
    spatial_nce_far_mult=4.0,
    spatial_nce_n_hard=20,
    spatial_nce_tau=0.1,
    spatial_nce_n_rand_neg=128,
    spatial_nce_n_anchors=64,     # anchors per step for support-set NCE
    # ---- Two-phase training (ChatGPT Step 2) ----
    two_phase_training=True,      # Phase 1: VICReg+NCE only, Phase 2: add alignment
    phase1_epochs=1000,           # ~where overlap saturates based on prior runs
    phase2_lr_factor=0.1,         # 10x lower LR in Phase 2 to preserve spatial geometry
)

print("\n✓ Training complete!")
torch.save(encoder.state_dict(), f'{outdir}/encoder_final_trained.pt')
print(f"✓ Saved to: {outdir}/encoder_final_trained.pt")

In [None]:
print("\n" + "=" * 70)
print("COMPUTING EMBEDDINGS")
print("=" * 70)

encoder.eval()
with torch.no_grad():
    Z_st = {name: encoder(torch.tensor(X, dtype=torch.float32, device=device)).cpu()
            for name, X in X_st.items()}
    Z_inf = {name: encoder(torch.tensor(X, dtype=torch.float32, device=device)).cpu()
             for name, X in X_inf.items()}

for name, Z in {**Z_st, **Z_inf}.items():
    print(f"  {name}: {Z.shape}")

In [None]:
# Load trained encoder
encoder = SharedEncoder(n_genes=n_genes, n_embedding=[512, 256, 128], dropout=0.1)
encoder.load_state_dict(torch.load(
    '/home/ehtesamul/sc_st/model/gems_liver_crossslide/encoder_final_trained.pt',
    map_location=device
))
encoder.to(device)
encoder.eval()

# Compute embeddings
with torch.no_grad():
    Z_st = {name: encoder(torch.tensor(X, dtype=torch.float32, device=device)).cpu()
            for name, X in X_st.items()}
    Z_inf = {name: encoder(torch.tensor(X, dtype=torch.float32, device=device)).cpu()
             for name, X in X_inf.items()}

for name, Z in {**Z_st, **Z_inf}.items():
    print(f"  {name}: {Z.shape}")

In [None]:
# === Analysis A: Embedding kNN vs Physical kNN ===
from scipy.spatial import cKDTree

print("=" * 70)
print("ANALYSIS A: Embedding kNN vs Physical Neighborhoods (ST only)")
print("=" * 70)

for name in ST_PATHS:
    coords = st_data[name].obsm['spatial']
    Z = Z_st[name].numpy()
    n = Z.shape[0]

    tree_phys = cKDTree(coords)
    tree_emb = cKDTree(Z)

    rng = np.random.default_rng(SEED)

    for k in [10, 20, 50]:
        # physical distances of embedding kNN
        _, emb_idx = tree_emb.query(Z, k=k + 1)
        emb_idx = emb_idx[:, 1:]  # exclude self
        emb_phys_dists = np.array([
            np.median(np.linalg.norm(coords[emb_idx[i]] - coords[i], axis=1))
            for i in range(n)
        ])

        # physical distances of true physical kNN
        _, phys_idx = tree_phys.query(coords, k=k + 1)
        phys_idx = phys_idx[:, 1:]
        phys_phys_dists = np.array([
            np.median(np.linalg.norm(coords[phys_idx[i]] - coords[i], axis=1))
            for i in range(n)
        ])

        # random baseline
        rand_dists = np.array([
            np.median(np.linalg.norm(coords[rng.choice(n, k, replace=False)] - coords[i], axis=1))
            for i in range(n)
        ])

        print(f"\n  {name} | k={k}")
        print(f"    Physical kNN median dist:  {np.median(phys_phys_dists):.2f}")
        print(f"    Embedding kNN median dist: {np.median(emb_phys_dists):.2f}")
        print(f"    Random median dist:        {np.median(rand_dists):.2f}")

    # Histogram for k=20
    _, emb_idx = tree_emb.query(Z, k=21)
    emb_idx = emb_idx[:, 1:]
    _, phys_idx = tree_phys.query(coords, k=21)
    phys_idx = phys_idx[:, 1:]

    emb_d = [np.median(np.linalg.norm(coords[emb_idx[i]] - coords[i], axis=1)) for i in range(n)]
    phys_d = [np.median(np.linalg.norm(coords[phys_idx[i]] - coords[i], axis=1)) for i in range(n)]
    rand_d = [np.median(np.linalg.norm(coords[rng.choice(n, 20, replace=False)] - coords[i], axis=1)) for i in range(n)]

    fig, ax = plt.subplots(figsize=(6, 4))
    ax.hist(phys_d, bins=50, alpha=0.5, label='Physical kNN', density=True)
    ax.hist(emb_d, bins=50, alpha=0.5, label='Embedding kNN', density=True)
    ax.hist(rand_d, bins=50, alpha=0.5, label='Random', density=True)
    ax.set_xlabel('Median physical distance to k=20 neighbors')
    ax.set_ylabel('Density')
    ax.set_title(f'{name}: kNN Physical Distance Distributions')
    ax.legend()
    plt.tight_layout()
    plt.show()

In [None]:
# === Analysis B: Moran's I on Embedding Dimensions (no libpysal) ===
from scipy.spatial import cKDTree
from scipy.stats import norm
from statsmodels.stats.multitest import multipletests

print("=" * 70)
print("ANALYSIS B: Spatial Autocorrelation of Embeddings (Moran's I)")
print("=" * 70)

def morans_i_fast(x, idx, n):
    """Moran's I with precomputed kNN indices."""
    xm = x - x.mean()
    denom = np.sum(xm ** 2)
    if denom < 1e-12:
        return 0.0, 1.0
    k = idx.shape[1]
    W = n * k
    num = np.sum(xm[:, None] * xm[idx]).sum() if False else sum(np.sum(xm[i] * xm[idx[i]]) for i in range(n))
    # vectorized version:
    num = np.einsum('i,ij->', xm, xm[idx])
    I = (n / W) * (num / denom)
    EI = -1.0 / (n - 1)
    # Variance (normality assumption, simplified)
    VI = max(1.0 / (n - 1), 1e-10)  # conservative approx
    z = (I - EI) / np.sqrt(VI / n)
    p = 2 * norm.sf(abs(z))
    return I, p

for name in ST_PATHS:
    coords = st_data[name].obsm['spatial']
    Z = Z_st[name].numpy()
    n, d = Z.shape

    tree = cKDTree(coords)
    _, idx = tree.query(coords, k=11)
    idx = idx[:, 1:]  # exclude self

    morans = np.zeros(d)
    pvals = np.zeros(d)
    for j in range(d):
        morans[j], pvals[j] = morans_i_fast(Z[:, j], idx, n)

    _, pvals_fdr, _, _ = multipletests(pvals, method='fdr_bh')
    n_sig = np.sum(pvals_fdr < 0.05)

    # Embedding norm
    Z_norm = np.linalg.norm(Z, axis=1)
    I_norm, p_norm = morans_i_fast(Z_norm, idx, n)

    print(f"\n  {name}:")
    print(f"    Embedding dims: {d}")
    print(f"    Spatially autocorrelated dims (FDR<0.05): {n_sig}/{d} ({100*n_sig/d:.1f}%)")
    print(f"    Moran's I — mean: {morans.mean():.4f}, median: {np.median(morans):.4f}")
    print(f"    Moran's I of embedding norm: {I_norm:.4f} (p={p_norm:.4e})")

    fig, ax = plt.subplots(figsize=(6, 3))
    ax.hist(morans, bins=30, edgecolor='k', alpha=0.7)
    ax.axvline(0, color='red', linestyle='--', label='No autocorrelation')
    ax.set_xlabel("Moran's I")
    ax.set_ylabel('Count')
    ax.set_title(f"{name}: Moran's I across {d} embedding dims")
    ax.legend()
    plt.tight_layout()
    plt.show()

In [None]:
# === Analysis E: Distribution Checks ===
print("=" * 70)
print("ANALYSIS E: Distribution Checks (ST vs Inference)")
print("=" * 70)

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

all_stats = {}
for name, X in {**X_st, **X_inf}.items():
    lib_size = X.sum(axis=1)
    genes_detected = (X > 0).sum(axis=1)
    sparsity = (X == 0).mean() * 100
    all_stats[name] = {
        'lib_size_median': np.median(lib_size),
        'lib_size_std': np.std(lib_size),
        'genes_detected_median': np.median(genes_detected),
        'sparsity_pct': sparsity,
    }
    print(f"  {name}:")
    print(f"    Library size — median: {np.median(lib_size):.1f}, std: {np.std(lib_size):.1f}")
    print(f"    Genes detected — median: {np.median(genes_detected):.0f}")
    print(f"    Sparsity: {sparsity:.1f}%")
    print()

# Library size distributions
for name, X in {**X_st, **X_inf}.items():
    axes[0].hist(X.sum(axis=1), bins=50, alpha=0.5, label=name, density=True)
axes[0].set_xlabel('Library size (post-norm)')
axes[0].set_title('Library Size Distribution')
axes[0].legend(fontsize=7)

# Genes detected
for name, X in {**X_st, **X_inf}.items():
    axes[1].hist((X > 0).sum(axis=1), bins=50, alpha=0.5, label=name, density=True)
axes[1].set_xlabel('Genes detected per spot')
axes[1].set_title('Detected Genes Distribution')
axes[1].legend(fontsize=7)

# Per-gene mean expression
for name, X in {**X_st, **X_inf}.items():
    gene_means = X.mean(axis=0)
    axes[2].hist(gene_means, bins=50, alpha=0.5, label=name, density=True)
axes[2].set_xlabel('Per-gene mean expression')
axes[2].set_title('Gene Mean Distribution')
axes[2].legend(fontsize=7)

plt.tight_layout()
plt.show()

In [None]:
# === Analysis C: Spatially Variable Genes ===
from scipy.spatial import cKDTree
import pandas as pd

print("=" * 70)
print("ANALYSIS C: Spatially Variable Gene Analysis")
print("=" * 70)

for name in list(ST_PATHS.keys())[:1]:
    adata = st_data[name].copy()
    adata = adata[:, common_genes]
    coords = adata.obsm['spatial']
    X_dense = adata.X.toarray() if hasattr(adata.X, 'toarray') else adata.X
    n = X_dense.shape[0]

    tree = cKDTree(coords)
    _, idx = tree.query(coords, k=11)
    idx = idx[:, 1:]

    morans_vals = np.empty(X_dense.shape[1])
    for g in range(X_dense.shape[1]):
        x = X_dense[:, g]
        xm = x - x.mean()
        denom = np.sum(xm ** 2)
        if denom < 1e-12:
            morans_vals[g] = 0.0
            continue
        num = np.einsum('i,ij->', xm, xm[idx])
        W = n * 10
        morans_vals[g] = (n / W) * (num / denom)

    svg_results = pd.DataFrame({'I': morans_vals}, index=common_genes)
    svg_results = svg_results.sort_values('I', ascending=False)

    n_sig_svg = (svg_results['I'] > 0.1).sum()
    top300_mean_I = svg_results.head(300)['I'].mean()

    detection_rate = (X_dense > 0).mean(axis=0)
    n_ultra_sparse = (detection_rate < 0.01).sum()

    print(f"\n  {name} ({len(common_genes)} common genes):")
    print(f"    Genes with Moran's I > 0.1: {n_sig_svg}")
    print(f"    Top 300 SVGs — mean Moran's I: {top300_mean_I:.4f}")
    print(f"    Ultra-sparse genes (<1% detection): {n_ultra_sparse}/{len(common_genes)}")
    print(f"\n  Top 20 SVGs:")
    print(svg_results.head(20)[['I']].to_string())

In [None]:
# === Ablation 2: Embedding kNN under different transforms ===
from scipy.spatial import cKDTree
from sklearn.decomposition import PCA

print("=" * 70)
print("ABLATION 2: Embedding kNN with different similarity transforms")
print("=" * 70)

K_VALUES = [10, 20, 50]

def compute_knn_phys_dist(Z_np, coords, k):
    """Median physical distance of embedding kNN neighbors."""
    tree = cKDTree(Z_np)
    _, idx = tree.query(Z_np, k=k + 1)
    idx = idx[:, 1:]
    diffs = coords[idx] - coords[:, None, :]
    dists = np.sqrt((diffs ** 2).sum(axis=2))
    return np.median(dists, axis=1)  # per-spot median

for name in list(ST_PATHS.keys())[:1]:  # run on first ST slide
    coords = st_data[name].obsm['spatial']
    Z = Z_st[name].numpy()
    n = Z.shape[0]

    # Baselines
    tree_phys = cKDTree(coords)

    transforms = {
        'Raw embedding': Z,
        'L2-normalized (cosine)': Z / (np.linalg.norm(Z, axis=1, keepdims=True) + 1e-8),
        'Per-dim z-scored': (Z - Z.mean(0)) / (Z.std(0) + 1e-8),
        'PCA(Z) → 32d': PCA(n_components=32, random_state=42).fit_transform(Z),
        'PCA(Z) → 64d': PCA(n_components=64, random_state=42).fit_transform(Z),
    }

    for k in K_VALUES:
        print(f"\n  {name} | k={k}")
        print(f"  {'Transform':<30s} {'Median phys dist':>18s}")
        print(f"  {'-'*30} {'-'*18}")

        # Physical baseline
        _, phys_idx = tree_phys.query(coords, k=k + 1)
        phys_idx = phys_idx[:, 1:]
        phys_d = np.median(np.sqrt(((coords[phys_idx] - coords[:, None, :]) ** 2).sum(2)), axis=1)
        print(f"  {'Physical kNN (best)':<30s} {np.median(phys_d):>18.2f}")

        for tname, Z_t in transforms.items():
            emb_d = compute_knn_phys_dist(Z_t, coords, k)
            print(f"  {tname:<30s} {np.median(emb_d):>18.2f}")

        # Random baseline
        rng = np.random.default_rng(42)
        rand_d = np.array([
            np.median(np.linalg.norm(coords[rng.choice(n, k, replace=False)] - coords[i], axis=1))
            for i in range(n)
        ])
        print(f"  {'Random (worst)':<30s} {np.median(rand_d):>18.2f}")

In [None]:
# === Ablation 3: Identifiability — crop a single spatial region ===
from scipy.spatial import cKDTree

print("=" * 70)
print("ABLATION 3: Embedding kNN within a spatial crop (single region)")
print("=" * 70)

for name in list(ST_PATHS.keys())[:1]:
    coords = st_data[name].obsm['spatial']
    Z = Z_st[name].numpy()
    n = Z.shape[0]

    # Crop: take a central region covering ~25% of spots
    cx, cy = np.median(coords, axis=0)
    dists_to_center = np.sqrt((coords[:, 0] - cx)**2 + (coords[:, 1] - cy)**2)
    radius = np.quantile(dists_to_center, 0.50)  # ~25% area → 50% radius
    mask = dists_to_center <= radius
    n_crop = mask.sum()

    coords_crop = coords[mask]
    Z_crop = Z[mask]

    print(f"  {name}: cropped {n_crop}/{n} spots (radius={radius:.1f})")

    tree_phys = cKDTree(coords_crop)
    tree_emb = cKDTree(Z_crop)

    for k in [10, 20]:
        if k >= n_crop:
            continue

        _, phys_idx = tree_phys.query(coords_crop, k=k + 1)
        phys_idx = phys_idx[:, 1:]
        phys_d = np.median(np.sqrt(((coords_crop[phys_idx] - coords_crop[:, None, :]) ** 2).sum(2)), axis=1)

        _, emb_idx = tree_emb.query(Z_crop, k=k + 1)
        emb_idx = emb_idx[:, 1:]
        emb_d = np.median(np.sqrt(((coords_crop[emb_idx] - coords_crop[:, None, :]) ** 2).sum(2)), axis=1)

        rng = np.random.default_rng(42)
        rand_d = np.array([
            np.median(np.linalg.norm(coords_crop[rng.choice(n_crop, k, replace=False)] - coords_crop[i], axis=1))
            for i in range(n_crop)
        ])

        print(f"\n  Cropped region | k={k}")
        print(f"    Physical kNN median dist:  {np.median(phys_d):.2f}")
        print(f"    Embedding kNN median dist: {np.median(emb_d):.2f}")
        print(f"    Random median dist:        {np.median(rand_d):.2f}")
        ratio = (np.median(emb_d) - np.median(phys_d)) / (np.median(rand_d) - np.median(phys_d) + 1e-8)
        print(f"    Emb→Phys ratio (0=perfect, 1=random): {ratio:.3f}")

In [None]:
# === Ablation 1: Physical vs Embedding patch construction (ST4) ===
from scipy.spatial import cKDTree

print("=" * 70)
print("ABLATION 1: Physical vs Embedding patch graph quality (ST4)")
print("=" * 70)

name = 'liver_ST4'
coords = inf_data[name].obsm['spatial']
Z = Z_inf[name].numpy()
n = Z.shape[0]

K_PATCH = 20  # typical miniset neighborhood size

# === Mode P: Physical kNN graph ===
tree_phys = cKDTree(coords)
_, phys_idx = tree_phys.query(coords, k=K_PATCH + 1)
phys_idx = phys_idx[:, 1:]

# === Mode Z: Embedding kNN graph ===
tree_emb = cKDTree(Z)
_, emb_idx = tree_emb.query(Z, k=K_PATCH + 1)
emb_idx = emb_idx[:, 1:]

# --- Metric 1: Physical compactness of patches ---
def patch_compactness(coords, idx):
    """Mean diameter of patches (max pairwise dist within each patch)."""
    diameters = []
    for i in range(len(idx)):
        patch_coords = coords[idx[i]]
        center = patch_coords.mean(axis=0)
        dists = np.sqrt(((patch_coords - center) ** 2).sum(axis=1))
        diameters.append(dists.max())
    return np.array(diameters)

phys_compact = patch_compactness(coords, phys_idx)
emb_compact = patch_compactness(coords, emb_idx)

print(f"\n  Patch compactness (physical radius of k={K_PATCH} neighborhood):")
print(f"    Physical patches — median: {np.median(phys_compact):.2f}, mean: {np.mean(phys_compact):.2f}")
print(f"    Embedding patches — median: {np.median(emb_compact):.2f}, mean: {np.mean(emb_compact):.2f}")
print(f"    Ratio (emb/phys): {np.median(emb_compact) / np.median(phys_compact):.2f}x")

# --- Metric 2: Graph connectivity overlap ---
# What fraction of embedding neighbors are also physical neighbors?
overlap_frac = np.array([
    len(set(phys_idx[i]) & set(emb_idx[i])) / K_PATCH
    for i in range(n)
])

print(f"\n  Neighbor overlap (phys ∩ emb / k):")
print(f"    Mean overlap: {overlap_frac.mean():.4f}")
print(f"    Median overlap: {np.median(overlap_frac):.4f}")
print(f"    % spots with zero overlap: {(overlap_frac == 0).mean() * 100:.1f}%")

# --- Metric 3: Pairwise distance preservation ---
# For each patch, compute pairwise distances in coord space
# Compare consistency between the two modes
def pairwise_dist_stats(coords, idx):
    """Per-patch mean pairwise distance in physical space."""
    means = []
    for i in range(len(idx)):
        pc = coords[idx[i]]
        d = np.sqrt(((pc[:, None] - pc[None, :]) ** 2).sum(axis=2))
        means.append(d[np.triu_indices(len(pc), k=1)].mean())
    return np.array(means)

phys_pw = pairwise_dist_stats(coords, phys_idx)
emb_pw = pairwise_dist_stats(coords, emb_idx)

print(f"\n  Mean pairwise physical distance within patches:")
print(f"    Physical patches: {np.median(phys_pw):.2f}")
print(f"    Embedding patches: {np.median(emb_pw):.2f}")

# --- Visualization ---
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

axes[0].hist(phys_compact, bins=50, alpha=0.6, label='Physical', density=True)
axes[0].hist(emb_compact, bins=50, alpha=0.6, label='Embedding', density=True)
axes[0].set_xlabel('Patch radius')
axes[0].set_title('Patch Compactness')
axes[0].legend()

axes[1].hist(overlap_frac, bins=50, edgecolor='k', alpha=0.7)
axes[1].set_xlabel('Neighbor overlap fraction')
axes[1].set_title(f'Physical ∩ Embedding (k={K_PATCH})')

axes[2].hist(phys_pw, bins=50, alpha=0.6, label='Physical', density=True)
axes[2].hist(emb_pw, bins=50, alpha=0.6, label='Embedding', density=True)
axes[2].set_xlabel('Mean pairwise distance')
axes[2].set_title('Within-Patch Distances')
axes[2].legend()

plt.suptitle(f'{name}: Physical vs Embedding Patch Construction', fontweight='bold')
plt.tight_layout()
plt.show()

# --- Summary verdict ---
ratio = np.median(emb_compact) / np.median(phys_compact)
print(f"\n  VERDICT:")
if ratio > 3.0 and np.median(overlap_frac) < 0.1:
    print(f"    Embedding patches are {ratio:.1f}x larger with {np.median(overlap_frac)*100:.1f}% overlap.")
    print(f"    → Embedding kNN is NOT spatially local. Physical patches should be used for training.")
elif ratio > 1.5:
    print(f"    Embedding patches are {ratio:.1f}x larger — moderate mismatch.")
    print(f"    → Consider hybrid approach or tighter embedding constraints.")
else:
    print(f"    Embedding patches are comparable ({ratio:.1f}x). kNN graph is reasonable.")

In [None]:
print("\n" + "=" * 70)
print("EVALUATION: Domain Mixing Analysis")
print("=" * 70)

N_MAX = 5000
K = 20
set_seed(SEED)

def subsample(X, n_max):
    if not isinstance(X, torch.Tensor):
        X = torch.tensor(X, dtype=torch.float32)
    if X.shape[0] <= n_max:
        return X
    return X[torch.randperm(X.shape[0])[:n_max]]

Z_st_sub = {n: subsample(Z, N_MAX) for n, Z in Z_st.items()}
Z_inf_sub = {n: subsample(Z, N_MAX) for n, Z in Z_inf.items()}

Z_all = torch.cat(list(Z_st_sub.values()) + list(Z_inf_sub.values()), dim=0)
Z_all_norm = F.normalize(Z_all, dim=1)
N = Z_all.shape[0]

# Domain labels
domain_labels = []
source_names = []
for i, name in enumerate(list(Z_st_sub) + list(Z_inf_sub)):
    n = (Z_st_sub if name in Z_st_sub else Z_inf_sub)[name].shape[0]
    domain_labels.append(torch.full((n,), i, dtype=torch.long))
    source_names.append(name)
domain_labels = torch.cat(domain_labels)
S = len(source_names)

print(f"Sources: {source_names}, Total: {N}, Domains: {S}")

# --- kNN Domain-Mixing ---
print(f"\n[METRIC 1] kNN Domain-Mixing (k={K})")
D = torch.cdist(Z_all_norm, Z_all_norm)
D.fill_diagonal_(float('inf'))
_, knn_idx = torch.topk(D, k=K, dim=1, largest=False)
knn_labels = domain_labels[knn_idx]

p = torch.zeros(N, S)
for s in range(S):
    p[:, s] = (knn_labels == s).float().mean(dim=1)

eps = 1e-10
H_norm = (-torch.sum(p * torch.log(p + eps), dim=1) / np.log(S)).mean().item()
iLISI_norm = ((1.0 / torch.sum(p ** 2, dim=1) - 1) / (S - 1)).mean().item()

print(f"  Normalized Neighbor Entropy: {H_norm:.4f}")
print(f"  Normalized iLISI: {iLISI_norm:.4f}")

# --- Domain Classification (CV) ---
print(f"\n[METRIC 2] Domain Classification (5-fold CV)")
Z_np, y_np = Z_all_norm.numpy(), domain_labels.numpy()
clf = LogisticRegression(max_iter=2000, random_state=42, class_weight='balanced', n_jobs=-1)
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
cv_scores = cross_val_score(clf, Z_np, y_np, cv=cv, scoring='balanced_accuracy')
chance = 1.0 / S
print(f"  Balanced Acc: {cv_scores.mean():.4f} ± {cv_scores.std():.4f} (chance={chance:.4f})")

# --- Binary: Training ST vs Inference ST4 ---
print(f"\n[METRIC 3] Binary ST(1-3) vs ST4 (5-fold CV)")
n_st_total = sum(Z_st_sub[n].shape[0] for n in Z_st_sub)
y_binary = torch.cat([torch.zeros(n_st_total, dtype=torch.long),
                       torch.ones(N - n_st_total, dtype=torch.long)]).numpy()
cv_binary = cross_val_score(clf, Z_np, y_binary, cv=cv, scoring='balanced_accuracy')
print(f"  Balanced Acc (ST vs ST4): {cv_binary.mean():.4f} ± {cv_binary.std():.4f}")

# --- Centroid Analysis ---
print(f"\n[CENTROID ANALYSIS]")
centroids = {}
for name in Z_st_sub:
    centroids[name] = Z_st_sub[name].mean(dim=0)
for name in Z_inf_sub:
    centroids[name] = Z_inf_sub[name].mean(dim=0)

names = list(centroids)
print("         ", "  ".join([f"{n:>12}" for n in names]))
for n1 in names:
    row = f"{n1:12s}"
    for n2 in names:
        row += f"  {(centroids[n1] - centroids[n2]).norm().item():12.4f}"
    print(row)

# --- Summary ---
print("\n" + "=" * 70)
print("SUMMARY")
print("=" * 70)
print(f"{'Metric':<40} {'Value':<15}")
print("-" * 55)
print(f"{'Neighbor Entropy (k=20)':<40} {H_norm:<15.4f}")
print(f"{'iLISI (k=20)':<40} {iLISI_norm:<15.4f}")
print(f"{'Domain Class. Acc (CV)':<40} {cv_scores.mean():<15.4f}")
print(f"{'ST vs ST4 Acc (CV)':<40} {cv_binary.mean():<15.4f}")
print("=" * 70)

In [None]:
print("\n" + "=" * 70)
print("PCA COMPARISON: Raw vs Embeddings")
print("=" * 70)

N_VIS = 3000
set_seed(SEED)

# Subsample matched
def subsample_matched(X, Z, n_max):
    X_t = torch.tensor(X, dtype=torch.float32) if not isinstance(X, torch.Tensor) else X.float()
    Z_t = Z.float() if isinstance(Z, torch.Tensor) else torch.tensor(Z, dtype=torch.float32)
    X_t, Z_t = X_t.cpu(), Z_t.cpu()
    n = min(X_t.shape[0], n_max)
    idx = torch.randperm(X_t.shape[0])[:n]
    return X_t[idx].numpy(), Z_t[idx].numpy()

X_vis, Z_vis, labels_vis = [], [], []
for name in labels_str:
    src = X_st if name in X_st else X_inf
    src_z = Z_st if name in Z_st else Z_inf
    x_sub, z_sub = subsample_matched(src[name], src_z[name], N_VIS)
    X_vis.append(x_sub)
    Z_vis.append(z_sub)
    labels_vis.extend([name] * x_sub.shape[0])

X_vis = np.vstack(X_vis)
Z_vis = np.vstack(Z_vis)
labels_vis = np.array(labels_vis)

pca_x = PCA(n_components=2).fit(X_vis)
X_pca = pca_x.transform(X_vis)
var_x = pca_x.explained_variance_ratio_

pca_z = PCA(n_components=2).fit(Z_vis)
Z_pca = pca_z.transform(Z_vis)
var_z = pca_z.explained_variance_ratio_

# plt.rcParams.update({'font.family': 'Arial', 'font.weight': 'bold',
#                      'axes.labelweight': 'bold', 'axes.titleweight': 'bold'})

fig, axes = plt.subplots(1, 2, figsize=(12, 5))
titles = [
    f"(a) PCA of raw expression (log1p)\nVar: PC1={var_x[0]:.1%}, PC2={var_x[1]:.1%}",
    f"(b) PCA of encoder embeddings\nVar: PC1={var_z[0]:.1%}, PC2={var_z[1]:.1%}",
]
for ax, data, title in zip(axes, [X_pca, Z_pca], titles):
    for label in labels_str:
        mask = labels_vis == label
        ax.scatter(data[mask, 0], data[mask, 1], c=colors_map[label], label=label,
                   s=25, alpha=0.65, edgecolors='white', linewidths=0.5)
    ax.set_xlabel('PC1', fontsize=13, fontweight='bold')
    ax.set_ylabel('PC2', fontsize=13, fontweight='bold')
    ax.set_title(title, fontsize=13, fontweight='bold')
    ax.set_facecolor('#FAFAFA')
    ax.tick_params(labelsize=11)
    for spine in ax.spines.values():
        spine.set_color('#CCCCCC')
        spine.set_linewidth(1.5)

axes[0].legend(loc='best', fontsize=11, framealpha=0.98, edgecolor='#888888')
plt.tight_layout()
# plt.savefig('liver_pca_comparison.svg', format='svg', bbox_inches='tight', dpi=600)
plt.show()

In [None]:
# print("\n" + "=" * 70)
# print("t-SNE COMPARISON: Raw vs Embeddings")
# print("=" * 70)

# N_TSNE = 2000
# set_seed(SEED)

# X_tsne_all, Z_tsne_all, labels_tsne = [], [], []
# for name in labels_str:
#     src = X_st if name in X_st else X_inf
#     src_z = Z_st if name in Z_st else Z_inf
#     x_sub, z_sub = subsample_matched(src[name], src_z[name], N_TSNE)
#     X_tsne_all.append(x_sub)
#     Z_tsne_all.append(z_sub)
#     labels_tsne.extend([name] * x_sub.shape[0])

# X_tsne_all = np.vstack(X_tsne_all)
# Z_tsne_all = np.vstack(Z_tsne_all)
# labels_tsne = np.array(labels_tsne)
# Z_tsne_norm = F.normalize(torch.tensor(Z_tsne_all), dim=1).numpy()

# print("Computing t-SNE for raw expression...")
# X_tsne_proj = TSNE(n_components=2, random_state=42, perplexity=30, n_iter=1000).fit_transform(X_tsne_all)

# print("Computing t-SNE for embeddings...")
# Z_tsne_proj = TSNE(n_components=2, random_state=42, perplexity=30, n_iter=1000).fit_transform(Z_tsne_norm)

# fig, axes = plt.subplots(1, 2, figsize=(18, 7))
# fig.patch.set_facecolor('white')

# for ax, data, title in zip(axes, [X_tsne_proj, Z_tsne_proj],
#     ['(A) t-SNE on Raw Expression (Before)', '(B) t-SNE on Embeddings (After)']):
#     for label in labels_str:
#         mask = labels_tsne == label
#         ax.scatter(data[mask, 0], data[mask, 1], c=colors_map[label],
#                    label=label, alpha=0.5, s=20, edgecolors='none')
#     ax.set_title(title, fontsize=14, fontweight='bold')
#     ax.set_xlabel('t-SNE 1', fontsize=12)
#     ax.set_ylabel('t-SNE 2', fontsize=12)
#     ax.legend(loc='best', fontsize=10, frameon=True)
#     ax.grid(alpha=0.3, linestyle='--')

# plt.suptitle('Mouse Liver: Before vs After Alignment (Cross-Slide)',
#              fontsize=16, fontweight='bold', y=0.98)
# plt.tight_layout()
# plt.show()

In [None]:
# ===================================================================
# EVALUATION METRICS
# ===================================================================
print("\n" + "=" * 70)
print("EVALUATION: Liver Cross-Slide Encoder")
print("=" * 70)

N_MAX = 5000
set_seed(SEED)

def subsample(X, n_max, device='cpu'):
    if not isinstance(X, torch.Tensor):
        X = torch.tensor(X, dtype=torch.float32, device=device)
    else:
        X = X.to(device)
    if X.shape[0] <= n_max:
        return X
    return X[torch.randperm(X.shape[0], device=device)[:n_max]]

Z_st_sub = {name: subsample(Z, N_MAX, 'cpu') for name, Z in Z_st.items()}
Z_inf_sub = {name: subsample(Z, N_MAX, 'cpu') for name, Z in Z_inf.items()}

print("Subsampled embeddings:")
for name, Z in {**Z_st_sub, **Z_inf_sub}.items():
    print(f"  {name}: {Z.shape}")

# ===================================================================
# TEST 1: ST(1-3) vs ST4 kNN mixing
# ===================================================================
print("\n[ST-vs-ST4 MIXING] kNN domain distribution:")

Z_st_all = torch.cat(list(Z_st_sub.values()), dim=0)
Z_inf_all = torch.cat(list(Z_inf_sub.values()), dim=0)
Z_all = torch.cat([Z_st_all, Z_inf_all], dim=0)

n_st = Z_st_all.shape[0]
n_inf = Z_inf_all.shape[0]
n_total = n_st + n_inf

# Domain labels (0=ST training, 1=ST4 inference)
domain_labels = torch.cat([
    torch.zeros(n_st, dtype=torch.long),
    torch.ones(n_inf, dtype=torch.long)
])

Z_all_norm = F.normalize(Z_all, dim=1)

K = 20
inf_start = n_st

D_inf = torch.cdist(Z_all_norm[inf_start:], Z_all_norm)
for i in range(n_inf):
    D_inf[i, inf_start + i] = float('inf')

_, knn_inf = torch.topk(D_inf, k=K, dim=1, largest=False)
frac_same_inf = (domain_labels[knn_inf] == 1).float().mean().item()
base_rate_inf = n_inf / n_total

print(f"  ST4 neighbors (K={K}):")
print(f"    Same-domain (ST4) fraction: {frac_same_inf:.4f}")
print(f"    Base rate (chance):         {base_rate_inf:.4f}")

if frac_same_inf < base_rate_inf + 0.10:
    print("    ✓ EXCELLENT mixing (ST4 not clustering)")
elif frac_same_inf < base_rate_inf + 0.20:
    print("    ✓ Good mixing")
else:
    print("    ⚠️ ST4 may be clustering")

# ===================================================================
# TEST 2: Linear Probe (4-class separability)
# ===================================================================
n_sources = len(ST_PATHS) + len(INF_PATHS)
print(f"\n[{n_sources}-CLASS PROBE] Linear separability test:")

class_labels_list = []
class_idx = 0
for name in ST_PATHS:
    n_cells = Z_st_sub[name].shape[0]
    class_labels_list.append(torch.full((n_cells,), class_idx, dtype=torch.long))
    class_idx += 1
for name in INF_PATHS:
    n_cells = Z_inf_sub[name].shape[0]
    class_labels_list.append(torch.full((n_cells,), class_idx, dtype=torch.long))
    class_idx += 1

class_labels = torch.cat(class_labels_list)

Z_np = Z_all_norm.numpy()
y_np = class_labels.numpy()

probe = LogisticRegression(max_iter=5000, random_state=42, class_weight='balanced')
probe.fit(Z_np, y_np)
pred = probe.predict(Z_np)
bal_acc = balanced_accuracy_score(y_np, pred)
chance = 1.0 / n_sources

print(f"  Balanced accuracy: {bal_acc:.4f} (chance={chance:.3f})")

if bal_acc < 0.30:
    print("  ✓ EXCELLENT: Sources are very well-mixed")
elif bal_acc < 0.40:
    print("  ✓ Good: Moderate mixing")
elif bal_acc < 0.50:
    print("  ~ Partial alignment")
else:
    print("  ⚠️ Sources are separable")

# ===================================================================
# TEST 3: Centroid distances
# ===================================================================
print("\n[CENTROID ANALYSIS] Cross-source distances:")

centroids = {}
for name in ST_PATHS:
    centroids[name] = Z_st_sub[name].mean(dim=0)
for name in INF_PATHS:
    centroids[name] = Z_inf_sub[name].mean(dim=0)

names = list(centroids)
print("\nCentroid distance matrix:")
print("             ", "  ".join([f"{n:>12}" for n in names]))

for n1 in names:
    row = f"{n1:12s}"
    for n2 in names:
        dist = (centroids[n1] - centroids[n2]).norm().item()
        row += f"  {dist:12.4f}"
    print(row)

# ST-to-ST vs ST-to-ST4 distances
st_names = list(ST_PATHS)
inf_names = list(INF_PATHS)

st_to_st = []
for i, n1 in enumerate(st_names):
    for n2 in st_names[i + 1:]:
        st_to_st.append((centroids[n1] - centroids[n2]).norm().item())

st_to_inf = []
for n1 in st_names:
    for n2 in inf_names:
        st_to_inf.append((centroids[n1] - centroids[n2]).norm().item())

print(f"\nAverage distances:")
print(f"  ST-to-ST:  {np.mean(st_to_st):.4f} ± {np.std(st_to_st):.4f}")
print(f"  ST-to-ST4: {np.mean(st_to_inf):.4f} ± {np.std(st_to_inf):.4f}")
ratio = np.mean(st_to_inf) / np.mean(st_to_st)
print(f"  Ratio (ST4/ST): {ratio:.2f}x")

if ratio < 1.5:
    print("  ✓ EXCELLENT: ST4 very close to training ST centroids")
elif ratio < 2.5:
    print("  ✓ Good: Reasonable cross-slide distance")
else:
    print("  ⚠️ Large slide gap remains")

# ===================================================================
# SUMMARY
# ===================================================================
print("\n" + "=" * 70)
print("LIVER ENCODER EVALUATION SUMMARY")
print("=" * 70)
print(f"✓ ST4 mixing:        same-domain frac = {frac_same_inf:.4f} (base = {base_rate_inf:.4f})")
print(f"✓ {n_sources}-class probe:   balanced acc = {bal_acc:.4f} (chance = {chance:.3f})")
print(f"✓ Centroid ratio:    ST-to-ST4 / ST-to-ST = {ratio:.2f}x")
print("=" * 70)

In [None]:
# ===================================================================
# METRIC CONSISTENCY DIAGNOSTIC v4 — FULL CHAIN VERIFICATION
# ===================================================================
# v2 proved overlap@20 = 0.02 on saved model (random-level).
# v3 tested cosine vs euclidean + checkpoint integrity.
# v4 adds:
#   - Compare encoder_final_new.pt vs encoder_final_trained.pt
#   - Test both files against the same overlap metric
#   - Check if the model learned ANYTHING (weight diff + loss comparison)
# ===================================================================

import torch
import torch.nn.functional as F
import numpy as np
from scipy.spatial import cKDTree
import os
from ssl_utils import (set_seed, precompute_spatial_nce_structures,
                       compute_knn_locality_metrics, compute_spatial_infonce_supportset)
from core_models_et_p1 import SharedEncoder

print("=" * 70)
print("DIAGNOSTIC v4: FULL CHECKPOINT CHAIN VERIFICATION")
print("=" * 70)

K = 20
CKPT_DIR = '/home/ehtesamul/sc_st/model/gems_liver_crossslide'

# Precompute pos_idx (same as training uses)
spatial_nce_data = precompute_spatial_nce_structures(
    st_coords=st_coords, st_gene_expr=st_expr, slide_ids=slide_ids,
    k_phys=20, far_mult=4.0, n_hard=20, device=device,
)
pos_idx = spatial_nce_data['pos_idx']

# ===================================================================
# TEST 0: Check which checkpoint files exist
# ===================================================================
print("\n[TEST 0] Checkpoint files on disk")
print("-" * 60)
for fname in ['encoder_final_new.pt', 'encoder_final_trained.pt']:
    fpath = os.path.join(CKPT_DIR, fname)
    if os.path.exists(fpath):
        fsize = os.path.getsize(fpath) / 1024
        mtime = os.path.getmtime(fpath)
        import datetime
        mtime_str = datetime.datetime.fromtimestamp(mtime).strftime('%Y-%m-%d %H:%M:%S')
        print(f"  {fname}: {fsize:.1f} KB, modified {mtime_str}")
    else:
        print(f"  {fname}: NOT FOUND")

# ===================================================================
# TEST 1: Load BOTH checkpoint files and compare weights
# ===================================================================
print("\n[TEST 1] Compare encoder_final_new.pt vs encoder_final_trained.pt")
print("-" * 60)

models = {}
for fname in ['encoder_final_new.pt', 'encoder_final_trained.pt']:
    fpath = os.path.join(CKPT_DIR, fname)
    if not os.path.exists(fpath):
        print(f"  SKIP: {fname} not found")
        continue
    m = SharedEncoder(n_genes=n_genes, n_embedding=[512, 256, 128], dropout=0.1)
    m.load_state_dict(torch.load(fpath, map_location=device))
    m.to(device).eval()
    models[fname] = m

if len(models) == 2:
    sd1 = models['encoder_final_new.pt'].state_dict()
    sd2 = models['encoder_final_trained.pt'].state_dict()
    max_diff = 0.0
    for k in sd1:
        d = (sd1[k].cpu().float() - sd2[k].cpu().float()).abs().max().item()
        max_diff = max(max_diff, d)
    print(f"  Max param diff between files: {max_diff:.2e}")
    if max_diff < 1e-6:
        print("  IDENTICAL — both files have the same weights")
    else:
        print("  DIFFERENT — files diverge!")

# ===================================================================
# TEST 2: Random model (same seed as training init)
# ===================================================================
print("\n[TEST 2] Random model baseline")
print("-" * 60)
set_seed(SEED)
encoder_random = SharedEncoder(n_genes=n_genes, n_embedding=[512, 256, 128], dropout=0.1)
encoder_random.to(device).eval()

# ===================================================================
# TEST 3: Compare trained vs random weights
# ===================================================================
print("\n[TEST 3] Trained vs Random weight comparison")
print("-" * 60)

# Use whichever trained file exists
trained_key = 'encoder_final_trained.pt' if 'encoder_final_trained.pt' in models else list(models.keys())[0]
encoder_trained = models[trained_key]

sd_random = encoder_random.state_dict()
sd_trained = encoder_trained.state_dict()

params_differ = 0
params_same = 0
total_l2_diff = 0.0
for k in sd_random:
    r = sd_random[k].cpu().float()
    t = sd_trained[k].cpu().float()
    max_diff = (r - t).abs().max().item()
    l2_diff = (r - t).norm().item()
    total_l2_diff += l2_diff
    if max_diff > 1e-6:
        params_differ += 1
        if params_differ <= 5:
            print(f"  {k}: DIFFERENT — max_diff={max_diff:.6f}, L2_diff={l2_diff:.4f}")
            print(f"    random[:5]: {r.flatten()[:5].tolist()}")
            print(f"    trained[:5]: {t.flatten()[:5].tolist()}")
    else:
        params_same += 1

print(f"\n  Summary: {params_differ} different, {params_same} same, total L2 diff={total_l2_diff:.4f}")
if params_differ == 0:
    print("  *** CRITICAL: Trained model == Random init! Checkpoint is CORRUPT ***")
else:
    print(f"  OK: Model has learned (weights changed)")

# ===================================================================
# TEST 4: Overlap@20 with EUCLIDEAN vs COSINE kNN
# ===================================================================
print("\n[TEST 4] Overlap@20 — Euclidean vs Cosine kNN")
print("-" * 60)

slide_names_list = list(ST_PATHS.keys())
offset = 0
slide_ranges = {}
for i, name in enumerate(slide_names_list):
    slide_ranges[name] = (offset, offset + ns[i])
    offset += ns[i]

for enc_name, enc_model in [("RANDOM", encoder_random), ("TRAINED", encoder_trained)]:
    enc_model.eval()
    with torch.no_grad():
        z_parts = []
        for ci in range(0, st_expr.shape[0], 512):
            z_parts.append(enc_model(st_expr[ci:ci+512]))
        z_all = torch.cat(z_parts, dim=0)

    z_all_norm = F.normalize(z_all, dim=1)

    print(f"\n  === {enc_name} ===")

    # Embedding norm stats
    norms = z_all.norm(dim=1)
    print(f"    Norms: mean={norms.mean():.3f}, std={norms.std():.3f}, "
          f"min={norms.min():.3f}, max={norms.max():.3f}, CV={norms.std()/norms.mean():.3f}")

    for si, name in enumerate(slide_names_list):
        s, e = slide_ranges[name]
        n_s = e - s
        z_slide = z_all[s:e]
        z_slide_norm = z_all_norm[s:e]

        # Euclidean kNN
        dists_euc = torch.cdist(z_slide, z_slide)
        dists_euc.fill_diagonal_(float('inf'))
        _, euc_knn = torch.topk(dists_euc, k=K, dim=1, largest=False)
        euc_knn_global = euc_knn + s

        # Cosine kNN
        sims_cos = z_slide_norm @ z_slide_norm.T
        sims_cos.fill_diagonal_(-float('inf'))
        _, cos_knn = torch.topk(sims_cos, k=K, dim=1, largest=True)
        cos_knn_global = cos_knn + s

        # Physical kNN
        phys_knn_global = pos_idx[s:e]

        overlaps_euc = []
        overlaps_cos = []
        euc_cos_agree = []

        for i in range(n_s):
            euc_set = set(euc_knn_global[i].cpu().tolist())
            cos_set = set(cos_knn_global[i].cpu().tolist())
            phys_set = set(phys_knn_global[i].cpu().tolist())
            phys_set.discard(-1)
            if len(phys_set) > 0:
                overlaps_euc.append(len(euc_set & phys_set) / len(phys_set))
                overlaps_cos.append(len(cos_set & phys_set) / len(phys_set))
            euc_cos_agree.append(len(euc_set & cos_set) / K)

        print(f"\n    {name} (n={n_s}):")
        print(f"      Euclidean overlap@{K}: {np.mean(overlaps_euc):.4f}")
        print(f"      Cosine overlap@{K}:    {np.mean(overlaps_cos):.4f}")
        print(f"      Euc-Cos agreement:     {np.mean(euc_cos_agree):.4f}")

    # Also test with the EXACT same function used during training
    ov = compute_knn_locality_metrics(
        model=enc_model, st_gene_expr=st_expr,
        st_coords=st_coords, slide_ids=slide_ids,
        phys_knn_idx=pos_idx, k=20, n_sample=300,
    )
    print(f"\n    compute_knn_locality_metrics: overlap={ov['overlap_mean']:.4f}, "
          f"emb_phys_dist={ov['emb_phys_dist_median']:.4f}")

# ===================================================================
# TEST 5: NCE loss comparison (trained should be LOWER than random)
# ===================================================================
print("\n\n[TEST 5] NCE Loss: Random vs Trained")
print("-" * 60)

for enc_name, enc_model in [("RANDOM", encoder_random), ("TRAINED", encoder_trained)]:
    enc_model.eval()
    with torch.no_grad():
        z_parts = []
        for ci in range(0, st_expr.shape[0], 512):
            z_parts.append(enc_model(st_expr[ci:ci+512]))
        z_cache = torch.cat(z_parts, dim=0)

    enc_model.train()
    loss_nce = compute_spatial_infonce_supportset(
        model=enc_model,
        st_gene_expr=st_expr,
        pos_idx=pos_idx,
        far_mask=spatial_nce_data['far_mask'],
        hard_neg=spatial_nce_data['hard_neg'],
        slide_ids=slide_ids,
        tau=0.1,
        n_rand_neg=128,
        n_anchors_per_step=64,
        slide_override=0,
        z_cache=z_cache.detach(),
        n_hard_mine=20,
    )
    enc_model.eval()
    print(f"  {enc_name}: NCE loss = {loss_nce.item():.4f}")

# ===================================================================
# TEST 6: Training history check (if available)
# ===================================================================
print("\n\n[TEST 6] Training history check")
print("-" * 60)
hist_path = os.path.join(CKPT_DIR, 'stageA_vicreg_history.json')
if os.path.exists(hist_path):
    import json
    with open(hist_path) as f:
        hist = json.load(f)
    # Check the overlap history
    ov_S = hist.get('locality_overlap_after_S', [])
    ov_M = hist.get('locality_overlap_after_M', [])
    epochs = hist.get('epoch', [])

    # Non-zero overlap values (computed every 100 epochs)
    nonzero_S = [(e, v) for e, v in zip(epochs, ov_S) if v > 0.001]
    nonzero_M = [(e, v) for e, v in zip(epochs, ov_M) if v > 0.001]

    print(f"  History length: {len(epochs)} epochs")
    print(f"  Non-zero overlap measurements: {len(nonzero_S)} (after S), {len(nonzero_M)} (after M)")
    if nonzero_M:
        print(f"\n  overlap@20 after Step M (every 100 epochs):")
        for e, v in nonzero_M:
            print(f"    epoch {e:5d}: {v:.4f}")
        best_e, best_v = max(nonzero_M, key=lambda x: x[1])
        print(f"\n  Best overlap: {best_v:.4f} at epoch {best_e}")
else:
    print("  History file not found")

# ===================================================================
# SUMMARY
# ===================================================================
print("\n\n" + "=" * 70)
print("DIAGNOSTIC v4 SUMMARY")
print("=" * 70)
print("""
KEY TESTS:
  TEST 1: Are encoder_final_new.pt and encoder_final_trained.pt identical?
  TEST 3: Are trained weights different from random init?
  TEST 4: Cosine vs Euclidean overlap — which is higher?
  TEST 5: NCE loss — is trained lower than random?
  TEST 6: Training history — what overlap was recorded?

INTERPRETATION:
  If TEST 3 says identical    → checkpoint save/restore is broken
  If TEST 4 cosine >> euclid  → metric should use cosine (NCE trains cosine)
  If TEST 5 losses are same   → model didn't learn spatial structure
  If TEST 6 shows 0.69        → training metric was live, bug is in save/load
""")
print("=" * 70)

In [None]:
# ===================================================================
# SECTION B+C: BUG VERIFICATION & METRIC CONSISTENCY
# ===================================================================
# B1) Bug #1: coordinate mismatch after subsampling
# B2) Bug #2: subsampling destroys spatial resolution
# C)  Per-slide overlap@20, p25/median physical distance
# ===================================================================

from verify_bugs_and_metrics import run_all_checks

CKPT_DIR = '/home/ehtesamul/sc_st/model/gems_liver_crossslide'
ckpt_path = f'{CKPT_DIR}/encoder_final_trained.pt'

verification_results = run_all_checks(
    st_expr=st_expr,
    st_coords=st_coords,
    slide_ids=slide_ids,
    inf_expr=inf_expr,
    n_genes=n_genes,
    ns=ns,
    slide_names=list(ST_PATHS.keys()),
    ckpt_path=ckpt_path,
    device=device,
    seed=SEED,
)