In [None]:
# ===================================================================
# TRAIN STAGE A: VICReg + Domain Adversary (3 ST slides + SC)
# ===================================================================
import torch
import torch.nn.functional as F
import scanpy as sc
import numpy as np
import sys
sys.path.insert(0, '/home/ehtesamul/sc_st/model')

from core_models_et_p1 import SharedEncoder, train_encoder
import utils_et as uet

device = 'cuda' if torch.cuda.is_available() else 'cpu'

print("="*70)
print("TRAIN STAGE A: VICReg + Domain Adversary (3 ST + SC)")
print("="*70)

def subsample_domain(X, n_max):
    """Subsample a domain to n_max samples."""
    n = X.shape[0]
    if n <= n_max:
        return X
    else:
        idx = torch.randperm(n, device=device)[:n_max]
        return X[idx]

# ===================================================================
# 1) LOAD DATA
# ===================================================================
print("\n--- Loading HSCC data ---")
scadata = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/scP2.h5ad')
stadata1 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP2.h5ad')
stadata2 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP2rep2.h5ad')
stadata3 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP2rep3.h5ad')

# Normalize
for adata in [scadata, stadata1, stadata2, stadata3]:
    sc.pp.normalize_total(adata)
    sc.pp.log1p(adata)

# Get common genes
common = sorted(list(set(scadata.var_names) & set(stadata1.var_names) & 
                     set(stadata2.var_names) & set(stadata3.var_names)))
n_genes = len(common)
print(f"‚úì Common genes: {n_genes}")

# ===================================================================
# 2) PREPARE TRAINING DATA - ALL 3 ST SLIDES + SC
# ===================================================================
# SC expression
X_sc = scadata[:, common].X
if hasattr(X_sc, "toarray"):
    X_sc = X_sc.toarray()
sc_expr = torch.tensor(X_sc, dtype=torch.float32, device=device)

# ST expression from ALL 3 slides
X_st1 = stadata1[:, common].X
X_st2 = stadata2[:, common].X
X_st3 = stadata3[:, common].X
if hasattr(X_st1, "toarray"):
    X_st1 = X_st1.toarray()
if hasattr(X_st2, "toarray"):
    X_st2 = X_st2.toarray()
if hasattr(X_st3, "toarray"):
    X_st3 = X_st3.toarray()

st_expr = torch.tensor(np.vstack([X_st1, X_st2, X_st3]), dtype=torch.float32, device=device)

# ST coordinates (required for function signature, but not used by VICReg)
st_coords1 = stadata1.obsm['spatial']
st_coords2 = stadata2.obsm['spatial']
st_coords3 = stadata3.obsm['spatial']
st_coords_raw = torch.tensor(np.vstack([st_coords1, st_coords2, st_coords3]),
                             dtype=torch.float32, device=device)

# Slide IDs - 3 ST SLIDES
slide_ids = torch.tensor(
    np.concatenate([
        np.zeros(X_st1.shape[0], dtype=int),    # slide 0
        np.ones(X_st2.shape[0], dtype=int),     # slide 1
        np.full(X_st3.shape[0], 2, dtype=int)   # slide 2
    ]),
    dtype=torch.long, device=device
)

# Canonicalize coordinates (for function signature)
st_coords, st_mu, st_scale = uet.canonicalize_st_coords_per_slide(st_coords_raw, slide_ids)

print(f"‚úì SC expr: {sc_expr.shape}")
print(f"‚úì ST expr: {st_expr.shape}")
print(f"‚úì ST coords: {st_coords.shape}")
print(f"‚úì Slide IDs: {slide_ids.shape} (slides: {torch.unique(slide_ids).tolist()})")
print(f"‚úì Training will use 4 domains: ST-slide0, ST-slide1, ST-slide2, SC")

# ===================================================================
# 3) CREATE AND TRAIN ENCODER WITH VICREG + DOMAIN ADVERSARY
# ===================================================================
print("\n" + "="*70)
print("TRAINING STAGE A ENCODER (VICReg + Domain Adversary)")
print("="*70)

encoder_vicreg = SharedEncoder(
    n_genes=n_genes,
    n_embedding=[512, 256, 128],
    dropout=0.1
)

encoder_vicreg, projector, discriminator, hist = train_encoder(
    model=encoder_vicreg,
    st_gene_expr=st_expr,
    st_coords=st_coords,
    sc_gene_expr=sc_expr,
    slide_ids=slide_ids,
    n_epochs=1000,  # Shorter to test fix quickly
    batch_size=256,
    lr=1e-3,
    device=device,
    outf='/home/ehtesamul/sc_st/model/gems_hscc_vicreg_output/run41_fixed',
    # ========== VICReg Mode ==========
    stageA_obj='vicreg_adv',
    vicreg_lambda_inv=25.0,
    vicreg_lambda_var=25.0,
    vicreg_lambda_cov=1.0,
    vicreg_gamma=1.0,
    vicreg_eps=1e-4,
    vicreg_project_dim=256,
    vicreg_use_projector=True,
    vicreg_float32_stats=True,
    vicreg_ddp_gather=False,
    # Expression augmentations (SAME)
    aug_gene_dropout=0.5,
    aug_gauss_std=0.05,
    aug_scale_jitter=0.4,
    # Domain adversary (INCREASED WEIGHT)
    adv_slide_weight=50.0,  # ‚Üê INCREASED from 20.0
    adv_warmup_epochs=50,
    adv_ramp_epochs=200,
    grl_alpha_max=1.0,
    disc_hidden=256,
    disc_dropout=0.1,
    # Balanced domain sampling
    stageA_balanced_slides=True,
    # ========== FIX FROM RUN 1 ==========
    adv_representation_mode='clean',
    adv_use_layernorm=False,  # Don't use LayerNorm
    adv_log_diagnostics=True,
    adv_log_grad_norms=False,
    # ========== LOCAL ALIGNMENT (optional - try without first) ==========
    use_local_align=False,  # ‚Üê Disable for now, test adversary fix first
    return_aux=True
)

print("\n‚úì VICReg Stage A training complete!")

# Save encoder
import os
os.makedirs('/home/ehtesamul/sc_st/model/gems_hscc_vicreg_output', exist_ok=True)
torch.save(encoder_vicreg.state_dict(), 
           '/home/ehtesamul/sc_st/model/gems_hscc_vicreg_output/encoder_vicreg.pt')
print("‚úì Encoder saved to: encoder_vicreg.pt")

# ===================================================================
# 4) EVALUATE DOMAIN MIXING (ALL 4 DOMAINS: 3 ST + SC)
# ===================================================================
print("\n" + "="*70)
print("EVALUATION: Domain Mixing (3 ST slides + SC)")
print("="*70)

N_MAX = 2000

# Subsample all 4 domains
X1 = torch.tensor(X_st1, dtype=torch.float32, device=device)
X2 = torch.tensor(X_st2, dtype=torch.float32, device=device)
X3 = torch.tensor(X_st3, dtype=torch.float32, device=device)
X_sc_torch = torch.tensor(X_sc, dtype=torch.float32, device=device)

X1_sub = subsample_domain(X1, N_MAX)
X2_sub = subsample_domain(X2, N_MAX)
X3_sub = subsample_domain(X3, N_MAX)
X_sc_sub = subsample_domain(X_sc_torch, N_MAX)

# Compute embeddings
encoder_vicreg.eval()
with torch.no_grad():
    Z1 = encoder_vicreg(X1_sub)
    Z2 = encoder_vicreg(X2_sub)
    Z3 = encoder_vicreg(X3_sub)
    Z_sc = encoder_vicreg(X_sc_sub)

print(f"Z1 (ST-slide0): {Z1.shape}")
print(f"Z2 (ST-slide1): {Z2.shape}")
print(f"Z3 (ST-slide2): {Z3.shape}")
print(f"Z_sc (SC):      {Z_sc.shape}")

# ===================================================================
# TEST 1: Expression Mixing (baseline)
# ===================================================================
print("\n[EXPR-MIXING] Expression space kNN domain distribution:")
X_all = torch.cat([X1_sub, X2_sub, X3_sub, X_sc_sub], dim=0)
X_all_norm = F.normalize(X_all, dim=1)

n1, n2, n3, n_sc = X1_sub.shape[0], X2_sub.shape[0], X3_sub.shape[0], X_sc_sub.shape[0]
n_total = n1 + n2 + n3 + n_sc

# Domain labels: 0=ST1, 1=ST2, 2=ST3, 3=SC
labels = torch.cat([
    torch.zeros(n1, dtype=torch.long, device=device),
    torch.ones(n2, dtype=torch.long, device=device),
    torch.full((n3,), 2, dtype=torch.long, device=device),
    torch.full((n_sc,), 3, dtype=torch.long, device=device)
])

K_mix = 20

# Check SC mixing in expression space
sc_start = n1 + n2 + n3
D_expr_sc = torch.cdist(X_all_norm[sc_start:], X_all_norm)

# Exclude self
for i in range(n_sc):
    D_expr_sc[i, sc_start + i] = float('inf')

_, knn_expr_sc = torch.topk(D_expr_sc, k=K_mix, dim=1, largest=False)

frac_same_expr = []
for i in range(n_sc):
    neighbor_labels = labels[knn_expr_sc[i]]
    frac = (neighbor_labels == 3).float().mean().item()
    frac_same_expr.append(frac)

frac_same_expr = np.array(frac_same_expr)
base_rate = n_sc / n_total

print(f"  SC neighbors (K={K_mix}):")
print(f"    Same-domain fraction: {frac_same_expr.mean():.4f}")
print(f"    Base rate (chance):   {base_rate:.4f}")
print(f"    ‚Üí Expression space shows domain clustering (expected)")

# ===================================================================
# TEST 2: Z Mixing (should approach base_rate if working)
# ===================================================================
print("\n[Z-MIXING] Embedding space kNN domain distribution:")
Z_all = torch.cat([Z1, Z2, Z3, Z_sc], dim=0)
Z_all_norm = F.normalize(Z_all, dim=1)

D_emb_sc = torch.cdist(Z_all_norm[sc_start:], Z_all_norm)

