In [None]:
import torch
import torch.nn.functional as F
import scanpy as sc
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import balanced_accuracy_score
from sklearn.model_selection import cross_val_score, StratifiedKFold
import sys
import os

sys.path.insert(0, '/home/ehtesamul/sc_st/model')
from core_models_et_p1 import SharedEncoder, train_encoder
from ssl_utils import set_seed
import utils_et as uet

device = 'cuda' if torch.cuda.is_available() else 'cpu'
SEED = 42
set_seed(SEED)

print("=" * 70)
print("CROSS-SLIDE ALIGNMENT: Mouse Liver")
print("ST training = ST1, ST2, ST3 | Inference target = ST4")
print("Same patient (mouse) — cross-slide gap only")
print("=" * 70)

# Load data
ST_PATHS = {
    'liver_ST1': '/home/ehtesamul/sc_st/data/liver/stadata1.h5ad',
    'liver_ST2': '/home/ehtesamul/sc_st/data/liver/stadata2.h5ad',
    'liver_ST3': '/home/ehtesamul/sc_st/data/liver/stadata3.h5ad',
}
INF_PATHS = {
    'liver_ST4': '/home/ehtesamul/sc_st/data/liver/stadata4.h5ad',
}

st_data = {}
for name, path in ST_PATHS.items():
    st_data[name] = sc.read_h5ad(path)
    print(f"  {name}: {st_data[name].n_obs} spots")

inf_data = {}
for name, path in INF_PATHS.items():
    inf_data[name] = sc.read_h5ad(path)
    print(f"  {name}: {inf_data[name].n_obs} spots")

# Normalize
all_data = list(st_data.values()) + list(inf_data.values())
for adata in all_data:
    sc.pp.normalize_total(adata)
    sc.pp.log1p(adata)

# Common genes
common_genes = sorted(set.intersection(*[set(a.var_names) for a in all_data]))
n_genes = len(common_genes)
print(f"\n✓ Common genes: {n_genes}")

# Extract expression
def extract_expr(adata, genes):
    X = adata[:, genes].X
    return X.toarray() if hasattr(X, "toarray") else X

X_st = {name: extract_expr(st_data[name], common_genes) for name in ST_PATHS}
X_inf = {name: extract_expr(inf_data[name], common_genes) for name in INF_PATHS}

for name, X in {**X_st, **X_inf}.items():
    print(f"  {name}: {X.shape}")

labels_str = list(ST_PATHS.keys()) + list(INF_PATHS.keys())

colors_map = {
    'liver_ST1': '#e74c3c',
    'liver_ST2': '#3498db',
    'liver_ST3': '#2ecc71',
    'liver_ST4': '#f39c12',
}

In [None]:
print("\n" + "=" * 70)
print("PREPARING TRAINING DATA")
print("=" * 70)

# ST domain (training slides with coords)
st_expr = torch.tensor(
    np.vstack([X_st[n] for n in ST_PATHS]),
    dtype=torch.float32, device=device
)

st_coords_list = [st_data[n].obsm['spatial'] for n in ST_PATHS]
st_coords_raw = torch.tensor(
    np.vstack(st_coords_list), dtype=torch.float32, device=device
)

ns = [X_st[n].shape[0] for n in ST_PATHS]
slide_ids = torch.tensor(
    np.concatenate([np.full(n, i, dtype=int) for i, n in enumerate(ns)]),
    dtype=torch.long, device=device
)

st_coords, st_mu, st_scale = uet.canonicalize_st_coords_per_slide(st_coords_raw, slide_ids)

print(f"✓ ST expr: {st_expr.shape}")
for i, (name, n) in enumerate(zip(ST_PATHS, ns)):
    print(f"  - {name}: {n} spots (slide {i})")
print(f"✓ ST coords: {st_coords.shape} (canonicalized per-slide)")

# Inference domain (ST4 treated as "SC")
inf_expr = torch.tensor(X_inf['liver_ST4'], dtype=torch.float32, device=device)
n_inf = X_inf['liver_ST4'].shape[0]
inf_slide_ids = torch.zeros(n_inf, dtype=torch.long, device=device)
inf_patient_ids = torch.zeros(n_inf, dtype=torch.long, device=device)  # same patient

print(f"\n✓ Inference expr: {inf_expr.shape}")
print(f"  - liver_ST4: {n_inf} spots (inference domain)")

# Source IDs for adversary
st_source_ids = torch.tensor(
    np.concatenate([np.full(n, i, dtype=int) for i, n in enumerate(ns)]),
    dtype=torch.long
)
inf_source_ids = torch.full((n_inf,), len(ns), dtype=torch.long)

print(f"\n✓ ST source IDs: {st_source_ids.unique().tolist()}, counts: {ns}")
print(f"✓ Inf source IDs: {inf_source_ids.unique().tolist()}, count: {n_inf}")
print(f"\nTotal for Stage A: {st_expr.shape[0] + n_inf}")

In [None]:
# ===========================================================================
# SPATIAL InfoNCE + ALL-LOSS DIAGNOSTICS
# ===========================================================================
# Add after Cell 1 (data prep). Uses: st_expr, inf_expr, st_coords,
# slide_ids, n_genes, device, SEED, X_ssl is built below.
# ===========================================================================

import torch
import torch.nn.functional as F
import numpy as np
from collections import defaultdict
import matplotlib.pyplot as plt

from core_models_et_p1 import SharedEncoder
from ssl_utils import (
    set_seed, vicreg_loss, compute_local_alignment_loss,
    coral_loss, mmd_rbf_loss,
    precompute_spatial_nce_structures, compute_spatial_infonce_loss,
)

set_seed(SEED)

# ---- build concat pool + domain labels ----
n_st = st_expr.shape[0]
n_sc = inf_expr.shape[0]
X_ssl = torch.cat([st_expr, inf_expr], dim=0)
domain_ids = torch.cat([
    torch.zeros(n_st, dtype=torch.long, device=device),
    torch.ones(n_sc, dtype=torch.long, device=device),
])

# ---- precompute spatial NCE structures ----
nce_data = precompute_spatial_nce_structures(
    st_coords=st_coords, st_gene_expr=st_expr, slide_ids=slide_ids,
    k_phys=20, far_mult=4.0, n_hard=20, device=device,
)
pos_idx  = nce_data['pos_idx']
far_mask = nce_data['far_mask']
hard_neg = nce_data['hard_neg']
r_pos    = nce_data['r_pos']

# ---- tiny helpers ----
def augment_expr(X, drop=0.2, gauss=0.01, jitter=0.1):
    m = torch.bernoulli(torch.full_like(X, 1.0 - drop))
    X = X * m + torch.randn_like(X) * gauss
    s = 1.0 + (torch.rand(X.shape[0], 1, device=X.device) * 2 - 1) * jitter
    return X * s

def enc_grad_norm(enc):
    return sum(p.grad.data.norm(2).item() ** 2
               for p in enc.parameters() if p.grad is not None) ** 0.5

def balanced_batch(n_st, n_sc, bs, device):
    h = bs // 2
    i_st = torch.randperm(n_st, device=device)[:h]
    i_sc = torch.randperm(n_sc, device=device)[:bs - h] + n_st
    idx = torch.cat([i_st, i_sc])
    return idx[torch.randperm(idx.shape[0], device=device)]

vicreg_fn = vicreg_loss(
    lambda_inv=25.0, lambda_var=50.0, lambda_cov=1.0,
    gamma=1.0, eps=1e-4, use_projector=False, float32_stats=True,
)

# ======================================================================
# CHECK C — Index-set sanity
# ======================================================================
print("=" * 70)
print("CHECK C: Index-Set Sanity (positives / hard-negs / overlap)")
print("=" * 70)

coords_np = st_coords.cpu().numpy()
rng = np.random.default_rng(SEED)
anchors = rng.choice(n_st, size=min(300, n_st), replace=False)

pos_d, hard_d, overlaps = [], [], []
for i in anchors:
    pg = pos_idx[i]; pg = pg[pg >= 0].cpu().numpy()
    hg = hard_neg[i]; hg = hg[hg >= 0].cpu().numpy()
    if len(pg): pos_d.extend(np.linalg.norm(coords_np[pg] - coords_np[i], axis=1))
    if len(hg): hard_d.extend(np.linalg.norm(coords_np[hg] - coords_np[i], axis=1))
    overlaps.append(len(set(pg) & set(hg)))