# Exclude self
for i in range(n_sc):
    D_emb_sc[i, sc_start + i] = float('inf')

_, knn_emb_sc = torch.topk(D_emb_sc, k=K_mix, dim=1, largest=False)

frac_same_z = []
for i in range(n_sc):
    neighbor_labels = labels[knn_emb_sc[i]]
    frac = (neighbor_labels == 3).float().mean().item()
    frac_same_z.append(frac)

frac_same_z = np.array(frac_same_z)

print(f"  SC neighbors (K={K_mix}):")
print(f"    Same-domain fraction: {frac_same_z.mean():.4f}")
print(f"    Base rate (chance):   {base_rate:.4f}")

improvement = (frac_same_expr.mean() - frac_same_z.mean()) / (frac_same_expr.mean() - base_rate)


print(f"  ‚Üí Mixing improvement: {improvement*100:.1f}%")

# ===================================================================
# TEST 3: All-vs-All Domain Mixing Matrix
# ===================================================================
print("\n[MIXING-MATRIX] Cross-domain neighbor fractions:")
domain_names = ['ST-slide0', 'ST-slide1', 'ST-slide2', 'SC']
domain_sizes = [n1, n2, n3, n_sc]
domain_starts = [0, n1, n1+n2, n1+n2+n3]

print("\nQuery ‚Üí Neighbors (mean fraction of K=20 neighbors):")
print("         ST-s0  ST-s1  ST-s2    SC")

for query_idx, query_name in enumerate(domain_names):
    start_idx = domain_starts[query_idx]
    end_idx = start_idx + domain_sizes[query_idx]
    
    # kNN for this domain
    D_query = torch.cdist(Z_all_norm[start_idx:end_idx], Z_all_norm)
    
    # Exclude self
    for i in range(domain_sizes[query_idx]):
        D_query[i, start_idx + i] = float('inf')
    
    _, knn_query = torch.topk(D_query, k=K_mix, dim=1, largest=False)
    
    # Count neighbors from each domain
    fracs = []
    for target_idx in range(4):
        neighbor_labels = labels[knn_query.flatten()]
        frac = (neighbor_labels == target_idx).float().mean().item()
        fracs.append(frac)
    
    print(f"{query_name:8s}  {fracs[0]:.3f}  {fracs[1]:.3f}  {fracs[2]:.3f}  {fracs[3]:.3f}")

print("\nIdeal (perfect mixing): uniform fractions matching domain sizes")
print(f"Expected: ST-s0={n1/n_total:.3f}, ST-s1={n2/n_total:.3f}, ST-s2={n3/n_total:.3f}, SC={n_sc/n_total:.3f}")

# ===================================================================
# TEST 4: Domain Linear Probe (should be near chance)
# ===================================================================
print("\n[DOMAIN-PROBE] Linear probe accuracy on Z:")
from sklearn.linear_model import LogisticRegression

Z_all_cpu = Z_all.cpu().numpy()
labels_cpu = labels.cpu().numpy()

probe = LogisticRegression(max_iter=1000, random_state=42)
Z_all_n = F.normalize(Z_all, dim=1).cpu().numpy()
probe.fit(Z_all_n, labels_cpu)
acc = probe.score(Z_all_n, labels_cpu)
chance = 1.0 / 4.0

print(f"  Accuracy: {acc:.4f} (chance={chance:.3f})")


# ===================================================================
# TEST 5: Per-dimension statistics (detect collapse)
# ===================================================================
print("\n[COLLAPSE-CHECK] Per-dimension statistics:")
std_per_dim = Z_all.std(dim=0)
print(f"  Mean std:   {std_per_dim.mean().item():.4f}")
print(f"  Min std:    {std_per_dim.min().item():.4f}")
print(f"  Max std:    {std_per_dim.max().item():.4f}")
print(f"  Dead dims:  {(std_per_dim < 0.1).sum().item()}/{std_per_dim.shape[0]}")

if std_per_dim.min().item() < 0.01:
    print("  ‚ö†Ô∏è  COLLAPSE DETECTED (some dims dead)")
elif std_per_dim.mean().item() < 0.5:
    print("  ‚ö†Ô∏è  LOW VARIANCE (increase vicreg_gamma or reduce adversary)")
else:
    print("  ‚úì  Healthy variance")

# ===================================================================
# TEST 6: ST‚ÜîSC CORAL distance (should be low)
# ===================================================================
print("\n[ST-SC-ALIGNMENT] CORAL distance between ST and SC:")

# Pool all ST
Z_st_all = torch.cat([Z1, Z2, Z3], dim=0)

# Compute CORAL
mu_st = Z_st_all.mean(dim=0)
mu_sc = Z_sc.mean(dim=0)
mean_diff = (mu_st - mu_sc).pow(2).mean().item()

z_st_c = Z_st_all - mu_st
z_sc_c = Z_sc - mu_sc
cov_st = (z_st_c.T @ z_st_c) / max(z_st_c.shape[0] - 1, 1)
cov_sc = (z_sc_c.T @ z_sc_c) / max(z_sc_c.shape[0] - 1, 1)
cov_diff = (cov_st - cov_sc).pow(2).mean().item()

coral_dist = mean_diff + cov_diff

print(f"  Mean difference:       {mean_diff:.6f}")
print(f"  Covariance difference: {cov_diff:.6f}")
print(f"  Total CORAL distance:  {coral_dist:.6f}")



In [None]:
# ===================================================================
# DECISIVE TEST: Train Stage A on SLIDE3 ALONE
# ===================================================================
import torch
import scanpy as sc
import numpy as np
import sys
sys.path.insert(0, '/home/ehtesamul/sc_st/model')

from core_models_et_p1 import SharedEncoder, train_encoder
import utils_et as uet

device = 'cuda' if torch.cuda.is_available() else 'cpu'

print("="*70)
print("DECISIVE TEST: Train Stage A on SLIDE3 ALONE")
print("="*70)
print("This tests if slide3 expression‚Üíspace is intrinsically learnable")
print("="*70)

# ===================================================================
# 1) LOAD SLIDE3 DATA ONLY
# ===================================================================
print("\n--- Loading slide3 data ---")
stadata3 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP2rep3.h5ad')

# Load other slides just to get common genes
stadata1 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP2.h5ad')
stadata2 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP2rep2.h5ad')
scadata = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/scP2.h5ad')

# Normalize
for adata in [stadata1, stadata2, stadata3, scadata]:
    sc.pp.normalize_total(adata)
    sc.pp.log1p(adata)

# Get common genes
common = sorted(list(set(scadata.var_names) & set(stadata1.var_names) & 
                     set(stadata2.var_names) & set(stadata3.var_names)))
n_genes = len(common)
print(f"‚úì Common genes: {n_genes}")

# ===================================================================
# 2) PREPARE SLIDE3 DATA
# ===================================================================
# ST expression (use slide3 for both st_expr and sc_expr to avoid SC dependency)
X_st3 = stadata3[:, common].X
if hasattr(X_st3, "toarray"):
    X_st3 = X_st3.toarray()
st3_expr = torch.tensor(X_st3, dtype=torch.float32, device=device)

# Canonicalize slide3 coords
st3_coords_raw = torch.tensor(stadata3.obsm['spatial'], dtype=torch.float32, device=device)
slide_ids_st3 = torch.zeros(st3_coords_raw.shape[0], dtype=torch.long, device=device)
st3_coords, st3_mu, st3_scale = uet.canonicalize_st_coords_per_slide(st3_coords_raw, slide_ids_st3)

print(f"‚úì Slide3 expr: {st3_expr.shape}")
print(f"‚úì Slide3 coords: {st3_coords.shape}")

# ===================================================================
# 3) TRAIN ENCODER ON SLIDE3 ALONE
# ===================================================================
print("\n" + "="*70)
print("TRAINING ON SLIDE3 ONLY (local minisets)")
print("="*70)

encoder_st3 = SharedEncoder(
    n_genes=n_genes,
    n_embedding=[512, 256, 128],
    dropout=0.1
)

encoder_st3 = train_encoder(
    model=encoder_st3,
    st_gene_expr=st3_expr,
    st_coords=st3_coords,
    sc_gene_expr=st3_expr,  # Use same as ST to avoid SC dependency
    slide_ids=slide_ids_st3,
    n_epochs=1000,
    batch_size=256,
    lr=1e-4,
    sigma=None,
    alpha=0.0,  # No MMD (only 1 slide)
    ratio_start=0.0,
    ratio_end=0.0,  # No circular loss
    mmdbatch=0.0,
    device=device,
    outf='/home/ehtesamul/sc_st/model/gems_hscc_output_anchored_new',
    local_miniset_mode=True,
    n_min=128,
    n_max=384,
    pool_mult=4.0,
    stochastic_tau=1.0,
    slide_align_mode='none',  # No alignment (only 1 slide)
    slide_align_weight=0.0,
    use_circle=False,
    use_mmd_sc=False,
)

print("\n‚úì Slide3-only training complete!")

# ===================================================================
# 4) EVALUATE ON SLIDE3
# ===================================================================
print("\n" + "="*70)
print("EVALUATION: Slide3 Learnability")
print("="*70)

# Subsample if too large
N_MAX = 2000
torch.manual_seed(42)

def subsample_slide(X, C, n_max):
    n = X.shape[0]
    if n <= n_max:
        return X, C
    else:
        idx = torch.randperm(n, device=device)[:n_max]
        return X[idx], C[idx]

X3_sub, C3_sub = subsample_slide(st3_expr, st3_coords, N_MAX)

# Compute embeddings
encoder_st3.eval()
with torch.no_grad():
    Z3 = encoder_st3(X3_sub)

print(f"Z3 shape: {Z3.shape}")

# kNN Overlap
def compute_knn_overlap(Z, C, k=10):
    n = Z.shape[0]
    D_emb = torch.cdist(Z, Z)
    D_emb.fill_diagonal_(float('inf'))
    _, knn_emb = torch.topk(D_emb, k=k, dim=1, largest=False)
    
    D_coord = torch.cdist(C, C)
    D_coord.fill_diagonal_(float('inf'))
    _, knn_coord = torch.topk(D_coord, k=k, dim=1, largest=False)
    
    overlaps = []
    for i in range(n):
        emb_set = set(knn_emb[i].cpu().numpy())
        coord_set = set(knn_coord[i].cpu().numpy())
        overlap = len(emb_set & coord_set) / k
        overlaps.append(overlap)
    
    return np.mean(overlaps)

overlap_k10 = compute_knn_overlap(Z3, C3_sub, k=10)
overlap_k20 = compute_knn_overlap(Z3, C3_sub, k=20)

print(f"\n[SLIDE3-ALONE] kNN Overlap:")
print(f"  k=10: {overlap_k10:.4f}")
print(f"  k=20: {overlap_k20:.4f}")


In [None]:
# ===================================================================
# RETRAIN STAGE A: Fixed Local Minisets + Slide Alignment
# ===================================================================
import torch
import scanpy as sc
import numpy as np
import sys
sys.path.insert(0, '/home/ehtesamul/sc_st/model')

from core_models_et_p1 import SharedEncoder, train_encoder
import utils_et as uet

device = 'cuda' if torch.cuda.is_available() else 'cpu'

print("="*70)
print("RETRAIN STAGE A: FIXED Local Minisets + Slide Alignment")
print("="*70)

# ===================================================================
# 1) LOAD DATA
# ===================================================================
print("\n--- Loading HSCC data ---")
scadata = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/scP2.h5ad')
stadata3 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP2.h5ad')
stadata1 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP2rep2.h5ad')
stadata2 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP2rep3.h5ad')

# Normalize
for adata in [scadata, stadata1, stadata2, stadata3]:
    sc.pp.normalize_total(adata)
    sc.pp.log1p(adata)

# Get common genes
common = sorted(list(set(scadata.var_names) & set(stadata1.var_names) & 
                     set(stadata2.var_names) & set(stadata3.var_names)))
n_genes = len(common)
print(f"‚úì Common genes: {n_genes}")

# ===================================================================
# 2) PREPARE TRAINING DATA
# ===================================================================
# SC expression
X_sc = scadata[:, common].X
if hasattr(X_sc, "toarray"):
    X_sc = X_sc.toarray()
sc_expr = torch.tensor(X_sc, dtype=torch.float32, device=device)

# ST expression from 2 training slides
X_st1 = stadata1[:, common].X
X_st2 = stadata2[:, common].X
if hasattr(X_st1, "toarray"):
    X_st1 = X_st1.toarray()
if hasattr(X_st2, "toarray"):
    X_st2 = X_st2.toarray()

st_expr = torch.tensor(np.vstack([X_st1, X_st2]), dtype=torch.float32, device=device)

# ST coordinates
st_coords1 = stadata1.obsm['spatial']
st_coords2 = stadata2.obsm['spatial']
st_coords_raw = torch.tensor(np.vstack([st_coords1, st_coords2]), 
                             dtype=torch.float32, device=device)

# Slide IDs
slide_ids = torch.tensor(
    np.concatenate([
        np.zeros(X_st1.shape[0], dtype=int),
        np.ones(X_st2.shape[0], dtype=int)
    ]),
    dtype=torch.long, device=device
)

# Canonicalize coordinates per-slide
st_coords, st_mu, st_scale = uet.canonicalize_st_coords_per_slide(st_coords_raw, slide_ids)

print(f"‚úì SC expr: {sc_expr.shape}")
print(f"‚úì ST expr: {st_expr.shape}")
print(f"‚úì ST coords: {st_coords.shape}")
print(f"‚úì Slide IDs: {slide_ids.shape} (slides: {torch.unique(slide_ids).tolist()})")

# ===================================================================
# 3) CREATE AND TRAIN ENCODER
# ===================================================================
print("\n" + "="*70)
print("TRAINING STAGE A ENCODER (FIXED)")
print("="*70)

encoder = SharedEncoder(
    n_genes=n_genes,
    n_embedding=[512, 256, 128],
    dropout=0.1
)

encoder = train_encoder(
    model=encoder,
    st_gene_expr=st_expr,
    st_coords=st_coords,
    sc_gene_expr=sc_expr,
    slide_ids=slide_ids,
    n_epochs=2000,
    batch_size=256,
    lr=1e-4,
    sigma=None,  # Auto-compute
    alpha=0.0,   # Disable MMD (ST-only)
    ratio_start=0.0,
    ratio_end=0.0,  # Disable circular (ST-only)
    mmdbatch=0.0,
    device=device,
    outf='/home/ehtesamul/sc_st/model/gems_hscc_output_anchored_new',
    # LOCAL MINISET MODE
    local_miniset_mode=True,
    n_min=96,
    n_max=384,
    pool_mult=2.0,
    stochastic_tau=1.0,
    # INFONCE ALIGNMENT
    slide_align_mode='infonce',
    slide_align_weight=1.0,  # Start with 1.0; reduce to 0.5 if unstable
    infonce_tau=0.07,
    infonce_match='expr',
    infonce_topk=5,  # Sample from top-5 expression matches
    infonce_sym=True,
    # DISABLE SC LOSSES
    use_circle=False,
    use_mmd_sc=False,
)


print("\n‚úì Stage A training complete!")

# Save encoder
import os
os.makedirs('/home/ehtesamul/sc_st/model/gems_hscc_output_anchored_new', exist_ok=True)
torch.save(encoder.state_dict(), 
           '/home/ehtesamul/sc_st/model/gems_hscc_output_anchored_new/encoder_final_fixed.pt')
print("‚úì Encoder saved to: encoder_final_fixed.pt")

# ===================================================================
# 4) EVALUATE ON ALL SLIDES
# ===================================================================
print("\n" + "="*70)
print("EVALUATION: Learnability + OOD (ALL SLIDES)")
print("="*70)

# Get canonicalized coords for all slides
st1_coords_canon = st_coords[slide_ids == 0]
st2_coords_canon = st_coords[slide_ids == 1]

st3_coords_raw = torch.tensor(stadata3.obsm['spatial'], dtype=torch.float32, device=device)
slide_ids_st3 = torch.zeros(st3_coords_raw.shape[0], dtype=torch.long, device=device)
st3_coords_canon, _, _ = uet.canonicalize_st_coords_per_slide(st3_coords_raw, slide_ids_st3)

# Extract expression
X1 = torch.tensor(X_st1, dtype=torch.float32, device=device)
X2 = torch.tensor(X_st2, dtype=torch.float32, device=device)

X3 = stadata3[:, common].X
if hasattr(X3, "toarray"):
    X3 = X3.toarray()
X3 = torch.tensor(X3, dtype=torch.float32, device=device)

# Subsample
N_MAX = 2000
torch.manual_seed(42)

def subsample_slide(X, C, n_max):
    n = X.shape[0]
    if n <= n_max:
        return X, C
    else:
        idx = torch.randperm(n, device=device)[:n_max]
        return X[idx], C[idx]

X1_sub, C1_sub = subsample_slide(X1, st1_coords_canon, N_MAX)
X2_sub, C2_sub = subsample_slide(X2, st2_coords_canon, N_MAX)
X3_sub, C3_sub = subsample_slide(X3, st3_coords_canon, N_MAX)

# Compute embeddings
encoder.eval()
with torch.no_grad():
    Z1 = encoder(X1_sub)
    Z2 = encoder(X2_sub)
    Z3 = encoder(X3_sub)

print(f"Z1 shape: {Z1.shape}")
print(f"Z2 shape: {Z2.shape}")
print(f"Z3 shape: {Z3.shape}")

# kNN Overlap
def compute_knn_overlap(Z, C, k=10):
    n = Z.shape[0]
    D_emb = torch.cdist(Z, Z)
    D_emb.fill_diagonal_(float('inf'))
    _, knn_emb = torch.topk(D_emb, k=k, dim=1, largest=False)
    
    D_coord = torch.cdist(C, C)
    D_coord.fill_diagonal_(float('inf'))
    _, knn_coord = torch.topk(D_coord, k=k, dim=1, largest=False)
    
    overlaps = []
    for i in range(n):
        emb_set = set(knn_emb[i].cpu().numpy())
        coord_set = set(knn_coord[i].cpu().numpy())
        overlap = len(emb_set & coord_set) / k
        overlaps.append(overlap)
    
    return np.mean(overlaps)

print("\n[EMB-LEARN-FIXED] kNN Overlap:")
for name, Z, C in [("slide1", Z1, C1_sub), ("slide2", Z2, C2_sub), ("slide3", Z3, C3_sub)]:
    overlap_k10 = compute_knn_overlap(Z, C, k=10)
    overlap_k20 = compute_knn_overlap(Z, C, k=20)
    print(f"  {name} k=10: {overlap_k10:.4f}  k=20: {overlap_k20:.4f}")

# OOD Mixing
Z_all = torch.cat([Z1, Z2, Z3], dim=0)
n1, n2, n3 = Z1.shape[0], Z2.shape[0], Z3.shape[0]
labels = torch.cat([
    torch.zeros(n1, dtype=torch.long, device=device),
    torch.ones(n2, dtype=torch.long, device=device),
    torch.full((n3,), 2, dtype=torch.long, device=device)
])

K_mix = 20
st3_start = n1 + n2
D_st3 = torch.cdist(Z3, Z_all)

for i in range(n3):
    D_st3[i, st3_start + i] = float('inf')

_, knn_st3 = torch.topk(D_st3, k=K_mix, dim=1, largest=False)

frac_same = []
for i in range(n3):
    neighbor_labels = labels[knn_st3[i]]
    frac = (neighbor_labels == 2).float().mean().item()
    frac_same.append(frac)

frac_same = np.array(frac_same)
base_rate = n3 / (n1 + n2 + n3)

print(f"\n[EMB-OOD-FIXED] K={K_mix} slide3_frac_same:")
print(f"  mean={frac_same.mean():.4f} p50={np.percentile(frac_same, 50):.4f} p90={np.percentile(frac_same, 90):.4f}")
print(f"  base_rate={base_rate:.4f}")

if frac_same.mean() > 0.7:
    print("  ‚ö†Ô∏è  HIGH OOD")
elif frac_same.mean() > base_rate * 2:
    print("  ‚ö†Ô∏è  MODERATE OOD")