pos_d, hard_d = np.asarray(pos_d), np.asarray(hard_d)
r_far = [r * 4.0 for r in r_pos]

print(f"\n  Sampled {len(anchors)} anchors")
print(f"  POSITIVES  — n={len(pos_d):,}  median_dist={np.median(pos_d):.4f}  "
      f"max={np.max(pos_d):.4f}")
print(f"  HARD_NEG   — n={len(hard_d):,}  median_dist={np.median(hard_d):.4f}  "
      f"min={np.min(hard_d):.4f}")
print(f"  r_pos/slide: {[f'{r:.4f}' for r in r_pos]}")
print(f"  r_far/slide: {[f'{r:.4f}' for r in r_far]}")
leak = (hard_d < min(r_far)).sum() if len(hard_d) else 0
print(f"  Hard negs below r_far (LEAK): {leak}/{len(hard_d)}  "
      f"{'✓ OK' if leak == 0 else '⚠️ BUG'}")
print(f"  Pos ∩ Neg overlap total: {sum(overlaps)}  "
      f"{'✓ zero' if sum(overlaps) == 0 else '⚠️ OVERLAP!'}")
no_pos = (pos_idx[:, 0] < 0).sum().item()
no_neg = (hard_neg[:, 0] < 0).sum().item()
print(f"  Spots w/ no positives: {no_pos}/{n_st}   no hard_neg: {no_neg}/{n_st}")

# ======================================================================
# CHECK A — Per-loss gradient norms (single batch)
# ======================================================================
print("\n" + "=" * 70)
print("CHECK A: Per-Loss Gradient Norms (single batch, bs=256)")
print("=" * 70)

set_seed(SEED)
enc_a = SharedEncoder(n_genes=n_genes, n_embedding=[512, 256, 128],
                      dropout=0.1).to(device).train()
BS = 256
idx = balanced_batch(n_st, n_sc, BS, device)
X_b = X_ssl[idx]; s_b = domain_ids[idx]
is_sc = (idx >= n_st); is_st = (s_b == 0)

results = {}  # name → (loss_val, grad_norm)

# --- VICReg ---
enc_a.zero_grad()
y1, y2 = enc_a(augment_expr(X_b)), enc_a(augment_expr(X_b))
lv, _ = vicreg_fn(y1, y2)
lv.backward(); results['VICReg'] = (lv.item(), enc_grad_norm(enc_a))

# --- Spatial InfoNCE ---
enc_a.zero_grad()
z = enc_a(X_b)
ln = compute_spatial_infonce_loss(
    z=z, batch_idx=idx, pos_idx=pos_idx, far_mask=far_mask,
    hard_neg=hard_neg, tau=0.1, n_rand_neg=128, is_st_mask=is_st)
ln.backward(); results['SpatialNCE'] = (ln.item(), enc_grad_norm(enc_a))

# --- Local Alignment ---
enc_a.zero_grad()
z = enc_a(X_b)
zst, zsc = z[~is_sc], z[is_sc]; xst, xsc = X_b[~is_sc], X_b[is_sc]
ll = compute_local_alignment_loss(z_sc=zsc, z_st=zst, x_sc=xsc, x_st=xst,
                                  tau_z=0.07, bidirectional=True
     ) if zst.shape[0] > 8 and zsc.shape[0] > 8 else torch.tensor(0., device=device)
ll.backward(); results['LocalAlign'] = (ll.item(), enc_grad_norm(enc_a))

# --- CORAL ---
enc_a.zero_grad()
z = enc_a(X_b); zst, zsc = z[~is_sc], z[is_sc]
lc = coral_loss(zst, zsc) if zst.shape[0] > 4 and zsc.shape[0] > 4 \
     else torch.tensor(0., device=device)
lc.backward(); results['CORAL'] = (lc.item(), enc_grad_norm(enc_a))

# --- MMD ---
enc_a.zero_grad()
z = enc_a(X_b); zst, zsc = F.normalize(z[~is_sc], 1), F.normalize(z[is_sc], 1)
lm, _ = mmd_rbf_loss(zst, zsc, return_sigma=True) \
        if zst.shape[0] > 4 and zsc.shape[0] > 4 \
        else (torch.tensor(0., device=device), 1.0)
lm.backward(); results['MMD'] = (lm.item(), enc_grad_norm(enc_a))

# --- Full combined (weights = your training config) ---
enc_a.zero_grad()
y1, y2 = enc_a(augment_expr(X_b)), enc_a(augment_expr(X_b))
lv_f, _ = vicreg_fn(y1, y2)
z_f = enc_a(X_b)
ln_f = compute_spatial_infonce_loss(z=z_f, batch_idx=idx, pos_idx=pos_idx,
    far_mask=far_mask, hard_neg=hard_neg, tau=0.1, n_rand_neg=128, is_st_mask=is_st)
zst_f, zsc_f = z_f[~is_sc], z_f[is_sc]
xst_f, xsc_f = X_b[~is_sc], X_b[is_sc]
ll_f = compute_local_alignment_loss(z_sc=zsc_f, z_st=zst_f, x_sc=xsc_f,
    x_st=xst_f, tau_z=0.07, bidirectional=True
) if zst_f.shape[0] > 8 and zsc_f.shape[0] > 8 else torch.tensor(0., device=device)
lc_f = coral_loss(zst_f, zsc_f) if zst_f.shape[0] > 4 and zsc_f.shape[0] > 4 \
       else torch.tensor(0., device=device)
zst_n, zsc_n = F.normalize(zst_f, 1), F.normalize(zsc_f, 1)
lm_f, _ = mmd_rbf_loss(zst_n, zsc_n, return_sigma=True)
L = lv_f + 3.0*ln_f + 3.0*ll_f + lc_f + 30.0*lm_f   # no adv here
L.backward(); results['FULL'] = (L.item(), enc_grad_norm(enc_a))

print(f"\n  {'Loss':<16s} {'value':>10s} {'‖∇θ‖':>12s} {'% of Full':>10s}")
print(f"  {'-'*16} {'-'*10} {'-'*12} {'-'*10}")
gn_full = results['FULL'][1]
for name in ['VICReg','SpatialNCE','LocalAlign','CORAL','MMD','FULL']:
    val, gn = results[name]
    pct = 100*gn/max(gn_full,1e-10)
    flag = " ← NO-OP!" if gn < 1e-8 and name != 'FULL' else ""
    print(f"  {name:<16s} {val:>10.4f} {gn:>12.6f} {pct:>9.1f}%{flag}")

gn_nce = results['SpatialNCE'][1]
if gn_nce < 1e-8:
    print("\n  ⚠️  SpatialNCE grad is ZERO — loss is disconnected!")
else:
    print(f"\n  ✓ SpatialNCE grad live (ratio vs VICReg: "
          f"{gn_nce/max(results['VICReg'][1],1e-10):.3f})")

# ======================================================================
# CHECK B — Loss scale + similarity gap (200 steps)
# ======================================================================
print("\n" + "=" * 70)
print("CHECK B: Loss Scale + Similarity Gap (200 training steps)")
print("=" * 70)

set_seed(SEED)
enc_b = SharedEncoder(n_genes=n_genes, n_embedding=[512, 256, 128],
                      dropout=0.1).to(device).train()
opt_b = torch.optim.Adam(enc_b.parameters(), lr=1e-4)

N_STEPS = 200
log = defaultdict(list)