else:
    print("  ‚úì  WELL-MIXED")

print("\n" + "="*70)
print("EXPECTED RESULTS:")
print("="*70)
print("Slide1/2 overlap should RECOVER to ~0.60-0.70 (was ~0.46-0.53)")
print("Slide3 overlap depends on DECISIVE TEST result:")
print("  - If slide3-alone was high: slide3 should improve here too")
print("  - If slide3-alone was low: slide3 will stay ~0.18-0.26")
print("="*70)


In [None]:
# ===================================================================
# TRAIN STAGE A: VICReg + Domain Adversary (3 ST slides + SC)
# ===================================================================
import torch
import torch.nn.functional as F
import scanpy as sc
import numpy as np
import sys
sys.path.insert(0, '/home/ehtesamul/sc_st/model')

from core_models_et_p1 import SharedEncoder, train_encoder
import utils_et as uet

device = 'cuda' if torch.cuda.is_available() else 'cpu'

print("="*70)
print("TRAIN STAGE A: VICReg + Domain Adversary (3 ST + SC)")
print("="*70)

def subsample_domain(X, n_max):
    """Subsample a domain to n_max samples."""
    n = X.shape[0]
    if n <= n_max:
        return X
    else:
        idx = torch.randperm(n, device=device)[:n_max]
        return X[idx]

# ===================================================================
# 1) LOAD DATA
# ===================================================================
print("\n--- Loading HSCC data ---")
scadata = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/scP2.h5ad')
stadata1 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP2.h5ad')
stadata2 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP2rep2.h5ad')
stadata3 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP2rep3.h5ad')

# Normalize
for adata in [scadata, stadata1, stadata2, stadata3]:
    sc.pp.normalize_total(adata)
    sc.pp.log1p(adata)

# Get common genes
common = sorted(list(set(scadata.var_names) & set(stadata1.var_names) & 
                     set(stadata2.var_names) & set(stadata3.var_names)))
n_genes = len(common)
print(f"‚úì Common genes: {n_genes}")

# ===================================================================
# 2) PREPARE TRAINING DATA - ALL 3 ST SLIDES + SC
# ===================================================================
# SC expression
X_sc = scadata[:, common].X
if hasattr(X_sc, "toarray"):
    X_sc = X_sc.toarray()
sc_expr = torch.tensor(X_sc, dtype=torch.float32, device=device)

# ST expression from ALL 3 slides
X_st1 = stadata1[:, common].X
X_st2 = stadata2[:, common].X
X_st3 = stadata3[:, common].X
if hasattr(X_st1, "toarray"):
    X_st1 = X_st1.toarray()
if hasattr(X_st2, "toarray"):
    X_st2 = X_st2.toarray()
if hasattr(X_st3, "toarray"):
    X_st3 = X_st3.toarray()

st_expr = torch.tensor(np.vstack([X_st1, X_st2, X_st3]), dtype=torch.float32, device=device)

# ST coordinates (required for function signature, but not used by VICReg)
st_coords1 = stadata1.obsm['spatial']
st_coords2 = stadata2.obsm['spatial']
st_coords3 = stadata3.obsm['spatial']
st_coords_raw = torch.tensor(np.vstack([st_coords1, st_coords2, st_coords3]),
                             dtype=torch.float32, device=device)

# Slide IDs - 3 ST SLIDES
slide_ids = torch.tensor(
    np.concatenate([
        np.zeros(X_st1.shape[0], dtype=int),    # slide 0
        np.ones(X_st2.shape[0], dtype=int),     # slide 1
        np.full(X_st3.shape[0], 2, dtype=int)   # slide 2
    ]),
    dtype=torch.long, device=device
)

# Canonicalize coordinates (for function signature)
st_coords, st_mu, st_scale = uet.canonicalize_st_coords_per_slide(st_coords_raw, slide_ids)

print(f"‚úì SC expr: {sc_expr.shape}")
print(f"‚úì ST expr: {st_expr.shape}")
print(f"‚úì ST coords: {st_coords.shape}")
print(f"‚úì Slide IDs: {slide_ids.shape} (slides: {torch.unique(slide_ids).tolist()})")
print(f"‚úì Training will use 4 domains: ST-slide0, ST-slide1, ST-slide2, SC")

# ===================================================================
# 3) CREATE AND TRAIN ENCODER WITH VICREG + DOMAIN ADVERSARY
# ===================================================================
print("\n" + "="*70)
print("TRAINING STAGE A ENCODER (VICReg + Domain Adversary)")
print("="*70)

encoder_vicreg = SharedEncoder(
    n_genes=n_genes,
    n_embedding=[512, 256, 128],
    dropout=0.1
)

encoder_vicreg, projector, discriminator, hist = train_encoder(
    model=encoder_vicreg,
    st_gene_expr=st_expr,
    st_coords=st_coords,
    sc_gene_expr=sc_expr,  # ‚Üê SC DATA INCLUDED
    slide_ids=slide_ids,
    n_epochs=1000,
    batch_size=256,  # Divisible by 4 (perfect for 4 domains)
    lr=1e-3,
    device=device,
    outf='/home/ehtesamul/sc_st/model/gems_hscc_vicreg_output',
    # ========== VICReg Mode ==========
    stageA_obj='vicreg_adv',
    # VICReg loss weights
    vicreg_lambda_inv=25.0,
    vicreg_lambda_var=25.0,
    vicreg_lambda_cov=1.0,
    vicreg_gamma=1.0,
    vicreg_eps=1e-4,
    vicreg_project_dim=256,
    vicreg_use_projector=True,
    vicreg_float32_stats=True,
    vicreg_ddp_gather=False,
    # Expression augmentations
    aug_gene_dropout=0.5,
    aug_gauss_std=0.05,
    aug_scale_jitter=0.4,
    # Domain adversary
    adv_slide_weight=20.0,
    adv_warmup_epochs=50,
    adv_ramp_epochs=200,
    grl_alpha_max=1.0,
    disc_hidden=256,
    disc_dropout=0.1,
    # Balanced domain sampling
    stageA_balanced_slides=True,
    return_aux=True
)

print("\n‚úì VICReg Stage A training complete!")

# Save encoder
import os
os.makedirs('/home/ehtesamul/sc_st/model/gems_hscc_vicreg_output', exist_ok=True)
torch.save(encoder_vicreg.state_dict(), 
           '/home/ehtesamul/sc_st/model/gems_hscc_vicreg_output/encoder_vicreg.pt')
print("‚úì Encoder saved to: encoder_vicreg.pt")

# ===================================================================
# 4) EVALUATE DOMAIN MIXING (ALL 4 DOMAINS: 3 ST + SC)
# ===================================================================
print("\n" + "="*70)
print("EVALUATION: Domain Mixing (3 ST slides + SC)")
print("="*70)

N_MAX = 2000

# Subsample all 4 domains
X1 = torch.tensor(X_st1, dtype=torch.float32, device=device)
X2 = torch.tensor(X_st2, dtype=torch.float32, device=device)
X3 = torch.tensor(X_st3, dtype=torch.float32, device=device)
X_sc_torch = torch.tensor(X_sc, dtype=torch.float32, device=device)

X1_sub = subsample_domain(X1, N_MAX)
X2_sub = subsample_domain(X2, N_MAX)
X3_sub = subsample_domain(X3, N_MAX)
X_sc_sub = subsample_domain(X_sc_torch, N_MAX)

# Compute embeddings
encoder_vicreg.eval()
with torch.no_grad():
    Z1 = encoder_vicreg(X1_sub)
    Z2 = encoder_vicreg(X2_sub)
    Z3 = encoder_vicreg(X3_sub)
    Z_sc = encoder_vicreg(X_sc_sub)

print(f"Z1 (ST-slide0): {Z1.shape}")
print(f"Z2 (ST-slide1): {Z2.shape}")
print(f"Z3 (ST-slide2): {Z3.shape}")
print(f"Z_sc (SC):      {Z_sc.shape}")

# ===================================================================
# TEST 1: Expression Mixing (baseline)
# ===================================================================
print("\n[EXPR-MIXING] Expression space kNN domain distribution:")
X_all = torch.cat([X1_sub, X2_sub, X3_sub, X_sc_sub], dim=0)
X_all_norm = F.normalize(X_all, dim=1)

n1, n2, n3, n_sc = X1_sub.shape[0], X2_sub.shape[0], X3_sub.shape[0], X_sc_sub.shape[0]
n_total = n1 + n2 + n3 + n_sc

# Domain labels: 0=ST1, 1=ST2, 2=ST3, 3=SC
labels = torch.cat([
    torch.zeros(n1, dtype=torch.long, device=device),
    torch.ones(n2, dtype=torch.long, device=device),
    torch.full((n3,), 2, dtype=torch.long, device=device),
    torch.full((n_sc,), 3, dtype=torch.long, device=device)
])

K_mix = 20

# Check SC mixing in expression space
sc_start = n1 + n2 + n3
D_expr_sc = torch.cdist(X_all_norm[sc_start:], X_all_norm)

# Exclude self
for i in range(n_sc):
    D_expr_sc[i, sc_start + i] = float('inf')

_, knn_expr_sc = torch.topk(D_expr_sc, k=K_mix, dim=1, largest=False)

frac_same_expr = []
for i in range(n_sc):
    neighbor_labels = labels[knn_expr_sc[i]]
    frac = (neighbor_labels == 3).float().mean().item()
    frac_same_expr.append(frac)

frac_same_expr = np.array(frac_same_expr)
base_rate = n_sc / n_total

print(f"  SC neighbors (K={K_mix}):")
print(f"    Same-domain fraction: {frac_same_expr.mean():.4f}")
print(f"    Base rate (chance):   {base_rate:.4f}")
print(f"    ‚Üí Expression space shows domain clustering (expected)")

# ===================================================================
# TEST 2: Z Mixing (should approach base_rate if working)
# ===================================================================
print("\n[Z-MIXING] Embedding space kNN domain distribution:")
Z_all = torch.cat([Z1, Z2, Z3, Z_sc], dim=0)
Z_all_norm = F.normalize(Z_all, dim=1)

D_emb_sc = torch.cdist(Z_all_norm[sc_start:], Z_all_norm)

# Exclude self
for i in range(n_sc):
    D_emb_sc[i, sc_start + i] = float('inf')

_, knn_emb_sc = torch.topk(D_emb_sc, k=K_mix, dim=1, largest=False)

frac_same_z = []
for i in range(n_sc):
    neighbor_labels = labels[knn_emb_sc[i]]
    frac = (neighbor_labels == 3).float().mean().item()
    frac_same_z.append(frac)

frac_same_z = np.array(frac_same_z)

print(f"  SC neighbors (K={K_mix}):")
print(f"    Same-domain fraction: {frac_same_z.mean():.4f}")
print(f"    Base rate (chance):   {base_rate:.4f}")

improvement = (frac_same_expr.mean() - frac_same_z.mean()) / (frac_same_expr.mean() - base_rate)


print(f"  ‚Üí Mixing improvement: {improvement*100:.1f}%")

# ===================================================================
# TEST 3: All-vs-All Domain Mixing Matrix
# ===================================================================
print("\n[MIXING-MATRIX] Cross-domain neighbor fractions:")
domain_names = ['ST-slide0', 'ST-slide1', 'ST-slide2', 'SC']
domain_sizes = [n1, n2, n3, n_sc]
domain_starts = [0, n1, n1+n2, n1+n2+n3]

print("\nQuery ‚Üí Neighbors (mean fraction of K=20 neighbors):")
print("         ST-s0  ST-s1  ST-s2    SC")

for query_idx, query_name in enumerate(domain_names):
    start_idx = domain_starts[query_idx]
    end_idx = start_idx + domain_sizes[query_idx]
    
    # kNN for this domain
    D_query = torch.cdist(Z_all_norm[start_idx:end_idx], Z_all_norm)
    
    # Exclude self
    for i in range(domain_sizes[query_idx]):
        D_query[i, start_idx + i] = float('inf')
    
    _, knn_query = torch.topk(D_query, k=K_mix, dim=1, largest=False)
    
    # Count neighbors from each domain
    fracs = []
    for target_idx in range(4):
        neighbor_labels = labels[knn_query.flatten()]
        frac = (neighbor_labels == target_idx).float().mean().item()
        fracs.append(frac)
    
    print(f"{query_name:8s}  {fracs[0]:.3f}  {fracs[1]:.3f}  {fracs[2]:.3f}  {fracs[3]:.3f}")

print("\nIdeal (perfect mixing): uniform fractions matching domain sizes")
print(f"Expected: ST-s0={n1/n_total:.3f}, ST-s1={n2/n_total:.3f}, ST-s2={n3/n_total:.3f}, SC={n_sc/n_total:.3f}")

# ===================================================================
# TEST 4: Domain Linear Probe (should be near chance)
# ===================================================================
print("\n[DOMAIN-PROBE] Linear probe accuracy on Z:")
from sklearn.linear_model import LogisticRegression

Z_all_cpu = Z_all.cpu().numpy()
labels_cpu = labels.cpu().numpy()

probe = LogisticRegression(max_iter=1000, random_state=42)
Z_all_n = F.normalize(Z_all, dim=1).cpu().numpy()
probe.fit(Z_all_n, labels_cpu)
acc = probe.score(Z_all_n, labels_cpu)
chance = 1.0 / 4.0

print(f"  Accuracy: {acc:.4f} (chance={chance:.3f})")


# ===================================================================
# TEST 5: Per-dimension statistics (detect collapse)
# ===================================================================
print("\n[COLLAPSE-CHECK] Per-dimension statistics:")
std_per_dim = Z_all.std(dim=0)
print(f"  Mean std:   {std_per_dim.mean().item():.4f}")
print(f"  Min std:    {std_per_dim.min().item():.4f}")
print(f"  Max std:    {std_per_dim.max().item():.4f}")
print(f"  Dead dims:  {(std_per_dim < 0.1).sum().item()}/{std_per_dim.shape[0]}")

if std_per_dim.min().item() < 0.01:
    print("  ‚ö†Ô∏è  COLLAPSE DETECTED (some dims dead)")
elif std_per_dim.mean().item() < 0.5:
    print("  ‚ö†Ô∏è  LOW VARIANCE (increase vicreg_gamma or reduce adversary)")
else:
    print("  ‚úì  Healthy variance")

# ===================================================================
# TEST 6: ST‚ÜîSC CORAL distance (should be low)
# ===================================================================
print("\n[ST-SC-ALIGNMENT] CORAL distance between ST and SC:")

# Pool all ST
Z_st_all = torch.cat([Z1, Z2, Z3], dim=0)

# Compute CORAL
mu_st = Z_st_all.mean(dim=0)
mu_sc = Z_sc.mean(dim=0)
mean_diff = (mu_st - mu_sc).pow(2).mean().item()

z_st_c = Z_st_all - mu_st
z_sc_c = Z_sc - mu_sc
cov_st = (z_st_c.T @ z_st_c) / max(z_st_c.shape[0] - 1, 1)
cov_sc = (z_sc_c.T @ z_sc_c) / max(z_sc_c.shape[0] - 1, 1)
cov_diff = (cov_st - cov_sc).pow(2).mean().item()

coral_dist = mean_diff + cov_diff

print(f"  Mean difference:       {mean_diff:.6f}")
print(f"  Covariance difference: {cov_diff:.6f}")
print(f"  Total CORAL distance:  {coral_dist:.6f}")



In [None]:
# ===================================================================
# DIAGNOSTIC: Per-Domain Embedding Analysis (3 ST + SC)
# ===================================================================
print("\n" + "="*70)
print("DIAGNOSTIC: Per-Domain Embedding Statistics (3 ST + SC)")
print("="*70)

import torch.nn.functional as F
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import balanced_accuracy_score, confusion_matrix
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import numpy as np

# Compute embeddings for all 4 domains
encoder_vicreg.eval()
with torch.no_grad():
    Z1_full = encoder_vicreg(torch.tensor(X_st1, dtype=torch.float32, device=device))
    Z2_full = encoder_vicreg(torch.tensor(X_st2, dtype=torch.float32, device=device))
    Z3_full = encoder_vicreg(torch.tensor(X_st3, dtype=torch.float32, device=device))
    Z_sc_full = encoder_vicreg(torch.tensor(X_sc, dtype=torch.float32, device=device))

# Per-domain statistics
def domain_stats(Z, name):
    mean = Z.mean(dim=0)
    std = Z.std(dim=0)
    print(f"\n{name}:")
    print(f"  Shape: {Z.shape}")
    print(f"  Mean norm: {mean.norm().item():.4f}")
    print(f"  Mean of stds: {std.mean().item():.4f}")
    print(f"  First 5 dims mean: {mean[:5].cpu().numpy()}")
    return mean

mean_st1 = domain_stats(Z1_full, "ST Slide 1")
mean_st2 = domain_stats(Z2_full, "ST Slide 2")
mean_st3 = domain_stats(Z3_full, "ST Slide 3")
mean_sc = domain_stats(Z_sc_full, "SC")

# Distance between all domain centroids
print(f"\n{'='*70}")
print("Centroid Distances (all pairs):")
print(f"{'='*70}")

dist_12 = (mean_st1 - mean_st2).norm().item()
dist_13 = (mean_st1 - mean_st3).norm().item()
dist_23 = (mean_st2 - mean_st3).norm().item()
dist_1sc = (mean_st1 - mean_sc).norm().item()
dist_2sc = (mean_st2 - mean_sc).norm().item()
dist_3sc = (mean_st3 - mean_sc).norm().item()

print(f"  ST1 ‚Üî ST2:  {dist_12:.4f}")
print(f"  ST1 ‚Üî ST3:  {dist_13:.4f}")
print(f"  ST2 ‚Üî ST3:  {dist_23:.4f}")
print(f"  ST1 ‚Üî SC:   {dist_1sc:.4f}")
print(f"  ST2 ‚Üî SC:   {dist_2sc:.4f}")
print(f"  ST3 ‚Üî SC:   {dist_3sc:.4f}")

# Within-ST vs ST-SC distances
st_dists = [dist_12, dist_13, dist_23]
st_sc_dists = [dist_1sc, dist_2sc, dist_3sc]

print(f"\nWithin-ST distances:")
print(f"  Mean: {np.mean(st_dists):.4f}")
print(f"  Max:  {np.max(st_dists):.4f}")

print(f"\nST ‚Üî SC distances:")
print(f"  Mean: {np.mean(st_sc_dists):.4f}")
print(f"  Max:  {np.max(st_sc_dists):.4f}")

if np.mean(st_sc_dists) > 2.0 * np.mean(st_dists):
    print("  ‚ö†Ô∏è  SC centroid is FAR from ST centroids (poor alignment!)")
elif np.mean(st_sc_dists) > 1.5 * np.mean(st_dists):
    print("  ‚ö†Ô∏è  SC centroid is separated from ST (moderate misalignment)")
else:
    print("  ‚úì  SC and ST centroids are reasonably close (good alignment!)")

# ===================================================================
# Linear Separability Check - ALL 4 DOMAINS
# ===================================================================
print(f"\n{'='*70}")
print("Linear Separability Test (4-class: ST1, ST2, ST3, SC)")
print(f"{'='*70}")

# Subsample for faster probe (optional)
N_PROBE = 2000
Z1_probe = Z1_full if Z1_full.shape[0] <= N_PROBE else Z1_full[torch.randperm(Z1_full.shape[0])[:N_PROBE]]
Z2_probe = Z2_full if Z2_full.shape[0] <= N_PROBE else Z2_full[torch.randperm(Z2_full.shape[0])[:N_PROBE]]
Z3_probe = Z3_full if Z3_full.shape[0] <= N_PROBE else Z3_full[torch.randperm(Z3_full.shape[0])[:N_PROBE]]
Z_sc_probe = Z_sc_full if Z_sc_full.shape[0] <= N_PROBE else Z_sc_full[torch.randperm(Z_sc_full.shape[0])[:N_PROBE]]