for step in range(N_STEPS):
    idx = balanced_batch(n_st, n_sc, 256, device)
    X_b = X_ssl[idx]; s_b = domain_ids[idx]
    is_sc_b = (idx >= n_st); is_st_b = (s_b == 0)

    # VICReg
    y1, y2 = enc_b(augment_expr(X_b)), enc_b(augment_expr(X_b))
    l_vic, sv = vicreg_fn(y1, y2)

    # Clean embeddings
    z = enc_b(X_b)
    z_n = F.normalize(z, dim=1)

    # --- Spatial NCE with per-anchor similarity logging ---
    # one-time global→local map
    max_g = max(idx.max().item(), int(pos_idx.max().item()),
                int(hard_neg.max().item())) + 1
    g2l = torch.full((max_g,), -1, dtype=torch.long, device=device)
    g2l[idx] = torch.arange(len(idx), device=device)

    sp_list, sh_list, sr_list = [], [], []
    st_locals = torch.where(is_st_b)[0]

    for li in st_locals:
        gi = idx[li].item()
        za = z_n[li]
        pg = pos_idx[gi]; pg = pg[pg >= 0]
        if pg.numel() == 0: continue
        pl = g2l[pg]; pl = pl[pl >= 0]
        if pl.numel() == 0: continue

        hg = hard_neg[gi]; hg = hg[hg >= 0]
        hl = g2l[hg.clamp(max=max_g-1)]; hl = hl[hl >= 0]

        fb = far_mask[gi][idx]; fb[li] = False
        fc = torch.where(fb)[0]
        rn = fc[torch.randperm(fc.numel(), device=device)[:128]] \
             if fc.numel() > 128 else fc

        sp_list.append((za @ z_n[pl].T).mean().item())
        if hl.numel(): sh_list.append((za @ z_n[hl].T).mean().item())
        if rn.numel(): sr_list.append((za @ z_n[rn].T).mean().item())

    l_nce = compute_spatial_infonce_loss(
        z=z, batch_idx=idx, pos_idx=pos_idx, far_mask=far_mask,
        hard_neg=hard_neg, tau=0.1, n_rand_neg=128, is_st_mask=is_st_b)

    # Other losses
    zst, zsc = z[~is_sc_b], z[is_sc_b]
    xst, xsc = X_b[~is_sc_b], X_b[is_sc_b]
    ok = zst.shape[0] > 8 and zsc.shape[0] > 8
    l_la  = compute_local_alignment_loss(z_sc=zsc, z_st=zst, x_sc=xsc,
                x_st=xst, tau_z=0.07, bidirectional=True) if ok \
            else torch.tensor(0., device=device)
    l_cor = coral_loss(zst, zsc) if ok else torch.tensor(0., device=device)
    zst_n, zsc_n = F.normalize(zst, 1), F.normalize(zsc, 1)
    l_mmd, sig = mmd_rbf_loss(zst_n, zsc_n, return_sigma=True) if ok \
                 else (torch.tensor(0., device=device), 1.0)

    L = l_vic + 3.0*l_nce + 3.0*l_la + l_cor + 30.0*l_mmd
    opt_b.zero_grad(); L.backward(); opt_b.step()

    # log
    log['loss_total'].append(L.item())
    log['loss_vic'].append(l_vic.item())
    log['loss_nce'].append(l_nce.item())
    log['loss_la'].append(l_la.item())
    log['loss_coral'].append(l_cor.item())
    log['loss_mmd'].append(l_mmd.item() if torch.is_tensor(l_mmd) else l_mmd)
    log['inv'].append(sv['inv']); log['var'].append(sv['var'])
    log['cov'].append(sv['cov']); log['std_min'].append(sv['std_min'])
    sp = np.mean(sp_list) if sp_list else float('nan')
    sh = np.mean(sh_list) if sh_list else float('nan')
    sr = np.mean(sr_list) if sr_list else float('nan')
    log['sim_pos'].append(sp); log['sim_hard'].append(sh)
    log['sim_rand'].append(sr); log['n_anchors'].append(len(sp_list))

    if step % 50 == 0 or step < 3:
        gap = sp - sh if not np.isnan(sh) else float('nan')
        print(f"  {step:>3d} | tot={L.item():.3f} vic={l_vic.item():.3f} "
              f"nce={l_nce.item():.4f} la={l_la.item():.4f} "
              f"cor={l_cor.item():.5f} mmd={l_mmd.item():.5f} | "
              f"sim p/h/r={sp:+.4f}/{sh:+.4f}/{sr:+.4f} gap={gap:+.4f} "
              f"| anchors={len(sp_list)}")

# summary
print(f"\n  --- 200-step summary ---")
for k in ['loss_nce','sim_pos','sim_hard','sim_rand']:
    v = log[k]
    print(f"  {k:<12s}  start={v[0]:.4f}  end={v[-1]:.4f}  "
          f"Δ={v[-1]-v[0]:+.4f}")
gap0 = log['sim_pos'][0] - log['sim_hard'][0]
gapN = log['sim_pos'][-1] - log['sim_hard'][-1]
print(f"  gap(p-h)    start={gap0:+.4f}  end={gapN:+.4f}  Δ={gapN-gap0:+.4f}")
print(f"  avg anchors/batch: {np.mean(log['n_anchors']):.1f}")

if np.mean(log['n_anchors']) < 5:
    print("  ⚠️  Very few in-batch anchors — positives rarely in same batch")
if abs(log['loss_nce'][-1] - log['loss_nce'][0]) < 0.01:
    print("  ⚠️  NCE loss barely moved — check gradient flow")
if gapN > gap0 + 0.02:
    print("  ✓ Gap is GROWING — NCE is learning spatial locality")
else:
    print("  ⚠️  Gap not growing — NCE may be dominated by other losses")

# ======================================================================
# PLOTS
# ======================================================================
steps = list(range(N_STEPS))
fig, axes = plt.subplots(2, 4, figsize=(22, 9))
fig.suptitle('Spatial InfoNCE + All-Loss Diagnostics (200 steps)',
             fontweight='bold', fontsize=14)

axes[0,0].plot(steps, log['loss_total'], lw=2)
axes[0,0].set_title('Total Loss'); axes[0,0].grid(alpha=.3)

axes[0,1].plot(steps, log['loss_vic'], c='tab:blue', lw=1.5)
axes[0,1].set_title('VICReg'); axes[0,1].grid(alpha=.3)

axes[0,2].plot(steps, log['loss_nce'], c='tab:red', lw=2)
axes[0,2].set_title('Spatial InfoNCE'); axes[0,2].grid(alpha=.3)

axes[0,3].plot(steps, log['loss_la'], label='LocalAlign', c='tab:green')
axes[0,3].plot(steps, log['loss_coral'], label='CORAL', c='tab:orange')
axes[0,3].plot(steps, log['loss_mmd'], label='MMD', c='tab:purple')
axes[0,3].set_title('Alignment Losses'); axes[0,3].legend(fontsize=7)
axes[0,3].grid(alpha=.3)

axes[1,0].plot(steps, log['sim_pos'], c='green', lw=2, label='pos')
axes[1,0].plot(steps, log['sim_hard'], c='red', lw=2, label='hard_neg')
axes[1,0].plot(steps, log['sim_rand'], c='gray', lw=1.5, ls='--', label='rand_neg')
axes[1,0].set_title('Similarity (KEY)'); axes[1,0].legend(fontsize=7)
axes[1,0].set_ylabel('cosine sim'); axes[1,0].grid(alpha=.3)

gap_curve = [p-h for p,h in zip(log['sim_pos'], log['sim_hard'])]
axes[1,1].plot(steps, gap_curve, c='darkblue', lw=2)
axes[1,1].axhline(0, c='red', ls='--', alpha=.5)
axes[1,1].set_title('Gap: sim(pos)−sim(hard)'); axes[1,1].grid(alpha=.3)

axes[1,2].plot(steps, log['inv'], label='inv')
axes[1,2].plot(steps, log['var'], label='var')
axes[1,2].plot(steps, log['cov'], label='cov')
axes[1,2].set_title('VICReg Components'); axes[1,2].legend(fontsize=7)
axes[1,2].grid(alpha=.3)

axes[1,3].plot(steps, log['std_min'], c='purple', lw=2)
axes[1,3].axhline(0.1, c='red', ls='--', alpha=.5, label='collapse')
axes[1,3].set_title('std_min (collapse?)'); axes[1,3].legend(fontsize=7)
axes[1,3].grid(alpha=.3)

for ax in axes.flat: ax.set_xlabel('Step', fontsize=9)
plt.tight_layout()
plt.savefig('/home/ehtesamul/sc_st/model/gems_liver_crossslide/nce_diagnostics.png',
            dpi=150, bbox_inches='tight')