Z_all_probe = torch.cat([Z1_probe, Z2_probe, Z3_probe, Z_sc_probe], dim=0)
labels_probe = torch.cat([
    torch.zeros(Z1_probe.shape[0], dtype=torch.long),
    torch.ones(Z2_probe.shape[0], dtype=torch.long),
    torch.full((Z3_probe.shape[0],), 2, dtype=torch.long),
    torch.full((Z_sc_probe.shape[0],), 3, dtype=torch.long)
])

Z_all_probe_np = Z_all_probe.cpu().numpy()
labels_probe_np = labels_probe.cpu().numpy()

# Train probe with balanced class weights
probe = LogisticRegression(max_iter=2000, random_state=42, class_weight='balanced')
probe.fit(Z_all_probe_np, labels_probe_np)

# Predictions
pred = probe.predict(Z_all_probe_np)

# Metrics
acc = probe.score(Z_all_probe_np, labels_probe_np)
bal_acc = balanced_accuracy_score(labels_probe_np, pred)
chance = 0.25

print(f"  Standard Accuracy:  {acc:.4f}")
print(f"  Balanced Accuracy:  {bal_acc:.4f} (chance={chance:.3f})")

if bal_acc > 0.45:
    print("  ‚ö†Ô∏è  Domains are VERY separable (adversary failed!)")
elif bal_acc > 0.35:
    print("  ‚ö†Ô∏è  Domains are moderately separable")
else:
    print("  ‚úì  Domains are well-mixed (near chance!)")

# Confusion matrix
cm = confusion_matrix(labels_probe_np, pred)
cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

print(f"\nConfusion Matrix (normalized by row):")
print("       ST1   ST2   ST3    SC")
for i, label in enumerate(['ST1', 'ST2', 'ST3', 'SC ']):
    row_str = f"{label}  "
    for j in range(4):
        row_str += f"{cm_norm[i, j]:.3f} "
    print(row_str)

# ===================================================================
# PCA Visualization - ALL 4 DOMAINS
# ===================================================================
print(f"\n{'='*70}")
print("PCA Visualization")
print(f"{'='*70}")

# Subsample for cleaner visualization
N_VIS = 3000
Z1_vis = Z1_full if Z1_full.shape[0] <= N_VIS else Z1_full[torch.randperm(Z1_full.shape[0])[:N_VIS]]
Z2_vis = Z2_full if Z2_full.shape[0] <= N_VIS else Z2_full[torch.randperm(Z2_full.shape[0])[:N_VIS]]
Z3_vis = Z3_full if Z3_full.shape[0] <= N_VIS else Z3_full[torch.randperm(Z3_full.shape[0])[:N_VIS]]
Z_sc_vis = Z_sc_full if Z_sc_full.shape[0] <= N_VIS else Z_sc_full[torch.randperm(Z_sc_full.shape[0])[:N_VIS]]

Z_all_vis = torch.cat([Z1_vis, Z2_vis, Z3_vis, Z_sc_vis], dim=0).cpu().numpy()
labels_vis = torch.cat([
    torch.zeros(Z1_vis.shape[0]),
    torch.ones(Z2_vis.shape[0]),
    torch.full((Z3_vis.shape[0],), 2),
    torch.full((Z_sc_vis.shape[0],), 3)
]).cpu().numpy()

pca = PCA(n_components=2)
Z_pca = pca.fit_transform(Z_all_vis)

# Create figure with 2 subplots
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Plot 1: All 4 domains with different colors
colors = ['#e74c3c', '#3498db', '#2ecc71', '#f39c12']  # red, blue, green, orange
labels_str = ['ST Slide 1', 'ST Slide 2', 'ST Slide 3', 'SC']
markers = ['o', 'o', 'o', 's']  # circles for ST, square for SC

for i in range(4):
    mask = labels_vis == i
    axes[0].scatter(Z_pca[mask, 0], Z_pca[mask, 1], 
                    c=colors[i], label=labels_str[i], 
                    alpha=0.4, s=15, marker=markers[i], edgecolors='none')

axes[0].legend(loc='best', framealpha=0.9)
axes[0].set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]*100:.1f}%)', fontsize=12)
axes[0].set_ylabel(f'PC2 ({pca.explained_variance_ratio_[1]*100:.1f}%)', fontsize=12)
axes[0].set_title('All 4 Domains (3 ST + SC)', fontsize=14, fontweight='bold')
axes[0].grid(True, alpha=0.3)

# Plot 2: ST vs SC only (binary view)
is_st = labels_vis < 3
axes[1].scatter(Z_pca[is_st, 0], Z_pca[is_st, 1], 
                c='#3498db', label='ST (all slides)', 
                alpha=0.4, s=15, edgecolors='none')
axes[1].scatter(Z_pca[~is_st, 0], Z_pca[~is_st, 1], 
                c='#f39c12', label='SC', 
                alpha=0.4, s=20, marker='s', edgecolors='none')

axes[1].legend(loc='best', framealpha=0.9)
axes[1].set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]*100:.1f}%)', fontsize=12)
axes[1].set_ylabel(f'PC2 ({pca.explained_variance_ratio_[1]*100:.1f}%)', fontsize=12)
axes[1].set_title('ST vs SC (Binary View)', fontsize=14, fontweight='bold')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
# plt.savefig('/home/ehtesamul/sc_st/model/gems_hscc_vicreg_output/pca_embeddings_4domains.png', dpi=150)
plt.show()

print(f"\n‚úì PCA shows explained variance: PC1={pca.explained_variance_ratio_[0]*100:.1f}%, PC2={pca.explained_variance_ratio_[1]*100:.1f}%")

# ===================================================================
# FINAL DIAGNOSTIC SUMMARY
# ===================================================================
print(f"\n{'='*70}")
print("DIAGNOSTIC SUMMARY")
print(f"{'='*70}")
print(f"‚úì Within-ST centroid dist:  {np.mean(st_dists):.4f}")
print(f"‚úì ST-SC centroid dist:      {np.mean(st_sc_dists):.4f}")
print(f"‚úì Balanced accuracy (probe): {bal_acc:.4f} (chance=0.25)")

if np.mean(st_sc_dists) < 1.5 * np.mean(st_dists) and bal_acc < 0.35:
    print("\nüéâ EXCELLENT: SC is well-aligned with ST!")
elif np.mean(st_sc_dists) < 2.0 * np.mean(st_dists) and bal_acc < 0.40:
    print("\n‚úì GOOD: Moderate alignment, but could improve")
else:
    print("\n‚ö†Ô∏è  POOR: SC is still separated from ST (needs tuning)")
    print("\nSuggestions:")
    print("  - Switch to 2-class adversary (ST vs SC)")
    print("  - Add MMD loss with higher weight")
    print("  - Increase disc_steps from 10 ‚Üí 20")
    print("  - Reduce adv_slide_weight from 20 ‚Üí 10")
    
print(f"{'='*70}\n")


In [None]:
import numpy as np
import torch
import torch.nn.functional as F
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, balanced_accuracy_score, confusion_matrix

def make_repr(Z, mode: str):
    # mode: 'raw', 'ln', 'ln_norm', 'norm'
    if mode == 'raw':
        return Z
    if mode == 'norm':
        return F.normalize(Z, dim=1)
    if mode == 'ln':
        return F.layer_norm(Z, (Z.shape[1],))
    if mode == 'ln_norm':
        Z = F.layer_norm(Z, (Z.shape[1],))
        return F.normalize(Z, dim=1)
    raise ValueError(mode)

@torch.no_grad()
def eval_knn_and_probe(Z1, Z2, Z3, Zsc, K=20, mode='ln_norm'):
    # Build representations
    Z1r = make_repr(Z1, mode)
    Z2r = make_repr(Z2, mode)
    Z3r = make_repr(Z3, mode)
    Zscr = make_repr(Zsc, mode)

    Z_all = torch.cat([Z1r, Z2r, Z3r, Zscr], dim=0)
    Z_all = make_repr(Z_all, mode)  # ensure consistent transform

    n1, n2, n3, nsc = Z1r.shape[0], Z2r.shape[0], Z3r.shape[0], Zscr.shape[0]
    sc_start = n1 + n2 + n3
    n_total = n1 + n2 + n3 + nsc

    labels = torch.cat([
        torch.zeros(n1, dtype=torch.long, device=Z_all.device),
        torch.ones(n2, dtype=torch.long, device=Z_all.device),
        torch.full((n3,), 2, dtype=torch.long, device=Z_all.device),
        torch.full((nsc,), 3, dtype=torch.long, device=Z_all.device),
    ])

    # --- kNN: query SC vs all ---
    D = torch.cdist(Z_all[sc_start:], Z_all)
    for i in range(nsc):
        D[i, sc_start + i] = float('inf')

    _, knn = torch.topk(D, k=K, dim=1, largest=False)

    frac_same_sc = (labels[knn] == 3).float().mean().item()
    base_rate_sc = nsc / n_total

    # --- probe: balanced accuracy (important because SC is ~50%) ---
    Z_np = Z_all.cpu().numpy()
    y_np = labels.cpu().numpy()

    clf = LogisticRegression(max_iter=2000, random_state=42, class_weight='balanced')
    clf.fit(Z_np, y_np)
    pred = clf.predict(Z_np)
    bal_acc = balanced_accuracy_score(y_np, pred)
    cm = confusion_matrix(y_np, pred, labels=[0,1,2,3])
    cmn = cm / cm.sum(axis=1, keepdims=True)

    pred = clf.predict(Z_np)  # from Exp 1
    print("Accuracy:", accuracy_score(y_np, pred))
    print("Balanced accuracy:", balanced_accuracy_score(y_np, pred))


    return {
        "mode": mode,
        "sc_knn_same": frac_same_sc,
        "sc_base_rate": base_rate_sc,
        "probe_bal_acc": bal_acc,
        "cm_norm": cmn
    }

# ---- compute embeddings once (full) ----
encoder_vicreg.eval()
with torch.no_grad():
    Z1_full = encoder_vicreg(X1)
    Z2_full = encoder_vicreg(X2)
    Z3_full = encoder_vicreg(X3)
    Zsc_full = encoder_vicreg(torch.tensor(X_sc, dtype=torch.float32, device=device))

for mode in ["raw", "norm", "ln", "ln_norm"]:
    out = eval_knn_and_probe(Z1_full, Z2_full, Z3_full, Zsc_full, K=20, mode=mode)
    print("\n" + "="*70)
    print("MODE:", out["mode"])
    print(f"SC kNN same-domain: {out['sc_knn_same']:.4f} | base-rate: {out['sc_base_rate']:.4f}")
    print(f"Probe balanced acc: {out['probe_bal_acc']:.4f} (chance=0.25)")
    print("Confusion matrix (row-normalized) rows=[ST1,ST2,ST3,SC]:")
    print(np.round(out["cm_norm"], 3))


In [None]:
import torch
import torch.nn.functional as F

@torch.no_grad()
def disc_eval(discriminator, Z, y, mode="ln_norm"):
    # mode should mirror what discriminator saw during training:
    # in your code: discriminator sees F.normalize(z_bar_raw) [file:3]
    if mode == "norm":
        Z_in = F.normalize(Z, dim=1)
    elif mode == "ln_norm":
        Z_in = F.layer_norm(Z, (Z.shape[1],))
        Z_in = F.normalize(Z_in, dim=1)
    else:
        raise ValueError(mode)

    logits = discriminator(Z_in)
    pred = logits.argmax(dim=1)
    acc = (pred == y).float().mean().item()
    return acc, logits.softmax(dim=1).mean(dim=0).cpu().numpy()

# Build a combined eval set
Z_all = torch.cat([Z1_full, Z2_full, Z3_full, Zsc_full], dim=0)
y_all = torch.cat([
    torch.zeros(Z1_full.shape[0], dtype=torch.long, device=device),
    torch.ones(Z2_full.shape[0], dtype=torch.long, device=device),
    torch.full((Z3_full.shape[0],), 2, dtype=torch.long, device=device),
    torch.full((Zsc_full.shape[0],), 3, dtype=torch.long, device=device),
])

acc_norm, mean_p_norm = disc_eval(discriminator, Z_all, y_all, mode="norm")
acc_ln_norm, mean_p_ln_norm = disc_eval(discriminator, Z_all, y_all, mode="ln_norm")

print("Disc acc (norm):    ", acc_norm)
print("Mean predicted probs (norm):   ", np.round(mean_p_norm, 3))
print("Disc acc (ln_norm): ", acc_ln_norm)
print("Mean predicted probs (ln_norm):", np.round(mean_p_ln_norm, 3))


In [None]:
import torch
import torch.nn.functional as F
import numpy as np

print("="*70)
print("EXPERIMENT 1: SC ‚Üí ST Distance Ratio")
print("="*70)

# Use the embeddings we already have
encoder_vicreg.eval()
with torch.no_grad():
    Z1 = encoder_vicreg(X1)
    Z2 = encoder_vicreg(X2)
    Z3 = encoder_vicreg(X3)
    Zsc = encoder_vicreg(torch.tensor(X_sc, dtype=torch.float32, device=device))

# Subsample for speed
N_TEST = 2000
if Zsc.shape[0] > N_TEST:
    idx = torch.randperm(Zsc.shape[0])[:N_TEST]
    Zsc_test = Zsc[idx]
else:
    Zsc_test = Zsc

# Concatenate all ST
Zst_all = torch.cat([Z1, Z2, Z3], dim=0)

# Subsample ST too
if Zst_all.shape[0] > N_TEST * 2:
    idx_st = torch.randperm(Zst_all.shape[0])[:N_TEST * 2]
    Zst_test = Zst_all[idx_st]
else:
    Zst_test = Zst_all

# Normalize (this is what you used in kNN)
Zsc_norm = F.normalize(Zsc_test, dim=1)
Zst_norm = F.normalize(Zst_test, dim=1)
Zsc_pool_norm = F.normalize(Zsc_test, dim=1)  # for SC‚ÜíSC

# Distance from each SC to nearest ST
D_sc_to_st = torch.cdist(Zsc_norm, Zst_norm)
dist_to_nearest_st = D_sc_to_st.min(dim=1)[0]

# Distance from each SC to nearest other SC
D_sc_to_sc = torch.cdist(Zsc_norm, Zsc_pool_norm)
# Mask self
for i in range(D_sc_to_sc.shape[0]):
    if i < D_sc_to_sc.shape[1]:
        D_sc_to_sc[i, i] = float('inf')
dist_to_nearest_sc = D_sc_to_sc.min(dim=1)[0]

# Ratio
ratio = (dist_to_nearest_st / (dist_to_nearest_sc + 1e-8)).cpu().numpy()

print(f"SC ‚Üí nearest ST distance:  {dist_to_nearest_st.mean().item():.4f} ¬± {dist_to_nearest_st.std().item():.4f}")
print(f"SC ‚Üí nearest SC distance:  {dist_to_nearest_sc.mean().item():.4f} ¬± {dist_to_nearest_sc.std().item():.4f}")
print(f"Ratio (ST/SC):             {ratio.mean():.4f} ¬± {ratio.std():.4f}")
print(f"Median ratio:              {np.median(ratio):.4f}")

if ratio.mean() > 2.0:
    print("\n‚ö†Ô∏è  SC is FAR from ST (ratio > 2 means SC forms tight cluster away from ST)")
elif ratio.mean() > 1.3:
    print("\n‚ö†Ô∏è  SC is moderately separated from ST")
else:
    print("\n‚úì  SC and ST are locally overlapping (ratio near 1)")

print("="*70)


In [None]:
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import balanced_accuracy_score, roc_auc_score
import numpy as np

print("="*70)
print("EXPERIMENT 2: Binary Probe (ST vs SC)")
print("="*70)

# Build binary labels: 0 = ST, 1 = SC
n1, n2, n3 = Z1.shape[0], Z2.shape[0], Z3.shape[0]
nsc = Zsc.shape[0]

Z_all = torch.cat([Z1, Z2, Z3, Zsc], dim=0)
y_binary = torch.cat([
    torch.zeros(n1 + n2 + n3, dtype=torch.long),
    torch.ones(nsc, dtype=torch.long)
])

Z_np = F.normalize(Z_all, dim=1).cpu().numpy()
y_np = y_binary.cpu().numpy()

# Train probe with SAGA solver (more stable)
clf = LogisticRegression(max_iter=5000, random_state=42, 
                         class_weight='balanced', solver='saga')
clf.fit(Z_np, y_np)

# Predictions
pred = clf.predict(Z_np)
pred_proba = clf.predict_proba(Z_np)[:, 1]

# Metrics
bal_acc = balanced_accuracy_score(y_np, pred)
auc = roc_auc_score(y_np, pred_proba)

# Per-class accuracy
acc_st = (pred[y_np == 0] == 0).mean()
acc_sc = (pred[y_np == 1] == 1).mean()

print(f"Balanced Accuracy: {bal_acc:.4f} (chance=0.50)")
print(f"ROC-AUC:           {auc:.4f} (chance=0.50)")
print(f"ST accuracy:       {acc_st:.4f} (how many ST predicted as ST)")
print(f"SC accuracy:       {acc_sc:.4f} (how many SC predicted as SC)")

if bal_acc > 0.85:
    print("\n‚ö†Ô∏è  ST and SC are VERY separable (adversary failed)")
elif bal_acc > 0.70:
    print("\n‚ö†Ô∏è  ST and SC are moderately separable")
elif bal_acc > 0.60:
    print("\n‚ö†Ô∏è  ST and SC are slightly separable")
else:
    print("\n‚úì  ST and SC are mixed (near chance)")

print("="*70)


In [None]:
import torch
import torch.nn.functional as F

print("="*70)
print("EXPERIMENT 3: Discriminator Eval (2-class: ST vs SC)")
print("="*70)

# Build binary labels
y_binary_disc = torch.cat([
    torch.zeros(n1 + n2 + n3, dtype=torch.long, device=device),
    torch.ones(nsc, dtype=torch.long, device=device)
])

# Check discriminator output shape
with torch.no_grad():
    test_out = discriminator(F.normalize(Z1[:10], dim=1))
    n_classes = test_out.shape[1]
    print(f"Discriminator outputs {n_classes} classes")

if n_classes == 2:
    print("‚úì Discriminator is 2-class (ST vs SC)")
    
    # Evaluate on normalized embeddings (what training used)
    Z_all_norm = F.normalize(Z_all, dim=1)
    
    with torch.no_grad():
        logits = discriminator(Z_all_norm)
        pred = logits.argmax(dim=1)
        probs = logits.softmax(dim=1)
    
    acc = (pred == y_binary_disc).float().mean().item()
    
    # Per-class accuracy
    acc_st = (pred[y_binary_disc == 0] == 0).float().mean().item()
    acc_sc = (pred[y_binary_disc == 1] == 1).float().mean().item()
    
    # Mean predicted probabilities
    mean_probs = probs.mean(dim=0).cpu().numpy()
    
    print(f"\nDiscriminator Accuracy: {acc:.4f}")
    print(f"  ST correctly classified: {acc_st:.4f}")
    print(f"  SC correctly classified: {acc_sc:.4f}")
    print(f"  Mean predicted probs [ST, SC]: {np.round(mean_probs, 3)}")
    
    if acc > 0.85:
        print("\n‚ö†Ô∏è  Discriminator EASILY separates ST/SC (adversary didn't work)")
    elif acc > 0.70:
        print("\n‚ö†Ô∏è  Discriminator can separate ST/SC (adversary partially worked)")
    elif acc > 0.60:
        print("\n‚úì  Discriminator is somewhat confused")
    else:
        print("\n‚úì  Discriminator is very confused (near chance)")
else:
    print(f"‚ö†Ô∏è  Discriminator is {n_classes}-class (not binary)")
    print("Cannot run binary eval - discriminator architecture mismatch")

print("="*70)


In [None]:
import torch
import torch.nn.functional as F
import numpy as np

print("="*70)
print("EXPERIMENT 4: Discriminator sanity check on training-style batch")
print("="*70)