plt.show()
# print("Saved: gems_liver_crossslide/nce_diagnostics.png")

In [None]:
print("\n" + "=" * 70)
print("TRAINING ENCODER — CROSS-SLIDE (SAME PATIENT)")
print("=" * 70)

set_seed(SEED)
encoder = SharedEncoder(n_genes=n_genes, n_embedding=[512, 256, 128], dropout=0.1)

outdir = '/home/ehtesamul/sc_st/model/gems_liver_crossslide'
os.makedirs(outdir, exist_ok=True)

encoder, projector, discriminator, hist = train_encoder(
    inference_dropout_prob=0.5,
    model=encoder,
    st_gene_expr=st_expr,
    st_coords=st_coords,
    sc_gene_expr=inf_expr,
    slide_ids=slide_ids,
    sc_slide_ids=inf_slide_ids,
    sc_patient_ids=inf_patient_ids,
    n_epochs=2000,
    batch_size=256,
    lr=1e-4,
    device=device,
    outf=outdir,
    st_source_ids=st_source_ids,
    sc_source_ids=inf_source_ids,
    use_source_adversary=False,
    source_coral_weight=75.0,
    stageA_obj='vicreg_adv',
    vicreg_lambda_inv=25.0,
    vicreg_lambda_var=50.0,
    vicreg_lambda_cov=1.0,
    vicreg_gamma=1.0,
    vicreg_eps=1e-4,
    vicreg_project_dim=256,
    vicreg_use_projector=False,
    vicreg_float32_stats=True,
    vicreg_ddp_gather=False,
    aug_gene_dropout=0.25,
    aug_gauss_std=0.01,
    aug_scale_jitter=0.1,
    adv_slide_weight=75.0,
    patient_coral_weight=0.0,
    mmd_weight=30.0,
    mmd_use_l2norm=True,
    mmd_ramp=True,
    adv_warmup_epochs=50,
    adv_ramp_epochs=200,
    grl_alpha_max=1.0,
    disc_hidden=512,
    disc_dropout=0.1,
    stageA_balanced_slides=True,
    adv_representation_mode='clean',
    adv_use_layernorm=False,
    adv_log_diagnostics=True,
    adv_log_grad_norms=False,
    use_local_align=True,
    return_aux=True,
    local_align_bidirectional=True,
    local_align_weight=0.0,
    local_align_tau_z=0.07,
    seed=SEED,
    use_best_checkpoint=True,
    coral_raw_weight=2.0,
    knn_weight=0.0,              # OFF — this preserves expression neighborhoods (bad for liver)
    spatial_nce_weight=3.0,       # ON — enforces physical neighborhoods
    spatial_nce_k_phys=20,
    spatial_nce_far_mult=4.0,
    spatial_nce_n_hard=20,
    spatial_nce_tau=0.1,
    spatial_nce_n_rand_neg=128,
)

print("\n✓ Training complete!")
torch.save(encoder.state_dict(), f'{outdir}/encoder_final_trained.pt')
print(f"✓ Saved to: {outdir}/encoder_final_trained.pt")

In [None]:
print("\n" + "=" * 70)
print("COMPUTING EMBEDDINGS")
print("=" * 70)

encoder.eval()
with torch.no_grad():
    Z_st = {name: encoder(torch.tensor(X, dtype=torch.float32, device=device)).cpu()
            for name, X in X_st.items()}
    Z_inf = {name: encoder(torch.tensor(X, dtype=torch.float32, device=device)).cpu()
             for name, X in X_inf.items()}

for name, Z in {**Z_st, **Z_inf}.items():
    print(f"  {name}: {Z.shape}")

In [None]:
# Load trained encoder
encoder = SharedEncoder(n_genes=n_genes, n_embedding=[512, 256, 128], dropout=0.1)
encoder.load_state_dict(torch.load(
    '/home/ehtesamul/sc_st/model/gems_liver_crossslide/encoder_final_trained.pt',
    map_location=device
))
encoder.to(device)
encoder.eval()

# Compute embeddings
with torch.no_grad():
    Z_st = {name: encoder(torch.tensor(X, dtype=torch.float32, device=device)).cpu()
            for name, X in X_st.items()}
    Z_inf = {name: encoder(torch.tensor(X, dtype=torch.float32, device=device)).cpu()
             for name, X in X_inf.items()}

for name, Z in {**Z_st, **Z_inf}.items():
    print(f"  {name}: {Z.shape}")

In [None]:
# === Analysis A: Embedding kNN vs Physical kNN ===
from scipy.spatial import cKDTree

print("=" * 70)
print("ANALYSIS A: Embedding kNN vs Physical Neighborhoods (ST only)")
print("=" * 70)

for name in ST_PATHS:
    coords = st_data[name].obsm['spatial']
    Z = Z_st[name].numpy()
    n = Z.shape[0]

    tree_phys = cKDTree(coords)
    tree_emb = cKDTree(Z)

    rng = np.random.default_rng(SEED)

    for k in [10, 20, 50]:
        # physical distances of embedding kNN
        _, emb_idx = tree_emb.query(Z, k=k + 1)
        emb_idx = emb_idx[:, 1:]  # exclude self
        emb_phys_dists = np.array([
            np.median(np.linalg.norm(coords[emb_idx[i]] - coords[i], axis=1))
            for i in range(n)
        ])

        # physical distances of true physical kNN
        _, phys_idx = tree_phys.query(coords, k=k + 1)
        phys_idx = phys_idx[:, 1:]
        phys_phys_dists = np.array([
            np.median(np.linalg.norm(coords[phys_idx[i]] - coords[i], axis=1))
            for i in range(n)
        ])

        # random baseline
        rand_dists = np.array([
            np.median(np.linalg.norm(coords[rng.choice(n, k, replace=False)] - coords[i], axis=1))
            for i in range(n)
        ])

        print(f"\n  {name} | k={k}")
        print(f"    Physical kNN median dist:  {np.median(phys_phys_dists):.2f}")
        print(f"    Embedding kNN median dist: {np.median(emb_phys_dists):.2f}")
        print(f"    Random median dist:        {np.median(rand_dists):.2f}")

    # Histogram for k=20
    _, emb_idx = tree_emb.query(Z, k=21)
    emb_idx = emb_idx[:, 1:]
    _, phys_idx = tree_phys.query(coords, k=21)
    phys_idx = phys_idx[:, 1:]

    emb_d = [np.median(np.linalg.norm(coords[emb_idx[i]] - coords[i], axis=1)) for i in range(n)]
    phys_d = [np.median(np.linalg.norm(coords[phys_idx[i]] - coords[i], axis=1)) for i in range(n)]
    rand_d = [np.median(np.linalg.norm(coords[rng.choice(n, 20, replace=False)] - coords[i], axis=1)) for i in range(n)]

    fig, ax = plt.subplots(figsize=(6, 4))
    ax.hist(phys_d, bins=50, alpha=0.5, label='Physical kNN', density=True)
    ax.hist(emb_d, bins=50, alpha=0.5, label='Embedding kNN', density=True)
    ax.hist(rand_d, bins=50, alpha=0.5, label='Random', density=True)
    ax.set_xlabel('Median physical distance to k=20 neighbors')
    ax.set_ylabel('Density')
    ax.set_title(f'{name}: kNN Physical Distance Distributions')
    ax.legend()
    plt.tight_layout()
    plt.show()

In [None]:
# === Analysis B: Moran's I on Embedding Dimensions (no libpysal) ===
from scipy.spatial import cKDTree
from scipy.stats import norm
from statsmodels.stats.multitest import multipletests

print("=" * 70)
print("ANALYSIS B: Spatial Autocorrelation of Embeddings (Moran's I)")
print("=" * 70)

def morans_i_fast(x, idx, n):
    """Moran's I with precomputed kNN indices."""
    xm = x - x.mean()
    denom = np.sum(xm ** 2)
    if denom < 1e-12:
        return 0.0, 1.0
    k = idx.shape[1]
    W = n * k
    num = np.sum(xm[:, None] * xm[idx]).sum() if False else sum(np.sum(xm[i] * xm[idx[i]]) for i in range(n))
    # vectorized version:
    num = np.einsum('i,ij->', xm, xm[idx])
    I = (n / W) * (num / denom)
    EI = -1.0 / (n - 1)
    # Variance (normality assumption, simplified)
    VI = max(1.0 / (n - 1), 1e-10)  # conservative approx
    z = (I - EI) / np.sqrt(VI / n)
    p = 2 * norm.sf(abs(z))
    return I, p