# Build the exact same pooled data you use in training
X_ssl_eval = torch.cat([st_expr, sc_expr], dim=0)
y_ssl_eval = torch.cat([
    torch.zeros(st_expr.shape[0], dtype=torch.long, device=device),  # ST=0
    torch.ones(sc_expr.shape[0], dtype=torch.long, device=device)    # SC=1
], dim=0)

# Take a balanced batch
B = 512
idx_st = torch.randperm(st_expr.shape[0], device=device)[:B//2]
idx_sc = torch.randperm(sc_expr.shape[0], device=device)[:B//2] + st_expr.shape[0]
idx = torch.cat([idx_st, idx_sc], dim=0)
idx = idx[torch.randperm(idx.shape[0], device=device)]

Xb = X_ssl_eval[idx]
yb = y_ssl_eval[idx]

encoder_vicreg.eval()
discriminator.eval()

with torch.no_grad():
    z = encoder_vicreg(Xb)
    z_in = F.normalize(z, dim=1)   # training-style discriminator input
    logits = discriminator(z_in)
    pred = logits.argmax(dim=1)
    acc = (pred == yb).float().mean().item()
    acc_st = (pred[yb==0] == 0).float().mean().item()
    acc_sc = (pred[yb==1] == 1).float().mean().item()
    pmean = logits.softmax(dim=1).mean(dim=0).cpu().numpy()

print(f"Batch disc acc: {acc:.4f} | ST acc: {acc_st:.4f} | SC acc: {acc_sc:.4f}")
print("Mean predicted probs [ST, SC]:", np.round(pmean, 3))


In [None]:
import torch
import torch.nn.functional as F
import numpy as np

print("="*70)
print("EXPERIMENT 5: SC kNN ST-neighbor rate")
print("="*70)

K = 20
Zst = torch.cat([Z1_full, Z2_full, Z3_full], dim=0)
Zsc = Zsc_full

Zst_n = F.normalize(Zst, dim=1)
Zsc_n = F.normalize(Zsc, dim=1)

Z_all = torch.cat([Zst_n, Zsc_n], dim=0)
labels = torch.cat([
    torch.zeros(Zst_n.shape[0], dtype=torch.long, device=device),  # ST=0
    torch.ones(Zsc_n.shape[0], dtype=torch.long, device=device)    # SC=1
], dim=0)

sc_start = Zst_n.shape[0]
D = torch.cdist(Z_all[sc_start:], Z_all)

# exclude self among SC points
for i in range(Zsc_n.shape[0]):
    D[i, sc_start + i] = float('inf')

_, knn = torch.topk(D, k=K, largest=False, dim=1)

frac_st_neighbors = (labels[knn] == 0).float().mean().item()
print(f"SC ‚Üí fraction of ST neighbors (K={K}): {frac_st_neighbors:.4f}")
print(f"SC base-rate ST fraction:              {Zst_n.shape[0] / (Zst_n.shape[0] + Zsc_n.shape[0]):.4f}")


In [None]:
import torch
import torch.nn.functional as F
import numpy as np

print("="*70)
print("EXPERIMENT 6: Density-controlled kNN (downsample SC)")
print("="*70)

Zst = torch.cat([Z1_full, Z2_full, Z3_full], dim=0)
Zsc = Zsc_full

# Downsample SC to match ST count
n_st = Zst.shape[0]
n_sc = Zsc.shape[0]
m = min(n_st, n_sc)

idx_st = torch.randperm(n_st, device=device)[:m]
idx_sc = torch.randperm(n_sc, device=device)[:m]

Zst_m = F.normalize(Zst[idx_st], dim=1)
Zsc_m = F.normalize(Zsc[idx_sc], dim=1)

Z_all = torch.cat([Zst_m, Zsc_m], dim=0)
labels = torch.cat([
    torch.zeros(m, dtype=torch.long, device=device),  # ST=0
    torch.ones(m, dtype=torch.long, device=device)    # SC=1
], dim=0)

K = 20
sc_start = m
D = torch.cdist(Z_all[sc_start:], Z_all)
for i in range(m):
    D[i, sc_start + i] = float('inf')

_, knn = torch.topk(D, k=K, dim=1, largest=False)
frac_st_neighbors = (labels[knn] == 0).float().mean().item()

print(f"Balanced pool size: ST={m}, SC={m}")
print(f"SC ‚Üí fraction of ST neighbors (K={K}): {frac_st_neighbors:.4f} (ideal ~0.50)")


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

print("="*70)
print("EXPERIMENT 7: Fresh discriminator on frozen embeddings")
print("="*70)

class SmallDisc(nn.Module):
    def __init__(self, d):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d, 256), nn.ReLU(),
            nn.Linear(256, 256), nn.ReLU(),
            nn.Linear(256, 2)
        )
    def forward(self, x): return self.net(x)

# Build dataset (subsample for speed)
N = 5000
Zst = torch.cat([Z1_full, Z2_full, Z3_full], dim=0)
Zsc = Zsc_full
idx_st = torch.randperm(Zst.shape[0], device=device)[:min(N, Zst.shape[0])]
idx_sc = torch.randperm(Zsc.shape[0], device=device)[:min(N, Zsc.shape[0])]

X = torch.cat([Zst[idx_st], Zsc[idx_sc]], dim=0).detach()
y = torch.cat([
    torch.zeros(idx_st.shape[0], dtype=torch.long, device=device),
    torch.ones(idx_sc.shape[0], dtype=torch.long, device=device)
], dim=0)

# Match training-style input to disc
X = F.normalize(X, dim=1)

disc = SmallDisc(X.shape[1]).to(device)
opt = torch.optim.Adam(disc.parameters(), lr=1e-3, weight_decay=1e-4)

for t in range(200):
    perm = torch.randperm(X.shape[0], device=device)[:512]
    xb, yb = X[perm], y[perm]
    logits = disc(xb)
    loss = F.cross_entropy(logits, yb)
    opt.zero_grad(set_to_none=True)
    loss.backward()
    opt.step()

with torch.no_grad():
    logits = disc(X)
    pred = logits.argmax(dim=1)
    acc = (pred == y).float().mean().item()
    acc_st = (pred[y==0] == 0).float().mean().item()
    acc_sc = (pred[y==1] == 1).float().mean().item()

print(f"Fresh disc acc: {acc:.4f} | ST acc: {acc_st:.4f} | SC acc: {acc_sc:.4f}")


In [None]:
# ===================================================================
# DIAGNOSTIC: Discriminator Analysis
# ===================================================================
print("\n" + "="*70)
print("DIAGNOSTIC: Discriminator Confusion Matrix")
print("="*70)

encoder_vicreg.eval()
discriminator.eval()

with torch.no_grad():
    # Sample balanced batch from ALL 3 SLIDES
    idx_s0 = torch.where(slide_ids == 0)[0]
    idx_s1 = torch.where(slide_ids == 1)[0]
    idx_s2 = torch.where(slide_ids == 2)[0]  # ‚Üê Added slide 3
    
    n_test = 200  # 200 per slide
    idx_test = torch.cat([
        idx_s0[torch.randperm(len(idx_s0))[:n_test]],
        idx_s1[torch.randperm(len(idx_s1))[:n_test]],
        idx_s2[torch.randperm(len(idx_s2))[:n_test]]  # ‚Üê Added
    ])
    
    X_test = st_expr[idx_test]
    s_test = slide_ids[idx_test]
    
    # Forward
    z_test = encoder_vicreg(X_test)
    z_test = F.normalize(z_test, dim=1)  # Match training
    logits_test = discriminator(z_test)
    preds_test = logits_test.argmax(dim=1)
    
    # Confusion matrix - NOW 3x3
    from sklearn.metrics import confusion_matrix
    cm = confusion_matrix(s_test.cpu().numpy(), preds_test.cpu().numpy())
    
    print("Confusion Matrix:")
    print("              Predicted")
    print("              S0    S1    S2")  # ‚Üê 3 columns
    print(f"Actual  S0   {cm[0,0]:4d}  {cm[0,1]:4d}  {cm[0,2]:4d}")
    print(f"        S1   {cm[1,0]:4d}  {cm[1,1]:4d}  {cm[1,2]:4d}")
    print(f"        S2   {cm[2,0]:4d}  {cm[2,1]:4d}  {cm[2,2]:4d}")  # ‚Üê Row 3
    
    acc_s0 = cm[0,0] / cm[0].sum() if cm[0].sum() > 0 else 0
    acc_s1 = cm[1,1] / cm[1].sum() if cm[1].sum() > 0 else 0
    acc_s2 = cm[2,2] / cm[2].sum() if cm[2].sum() > 0 else 0  # ‚Üê Added
    acc_total = np.trace(cm) / cm.sum()  # Diagonal sum / total
    
    print(f"\nPer-slide accuracy:")
    print(f"  Slide 0: {acc_s0:.3f}")
    print(f"  Slide 1: {acc_s1:.3f}")
    print(f"  Slide 2: {acc_s2:.3f}")  # ‚Üê Added
    print(f"  Overall: {acc_total:.3f}")
    
    chance = 1.0 / 3.0  # ‚Üê Changed from 0.5
    print(f"  Chance: {chance:.3f}")
    
    if acc_total < chance + 0.05:
        print("  ‚úì Discriminator is failing (good for us!)")
    elif acc_total < chance + 0.15:
        print("  ‚ö†Ô∏è  Discriminator is learning slowly")
    else:
        print("  ‚ö†Ô∏è  Discriminator is succeeding (encoder not fighting back)")
    
    # Check logit magnitudes - for 3-class, check entropy
    probs_test = torch.softmax(logits_test, dim=1)
    entropy = -(probs_test * torch.log(probs_test + 1e-10)).sum(dim=1)
    max_entropy = np.log(3)  # ‚Üê log(num_classes)
    
    print(f"\nLogit confidence:")
    print(f"  Mean entropy: {entropy.mean().item():.4f} (max={max_entropy:.4f})")
    print(f"  Mean max prob: {probs_test.max(dim=1)[0].mean().item():.4f} (uniform={chance:.3f})")
    
    if entropy.mean().item() > 0.9 * max_entropy:
        print("  ‚úì Predictions are uniform (discriminator confused)")
    else:
        print("  ‚ö†Ô∏è  Predictions are confident")