for name in ST_PATHS:
    coords = st_data[name].obsm['spatial']
    Z = Z_st[name].numpy()
    n, d = Z.shape

    tree = cKDTree(coords)
    _, idx = tree.query(coords, k=11)
    idx = idx[:, 1:]  # exclude self

    morans = np.zeros(d)
    pvals = np.zeros(d)
    for j in range(d):
        morans[j], pvals[j] = morans_i_fast(Z[:, j], idx, n)

    _, pvals_fdr, _, _ = multipletests(pvals, method='fdr_bh')
    n_sig = np.sum(pvals_fdr < 0.05)

    # Embedding norm
    Z_norm = np.linalg.norm(Z, axis=1)
    I_norm, p_norm = morans_i_fast(Z_norm, idx, n)

    print(f"\n  {name}:")
    print(f"    Embedding dims: {d}")
    print(f"    Spatially autocorrelated dims (FDR<0.05): {n_sig}/{d} ({100*n_sig/d:.1f}%)")
    print(f"    Moran's I — mean: {morans.mean():.4f}, median: {np.median(morans):.4f}")
    print(f"    Moran's I of embedding norm: {I_norm:.4f} (p={p_norm:.4e})")

    fig, ax = plt.subplots(figsize=(6, 3))
    ax.hist(morans, bins=30, edgecolor='k', alpha=0.7)
    ax.axvline(0, color='red', linestyle='--', label='No autocorrelation')
    ax.set_xlabel("Moran's I")
    ax.set_ylabel('Count')
    ax.set_title(f"{name}: Moran's I across {d} embedding dims")
    ax.legend()
    plt.tight_layout()
    plt.show()

In [None]:
# === Analysis E: Distribution Checks ===
print("=" * 70)
print("ANALYSIS E: Distribution Checks (ST vs Inference)")
print("=" * 70)

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

all_stats = {}
for name, X in {**X_st, **X_inf}.items():
    lib_size = X.sum(axis=1)
    genes_detected = (X > 0).sum(axis=1)
    sparsity = (X == 0).mean() * 100
    all_stats[name] = {
        'lib_size_median': np.median(lib_size),
        'lib_size_std': np.std(lib_size),
        'genes_detected_median': np.median(genes_detected),
        'sparsity_pct': sparsity,
    }
    print(f"  {name}:")
    print(f"    Library size — median: {np.median(lib_size):.1f}, std: {np.std(lib_size):.1f}")
    print(f"    Genes detected — median: {np.median(genes_detected):.0f}")
    print(f"    Sparsity: {sparsity:.1f}%")
    print()

# Library size distributions
for name, X in {**X_st, **X_inf}.items():
    axes[0].hist(X.sum(axis=1), bins=50, alpha=0.5, label=name, density=True)
axes[0].set_xlabel('Library size (post-norm)')
axes[0].set_title('Library Size Distribution')
axes[0].legend(fontsize=7)

# Genes detected
for name, X in {**X_st, **X_inf}.items():
    axes[1].hist((X > 0).sum(axis=1), bins=50, alpha=0.5, label=name, density=True)
axes[1].set_xlabel('Genes detected per spot')
axes[1].set_title('Detected Genes Distribution')
axes[1].legend(fontsize=7)

# Per-gene mean expression
for name, X in {**X_st, **X_inf}.items():
    gene_means = X.mean(axis=0)
    axes[2].hist(gene_means, bins=50, alpha=0.5, label=name, density=True)
axes[2].set_xlabel('Per-gene mean expression')
axes[2].set_title('Gene Mean Distribution')
axes[2].legend(fontsize=7)

plt.tight_layout()
plt.show()

In [None]:
# === Analysis C: Spatially Variable Genes ===
from scipy.spatial import cKDTree
import pandas as pd

print("=" * 70)
print("ANALYSIS C: Spatially Variable Gene Analysis")
print("=" * 70)

for name in list(ST_PATHS.keys())[:1]:
    adata = st_data[name].copy()
    adata = adata[:, common_genes]
    coords = adata.obsm['spatial']
    X_dense = adata.X.toarray() if hasattr(adata.X, 'toarray') else adata.X
    n = X_dense.shape[0]

    tree = cKDTree(coords)
    _, idx = tree.query(coords, k=11)
    idx = idx[:, 1:]

    morans_vals = np.empty(X_dense.shape[1])
    for g in range(X_dense.shape[1]):
        x = X_dense[:, g]
        xm = x - x.mean()
        denom = np.sum(xm ** 2)
        if denom < 1e-12:
            morans_vals[g] = 0.0
            continue
        num = np.einsum('i,ij->', xm, xm[idx])
        W = n * 10
        morans_vals[g] = (n / W) * (num / denom)

    svg_results = pd.DataFrame({'I': morans_vals}, index=common_genes)
    svg_results = svg_results.sort_values('I', ascending=False)

    n_sig_svg = (svg_results['I'] > 0.1).sum()
    top300_mean_I = svg_results.head(300)['I'].mean()

    detection_rate = (X_dense > 0).mean(axis=0)
    n_ultra_sparse = (detection_rate < 0.01).sum()

    print(f"\n  {name} ({len(common_genes)} common genes):")
    print(f"    Genes with Moran's I > 0.1: {n_sig_svg}")
    print(f"    Top 300 SVGs — mean Moran's I: {top300_mean_I:.4f}")
    print(f"    Ultra-sparse genes (<1% detection): {n_ultra_sparse}/{len(common_genes)}")
    print(f"\n  Top 20 SVGs:")
    print(svg_results.head(20)[['I']].to_string())

In [None]:
# === Ablation 2: Embedding kNN under different transforms ===
from scipy.spatial import cKDTree
from sklearn.decomposition import PCA

print("=" * 70)
print("ABLATION 2: Embedding kNN with different similarity transforms")
print("=" * 70)

K_VALUES = [10, 20, 50]

def compute_knn_phys_dist(Z_np, coords, k):
    """Median physical distance of embedding kNN neighbors."""
    tree = cKDTree(Z_np)
    _, idx = tree.query(Z_np, k=k + 1)
    idx = idx[:, 1:]
    diffs = coords[idx] - coords[:, None, :]
    dists = np.sqrt((diffs ** 2).sum(axis=2))
    return np.median(dists, axis=1)  # per-spot median

for name in list(ST_PATHS.keys())[:1]:  # run on first ST slide
    coords = st_data[name].obsm['spatial']
    Z = Z_st[name].numpy()
    n = Z.shape[0]

    # Baselines
    tree_phys = cKDTree(coords)

    transforms = {
        'Raw embedding': Z,
        'L2-normalized (cosine)': Z / (np.linalg.norm(Z, axis=1, keepdims=True) + 1e-8),
        'Per-dim z-scored': (Z - Z.mean(0)) / (Z.std(0) + 1e-8),
        'PCA(Z) → 32d': PCA(n_components=32, random_state=42).fit_transform(Z),
        'PCA(Z) → 64d': PCA(n_components=64, random_state=42).fit_transform(Z),
    }

    for k in K_VALUES:
        print(f"\n  {name} | k={k}")
        print(f"  {'Transform':<30s} {'Median phys dist':>18s}")
        print(f"  {'-'*30} {'-'*18}")

        # Physical baseline
        _, phys_idx = tree_phys.query(coords, k=k + 1)
        phys_idx = phys_idx[:, 1:]
        phys_d = np.median(np.sqrt(((coords[phys_idx] - coords[:, None, :]) ** 2).sum(2)), axis=1)
        print(f"  {'Physical kNN (best)':<30s} {np.median(phys_d):>18.2f}")

        for tname, Z_t in transforms.items():
            emb_d = compute_knn_phys_dist(Z_t, coords, k)
            print(f"  {tname:<30s} {np.median(emb_d):>18.2f}")

        # Random baseline
        rng = np.random.default_rng(42)
        rand_d = np.array([
            np.median(np.linalg.norm(coords[rng.choice(n, k, replace=False)] - coords[i], axis=1))
            for i in range(n)
        ])
        print(f"  {'Random (worst)':<30s} {np.median(rand_d):>18.2f}")

In [None]:
# === Ablation 3: Identifiability — crop a single spatial region ===
from scipy.spatial import cKDTree

print("=" * 70)
print("ABLATION 3: Embedding kNN within a spatial crop (single region)")
print("=" * 70)

for name in list(ST_PATHS.keys())[:1]:
    coords = st_data[name].obsm['spatial']
    Z = Z_st[name].numpy()
    n = Z.shape[0]

    # Crop: take a central region covering ~25% of spots
    cx, cy = np.median(coords, axis=0)
    dists_to_center = np.sqrt((coords[:, 0] - cx)**2 + (coords[:, 1] - cy)**2)
    radius = np.quantile(dists_to_center, 0.50)  # ~25% area → 50% radius
    mask = dists_to_center <= radius
    n_crop = mask.sum()

    coords_crop = coords[mask]
    Z_crop = Z[mask]

    print(f"  {name}: cropped {n_crop}/{n} spots (radius={radius:.1f})")

    tree_phys = cKDTree(coords_crop)
    tree_emb = cKDTree(Z_crop)

    for k in [10, 20]:
        if k >= n_crop:
            continue

        _, phys_idx = tree_phys.query(coords_crop, k=k + 1)
        phys_idx = phys_idx[:, 1:]
        phys_d = np.median(np.sqrt(((coords_crop[phys_idx] - coords_crop[:, None, :]) ** 2).sum(2)), axis=1)

        _, emb_idx = tree_emb.query(Z_crop, k=k + 1)
        emb_idx = emb_idx[:, 1:]
        emb_d = np.median(np.sqrt(((coords_crop[emb_idx] - coords_crop[:, None, :]) ** 2).sum(2)), axis=1)

        rng = np.random.default_rng(42)
        rand_d = np.array([
            np.median(np.linalg.norm(coords_crop[rng.choice(n_crop, k, replace=False)] - coords_crop[i], axis=1))
            for i in range(n_crop)
        ])

        print(f"\n  Cropped region | k={k}")
        print(f"    Physical kNN median dist:  {np.median(phys_d):.2f}")
        print(f"    Embedding kNN median dist: {np.median(emb_d):.2f}")
        print(f"    Random median dist:        {np.median(rand_d):.2f}")
        ratio = (np.median(emb_d) - np.median(phys_d)) / (np.median(rand_d) - np.median(phys_d) + 1e-8)
        print(f"    Emb→Phys ratio (0=perfect, 1=random): {ratio:.3f}")

In [None]:
# === Ablation 1: Physical vs Embedding patch construction (ST4) ===
from scipy.spatial import cKDTree

print("=" * 70)
print("ABLATION 1: Physical vs Embedding patch graph quality (ST4)")
print("=" * 70)

name = 'liver_ST4'
coords = inf_data[name].obsm['spatial']
Z = Z_inf[name].numpy()
n = Z.shape[0]

K_PATCH = 20  # typical miniset neighborhood size

# === Mode P: Physical kNN graph ===
tree_phys = cKDTree(coords)
_, phys_idx = tree_phys.query(coords, k=K_PATCH + 1)
phys_idx = phys_idx[:, 1:]

# === Mode Z: Embedding kNN graph ===
tree_emb = cKDTree(Z)
_, emb_idx = tree_emb.query(Z, k=K_PATCH + 1)
emb_idx = emb_idx[:, 1:]

# --- Metric 1: Physical compactness of patches ---
def patch_compactness(coords, idx):
    """Mean diameter of patches (max pairwise dist within each patch)."""
    diameters = []
    for i in range(len(idx)):
        patch_coords = coords[idx[i]]
        center = patch_coords.mean(axis=0)
        dists = np.sqrt(((patch_coords - center) ** 2).sum(axis=1))
        diameters.append(dists.max())
    return np.array(diameters)

phys_compact = patch_compactness(coords, phys_idx)
emb_compact = patch_compactness(coords, emb_idx)

print(f"\n  Patch compactness (physical radius of k={K_PATCH} neighborhood):")
print(f"    Physical patches — median: {np.median(phys_compact):.2f}, mean: {np.mean(phys_compact):.2f}")
print(f"    Embedding patches — median: {np.median(emb_compact):.2f}, mean: {np.mean(emb_compact):.2f}")
print(f"    Ratio (emb/phys): {np.median(emb_compact) / np.median(phys_compact):.2f}x")

# --- Metric 2: Graph connectivity overlap ---
# What fraction of embedding neighbors are also physical neighbors?
overlap_frac = np.array([
    len(set(phys_idx[i]) & set(emb_idx[i])) / K_PATCH
    for i in range(n)
])

print(f"\n  Neighbor overlap (phys ∩ emb / k):")
print(f"    Mean overlap: {overlap_frac.mean():.4f}")
print(f"    Median overlap: {np.median(overlap_frac):.4f}")
print(f"    % spots with zero overlap: {(overlap_frac == 0).mean() * 100:.1f}%")

# --- Metric 3: Pairwise distance preservation ---
# For each patch, compute pairwise distances in coord space
# Compare consistency between the two modes
def pairwise_dist_stats(coords, idx):
    """Per-patch mean pairwise distance in physical space."""
    means = []
    for i in range(len(idx)):
        pc = coords[idx[i]]
        d = np.sqrt(((pc[:, None] - pc[None, :]) ** 2).sum(axis=2))
        means.append(d[np.triu_indices(len(pc), k=1)].mean())
    return np.array(means)

phys_pw = pairwise_dist_stats(coords, phys_idx)
emb_pw = pairwise_dist_stats(coords, emb_idx)

print(f"\n  Mean pairwise physical distance within patches:")
print(f"    Physical patches: {np.median(phys_pw):.2f}")
print(f"    Embedding patches: {np.median(emb_pw):.2f}")

# --- Visualization ---
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

axes[0].hist(phys_compact, bins=50, alpha=0.6, label='Physical', density=True)
axes[0].hist(emb_compact, bins=50, alpha=0.6, label='Embedding', density=True)
axes[0].set_xlabel('Patch radius')
axes[0].set_title('Patch Compactness')
axes[0].legend()

axes[1].hist(overlap_frac, bins=50, edgecolor='k', alpha=0.7)
axes[1].set_xlabel('Neighbor overlap fraction')
axes[1].set_title(f'Physical ∩ Embedding (k={K_PATCH})')

axes[2].hist(phys_pw, bins=50, alpha=0.6, label='Physical', density=True)
axes[2].hist(emb_pw, bins=50, alpha=0.6, label='Embedding', density=True)
axes[2].set_xlabel('Mean pairwise distance')
axes[2].set_title('Within-Patch Distances')
axes[2].legend()

plt.suptitle(f'{name}: Physical vs Embedding Patch Construction', fontweight='bold')
plt.tight_layout()
plt.show()

# --- Summary verdict ---
ratio = np.median(emb_compact) / np.median(phys_compact)
print(f"\n  VERDICT:")
if ratio > 3.0 and np.median(overlap_frac) < 0.1:
    print(f"    Embedding patches are {ratio:.1f}x larger with {np.median(overlap_frac)*100:.1f}% overlap.")
    print(f"    → Embedding kNN is NOT spatially local. Physical patches should be used for training.")
elif ratio > 1.5:
    print(f"    Embedding patches are {ratio:.1f}x larger — moderate mismatch.")
    print(f"    → Consider hybrid approach or tighter embedding constraints.")
else:
    print(f"    Embedding patches are comparable ({ratio:.1f}x). kNN graph is reasonable.")

In [None]:
print("\n" + "=" * 70)
print("EVALUATION: Domain Mixing Analysis")
print("=" * 70)

N_MAX = 5000
K = 20
set_seed(SEED)

def subsample(X, n_max):
    if not isinstance(X, torch.Tensor):
        X = torch.tensor(X, dtype=torch.float32)
    if X.shape[0] <= n_max:
        return X
    return X[torch.randperm(X.shape[0])[:n_max]]

Z_st_sub = {n: subsample(Z, N_MAX) for n, Z in Z_st.items()}
Z_inf_sub = {n: subsample(Z, N_MAX) for n, Z in Z_inf.items()}

Z_all = torch.cat(list(Z_st_sub.values()) + list(Z_inf_sub.values()), dim=0)
Z_all_norm = F.normalize(Z_all, dim=1)
N = Z_all.shape[0]

# Domain labels
domain_labels = []
source_names = []
for i, name in enumerate(list(Z_st_sub) + list(Z_inf_sub)):
    n = (Z_st_sub if name in Z_st_sub else Z_inf_sub)[name].shape[0]
    domain_labels.append(torch.full((n,), i, dtype=torch.long))
    source_names.append(name)
domain_labels = torch.cat(domain_labels)
S = len(source_names)

print(f"Sources: {source_names}, Total: {N}, Domains: {S}")

# --- kNN Domain-Mixing ---
print(f"\n[METRIC 1] kNN Domain-Mixing (k={K})")
D = torch.cdist(Z_all_norm, Z_all_norm)
D.fill_diagonal_(float('inf'))
_, knn_idx = torch.topk(D, k=K, dim=1, largest=False)
knn_labels = domain_labels[knn_idx]

p = torch.zeros(N, S)
for s in range(S):
    p[:, s] = (knn_labels == s).float().mean(dim=1)

eps = 1e-10
H_norm = (-torch.sum(p * torch.log(p + eps), dim=1) / np.log(S)).mean().item()
iLISI_norm = ((1.0 / torch.sum(p ** 2, dim=1) - 1) / (S - 1)).mean().item()

print(f"  Normalized Neighbor Entropy: {H_norm:.4f}")
print(f"  Normalized iLISI: {iLISI_norm:.4f}")

# --- Domain Classification (CV) ---
print(f"\n[METRIC 2] Domain Classification (5-fold CV)")
Z_np, y_np = Z_all_norm.numpy(), domain_labels.numpy()
clf = LogisticRegression(max_iter=2000, random_state=42, class_weight='balanced', n_jobs=-1)
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
cv_scores = cross_val_score(clf, Z_np, y_np, cv=cv, scoring='balanced_accuracy')
chance = 1.0 / S
print(f"  Balanced Acc: {cv_scores.mean():.4f} ± {cv_scores.std():.4f} (chance={chance:.4f})")

# --- Binary: Training ST vs Inference ST4 ---
print(f"\n[METRIC 3] Binary ST(1-3) vs ST4 (5-fold CV)")
n_st_total = sum(Z_st_sub[n].shape[0] for n in Z_st_sub)
y_binary = torch.cat([torch.zeros(n_st_total, dtype=torch.long),
                       torch.ones(N - n_st_total, dtype=torch.long)]).numpy()
cv_binary = cross_val_score(clf, Z_np, y_binary, cv=cv, scoring='balanced_accuracy')
print(f"  Balanced Acc (ST vs ST4): {cv_binary.mean():.4f} ± {cv_binary.std():.4f}")

# --- Centroid Analysis ---
print(f"\n[CENTROID ANALYSIS]")
centroids = {}
for name in Z_st_sub:
    centroids[name] = Z_st_sub[name].mean(dim=0)
for name in Z_inf_sub:
    centroids[name] = Z_inf_sub[name].mean(dim=0)

names = list(centroids)
print("         ", "  ".join([f"{n:>12}" for n in names]))
for n1 in names:
    row = f"{n1:12s}"
    for n2 in names:
        row += f"  {(centroids[n1] - centroids[n2]).norm().item():12.4f}"
    print(row)

# --- Summary ---
print("\n" + "=" * 70)
print("SUMMARY")
print("=" * 70)
print(f"{'Metric':<40} {'Value':<15}")
print("-" * 55)
print(f"{'Neighbor Entropy (k=20)':<40} {H_norm:<15.4f}")
print(f"{'iLISI (k=20)':<40} {iLISI_norm:<15.4f}")
print(f"{'Domain Class. Acc (CV)':<40} {cv_scores.mean():<15.4f}")
print(f"{'ST vs ST4 Acc (CV)':<40} {cv_binary.mean():<15.4f}")
print("=" * 70)

In [None]:
print("\n" + "=" * 70)
print("PCA COMPARISON: Raw vs Embeddings")
print("=" * 70)

N_VIS = 3000
set_seed(SEED)

# Subsample matched
def subsample_matched(X, Z, n_max):
    X_t = torch.tensor(X, dtype=torch.float32) if not isinstance(X, torch.Tensor) else X.float()
    Z_t = Z.float() if isinstance(Z, torch.Tensor) else torch.tensor(Z, dtype=torch.float32)
    X_t, Z_t = X_t.cpu(), Z_t.cpu()
    n = min(X_t.shape[0], n_max)
    idx = torch.randperm(X_t.shape[0])[:n]
    return X_t[idx].numpy(), Z_t[idx].numpy()

X_vis, Z_vis, labels_vis = [], [], []
for name in labels_str:
    src = X_st if name in X_st else X_inf
    src_z = Z_st if name in Z_st else Z_inf
    x_sub, z_sub = subsample_matched(src[name], src_z[name], N_VIS)
    X_vis.append(x_sub)
    Z_vis.append(z_sub)
    labels_vis.extend([name] * x_sub.shape[0])

X_vis = np.vstack(X_vis)
Z_vis = np.vstack(Z_vis)
labels_vis = np.array(labels_vis)

pca_x = PCA(n_components=2).fit(X_vis)
X_pca = pca_x.transform(X_vis)
var_x = pca_x.explained_variance_ratio_

pca_z = PCA(n_components=2).fit(Z_vis)
Z_pca = pca_z.transform(Z_vis)
var_z = pca_z.explained_variance_ratio_

# plt.rcParams.update({'font.family': 'Arial', 'font.weight': 'bold',
#                      'axes.labelweight': 'bold', 'axes.titleweight': 'bold'})

fig, axes = plt.subplots(1, 2, figsize=(12, 5))
titles = [
    f"(a) PCA of raw expression (log1p)\nVar: PC1={var_x[0]:.1%}, PC2={var_x[1]:.1%}",
    f"(b) PCA of encoder embeddings\nVar: PC1={var_z[0]:.1%}, PC2={var_z[1]:.1%}",
]
for ax, data, title in zip(axes, [X_pca, Z_pca], titles):
    for label in labels_str:
        mask = labels_vis == label
        ax.scatter(data[mask, 0], data[mask, 1], c=colors_map[label], label=label,
                   s=25, alpha=0.65, edgecolors='white', linewidths=0.5)
    ax.set_xlabel('PC1', fontsize=13, fontweight='bold')
    ax.set_ylabel('PC2', fontsize=13, fontweight='bold')
    ax.set_title(title, fontsize=13, fontweight='bold')
    ax.set_facecolor('#FAFAFA')
    ax.tick_params(labelsize=11)
    for spine in ax.spines.values():
        spine.set_color('#CCCCCC')
        spine.set_linewidth(1.5)

axes[0].legend(loc='best', fontsize=11, framealpha=0.98, edgecolor='#888888')
plt.tight_layout()
# plt.savefig('liver_pca_comparison.svg', format='svg', bbox_inches='tight', dpi=600)
plt.show()

In [None]:
# print("\n" + "=" * 70)
# print("t-SNE COMPARISON: Raw vs Embeddings")
# print("=" * 70)

# N_TSNE = 2000
# set_seed(SEED)

# X_tsne_all, Z_tsne_all, labels_tsne = [], [], []
# for name in labels_str:
#     src = X_st if name in X_st else X_inf
#     src_z = Z_st if name in Z_st else Z_inf
#     x_sub, z_sub = subsample_matched(src[name], src_z[name], N_TSNE)
#     X_tsne_all.append(x_sub)
#     Z_tsne_all.append(z_sub)
#     labels_tsne.extend([name] * x_sub.shape[0])

# X_tsne_all = np.vstack(X_tsne_all)
# Z_tsne_all = np.vstack(Z_tsne_all)
# labels_tsne = np.array(labels_tsne)
# Z_tsne_norm = F.normalize(torch.tensor(Z_tsne_all), dim=1).numpy()

# print("Computing t-SNE for raw expression...")
# X_tsne_proj = TSNE(n_components=2, random_state=42, perplexity=30, n_iter=1000).fit_transform(X_tsne_all)

# print("Computing t-SNE for embeddings...")
# Z_tsne_proj = TSNE(n_components=2, random_state=42, perplexity=30, n_iter=1000).fit_transform(Z_tsne_norm)

# fig, axes = plt.subplots(1, 2, figsize=(18, 7))
# fig.patch.set_facecolor('white')

# for ax, data, title in zip(axes, [X_tsne_proj, Z_tsne_proj],
#     ['(A) t-SNE on Raw Expression (Before)', '(B) t-SNE on Embeddings (After)']):
#     for label in labels_str:
#         mask = labels_tsne == label
#         ax.scatter(data[mask, 0], data[mask, 1], c=colors_map[label],
#                    label=label, alpha=0.5, s=20, edgecolors='none')
#     ax.set_title(title, fontsize=14, fontweight='bold')
#     ax.set_xlabel('t-SNE 1', fontsize=12)
#     ax.set_ylabel('t-SNE 2', fontsize=12)
#     ax.legend(loc='best', fontsize=10, frameon=True)
#     ax.grid(alpha=0.3, linestyle='--')

# plt.suptitle('Mouse Liver: Before vs After Alignment (Cross-Slide)',
#              fontsize=16, fontweight='bold', y=0.98)
# plt.tight_layout()
# plt.show()

In [None]:
# ===================================================================
# EVALUATION METRICS
# ===================================================================
print("\n" + "=" * 70)
print("EVALUATION: Liver Cross-Slide Encoder")
print("=" * 70)

N_MAX = 5000
set_seed(SEED)

def subsample(X, n_max, device='cpu'):
    if not isinstance(X, torch.Tensor):
        X = torch.tensor(X, dtype=torch.float32, device=device)
    else:
        X = X.to(device)
    if X.shape[0] <= n_max:
        return X
    return X[torch.randperm(X.shape[0], device=device)[:n_max]]

Z_st_sub = {name: subsample(Z, N_MAX, 'cpu') for name, Z in Z_st.items()}
Z_inf_sub = {name: subsample(Z, N_MAX, 'cpu') for name, Z in Z_inf.items()}

print("Subsampled embeddings:")
for name, Z in {**Z_st_sub, **Z_inf_sub}.items():
    print(f"  {name}: {Z.shape}")

# ===================================================================
# TEST 1: ST(1-3) vs ST4 kNN mixing
# ===================================================================
print("\n[ST-vs-ST4 MIXING] kNN domain distribution:")

Z_st_all = torch.cat(list(Z_st_sub.values()), dim=0)
Z_inf_all = torch.cat(list(Z_inf_sub.values()), dim=0)
Z_all = torch.cat([Z_st_all, Z_inf_all], dim=0)

n_st = Z_st_all.shape[0]
n_inf = Z_inf_all.shape[0]
n_total = n_st + n_inf

# Domain labels (0=ST training, 1=ST4 inference)
domain_labels = torch.cat([
    torch.zeros(n_st, dtype=torch.long),
    torch.ones(n_inf, dtype=torch.long)
])

Z_all_norm = F.normalize(Z_all, dim=1)

K = 20
inf_start = n_st

D_inf = torch.cdist(Z_all_norm[inf_start:], Z_all_norm)
for i in range(n_inf):
    D_inf[i, inf_start + i] = float('inf')

_, knn_inf = torch.topk(D_inf, k=K, dim=1, largest=False)
frac_same_inf = (domain_labels[knn_inf] == 1).float().mean().item()
base_rate_inf = n_inf / n_total

print(f"  ST4 neighbors (K={K}):")
print(f"    Same-domain (ST4) fraction: {frac_same_inf:.4f}")
print(f"    Base rate (chance):         {base_rate_inf:.4f}")

if frac_same_inf < base_rate_inf + 0.10:
    print("    ✓ EXCELLENT mixing (ST4 not clustering)")
elif frac_same_inf < base_rate_inf + 0.20:
    print("    ✓ Good mixing")
else:
    print("    ⚠️ ST4 may be clustering")

# ===================================================================
# TEST 2: Linear Probe (4-class separability)
# ===================================================================
n_sources = len(ST_PATHS) + len(INF_PATHS)
print(f"\n[{n_sources}-CLASS PROBE] Linear separability test:")

class_labels_list = []
class_idx = 0
for name in ST_PATHS:
    n_cells = Z_st_sub[name].shape[0]
    class_labels_list.append(torch.full((n_cells,), class_idx, dtype=torch.long))
    class_idx += 1
for name in INF_PATHS:
    n_cells = Z_inf_sub[name].shape[0]
    class_labels_list.append(torch.full((n_cells,), class_idx, dtype=torch.long))
    class_idx += 1

class_labels = torch.cat(class_labels_list)

Z_np = Z_all_norm.numpy()
y_np = class_labels.numpy()

probe = LogisticRegression(max_iter=5000, random_state=42, class_weight='balanced')
probe.fit(Z_np, y_np)
pred = probe.predict(Z_np)
bal_acc = balanced_accuracy_score(y_np, pred)
chance = 1.0 / n_sources

print(f"  Balanced accuracy: {bal_acc:.4f} (chance={chance:.3f})")

if bal_acc < 0.30:
    print("  ✓ EXCELLENT: Sources are very well-mixed")
elif bal_acc < 0.40:
    print("  ✓ Good: Moderate mixing")
elif bal_acc < 0.50:
    print("  ~ Partial alignment")
else:
    print("  ⚠️ Sources are separable")

# ===================================================================
# TEST 3: Centroid distances
# ===================================================================
print("\n[CENTROID ANALYSIS] Cross-source distances:")

centroids = {}
for name in ST_PATHS:
    centroids[name] = Z_st_sub[name].mean(dim=0)
for name in INF_PATHS:
    centroids[name] = Z_inf_sub[name].mean(dim=0)

names = list(centroids)
print("\nCentroid distance matrix:")
print("             ", "  ".join([f"{n:>12}" for n in names]))

for n1 in names:
    row = f"{n1:12s}"
    for n2 in names:
        dist = (centroids[n1] - centroids[n2]).norm().item()
        row += f"  {dist:12.4f}"
    print(row)

# ST-to-ST vs ST-to-ST4 distances
st_names = list(ST_PATHS)
inf_names = list(INF_PATHS)

st_to_st = []
for i, n1 in enumerate(st_names):
    for n2 in st_names[i + 1:]:
        st_to_st.append((centroids[n1] - centroids[n2]).norm().item())

st_to_inf = []
for n1 in st_names:
    for n2 in inf_names:
        st_to_inf.append((centroids[n1] - centroids[n2]).norm().item())

print(f"\nAverage distances:")
print(f"  ST-to-ST:  {np.mean(st_to_st):.4f} ± {np.std(st_to_st):.4f}")
print(f"  ST-to-ST4: {np.mean(st_to_inf):.4f} ± {np.std(st_to_inf):.4f}")
ratio = np.mean(st_to_inf) / np.mean(st_to_st)
print(f"  Ratio (ST4/ST): {ratio:.2f}x")

if ratio < 1.5:
    print("  ✓ EXCELLENT: ST4 very close to training ST centroids")
elif ratio < 2.5:
    print("  ✓ Good: Reasonable cross-slide distance")
else:
    print("  ⚠️ Large slide gap remains")

# ===================================================================
# SUMMARY
# ===================================================================
print("\n" + "=" * 70)
print("LIVER ENCODER EVALUATION SUMMARY")
print("=" * 70)
print(f"✓ ST4 mixing:        same-domain frac = {frac_same_inf:.4f} (base = {base_rate_inf:.4f})")
print(f"✓ {n_sources}-class probe:   balanced acc = {bal_acc:.4f} (chance = {chance:.3f})")
print(f"✓ Centroid ratio:    ST-to-ST4 / ST-to-ST = {ratio:.2f}x")
print("=" * 70)