In [None]:
import torch
import numpy as np
import pandas as pd
import anndata as ad
import sys
import matplotlib.pyplot as plt
sys.path.insert(0, '/home/ehtesamul/sc_st/model')

from core_models_et_p3 import GEMSModel
from core_models_et_p1 import STSetDataset, SCSetDataset
import utils_et as uet

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# ===================================================================
# 1. LOAD DATA
# ===================================================================
print("Loading data...")
st_counts = '/home/ehtesamul/sc_st/data/mousedata_2020/E1z2/simu_st1_counts_et.csv'
st_meta = '/home/ehtesamul/sc_st/data/mousedata_2020/E1z2/simu_st1_metadata_et.csv'
sc_counts = '/home/ehtesamul/sc_st/data/mousedata_2020/E1z2/simu_st2_counts_et.csv'
sc_meta = '/home/ehtesamul/sc_st/data/mousedata_2020/E1z2/simu_st2_metadata_et.csv'

st_expr_df = pd.read_csv(st_counts, index_col=0)
st_meta_df = pd.read_csv(st_meta, index_col=0)
stadata = ad.AnnData(X=st_expr_df.values.T)
stadata.obs_names = st_expr_df.columns
stadata.var_names = st_expr_df.index
stadata.obsm['spatial'] = st_meta_df[['coord_x', 'coord_y']].values

sc_expr_df = pd.read_csv(sc_counts, index_col=0)
sc_meta_df = pd.read_csv(sc_meta, index_col=0)
scadata = ad.AnnData(X=sc_expr_df.values.T)
scadata.obs_names = sc_expr_df.columns
scadata.var_names = sc_expr_df.index
scadata.obsm['spatial_gt'] = sc_meta_df[['coord_x', 'coord_y']].values

common = sorted(list(set(scadata.var_names) & set(stadata.var_names)))
X_st = stadata[:, common].X
X_sc = scadata[:, common].X
if hasattr(X_st, "toarray"): X_st = X_st.toarray()
if hasattr(X_sc, "toarray"): X_sc = X_sc.toarray()

st_expr = torch.tensor(X_st, dtype=torch.float32, device=device)
sc_expr = torch.tensor(X_sc, dtype=torch.float32, device=device)

st_coords_raw = torch.tensor(stadata.obsm['spatial'], dtype=torch.float32, device=device)
slide_ids = torch.zeros(st_expr.shape[0], dtype=torch.long, device=device)
st_coords, _, _ = uet.canonicalize_st_coords_per_slide(st_coords_raw, slide_ids)

print(f"‚úì ST: {st_expr.shape[0]} spots √ó {st_expr.shape[1]} genes")
print(f"‚úì SC: {sc_expr.shape[0]} cells √ó {sc_expr.shape[1]} genes")

# ===================================================================
# 2. LOAD TRAINED MODEL
# ===================================================================
print("\nLoading trained model...")
checkpoint_path = "/home/ehtesamul/sc_st/model/gems_mousebrain_output/phase1_st_checkpoint.pt"
# checkpoint_path = '/home/ehtesamul/sc_st/model/gems_v2_output/final_checkpoint_20251205_233814.pt'

n_genes = len(common)

model = GEMSModel(
    n_genes=n_genes,
    n_embedding=[512, 256, 128],
    D_latent=32,
    c_dim=256,
    n_heads=4,
    isab_m=64,
    device=device,
    use_canonicalize=True,
    use_dist_bias=True,
    dist_bins=24,
    dist_head_shared=True,
    use_angle_features=True,
    angle_bins=8,
    knn_k=12,
    self_conditioning=True,
    sc_feat_mode='concat',
    landmarks_L=0,
)

checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
model.encoder.load_state_dict(checkpoint['encoder'])
model.context_encoder.load_state_dict(checkpoint['context_encoder'])
model.generator.load_state_dict(checkpoint['generator'])
model.score_net.load_state_dict(checkpoint['score_net'])

model.encoder.eval()
model.context_encoder.eval()
model.generator.eval()
model.score_net.eval()

print("‚úì Model loaded")

# ===================================================================
# 3. RUN STAGE B
# ===================================================================
print("\nRunning Stage B...")
slides_dict = {0: (st_coords, st_expr)}
model.train_stageB(
    slides=slides_dict,
    outdir='temp_miniset_cache'
)
print("‚úì Stage B complete")

# ===================================================================
# 4. CREATE MINI-SET DATASETS
# ===================================================================
print("\nCreating mini-set datasets...")

st_gene_expr_dict_cpu = {0: st_expr.cpu()}
st_dataset = STSetDataset(
    targets_dict=model.targets_dict,
    encoder=model.encoder,
    st_gene_expr_dict=st_gene_expr_dict_cpu,
    n_min=128,
    n_max=384,
    D_latent=model.D_latent,
    num_samples=15,
    knn_k=12,
    device=device,
    landmarks_L=0,
    pool_mult=2.0,
    stochastic_tau=1.0,
)

sc_gene_expr_cpu = sc_expr.cpu()
# sc_dataset = SCSetDataset(
#     sc_gene_expr=sc_gene_expr_cpu,
#     encoder=model.encoder,
#     n_min=128,
#     n_max=384,
#     n_large_max=384,
#     num_samples=25,
#     device=device,
#     landmarks_L=0
# )


sc_dataset = SCSetDataset(
    sc_gene_expr=sc_gene_expr_cpu,
    encoder=model.encoder,
    n_min=384,
    n_max=512,
    num_samples=25,
    device=device,
    landmarks_L=0,
    pool_mult=2.0,       # same style as ST
    stochastic_tau=1.0,  # same style as ST
    knn_k=12,
)


print(f"‚úì ST dataset: {len(st_dataset)} mini-sets")
print(f"‚úì SC dataset: {len(sc_dataset)} mini-sets")

# ===================================================================
# 5. SAMPLE MINI-SETS
# ===================================================================
print("\nSampling mini-sets...")

st_minisets = []
sc_minisets = []

with torch.no_grad():
    for i in range(len(st_dataset)):
        st_minisets.append(st_dataset[i])
    
    for i in range(len(sc_dataset)):
        sc_minisets.append(sc_dataset[i])

print(f"‚úì Sampled {len(st_minisets)} ST mini-sets")
print(f"‚úì Sampled {len(sc_minisets)} SC mini-sets")

# Check structure
print(f"\nST mini-set #0 keys: {list(st_minisets[0].keys())}")
print(f"SC mini-set #0 keys: {list(sc_minisets[0].keys())}")

print(f"\nSample ST mini-set #0:")
print(f"  Z_set shape: {st_minisets[0]['Z_set'].shape}")
print(f"  V_target shape: {st_minisets[0]['V_target'].shape}")
print(f"  n points: {st_minisets[0]['n']}")

# print(f"\nSample SC mini-set #0 (overlapping pair A+B):")
# print(f"  Z_A shape: {sc_minisets[0]['Z_A'].shape}")
# print(f"  Z_B shape: {sc_minisets[0]['Z_B'].shape}")
# print(f"  n_A: {sc_minisets[0]['n_A']}, n_B: {sc_minisets[0]['n_B']}")
# print(f"  shared_A length: {len(sc_minisets[0]['shared_A'])}")
# print(f"  shared_B length: {len(sc_minisets[0]['shared_B'])}")

# ===================================================================
# 6. PLOT MINI-SETS
# ===================================================================
print("\nPlotting mini-sets...")

# PARAMETERS - Adjust as needed
n_st_plots = 5  # Number of ST mini-sets to plot
n_sc_plots = 10  # Number of SC mini-sets to plot
max_cols = 3    # Maximum plots per row

# Plot ST mini-sets
if n_st_plots > 0:
    n_rows_st = (n_st_plots + max_cols - 1) // max_cols
    n_cols_st = min(n_st_plots, max_cols)
    
    fig_st, axes_st = plt.subplots(n_rows_st, n_cols_st, 
                                    figsize=(6.5*n_cols_st, 6*n_rows_st))
    axes_st = np.atleast_2d(axes_st).reshape(n_rows_st, n_cols_st)
    
    for i in range(n_st_plots):
        row, col = i // max_cols, i % max_cols
        ax = axes_st[row, col]
        coords = st_minisets[i]['V_target'].cpu().numpy()
        n = st_minisets[i]['n']
        ax.scatter(coords[:n, 0], coords[:n, 1], s=20, alpha=0.7)
        ax.set_title(f'ST Mini-set {i} (n={n})', fontsize=10)
        ax.set_aspect('equal')
        ax.grid(True, alpha=0.3)
    
    # Hide unused subplots
    for i in range(n_st_plots, n_rows_st * n_cols_st):
        row, col = i // max_cols, i % max_cols
        axes_st[row, col].axis('off')
    
    plt.tight_layout()
    # plt.savefig('st_minisets_visualization.png', dpi=150, bbox_inches='tight')
    plt.show()

# Plot SC mini-sets
if n_sc_plots > 0:
    n_rows_sc = (n_sc_plots + max_cols - 1) // max_cols
    n_cols_sc = min(n_sc_plots, max_cols)
    
    fig_sc, axes_sc = plt.subplots(n_rows_sc, n_cols_sc,
                                    figsize=(6.5*n_cols_sc, 6*n_rows_sc))
    axes_sc = np.atleast_2d(axes_sc).reshape(n_rows_sc, n_cols_sc)
    
    for i in range(n_sc_plots):
        row, col = i // max_cols, i % max_cols
        ax = axes_sc[row, col]
        # indices_A = sc_minisets[i]['global_indices_A'].cpu().numpy()
        # n_A = sc_minisets[i]['n_A']
        # coords_gt = scadata.obsm['spatial_gt'][indices_A[:n_A]]
        # ax.scatter(coords_gt[:, 0], coords_gt[:, 1], s=20, alpha=0.7, c='orange')
        # ax.set_title(f'SC Mini-set {i} - Set A (n={n_A})', fontsize=10)
        # ax.set_aspect('equal')
        # ax.grid(True, alpha=0.3)
        indices = sc_minisets[i]['global_indices'].cpu().numpy()
        n = sc_minisets[i]['n']
        coords_gt = scadata.obsm['spatial_gt'][indices[:n]]
        ax.scatter(coords_gt[:, 0], coords_gt[:, 1], s=20, alpha=0.7, c='orange')
        ax.set_title(f'SC Mini-set {i} (n={n})', fontsize=10)

    
    # Hide unused subplots
    for i in range(n_sc_plots, n_rows_sc * n_cols_sc):
        row, col = i // max_cols, i % max_cols
        axes_sc[row, col].axis('off')
    
    plt.tight_layout()
    # plt.savefig('sc_minisets_visualization.png', dpi=150, bbox_inches='tight')
    plt.show()

print("\n‚úì MINI-SETS READY FOR INFERENCE!")
print("\nNote: SC mini-sets contain overlapping pairs (A, B)")
print("      Each has Z_A, Z_B embeddings and global_indices_A, global_indices_B")

In [None]:
from sklearn.neighbors import NearestNeighbors

# ===================================================================
# COMPUTE GLOBAL kNN ON ST COORDINATES
# ===================================================================
k_global = 15

st_coords_cpu = st_coords.cpu().numpy()

nbrs = NearestNeighbors(n_neighbors=k_global + 1, algorithm='ball_tree').fit(st_coords_cpu)
distances, global_knn_indices = nbrs.kneighbors(st_coords_cpu)
global_knn_indices = global_knn_indices[:, 1:]  # Remove self

print(f"‚úì Computed global {k_global}-NN for {st_coords_cpu.shape[0]} ST spots")

# ===================================================================
# CHECK NEIGHBOR COVERAGE IN EACH MINISET
# ===================================================================
coverage_ratios = []

for miniset_idx, miniset in enumerate(st_minisets):
    global_idx = miniset['overlap_info']['indices'].cpu().numpy()
    
    for local_i, global_i in enumerate(global_idx):
        true_neighbors = set(global_knn_indices[global_i])
        miniset_cells = set(global_idx)
        neighbors_in_miniset = true_neighbors & miniset_cells
        coverage = len(neighbors_in_miniset) / k_global
        coverage_ratios.append(coverage)

coverage_ratios = np.array(coverage_ratios)

# ===================================================================
# STATISTICS
# ===================================================================
print(f"\n{'='*60}")
print(f"NEIGHBOR COVERAGE STATISTICS")
print(f"{'='*60}")
print(f"Total cell appearances: {len(coverage_ratios)}")
print(f"Mean coverage: {coverage_ratios.mean():.3f} ({coverage_ratios.mean() * k_global:.1f}/{k_global})")
print(f"Median coverage: {np.median(coverage_ratios):.3f} ({np.median(coverage_ratios) * k_global:.1f}/{k_global})")
print(f"Min coverage: {coverage_ratios.min():.3f} ({coverage_ratios.min() * k_global:.1f}/{k_global})")
print(f"Max coverage: {coverage_ratios.max():.3f} ({coverage_ratios.max() * k_global:.1f}/{k_global})")

print(f"\nCoverage distribution:")
for threshold in [0, 2, 5, 8, 10, 12, 14]:
    count = np.sum(coverage_ratios >= threshold / k_global)
    pct = 100 * count / len(coverage_ratios)
    print(f"  ‚â•{threshold:2d}/{k_global}: {count:6d} ({pct:5.1f}%)")

# ===================================================================
# PLOT
# ===================================================================
fig, ax = plt.subplots(1, 1, figsize=(10, 5))
ax.hist(coverage_ratios * k_global, bins=np.arange(0, k_global + 2) - 0.5, 
        edgecolor='black', alpha=0.7)
ax.axvline(np.mean(coverage_ratios) * k_global, color='red', linestyle='--', 
           linewidth=2, label=f'Mean: {coverage_ratios.mean() * k_global:.1f}')
ax.axvline(np.median(coverage_ratios) * k_global, color='orange', linestyle='--', 
           linewidth=2, label=f'Median: {np.median(coverage_ratios) * k_global:.1f}')
ax.set_xlabel(f'Number of true {k_global}-NN present in miniset', fontsize=12)
ax.set_ylabel('Frequency', fontsize=12)
ax.set_title('ST Miniset Neighbor Coverage', fontsize=14, fontweight='bold')
ax.legend()
ax.grid(alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial.distance import pdist, squareform
from scipy.stats import pearsonr, spearmanr
from sklearn.neighbors import NearestNeighbors

import torch
import numpy as np
import pandas as pd
import anndata as ad
import sys
import matplotlib.pyplot as plt
import os

sys.path.insert(0, '/home/ehtesamul/sc_st/model')


from core_models_et_p3 import GEMSModel
from core_models_et_p1 import STSetDataset, SCSetDataset
import utils_et as uet

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# ===================================================================
# 1. LOAD DATA
# ===================================================================
print("Loading data...")
st_counts = '/home/ehtesamul/sc_st/data/mousedata_2020/E1z2/simu_st1_counts_et.csv'
st_meta = '/home/ehtesamul/sc_st/data/mousedata_2020/E1z2/simu_st1_metadata_et.csv'
sc_counts = '/home/ehtesamul/sc_st/data/mousedata_2020/E1z2/simu_st2_counts_et.csv'
sc_meta = '/home/ehtesamul/sc_st/data/mousedata_2020/E1z2/simu_st2_metadata_et.csv'

st_expr_df = pd.read_csv(st_counts, index_col=0)
st_meta_df = pd.read_csv(st_meta, index_col=0)
stadata = ad.AnnData(X=st_expr_df.values.T)
stadata.obs_names = st_expr_df.columns
stadata.var_names = st_expr_df.index
stadata.obsm['spatial'] = st_meta_df[['coord_x', 'coord_y']].values

sc_expr_df = pd.read_csv(sc_counts, index_col=0)
sc_meta_df = pd.read_csv(sc_meta, index_col=0)
scadata = ad.AnnData(X=sc_expr_df.values.T)
scadata.obs_names = sc_expr_df.columns
scadata.var_names = sc_expr_df.index
scadata.obsm['spatial_gt'] = sc_meta_df[['coord_x', 'coord_y']].values

common = sorted(list(set(scadata.var_names) & set(stadata.var_names)))
X_st = stadata[:, common].X
X_sc = scadata[:, common].X
if hasattr(X_st, "toarray"): X_st = X_st.toarray()
if hasattr(X_sc, "toarray"): X_sc = X_sc.toarray()

st_expr = torch.tensor(X_st, dtype=torch.float32, device=device)
sc_expr = torch.tensor(X_sc, dtype=torch.float32, device=device)

st_coords_raw = torch.tensor(stadata.obsm['spatial'], dtype=torch.float32, device=device)
slide_ids = torch.zeros(st_expr.shape[0], dtype=torch.long, device=device)
st_coords, _, _ = uet.canonicalize_st_coords_per_slide(st_coords_raw, slide_ids)

print(f"‚úì ST: {st_expr.shape[0]} spots √ó {st_expr.shape[1]} genes")
print(f"‚úì SC: {sc_expr.shape[0]} cells √ó {sc_expr.shape[1]} genes")

# ===================================================================
# 2. LOAD TRAINED MODEL
# ===================================================================
print("\nLoading trained model...")
# checkpoint_path = "/home/ehtesamul/sc_st/model/gems_mousebrain_output/phase2_sc_finetuned_checkpoint.pt"
checkpoint_path = "/home/ehtesamul/sc_st/model/gems_mousebrain_output/phase1_st_checkpoint.pt"


n_genes = len(common)
model = GEMSModel(
    n_genes=n_genes,
    n_embedding=[512, 256, 128],
    D_latent=32,
    c_dim=256,
    n_heads=4,
    isab_m=64,
    device=device,
    use_canonicalize=True,
    use_dist_bias=True,
    dist_bins=24,
    dist_head_shared=True,
    use_angle_features=True,
    angle_bins=8,
    knn_k=12,
    self_conditioning=True,
    sc_feat_mode='concat',
    landmarks_L=16
)

checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
model.encoder.load_state_dict(checkpoint['encoder'])
model.context_encoder.load_state_dict(checkpoint['context_encoder'])
model.generator.load_state_dict(checkpoint['generator'])
model.score_net.load_state_dict(checkpoint['score_net'])

model.encoder.eval()
model.context_encoder.eval()
model.generator.eval()
model.score_net.eval()

# ===================================================================
# COMPUTE CORAL TRANSFORMATION
# ===================================================================
print("\n" + "="*70)
print("COMPUTING CORAL TRANSFORMATION")
print("="*70)

# 1. Prepare ST gene expression dict
st_gene_expr_dict = {0: st_expr.cpu()}

# 2. Load Stage B targets (needed for compute_coral_params_from_st)
targets_path = "/home/ehtesamul/sc_st/model/gems_mousebrain_output/stageB_targets/targets_dict.pt"
if os.path.exists(targets_path):
    model.targets_dict = torch.load(targets_path, map_location='cpu')
    print(f"‚úì Loaded targets_dict from {targets_path}")
else:
    print(f"‚ö†Ô∏è  Targets not found at {targets_path}")
    print("   Running Stage B precomputation...")
    slides_dict = {0: (st_coords, st_expr)}
    model.train_stageB(slides=slides_dict, outdir="/home/ehtesamul/sc_st/model/gems_mousebrain_output/stageB_targets")
    print("‚úì Stage B complete")

# 3. Compute CORAL parameters
print("\n--- Computing ST context distribution ---")
model.compute_coral_params_from_st(
    st_gene_expr_dict=st_gene_expr_dict,
    n_samples=2000,
    n_min=96,
    n_max=384,
)

print("\n--- Building CORAL transformation ---")
model.build_coral_transform(
    sc_gene_expr=sc_expr,
    n_samples=2000,
    n_min=96,
    n_max=384,
    shrink=0.01,
    eps=1e-5,
)

print("‚úì CORAL transformation ready!")

print("="*70)
print("MINI-SET INFERENCE AND ANALYSIS")
print("="*70)

# ===================================================================
# HELPER FUNCTION: k-NN PRESERVATION
# ===================================================================
def compute_knn_preservation(coords_gt, coords_pred, k=10):
    """Compute k-nearest neighbor preservation."""
    n = coords_gt.shape[0]
    
    nbrs_gt = NearestNeighbors(n_neighbors=k+1, algorithm='ball_tree').fit(coords_gt)
    _, indices_gt = nbrs_gt.kneighbors(coords_gt)
    indices_gt = indices_gt[:, 1:]
    
    nbrs_pred = NearestNeighbors(n_neighbors=k+1, algorithm='ball_tree').fit(coords_pred)
    _, indices_pred = nbrs_pred.kneighbors(coords_pred)
    indices_pred = indices_pred[:, 1:]
    
    overlaps = []
    for i in range(n):
        gt_neighbors = set(indices_gt[i])
        pred_neighbors = set(indices_pred[i])
        overlap = len(gt_neighbors.intersection(pred_neighbors))
        overlaps.append(overlap)
    
    return np.mean(overlaps), overlaps

# ===================================================================
# HELPER FUNCTION: COMPUTE ANISOTROPY
# ===================================================================
def compute_anisotropy(coords):
    """Compute eigenvalue anisotropy ratio Œª1/Œª2"""
    X = coords.astype(float)
    Xc = X - X.mean(axis=0, keepdims=True)
    
    cov = Xc.T @ Xc / (Xc.shape[0] - 1)
    eigvals, eigvecs = np.linalg.eigh(cov)
    eigvals = eigvals[::-1]
    
    lam1, lam2 = eigvals[0], eigvals[1]
    ratio = lam1 / (lam2 + 1e-12)
    
    return lam1, lam2, ratio

# ===================================================================
# ANALYZE ST MINI-SETS
# ===================================================================
print("\n" + "="*70)
print("ANALYZING ST MINI-SETS (5 samples)")
print("="*70)

st_results = []

with torch.no_grad():
    for idx in range(5):
        print(f"\n--- ST Mini-set {idx} ---")
        
        miniset = st_minisets[idx]
        n = miniset['n']
        
        # Ground truth coordinates
        coords_gt = miniset['V_target'].cpu().numpy()[:n]
        
        # Extract gene expression for inference
        # Use overlap_info to get original indices
        slide_id = 0  # We only have one slide
        # indices = miniset['overlap_info']['indices'][slide_id][:n]
        indices = miniset['overlap_info']['indices'][:n]
        slide_id = miniset['overlap_info']['slide_id']
        gene_expr = st_gene_expr_dict_cpu[slide_id][indices]
        
        print(f"  n_spots: {n}")
        print(f"  Running inference with patch_size={n}, coverage=1.0, iters=1...")
        
        if 'sigma_data' in checkpoint:
            model.sigma_data = checkpoint['sigma_data']
        if 'sigma_min' in checkpoint:
            model.sigma_min = checkpoint['sigma_min']
        if 'sigma_max' in checkpoint:
            model.sigma_max = checkpoint['sigma_max']


        with torch.no_grad():
            results = model.infer_sc_patchwise(
                sc_gene_expr=gene_expr,
                n_timesteps_sample=500,
                return_coords=True,
                patch_size=n,
                coverage_per_cell=1.0,
                n_align_iters=1,
                eta=0.0,
                guidance_scale=2.0,   # <-- changed from 4.0 to 2.0
                debug_flag=False
            )

        # Extract predictions
        D_edm_pred = results['D_edm'].cpu().numpy()
        coords_pred = results['coords_canon'].cpu().numpy()
        
        # Compute ground truth EDM
        gt_edm = squareform(pdist(coords_gt, 'euclidean'))
        
        # Extract upper triangle
        triu_indices = np.triu_indices(n, k=1)
        gt_distances = gt_edm[triu_indices]
        pred_distances = D_edm_pred[triu_indices]
        
        # Scale alignment
        scale = np.median(gt_distances) / (np.median(pred_distances) + 1e-12)
        pred_distances_scaled = pred_distances * scale
        
        # Correlations
        pearson_corr, _ = pearsonr(gt_distances, pred_distances_scaled)
        spearman_corr, _ = spearmanr(gt_distances, pred_distances_scaled)
        
        # k-NN preservation
        knn_k10, _ = compute_knn_preservation(coords_gt, coords_pred, k=min(10, n-1))
        knn_k20, _ = compute_knn_preservation(coords_gt, coords_pred, k=min(20, n-1))
        
        # Anisotropy
        lam1_gt, lam2_gt, ratio_gt = compute_anisotropy(coords_gt)
        lam1_pred, lam2_pred, ratio_pred = compute_anisotropy(coords_pred)
        
        # Store results
        st_results.append({
            'idx': idx,
            'n': n,
            'coords_gt': coords_gt,
            'coords_pred': coords_pred,
            'pearson': pearson_corr,
            'spearman': spearman_corr,
            'knn_k10': knn_k10 / min(10, n-1),
            'knn_k20': knn_k20 / min(20, n-1),
            'scale': scale,
            'ratio_gt': ratio_gt,
            'ratio_pred': ratio_pred,
            'gt_distances': gt_distances,
            'pred_distances_scaled': pred_distances_scaled,
        })
        
        print(f"  Pearson:  {pearson_corr:.4f}")
        print(f"  Spearman: {spearman_corr:.4f}")
        print(f"  k-NN@10:  {knn_k10/min(10, n-1):.4f}")
        print(f"  k-NN@20:  {knn_k20/min(20, n-1):.4f}")
        print(f"  Œª‚ÇÅ/Œª‚ÇÇ GT:   {ratio_gt:.2f}")
        print(f"  Œª‚ÇÅ/Œª‚ÇÇ Pred: {ratio_pred:.2f}")

# ===================================================================
# ANALYZE SC MINI-SETS (FIXED: Canonicalize GT coords)
# ===================================================================
# ===================================================================
# CONFIGURATION
# ===================================================================
N_SC_SAMPLES = 10  # Change this to analyze any number of samples

# ===================================================================
# ANALYZE SC MINI-SETS (FIXED: Canonicalize GT coords)
# ===================================================================
print("\n" + "="*70)
print(f"ANALYZING SC MINI-SETS ({N_SC_SAMPLES} samples) - FIXED COORDINATE SPACE")
print("="*70)

sc_results = []

with torch.no_grad():
    for idx in range(N_SC_SAMPLES):
        print(f"\n--- SC Mini-set {idx} ---")
        
        miniset = sc_minisets[idx]
        n = miniset['n']

        indices = miniset['global_indices'].cpu().numpy()[:n]
        coords_gt_raw = scadata.obsm['spatial_gt'][indices]

        coords_gt_tensor = torch.tensor(coords_gt_raw, dtype=torch.float32, device=device)
        slide_ids_mini = torch.zeros(n, dtype=torch.long, device=device)
        coords_gt_canon, gt_mu, gt_scale = uet.canonicalize_st_coords_per_slide(
            coords_gt_tensor, slide_ids_mini
        )
        coords_gt = coords_gt_canon.cpu().numpy()
        
        print(f"  [DEBUG] GT raw range: X=[{coords_gt_raw[:,0].min():.2f}, {coords_gt_raw[:,0].max():.2f}]")
        print(f"  [DEBUG] GT canon range: X=[{coords_gt[:,0].min():.3f}, {coords_gt[:,0].max():.3f}]")
        print(f"  [DEBUG] GT scale factor: {gt_scale[0].item():.4f}")
        
        gene_expr = sc_gene_expr_cpu[indices]
        
        print(f"  n_cells: {n}")
        print(f"  Running inference...")

        if 'sigma_data' in checkpoint:
            model.sigma_data = checkpoint['sigma_data']
        if 'sigma_min' in checkpoint:
            model.sigma_min = checkpoint['sigma_min']
        if 'sigma_max' in checkpoint:
            model.sigma_max = checkpoint['sigma_max']

        with torch.no_grad():
            results = model.infer_sc_patchwise(
                sc_gene_expr=gene_expr,
                n_timesteps_sample=500,
                return_coords=True,
                patch_size=n,
                coverage_per_cell=1.0,
                n_align_iters=1,
                eta=0.0,
                guidance_scale=2.0,
                debug_flag=False
            )
        
        D_edm_pred = results['D_edm'].cpu().numpy()
        coords_pred = results['coords_canon'].cpu().numpy()
        
        print(f"  [DEBUG] Pred range: X=[{coords_pred[:,0].min():.3f}, {coords_pred[:,0].max():.3f}]")
        
        gt_edm = squareform(pdist(coords_gt, 'euclidean'))
        
        triu_indices = np.triu_indices(n, k=1)
        gt_distances = gt_edm[triu_indices]
        pred_distances = D_edm_pred[triu_indices]
        
        scale = np.median(gt_distances) / (np.median(pred_distances) + 1e-12)
        pred_distances_scaled = pred_distances * scale
        
        pearson_corr, _ = pearsonr(gt_distances, pred_distances_scaled)
        spearman_corr, _ = spearmanr(gt_distances, pred_distances_scaled)
        
        knn_k10, _ = compute_knn_preservation(coords_gt, coords_pred, k=min(10, n-1))
        knn_k20, _ = compute_knn_preservation(coords_gt, coords_pred, k=min(20, n-1))
        
        lam1_gt, lam2_gt, ratio_gt = compute_anisotropy(coords_gt)
        lam1_pred, lam2_pred, ratio_pred = compute_anisotropy(coords_pred)
        
        sc_results.append({
            'idx': idx,
            'n': n,
            'coords_gt': coords_gt,
            'coords_gt_raw': coords_gt_raw,
            'coords_pred': coords_pred,
            'pearson': pearson_corr,
            'spearman': spearman_corr,
            'knn_k10': knn_k10 / min(10, n-1),
            'knn_k20': knn_k20 / min(20, n-1),
            'scale': scale,
            'ratio_gt': ratio_gt,
            'ratio_pred': ratio_pred,
            'gt_distances': gt_distances,
            'pred_distances_scaled': pred_distances_scaled,
        })
        
        print(f"  Pearson:  {pearson_corr:.4f}")
        print(f"  Spearman: {spearman_corr:.4f}")
        print(f"  k-NN@10:  {knn_k10/min(10, n-1):.4f}")
        print(f"  k-NN@20:  {knn_k20/min(20, n-1):.4f}")
        print(f"  Œª‚ÇÅ/Œª‚ÇÇ GT:   {ratio_gt:.2f}")
        print(f"  Œª‚ÇÅ/Œª‚ÇÇ Pred: {ratio_pred:.2f}")

# ===================================================================
# SUMMARY STATISTICS
# ===================================================================
print("\n" + "="*70)
print("SUMMARY STATISTICS")
print("="*70)

print("\n--- ST MINI-SETS (n=5) ---")
print(f"Pearson:       {np.mean([r['pearson'] for r in st_results]):.4f} ¬± {np.std([r['pearson'] for r in st_results]):.4f}")
print(f"Spearman:      {np.mean([r['spearman'] for r in st_results]):.4f} ¬± {np.std([r['spearman'] for r in st_results]):.4f}")
print(f"k-NN@10:       {np.mean([r['knn_k10'] for r in st_results]):.4f} ¬± {np.std([r['knn_k10'] for r in st_results]):.4f}")
print(f"k-NN@20:       {np.mean([r['knn_k20'] for r in st_results]):.4f} ¬± {np.std([r['knn_k20'] for r in st_results]):.4f}")
print(f"Anisotropy GT: {np.mean([r['ratio_gt'] for r in st_results]):.2f} ¬± {np.std([r['ratio_gt'] for r in st_results]):.2f}")
print(f"Anisotropy PR: {np.mean([r['ratio_pred'] for r in st_results]):.2f} ¬± {np.std([r['ratio_pred'] for r in st_results]):.2f}")

print(f"\n--- SC MINI-SETS (n={N_SC_SAMPLES}) ---")
print(f"Pearson:       {np.mean([r['pearson'] for r in sc_results]):.4f} ¬± {np.std([r['pearson'] for r in sc_results]):.4f}")
print(f"Spearman:      {np.mean([r['spearman'] for r in sc_results]):.4f} ¬± {np.std([r['spearman'] for r in sc_results]):.4f}")
print(f"k-NN@10:       {np.mean([r['knn_k10'] for r in sc_results]):.4f} ¬± {np.std([r['knn_k10'] for r in sc_results]):.4f}")
print(f"k-NN@20:       {np.mean([r['knn_k20'] for r in sc_results]):.4f} ¬± {np.std([r['knn_k20'] for r in sc_results]):.4f}")
print(f"Anisotropy GT: {np.mean([r['ratio_gt'] for r in sc_results]):.2f} ¬± {np.std([r['ratio_gt'] for r in sc_results]):.2f}")
print(f"Anisotropy PR: {np.mean([r['ratio_pred'] for r in sc_results]):.2f} ¬± {np.std([r['ratio_pred'] for r in sc_results]):.2f}")

# ===================================================================
# VISUALIZATION: COMPARISON PLOTS (FIXED)
# ===================================================================
print("\n" + "="*70)
print("CREATING COMPARISON PLOTS (FIXED COORDINATE SPACE)")
print("="*70)

MAX_COLS = 3
n_sc = len(sc_results)
n_st = len(st_results)
n_sc_cols = min(n_sc, MAX_COLS)
n_st_cols = min(n_st, MAX_COLS)
n_sc_rows = int(np.ceil(n_sc / MAX_COLS))
n_st_rows = int(np.ceil(n_st / MAX_COLS))

total_rows = 2 * n_st_rows + 2 * n_sc_rows + 1
total_cols = MAX_COLS
fig = plt.figure(figsize=(10 * total_cols, 7 * total_rows))

current_row = 0

# ST GT (canonical)
for i, result in enumerate(st_results):
    row = i // MAX_COLS
    col = i % MAX_COLS
    ax = plt.subplot(total_rows, total_cols, row * total_cols + col + 1)
    coords = result['coords_gt']
    ax.scatter(coords[:, 0], coords[:, 1], s=20, alpha=0.7, c='blue')
    ax.set_title(f'ST {i} GT (canon)\nn={result["n"]}', fontsize=9)
    ax.set_aspect('equal')
    ax.grid(True, alpha=0.3)

current_row += n_st_rows

# ST Pred (canonical)
for i, result in enumerate(st_results):
    row = i // MAX_COLS
    col = i % MAX_COLS
    ax = plt.subplot(total_rows, total_cols, (current_row + row) * total_cols + col + 1)
    coords = result['coords_pred']
    ax.scatter(coords[:, 0], coords[:, 1], s=20, alpha=0.7, c='red')
    ax.set_title(f'ST {i} Pred\nkNN={result["knn_k10"]:.3f}', fontsize=9)
    ax.set_aspect('equal')
    ax.grid(True, alpha=0.3)

current_row += n_st_rows

# Separator row
ax = plt.subplot(total_rows, total_cols, current_row * total_cols + 2)
ax.text(0.5, 0.5, 'SC RESULTS BELOW', ha='center', va='center', fontsize=14, fontweight='bold')
ax.axis('off')

current_row += 1

# SC GT (canonical)
for i, result in enumerate(sc_results):
    row = i // MAX_COLS
    col = i % MAX_COLS
    ax = plt.subplot(total_rows, total_cols, (current_row + row) * total_cols + col + 1)
    coords = result['coords_gt']
    ax.scatter(coords[:, 0], coords[:, 1], s=20, alpha=0.7, c='green')
    ax.set_title(f'SC {i} GT (canon)\nn={result["n"]}', fontsize=9)
    ax.set_aspect('equal')
    ax.grid(True, alpha=0.3)

current_row += n_sc_rows

# SC Pred (canonical)
for i, result in enumerate(sc_results):
    row = i // MAX_COLS
    col = i % MAX_COLS
    ax = plt.subplot(total_rows, total_cols, (current_row + row) * total_cols + col + 1)
    coords = result['coords_pred']
    ax.scatter(coords[:, 0], coords[:, 1], s=20, alpha=0.7, c='orange')
    ax.set_title(f'SC {i} Pred\nkNN={result["knn_k10"]:.3f}', fontsize=9)
    ax.set_aspect('equal')
    ax.grid(True, alpha=0.3)

plt.suptitle(f'Mini-set Inference: ST (n={n_st}) vs SC (n={n_sc}) - All in Canonical Space', 
             fontsize=14, fontweight='bold', y=0.995)
plt.tight_layout()
plt.show()

# ===================================================================
# METRIC COMPARISON BAR PLOTS
# ===================================================================
fig, axes = plt.subplots(2, 3, figsize=(18, 10))

metrics = ['pearson', 'spearman', 'knn_k10', 'knn_k20', 'ratio_gt', 'ratio_pred']
titles = ['Pearson Correlation', 'Spearman Correlation', 'k-NN@10 Preservation', 
          'k-NN@20 Preservation', 'Anisotropy (GT)', 'Anisotropy (Pred)']

for ax, metric, title in zip(axes.flat, metrics, titles):
    st_vals = [r[metric] for r in st_results]
    sc_vals = [r[metric] for r in sc_results]
    
    x_st = np.arange(len(st_vals))
    x_sc = np.arange(len(sc_vals))
    width = 0.35
    
    ax.bar(x_st - width/2, st_vals, width, label='ST', alpha=0.7, color='blue')
    ax.bar(x_sc + width/2, sc_vals, width, label='SC', alpha=0.7, color='orange')
    
    ax.set_xlabel('Mini-set Index')
    ax.set_ylabel(title)
    ax.set_title(title, fontweight='bold')
    ax.set_xticks(np.arange(max(len(st_vals), len(sc_vals))))
    ax.legend()
    ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

print("\n‚úì Analysis complete!")


In [None]:
# ===================================================================
# CHECK 1: CONDITIONING SENSITIVITY TEST
# Tests if SC conditioning is being used or ignored
# ===================================================================
print("\n" + "="*70)
print("CHECK 1: CONDITIONING SENSITIVITY TEST")
print("="*70)

def run_conditioning_sensitivity_test(model, sc_minisets, scadata, sc_gene_expr_cpu, 
                                       checkpoint, device, n_tests=3):
    """
    For the same SC miniset, compare inference with:
    - Real expression
    - Shuffled expression (permuted rows)
    - Mean expression (all cells = mean vector)
    
    If outputs are similar, conditioning is being ignored.
    """
    results = []
    
    # Set sigma params
    if 'sigma_data' in checkpoint:
        model.sigma_data = checkpoint['sigma_data']
    if 'sigma_min' in checkpoint:
        model.sigma_min = checkpoint['sigma_min']
    if 'sigma_max' in checkpoint:
        model.sigma_max = checkpoint['sigma_max']
    
    for idx in range(n_tests):
        print(f"\n--- SC Mini-set {idx} ---")
        
        miniset = sc_minisets[idx]
        n_A = miniset['n_A']
        indices_A = miniset['global_indices_A'].cpu().numpy()[:n_A]
        
        # Get GT coords (canonicalized)
        coords_gt_raw = scadata.obsm['spatial_gt'][indices_A]
        coords_gt_tensor = torch.tensor(coords_gt_raw, dtype=torch.float32, device=device)
        slide_ids_mini = torch.zeros(n_A, dtype=torch.long, device=device)
        coords_gt_canon, _, _ = uet.canonicalize_st_coords_per_slide(coords_gt_tensor, slide_ids_mini)
        coords_gt = coords_gt_canon.cpu().numpy()
        
        # Real expression
        gene_expr_real = sc_gene_expr_cpu[indices_A].clone()
        
        # Shuffled expression (permute rows)
        perm = torch.randperm(n_A)
        gene_expr_shuffled = gene_expr_real[perm].clone()
        
        # Mean expression (all cells = mean vector)
        gene_expr_mean = gene_expr_real.mean(dim=0, keepdim=True).expand(n_A, -1).clone()
        
        # Fixed noise seed for fair comparison
        torch.manual_seed(42 + idx)
        np.random.seed(42 + idx)
        
        conditions = [
            ('REAL', gene_expr_real),
            ('SHUFFLED', gene_expr_shuffled),
            ('MEAN', gene_expr_mean),
        ]
        
        condition_results = {}
        
        for cond_name, gene_expr in conditions:
            # Reset seed for each condition (same noise)
            torch.manual_seed(42 + idx)
            np.random.seed(42 + idx)
            
            with torch.no_grad():
                inf_results = model.infer_sc_patchwise(
                    sc_gene_expr=gene_expr,
                    n_timesteps_sample=500,
                    return_coords=True,
                    patch_size=n_A,
                    coverage_per_cell=1.0,
                    n_align_iters=1,
                    eta=0.0,
                    guidance_scale=2.0,
                    debug_flag=False
                )
            
            coords_pred = inf_results['coords_canon'].cpu().numpy()
            D_pred = inf_results['D_edm'].cpu().numpy()
            
            # Compute metrics
            gt_edm = squareform(pdist(coords_gt, 'euclidean'))
            triu_idx = np.triu_indices(n_A, k=1)
            
            # Correlations
            spear, _ = spearmanr(gt_edm[triu_idx], D_pred[triu_idx])
            
            # kNN
            knn_k10, _ = compute_knn_preservation(coords_gt, coords_pred, k=min(10, n_A-1))
            knn_k10_norm = knn_k10 / min(10, n_A-1)
            
            condition_results[cond_name] = {
                'coords': coords_pred,
                'D_pred': D_pred,
                'spearman': spear,
                'knn10': knn_k10_norm,
            }
            
            print(f"  [{cond_name}] Spearman={spear:.4f}, kNN@10={knn_k10_norm:.4f}")
        
        # Compute similarity between conditions
        D_real = condition_results['REAL']['D_pred']
        D_shuffled = condition_results['SHUFFLED']['D_pred']
        D_mean = condition_results['MEAN']['D_pred']
        
        # Correlation between distance matrices
        corr_real_shuffled, _ = pearsonr(D_real[triu_idx], D_shuffled[triu_idx])
        corr_real_mean, _ = pearsonr(D_real[triu_idx], D_mean[triu_idx])
        
        print(f"\n[COND-SENSITIVITY-{idx}] Checking if conditioning matters...")
        print(f"[COND-SENSITIVITY-{idx}] D_pred correlation REAL vs SHUFFLED: {corr_real_shuffled:.4f}")
        print(f"[COND-SENSITIVITY-{idx}] D_pred correlation REAL vs MEAN: {corr_real_mean:.4f}")
        
        if corr_real_shuffled > 0.95 and corr_real_mean > 0.95:
            print(f"[COND-SENSITIVITY-{idx}] ‚ö†Ô∏è HIGH CORRELATION ‚Üí CONDITIONING IS BEING IGNORED!")
        elif corr_real_shuffled > 0.8 or corr_real_mean > 0.8:
            print(f"[COND-SENSITIVITY-{idx}] ‚ö†Ô∏è MODERATE CORRELATION ‚Üí WEAK CONDITIONING")
        else:
            print(f"[COND-SENSITIVITY-{idx}] ‚úì LOW CORRELATION ‚Üí CONDITIONING IS ACTIVE")
        
        # Store for summary
        results.append({
            'idx': idx,
            'n': n_A,
            'corr_real_shuffled': corr_real_shuffled,
            'corr_real_mean': corr_real_mean,
            'spear_real': condition_results['REAL']['spearman'],
            'spear_shuffled': condition_results['SHUFFLED']['spearman'],
            'spear_mean': condition_results['MEAN']['spearman'],
            'knn_real': condition_results['REAL']['knn10'],
            'knn_shuffled': condition_results['SHUFFLED']['knn10'],
            'knn_mean': condition_results['MEAN']['knn10'],
        })
    
    # Summary
    print("\n" + "="*70)
    print("[COND-SENSITIVITY-SUMMARY]")
    print("="*70)
    
    avg_corr_shuffled = np.mean([r['corr_real_shuffled'] for r in results])
    avg_corr_mean = np.mean([r['corr_real_mean'] for r in results])
    
    print(f"[COND-SENSITIVITY-SUMMARY] Avg D_pred corr REAL vs SHUFFLED: {avg_corr_shuffled:.4f}")
    print(f"[COND-SENSITIVITY-SUMMARY] Avg D_pred corr REAL vs MEAN: {avg_corr_mean:.4f}")
    
    print(f"\n[COND-SENSITIVITY-SUMMARY] Metric comparison:")
    print(f"  Spearman: REAL={np.mean([r['spear_real'] for r in results]):.4f} "
          f"SHUFFLED={np.mean([r['spear_shuffled'] for r in results]):.4f} "
          f"MEAN={np.mean([r['spear_mean'] for r in results]):.4f}")
    print(f"  kNN@10:   REAL={np.mean([r['knn_real'] for r in results]):.4f} "
          f"SHUFFLED={np.mean([r['knn_shuffled'] for r in results]):.4f} "
          f"MEAN={np.mean([r['knn_mean'] for r in results]):.4f}")
    
    if avg_corr_shuffled > 0.9:
        print(f"\n[COND-SENSITIVITY-DIAGNOSIS] üö® SC CONDITIONING IS EFFECTIVELY IGNORED")
        print(f"[COND-SENSITIVITY-DIAGNOSIS] The model produces nearly identical outputs regardless of input")
        print(f"[COND-SENSITIVITY-DIAGNOSIS] ‚Üí Fix: Align SC context to ST manifold, or include SC in training")
    elif avg_corr_shuffled > 0.7:
        print(f"\n[COND-SENSITIVITY-DIAGNOSIS] ‚ö†Ô∏è SC CONDITIONING IS WEAK")
        print(f"[COND-SENSITIVITY-DIAGNOSIS] The model partially uses conditioning but falls back to prior")
    else:
        print(f"\n[COND-SENSITIVITY-DIAGNOSIS] ‚úì SC CONDITIONING IS ACTIVE")
        print(f"[COND-SENSITIVITY-DIAGNOSIS] Low kNN may be due to other issues (margin/density)")
    
    return results

# Run Check 1
check1_results = run_conditioning_sensitivity_test(
    model, sc_minisets, scadata, sc_gene_expr_cpu, checkpoint, device, n_tests=3
)


In [None]:
# ===================================================================
# CHECK 2: EMBEDDING/CONTEXT STATISTICS - ST vs SC
# Compare Z embeddings and context H between ST and SC
# ===================================================================
print("\n" + "="*70)
print("CHECK 2: EMBEDDING/CONTEXT STATISTICS - ST vs SC")
print("="*70)

def compute_embedding_stats(model, st_minisets, sc_minisets, st_gene_expr_dict_cpu, 
                            sc_gene_expr_cpu, device, n_samples=5):
    """
    Compare Z (encoder output) and H (context encoder output) statistics
    between ST and SC minisets.
    """
    st_Z_stats = []
    st_H_stats = []
    sc_Z_stats = []
    sc_H_stats = []
    
    model.encoder.eval()
    model.context_encoder.eval()
    
    print("\n--- Computing ST embedding statistics ---")
    for idx in range(min(n_samples, len(st_minisets))):
        miniset = st_minisets[idx]
        n = miniset['n']
        indices = miniset['overlap_info']['indices'][:n]
        slide_id = miniset['overlap_info']['slide_id']
        gene_expr = st_gene_expr_dict_cpu[slide_id][indices].to(device)
        
        with torch.no_grad():
            # Encoder output Z
            Z = model.encoder(gene_expr)  # (n, D_z)
            
            # Context encoder output H
            Z_batch = Z.unsqueeze(0)  # (1, n, D_z)
            mask = torch.ones(1, n, dtype=torch.bool, device=device)
            H = model.context_encoder(Z_batch, mask)  # (1, n, c_dim)
            H = H.squeeze(0)  # (n, c_dim)
            
            # Compute statistics
            Z_norms = Z.norm(dim=1)  # (n,)
            H_norms = H.norm(dim=1)  # (n,)
            
            # Cosine similarity within set
            Z_normed = F.normalize(Z, dim=1)
            Z_cos_sim = (Z_normed @ Z_normed.T)
            Z_cos_sim_offdiag = Z_cos_sim[torch.triu(torch.ones(n, n, dtype=torch.bool), diagonal=1)]
            
            H_normed = F.normalize(H, dim=1)
            H_cos_sim = (H_normed @ H_normed.T)
            H_cos_sim_offdiag = H_cos_sim[torch.triu(torch.ones(n, n, dtype=torch.bool), diagonal=1)]
            
            st_Z_stats.append({
                'norm_mean': Z_norms.mean().item(),
                'norm_std': Z_norms.std().item(),
                'cos_sim_mean': Z_cos_sim_offdiag.mean().item(),
                'cos_sim_std': Z_cos_sim_offdiag.std().item(),
            })
            st_H_stats.append({
                'norm_mean': H_norms.mean().item(),
                'norm_std': H_norms.std().item(),
                'cos_sim_mean': H_cos_sim_offdiag.mean().item(),
                'cos_sim_std': H_cos_sim_offdiag.std().item(),
            })
    
    print("--- Computing SC embedding statistics ---")
    for idx in range(min(n_samples, len(sc_minisets))):
        miniset = sc_minisets[idx]
        n_A = miniset['n_A']
        indices_A = miniset['global_indices_A'].cpu().numpy()[:n_A]
        gene_expr = sc_gene_expr_cpu[indices_A].to(device)
        
        with torch.no_grad():
            # Encoder output Z
            Z = model.encoder(gene_expr)  # (n, D_z)
            
            # Context encoder output H
            Z_batch = Z.unsqueeze(0)  # (1, n, D_z)
            mask = torch.ones(1, n_A, dtype=torch.bool, device=device)
            H = model.context_encoder(Z_batch, mask)  # (1, n, c_dim)
            H = H.squeeze(0)  # (n, c_dim)
            
            # Compute statistics
            Z_norms = Z.norm(dim=1)
            H_norms = H.norm(dim=1)
            
            # Cosine similarity within set
            Z_normed = F.normalize(Z, dim=1)
            Z_cos_sim = (Z_normed @ Z_normed.T)
            Z_cos_sim_offdiag = Z_cos_sim[torch.triu(torch.ones(n_A, n_A, dtype=torch.bool), diagonal=1)]
            
            H_normed = F.normalize(H, dim=1)
            H_cos_sim = (H_normed @ H_normed.T)
            H_cos_sim_offdiag = H_cos_sim[torch.triu(torch.ones(n_A, n_A, dtype=torch.bool), diagonal=1)]
            
            sc_Z_stats.append({
                'norm_mean': Z_norms.mean().item(),
                'norm_std': Z_norms.std().item(),
                'cos_sim_mean': Z_cos_sim_offdiag.mean().item(),
                'cos_sim_std': Z_cos_sim_offdiag.std().item(),
            })
            sc_H_stats.append({
                'norm_mean': H_norms.mean().item(),
                'norm_std': H_norms.std().item(),
                'cos_sim_mean': H_cos_sim_offdiag.mean().item(),
                'cos_sim_std': H_cos_sim_offdiag.std().item(),
            })
    
    # Print comparison
    print("\n" + "="*70)
    print("[EMBED-STATS] Z (Encoder Output) Statistics")
    print("="*70)
    
    st_Z_norm_mean = np.mean([s['norm_mean'] for s in st_Z_stats])
    st_Z_norm_std = np.mean([s['norm_std'] for s in st_Z_stats])
    st_Z_cos_mean = np.mean([s['cos_sim_mean'] for s in st_Z_stats])
    st_Z_cos_std = np.mean([s['cos_sim_std'] for s in st_Z_stats])
    
    sc_Z_norm_mean = np.mean([s['norm_mean'] for s in sc_Z_stats])
    sc_Z_norm_std = np.mean([s['norm_std'] for s in sc_Z_stats])
    sc_Z_cos_mean = np.mean([s['cos_sim_mean'] for s in sc_Z_stats])
    sc_Z_cos_std = np.mean([s['cos_sim_std'] for s in sc_Z_stats])
    
    print(f"[EMBED-STATS-Z] ST: ||Z|| mean={st_Z_norm_mean:.4f} std={st_Z_norm_std:.4f}")
    print(f"[EMBED-STATS-Z] SC: ||Z|| mean={sc_Z_norm_mean:.4f} std={sc_Z_norm_std:.4f}")
    print(f"[EMBED-STATS-Z] ST: cos_sim mean={st_Z_cos_mean:.4f} std={st_Z_cos_std:.4f}")
    print(f"[EMBED-STATS-Z] SC: cos_sim mean={sc_Z_cos_mean:.4f} std={sc_Z_cos_std:.4f}")
    
    # Check for distribution shift
    norm_ratio = sc_Z_norm_mean / (st_Z_norm_mean + 1e-8)
    cos_diff = abs(sc_Z_cos_mean - st_Z_cos_mean)
    
    print(f"\n[EMBED-STATS-Z-SHIFT] ||Z|| ratio (SC/ST): {norm_ratio:.4f}")
    print(f"[EMBED-STATS-Z-SHIFT] cos_sim difference: {cos_diff:.4f}")
    
    if abs(norm_ratio - 1.0) > 0.3:
        print(f"[EMBED-STATS-Z-SHIFT] ‚ö†Ô∏è LARGE NORM SHIFT between ST and SC Z embeddings")
    if cos_diff > 0.15:
        print(f"[EMBED-STATS-Z-SHIFT] ‚ö†Ô∏è LARGE COS_SIM SHIFT between ST and SC Z embeddings")
    
    print("\n" + "="*70)
    print("[EMBED-STATS] H (Context Encoder Output) Statistics")
    print("="*70)
    
    st_H_norm_mean = np.mean([s['norm_mean'] for s in st_H_stats])
    st_H_norm_std = np.mean([s['norm_std'] for s in st_H_stats])
    st_H_cos_mean = np.mean([s['cos_sim_mean'] for s in st_H_stats])
    st_H_cos_std = np.mean([s['cos_sim_std'] for s in st_H_stats])
    
    sc_H_norm_mean = np.mean([s['norm_mean'] for s in sc_H_stats])
    sc_H_norm_std = np.mean([s['norm_std'] for s in sc_H_stats])
    sc_H_cos_mean = np.mean([s['cos_sim_mean'] for s in sc_H_stats])
    sc_H_cos_std = np.mean([s['cos_sim_std'] for s in sc_H_stats])
    
    print(f"[EMBED-STATS-H] ST: ||H|| mean={st_H_norm_mean:.4f} std={st_H_norm_std:.4f}")
    print(f"[EMBED-STATS-H] SC: ||H|| mean={sc_H_norm_mean:.4f} std={sc_H_norm_std:.4f}")
    print(f"[EMBED-STATS-H] ST: cos_sim mean={st_H_cos_mean:.4f} std={st_H_cos_std:.4f}")
    print(f"[EMBED-STATS-H] SC: cos_sim mean={sc_H_cos_mean:.4f} std={sc_H_cos_std:.4f}")
    
    # Check for distribution shift
    H_norm_ratio = sc_H_norm_mean / (st_H_norm_mean + 1e-8)
    H_cos_diff = abs(sc_H_cos_mean - st_H_cos_mean)
    
    print(f"\n[EMBED-STATS-H-SHIFT] ||H|| ratio (SC/ST): {H_norm_ratio:.4f}")
    print(f"[EMBED-STATS-H-SHIFT] cos_sim difference: {H_cos_diff:.4f}")
    
    if abs(H_norm_ratio - 1.0) > 0.3:
        print(f"[EMBED-STATS-H-SHIFT] ‚ö†Ô∏è LARGE NORM SHIFT between ST and SC context embeddings")
    if H_cos_diff > 0.15:
        print(f"[EMBED-STATS-H-SHIFT] ‚ö†Ô∏è LARGE COS_SIM SHIFT between ST and SC context embeddings")
    
    # Diagnosis
    print("\n" + "="*70)
    print("[EMBED-STATS-DIAGNOSIS]")
    print("="*70)
    
    if abs(norm_ratio - 1.0) > 0.3 or abs(H_norm_ratio - 1.0) > 0.3:
        print("[EMBED-STATS-DIAGNOSIS] üö® SIGNIFICANT DISTRIBUTION SHIFT detected")
        print("[EMBED-STATS-DIAGNOSIS] SC embeddings have different scale than ST")
        print("[EMBED-STATS-DIAGNOSIS] ‚Üí The diffusion model may treat SC as OOD")
    elif cos_diff > 0.15 or H_cos_diff > 0.15:
        print("[EMBED-STATS-DIAGNOSIS] ‚ö†Ô∏è MODERATE DISTRIBUTION SHIFT detected")
        print("[EMBED-STATS-DIAGNOSIS] SC tokens have different similarity structure than ST")
    else:
        print("[EMBED-STATS-DIAGNOSIS] ‚úì Embedding distributions appear similar")
        print("[EMBED-STATS-DIAGNOSIS] ‚Üí Issue may be elsewhere (margin/density)")
    
    return {
        'st_Z': st_Z_stats, 'st_H': st_H_stats,
        'sc_Z': sc_Z_stats, 'sc_H': sc_H_stats,
    }

# Need F.normalize
import torch.nn.functional as F

# Run Check 2
check2_results = compute_embedding_stats(
    model, st_minisets, sc_minisets, st_gene_expr_dict_cpu, sc_gene_expr_cpu, device, n_samples=5
)


In [None]:
# ===================================================================
# CHECK 3: SC MINISETS USING TRUE SPATIAL NEIGHBORS
# Build SC minisets using GT spatial proximity (like ST sampling)
# ===================================================================
print("\n" + "="*70)
print("CHECK 3: SC MINISETS WITH TRUE SPATIAL NEIGHBORS")
print("="*70)

def build_sc_spatial_minisets(scadata, sc_gene_expr_cpu, n_minisets=5, n_min=128, n_max=256):
    """
    Build SC minisets using TRUE spatial neighbors (like ST sampling).
    This tests if the problem is patch composition vs conditioning.
    """
    gt_coords = scadata.obsm['spatial_gt']
    n_cells = gt_coords.shape[0]
    
    # Build spatial distance matrix
    from scipy.spatial.distance import cdist
    D_spatial = cdist(gt_coords, gt_coords)
    
    minisets = []
    
    for i in range(n_minisets):
        # Random patch size
        n = np.random.randint(n_min, n_max + 1)
        n = min(n, n_cells)
        
        # Random center
        center_idx = np.random.randint(0, n_cells)
        
        # Get n-1 nearest spatial neighbors
        dists_from_center = D_spatial[center_idx]
        sorted_indices = np.argsort(dists_from_center)
        
        # Take center + nearest neighbors
        indices = sorted_indices[:n]
        
        minisets.append({
            'indices': indices,
            'n': n,
            'center_idx': center_idx,
        })
        
        print(f"[SC-SPATIAL-MINISET-{i}] n={n}, center={center_idx}, "
              f"max_dist_from_center={dists_from_center[indices[-1]]:.4f}")
    
    return minisets

def run_sc_spatial_inference(model, sc_spatial_minisets, scadata, sc_gene_expr_cpu, 
                              checkpoint, device):
    """
    Run inference on SC minisets built from spatial neighbors.
    """
    results = []
    
    if 'sigma_data' in checkpoint:
        model.sigma_data = checkpoint['sigma_data']
    if 'sigma_min' in checkpoint:
        model.sigma_min = checkpoint['sigma_min']
    if 'sigma_max' in checkpoint:
        model.sigma_max = checkpoint['sigma_max']
    
    for idx, miniset in enumerate(sc_spatial_minisets):
        indices = miniset['indices']
        n = miniset['n']
        
        print(f"\n--- SC Spatial Miniset {idx} (n={n}) ---")
        
        # Get GT coords (canonicalized)
        coords_gt_raw = scadata.obsm['spatial_gt'][indices]
        coords_gt_tensor = torch.tensor(coords_gt_raw, dtype=torch.float32, device=device)
        slide_ids_mini = torch.zeros(n, dtype=torch.long, device=device)
        coords_gt_canon, _, _ = uet.canonicalize_st_coords_per_slide(coords_gt_tensor, slide_ids_mini)
        coords_gt = coords_gt_canon.cpu().numpy()
        
        # Get gene expression
        gene_expr = sc_gene_expr_cpu[indices]
        
        # Run inference
        torch.manual_seed(42 + idx)
        np.random.seed(42 + idx)
        
        with torch.no_grad():
            inf_results = model.infer_sc_patchwise(
                sc_gene_expr=gene_expr,
                n_timesteps_sample=500,
                return_coords=True,
                patch_size=n,
                coverage_per_cell=1.0,
                n_align_iters=1,
                eta=0.0,
                guidance_scale=2.0,
                debug_flag=False
            )
        
        coords_pred = inf_results['coords_canon'].cpu().numpy()
        D_pred = inf_results['D_edm'].cpu().numpy()
        
        # Compute metrics
        gt_edm = squareform(pdist(coords_gt, 'euclidean'))
        triu_idx = np.triu_indices(n, k=1)
        
        spear, _ = spearmanr(gt_edm[triu_idx], D_pred[triu_idx])
        pear, _ = pearsonr(gt_edm[triu_idx], D_pred[triu_idx])
        
        knn_k10, _ = compute_knn_preservation(coords_gt, coords_pred, k=min(10, n-1))
        knn_k20, _ = compute_knn_preservation(coords_gt, coords_pred, k=min(20, n-1))
        
        results.append({
            'idx': idx,
            'n': n,
            'spearman': spear,
            'pearson': pear,
            'knn10': knn_k10 / min(10, n-1),
            'knn20': knn_k20 / min(20, n-1),
            'coords_gt': coords_gt,
            'coords_pred': coords_pred,
        })
        
        print(f"[SC-SPATIAL-{idx}] Spearman={spear:.4f}, Pearson={pear:.4f}")
        print(f"[SC-SPATIAL-{idx}] kNN@10={knn_k10/min(10, n-1):.4f}, kNN@20={knn_k20/min(20, n-1):.4f}")
    
    # Summary
    print("\n" + "="*70)
    print("[SC-SPATIAL-SUMMARY] SC Minisets with TRUE Spatial Neighbors")
    print("="*70)
    
    avg_spear = np.mean([r['spearman'] for r in results])
    avg_knn10 = np.mean([r['knn10'] for r in results])
    avg_knn20 = np.mean([r['knn20'] for r in results])
    
    print(f"[SC-SPATIAL-SUMMARY] Avg Spearman: {avg_spear:.4f}")
    print(f"[SC-SPATIAL-SUMMARY] Avg kNN@10: {avg_knn10:.4f}")
    print(f"[SC-SPATIAL-SUMMARY] Avg kNN@20: {avg_knn20:.4f}")
    
    # Compare with random SC minisets
    print("\n[SC-SPATIAL-VS-RANDOM] Comparison with random SC minisets:")
    if 'sc_results' in dir():
        random_knn10 = np.mean([r['knn_k10'] for r in sc_results])
        print(f"[SC-SPATIAL-VS-RANDOM] Spatial kNN@10: {avg_knn10:.4f}")
        print(f"[SC-SPATIAL-VS-RANDOM] Random kNN@10:  {random_knn10:.4f}")
        
        if avg_knn10 > random_knn10 + 0.1:
            print(f"[SC-SPATIAL-VS-RANDOM] ‚úì Spatial selection helps (+{avg_knn10-random_knn10:.3f})")
        else:
            print(f"[SC-SPATIAL-VS-RANDOM] ‚úó Spatial selection doesn't help much")
            print(f"[SC-SPATIAL-VS-RANDOM] ‚Üí Problem is NOT patch composition, it's CONDITIONING")
    
    return results

# Build SC spatial minisets
sc_spatial_minisets = build_sc_spatial_minisets(
    scadata, sc_gene_expr_cpu, n_minisets=5, n_min=128, n_max=256
)

# Run inference
check3_results = run_sc_spatial_inference(
    model, sc_spatial_minisets, scadata, sc_gene_expr_cpu, checkpoint, device
)

# Visualization
fig, axes = plt.subplots(2, 5, figsize=(20, 8))

for i, res in enumerate(check3_results):
    # GT
    ax = axes[0, i]
    ax.scatter(res['coords_gt'][:, 0], res['coords_gt'][:, 1], s=15, alpha=0.7, c='green')
    ax.set_title(f'SC-Spatial {i} GT\nn={res["n"]}', fontsize=10)
    ax.set_aspect('equal')
    ax.grid(True, alpha=0.3)
    
    # Pred
    ax = axes[1, i]
    ax.scatter(res['coords_pred'][:, 0], res['coords_pred'][:, 1], s=15, alpha=0.7, c='orange')
    ax.set_title(f'SC-Spatial {i} Pred\nkNN={res["knn10"]:.3f}', fontsize=10)
    ax.set_aspect('equal')
    ax.grid(True, alpha=0.3)

plt.suptitle('CHECK 3: SC Minisets with TRUE Spatial Neighbors', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('check3_sc_spatial_minisets.png', dpi=150, bbox_inches='tight')
plt.show()


In [None]:
# ===================================================================
# HYPOTHESIS TEST: ST vs SC Context Distribution Shift
# ===================================================================
print("\n" + "="*70)
print("TESTING: ST vs SC Context Distribution Shift")
print("="*70)

from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, roc_curve
from sklearn.preprocessing import StandardScaler

# Collect context tokens from ST mini-sets
st_contexts = []
with torch.no_grad():
    for idx in range(len(st_minisets)):
        miniset = st_minisets[idx]
        n = miniset['n']
        
        # Get gene expression
        slide_id = miniset['overlap_info']['slide_id']
        indices = miniset['overlap_info']['indices'][:n]
        gene_expr = st_gene_expr_dict_cpu[slide_id][indices].to(device)
        
        # Encode: gene_expr ‚Üí Z embeddings
        Z = model.encoder(gene_expr)  # (n, h_dim)
        
        # Create mask (all ones since no padding)
        mask = torch.ones(n, dtype=torch.bool, device=device)
        
        # Context encoder: Z ‚Üí H (context tokens)
        Z_batch = Z.unsqueeze(0)  # (1, n, h_dim)
        mask_batch = mask.unsqueeze(0)  # (1, n)
        H = model.context_encoder(Z_batch, mask_batch)  # (1, n, c_dim)
        
        # Store flattened context tokens
        st_contexts.append(H.squeeze(0).cpu())  # (n, c_dim)

# Collect context tokens from SC mini-sets
sc_contexts = []
with torch.no_grad():
    for idx in range(len(sc_minisets)):
        miniset = sc_minisets[idx]
        n_A = miniset['n_A']
        
        # Get gene expression
        indices_A = miniset['global_indices_A'].cpu().numpy()[:n_A]
        gene_expr = sc_gene_expr_cpu[indices_A].to(device)
        
        # Encode: gene_expr ‚Üí Z embeddings
        Z = model.encoder(gene_expr)  # (n_A, h_dim)
        
        # Create mask
        mask = torch.ones(n_A, dtype=torch.bool, device=device)
        
        # Context encoder: Z ‚Üí H
        Z_batch = Z.unsqueeze(0)  # (1, n_A, h_dim)
        mask_batch = mask.unsqueeze(0)  # (1, n_A)
        H = model.context_encoder(Z_batch, mask_batch)  # (1, n_A, c_dim)
        
        # Store flattened context tokens
        sc_contexts.append(H.squeeze(0).cpu())  # (n_A, c_dim)

# Concatenate all context tokens
st_H = torch.cat(st_contexts, dim=0).numpy()  # (N_st, c_dim)
sc_H = torch.cat(sc_contexts, dim=0).numpy()  # (N_sc, c_dim)

print(f"\nST contexts: {st_H.shape}")
print(f"SC contexts: {sc_H.shape}")

# ===================================================================
# 1. COMPUTE STATISTICS
# ===================================================================
print("\n--- Context Token Statistics ---")

# Per-dimension statistics
st_mean = st_H.mean(axis=0)
st_std = st_H.std(axis=0)
st_norm = np.linalg.norm(st_H, axis=1)

sc_mean = sc_H.mean(axis=0)
sc_std = sc_H.std(axis=0)
sc_norm = np.linalg.norm(sc_H, axis=1)

print(f"\nST - Mean of means: {st_mean.mean():.6f}")
print(f"ST - Mean of stds:  {st_std.mean():.6f}")
print(f"ST - Mean norm:     {st_norm.mean():.6f} ¬± {st_norm.std():.6f}")

print(f"\nSC - Mean of means: {sc_mean.mean():.6f}")
print(f"SC - Mean of stds:  {sc_std.mean():.6f}")
print(f"SC - Mean norm:     {sc_norm.mean():.6f} ¬± {sc_norm.std():.6f}")

# Norm difference
norm_diff = np.abs(st_norm.mean() - sc_norm.mean()) / st_norm.mean()
print(f"\n[Norm difference: {norm_diff*100:.2f}%]")

# Mean vector cosine similarity
cos_sim = np.dot(st_mean, sc_mean) / (np.linalg.norm(st_mean) * np.linalg.norm(sc_mean))
print(f"[Mean vector cosine similarity: {cos_sim:.4f}]")

# ===================================================================
# 2. LINEAR CLASSIFIER TEST
# ===================================================================
print("\n--- Linear Classifier Test (ST vs SC) ---")

# Prepare data
X = np.vstack([st_H, sc_H])
y = np.array([0]*len(st_H) + [1]*len(sc_H))  # 0=ST, 1=SC

# Shuffle
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.3, random_state=42, stratify=y
)

# Standardize
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

# Train logistic regression
clf = LogisticRegression(max_iter=1000, random_state=42)
clf.fit(X_train_scaled, y_train)

# Predict
y_pred_proba = clf.predict_proba(X_test_scaled)[:, 1]
auc = roc_auc_score(y_test, y_pred_proba)
acc = clf.score(X_test_scaled, y_test)

print(f"\nClassifier Accuracy: {acc:.4f}")
print(f"Classifier AUC:      {auc:.4f}")

# Interpretation
if auc > 0.85:
    print("\n‚ö†Ô∏è  HIGH AUC (>0.85): Strong distribution shift detected!")
    print("   The score net sees ST vs SC contexts as very different.")
    print("   This supports the 'conditioning shift' hypothesis.")
elif auc > 0.70:
    print("\n‚ö†Ô∏è  MODERATE AUC (0.70-0.85): Moderate distribution shift.")
    print("   ST and SC contexts are distinguishable but not wildly different.")
elif auc > 0.60:
    print("\n‚úì  LOW AUC (0.60-0.70): Mild distribution shift.")
    print("   Contexts are similar. Shift may not be the main issue.")
else:
    print("\n‚úì  VERY LOW AUC (<0.60): Contexts are indistinguishable.")
    print("   Conditioning shift is NOT the problem.")

# ===================================================================
# 3. VISUALIZE (Optional)
# ===================================================================
print("\n--- Visualization ---")

from sklearn.decomposition import PCA
import matplotlib.pyplot as plt

# PCA projection
pca = PCA(n_components=2)
X_pca = pca.fit_transform(X)
explained_var = pca.explained_variance_ratio_

plt.figure(figsize=(8, 6))
plt.scatter(X_pca[:len(st_H), 0], X_pca[:len(st_H), 1], 
           alpha=0.5, s=10, label='ST contexts', c='blue')
plt.scatter(X_pca[len(st_H):, 0], X_pca[len(st_H):, 1], 
           alpha=0.5, s=10, label='SC contexts', c='red')
plt.xlabel(f'PC1 ({explained_var[0]*100:.1f}%)')
plt.ylabel(f'PC2 ({explained_var[1]*100:.1f}%)')
plt.title('Context Token Distribution: ST vs SC')
plt.legend()
plt.grid(alpha=0.3)
plt.tight_layout()
# plt.savefig('context_shift_pca.png', dpi=150)
print("‚úì Saved: context_shift_pca.png")
plt.show()

# ROC curve
fpr, tpr, _ = roc_curve(y_test, y_pred_proba)
plt.figure(figsize=(6, 6))
plt.plot(fpr, tpr, linewidth=2, label=f'AUC = {auc:.3f}')
plt.plot([0, 1], [0, 1], 'k--', linewidth=1, label='Random')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC: ST vs SC Context Classification')
plt.legend()
plt.grid(alpha=0.3)
plt.tight_layout()
# plt.savefig('context_shift_roc.png', dpi=150)
print("‚úì Saved: context_shift_roc.png")
plt.show()

print("\n" + "="*70)
print("ANALYSIS COMPLETE")
print("="*70)

In [None]:
# ===================================================================
# CORAL TRANSFORMATION + RE-TEST
# ===================================================================
print("\n" + "="*70)
print("APPLYING CORAL TRANSFORMATION TO SC CONTEXTS")
print("="*70)

# ===================================================================
# STEP 1: Compute Covariance Statistics
# ===================================================================
print("\nComputing ST and SC covariance statistics...")

# Flatten all ST context tokens into (T_st, c_dim)
st_H_flat = st_H  # Already (N_st, c_dim)

# Flatten all SC context tokens into (T_sc, c_dim)  
sc_H_flat = sc_H  # Already (N_sc, c_dim)

# Compute statistics
mu_st = torch.tensor(st_H_flat.mean(axis=0), dtype=torch.float32, device=device)
mu_sc = torch.tensor(sc_H_flat.mean(axis=0), dtype=torch.float32, device=device)

# Compute covariance matrices
st_H_centered = st_H_flat - mu_st.cpu().numpy()
sc_H_centered = sc_H_flat - mu_sc.cpu().numpy()

cov_st = torch.tensor(
    (st_H_centered.T @ st_H_centered) / (st_H_flat.shape[0] - 1),
    dtype=torch.float32, device=device
)
cov_sc = torch.tensor(
    (sc_H_centered.T @ sc_H_centered) / (sc_H_flat.shape[0] - 1),
    dtype=torch.float32, device=device
)

print(f"‚úì mu_st shape: {mu_st.shape}")
print(f"‚úì mu_sc shape: {mu_sc.shape}")
print(f"‚úì cov_st shape: {cov_st.shape}")
print(f"‚úì cov_sc shape: {cov_sc.shape}")

# ===================================================================
# STEP 2: Build CORAL Transform
# ===================================================================
print("\nBuilding CORAL transform...")

def sqrtm_psd(C, eps=1e-5):
    """Compute matrix square root of PSD matrix."""
    evals, evecs = torch.linalg.eigh(C)
    evals = torch.clamp(evals, min=eps)
    return (evecs * torch.sqrt(evals)) @ evecs.T

def invsqrtm_psd(C, eps=1e-5):
    """Compute inverse matrix square root of PSD matrix."""
    evals, evecs = torch.linalg.eigh(C)
    evals = torch.clamp(evals, min=eps)
    return (evecs * (1.0 / torch.sqrt(evals))) @ evecs.T

# Add shrinkage for numerical stability
shrink = 0.01
D = cov_sc.shape[0]
I = torch.eye(D, device=device, dtype=torch.float32)

cov_sc_shrunk = (1 - shrink) * cov_sc + shrink * I
cov_st_shrunk = (1 - shrink) * cov_st + shrink * I

# Compute transform matrices
A = invsqrtm_psd(cov_sc_shrunk, eps=1e-5)  # C_sc^{-1/2}
B = sqrtm_psd(cov_st_shrunk, eps=1e-5)     # C_st^{1/2}

print(f"‚úì A (C_sc^{{-1/2}}) shape: {A.shape}")
print(f"‚úì B (C_st^{{1/2}}) shape: {B.shape}")

# ===================================================================
# STEP 3: Apply CORAL to SC Contexts
# ===================================================================
print("\nApplying CORAL transformation to SC contexts...")

sc_H_torch = torch.tensor(sc_H, dtype=torch.float32, device=device)

# Transform: (x - mu_sc) @ A @ B + mu_st
sc_H_centered_torch = sc_H_torch - mu_sc
sc_H_transformed_torch = sc_H_centered_torch @ A @ B + mu_st
sc_H_transformed = sc_H_transformed_torch.cpu().numpy()

print(f"‚úì SC contexts transformed: {sc_H_transformed.shape}")

# Verify transformation statistics
print("\n--- Post-CORAL SC Statistics ---")
sc_transformed_mean = sc_H_transformed.mean(axis=0)
sc_transformed_std = sc_H_transformed.std(axis=0)
sc_transformed_norm = np.linalg.norm(sc_H_transformed, axis=1)

print(f"SC (transformed) - Mean of means: {sc_transformed_mean.mean():.6f}")
print(f"SC (transformed) - Mean of stds:  {sc_transformed_std.mean():.6f}")
print(f"SC (transformed) - Mean norm:     {sc_transformed_norm.mean():.6f} ¬± {sc_transformed_norm.std():.6f}")

print(f"\nST - Mean of means: {st_mean.mean():.6f} (reference)")
print(f"ST - Mean of stds:  {st_std.mean():.6f} (reference)")
print(f"ST - Mean norm:     {st_norm.mean():.6f} ¬± {st_norm.std():.6f} (reference)")

# Cosine similarity of mean vectors (should be closer to 1.0 now)
cos_sim_after = np.dot(st_mean, sc_transformed_mean) / (
    np.linalg.norm(st_mean) * np.linalg.norm(sc_transformed_mean)
)
print(f"\n[Post-CORAL mean vector cosine similarity: {cos_sim_after:.4f}]")
print(f"[Pre-CORAL mean vector cosine similarity:  {cos_sim:.4f}]")

# ===================================================================
# STEP 4: Re-run Linear Classifier Test
# ===================================================================
print("\n" + "="*70)
print("RE-RUNNING LINEAR CLASSIFIER TEST (POST-CORAL)")
print("="*70)

# Prepare data with transformed SC contexts
X_post = np.vstack([st_H, sc_H_transformed])
y_post = np.array([0]*len(st_H) + [1]*len(sc_H_transformed))

# Split
X_train_post, X_test_post, y_train_post, y_test_post = train_test_split(
    X_post, y_post, test_size=0.3, random_state=42, stratify=y_post
)

# Standardize
scaler_post = StandardScaler()
X_train_post_scaled = scaler_post.fit_transform(X_train_post)
X_test_post_scaled = scaler_post.transform(X_test_post)

# Train classifier
clf_post = LogisticRegression(max_iter=1000, random_state=42)
clf_post.fit(X_train_post_scaled, y_train_post)

# Predict
y_pred_proba_post = clf_post.predict_proba(X_test_post_scaled)[:, 1]
auc_post = roc_auc_score(y_test_post, y_pred_proba_post)
acc_post = clf_post.score(X_test_post_scaled, y_test_post)

print(f"\nPOST-CORAL Classifier Accuracy: {acc_post:.4f}")
print(f"POST-CORAL Classifier AUC:      {auc_post:.4f}")

print("\n" + "="*70)
print("COMPARISON: PRE-CORAL vs POST-CORAL")
print("="*70)
print(f"PRE-CORAL  AUC: {auc:.4f}")
print(f"POST-CORAL AUC: {auc_post:.4f}")
print(f"AUC Reduction:  {(auc - auc_post):.4f} ({((auc - auc_post)/auc)*100:.1f}%)")

# Interpretation
if auc_post < 0.60:
    print("\n‚úì SUCCESS: AUC < 0.60 ‚Üí ST and SC contexts are now indistinguishable!")
    print("  CORAL successfully aligned the conditioning distributions.")
    print("  ‚Üí Use CORAL at inference to fix SC local scrambling.")
elif auc_post < 0.70:
    print("\n‚úì GOOD: AUC 0.60-0.70 ‚Üí Distributions much more similar.")
    print("  CORAL helped but some residual shift remains.")
    print("  ‚Üí Try CORAL + small adapter for best results.")
elif auc_post < 0.85:
    print("\n‚ö†Ô∏è  PARTIAL: AUC 0.70-0.85 ‚Üí Shift reduced but still significant.")
    print("  CORAL alone may not be sufficient.")
    print("  ‚Üí Consider training a small SC‚ÜíST adapter.")
else:
    print("\n‚ö†Ô∏è  NO EFFECT: AUC still >0.85 ‚Üí CORAL did not resolve shift.")
    print("  The shift is not purely covariance-based.")
    print("  ‚Üí Need learned adapter or SC fine-tuning with geometry losses.")

# ===================================================================
# STEP 5: Visualize Post-CORAL Distribution
# ===================================================================
print("\n--- Post-CORAL Visualization ---")

# PCA on post-CORAL data
X_post_pca = pca.transform(X_post)  # Use same PCA as before for comparison

fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Pre-CORAL
axes[0].scatter(X_pca[:len(st_H), 0], X_pca[:len(st_H), 1],
               alpha=0.5, s=10, label='ST contexts', c='blue')
axes[0].scatter(X_pca[len(st_H):, 0], X_pca[len(st_H):, 1],
               alpha=0.5, s=10, label='SC contexts (original)', c='red')
axes[0].set_xlabel(f'PC1 ({explained_var[0]*100:.1f}%)')
axes[0].set_ylabel(f'PC2 ({explained_var[1]*100:.1f}%)')
axes[0].set_title(f'PRE-CORAL (AUC={auc:.3f})')
axes[0].legend()
axes[0].grid(alpha=0.3)

# Post-CORAL
axes[1].scatter(X_post_pca[:len(st_H), 0], X_post_pca[:len(st_H), 1],
               alpha=0.5, s=10, label='ST contexts', c='blue')
axes[1].scatter(X_post_pca[len(st_H):, 0], X_post_pca[len(st_H):, 1],
               alpha=0.5, s=10, label='SC contexts (CORAL)', c='green')
axes[1].set_xlabel(f'PC1 ({explained_var[0]*100:.1f}%)')
axes[1].set_ylabel(f'PC2 ({explained_var[1]*100:.1f}%)')
axes[1].set_title(f'POST-CORAL (AUC={auc_post:.3f})')
axes[1].legend()
axes[1].grid(alpha=0.3)

plt.tight_layout()
# plt.savefig('context_shift_before_after_coral.png', dpi=150)
print("‚úì Saved: context_shift_before_after_coral.png")
plt.show()

# ROC curves comparison
fpr_post, tpr_post, _ = roc_curve(y_test_post, y_pred_proba_post)

plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, linewidth=2, label=f'Pre-CORAL (AUC={auc:.3f})', color='red')
plt.plot(fpr_post, tpr_post, linewidth=2, label=f'Post-CORAL (AUC={auc_post:.3f})', color='green')
plt.plot([0, 1], [0, 1], 'k--', linewidth=1, label='Random')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curves: ST vs SC Classification')
plt.legend()
plt.grid(alpha=0.3)
plt.tight_layout()
# plt.savefig('context_shift_roc_comparison.png', dpi=150)
print("‚úì Saved: context_shift_roc_comparison.png")
plt.show()

print("\n" + "="*70)
print("CORAL ANALYSIS COMPLETE")
print("="*70)

# ===================================================================
# STEP 6: Save CORAL Transform for Inference
# ===================================================================
print("\nSaving CORAL transform parameters...")

coral_params = {
    'mu_st': mu_st.cpu(),
    'mu_sc': mu_sc.cpu(),
    'cov_st': cov_st.cpu(),
    'cov_sc': cov_sc.cpu(),
    'A': A.cpu(),
    'B': B.cpu(),
    'shrink': shrink,
}

# torch.save(coral_params, 'coral_transform_params.pt')
print("‚úì Saved: coral_transform_params.pt")
print("\nYou can load this in your inference code and apply:")
print("  sc_ctx_transformed = (sc_ctx - mu_sc) @ A @ B + mu_st")

In [None]:
# ===================================================================
# HYPOTHESIS TEST: ST vs SC Context Distribution Shift
# ===================================================================
print("\n" + "="*70)
print("TESTING: ST vs SC Context Distribution Shift")
print("="*70)

from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, roc_curve
from sklearn.preprocessing import StandardScaler

# Collect context tokens from ST mini-sets
st_contexts = []
with torch.no_grad():
    for idx in range(len(st_minisets)):
        miniset = st_minisets[idx]
        n = miniset['n']
        
        # Get gene expression
        slide_id = miniset['overlap_info']['slide_id']
        indices = miniset['overlap_info']['indices'][:n]
        gene_expr = st_gene_expr_dict_cpu[slide_id][indices].to(device)
        
        # Encode: gene_expr ‚Üí Z embeddings
        Z = model.encoder(gene_expr)  # (n, h_dim)
        
        # Create mask (all ones since no padding)
        mask = torch.ones(n, dtype=torch.bool, device=device)
        
        # Context encoder: Z ‚Üí H (context tokens)
        Z_batch = Z.unsqueeze(0)  # (1, n, h_dim)
        mask_batch = mask.unsqueeze(0)  # (1, n)
        H = model.context_encoder(Z_batch, mask_batch)  # (1, n, c_dim)
        
        # Store flattened context tokens
        st_contexts.append(H.squeeze(0).cpu())  # (n, c_dim)

# Collect context tokens from SC mini-sets
sc_contexts = []
with torch.no_grad():
    for idx in range(len(sc_minisets)):
        miniset = sc_minisets[idx]
        n_A = miniset['n_A']
        
        # Get gene expression
        indices_A = miniset['global_indices_A'].cpu().numpy()[:n_A]
        gene_expr = sc_gene_expr_cpu[indices_A].to(device)
        
        # Encode: gene_expr ‚Üí Z embeddings
        Z = model.encoder(gene_expr)  # (n_A, h_dim)
        
        # Create mask
        mask = torch.ones(n_A, dtype=torch.bool, device=device)
        
        # Context encoder: Z ‚Üí H
        Z_batch = Z.unsqueeze(0)  # (1, n_A, h_dim)
        mask_batch = mask.unsqueeze(0)  # (1, n_A)
        H = model.context_encoder(Z_batch, mask_batch)  # (1, n_A, c_dim)
        
        # Store flattened context tokens
        sc_contexts.append(H.squeeze(0).cpu())  # (n_A, c_dim)

# Concatenate all context tokens
st_H = torch.cat(st_contexts, dim=0).numpy()  # (N_st, c_dim)
sc_H = torch.cat(sc_contexts, dim=0).numpy()  # (N_sc, c_dim)

print(f"\nST contexts: {st_H.shape}")
print(f"SC contexts: {sc_H.shape}")

# ===================================================================
# 1. COMPUTE STATISTICS
# ===================================================================
print("\n--- Context Token Statistics ---")

# Per-dimension statistics
st_mean = st_H.mean(axis=0)
st_std = st_H.std(axis=0)
st_norm = np.linalg.norm(st_H, axis=1)

sc_mean = sc_H.mean(axis=0)
sc_std = sc_H.std(axis=0)
sc_norm = np.linalg.norm(sc_H, axis=1)

print(f"\nST - Mean of means: {st_mean.mean():.6f}")
print(f"ST - Mean of stds:  {st_std.mean():.6f}")
print(f"ST - Mean norm:     {st_norm.mean():.6f} ¬± {st_norm.std():.6f}")

print(f"\nSC - Mean of means: {sc_mean.mean():.6f}")
print(f"SC - Mean of stds:  {sc_std.mean():.6f}")
print(f"SC - Mean norm:     {sc_norm.mean():.6f} ¬± {sc_norm.std():.6f}")

# Norm difference
norm_diff = np.abs(st_norm.mean() - sc_norm.mean()) / st_norm.mean()
print(f"\n[Norm difference: {norm_diff*100:.2f}%]")

# Mean vector cosine similarity
cos_sim = np.dot(st_mean, sc_mean) / (np.linalg.norm(st_mean) * np.linalg.norm(sc_mean))
print(f"[Mean vector cosine similarity: {cos_sim:.4f}]")

# ===================================================================
# 2. LINEAR CLASSIFIER TEST
# ===================================================================
print("\n--- Linear Classifier Test (ST vs SC) ---")

# Prepare data
X = np.vstack([st_H, sc_H])
y = np.array([0]*len(st_H) + [1]*len(sc_H))  # 0=ST, 1=SC

# Shuffle
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.3, random_state=42, stratify=y
)

# Standardize
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

# Train logistic regression
clf = LogisticRegression(max_iter=1000, random_state=42)
clf.fit(X_train_scaled, y_train)

# Predict
y_pred_proba = clf.predict_proba(X_test_scaled)[:, 1]
auc = roc_auc_score(y_test, y_pred_proba)
acc = clf.score(X_test_scaled, y_test)

print(f"\nClassifier Accuracy: {acc:.4f}")
print(f"Classifier AUC:      {auc:.4f}")

# Interpretation
if auc > 0.85:
    print("\n‚ö†Ô∏è  HIGH AUC (>0.85): Strong distribution shift detected!")
    print("   The score net sees ST vs SC contexts as very different.")
    print("   This supports the 'conditioning shift' hypothesis.")
elif auc > 0.70:
    print("\n‚ö†Ô∏è  MODERATE AUC (0.70-0.85): Moderate distribution shift.")
    print("   ST and SC contexts are distinguishable but not wildly different.")
elif auc > 0.60:
    print("\n‚úì  LOW AUC (0.60-0.70): Mild distribution shift.")
    print("   Contexts are similar. Shift may not be the main issue.")
else:
    print("\n‚úì  VERY LOW AUC (<0.60): Contexts are indistinguishable.")
    print("   Conditioning shift is NOT the problem.")

# ===================================================================
# 3. VISUALIZE (Optional)
# ===================================================================
print("\n--- Visualization ---")

from sklearn.decomposition import PCA
import matplotlib.pyplot as plt

# PCA projection
pca = PCA(n_components=2)
X_pca = pca.fit_transform(X)
explained_var = pca.explained_variance_ratio_

plt.figure(figsize=(8, 6))
plt.scatter(X_pca[:len(st_H), 0], X_pca[:len(st_H), 1], 
           alpha=0.5, s=10, label='ST contexts', c='blue')
plt.scatter(X_pca[len(st_H):, 0], X_pca[len(st_H):, 1], 
           alpha=0.5, s=10, label='SC contexts', c='red')
plt.xlabel(f'PC1 ({explained_var[0]*100:.1f}%)')
plt.ylabel(f'PC2 ({explained_var[1]*100:.1f}%)')
plt.title('Context Token Distribution: ST vs SC')
plt.legend()
plt.grid(alpha=0.3)
plt.tight_layout()
plt.savefig('context_shift_pca.png', dpi=150)
print("‚úì Saved: context_shift_pca.png")
plt.close()

# ROC curve
fpr, tpr, _ = roc_curve(y_test, y_pred_proba)
plt.figure(figsize=(6, 6))
plt.plot(fpr, tpr, linewidth=2, label=f'AUC = {auc:.3f}')
plt.plot([0, 1], [0, 1], 'k--', linewidth=1, label='Random')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC: ST vs SC Context Classification')
plt.legend()
plt.grid(alpha=0.3)
plt.tight_layout()
plt.savefig('context_shift_roc.png', dpi=150)
print("‚úì Saved: context_shift_roc.png")
plt.close()

print("\n" + "="*70)
print("ANALYSIS COMPLETE")
print("="*70)

In [None]:
# ===================================================================
# PATCHWISE INFERENCE WITH GLOBAL SCALE
# ===================================================================

import torch
import numpy as np
from datetime import datetime
import sys
import os
import scanpy as sc
from scipy.spatial.distance import pdist, squareform
from scipy.stats import pearsonr, spearmanr
import matplotlib.pyplot as plt
import seaborn as sns

sys.path.insert(0, '/home/ehtesamul/sc_st/model')
from core_models_et_p3 import GEMSModel
import utils_et as uet

# ===================================================================
# PATHS AND CONFIG
# ===================================================================
output_dir = "/home/ehtesamul/sc_st/model/gems_mousebrain_output"
# checkpoint_path = f"{output_dir}/phase2_sc_finetuned_checkpoint.pt"
checkpoint_path = f"{output_dir}/phase1_st_checkpoint.pt"

device = 'cuda' if torch.cuda.is_available() else 'cpu'

print("="*70)
print("LOADING DATA AND MODEL")
print("="*70)

# Load data
adata_sc = sc.read_h5ad(f"{output_dir}/scadata_with_gems_20260112_023604.h5ad")

if hasattr(adata_sc, 'raw') and adata_sc.raw is not None:
    sc_expr = torch.tensor(adata_sc.raw.X.toarray() if hasattr(adata_sc.raw.X, 'toarray') else adata_sc.raw.X, dtype=torch.float32)
else:
    sc_expr = torch.tensor(adata_sc.X.toarray() if hasattr(adata_sc.X, 'toarray') else adata_sc.X, dtype=torch.float32)

# ===================================================================
# LOAD AND NORMALIZE GT COORDS
# ===================================================================

# Load raw GT coords (numpy)
gt_coords_raw = adata_sc.obsm['spatial_gt']
n_cells, n_genes = sc_expr.shape
print(f"‚úì Loaded SC data: {n_cells} cells √ó {n_genes} genes")
print(f"‚úì Ground truth coords (raw): {gt_coords_raw.shape}")

# Convert to tensor for normalization
gt_coords_tensor = torch.tensor(gt_coords_raw, dtype=torch.float32, device=device)
slide_ids = torch.zeros(gt_coords_tensor.shape[0], dtype=torch.long, device=device)

# Canonicalize (same as training)
gt_coords_norm, gt_mu, gt_scale = uet.canonicalize_st_coords_per_slide(
    gt_coords_tensor, slide_ids
)

# Keep BOTH versions:
# - gt_coords_norm: tensor on device (for torch operations, passing to inference)
# - gt_coords_np: numpy on CPU (for scipy operations like pdist)
gt_coords_np = gt_coords_norm.cpu().numpy()

print(f"‚úì GT coords normalized: scale={gt_scale[0].item():.4f}")
print(f"‚úì GT coords RMS: {gt_coords_norm.pow(2).mean().sqrt().item():.4f}")
print(f"‚úì GT coords range: X=[{gt_coords_np[:,0].min():.3f}, {gt_coords_np[:,0].max():.3f}], "
      f"Y=[{gt_coords_np[:,1].min():.3f}, {gt_coords_np[:,1].max():.3f}]")


# Load checkpoint
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
print(f"‚úì Loaded checkpoint")

# Initialize model (from run_mouse_brain_2.py config)
model = GEMSModel(
    n_genes=n_genes,
    n_embedding=[512, 256, 128],
    D_latent=32,
    c_dim=256,
    n_heads=4,
    isab_m=64,
    dist_bins=24,
    device=device
)

# Load weights
model.encoder.load_state_dict(checkpoint['encoder'])
model.context_encoder.load_state_dict(checkpoint['context_encoder'])
model.generator.load_state_dict(checkpoint['generator'])
model.score_net.load_state_dict(checkpoint['score_net'])

model.encoder.eval()
model.context_encoder.eval()
model.generator.eval()
model.score_net.eval()
print(f"‚úì Model loaded and set to eval mode")

# ===================================================================
# COMPUTE CORAL TRANSFORMATION
# ===================================================================
print("\n" + "="*70)
print("COMPUTING CORAL TRANSFORMATION")
print("="*70)

# 1. Load ST data (you need this to compute ST context distribution)
import pandas as pd
import anndata as ad

st_counts = '/home/ehtesamul/sc_st/data/mousedata_2020/E1z2/simu_st1_counts_et.csv'
st_meta = '/home/ehtesamul/sc_st/data/mousedata_2020/E1z2/simu_st1_metadata_et.csv'

st_expr_df = pd.read_csv(st_counts, index_col=0)
st_meta_df = pd.read_csv(st_meta, index_col=0)
stadata = ad.AnnData(X=st_expr_df.values.T)
stadata.obs_names = st_expr_df.columns
stadata.var_names = st_expr_df.index
stadata.obsm['spatial'] = st_meta_df[['coord_x', 'coord_y']].values

# Get common genes
common = sorted(list(set(adata_sc.var_names) & set(stadata.var_names)))
X_st = stadata[:, common].X
if hasattr(X_st, "toarray"): 
    X_st = X_st.toarray()

st_expr = torch.tensor(X_st, dtype=torch.float32, device=device)
st_coords_raw = torch.tensor(stadata.obsm['spatial'], dtype=torch.float32, device=device)
slide_ids = torch.zeros(st_expr.shape[0], dtype=torch.long, device=device)
st_coords, _, _ = uet.canonicalize_st_coords_per_slide(st_coords_raw, slide_ids)

print(f"‚úì Loaded ST data: {st_expr.shape[0]} spots √ó {st_expr.shape[1]} genes")

# 2. Prepare ST gene expression dict
st_gene_expr_dict = {0: st_expr.cpu()}

# 3. Load Stage B targets (needed for compute_coral_params_from_st)
# This should be in your checkpoint or output directory
targets_path = f"{output_dir}/stageB_targets/targets_dict.pt"
if os.path.exists(targets_path):
    model.targets_dict = torch.load(targets_path, map_location='cpu')
    print(f"‚úì Loaded targets_dict from {targets_path}")
else:
    print(f"‚ö†Ô∏è  Targets not found at {targets_path}")
    print("   Running Stage B precomputation...")
    slides_dict = {0: (st_coords, st_expr)}
    model.train_stageB(slides=slides_dict, outdir=f"{output_dir}/stageB_targets")
    print("‚úì Stage B complete")

# 4. Compute CORAL parameters
print("\n--- Computing ST context distribution ---")
model.compute_coral_params_from_st(
    st_gene_expr_dict=st_gene_expr_dict,
    n_samples=2000,
    n_min=96,
    n_max=384,
)

print("\n--- Building CORAL transformation ---")
model.build_coral_transform(
    sc_gene_expr=sc_expr,
    n_samples=2000,
    n_min=96,
    n_max=384,
    shrink=0.01,
    eps=1e-5,
)

print("‚úì CORAL transformation ready!")

print("\n" + "="*70)

# ===================================================================
# RUN INFERENCE (NEW GLOBAL SCALE VERSION)
# ===================================================================
print("\n" + "="*70)
print("RUNNING INFERENCE (GLOBAL SCALE)")
print("="*70)

# ADD THIS after loading checkpoint and before calling inference
if 'sigma_data' in checkpoint:
    model.sigma_data = checkpoint['sigma_data']
if 'sigma_min' in checkpoint:
    model.sigma_min = checkpoint['sigma_min']
if 'sigma_max' in checkpoint:
    model.sigma_max = checkpoint['sigma_max']


with torch.no_grad():
    results = model.infer_sc_patchwise(
        sc_gene_expr=sc_expr,
        n_timesteps_sample=600,
        return_coords=True,
        patch_size=224,
        coverage_per_cell=8.0,
        n_align_iters=10,
        eta=0.0,
        guidance_scale=2.0,
        gt_coords=gt_coords_norm,
        debug_knn=True,
        debug_max_patches=15,
        debug_k_list=(10, 20),
        pool_mult=2.0,
        stochastic_tau=1.0,
        tau_mode="adaptive_kth",
        ensure_connected=True,
        local_refine=False,
        anchor_sampling_mode="seq_align_only",  # NEW PARAMETER
        commit_frac=0.75,
        seq_align_dim=32,
    )


D_edm_pred = results['D_edm'].cpu().numpy()
coords_pred = results['coords_canon'].cpu().numpy()
print(f"\n‚úì Inference complete!")
print(f"  Predicted EDM: {D_edm_pred.shape}")
print(f"  Predicted coords: {coords_pred.shape}")

In [None]:
# ===================================================================
# COMPUTE GROUND TRUTH EDM
# ===================================================================
print("\n" + "="*70)
print("COMPUTING METRICS")
print("="*70)

gt_edm = squareform(pdist(gt_coords_np, 'euclidean'))
print(f"‚úì Ground truth EDM: {gt_edm.shape}")

# Extract upper triangle distances
triu_indices = np.triu_indices(n_cells, k=1)
gt_distances = gt_edm[triu_indices]
pred_distances = D_edm_pred[triu_indices]

# Scale alignment
scale = np.median(gt_distances) / np.median(pred_distances)
pred_distances_scaled = pred_distances * scale

# Correlations
pearson_corr, _ = pearsonr(gt_distances, pred_distances_scaled)
spearman_corr, _ = spearmanr(gt_distances, pred_distances_scaled)

print(f"\nPearson Correlation:  {pearson_corr:.4f}")
print(f"Spearman Correlation: {spearman_corr:.4f}")
print(f"Scale factor: {scale:.4f}")

# ===================================================================
# VISUALIZATIONS
# ===================================================================
print("\n" + "="*70)
print("GENERATING PLOTS")
print("="*70)

# Plot 1: Coordinate comparison
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

axes[0].scatter(gt_coords_np[:, 0], gt_coords_np[:, 1], s=5, alpha=0.6, c='blue')
axes[0].set_title('Ground Truth Coordinates', fontsize=14, weight='bold')
axes[0].set_aspect('equal')

axes[1].scatter(coords_pred[:, 0], coords_pred[:, 1], s=5, alpha=0.6, c='red')
axes[1].set_title('Predicted Coordinates (Global Scale)', fontsize=14, weight='bold')
axes[1].set_aspect('equal')

plt.tight_layout()
plt.show()

# Plot 2: Distance scatter - Pearson and Spearman
sample_size = 50000
sample_idx = np.random.choice(len(gt_distances), sample_size, replace=False)

fig, axes = plt.subplots(1, 2, figsize=(16, 8))

# Pearson correlation plot
ax = axes[0]
ax.scatter(gt_distances[sample_idx], pred_distances_scaled[sample_idx], alpha=0.2, s=5)
ax.set_title(f'Distance Correlation (Pearson)\nœÅ = {pearson_corr:.4f}', fontsize=16, weight='bold')
ax.set_xlabel('Ground Truth Distance', fontsize=12)
ax.set_ylabel('Predicted Distance (scaled)', fontsize=12)
ax.grid(True, linestyle='--', alpha=0.5)

lims = [min(ax.get_xlim()[0], ax.get_ylim()[0]), max(ax.get_xlim()[1], ax.get_ylim()[1])]
ax.plot(lims, lims, 'r--', alpha=0.75, label='Ideal')
ax.set_aspect('equal')
ax.legend()

# Spearman correlation plot
ax = axes[1]
ax.scatter(gt_distances[sample_idx], pred_distances_scaled[sample_idx], alpha=0.2, s=5)
ax.set_title(f'Distance Correlation (Spearman)\nœÅ = {spearman_corr:.4f}', fontsize=16, weight='bold')
ax.set_xlabel('Ground Truth Distance', fontsize=12)
ax.set_ylabel('Predicted Distance (scaled)', fontsize=12)
ax.grid(True, linestyle='--', alpha=0.5)

lims = [min(ax.get_xlim()[0], ax.get_ylim()[0]), max(ax.get_xlim()[1], ax.get_ylim()[1])]
ax.plot(lims, lims, 'r--', alpha=0.75, label='Ideal')
ax.set_aspect('equal')
ax.legend()

plt.tight_layout()
plt.show()

# Plot 3: Distance distributions
fig, ax = plt.subplots(figsize=(10, 6))
sns.histplot(gt_distances, color='blue', label='Ground Truth', stat='density', bins=100, alpha=0.5, ax=ax)
sns.histplot(pred_distances_scaled, color='red', label='Predicted (scaled)', stat='density', bins=100, alpha=0.5, ax=ax)
ax.set_title('Distance Distribution Comparison', fontsize=16, weight='bold')
ax.set_xlabel('Distance')
ax.legend()
plt.show()

print("\n" + "="*70)
print("COMPLETE!")
print("="*70)


# ===================================================================
# EDM HEATMAP COMPARISON
# ===================================================================
print("\n" + "="*70)
print("EDM HEATMAP VISUALIZATION")
print("="*70)

def normalize_matrix(matrix):
    min_val = matrix.min()
    max_val = matrix.max()
    return (matrix - min_val) / (max_val - min_val)

# Normalize EDMs
gt_edm_norm = normalize_matrix(gt_edm)
pred_edm_norm = normalize_matrix(D_edm_pred)

# Sample cells for visualization
sample_size = min(838, n_cells)
sample_indices = np.random.choice(n_cells, sample_size, replace=False)
sample_indices = np.sort(sample_indices)

print(f"\nCreating EDM heatmaps with {sample_size} sampled cells...")

# Create side-by-side heatmaps
fig, axes = plt.subplots(1, 2, figsize=(16, 7))
fig.suptitle('EDM Comparison: Ground Truth vs. Predicted', fontsize=18, fontweight='bold')

# Ground Truth EDM
im1 = axes[0].imshow(gt_edm_norm[np.ix_(sample_indices, sample_indices)], cmap='viridis')
axes[0].set_title('Ground Truth EDM (Normalized)', fontsize=14)
axes[0].set_xlabel('Cell Index (Sampled)', fontsize=12)
axes[0].set_ylabel('Cell Index (Sampled)', fontsize=12)
fig.colorbar(im1, ax=axes[0], fraction=0.046, pad=0.04)

# Predicted EDM
im2 = axes[1].imshow(pred_edm_norm[np.ix_(sample_indices, sample_indices)], cmap='viridis')
axes[1].set_title('Predicted EDM (Normalized)', fontsize=14)
axes[1].set_xlabel('Cell Index (Sampled)', fontsize=12)
fig.colorbar(im2, ax=axes[1], fraction=0.046, pad=0.04)

plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()

print("\n" + "="*70)
print("EDM HEATMAP VISUALIZATION COMPLETE")
print("="*70)

# ===================================================================
# ANISOTROPY ANALYSIS
# ===================================================================
print("\n" + "="*70)
print("EIGENVALUE ANISOTROPY ANALYSIS")
print("="*70)

def compute_anisotropy(coords):
    """Compute eigenvalue anisotropy ratio Œª1/Œª2"""
    X = coords.astype(float)
    Xc = X - X.mean(axis=0, keepdims=True)
    
    cov = Xc.T @ Xc / (Xc.shape[0] - 1)
    eigvals, eigvecs = np.linalg.eigh(cov)
    eigvals = eigvals[::-1]
    
    lam1, lam2 = eigvals[0], eigvals[1]
    ratio = lam1 / (lam2 + 1e-12)
    
    return lam1, lam2, ratio

# Compute anisotropy for ground truth
lam1_gt, lam2_gt, ratio_gt = compute_anisotropy(gt_coords_np)
print(f"\nGround Truth Coordinates:")
print(f"  Œª‚ÇÅ = {lam1_gt:.4f},  Œª‚ÇÇ = {lam2_gt:.4f}")
print(f"  Œª‚ÇÅ/Œª‚ÇÇ = {ratio_gt:.2f}")

if ratio_gt < 5:
    print(f"  ‚Üí GENUINELY 2D ‚úì")
elif ratio_gt < 20:
    print(f"  ‚Üí Anisotropic but still 2D-ish")
else:
    print(f"  ‚Üí EFFECTIVELY 1D (very elongated) ‚úó")

# Compute anisotropy for predicted
lam1_pred, lam2_pred, ratio_pred = compute_anisotropy(coords_pred)
print(f"\nPredicted Coordinates (Global Scale):")
print(f"  Œª‚ÇÅ = {lam1_pred:.4f},  Œª‚ÇÇ = {lam2_pred:.4f}")
print(f"  Œª‚ÇÅ/Œª‚ÇÇ = {ratio_pred:.2f}")

if ratio_pred < 5:
    print(f"  ‚Üí GENUINELY 2D ‚úì")
elif ratio_pred < 20:
    print(f"  ‚Üí Anisotropic but still 2D-ish")
else:
    print(f"  ‚Üí EFFECTIVELY 1D (very elongated) ‚úó")

# ===================================================================
# ANISOTROPY VISUALIZATIONS
# ===================================================================
print("\n=== Creating Anisotropy Plots ===\n")

fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# Plot 1: Bar comparison
ax = axes[0, 0]
methods = ['Ground Truth', 'Predicted']
ratios = [ratio_gt, ratio_pred]
colors = ['blue', 'red']

bars = ax.bar(methods, ratios, color=colors, alpha=0.7, edgecolor='black', linewidth=2)
ax.axhline(5, color='g', linestyle='--', linewidth=2, alpha=0.5, label='2D threshold (5)')
ax.axhline(20, color='orange', linestyle='--', linewidth=2, alpha=0.5, label='1D threshold (20)')

ax.set_ylabel('$\\lambda_1/\\lambda_2$ (Anisotropy Ratio)', fontsize=13, fontweight='bold')
ax.set_title('Anisotropy Comparison: Ground Truth vs Predicted', fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3, axis='y')

for bar, ratio in zip(bars, ratios):
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2., height,
            f'{ratio:.2f}',
            ha='center', va='bottom', fontsize=12, fontweight='bold')

# Plot 2: Eigenvalue scatter
ax = axes[0, 1]
ax.scatter(lam2_gt, lam1_gt, c='blue', s=300, marker='o', 
          edgecolors='darkblue', linewidth=2, label='Ground Truth', zorder=5)
ax.scatter(lam2_pred, lam1_pred, c='red', s=300, marker='*', 
          edgecolors='darkred', linewidth=2, label='Predicted', zorder=5)

min_val = min(lam2_gt, lam2_pred)
max_val = max(lam1_gt, lam1_pred)
ax.plot([min_val, max_val], [min_val, max_val], 'k--', 
        linewidth=2, label='$\\lambda_1 = \\lambda_2$', alpha=0.7)

ax.set_xlabel('$\\lambda_2$ (Smaller Eigenvalue)', fontsize=13, fontweight='bold')
ax.set_ylabel('$\\lambda_1$ (Larger Eigenvalue)', fontsize=13, fontweight='bold')
ax.set_title('Eigenvalue Comparison', fontsize=14, fontweight='bold')
ax.legend(fontsize=11, loc='upper left')
ax.grid(True, alpha=0.3)

# Plot 3: Ground truth coordinates with anisotropy
ax = axes[1, 0]
ax.scatter(gt_coords_np[:, 0], gt_coords_np[:, 1], alpha=0.5, s=10, 
          c='blue', edgecolors='none')
ax.set_xlabel('Dimension 1', fontsize=13, fontweight='bold')
ax.set_ylabel('Dimension 2', fontsize=13, fontweight='bold')
ax.set_title(f'Ground Truth\n$\\lambda_1/\\lambda_2$ = {ratio_gt:.2f}', 
             fontsize=14, fontweight='bold', color='blue')
ax.grid(True, alpha=0.3)
ax.set_aspect('equal', adjustable='box')

# Plot 4: Predicted coordinates with anisotropy
ax = axes[1, 1]
ax.scatter(coords_pred[:, 0], coords_pred[:, 1], alpha=0.5, s=10, 
          c='red', edgecolors='none')
ax.set_xlabel('Dimension 1', fontsize=13, fontweight='bold')
ax.set_ylabel('Dimension 2', fontsize=13, fontweight='bold')
ax.set_title(f'Predicted (Global Scale)\n$\\lambda_1/\\lambda_2$ = {ratio_pred:.2f}', 
             fontsize=14, fontweight='bold', color='red')
ax.grid(True, alpha=0.3)
ax.set_aspect('equal', adjustable='box')

plt.tight_layout()
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
anisotropy_path = os.path.join(output_dir, f'patchwise_anisotropy_{timestamp}.png')
# plt.savefig(anisotropy_path, dpi=300, bbox_inches='tight')
print(f"‚úì Saved anisotropy plot: {anisotropy_path}")
plt.show()

# ===================================================================
# DETAILED COMPARISON TABLE
# ===================================================================
print("\n" + "="*70)
print("ANISOTROPY COMPARISON TABLE")
print("="*70)

print(f"\n{'Metric':<35} {'Ground Truth':<20} {'Predicted':<20}")
print("-" * 70)
print(f"{'Œª‚ÇÅ (larger eigenvalue)':<35} {lam1_gt:.4f}            {lam1_pred:.4f}")
print(f"{'Œª‚ÇÇ (smaller eigenvalue)':<35} {lam2_gt:.4f}            {lam2_pred:.4f}")
print(f"{'Œª‚ÇÅ/Œª‚ÇÇ ratio':<35} {ratio_gt:.2f}                {ratio_pred:.2f}")
print(f"{'Difference in Œª‚ÇÅ/Œª‚ÇÇ':<35} {abs(ratio_gt - ratio_pred):.2f}")

print("\n" + "="*70)
print("INTERPRETATION")
print("="*70)

ratio_diff = abs(ratio_gt - ratio_pred)

if ratio_diff < 2:
    print(f"\n‚úì EXCELLENT: Anisotropy closely matches ground truth (diff = {ratio_diff:.2f})")
elif ratio_diff < 5:
    print(f"\n‚úì GOOD: Moderate anisotropy difference (diff = {ratio_diff:.2f})")
else:
    print(f"\n‚ö† WARNING: Large anisotropy difference (diff = {ratio_diff:.2f})")

if ratio_pred < 5 and ratio_gt < 5:
    print(f"\n‚úì Both ground truth and predicted preserve 2D geometry")
elif ratio_pred < 5:
    print(f"\n‚úì Predicted successfully preserves 2D geometry")
    print(f"  Ground truth is more anisotropic (Œª‚ÇÅ/Œª‚ÇÇ = {ratio_gt:.2f})")
elif ratio_gt < 5:
    print(f"\n‚ö† Ground truth is 2D but predicted shows elongation")
    print(f"  Predicted Œª‚ÇÅ/Œª‚ÇÇ = {ratio_pred:.2f}")
else:
    print(f"\n‚ö† Both show anisotropic structure")

print("\n" + "="*70)
print("ANISOTROPY ANALYSIS COMPLETE")
print("="*70)

# ===================================================================
# CELL TYPE VISUALIZATION (GROUND TRUTH vs PREDICTED)
# ===================================================================
print("\n" + "="*70)
print("CELL TYPE VISUALIZATION")
print("="*70)

# Check what cell type column exists
print("\nAvailable columns in adata_sc.obs:")
print(list(adata_sc.obs.columns))

# Find cell type column
cell_type_col = None
for col in ['cell_type', 'celltype', 'cluster', 'annotation', 'cell_ontology_class']:
    if col in adata_sc.obs.columns:
        cell_type_col = col
        break

if cell_type_col is None:
    print("\nWARNING: No cell type column found. Using first categorical column or creating dummy labels.")
    categorical_cols = adata_sc.obs.select_dtypes(include=['category', 'object']).columns
    if len(categorical_cols) > 0:
        cell_type_col = categorical_cols[0]
    else:
        adata_sc.obs['cell_type'] = 'Unknown'
        cell_type_col = 'cell_type'

print(f"\nUsing cell type column: '{cell_type_col}'")
cell_types = adata_sc.obs[cell_type_col].values

# Get unique cell types
unique_types = np.unique(cell_types)
n_types = len(unique_types)

print(f"Found {n_types} unique cell types:")
for i, ct in enumerate(unique_types):
    count = (cell_types == ct).sum()
    print(f"  {i+1}. {ct}: {count} cells")

# Create colormap
if n_types <= 10:
    cmap = plt.cm.tab10
elif n_types <= 20:
    cmap = plt.cm.tab20
else:
    cmap = plt.cm.gist_ncar

# Map cell types to colors
type_to_color = {ct: cmap(i / n_types) for i, ct in enumerate(unique_types)}

# Create side-by-side plots
fig, axes = plt.subplots(1, 2, figsize=(20, 8))

# Plot 1: Ground Truth
ax = axes[0]
for ct in unique_types:
    mask = cell_types == ct
    ax.scatter(gt_coords_np[mask, 0], gt_coords_np[mask, 1], 
              c=[type_to_color[ct]], label=ct, s=15, alpha=0.7, edgecolors='none')

ax.set_xlabel('Dimension 1', fontsize=13, fontweight='bold')
ax.set_ylabel('Dimension 2', fontsize=13, fontweight='bold')
ax.set_title('Ground Truth - Cell Types', fontsize=14, fontweight='bold')
ax.set_aspect('equal', adjustable='box')
ax.grid(True, alpha=0.3)

# Plot 2: Predicted
ax = axes[1]
for ct in unique_types:
    mask = cell_types == ct
    ax.scatter(coords_pred[mask, 0], coords_pred[mask, 1], 
              c=[type_to_color[ct]], label=ct, s=15, alpha=0.7, edgecolors='none')

ax.set_xlabel('Dimension 1', fontsize=13, fontweight='bold')
ax.set_ylabel('Dimension 2', fontsize=13, fontweight='bold')
ax.set_title('Predicted (Global Scale) - Cell Types', fontsize=14, fontweight='bold')
ax.set_aspect('equal', adjustable='box')
ax.grid(True, alpha=0.3)

# Add legend outside the plot
if n_types <= 15:
    handles, labels = axes[1].get_legend_handles_labels()
    fig.legend(handles, labels, loc='center left', bbox_to_anchor=(1.0, 0.5), 
              fontsize=10, title='Cell Type', title_fontsize=12, frameon=True)

plt.tight_layout(rect=[0, 0, 0.95, 1])
celltype_path = os.path.join(output_dir, f'patchwise_celltype_{timestamp}.png')
# plt.savefig(celltype_path, dpi=300, bbox_inches='tight')
print(f"\n‚úì Saved cell type visualization: {celltype_path}")
plt.show()

print("\n" + "="*70)
print("CELL TYPE VISUALIZATION COMPLETE")
print("="*70)

In [None]:
# ===================================================================
# k-NN PRESERVATION ANALYSIS
# ===================================================================

from sklearn.neighbors import NearestNeighbors

print("\n" + "="*70)
print("k-NN PRESERVATION ANALYSIS")
print("="*70)

def compute_knn_preservation(coords_gt, coords_pred, k=10):
    """
    Compute k-nearest neighbor preservation.
    Returns average number of preserved neighbors per cell.
    """
    n = coords_gt.shape[0]

    nbrs_gt = NearestNeighbors(n_neighbors=k+1, algorithm='ball_tree').fit(coords_gt)
    _, indices_gt = nbrs_gt.kneighbors(coords_gt)
    indices_gt = indices_gt[:, 1:]

    nbrs_pred = NearestNeighbors(n_neighbors=k+1, algorithm='ball_tree').fit(coords_pred)
    _, indices_pred = nbrs_pred.kneighbors(coords_pred)
    indices_pred = indices_pred[:, 1:]

    overlaps = []
    for i in range(n):
        gt_neighbors = set(indices_gt[i])
        pred_neighbors = set(indices_pred[i])
        overlap = len(gt_neighbors.intersection(pred_neighbors))
        overlaps.append(overlap)

    return np.mean(overlaps), overlaps

knn_k10, overlaps_k10 = compute_knn_preservation(gt_coords_np, coords_pred, k=10)
knn_k20, overlaps_k20 = compute_knn_preservation(gt_coords_np, coords_pred, k=20)

print(f"\nk-NN Preservation Results:")
print(f" k=10: {knn_k10:.2f} / 10 ({knn_k10/10*100:.1f}% neighbors preserved)")
print(f" k=20: {knn_k20:.2f} / 20 ({knn_k20/20*100:.1f}% neighbors preserved)")

print(f"\nComparison:")
print(f" EDM Pearson: {pearson_corr:.4f} (global distances)")
print(f" EDM Spearman: {spearman_corr:.4f} (global distances)")
print(f" k-NN@10: {knn_k10/10:.4f} (local neighborhoods)")
print(f" k-NN@20: {knn_k20/20:.4f} (local neighborhoods)")

print("\n" + "="*70)
# ===================================================================
# k-NN PRESERVATION ANALYSIS (ENHANCED)
# ===================================================================

from sklearn.neighbors import NearestNeighbors
from scipy.stats import spearmanr

print("\n" + "="*70)
print("k-NN PRESERVATION ANALYSIS (ENHANCED)")
print("="*70)

# --- Helper functions ---

def knn_sets(coords, k):
    """Get k-NN indices for all points."""
    nbrs = NearestNeighbors(n_neighbors=k+1, algorithm='ball_tree').fit(coords)
    _, idx = nbrs.kneighbors(coords)
    return idx[:, 1:]

def knn_overlap_frac(idx_gt, idx_pr):
    """Compute per-point overlap fraction: |intersection| / k"""
    n, k = idx_gt.shape
    out = np.empty(n, dtype=np.float32)
    for i in range(n):
        out[i] = len(set(idx_gt[i]).intersection(idx_pr[i])) / k
    return out

def knn_jaccard(idx_gt, idx_pr):
    """Compute per-point Jaccard: |intersection| / |union|"""
    n, k = idx_gt.shape
    out = np.empty(n, dtype=np.float32)
    for i in range(n):
        a = set(idx_gt[i])
        b = set(idx_pr[i])
        out[i] = len(a & b) / max(1, len(a | b))
    return out

def kth_neighbor_radius(coords, k):
    """Get distance to k-th neighbor for all points."""
    nbrs = NearestNeighbors(n_neighbors=k+1, algorithm='ball_tree').fit(coords)
    d, _ = nbrs.kneighbors(coords)
    return d[:, k]

def local_spearman_within_radius(gt_coords, pr_coords, R, min_n=10, fallback_k=30):
    """Compute per-point local Spearman correlation within radius R."""
    nbrs_gt = NearestNeighbors(radius=R, algorithm='ball_tree').fit(gt_coords)
    ind = nbrs_gt.radius_neighbors(gt_coords, return_distance=False)

    nbrs_gt_k = NearestNeighbors(n_neighbors=fallback_k+1, algorithm='ball_tree').fit(gt_coords)
    _, idx_k = nbrs_gt_k.kneighbors(gt_coords)
    idx_k = idx_k[:, 1:]

    n = gt_coords.shape[0]
    vals = np.full(n, np.nan, dtype=np.float32)

    for i in range(n):
        neigh = ind[i]
        neigh = neigh[neigh != i]
        if neigh.shape[0] < min_n:
            neigh = idx_k[i]

        d_gt = np.linalg.norm(gt_coords[neigh] - gt_coords[i], axis=1)
        d_pr = np.linalg.norm(pr_coords[neigh] - pr_coords[i], axis=1)

        if np.std(d_gt) < 1e-12 or np.std(d_pr) < 1e-12:
            continue

        vals[i] = spearmanr(d_gt, d_pr).correlation

    return vals

def soft_weighted_jaccard(gt_coords, pr_coords, k=20, tau=None):
    """Compute soft distance-weighted Jaccard (gives partial credit for near-misses)."""
    nbrs_gt = NearestNeighbors(n_neighbors=k+1, algorithm='ball_tree').fit(gt_coords)
    d_gt, idx_gt = nbrs_gt.kneighbors(gt_coords)
    d_gt, idx_gt = d_gt[:, 1:], idx_gt[:, 1:]

    nbrs_pr = NearestNeighbors(n_neighbors=k+1, algorithm='ball_tree').fit(pr_coords)
    d_pr, idx_pr = nbrs_pr.kneighbors(pr_coords)
    d_pr, idx_pr = d_pr[:, 1:], idx_pr[:, 1:]

    if tau is None:
        tau = np.median(d_gt[:, -1]) + 1e-12

    n = gt_coords.shape[0]
    out = np.empty(n, dtype=np.float32)

    for i in range(n):
        wgt = {int(j): np.exp(-float(d)/tau) for j, d in zip(idx_gt[i], d_gt[i])}
        wpr = {int(j): np.exp(-float(d)/tau) for j, d in zip(idx_pr[i], d_pr[i])}

        keys = set(wgt.keys()) | set(wpr.keys())
        num = 0.0
        den = 0.0
        for j in keys:
            a = wgt.get(j, 0.0)
            b = wpr.get(j, 0.0)
            num += min(a, b)
            den += max(a, b)
        out[i] = num / max(1e-12, den)

    return out, tau

def compute_hits(gt_knn_k, pred_knn_m):
    """
    Compute h_i(k,m) = |G_i^(k) ‚à© P_i^(m)| for each point i.
    
    Args:
        gt_knn_k: (N, k) array of GT top-k neighbor indices
        pred_knn_m: (N, m) array of predicted top-m neighbor indices
    
    Returns:
        hits: (N,) array of hit counts per point
    """
    n = gt_knn_k.shape[0]
    k = gt_knn_k.shape[1]
    hits = np.zeros(n, dtype=np.int32)
    
    for i in range(n):
        gt_set = set(gt_knn_k[i])
        pred_set = set(pred_knn_m[i])
        hits[i] = len(gt_set & pred_set)
    
    return hits

def recall_at_k(gt_coords, pred_coords, k=10):
    """
    Recall@k = mean(h_i(k,k) / k)
    Strict metric: fraction of GT top-k recovered in predicted top-k.
    """
    gt_knn = knn_sets(gt_coords, k)
    pred_knn = knn_sets(pred_coords, k)
    hits = compute_hits(gt_knn, pred_knn)
    return hits.mean() / k, hits / k

def nearmiss_at_m(gt_coords, pred_coords, k_base=10, m=20):
    """
    NearMiss@m(k) = mean(h_i(k,m) / k)
    Tolerant metric: fraction of GT top-k that appear in predicted top-m.
    """
    gt_knn = knn_sets(gt_coords, k_base)
    pred_knn = knn_sets(pred_coords, m)
    hits = compute_hits(gt_knn, pred_knn)
    return hits.mean() / k_base, hits / k_base

# ===================================================================
# 1) HARD kNN OVERLAP + JACCARD (k=10, 20, 50)
# ===================================================================
print("\n--- Hard k-NN Metrics ---")

idx_gt_10 = knn_sets(gt_coords_np, 10)
idx_pr_10 = knn_sets(coords_pred, 10)
idx_gt_20 = knn_sets(gt_coords_np, 20)
idx_pr_20 = knn_sets(coords_pred, 20)
idx_gt_50 = knn_sets(gt_coords_np, 50)
idx_pr_50 = knn_sets(coords_pred, 50)

for k, ig, ip in [(10, idx_gt_10, idx_pr_10), (20, idx_gt_20, idx_pr_20), (50, idx_gt_50, idx_pr_50)]:
    ov = knn_overlap_frac(ig, ip)
    jc = knn_jaccard(ig, ip)
    print(f"[KNN] k={k:2d}: overlap mean={ov.mean():.3f} p50={np.median(ov):.3f} | "
          f"jaccard mean={jc.mean():.3f} p50={np.median(jc):.3f}")

# ===================================================================
# 2) RECALL@k AND NEARMISS@m METRICS
# ===================================================================
print("\n--- Recall and NearMiss Metrics ---")

recall_10, recall_10_per_point = recall_at_k(gt_coords_np, coords_pred, k=10)
print(f"[RECALL@10] mean={recall_10:.4f} p50={np.median(recall_10_per_point):.4f}")

nearmiss_20, nearmiss_20_per_point = nearmiss_at_m(gt_coords_np, coords_pred, k_base=10, m=20)
print(f"[NEARMISS@20] (k_base=10) mean={nearmiss_20:.4f} p50={np.median(nearmiss_20_per_point):.4f}")

nearmiss_50, nearmiss_50_per_point = nearmiss_at_m(gt_coords_np, coords_pred, k_base=10, m=50)
print(f"[NEARMISS@50] (k_base=10) mean={nearmiss_50:.4f} p50={np.median(nearmiss_50_per_point):.4f}")

# ===================================================================
# 3) LOCAL SPEARMAN (distance ordering within local neighborhood)
# ===================================================================
print("\n--- Local Spearman Correlation ---")

r20 = kth_neighbor_radius(gt_coords_np, 20)
R = np.median(r20)
print(f"[LOCAL-RADIUS] R = median GT d(20) = {R:.6f}")

rho_local = local_spearman_within_radius(gt_coords_np, coords_pred, R)
good = np.isfinite(rho_local)
print(f"[LOCAL-SPEARMAN] finite_frac={good.mean():.2%} "
      f"mean={np.nanmean(rho_local):.3f} p50={np.nanmedian(rho_local):.3f} "
      f"p10={np.nanpercentile(rho_local,10):.3f} p90={np.nanpercentile(rho_local,90):.3f}")

# ===================================================================
# 4) SOFT WEIGHTED JACCARD (partial credit for near-misses)
# ===================================================================
print("\n--- Soft Weighted Jaccard ---")

sj20, tau20 = soft_weighted_jaccard(gt_coords_np, coords_pred, k=20, tau=None)
sj50, tau50 = soft_weighted_jaccard(gt_coords_np, coords_pred, k=50, tau=None)
print(f"[SOFT-JACCARD] k=20 tau={tau20:.6f}: mean={sj20.mean():.3f} p50={np.median(sj20):.3f}")
print(f"[SOFT-JACCARD] k=50 tau={tau50:.6f}: mean={sj50.mean():.3f} p50={np.median(sj50):.3f}")

# ===================================================================
# SUMMARY COMPARISON
# ===================================================================
print("\n" + "="*70)
print("SUMMARY")
print("="*70)

ov_10 = knn_overlap_frac(idx_gt_10, idx_pr_10).mean()
jc_10 = knn_jaccard(idx_gt_10, idx_pr_10).mean()

print(f"\n  Global Metrics:")
print(f"    EDM Pearson:      {pearson_corr:.4f}")
print(f"    EDM Spearman:     {spearman_corr:.4f}")

print(f"\n  Local Metrics (Hard):")
print(f"    kNN@10 overlap:   {ov_10:.4f}")
print(f"    kNN@10 Jaccard:   {jc_10:.4f}")
print(f"    Recall@10:        {recall_10:.4f}")

print(f"\n  Local Metrics (Tolerant):")
print(f"    NearMiss@20:      {nearmiss_20:.4f}")
print(f"    NearMiss@50:      {nearmiss_50:.4f}")

print(f"\n  Local Metrics (Stable):")
print(f"    Local Spearman:   {np.nanmean(rho_local):.4f}")
print(f"    Soft Jaccard@20:  {sj20.mean():.4f}")

In [None]:
import torch
import numpy as np
from datetime import datetime
import sys
import os
import scanpy as sc
from scipy.spatial.distance import pdist, squareform
from scipy.stats import pearsonr, spearmanr
import matplotlib.pyplot as plt
import seaborn as sns

sys.path.insert(0, '/home/ehtesamul/sc_st/model')
from core_models_et_p3 import GEMSModel, infer_anchor_train_from_checkpoint
import utils_et as uet

# ===================================================================
# PATHS AND CONFIG
# ===================================================================

output_dir = "/home/ehtesamul/sc_st/model/gems_mousebrain_output_anchored"
output_dir_old = "/home/ehtesamul/sc_st/model/gems_mousebrain_output"

checkpoint_path = f"{output_dir}/phase1_st_checkpoint.pt"

device = 'cuda' if torch.cuda.is_available() else 'cpu'

print("="*70)
print("LOADING DATA AND MODEL")
print("="*70)

# Load data
adata_sc = sc.read_h5ad(f"{output_dir_old}/scadata_with_gems_20260112_023604.h5ad")

if hasattr(adata_sc, 'raw') and adata_sc.raw is not None:
    sc_expr = torch.tensor(adata_sc.raw.X.toarray() if hasattr(adata_sc.raw.X, 'toarray') else adata_sc.raw.X, dtype=torch.float32)
else:
    sc_expr = torch.tensor(adata_sc.X.toarray() if hasattr(adata_sc.X, 'toarray') else adata_sc.X, dtype=torch.float32)

# ===================================================================
# LOAD AND NORMALIZE GT COORDS
# ===================================================================
gt_coords_raw = adata_sc.obsm['spatial_gt']
n_cells, n_genes = sc_expr.shape
print(f"‚úì Loaded SC data: {n_cells} cells √ó {n_genes} genes")
print(f"‚úì Ground truth coords (raw): {gt_coords_raw.shape}")

gt_coords_tensor = torch.tensor(gt_coords_raw, dtype=torch.float32, device=device)
slide_ids = torch.zeros(gt_coords_tensor.shape[0], dtype=torch.long, device=device)

gt_coords_norm, gt_mu, gt_scale = uet.canonicalize_st_coords_per_slide(
    gt_coords_tensor, slide_ids
)

gt_coords_np = gt_coords_norm.cpu().numpy()

print(f"‚úì GT coords normalized: scale={gt_scale[0].item():.4f}")
print(f"‚úì GT coords RMS: {gt_coords_norm.pow(2).mean().sqrt().item():.4f}")
print(f"‚úì GT coords range: X=[{gt_coords_np[:,0].min():.3f}, {gt_coords_np[:,0].max():.3f}], "
      f"Y=[{gt_coords_np[:,1].min():.3f}, {gt_coords_np[:,1].max():.3f}]")

# ===================================================================
# AUTO-DETECT ANCHOR MODE FROM CHECKPOINT BEFORE MODEL INIT
# ===================================================================
print("\n" + "="*70)
print("AUTO-DETECTING ANCHOR MODE FROM CHECKPOINT")
print("="*70)

checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)

# Detect anchor mode (base_h_dim=128 for [512,256,128] encoder)
base_h_dim = 128  # Last dimension of n_embedding
anchor_train_detected = infer_anchor_train_from_checkpoint(checkpoint, base_h_dim)

print(f"‚úì Detected anchor_train={anchor_train_detected}")

# ===================================================================
# INITIALIZE MODEL WITH CORRECT ANCHOR MODE
# ===================================================================
print("\n" + "="*70)
print("INITIALIZING MODEL")
print("="*70)

model = GEMSModel(
    n_genes=n_genes,
    n_embedding=[512, 256, 128],
    D_latent=32,
    c_dim=256,
    n_heads=4,
    isab_m=128,
    dist_bins=24,
    device=device,
    anchor_train=anchor_train_detected,  # USE DETECTED VALUE!
)

print(f"‚úì Model initialized with anchor_train={model.anchor_train}")
print(f"‚úì context_encoder.input_dim={model.context_encoder.input_dim}")

# ===================================================================
# LOAD CHECKPOINT WEIGHTS
# ===================================================================
model.encoder.load_state_dict(checkpoint['encoder'])
model.context_encoder.load_state_dict(checkpoint['context_encoder'])
model.generator.load_state_dict(checkpoint['generator'])
model.score_net.load_state_dict(checkpoint['score_net'])

# Restore EDM parameters
if 'sigma_data' in checkpoint:
    model.sigma_data = checkpoint['sigma_data']
if 'sigma_min' in checkpoint:
    model.sigma_min = checkpoint['sigma_min']
if 'sigma_max' in checkpoint:
    model.sigma_max = checkpoint['sigma_max']

model.encoder.eval()
model.context_encoder.eval()
model.generator.eval()
model.score_net.eval()

print(f"‚úì Model loaded and set to eval mode")
print(f"‚úì sigma_data={getattr(model, 'sigma_data', 'N/A')}")


# ===================================================================
# COMPUTE CORAL TRANSFORMATION
# ===================================================================
print("\n" + "="*70)
print("COMPUTING CORAL TRANSFORMATION")
print("="*70)

# 1. Load ST data (you need this to compute ST context distribution)
import pandas as pd
import anndata as ad

st_counts = '/home/ehtesamul/sc_st/data/mousedata_2020/E1z2/simu_st1_counts_et.csv'
st_meta = '/home/ehtesamul/sc_st/data/mousedata_2020/E1z2/simu_st1_metadata_et.csv'

st_expr_df = pd.read_csv(st_counts, index_col=0)
st_meta_df = pd.read_csv(st_meta, index_col=0)
stadata = ad.AnnData(X=st_expr_df.values.T)
stadata.obs_names = st_expr_df.columns
stadata.var_names = st_expr_df.index
stadata.obsm['spatial'] = st_meta_df[['coord_x', 'coord_y']].values

# Get common genes
common = sorted(list(set(adata_sc.var_names) & set(stadata.var_names)))
X_st = stadata[:, common].X
if hasattr(X_st, "toarray"): 
    X_st = X_st.toarray()

st_expr = torch.tensor(X_st, dtype=torch.float32, device=device)
st_coords_raw = torch.tensor(stadata.obsm['spatial'], dtype=torch.float32, device=device)
slide_ids = torch.zeros(st_expr.shape[0], dtype=torch.long, device=device)
st_coords, _, _ = uet.canonicalize_st_coords_per_slide(st_coords_raw, slide_ids)

print(f"‚úì Loaded ST data: {st_expr.shape[0]} spots √ó {st_expr.shape[1]} genes")

# 2. Prepare ST gene expression dict
st_gene_expr_dict = {0: st_expr.cpu()}

# 3. Load Stage B targets (needed for compute_coral_params_from_st)
# This should be in your checkpoint or output directory
targets_path = f"{output_dir}/stageB_targets/targets_dict.pt"
if os.path.exists(targets_path):
    model.targets_dict = torch.load(targets_path, map_location='cpu')
    print(f"‚úì Loaded targets_dict from {targets_path}")
else:
    print(f"‚ö†Ô∏è  Targets not found at {targets_path}")
    print("   Running Stage B precomputation...")
    slides_dict = {0: (st_coords, st_expr)}
    model.train_stageB(slides=slides_dict, outdir=f"{output_dir}/stageB_targets")
    print("‚úì Stage B complete")

# 4. Compute CORAL parameters
print("\n--- Computing ST context distribution ---")
model.compute_coral_params_from_st(
    st_gene_expr_dict=st_gene_expr_dict,
    n_samples=2000,
    n_min=96,
    n_max=384,
)

print("\n--- Building CORAL transformation ---")
model.build_coral_transform(
    sc_gene_expr=sc_expr,
    n_samples=2000,
    n_min=96,
    n_max=384,
    shrink=0.01,
    eps=1e-5,
)

print("‚úì CORAL transformation ready!")

print("\n" + "="*70)

# ===================================================================
# RUN INFERENCE (NEW GLOBAL SCALE VERSION)
# ===================================================================
print("\n" + "="*70)
print("RUNNING INFERENCE (GLOBAL SCALE)")
print("="*70)

# ADD THIS after loading checkpoint and before calling inference
if 'sigma_data' in checkpoint:
    model.sigma_data = checkpoint['sigma_data']
if 'sigma_min' in checkpoint:
    model.sigma_min = checkpoint['sigma_min']
if 'sigma_max' in checkpoint:
    model.sigma_max = checkpoint['sigma_max']


# with torch.no_grad():
#     results = model.infer_sc_patchwise(
#         sc_gene_expr=sc_expr,
#         n_timesteps_sample=600,
#         return_coords=True,
#         patch_size=192,
#         coverage_per_cell=8.0,
#         n_align_iters=15,
#         eta=0.0,
#         guidance_scale=2.0,
#         gt_coords=gt_coords_norm,
#         debug_knn=True,
#         debug_max_patches=15,
#         debug_k_list=(10, 20),
#         pool_mult=2.0,
#         stochastic_tau=0.8,
#         tau_mode="adaptive_kth",
#         ensure_connected=True,
#         local_refine=False,
#         anchor_sampling_mode="edm_anchor_local",  # NEW PARAMETER
#         commit_frac=1.0,
#         seq_align_dim=32,
#     )

with torch.no_grad():
    results = model.infer_sc_patchwise(
        sc_gene_expr=sc_expr,
        n_timesteps_sample=500,
        return_coords=True,
        patch_size=192,
        coverage_per_cell=6.0,
        n_align_iters=15,
        eta=0.0,
        guidance_scale=2.0,
        gt_coords=gt_coords_norm,
        debug_knn=True,
        debug_max_patches=15,
        debug_k_list=(10, 20),
        pool_mult=2.0,
        stochastic_tau=0.8,
        tau_mode="adaptive_kth",
        ensure_connected=True,
        local_refine=False,
        # ========== NEW PARAMETERS ==========
        inference_mode="anchored",  # NEW: Use this instead of anchor_sampling_mode
        anchor_sampling_mode="align_vote_only",  # Still specify the specific anchored method
        commit_frac=0.6,
        seq_align_dim=2,
    )



D_edm_pred = results['D_edm'].cpu().numpy()
coords_pred = results['coords_canon'].cpu().numpy()
print(f"\n‚úì Inference complete!")
print(f"  Predicted EDM: {D_edm_pred.shape}")
print(f"  Predicted coords: {coords_pred.shape}")

In [None]:
# ===================================================================
# SINGLE-PATCH INFERENCE EXPERIMENT (NO STITCHING)
# Set-size sweep: 96, 192, 256, 384, 838
# Metrics: kNN@10, kNN@20, Pearson, Spearman (local patch only)
# ===================================================================

import torch
import numpy as np
from sklearn.neighbors import NearestNeighbors
from scipy.stats import pearsonr, spearmanr
from scipy.spatial.distance import pdist, squareform
import pandas as pd

print("="*70)
print("SINGLE-PATCH QUALITY EXPERIMENT")
print("="*70)

# ===================================================================
# HELPER FUNCTIONS
# ===================================================================

def compute_knn_overlap(coords_gt, coords_pred, k=10):
    """Compute k-NN overlap fraction."""
    n = coords_gt.shape[0]
    
    # GT kNN
    nbrs_gt = NearestNeighbors(n_neighbors=k+1, algorithm='ball_tree').fit(coords_gt)
    _, idx_gt = nbrs_gt.kneighbors(coords_gt)
    idx_gt = idx_gt[:, 1:]  # Exclude self
    
    # Pred kNN
    nbrs_pred = NearestNeighbors(n_neighbors=k+1, algorithm='ball_tree').fit(coords_pred)
    _, idx_pred = nbrs_pred.kneighbors(coords_pred)
    idx_pred = idx_pred[:, 1:]  # Exclude self
    
    # Compute overlap per cell
    overlaps = []
    for i in range(n):
        gt_set = set(idx_gt[i])
        pred_set = set(idx_pred[i])
        overlap = len(gt_set & pred_set) / k
        overlaps.append(overlap)
    
    return np.mean(overlaps)


def compute_edm_correlations(coords_gt, coords_pred):
    """Compute Pearson and Spearman correlations of EDM distances."""
    # Compute EDMs
    edm_gt = squareform(pdist(coords_gt, 'euclidean'))
    edm_pred = squareform(pdist(coords_pred, 'euclidean'))
    
    # Extract upper triangle
    n = coords_gt.shape[0]
    triu_idx = np.triu_indices(n, k=1)
    dist_gt = edm_gt[triu_idx]
    dist_pred = edm_pred[triu_idx]
    
    # Scale alignment (median normalization)
    scale = np.median(dist_gt) / np.median(dist_pred)
    dist_pred_scaled = dist_pred * scale
    
    # Correlations
    pearson_corr, _ = pearsonr(dist_gt, dist_pred_scaled)
    spearman_corr, _ = spearmanr(dist_gt, dist_pred_scaled)
    
    return pearson_corr, spearman_corr


def run_single_patch_inference(model, sc_expr, gt_coords_norm, subset_idx, n_timesteps=600, guidance_scale=2.0):
    """
    Run inference on a single patch (subset of cells).
    Returns predicted coordinates.
    """
    # Extract subset
    sc_expr_subset = sc_expr[subset_idx]
    gt_coords_subset = gt_coords_norm[subset_idx]
    
    with torch.no_grad():
        # Encode genes
        H_latent = model.encoder(sc_expr_subset.to(device))  # (m, D_latent)
        
        # Create mask (all cells in patch)
        m = sc_expr_subset.shape[0]
        mask = torch.ones(m, dtype=torch.bool, device=device)
        
        # Build context (apply CORAL if available)
        Z = H_latent
        if hasattr(model, 'coral_params') and model.coral_params is not None:
            H_ctx = GEMSModel.apply_coral_transform(
                H_latent.unsqueeze(0),  # (1, m, D_latent)
                mu_sc=model.coral_params['mu_sc'],
                A=model.coral_params['A'],
                B=model.coral_params['B'],
                mu_st=model.coral_params['mu_st']
            ).squeeze(0)
        else:
            H_ctx = H_latent
        
        # Add anchor channel if needed
        if model.anchor_train:
            anchor_channel = torch.zeros(m, 1, device=device)
            Z_ctx = torch.cat([H_ctx, anchor_channel], dim=-1)  # (m, D_latent+1)
        else:
            Z_ctx = H_ctx
        
        # Encode context
        H_context = model.context_encoder(Z_ctx.unsqueeze(0), mask.unsqueeze(0))  # (1, m, c_dim)
        
        # Generate initial proposal
        V_gen = model.generator(H_context, mask.unsqueeze(0)).squeeze(0)  # (m, D_latent)
        
        # EDM sampling
        sigma_data = getattr(model, 'sigma_data', 1.0)
        sigma_min = getattr(model, 'sigma_min', 0.02)
        sigma_max = getattr(model, 'sigma_max', 3.0)
        
        import utils_et as uet
        sigmas = uet.edm_sigma_schedule(n_timesteps, sigma_min, sigma_max, rho=7.0, device=device)
        
        # Initialize with noise
        V_t = V_gen + sigmas[0] * torch.randn_like(V_gen)
        
        # Denoising loop (Euler method)
        for i in range(len(sigmas) - 1):
            sigma = sigmas[i]
            sigma_next = sigmas[i + 1]
            
            # Score network prediction
            x0_c = model.score_net.forward_edm(
                V_t.unsqueeze(0), 
                sigma.view(1), 
                H_context, 
                mask.unsqueeze(0), 
                sigma_data
            ).squeeze(0)
            
            # CFG if guidance > 1
            if guidance_scale != 1.0:
                H_null = torch.zeros_like(H_context)
                x0_u = model.score_net.forward_edm(
                    V_t.unsqueeze(0), 
                    sigma.view(1), 
                    H_null, 
                    mask.unsqueeze(0), 
                    sigma_data
                ).squeeze(0)
                
                x0 = x0_u + guidance_scale * (x0_c - x0_u)
            else:
                x0 = x0_c
            
            # Euler step
            d = (V_t - x0) / sigma.clamp_min(1e-8)
            V_t = V_t + (sigma_next - sigma) * d
        
        # Final coordinates
        coords_pred = V_t.cpu().numpy()
        
        # Center and PCA to 2D
        coords_pred_centered = coords_pred - coords_pred.mean(axis=0, keepdims=True)
        if coords_pred.shape[1] > 2:
            from sklearn.decomposition import PCA
            pca = PCA(n_components=2)
            coords_pred_2d = pca.fit_transform(coords_pred_centered)
        else:
            coords_pred_2d = coords_pred_centered
    
    return coords_pred_2d


# ===================================================================
# EXPERIMENT: SWEEP OVER PATCH SIZES
# ===================================================================

patch_sizes = [96, 192, 256, 384, 838]
n_repeats_per_size = {
    96: 10,
    192: 10,
    256: 10,
    384: 6,
    838: 1  # Full dataset, run once
}

results = []

for patch_size in patch_sizes:
    print(f"\n{'='*70}")
    print(f"PATCH SIZE: {patch_size}")
    print(f"{'='*70}")
    
    n_repeats = n_repeats_per_size[patch_size]
    
    metrics_this_size = {
        'knn10': [],
        'knn20': [],
        'pearson': [],
        'spearman': []
    }
    
    for rep in range(n_repeats):
        # Sample random subset
        if patch_size == 838:
            subset_idx = np.arange(838)
        else:
            subset_idx = np.random.choice(838, size=patch_size, replace=False)
        
        # Get GT coords for subset
        gt_coords_subset = gt_coords_np[subset_idx]
        
        # Run inference
        print(f"  Run {rep+1}/{n_repeats}: sampling {patch_size} cells...", end=' ')
        
        coords_pred = run_single_patch_inference(
            model=model,
            sc_expr=sc_expr,
            gt_coords_norm=gt_coords_norm,
            subset_idx=subset_idx,
            n_timesteps=500,
            guidance_scale=2.0
        )
        
        # Compute metrics
        knn10 = compute_knn_overlap(gt_coords_subset, coords_pred, k=10)
        knn20 = compute_knn_overlap(gt_coords_subset, coords_pred, k=20)
        pearson, spearman = compute_edm_correlations(gt_coords_subset, coords_pred)
        
        metrics_this_size['knn10'].append(knn10)
        metrics_this_size['knn20'].append(knn20)
        metrics_this_size['pearson'].append(pearson)
        metrics_this_size['spearman'].append(spearman)
        
        print(f"kNN@10={knn10:.3f} kNN@20={knn20:.3f} Pearson={pearson:.3f} Spearman={spearman:.3f}")
    
    # Average across repeats
    results.append({
        'patch_size': patch_size,
        'n_repeats': n_repeats,
        'knn10_mean': np.mean(metrics_this_size['knn10']),
        'knn10_std': np.std(metrics_this_size['knn10']),
        'knn20_mean': np.mean(metrics_this_size['knn20']),
        'knn20_std': np.std(metrics_this_size['knn20']),
        'pearson_mean': np.mean(metrics_this_size['pearson']),
        'pearson_std': np.std(metrics_this_size['pearson']),
        'spearman_mean': np.mean(metrics_this_size['spearman']),
        'spearman_std': np.std(metrics_this_size['spearman']),
    })

# ===================================================================
# PRINT SUMMARY TABLE
# ===================================================================

print("\n" + "="*70)
print("SUMMARY: SINGLE-PATCH QUALITY vs PATCH SIZE")
print("="*70)

df_results = pd.DataFrame(results)

print("\n" + df_results.to_string(index=False))

print("\n" + "="*70)
print("EXPERIMENT COMPLETE")
print("="*70)

# Export to CSV
df_results.to_csv(f"{output_dir}/single_patch_quality_sweep.csv", index=False)
print(f"\n‚úì Results saved to {output_dir}/single_patch_quality_sweep.csv")


In [None]:
# ===================================================================
# k-NN PRESERVATION ANALYSIS
# ===================================================================

from sklearn.neighbors import NearestNeighbors

print("\n" + "="*70)
print("k-NN PRESERVATION ANALYSIS")
print("="*70)

def compute_knn_preservation(coords_gt, coords_pred, k=10):
    """
    Compute k-nearest neighbor preservation.
    Returns average number of preserved neighbors per cell.
    """
    n = coords_gt.shape[0]

    # Build k-NN for ground truth
    nbrs_gt = NearestNeighbors(n_neighbors=k+1, algorithm='ball_tree').fit(coords_gt)
    _, indices_gt = nbrs_gt.kneighbors(coords_gt)
    indices_gt = indices_gt[:, 1:]  # Remove self

    # Build k-NN for predicted
    nbrs_pred = NearestNeighbors(n_neighbors=k+1, algorithm='ball_tree').fit(coords_pred)
    _, indices_pred = nbrs_pred.kneighbors(coords_pred)
    indices_pred = indices_pred[:, 1:]  # Remove self

    # Compute overlap for each cell
    overlaps = []
    for i in range(n):
        gt_neighbors = set(indices_gt[i])
        pred_neighbors = set(indices_pred[i])
        overlap = len(gt_neighbors.intersection(pred_neighbors))
        overlaps.append(overlap)

    return np.mean(overlaps), overlaps

#Compute for k=10 and k=20

knn_k10, overlaps_k10 = compute_knn_preservation(gt_coords_np, coords_pred, k=10)
knn_k20, overlaps_k20 = compute_knn_preservation(gt_coords_np, coords_pred, k=20)

print(f"\nk-NN Preservation Results:")
print(f" k=10: {knn_k10:.2f} / 10 ({knn_k10/10*100:.1f}% neighbors preserved)")
print(f" k=20: {knn_k20:.2f} / 20 ({knn_k20/20*100:.1f}% neighbors preserved)")

print(f"\nComparison:")
print(f" EDM Pearson: {pearson_corr:.4f} (global distances)")
print(f" EDM Spearman: {spearman_corr:.4f} (global distances)")
print(f" k-NN@10: {knn_k10/10:.4f} (local neighborhoods)")
print(f" k-NN@20: {knn_k20/20:.4f} (local neighborhoods)")

print("\n" + "="*70)
print("INTERPRETATION")
print("="*70)

if knn_k10/10 < 0.3:
    print("\n‚ö†Ô∏è SEVERE local scrambling detected!")
    print(" < 30% of neighbors preserved - local structure heavily disrupted")
elif knn_k10/10 < 0.5:
    print("\n‚ö†Ô∏è MODERATE local scrambling detected")
    print(" 30-50% of neighbors preserved - significant local disruption")
elif knn_k10/10 < 0.7:
    print("\n‚úì MILD local scrambling")
    print(" 50-70% of neighbors preserved - some local structure retained")
else:
    print("\n‚úì‚úì GOOD local preservation")
    print(" > 70% of neighbors preserved - local structure mostly intact")

if pearson_corr > 0.6 and knn_k10/10 < 0.4:
    print("\nüîç DIAGNOSIS: High EDM correlation but low k-NN preservation")
    print(" ‚Üí Global geometry preserved, but specific neighbor relationships lost")
    print(" ‚Üí Consistent with context-dependent coordinate generation")

In [None]:
# ===================================================================
# k-NN PRESERVATION ANALYSIS (ENHANCED)
# ===================================================================

from sklearn.neighbors import NearestNeighbors
from scipy.stats import spearmanr

print("\n" + "="*70)
print("k-NN PRESERVATION ANALYSIS (ENHANCED)")
print("="*70)

# --- Helper functions ---

def knn_sets(coords, k):
    """Get k-NN indices for all points."""
    nbrs = NearestNeighbors(n_neighbors=k+1, algorithm='ball_tree').fit(coords)
    _, idx = nbrs.kneighbors(coords)
    return idx[:, 1:]  # (N, k) - remove self

def knn_overlap_frac(idx_gt, idx_pr):
    """Compute per-point overlap fraction: |intersection| / k"""
    n, k = idx_gt.shape
    out = np.empty(n, dtype=np.float32)
    for i in range(n):
        out[i] = len(set(idx_gt[i]).intersection(idx_pr[i])) / k
    return out  # (N,)

def knn_jaccard(idx_gt, idx_pr):
    """Compute per-point Jaccard: |intersection| / |union|"""
    n, k = idx_gt.shape
    out = np.empty(n, dtype=np.float32)
    for i in range(n):
        a = set(idx_gt[i])
        b = set(idx_pr[i])
        out[i] = len(a & b) / max(1, len(a | b))
    return out  # (N,)

def kth_neighbor_radius(coords, k):
    """Get distance to k-th neighbor for all points."""
    nbrs = NearestNeighbors(n_neighbors=k+1, algorithm='ball_tree').fit(coords)
    d, _ = nbrs.kneighbors(coords)
    return d[:, k]  # distance to k-th neighbor

def local_spearman_within_radius(gt_coords, pr_coords, R, min_n=10, fallback_k=30):
    """Compute per-point local Spearman correlation within radius R."""
    nbrs_gt = NearestNeighbors(radius=R, algorithm='ball_tree').fit(gt_coords)
    ind = nbrs_gt.radius_neighbors(gt_coords, return_distance=False)

    # For fallback when radius gives too few neighbors
    nbrs_gt_k = NearestNeighbors(n_neighbors=fallback_k+1, algorithm='ball_tree').fit(gt_coords)
    _, idx_k = nbrs_gt_k.kneighbors(gt_coords)
    idx_k = idx_k[:, 1:]

    n = gt_coords.shape[0]
    vals = np.full(n, np.nan, dtype=np.float32)

    for i in range(n):
        neigh = ind[i]
        neigh = neigh[neigh != i]
        if neigh.shape[0] < min_n:
            neigh = idx_k[i]  # fallback to kNN set

        d_gt = np.linalg.norm(gt_coords[neigh] - gt_coords[i], axis=1)
        d_pr = np.linalg.norm(pr_coords[neigh] - pr_coords[i], axis=1)

        # handle degenerate cases
        if np.std(d_gt) < 1e-12 or np.std(d_pr) < 1e-12:
            continue

        vals[i] = spearmanr(d_gt, d_pr).correlation

    return vals

def soft_weighted_jaccard(gt_coords, pr_coords, k=20, tau=None):
    """Compute soft distance-weighted Jaccard (gives partial credit for near-misses)."""
    nbrs_gt = NearestNeighbors(n_neighbors=k+1, algorithm='ball_tree').fit(gt_coords)
    d_gt, idx_gt = nbrs_gt.kneighbors(gt_coords)
    d_gt, idx_gt = d_gt[:, 1:], idx_gt[:, 1:]

    nbrs_pr = NearestNeighbors(n_neighbors=k+1, algorithm='ball_tree').fit(pr_coords)
    d_pr, idx_pr = nbrs_pr.kneighbors(pr_coords)
    d_pr, idx_pr = d_pr[:, 1:], idx_pr[:, 1:]

    # tau: data-driven default = median GT d(k)
    if tau is None:
        tau = np.median(d_gt[:, -1]) + 1e-12

    n = gt_coords.shape[0]
    out = np.empty(n, dtype=np.float32)

    for i in range(n):
        wgt = {int(j): np.exp(-float(d)/tau) for j, d in zip(idx_gt[i], d_gt[i])}
        wpr = {int(j): np.exp(-float(d)/tau) for j, d in zip(idx_pr[i], d_pr[i])}

        keys = set(wgt.keys()) | set(wpr.keys())
        num = 0.0
        den = 0.0
        for j in keys:
            a = wgt.get(j, 0.0)
            b = wpr.get(j, 0.0)
            num += min(a, b)
            den += max(a, b)
        out[i] = num / max(1e-12, den)

    return out, tau

# ===================================================================
# 1) HARD kNN OVERLAP + JACCARD (k=10, 20, 50)
# ===================================================================
print("\n--- Hard k-NN Metrics ---")

idx_gt_10 = knn_sets(gt_coords_np, 10)
idx_pr_10 = knn_sets(coords_pred, 10)
idx_gt_20 = knn_sets(gt_coords_np, 20)
idx_pr_20 = knn_sets(coords_pred, 20)
idx_gt_50 = knn_sets(gt_coords_np, 50)
idx_pr_50 = knn_sets(coords_pred, 50)

for k, ig, ip in [(10, idx_gt_10, idx_pr_10), (20, idx_gt_20, idx_pr_20), (50, idx_gt_50, idx_pr_50)]:
    ov = knn_overlap_frac(ig, ip)
    jc = knn_jaccard(ig, ip)
    print(f"[KNN] k={k:2d}: overlap mean={ov.mean():.3f} p50={np.median(ov):.3f} | "
          f"jaccard mean={jc.mean():.3f} p50={np.median(jc):.3f}")

# ===================================================================
# 2) LOCAL SPEARMAN (distance ordering within local neighborhood)
# ===================================================================
print("\n--- Local Spearman Correlation ---")

r20 = kth_neighbor_radius(gt_coords_np, 20)
R = np.median(r20)
print(f"[LOCAL-RADIUS] R = median GT d(20) = {R:.6f}")

rho_local = local_spearman_within_radius(gt_coords_np, coords_pred, R)
good = np.isfinite(rho_local)
print(f"[LOCAL-SPEARMAN] finite_frac={good.mean():.2%} "
      f"mean={np.nanmean(rho_local):.3f} p50={np.nanmedian(rho_local):.3f} "
      f"p10={np.nanpercentile(rho_local,10):.3f} p90={np.nanpercentile(rho_local,90):.3f}")

# ===================================================================
# 3) SOFT WEIGHTED JACCARD (partial credit for near-misses)
# ===================================================================
print("\n--- Soft Weighted Jaccard ---")

sj20, tau20 = soft_weighted_jaccard(gt_coords_np, coords_pred, k=20, tau=None)
sj50, tau50 = soft_weighted_jaccard(gt_coords_np, coords_pred, k=50, tau=None)
print(f"[SOFT-JACCARD] k=20 tau={tau20:.6f}: mean={sj20.mean():.3f} p50={np.median(sj20):.3f}")
print(f"[SOFT-JACCARD] k=50 tau={tau50:.6f}: mean={sj50.mean():.3f} p50={np.median(sj50):.3f}")

# ===================================================================
# SUMMARY COMPARISON
# ===================================================================
print("\n" + "="*70)
print("SUMMARY")
print("="*70)

ov_10 = knn_overlap_frac(idx_gt_10, idx_pr_10).mean()
jc_10 = knn_jaccard(idx_gt_10, idx_pr_10).mean()

print(f"\n  Global Metrics:")
print(f"    EDM Pearson:      {pearson_corr:.4f}")
print(f"    EDM Spearman:     {spearman_corr:.4f}")

print(f"\n  Local Metrics (Hard):")
print(f"    kNN@10 overlap:   {ov_10:.4f}")
print(f"    kNN@10 Jaccard:   {jc_10:.4f}")

print(f"\n  Local Metrics (Stable):")
print(f"    Local Spearman:   {np.nanmean(rho_local):.4f}")
print(f"    Soft Jaccard@20:  {sj20.mean():.4f}")

# ===================================================================
# INTERPRETATION
# ===================================================================
print("\n" + "="*70)
print("INTERPRETATION")
print("="*70)

# Diagnosis based on multiple metrics
if ov_10 < 0.3:
    print("\n‚ö†Ô∏è SEVERE local scrambling (kNN@10 overlap < 30%)")
elif ov_10 < 0.5:
    print("\n‚ö†Ô∏è MODERATE local scrambling (kNN@10 overlap 30-50%)")
elif ov_10 < 0.7:
    print("\n‚úì MILD local scrambling (kNN@10 overlap 50-70%)")
else:
    print("\n‚úì‚úì GOOD local preservation (kNN@10 overlap > 70%)")

# Check for near-ties scenario
if ov_10 < 0.4 and sj20.mean() > ov_10 + 0.15:
    print("\nüîç NEAR-TIES DETECTED:")
    print(f"   Hard kNN@10 is low ({ov_10:.3f}) but Soft Jaccard is higher ({sj20.mean():.3f})")
    print("   ‚Üí Neighbors are 'almost' correct but exact ranks unstable (GT has near-ties)")

# Check for local Spearman vs global
if spearman_corr > 0.6 and np.nanmean(rho_local) < 0.4:
    print("\nüîç LOCAL vs GLOBAL MISMATCH:")
    print(f"   Global EDM Spearman is decent ({spearman_corr:.3f})")
    print(f"   But Local Spearman is poor ({np.nanmean(rho_local):.3f})")
    print("   ‚Üí Macro geometry OK, but fine local ordering disrupted")

# Check if kNN@50 >> kNN@10
ov_50 = knn_overlap_frac(idx_gt_50, idx_pr_50).mean()
if ov_50 > ov_10 + 0.2:
    print("\nüîç kNN SCALE EFFECT:")
    print(f"   kNN@10 overlap: {ov_10:.3f}")
    print(f"   kNN@50 overlap: {ov_50:.3f}")
    print("   ‚Üí Larger neighborhoods preserved better (consistent with near-ties at small k)")

if pearson_corr > 0.6 and ov_10 < 0.4:
    print("\nüîç CLASSIC DIAGNOSIS:")
    print("   High EDM correlation + low kNN preservation")
    print("   ‚Üí Global geometry preserved, specific neighbor relationships lost")


In [None]:
D_edm_pred

In [None]:
# ===================================================================
# k-NN PRESERVATION FROM EDM (HIGH-DIMENSIONAL)
# ===================================================================
print("\n" + "="*70)
print("k-NN PRESERVATION FROM EDM (NO 2D PROJECTION)")
print("="*70)

# 1. Compute ground truth EDM
D_edm_gt = squareform(pdist(gt_coords_np, metric='euclidean'))
print(f"‚úì Ground truth EDM computed: {D_edm_gt.shape}")

# 2. k-NN preservation function using precomputed distances
def knn_acc_from_dist(D_pred, D_gt, k=10):
    """
    Compute k-NN preservation from distance matrices directly.
    No 2D projection involved - uses full distance information.
    """
    n = D_pred.shape[0]
    
    # Get k nearest neighbors from predicted distances
    nn_pred_idx = np.argsort(D_pred, axis=1)[:, 1:k+1]  # skip self (index 0)
    
    # Get k nearest neighbors from ground truth distances
    nn_gt_idx = np.argsort(D_gt, axis=1)[:, 1:k+1]
    
    # Compute overlap fraction for each cell
    overlaps = np.array([
        len(set(nn_pred_idx[i]) & set(nn_gt_idx[i])) / k 
        for i in range(n)
    ])
    
    return overlaps.mean(), overlaps

# 3. Compute k-NN from EDM
knn_edm_k10, overlaps_edm_k10 = knn_acc_from_dist(D_edm_pred, D_edm_gt, k=10)
knn_edm_k20, overlaps_edm_k20 = knn_acc_from_dist(D_edm_pred, D_edm_gt, k=20)

print(f"\nk-NN Preservation from EDM (no projection):")
print(f"  k=10: {knn_edm_k10*10:.2f} / 10  ({knn_edm_k10*100:.1f}% neighbors preserved)")
print(f"  k=20: {knn_edm_k20*20:.2f} / 20  ({knn_edm_k20*100:.1f}% neighbors preserved)")

print(f"\nüìä COMPARISON: EDM-based vs 2D-coordinate-based k-NN:")
print(f"  Method                  k=10        k=20")
print(f"  {'‚îÄ'*50}")
print(f"  From 2D coords:         {knn_k10/10:.3f}      {knn_k20/20:.3f}")
print(f"  From EDM (no proj):     {knn_edm_k10:.3f}      {knn_edm_k20:.3f}")
print(f"  Difference:             {knn_edm_k10 - knn_k10/10:+.3f}      {knn_edm_k20 - knn_k20/20:+.3f}")

print("\n" + "="*70)
print("EDM vs 2D DIAGNOSIS")
print("="*70)

improvement = knn_edm_k10 - knn_k10/10

if improvement > 0.1:
    print("\n‚úì‚úì EDM k-NN is MUCH better than 2D k-NN")
    print("   ‚Üí 2D MDS projection is destroying local structure")
    print("   ‚Üí Your model IS learning good neighborhoods")
    print("   ‚Üí Problem is the visualization/embedding step, not the model")
elif improvement > 0.03:
    print("\n‚úì EDM k-NN is somewhat better than 2D k-NN")
    print("   ‚Üí Some information loss in 2D projection")
    print("   ‚Üí Model performance better than 2D metrics suggest")
elif improvement > -0.03:
    print("\n‚ö†Ô∏è EDM and 2D k-NN are similar")
    print("   ‚Üí 2D projection is reasonably faithful")
    print("   ‚Üí Local scrambling issue is real, not projection artifact")
else:
    print("\n‚ùå EDM k-NN is WORSE than 2D k-NN")
    print("   ‚Üí This shouldn't happen (2D can't add information)")
    print("   ‚Üí Check for bugs in distance matrix computation")

# 4. Optional: Distribution of per-cell overlaps
print(f"\nPer-cell k-NN@10 distribution (from EDM):")
print(f"  Min:     {overlaps_edm_k10.min():.3f}")
print(f"  25th %:  {np.percentile(overlaps_edm_k10, 25):.3f}")
print(f"  Median:  {np.median(overlaps_edm_k10):.3f}")
print(f"  75th %:  {np.percentile(overlaps_edm_k10, 75):.3f}")
print(f"  Max:     {overlaps_edm_k10.max():.3f}")

In [None]:
# ===================================================================
# LOCAL-ONLY DISTANCE CORRELATIONS
# ===================================================================
print("\n" + "="*70)
print("LOCAL-ONLY DISTANCE CORRELATIONS")
print("="*70)

def local_pair_corr(D_pred, D_gt, k=20):
    """
    Compute Pearson/Spearman only on k-NN edges from ground truth.
    This isolates local distance preservation from global structure.
    """
    N = D_gt.shape[0]
    
    # Get k nearest neighbors from GT distances
    nn_gt_idx = np.argsort(D_gt, axis=1)[:, 1:k+1]  # exclude self
    
    # Extract (i,j) pairs for all k-NN edges
    ii = np.repeat(np.arange(N), k)
    jj = nn_gt_idx.reshape(-1)
    
    # Get predicted and GT distances for these pairs only
    x_pred = D_pred[ii, jj]
    y_gt = D_gt[ii, jj]
    
    return pearsonr(x_pred, y_gt)[0], spearmanr(x_pred, y_gt)[0]

def quantile_corr(D_pred, D_gt, q=0.05):
    """
    Compute correlation only on shortest q% of distances.
    Another way to isolate local structure.
    """
    N = D_gt.shape[0]
    
    # Get upper triangle (unique pairs)
    tri = np.triu_indices(N, k=1)
    x_pred = D_pred[tri]
    y_gt = D_gt[tri]
    
    # Keep only shortest q% by GT distance
    threshold = np.quantile(y_gt, q)
    mask = y_gt <= threshold
    
    return pearsonr(x_pred[mask], y_gt[mask])[0], spearmanr(x_pred[mask], y_gt[mask])[0]

# Compute local correlations
print("\n1. k-NN Edge Correlations (local neighborhoods only):")
local_p_k10, local_s_k10 = local_pair_corr(D_edm_pred, D_edm_gt, k=10)
local_p_k20, local_s_k20 = local_pair_corr(D_edm_pred, D_edm_gt, k=20)

print(f"  k=10 edges: Pearson={local_p_k10:.4f}, Spearman={local_s_k10:.4f}")
print(f"  k=20 edges: Pearson={local_p_k20:.4f}, Spearman={local_s_k20:.4f}")

# Compute quantile correlations
print("\n2. Shortest Distance Quantile Correlations:")
q5_p, q5_s = quantile_corr(D_edm_pred, D_edm_gt, q=0.05)
q10_p, q10_s = quantile_corr(D_edm_pred, D_edm_gt, q=0.10)

print(f"  Shortest 5%:  Pearson={q5_p:.4f}, Spearman={q5_s:.4f}")
print(f"  Shortest 10%: Pearson={q10_p:.4f}, Spearman={q10_s:.4f}")

# Compare with global correlations
print(f"\nüìä GLOBAL vs LOCAL COMPARISON:")
print(f"  Metric                          Pearson    Spearman")
print(f"  {'‚îÄ'*60}")
print(f"  Global (all pairs):             {pearson_corr:.4f}     {spearman_corr:.4f}")
print(f"  Local k=20 edges only:          {local_p_k20:.4f}     {local_s_k20:.4f}")
print(f"  Shortest 5% distances only:     {q5_p:.4f}     {q5_s:.4f}")
print(f"\n  Œî (Local k20 - Global):         {local_p_k20 - pearson_corr:+.4f}     {local_s_k20 - spearman_corr:+.4f}")
print(f"  Œî (Shortest 5% - Global):       {q5_p - pearson_corr:+.4f}     {q5_s - spearman_corr:+.4f}")

print("\n" + "="*70)
print("LOCAL vs GLOBAL DIAGNOSIS")
print("="*70)

local_boost = local_p_k20 - pearson_corr

if local_boost > 0.1:
    print("\n‚úì‚úì LOCAL correlations are MUCH better than global")
    print("   ‚Üí Model is learning good local structure")
    print("   ‚Üí Global metric is dominated by long-range pairs")
    print("   ‚Üí Guidance IS working on neighborhoods")
elif local_boost > 0.03:
    print("\n‚úì LOCAL correlations are moderately better than global")
    print("   ‚Üí Some local structure learned")
    print("   ‚Üí Global metric partially masking local improvements")
elif local_boost > -0.03:
    print("\n‚ö†Ô∏è LOCAL and GLOBAL correlations are similar")
    print("   ‚Üí Model treats all scales equally")
    print("   ‚Üí No specific local structure emphasis")
else:
    print("\n‚ùå LOCAL correlations are WORSE than global")
    print("   ‚Üí Model better at long-range than short-range")
    print("   ‚Üí Local scrambling confirmed at distance level")

# Actionable insight
if local_boost > 0.05 and knn_edm_k10 < 0.4:
    print("\nüîç INSIGHT: Good local DISTANCES but poor k-NN preservation")
    print("   ‚Üí Distances are right but WHICH neighbors is wrong")
    print("   ‚Üí Consider: triplet losses, contrastive neighbor losses")

In [None]:
# ===================================================================
# PROCRUSTES ALIGNMENT + CELL TYPE VISUALIZATION (WITH HEATMAPS)
# ===================================================================
print("\n" + "="*70)
print("PROCRUSTES ALIGNMENT + CELL TYPE VISUALIZATION")
print("="*70)

def canonicalize_unit_rms(X):
    """Center and scale to unit RMS"""
    X_centered = X - X.mean(axis=0, keepdims=True)
    rms = np.sqrt((X_centered ** 2).sum() / X_centered.size)
    return X_centered / rms

def procrustes_align(X, Y):
    """Align X to Y using Procrustes (allows rotation + reflection)"""
    X_mean = X.mean(axis=0, keepdims=True)
    Y_mean = Y.mean(axis=0, keepdims=True)
    
    X_centered = X - X_mean
    Y_centered = Y - Y_mean
    
    H = X_centered.T @ Y_centered
    U, S, Vt = np.linalg.svd(H)
    R = Vt.T @ U.T
    
    X_aligned = X_centered @ R + Y_mean
    return X_aligned, R

# Canonicalize both to unit RMS
gt_coords_canon = canonicalize_unit_rms(gt_coords_np)
coords_pred_canon = canonicalize_unit_rms(coords_pred)

print(f"\n‚úì Canonicalized coordinates to unit RMS")
print(f"  GT RMS: {np.sqrt((gt_coords_canon ** 2).sum() / gt_coords_canon.size):.6f}")
print(f"  Pred RMS: {np.sqrt((coords_pred_canon ** 2).sum() / coords_pred_canon.size):.6f}")

# Procrustes alignment
coords_pred_aligned, R = procrustes_align(coords_pred_canon, gt_coords_canon)
det_R = np.linalg.det(R)
alignment_error = np.linalg.norm(gt_coords_canon - coords_pred_aligned, 'fro')

print(f"\n‚úì Procrustes alignment complete")
print(f"  det(R) = {det_R:.4f} {'(reflection)' if det_R < 0 else '(rotation)'}")
print(f"  Frobenius error: {alignment_error:.4f}")
print(f"  Per-cell RMSE: {alignment_error / np.sqrt(n_cells):.4f}")

# Compute distance matrices from aligned coordinates
D_gt_aligned = squareform(pdist(gt_coords_canon, 'euclidean'))
D_pred_aligned = squareform(pdist(coords_pred_aligned, 'euclidean'))

print(f"\n‚úì Computed distance matrices from aligned coords")
print(f"  GT distance matrix: {D_gt_aligned.shape}")
print(f"  Pred distance matrix: {D_pred_aligned.shape}")

# Get cell types
print("\nAvailable columns in adata_sc.obs:")
print(list(adata_sc.obs.columns))

cell_type_col = None
for col in ['cell_type', 'celltype', 'cluster', 'annotation', 'cell_ontology_class', 'celltype_mapped_refined']:
    if col in adata_sc.obs.columns:
        cell_type_col = col
        break

if cell_type_col is None:
    categorical_cols = adata_sc.obs.select_dtypes(include=['category', 'object']).columns
    if len(categorical_cols) > 0:
        cell_type_col = categorical_cols[0]
    else:
        adata_sc.obs['cell_type'] = 'Unknown'
        cell_type_col = 'cell_type'

print(f"\nUsing cell type column: '{cell_type_col}'")
cell_types = adata_sc.obs[cell_type_col].values

unique_types = np.unique(cell_types)
n_types = len(unique_types)

print(f"Found {n_types} unique cell types:")
for i, ct in enumerate(unique_types):
    count = (cell_types == ct).sum()
    print(f"  {i+1}. {ct}: {count} cells")

# Colormap
if n_types <= 10:
    cmap = plt.cm.tab10
elif n_types <= 20:
    cmap = plt.cm.tab20
else:
    cmap = plt.cm.gist_ncar

type_to_color = {ct: cmap(i / n_types) for i, ct in enumerate(unique_types)}

# PLOT 1: Side-by-side cell type visualization
fig, axes = plt.subplots(1, 2, figsize=(20, 8))

# Ground Truth
ax = axes[0]
for ct in unique_types:
    mask = cell_types == ct
    ax.scatter(gt_coords_canon[mask, 0], gt_coords_canon[mask, 1], 
              c=[type_to_color[ct]], label=ct, s=15, alpha=0.7, edgecolors='none')

ax.set_xlabel('Dimension 1', fontsize=13, fontweight='bold')
ax.set_ylabel('Dimension 2', fontsize=13, fontweight='bold')
ax.set_title('Ground Truth - Cell Types', fontsize=14, fontweight='bold')
ax.set_aspect('equal', adjustable='box')
ax.grid(True, alpha=0.3)

# Predicted (Procrustes Aligned)
ax = axes[1]
for ct in unique_types:
    mask = cell_types == ct
    ax.scatter(coords_pred_aligned[mask, 0], coords_pred_aligned[mask, 1], 
              c=[type_to_color[ct]], label=ct, s=15, alpha=0.7, edgecolors='none')

ax.set_xlabel('Dimension 1', fontsize=13, fontweight='bold')
ax.set_ylabel('Dimension 2', fontsize=13, fontweight='bold')
ax.set_title(f'Predicted (Aligned) - Frobenius Error: {alignment_error:.2f}', fontsize=14, fontweight='bold')
ax.set_aspect('equal', adjustable='box')
ax.grid(True, alpha=0.3)

# Legend
if n_types <= 15:
    handles, labels = axes[1].get_legend_handles_labels()
    fig.legend(handles, labels, loc='center left', bbox_to_anchor=(1.0, 0.5), 
              fontsize=10, title='Cell Type', title_fontsize=12, frameon=True)

plt.tight_layout(rect=[0, 0, 0.95, 1])
celltype_path = os.path.join(output_dir, f'patchwise_celltype_aligned_{timestamp}.png')
# plt.savefig(celltype_path, dpi=300, bbox_inches='tight')
print(f"\n‚úì Saved cell type visualization: {celltype_path}")
plt.show()

# PLOT 2: Distance matrix heatmaps
print("\n" + "="*70)
print("DISTANCE MATRIX HEATMAP COMPARISON")
print("="*70)

# Sample cells for visualization
sample_size = min(600, n_cells)
sample_indices = np.random.choice(n_cells, sample_size, replace=False)
sample_indices = np.sort(sample_indices)

print(f"\nCreating distance matrix heatmaps with {sample_size} sampled cells...")

fig, axes = plt.subplots(1, 2, figsize=(18, 8))
fig.suptitle('Distance Matrix Comparison (from aligned coordinates)', fontsize=18, fontweight='bold')

# Ground Truth Distance Matrix
im1 = axes[0].imshow(D_gt_aligned[np.ix_(sample_indices, sample_indices)], 
                     cmap='viridis', aspect='auto')
axes[0].set_title('Ground Truth Distance Matrix', fontsize=14, fontweight='bold')
axes[0].set_xlabel('Cell Index (Sampled)', fontsize=12)
axes[0].set_ylabel('Cell Index (Sampled)', fontsize=12)
fig.colorbar(im1, ax=axes[0], fraction=0.046, pad=0.04, label='Euclidean Distance')

# Predicted Distance Matrix
im2 = axes[1].imshow(D_pred_aligned[np.ix_(sample_indices, sample_indices)], 
                     cmap='viridis', aspect='auto')
axes[1].set_title('Predicted Distance Matrix', fontsize=14, fontweight='bold')
axes[1].set_xlabel('Cell Index (Sampled)', fontsize=12)
fig.colorbar(im2, ax=axes[1], fraction=0.046, pad=0.04, label='Euclidean Distance')

plt.tight_layout(rect=[0, 0, 1, 0.96])
distmat_path = os.path.join(output_dir, f'patchwise_distmat_heatmap_{timestamp}.png')
# plt.savefig(distmat_path, dpi=300, bbox_inches='tight')
print(f"‚úì Saved distance matrix heatmap: {distmat_path}")
plt.show()

print("\n" + "="*70)
print("COMPLETE")
print("="*70)

In [None]:
# Diagnostic: correlation on canonicalized distance matrices
from scipy.stats import pearsonr, spearmanr
D_gt_flat = squareform(pdist(gt_coords_canon, 'euclidean'))
D_pred_flat = squareform(pdist(coords_pred_canon, 'euclidean'))
triu_idx = np.triu_indices(len(D_gt_flat), k=1)
pearson_r = pearsonr(D_gt_flat[triu_idx], D_pred_flat[triu_idx])[0]
spearman_r = spearmanr(D_gt_flat[triu_idx], D_pred_flat[triu_idx])[0]
print(f"\n  Distance correlation (canonicalized):")
print(f"    Pearson:  {pearson_r:.4f}")
print(f"    Spearman: {spearman_r:.4f}")

In [None]:
import numpy as np
import pandas as pd

# Load ST1 training coordinates
st_meta = pd.read_csv('/home/ehtesamul/sc_st/data/mousedata_2020/E1z2/simu_st1_metadata_et.csv', index_col=0)
coords = st_meta[['coord_x', 'coord_y']].values

# Compute geometry metrics
coords_centered = coords - coords.mean(axis=0)
cov = np.cov(coords_centered.T)
eigvals = np.linalg.eigvalsh(cov)
eigvals = np.sort(eigvals)[::-1]

# Effective dimensionality (participation ratio)
dim_eff = (eigvals.sum() ** 2) / (eigvals ** 2).sum()

# Anisotropy ratio
aniso = eigvals[0] / (eigvals[1] + 1e-8)

print(f"Mouse Brain ST1 Training Data:")
print(f"  N cells: {coords.shape[0]}")
print(f"  Effective dimensionality: {dim_eff:.4f}")
print(f"  Anisotropy ratio: {aniso:.4f}")
print(f"  Eigenvalues: {eigvals}")

In [None]:
# ============================================================================
# EIGENVALUE ANISOTROPY ANALYSIS FOR GEMS INFERENCE
# ============================================================================

print("="*70)
print("GEMS INFERENCE: 2D GEOMETRY VERIFICATION")
print("="*70)

# ============================================================================
# 1. ANALYZE GEMS PREDICTED COORDINATES
# ============================================================================

print("\n=== Analyzing GEMS Predicted Coordinates ===\n")

gems_coords = coords_pred

# Get GEMS coordinates (already loaded as gems_coords)
coords_gems = gems_coords.numpy() if torch.is_tensor(gems_coords) else gems_coords

# Compute anisotropy for GEMS
X = coords_gems.astype(float)
Xc = X - X.mean(axis=0, keepdims=True)

cov = Xc.T @ Xc / (Xc.shape[0] - 1)
eigvals_gems, eigvecs_gems = np.linalg.eigh(cov)
eigvals_gems = eigvals_gems[::-1]

lam1_gems, lam2_gems = eigvals_gems
ratio_gems = lam1_gems / (lam2_gems + 1e-12)

print(f"GEMS Predicted Coordinates ({coords_gems.shape[0]} cells):")
print(f"  Œª1 = {lam1_gems:.4f},  Œª2 = {lam2_gems:.4f}")
print(f"  Œª1/Œª2 = {ratio_gems:.2f}")

if ratio_gems < 5:
    interpretation_gems = "‚Üí GENUINELY 2D ‚úì"
elif ratio_gems < 20:
    interpretation_gems = "‚Üí Anisotropic but still 2D-ish"
else:
    interpretation_gems = "‚Üí EFFECTIVELY 1D (very elongated) ‚úó"

print(f"  {interpretation_gems}\n")

# ============================================================================
# 2. ANALYZE ST MINISETS FOR COMPARISON
# ============================================================================

print("=== Analyzing ST Mini-Subsets for Comparison ===\n")

ratios_st = []
eigenvalues_st = []

for i, data in enumerate(miniset_data):
    D = data['D_edm']
    
    # Reconstruct coordinates from EDM using classical MDS
    n = D.shape[0]
    Jn = np.eye(n) - np.ones((n, n)) / n
    B = -0.5 * (Jn @ (D ** 2) @ Jn)
    
    eigvals_full, eigvecs_full = np.linalg.eigh(B)
    eigvals_full = eigvals_full[::-1]
    eigvecs_full = eigvecs_full[:, ::-1]
    
    coords_patch = eigvecs_full[:, :2] @ np.diag(np.sqrt(np.maximum(eigvals_full[:2], 0)))
    
    # Analyze 2D variance
    X = coords_patch.astype(float)
    Xc = X - X.mean(axis=0, keepdims=True)
    
    cov = Xc.T @ Xc / (Xc.shape[0] - 1)
    eigvals_2d, _ = np.linalg.eigh(cov)
    eigvals_2d = eigvals_2d[::-1]
    
    lam1, lam2 = eigvals_2d
    ratio = lam1 / (lam2 + 1e-12)
    
    ratios_st.append(ratio)
    eigenvalues_st.append((lam1, lam2))

ratios_st = np.array(ratios_st)
eigenvalues_st = np.array(eigenvalues_st)

print(f"ST Mini-Subsets Statistics:")
print(f"  Œª1/Œª2 - Median: {np.median(ratios_st):.2f}")
print(f"  Œª1/Œª2 - Mean:   {ratios_st.mean():.2f}")
print(f"  Œª1/Œª2 - Range:  [{ratios_st.min():.2f}, {ratios_st.max():.2f}]")

# ============================================================================
# 3. COMPARISON VISUALIZATION
# ============================================================================

print("\n=== Creating Comparison Visualizations ===\n")

fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# Plot 1: Histogram comparison - ST vs GEMS
ax = axes[0, 0]
ax.hist(ratios_st, bins=30, alpha=0.6, edgecolor='black', color='steelblue', 
        label=f'ST Minisets (n={len(ratios_st)})')
ax.axvline(ratio_gems, color='red', linestyle='--', linewidth=3, 
           label=f'GEMS: {ratio_gems:.2f}')
ax.axvline(np.median(ratios_st), color='blue', linestyle='--', linewidth=2, 
           label=f'ST Median: {np.median(ratios_st):.2f}')
ax.axvline(5, color='g', linestyle=':', linewidth=2, alpha=0.5, label='2D threshold (5)')
ax.axvline(20, color='orange', linestyle=':', linewidth=2, alpha=0.5, label='1D threshold (20)')

ax.set_xlabel('Œª1/Œª2 (Anisotropy Ratio)', fontsize=13, fontweight='bold')
ax.set_ylabel('Count', fontsize=13, fontweight='bold')
ax.set_title('Anisotropy Comparison: ST Minisets vs GEMS', fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3, axis='y')

# Plot 2: Eigenvalue scatter - ST minisets
ax = axes[0, 1]
scatter = ax.scatter(eigenvalues_st[:, 1], eigenvalues_st[:, 0], 
                    c=ratios_st, cmap='viridis', alpha=0.7, s=80, 
                    edgecolors='black', linewidth=1, label='ST Minisets')

# Add GEMS point
ax.scatter(lam2_gems, lam1_gems, c='red', s=300, marker='*', 
          edgecolors='darkred', linewidth=2, label='GEMS', zorder=5)

# Diagonal line
min_val = min(eigenvalues_st[:, 1].min(), lam2_gems)
max_val = max(eigenvalues_st[:, 0].max(), lam1_gems)
ax.plot([min_val, max_val], [min_val, max_val], 'r--', 
        linewidth=2, label='Œª1 = Œª2', alpha=0.7)

ax.set_xlabel('Œª2 (Smaller Eigenvalue)', fontsize=13, fontweight='bold')
ax.set_ylabel('Œª1 (Larger Eigenvalue)', fontsize=13, fontweight='bold')
ax.set_title('Eigenvalue Scatter Plot', fontsize=14, fontweight='bold')
ax.legend(fontsize=11, loc='upper left')
ax.grid(True, alpha=0.3)

cbar = plt.colorbar(scatter, ax=ax)
cbar.set_label('Œª1/Œª2', fontsize=11)

# Plot 3: Coordinate scatter - GEMS
ax = axes[1, 0]
ax.scatter(coords_gems[:, 0], coords_gems[:, 1], alpha=0.5, s=10, 
          c='red', edgecolors='none')
ax.set_xlabel('Dimension 1', fontsize=13, fontweight='bold')
ax.set_ylabel('Dimension 2', fontsize=13, fontweight='bold')
ax.set_title(f'GEMS Predicted Coordinates\nŒª1/Œª2 = {ratio_gems:.2f}', 
             fontsize=14, fontweight='bold', color='red')
ax.grid(True, alpha=0.3)
ax.set_aspect('equal', adjustable='box')

# Plot 4: Box plot comparison
ax = axes[1, 1]

# Combine ST ratios with GEMS ratio for box plot
data_for_plot = [ratios_st, [ratio_gems]]
labels = ['ST Minisets\n(Ground Truth)', 'GEMS\n(Predicted)']

bp = ax.boxplot(data_for_plot, labels=labels, patch_artist=True,
                showmeans=True, meanline=True, widths=0.6)

bp['boxes'][0].set_facecolor('steelblue')
bp['boxes'][0].set_alpha(0.7)
bp['boxes'][1].set_facecolor('red')
bp['boxes'][1].set_alpha(0.7)

ax.axhline(5, color='g', linestyle='--', linewidth=2, alpha=0.5, label='2D threshold (5)')
ax.axhline(20, color='orange', linestyle='--', linewidth=2, alpha=0.5, label='1D threshold (20)')

ax.set_ylabel('Œª1/Œª2 (Anisotropy Ratio)', fontsize=13, fontweight='bold')
ax.set_title('Anisotropy Distribution Comparison', fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
gems_anisotropy_path = os.path.join(output_dir, f'gems_anisotropy_analysis_{timestamp}.png')
plt.savefig(gems_anisotropy_path, dpi=300, bbox_inches='tight')
print(f"‚úì Saved GEMS anisotropy analysis: {gems_anisotropy_path}")
plt.show()

# ============================================================================
# 4. DETAILED COMPARISON TABLE
# ============================================================================

print("\n" + "="*70)
print("DETAILED COMPARISON: ST MINISETS vs GEMS INFERENCE")
print("="*70)

print(f"\n{'Metric':<30} {'ST Minisets':<20} {'GEMS Inference':<20}")
print("-" * 70)
print(f"{'Number of samples':<30} {len(ratios_st):<20} {coords_gems.shape[0]:<20}")
print(f"{'Œª1 (larger eigenvalue)':<30} {eigenvalues_st[:, 0].mean():.4f} ¬± {eigenvalues_st[:, 0].std():.4f}   {lam1_gems:.4f}")
print(f"{'Œª2 (smaller eigenvalue)':<30} {eigenvalues_st[:, 1].mean():.4f} ¬± {eigenvalues_st[:, 1].std():.4f}   {lam2_gems:.4f}")
print(f"{'Œª1/Œª2 ratio (median)':<30} {np.median(ratios_st):.2f}              {ratio_gems:.2f}")
print(f"{'Œª1/Œª2 ratio (mean)':<30} {ratios_st.mean():.2f}              {ratio_gems:.2f}")
print(f"{'Œª1/Œª2 ratio (range)':<30} [{ratios_st.min():.2f}, {ratios_st.max():.2f}]     {ratio_gems:.2f}")

print("\n" + "="*70)
print("INTERPRETATION")
print("="*70)

st_genuinely_2d_pct = (ratios_st < 5).sum() / len(ratios_st) * 100

print(f"\nST Mini-Subsets (Ground Truth):")
print(f"  {st_genuinely_2d_pct:.1f}% are genuinely 2D (Œª1/Œª2 < 5)")
print(f"  Median anisotropy: {np.median(ratios_st):.2f}")
if np.median(ratios_st) < 5:
    print(f"  ‚Üí GENUINELY 2D ‚úì")

print(f"\nGEMS Predicted Coordinates:")
if ratio_gems < 5:
    print(f"  ‚úì GENUINELY 2D (Œª1/Œª2 = {ratio_gems:.2f})")
    print(f"  ‚Üí GEMS successfully preserves 2D spatial structure")
elif ratio_gems < 20:
    print(f"  ‚ö† Anisotropic but still 2D-ish (Œª1/Œª2 = {ratio_gems:.2f})")
    print(f"  ‚Üí GEMS produces elongated but 2D structures")
else:
    print(f"  ‚úó EFFECTIVELY 1D (Œª1/Œª2 = {ratio_gems:.2f})")
    print(f"  ‚Üí WARNING: GEMS collapsed to 1D structure")

# Comparison
if ratio_gems < 5 and np.median(ratios_st) < 5:
    print(f"\n‚úì EXCELLENT: Both ST and GEMS are genuinely 2D")
elif abs(ratio_gems - np.median(ratios_st)) < 3:
    print(f"\n‚úì GOOD: GEMS anisotropy ({ratio_gems:.2f}) is similar to ST ({np.median(ratios_st):.2f})")
else:
    print(f"\n‚ö† WARNING: Large anisotropy difference between GEMS ({ratio_gems:.2f}) and ST ({np.median(ratios_st):.2f})")

print("\n" + "="*70)
print("GEMS GEOMETRY ANALYSIS COMPLETE")
print("="*70)

In [None]:
# ===================================================================
# PATCHWISE INFERENCE - TESTING INIT-ONLY (NO PROCRUSTES ALIGNMENT)
# ===================================================================

import torch
import numpy as np
from datetime import datetime
import sys
import os
import scanpy as sc
from scipy.spatial.distance import pdist, squareform
from scipy.stats import pearsonr, spearmanr
import matplotlib.pyplot as plt
import seaborn as sns

sys.path.insert(0, '/home/ehtesamul/sc_st/model')
from core_models_et_p3 import GEMSModel
import utils_et as uet

# ===================================================================
# PATHS AND CONFIG
# ===================================================================
output_dir = "/home/ehtesamul/sc_st/model/gems_mousebrain_output"
checkpoint_path = f"{output_dir}/phase2_sc_finetuned_checkpoint.pt"
device = 'cuda' if torch.cuda.is_available() else 'cpu'

print("="*70)
print("LOADING DATA AND MODEL")
print("="*70)

# Load data
adata_sc = sc.read_h5ad(f"{output_dir}/scadata_with_gems_20251129_205637.h5ad")

if hasattr(adata_sc, 'raw') and adata_sc.raw is not None:
    sc_expr = torch.tensor(adata_sc.raw.X.toarray() if hasattr(adata_sc.raw.X, 'toarray') else adata_sc.raw.X, dtype=torch.float32)
else:
    sc_expr = torch.tensor(adata_sc.X.toarray() if hasattr(adata_sc.X, 'toarray') else adata_sc.X, dtype=torch.float32)

gt_coords = adata_sc.obsm['spatial_gt']
n_cells, n_genes = sc_expr.shape
print(f"‚úì Loaded SC data: {n_cells} cells √ó {n_genes} genes")
print(f"‚úì Ground truth coords: {gt_coords.shape}")

# Load checkpoint
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
print(f"‚úì Loaded checkpoint")

# Initialize model
model = GEMSModel(
    n_genes=n_genes,
    n_embedding=[512, 256, 128],
    D_latent=32,
    c_dim=256,
    n_heads=4,
    isab_m=64,
    dist_bins=24,
    device=device
)

# Load weights
model.encoder.load_state_dict(checkpoint['encoder'])
model.context_encoder.load_state_dict(checkpoint['context_encoder'])
model.generator.load_state_dict(checkpoint['generator'])
model.score_net.load_state_dict(checkpoint['score_net'])

model.encoder.eval()
model.context_encoder.eval()
model.generator.eval()
model.score_net.eval()
print(f"‚úì Model loaded and set to eval mode")

# ===================================================================
# RUN INFERENCE WITH DIFFERENT CONFIGS
# ===================================================================

configs = [
    {"patch_size": 838, "n_align_iters": 1, "name": "Single Patch (baseline)"},
    {"patch_size": 256, "n_align_iters": 0, "name": "patch_size=256, INIT ONLY (no Procrustes)"},
    {"patch_size": 256, "n_align_iters": 10, "name": "patch_size=256, WITH Procrustes"},
    {"patch_size": 384, "n_align_iters": 0, "name": "patch_size=384, INIT ONLY (no Procrustes)"},
    {"patch_size": 384, "n_align_iters": 10, "name": "patch_size=384, WITH Procrustes"},
]

results_comparison = []

for config in configs:
    print("\n" + "="*70)
    print(f"RUNNING: {config['name']}")
    print("="*70)
    
    with torch.no_grad():
        results = model.infer_sc_patchwise(
            sc_gene_expr=sc_expr,
            n_timesteps_sample=500,
            return_coords=True,
            patch_size=config['patch_size'],
            coverage_per_cell=6.0,
            n_align_iters=config['n_align_iters'],
            eta=0.0,
            guidance_scale=2.0,
            sigma_min=0.01,
            sigma_max=3.0,
        )
    
    D_edm_pred = results['D_edm'].cpu().numpy()
    coords_pred = results['coords_canon'].cpu().numpy()
    
    print(f"\n‚úì Inference complete!")
    print(f"  Predicted EDM: {D_edm_pred.shape}")
    print(f"  Predicted coords: {coords_pred.shape}")
    
    # Compute metrics
    gt_edm = squareform(pdist(gt_coords, 'euclidean'))
    triu_indices = np.triu_indices(n_cells, k=1)
    gt_distances = gt_edm[triu_indices]
    pred_distances = D_edm_pred[triu_indices]
    
    scale = np.median(gt_distances) / np.median(pred_distances)
    pred_distances_scaled = pred_distances * scale
    
    pearson_corr, _ = pearsonr(gt_distances, pred_distances_scaled)
    spearman_corr, _ = spearmanr(gt_distances, pred_distances_scaled)
    
    print(f"\nPearson Correlation:  {pearson_corr:.4f}")
    print(f"Spearman Correlation: {spearman_corr:.4f}")
    print(f"Scale factor: {scale:.4f}")
    
    results_comparison.append({
        'config': config['name'],
        'patch_size': config['patch_size'],
        'n_align_iters': config['n_align_iters'],
        'pearson': pearson_corr,
        'spearman': spearman_corr,
        'scale': scale,
        'coords': coords_pred,
        'edm': D_edm_pred
    })

# ===================================================================
# SUMMARY TABLE
# ===================================================================
print("\n" + "="*70)
print("RESULTS SUMMARY")
print("="*70)
print(f"{'Config':<50} {'Pearson':>10} {'Spearman':>10} {'Scale':>12}")
print("-"*84)
for r in results_comparison:
    print(f"{r['config']:<50} {r['pearson']:>10.4f} {r['spearman']:>10.4f} {r['scale']:>12.2f}")

# ===================================================================
# VISUALIZATIONS
# ===================================================================
print("\n" + "="*70)
print("GENERATING COMPARISON PLOTS")
print("="*70)

# Plot coordinates for each config
n_configs = len(results_comparison)
fig, axes = plt.subplots(2, n_configs, figsize=(5*n_configs, 10))

for i, r in enumerate(results_comparison):
    # Ground truth
    axes[0, i].scatter(gt_coords[:, 0], gt_coords[:, 1], s=3, alpha=0.6, c='blue')
    axes[0, i].set_title(f'Ground Truth\n(ref for all)', fontsize=10, weight='bold')
    axes[0, i].set_aspect('equal')
    
    # Predicted
    axes[1, i].scatter(r['coords'][:, 0], r['coords'][:, 1], s=3, alpha=0.6, c='red')
    axes[1, i].set_title(f"{r['config']}\nœÅ={r['spearman']:.3f}", fontsize=10, weight='bold')
    axes[1, i].set_aspect('equal')

plt.tight_layout()
plt.savefig('patchwise_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

# Bar plot of correlations
fig, ax = plt.subplots(figsize=(12, 6))
x = np.arange(len(results_comparison))
width = 0.35

pearson_vals = [r['pearson'] for r in results_comparison]
spearman_vals = [r['spearman'] for r in results_comparison]

ax.bar(x - width/2, pearson_vals, width, label='Pearson', alpha=0.8)
ax.bar(x + width/2, spearman_vals, width, label='Spearman', alpha=0.8)

ax.set_ylabel('Correlation', fontsize=12)
ax.set_title('EDM Correlation: Init-Only vs Full Procrustes Alignment', fontsize=14, weight='bold')
ax.set_xticks(x)
ax.set_xticklabels([r['config'] for r in results_comparison], rotation=15, ha='right')
ax.legend()
ax.grid(axis='y', alpha=0.3)
ax.set_ylim(0, 1)

plt.tight_layout()
plt.savefig('correlation_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

print("\n" + "="*70)
print("COMPLETE!")
print("="*70)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import svd
from scipy.spatial.distance import cdist

# Ground truth: 8-sided polygon (octagon)
n_cells = 8
angles = np.linspace(0, 2*np.pi, n_cells, endpoint=False)
X_true = np.column_stack([np.cos(angles), np.sin(angles)])

# Define 3 overlapping patches
patches = [
    np.array([0, 1, 2, 3, 4]),
    np.array([3, 4, 5, 6, 7]),
    np.array([6, 7, 0, 1, 2])
]
patch_centers = [2, 5, 7]

# REALISTIC: Generate patches with SAME scale but random rotation/translation
# This mimics what your diffusion model does
def realistic_transform(X, rotation_deg=None, translation=None):
    """Only rotate and translate, NO SCALING (like diffusion model)"""
    if rotation_deg is None:
        rotation_deg = np.random.uniform(0, 360)
    if translation is None:
        translation = np.random.randn(2) * 0.5
    
    theta = np.radians(rotation_deg)
    R = np.array([[np.cos(theta), -np.sin(theta)], 
                  [np.sin(theta), np.cos(theta)]])
    return X @ R.T + translation

# Create transformed local coordinates (centered, no scale change)
local_coords = []
for patch_idx in patches:
    X_patch = X_true[patch_idx].copy()
    X_patch -= X_patch.mean(axis=0)  # Center
    X_transformed = realistic_transform(X_patch)  # Only rotate + translate
    X_transformed -= X_transformed.mean(axis=0)  # Re-center
    local_coords.append(X_transformed)

def compute_weights(patch_idx, center_cell_global_idx):
    n_points = len(patch_idx)
    center_pos = np.where(patch_idx == center_cell_global_idx)[0]
    
    if len(center_pos) > 0:
        center_pos = center_pos[0]
    else:
        center_pos = n_points // 2
    
    positions = np.arange(n_points)
    distances = np.abs(positions - center_pos)
    weights = np.exp(-distances**2 / (2 * (n_points/4)**2))
    return weights

def procrustes_error(X, X_true):
    X_c = X - X.mean(axis=0)
    X_true_c = X_true - X_true.mean(axis=0)
    C = X_true_c.T @ X_c
    U, S, Vt = svd(C)
    R = U @ Vt
    if np.linalg.det(R) < 0:
        U[:, -1] *= -1
        R = U @ Vt
    s = S.sum() / (X_c**2).sum()
    X_aligned = s * (X_c @ R.T)
    rmse = np.sqrt(((X_aligned - X_true_c)**2).sum() / len(X))
    return rmse

def distance_matrix_error(X, X_true):
    D = cdist(X, X)
    D_true = cdist(X_true, X_true)
    return np.sqrt(((D - D_true)**2).mean())

# Run TWO experiments in parallel
def run_alignment(use_global_scale):
    """
    use_global_scale=True: Use same scale for all patches (correct)
    use_global_scale=False: Use per-patch scales (your current code)
    """
    # Initialize
    X_global = np.zeros_like(X_true)
    X_global[patches[0]] = local_coords[0]
    for i, patch_idx in enumerate(patches[1:], 1):
        mask = ~np.isin(patch_idx, patches[0])
        X_global[patch_idx[mask]] = local_coords[i][mask]
    X_global -= X_global.mean(axis=0)
    
    history = [X_global.copy()]
    rmse_hist = [procrustes_error(X_global, X_true)]
    dm_hist = [distance_matrix_error(X_global, X_true)]
    
    for iter_idx in range(10):
        # Step A: Compute rotations
        R_list = []
        numerators = []
        denominators = []
        mu_X_list = []
        mu_V_list = []
        
        for patch_idx, V_local, center_cell in zip(patches, local_coords, patch_centers):
            X_patch = X_global[patch_idx]
            weights = compute_weights(patch_idx, center_cell)
            
            w_sum = weights.sum()
            mu_X = (weights[:, None] * X_patch).sum(axis=0) / w_sum
            mu_V = (weights[:, None] * V_local).sum(axis=0) / w_sum
            
            X_centered = X_patch - mu_X
            V_centered = V_local - mu_V
            
            C = (X_centered.T * weights) @ V_centered
            
            U, S, Vt = svd(C)
            R = U @ Vt
            if np.linalg.det(R) < 0:
                U[:, -1] *= -1
                R = U @ Vt
            
            numerators.append(S.sum())
            denominators.append((weights[:, None] * V_centered**2).sum())
            
            R_list.append(R)
            mu_X_list.append(mu_X)
            mu_V_list.append(mu_V)
        
        # Compute scale(s)
        if use_global_scale:
            # GLOBAL SCALE: shared across all patches
            s_global = sum(numerators) / sum(denominators)
            s_list = [s_global] * len(patches)
        else:
            # PER-PATCH SCALE: each patch gets its own
            s_list = [num / (denom + 1e-8) for num, denom in zip(numerators, denominators)]
        
        # Compute translations
        t_list = []
        for R, mu_X, mu_V, s, V_local in zip(R_list, mu_X_list, mu_V_list, s_list, local_coords):
            t = mu_X - s * (mu_V @ R.T)
            t_list.append(t)
        
        # Step B: Update global coordinates
        X_new = np.zeros_like(X_global)
        W_total = np.zeros(n_cells)
        
        for s, R, t, patch_idx, V_local, center_cell in zip(
            s_list, R_list, t_list, patches, local_coords, patch_centers
        ):
            X_transformed = s * (V_local @ R.T) + t
            weights = compute_weights(patch_idx, center_cell)
            
            for i, cell_i in enumerate(patch_idx):
                X_new[cell_i] += weights[i] * X_transformed[i]
                W_total[cell_i] += weights[i]
        
        X_new /= W_total[:, None]
        X_new -= X_new.mean(axis=0)
        X_global = X_new
        
        history.append(X_global.copy())
        rmse_hist.append(procrustes_error(X_global, X_true))
        dm_hist.append(distance_matrix_error(X_global, X_true))
    
    return history, rmse_hist, dm_hist, s_list

# Run both methods
history_global, rmse_global, dm_global, s_global = run_alignment(use_global_scale=True)
history_perPatch, rmse_perPatch, dm_perPatch, s_perPatch = run_alignment(use_global_scale=False)

# Plot comparison
fig, axes = plt.subplots(3, 4, figsize=(24, 18))

colors = ['red', 'blue', 'green']

def plot_octagon(ax, X, title):
    ax.scatter(X[:, 0], X[:, 1], c='black', s=150, zorder=3, edgecolors='white', linewidths=2)
    
    for j in range(n_cells):
        next_j = (j + 1) % n_cells
        ax.plot([X[j, 0], X[next_j, 0]], [X[j, 1], X[next_j, 1]], 
                'k-', alpha=0.5, linewidth=2)
    
    for j in range(n_cells):
        ax.text(X[j, 0], X[j, 1], str(j), fontsize=14, 
                ha='center', va='center', color='white', weight='bold')
    
    for patch_idx, color in zip(patches, colors):
        X_patch = X[patch_idx]
        ax.fill(X_patch[:, 0], X_patch[:, 1], alpha=0.15, color=color, edgecolor=color, linewidth=2)
    
    ax.set_xlim(-2, 2)
    ax.set_ylim(-2, 2)
    ax.set_aspect('equal')
    ax.set_title(title, fontsize=16, weight='bold', pad=10)
    ax.grid(True, alpha=0.3)
    ax.set_xticks([])
    ax.set_yticks([])

# Row 0: Ground truth
plot_octagon(axes[0, 0], X_true, 'Ground Truth')
plot_octagon(axes[0, 1], X_true, 'Ground Truth')
axes[0, 2].axis('off')
axes[0, 3].axis('off')

# Row 1: Global scale method
plot_octagon(axes[1, 0], history_global[0], 'Global Scale: Iter 0')
plot_octagon(axes[1, 1], history_global[5], 'Global Scale: Iter 5')
plot_octagon(axes[1, 2], history_global[10], 'Global Scale: Iter 10')

ax = axes[1, 3]
ax.plot(rmse_global, 'o-', color='blue', linewidth=2, label='RMSE')
ax2 = ax.twinx()
ax2.plot(dm_global, 's-', color='red', linewidth=2, label='DM Error')
ax.set_xlabel('Iteration', fontsize=12, weight='bold')
ax.set_ylabel('RMSE', fontsize=12, color='blue')
ax2.set_ylabel('DM Error', fontsize=12, color='red')
ax.tick_params(axis='y', labelcolor='blue')
ax2.tick_params(axis='y', labelcolor='red')
ax.grid(True, alpha=0.3)
ax.set_title('Global Scale Convergence', fontsize=14, weight='bold')

# Row 2: Per-patch scale method
plot_octagon(axes[2, 0], history_perPatch[0], 'Per-Patch Scale: Iter 0')
plot_octagon(axes[2, 1], history_perPatch[5], 'Per-Patch Scale: Iter 5')
plot_octagon(axes[2, 2], history_perPatch[10], 'Per-Patch Scale: Iter 10')

ax = axes[2, 3]
ax.plot(rmse_perPatch, 'o-', color='blue', linewidth=2, label='RMSE')
ax2 = ax.twinx()
ax2.plot(dm_perPatch, 's-', color='red', linewidth=2, label='DM Error')
ax.set_xlabel('Iteration', fontsize=12, weight='bold')
ax.set_ylabel('RMSE', fontsize=12, color='blue')
ax2.set_ylabel('DM Error', fontsize=12, color='red')
ax.tick_params(axis='y', labelcolor='blue')
ax2.tick_params(axis='y', labelcolor='red')
ax.grid(True, alpha=0.3)
ax.set_title('Per-Patch Scale Convergence', fontsize=14, weight='bold')

plt.tight_layout()
plt.show()

print("\n" + "="*60)
print("FINAL RESULTS:")
print("="*60)
print(f"\nGLOBAL SCALE METHOD:")
print(f"  Final RMSE: {rmse_global[-1]:.6f}")
print(f"  Final DM Error: {dm_global[-1]:.6f}")
print(f"  All patches use scale: {s_global[0]:.4f}")

print(f"\nPER-PATCH SCALE METHOD:")
print(f"  Final RMSE: {rmse_perPatch[-1]:.6f}")
print(f"  Final DM Error: {dm_perPatch[-1]:.6f}")
print(f"  Patch scales: {[f'{s:.4f}' for s in s_perPatch]}")
print("="*60)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import svd

# Ground truth: 8-sided polygon (octagon)
n_cells = 8
angles = np.linspace(0, 2*np.pi, n_cells, endpoint=False)
X_true = np.column_stack([np.cos(angles), np.sin(angles)])

# Define 3 overlapping patches WITH CENTER CELLS
patches = [
    np.array([0, 1, 2, 3, 4]),
    np.array([3, 4, 5, 6, 7]),
    np.array([6, 7, 0, 1, 2])
]

# Store which cell is the "center" of each patch (middle of the array)
patch_centers = [2, 5, 7]  # The middle cell in each patch

# Generate patch-local coordinates with random transforms
def random_transform(X, rotation_deg=None, scale=None, translation=None):
    if rotation_deg is None:
        rotation_deg = np.random.uniform(0, 360)
    if scale is None:
        scale = np.random.uniform(0.8, 1.2)
    if translation is None:
        translation = np.random.randn(2) * 0.5
    
    theta = np.radians(rotation_deg)
    R = np.array([[np.cos(theta), -np.sin(theta)], 
                  [np.sin(theta), np.cos(theta)]])
    return scale * (X @ R.T) + translation

# Create transformed local coordinates for each patch
local_coords = []
for patch_idx in patches:
    X_patch = X_true[patch_idx].copy()
    X_patch -= X_patch.mean(axis=0)
    X_transformed = random_transform(X_patch)
    X_transformed -= X_transformed.mean(axis=0)
    local_coords.append(X_transformed)

# Initialize global coordinates from first patch
X_global = np.zeros_like(X_true)
X_global[patches[0]] = local_coords[0]
for i, patch_idx in enumerate(patches[1:], 1):
    mask = ~np.isin(patch_idx, patches[0])
    X_global[patch_idx[mask]] = local_coords[i][mask]
X_global -= X_global.mean(axis=0)

# Centrality weights based on INDEX distance from center cell
def compute_weights(patch_idx, center_cell_global_idx):
    """
    Compute weights based on position in patch array.
    The center_cell should get weight 1.0, edges get lower weights.
    """
    n_points = len(patch_idx)
    # Find position of center cell in this patch
    center_pos = np.where(patch_idx == center_cell_global_idx)[0]
    
    if len(center_pos) > 0:
        center_pos = center_pos[0]
    else:
        # Fallback: use middle of array
        center_pos = n_points // 2
    
    # Distance in index space from center position
    positions = np.arange(n_points)
    distances = np.abs(positions - center_pos)
    
    # Exponential decay
    weights = np.exp(-distances**2 / (2 * (n_points/4)**2))
    return weights

n_iters = 10
history = [X_global.copy()]

# Compute metrics
def procrustes_error(X, X_true):
    X_c = X - X.mean(axis=0)
    X_true_c = X_true - X_true.mean(axis=0)
    
    C = X_true_c.T @ X_c
    U, S, Vt = svd(C)
    R = U @ Vt
    if np.linalg.det(R) < 0:
        U[:, -1] *= -1
        R = U @ Vt
    
    s = S.sum() / (X_c**2).sum()
    X_aligned = s * (X_c @ R.T)
    
    rmse = np.sqrt(((X_aligned - X_true_c)**2).sum() / len(X))
    return rmse

def distance_matrix_error(X, X_true):
    from scipy.spatial.distance import cdist
    D = cdist(X, X)
    D_true = cdist(X_true, X_true)
    return np.sqrt(((D - D_true)**2).mean())

rmse_history = [procrustes_error(X_global, X_true)]
dm_error_history = [distance_matrix_error(X_global, X_true)]

for iter_idx in range(n_iters):
    # Step A: Align each patch via Procrustes with centrality weighting
    transforms = []
    for patch_idx, V_local, center_cell in zip(patches, local_coords, patch_centers):
        X_patch = X_global[patch_idx]
        weights = compute_weights(patch_idx, center_cell)
        
        # Weighted centroids
        w_sum = weights.sum()
        mu_X = (weights[:, None] * X_patch).sum(axis=0) / w_sum
        mu_V = (weights[:, None] * V_local).sum(axis=0) / w_sum
        
        # Center
        X_centered = X_patch - mu_X
        V_centered = V_local - mu_V
        
        # Weighted covariance
        C = (X_centered.T * weights) @ V_centered
        
        # SVD for rotation
        U, S, Vt = svd(C)
        R = U @ Vt
        if np.linalg.det(R) < 0:
            U[:, -1] *= -1
            R = U @ Vt
        
        # Scale
        numerator = S.sum()
        denominator = (weights[:, None] * V_centered**2).sum()
        s = numerator / denominator if denominator > 0 else 1.0
        
        # Translation
        t = mu_X - s * (mu_V @ R.T)
        
        transforms.append((s, R, t))
    
    # Step B: Update global coordinates via centrality-weighted averaging
    X_new = np.zeros_like(X_global)
    W_total = np.zeros(n_cells)
    
    for (s, R, t), patch_idx, V_local, center_cell in zip(transforms, patches, local_coords, patch_centers):
        X_transformed = s * (V_local @ R.T) + t
        weights = compute_weights(patch_idx, center_cell)
        
        for i, cell_i in enumerate(patch_idx):
            X_new[cell_i] += weights[i] * X_transformed[i]
            W_total[cell_i] += weights[i]
    
    X_new /= W_total[:, None]
    X_new -= X_new.mean(axis=0)
    X_global = X_new
    history.append(X_global.copy())
    
    # Compute errors
    rmse_history.append(procrustes_error(X_global, X_true))
    dm_error_history.append(distance_matrix_error(X_global, X_true))

# Setup figure
fig, axes = plt.subplots(6, 4, figsize=(24, 30))

colors = ['red', 'blue', 'green']

def plot_octagon(ax, X, title, show_patches=True):
    ax.scatter(X[:, 0], X[:, 1], c='black', s=200, zorder=3, edgecolors='white', linewidths=2)
    
    for j in range(n_cells):
        next_j = (j + 1) % n_cells
        ax.plot([X[j, 0], X[next_j, 0]], [X[j, 1], X[next_j, 1]], 
                'k-', alpha=0.5, linewidth=2)
    
    for j in range(n_cells):
        ax.text(X[j, 0], X[j, 1], str(j), fontsize=16, 
                ha='center', va='center', color='white', weight='bold')
    
    if show_patches:
        for patch_idx, color in zip(patches, colors):
            X_patch = X[patch_idx]
            ax.fill(X_patch[:, 0], X_patch[:, 1], alpha=0.15, color=color, edgecolor=color, linewidth=2)
    
    ax.set_xlim(-2, 2)
    ax.set_ylim(-2, 2)
    ax.set_aspect('equal')
    ax.set_title(title, fontsize=18, weight='bold', pad=10)
    ax.grid(True, alpha=0.3, linewidth=1)
    ax.set_xticks([])
    ax.set_yticks([])

def plot_error_metrics(ax, iter_idx, rmse_hist, dm_hist):
    ax.clear()
    ax2 = ax.twinx()
    
    iters = np.arange(len(rmse_hist[:iter_idx+1]))
    
    line1 = ax.plot(iters, rmse_hist[:iter_idx+1], 'o-', color='blue', linewidth=3, 
                    markersize=8, label='RMSE (Procrustes)')
    line2 = ax2.plot(iters, dm_hist[:iter_idx+1], 's-', color='red', linewidth=3, 
                     markersize=8, label='Distance Matrix Error')
    
    ax.set_xlabel('Iteration', fontsize=14, weight='bold')
    ax.set_ylabel('RMSE', fontsize=14, weight='bold', color='blue')
    ax2.set_ylabel('Distance Matrix Error', fontsize=14, weight='bold', color='red')
    
    ax.tick_params(axis='y', labelcolor='blue', labelsize=12)
    ax2.tick_params(axis='y', labelcolor='red', labelsize=12)
    ax.tick_params(axis='x', labelsize=12)
    
    ax.grid(True, alpha=0.3, linewidth=1)
    ax.set_title(f'Convergence Metrics (Iter {iter_idx})', fontsize=16, weight='bold', pad=10)
    
    lines = line1 + line2
    labels = [l.get_label() for l in lines]
    ax.legend(lines, labels, loc='upper right', fontsize=12, framealpha=0.9)

# Row 0
plot_octagon(axes[0, 0], X_true, 'Ground Truth', show_patches=True)
plot_octagon(axes[0, 1], history[0], 'Iteration 0', show_patches=True)
plot_octagon(axes[0, 2], X_true, 'Ground Truth', show_patches=True)
plot_error_metrics(axes[0, 3], 0, rmse_history, dm_error_history)

# Rows 1-5
for row in range(1, 6):
    for col in range(2):
        iter_idx = (row - 1) * 2 + col + 1
        if iter_idx < len(history):
            plot_octagon(axes[row, col], history[iter_idx], f'Iteration {iter_idx}', show_patches=True)
        else:
            axes[row, col].axis('off')
    
    iter_idx = (row - 1) * 2 + 1
    if iter_idx < len(history):
        plot_octagon(axes[row, 2], X_true, 'Ground Truth', show_patches=True)
        plot_error_metrics(axes[row, 3], iter_idx, rmse_history, dm_error_history)

plt.tight_layout()
plt.show()

In [None]:
import torch
import numpy as np
import pandas as pd
import anndata as ad
import matplotlib.pyplot as plt
import os
from torch.utils.data import DataLoader

from core_models_et_p1 import STSetDataset, collate_minisets
import utils_et as uet

# ============================================================================
# SETUP
# ============================================================================

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

output_dir = "/home/ehtesamul/sc_st/model/gems_mousebrain_output"
timestamp = "20251129_205637"

# ============================================================================
# LOAD ST DATA
# ============================================================================

st_counts = '/home/ehtesamul/sc_st/data/mousedata_2020/E1z2/simu_st1_counts_et.csv'
st_meta   = '/home/ehtesamul/sc_st/data/mousedata_2020/E1z2/simu_st1_metadata_et.csv'
st_ct     = '/home/ehtesamul/sc_st/data/mousedata_2020/E1z2/simu_st1_celltype_et.csv'

print("\nLoading ST1 (training ST data)...")
st_expr_df = pd.read_csv(st_counts, index_col=0)
st_meta_df = pd.read_csv(st_meta, index_col=0)
st_ct_df = pd.read_csv(st_ct, index_col=0)

stadata = ad.AnnData(X=st_expr_df.values.T)
stadata.obs_names = st_expr_df.columns
stadata.var_names = st_expr_df.index
stadata.obsm['spatial'] = st_meta_df[['coord_x', 'coord_y']].values
stadata.obs['celltype_mapped_refined'] = st_ct_df.idxmax(axis=1).values

X_st = stadata.X
if hasattr(X_st, "toarray"):
    X_st = X_st.toarray()

st_expr = torch.tensor(X_st, dtype=torch.float32, device=device)
st_coords_raw = torch.tensor(stadata.obsm['spatial'], dtype=torch.float32, device=device)

slide_ids = torch.zeros(st_expr.shape[0], dtype=torch.long, device=device)
st_coords, st_mu, st_scale = uet.canonicalize_st_coords_per_slide(
    st_coords_raw, slide_ids
)

print(f"ST loaded: {stadata.shape[0]} spots, {stadata.shape[1]} genes")

# ============================================================================
# LOAD GEMS INFERENCE RESULTS
# ============================================================================

print(f"\n=== Loading GEMS Inference Results (timestamp: {timestamp}) ===")

processed_path = os.path.join(output_dir, f"sc_inference_processed_{timestamp}.pt")
gems_results = torch.load(processed_path, map_location='cpu', weights_only=False)

gems_coords = gems_results['coords_canon']
gems_D_edm = gems_results['D_edm']

print(f"GEMS results loaded:")
print(f"  - Coordinates shape: {gems_coords.shape}")
print(f"  - EDM shape: {gems_D_edm.shape}")
print(f"  - Number of cells: {gems_results['n_cells']}")

# ============================================================================
# GENERATE ST MINI-SUBSETS (NO MODEL NEEDED)
# ============================================================================

print("\n=== Generating 5 ST Mini-Subsets ===")

n_min = 96
n_max = 384
num_minisets = 8

np.random.seed(42)
torch.manual_seed(42)

miniset_data = []

for i in range(num_minisets):
    # Sample random subset (same logic as STSetDataset)
    n_total = st_coords.shape[0]
    n = np.random.randint(n_min, min(n_max + 1, n_total))
    
    # Random indices
    indices = torch.randperm(n_total, device=device)[:n]
    
    # Get coordinates for this miniset
    miniset_coords = st_coords[indices]
    
    # Compute ground truth EDM
    D_gt = torch.cdist(miniset_coords, miniset_coords).cpu().numpy()
    
    print(f"Miniset {i+1}: {n} points")
    
    miniset_data.append({
        'index': i,
        'n_points': n,
        'D_edm': D_gt
    })

# ============================================================================
# VISUALIZATION: EDM HEATMAPS (2 rows x 3 cols)
# ============================================================================

print("\n=== Creating EDM Heatmap Visualizations ===")

fig, axes = plt.subplots(3, 3, figsize=(18, 12))
axes = axes.flatten()

# Plot 5 ST minisets
for i, data in enumerate(miniset_data):
    ax = axes[i]
    D = data['D_edm']
    
    im = ax.imshow(D, cmap='viridis', aspect='auto')
    ax.set_title(f'ST Miniset {i+1}\n({data["n_points"]} points)', 
                 fontsize=14, fontweight='bold')
    ax.set_xlabel('Cell Index', fontsize=11)
    ax.set_ylabel('Cell Index', fontsize=11)
    
    cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    cbar.set_label('Distance', fontsize=10)

# Plot GEMS inference result
ax = axes[5]
D_gems = gems_D_edm.numpy() if torch.is_tensor(gems_D_edm) else gems_D_edm

# Subsample if too large
max_viz_size = 1000
if D_gems.shape[0] > max_viz_size:
    idx_sample = np.random.choice(D_gems.shape[0], max_viz_size, replace=False)
    idx_sample = np.sort(idx_sample)
    D_gems_viz = D_gems[np.ix_(idx_sample, idx_sample)]
    title_suffix = f'\n(showing {max_viz_size}/{D_gems.shape[0]} cells)'
else:
    D_gems_viz = D_gems
    title_suffix = f'\n({D_gems.shape[0]} cells)'

im = ax.imshow(D_gems_viz, cmap='viridis', aspect='auto')
ax.set_title(f'GEMS Inference Result{title_suffix}', 
             fontsize=14, fontweight='bold', color='red')
ax.set_xlabel('Cell Index (Sampled)', fontsize=11)
ax.set_ylabel('Cell Index (Sampled)', fontsize=11)

cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
cbar.set_label('Distance', fontsize=10)

plt.tight_layout()
edm_heatmap_path = os.path.join(output_dir, f'comparison_edm_heatmaps_{timestamp}.png')
# plt.savefig(edm_heatmap_path, dpi=300, bbox_inches='tight')
print(f"‚úì Saved EDM heatmaps: {edm_heatmap_path}")
plt.show()

# ============================================================================
# VISUALIZATION: DISTANCE DISTRIBUTIONS (2 rows x 3 cols)
# ============================================================================

print("\n=== Creating Distance Distribution Visualizations ===")

fig, axes = plt.subplots(3, 3, figsize=(21, 12))
axes = axes.flatten()

# Plot 5 ST minisets
for i, data in enumerate(miniset_data):
    ax = axes[i]
    D = data['D_edm']
    
    upper_tri_idx = np.triu_indices_from(D, k=1)
    distances = D[upper_tri_idx]
    
    ax.hist(distances, bins=50, alpha=0.7, edgecolor='black', color='steelblue', label='ST Miniset')
    ax.set_title(f'ST Miniset {i+1}\n({data["n_points"]} points)', 
                 fontsize=14, fontweight='bold')
    ax.set_xlabel('Distance', fontsize=11)
    ax.set_ylabel('Count', fontsize=11)
    
    mean_dist = distances.mean()
    median_dist = np.median(distances)
    ax.axvline(mean_dist, color='r', linestyle='--', linewidth=2, 
               label=f'Mean: {mean_dist:.2f}')
    ax.axvline(median_dist, color='g', linestyle='--', linewidth=2, 
               label=f'Median: {median_dist:.2f}')
    
    ax.legend(fontsize=9, loc='upper right')
    ax.grid(True, alpha=0.3, axis='y')

# Plot GEMS inference result
ax = axes[8]
upper_tri_idx = np.triu_indices_from(D_gems, k=1)
distances_gems = D_gems[upper_tri_idx]

ax.hist(distances_gems, bins=100, alpha=0.7, edgecolor='black', 
        color='red', label='GEMS Inference')
ax.set_title(f'GEMS Inference Result\n({D_gems.shape[0]} cells)', 
             fontsize=14, fontweight='bold', color='red')
ax.set_xlabel('Distance', fontsize=11)
ax.set_ylabel('Count', fontsize=11)

mean_gems = distances_gems.mean()
median_gems = np.median(distances_gems)
ax.axvline(mean_gems, color='darkred', linestyle='--', linewidth=2, 
           label=f'Mean: {mean_gems:.2f}')
ax.axvline(median_gems, color='darkgreen', linestyle='--', linewidth=2, 
           label=f'Median: {median_gems:.2f}')

ax.legend(fontsize=9, loc='upper right')
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
dist_hist_path = os.path.join(output_dir, f'comparison_distance_distributions_{timestamp}.png')
# plt.savefig(dist_hist_path, dpi=300, bbox_inches='tight')
print(f"‚úì Saved distance distributions: {dist_hist_path}")
plt.show()

# ============================================================================
# SUMMARY STATISTICS
# ============================================================================

print("\n" + "="*70)
print("COMPARISON SUMMARY STATISTICS")
print("="*70)

print("\nST Minisets (Ground Truth):")
for i, data in enumerate(miniset_data):
    D = data['D_edm']
    upper_tri_idx = np.triu_indices_from(D, k=1)
    distances = D[upper_tri_idx]
    
    print(f"\n  Miniset {i+1} ({data['n_points']} points):")
    print(f"    Mean distance:   {distances.mean():.4f}")
    print(f"    Median distance: {np.median(distances):.4f}")
    print(f"    Std distance:    {distances.std():.4f}")
    print(f"    Min distance:    {distances.min():.4f}")
    print(f"    Max distance:    {distances.max():.4f}")

print(f"\nGEMS Inference Result ({D_gems.shape[0]} cells):")
print(f"  Mean distance:   {distances_gems.mean():.4f}")
print(f"  Median distance: {np.median(distances_gems):.4f}")
print(f"  Std distance:    {distances_gems.std():.4f}")
print(f"  Min distance:    {distances_gems.min():.4f}")
print(f"  Max distance:    {distances_gems.max():.4f}")

print("\n" + "="*70)
print(f"All visualizations saved to: {output_dir}")
print("="*70)

In [None]:
# ============================================================================
# EIGENVALUE ANISOTROPY ANALYSIS FOR GEMS INFERENCE
# ============================================================================

print("="*70)
print("GEMS INFERENCE: 2D GEOMETRY VERIFICATION")
print("="*70)

# ============================================================================
# 1. ANALYZE GEMS PREDICTED COORDINATES
# ============================================================================

print("\n=== Analyzing GEMS Predicted Coordinates ===\n")

# Get GEMS coordinates (already loaded as gems_coords)
coords_gems = gems_coords.numpy() if torch.is_tensor(gems_coords) else gems_coords

# Compute anisotropy for GEMS
X = coords_gems.astype(float)
Xc = X - X.mean(axis=0, keepdims=True)

cov = Xc.T @ Xc / (Xc.shape[0] - 1)
eigvals_gems, eigvecs_gems = np.linalg.eigh(cov)
eigvals_gems = eigvals_gems[::-1]

lam1_gems, lam2_gems = eigvals_gems
ratio_gems = lam1_gems / (lam2_gems + 1e-12)

print(f"GEMS Predicted Coordinates ({coords_gems.shape[0]} cells):")
print(f"  Œª1 = {lam1_gems:.4f},  Œª2 = {lam2_gems:.4f}")
print(f"  Œª1/Œª2 = {ratio_gems:.2f}")

if ratio_gems < 5:
    interpretation_gems = "‚Üí GENUINELY 2D ‚úì"
elif ratio_gems < 20:
    interpretation_gems = "‚Üí Anisotropic but still 2D-ish"
else:
    interpretation_gems = "‚Üí EFFECTIVELY 1D (very elongated) ‚úó"

print(f"  {interpretation_gems}\n")

# ============================================================================
# 2. ANALYZE ST MINISETS FOR COMPARISON
# ============================================================================

print("=== Analyzing ST Mini-Subsets for Comparison ===\n")

ratios_st = []
eigenvalues_st = []

for i, data in enumerate(miniset_data):
    D = data['D_edm']
    
    # Reconstruct coordinates from EDM using classical MDS
    n = D.shape[0]
    Jn = np.eye(n) - np.ones((n, n)) / n
    B = -0.5 * (Jn @ (D ** 2) @ Jn)
    
    eigvals_full, eigvecs_full = np.linalg.eigh(B)
    eigvals_full = eigvals_full[::-1]
    eigvecs_full = eigvecs_full[:, ::-1]
    
    coords_patch = eigvecs_full[:, :2] @ np.diag(np.sqrt(np.maximum(eigvals_full[:2], 0)))
    
    # Analyze 2D variance
    X = coords_patch.astype(float)
    Xc = X - X.mean(axis=0, keepdims=True)
    
    cov = Xc.T @ Xc / (Xc.shape[0] - 1)
    eigvals_2d, _ = np.linalg.eigh(cov)
    eigvals_2d = eigvals_2d[::-1]
    
    lam1, lam2 = eigvals_2d
    ratio = lam1 / (lam2 + 1e-12)
    
    ratios_st.append(ratio)
    eigenvalues_st.append((lam1, lam2))

ratios_st = np.array(ratios_st)
eigenvalues_st = np.array(eigenvalues_st)

print(f"ST Mini-Subsets Statistics:")
print(f"  Œª1/Œª2 - Median: {np.median(ratios_st):.2f}")
print(f"  Œª1/Œª2 - Mean:   {ratios_st.mean():.2f}")
print(f"  Œª1/Œª2 - Range:  [{ratios_st.min():.2f}, {ratios_st.max():.2f}]")

# ============================================================================
# 3. COMPARISON VISUALIZATION
# ============================================================================

print("\n=== Creating Comparison Visualizations ===\n")

fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# Plot 1: Histogram comparison - ST vs GEMS
ax = axes[0, 0]
ax.hist(ratios_st, bins=30, alpha=0.6, edgecolor='black', color='steelblue', 
        label=f'ST Minisets (n={len(ratios_st)})')
ax.axvline(ratio_gems, color='red', linestyle='--', linewidth=3, 
           label=f'GEMS: {ratio_gems:.2f}')
ax.axvline(np.median(ratios_st), color='blue', linestyle='--', linewidth=2, 
           label=f'ST Median: {np.median(ratios_st):.2f}')
ax.axvline(5, color='g', linestyle=':', linewidth=2, alpha=0.5, label='2D threshold (5)')
ax.axvline(20, color='orange', linestyle=':', linewidth=2, alpha=0.5, label='1D threshold (20)')

ax.set_xlabel('Œª1/Œª2 (Anisotropy Ratio)', fontsize=13, fontweight='bold')
ax.set_ylabel('Count', fontsize=13, fontweight='bold')
ax.set_title('Anisotropy Comparison: ST Minisets vs GEMS', fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3, axis='y')

# Plot 2: Eigenvalue scatter - ST minisets
ax = axes[0, 1]
scatter = ax.scatter(eigenvalues_st[:, 1], eigenvalues_st[:, 0], 
                    c=ratios_st, cmap='viridis', alpha=0.7, s=80, 
                    edgecolors='black', linewidth=1, label='ST Minisets')

# Add GEMS point
ax.scatter(lam2_gems, lam1_gems, c='red', s=300, marker='*', 
          edgecolors='darkred', linewidth=2, label='GEMS', zorder=5)

# Diagonal line
min_val = min(eigenvalues_st[:, 1].min(), lam2_gems)
max_val = max(eigenvalues_st[:, 0].max(), lam1_gems)
ax.plot([min_val, max_val], [min_val, max_val], 'r--', 
        linewidth=2, label='Œª1 = Œª2', alpha=0.7)

ax.set_xlabel('Œª2 (Smaller Eigenvalue)', fontsize=13, fontweight='bold')
ax.set_ylabel('Œª1 (Larger Eigenvalue)', fontsize=13, fontweight='bold')
ax.set_title('Eigenvalue Scatter Plot', fontsize=14, fontweight='bold')
ax.legend(fontsize=11, loc='upper left')
ax.grid(True, alpha=0.3)

cbar = plt.colorbar(scatter, ax=ax)
cbar.set_label('Œª1/Œª2', fontsize=11)

# Plot 3: Coordinate scatter - GEMS
ax = axes[1, 0]
ax.scatter(coords_gems[:, 0], coords_gems[:, 1], alpha=0.5, s=10, 
          c='red', edgecolors='none')
ax.set_xlabel('Dimension 1', fontsize=13, fontweight='bold')
ax.set_ylabel('Dimension 2', fontsize=13, fontweight='bold')
ax.set_title(f'GEMS Predicted Coordinates\nŒª1/Œª2 = {ratio_gems:.2f}', 
             fontsize=14, fontweight='bold', color='red')
ax.grid(True, alpha=0.3)
ax.set_aspect('equal', adjustable='box')

# Plot 4: Box plot comparison
ax = axes[1, 1]

# Combine ST ratios with GEMS ratio for box plot
data_for_plot = [ratios_st, [ratio_gems]]
labels = ['ST Minisets\n(Ground Truth)', 'GEMS\n(Predicted)']

bp = ax.boxplot(data_for_plot, labels=labels, patch_artist=True,
                showmeans=True, meanline=True, widths=0.6)

bp['boxes'][0].set_facecolor('steelblue')
bp['boxes'][0].set_alpha(0.7)
bp['boxes'][1].set_facecolor('red')
bp['boxes'][1].set_alpha(0.7)

ax.axhline(5, color='g', linestyle='--', linewidth=2, alpha=0.5, label='2D threshold (5)')
ax.axhline(20, color='orange', linestyle='--', linewidth=2, alpha=0.5, label='1D threshold (20)')

ax.set_ylabel('Œª1/Œª2 (Anisotropy Ratio)', fontsize=13, fontweight='bold')
ax.set_title('Anisotropy Distribution Comparison', fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
gems_anisotropy_path = os.path.join(output_dir, f'gems_anisotropy_analysis_{timestamp}.png')
plt.savefig(gems_anisotropy_path, dpi=300, bbox_inches='tight')
print(f"‚úì Saved GEMS anisotropy analysis: {gems_anisotropy_path}")
plt.show()

# ============================================================================
# 4. DETAILED COMPARISON TABLE
# ============================================================================

print("\n" + "="*70)
print("DETAILED COMPARISON: ST MINISETS vs GEMS INFERENCE")
print("="*70)

print(f"\n{'Metric':<30} {'ST Minisets':<20} {'GEMS Inference':<20}")
print("-" * 70)
print(f"{'Number of samples':<30} {len(ratios_st):<20} {coords_gems.shape[0]:<20}")
print(f"{'Œª1 (larger eigenvalue)':<30} {eigenvalues_st[:, 0].mean():.4f} ¬± {eigenvalues_st[:, 0].std():.4f}   {lam1_gems:.4f}")
print(f"{'Œª2 (smaller eigenvalue)':<30} {eigenvalues_st[:, 1].mean():.4f} ¬± {eigenvalues_st[:, 1].std():.4f}   {lam2_gems:.4f}")
print(f"{'Œª1/Œª2 ratio (median)':<30} {np.median(ratios_st):.2f}              {ratio_gems:.2f}")
print(f"{'Œª1/Œª2 ratio (mean)':<30} {ratios_st.mean():.2f}              {ratio_gems:.2f}")
print(f"{'Œª1/Œª2 ratio (range)':<30} [{ratios_st.min():.2f}, {ratios_st.max():.2f}]     {ratio_gems:.2f}")

print("\n" + "="*70)
print("INTERPRETATION")
print("="*70)

st_genuinely_2d_pct = (ratios_st < 5).sum() / len(ratios_st) * 100

print(f"\nST Mini-Subsets (Ground Truth):")
print(f"  {st_genuinely_2d_pct:.1f}% are genuinely 2D (Œª1/Œª2 < 5)")
print(f"  Median anisotropy: {np.median(ratios_st):.2f}")
if np.median(ratios_st) < 5:
    print(f"  ‚Üí GENUINELY 2D ‚úì")

print(f"\nGEMS Predicted Coordinates:")
if ratio_gems < 5:
    print(f"  ‚úì GENUINELY 2D (Œª1/Œª2 = {ratio_gems:.2f})")
    print(f"  ‚Üí GEMS successfully preserves 2D spatial structure")
elif ratio_gems < 20:
    print(f"  ‚ö† Anisotropic but still 2D-ish (Œª1/Œª2 = {ratio_gems:.2f})")
    print(f"  ‚Üí GEMS produces elongated but 2D structures")
else:
    print(f"  ‚úó EFFECTIVELY 1D (Œª1/Œª2 = {ratio_gems:.2f})")
    print(f"  ‚Üí WARNING: GEMS collapsed to 1D structure")

# Comparison
if ratio_gems < 5 and np.median(ratios_st) < 5:
    print(f"\n‚úì EXCELLENT: Both ST and GEMS are genuinely 2D")
elif abs(ratio_gems - np.median(ratios_st)) < 3:
    print(f"\n‚úì GOOD: GEMS anisotropy ({ratio_gems:.2f}) is similar to ST ({np.median(ratios_st):.2f})")
else:
    print(f"\n‚ö† WARNING: Large anisotropy difference between GEMS ({ratio_gems:.2f}) and ST ({np.median(ratios_st):.2f})")

print("\n" + "="*70)
print("GEMS GEOMETRY ANALYSIS COMPLETE")
print("="*70)

In [None]:
# ===================================================================
# ANISOTROPY CORRECTION DIAGNOSTIC
# ===================================================================

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.spatial.distance import pdist, squareform
from scipy.stats import pearsonr, spearmanr

print("\n" + "="*70)
print("ANISOTROPY CORRECTION DIAGNOSTIC")
print("="*70)

# ===================================================================
# STEP 1: ANALYZE CURRENT PREDICTION ANISOTROPY
# ===================================================================

print("\n=== Current Prediction Analysis ===")

coords_centered = coords_pred - coords_pred.mean(axis=0, keepdims=True)

# PCA on predicted coordinates
cov_pred = coords_centered.T @ coords_centered / (coords_centered.shape[0] - 1)
eigvals_pred, eigvecs_pred = np.linalg.eigh(cov_pred)
eigvals_pred = eigvals_pred[::-1]
eigvecs_pred = eigvecs_pred[:, ::-1]

lam1_pred, lam2_pred = eigvals_pred
r_cur = lam1_pred / (lam2_pred + 1e-12)

print(f"Current Prediction:")
print(f"  Œª1 = {lam1_pred:.4f},  Œª2 = {lam2_pred:.4f}")
print(f"  Œª1/Œª2 = {r_cur:.2f}")

if r_cur < 5:
    print(f"  ‚Üí GENUINELY 2D ‚úì")
elif r_cur < 20:
    print(f"  ‚Üí Anisotropic but still 2D-ish")
else:
    print(f"  ‚Üí EFFECTIVELY 1D (very elongated) ‚úó")

# ===================================================================
# STEP 2: COMPUTE CORRECTION FACTOR
# ===================================================================

print("\n=== Computing Anisotropy Correction ===")

# Target ratio from ST minisets (from your previous analysis)
r_tgt = 2.2  

print(f"Target anisotropy ratio (from ST): {r_tgt:.2f}")
print(f"Current anisotropy ratio: {r_cur:.2f}")
print(f"Correction needed: {r_cur/r_tgt:.2f}x")

# Compute scaling factors
s = np.sqrt(r_cur / r_tgt)

print(f"\nScaling factors:")
print(f"  First PC (compress):  1/{s:.4f} = {1/s:.4f}")
print(f"  Second PC (expand):   {s:.4f}")

# Create scaling matrix
S = np.diag([1/s, s])

# ===================================================================
# STEP 3: APPLY CORRECTION
# ===================================================================

print("\n=== Applying Correction ===")

# Transform: center ‚Üí rotate to PCs ‚Üí scale ‚Üí rotate back
coords_balanced = coords_centered @ eigvecs_pred @ S @ eigvecs_pred.T

# Verify new anisotropy
cov_balanced = coords_balanced.T @ coords_balanced / (coords_balanced.shape[0] - 1)
eigvals_balanced, _ = np.linalg.eigh(cov_balanced)
eigvals_balanced = eigvals_balanced[::-1]

lam1_balanced, lam2_balanced = eigvals_balanced
r_balanced = lam1_balanced / (lam2_balanced + 1e-12)

print(f"\nCorrected Coordinates:")
print(f"  Œª1 = {lam1_balanced:.4f},  Œª2 = {lam2_balanced:.4f}")
print(f"  Œª1/Œª2 = {r_balanced:.2f}")
print(f"  ‚Üí Target was {r_tgt:.2f}, achieved {r_balanced:.2f}")

# ===================================================================
# STEP 4: RECOMPUTE METRICS
# ===================================================================

print("\n=== Recomputing Metrics ===")

# Compute balanced EDM
D_edm_balanced = squareform(pdist(coords_balanced, 'euclidean'))

# Extract upper triangle
triu_indices = np.triu_indices(n_cells, k=1)
gt_distances = gt_edm[triu_indices]
pred_distances_original = D_edm_pred[triu_indices]
pred_distances_balanced = D_edm_balanced[triu_indices]

# Scale alignment for both
scale_original = np.median(gt_distances) / np.median(pred_distances_original)
scale_balanced = np.median(gt_distances) / np.median(pred_distances_balanced)

pred_distances_original_scaled = pred_distances_original * scale_original
pred_distances_balanced_scaled = pred_distances_balanced * scale_balanced

# Correlations
pearson_original, _ = pearsonr(gt_distances, pred_distances_original_scaled)
spearman_original, _ = spearmanr(gt_distances, pred_distances_original_scaled)

pearson_balanced, _ = pearsonr(gt_distances, pred_distances_balanced_scaled)
spearman_balanced, _ = spearmanr(gt_distances, pred_distances_balanced_scaled)

print(f"\nOriginal Prediction:")
print(f"  Pearson:  {pearson_original:.4f}")
print(f"  Spearman: {spearman_original:.4f}")

print(f"\nAnisotropy-Corrected Prediction:")
print(f"  Pearson:  {pearson_balanced:.4f}")
print(f"  Spearman: {spearman_balanced:.4f}")

delta_pearson = pearson_balanced - pearson_original
delta_spearman = spearman_balanced - spearman_original

print(f"\nImprovement:")
print(f"  Œî Pearson:  {delta_pearson:+.4f}")
print(f"  Œî Spearman: {delta_spearman:+.4f}")

if abs(delta_spearman) < 0.02:
    print(f"\n‚Üí Spearman barely changed - anisotropy was NOT the main issue")
else:
    print(f"\n‚Üí Significant Spearman change - anisotropy correction helps!")

# ===================================================================
# VISUALIZATION 1: COORDINATE COMPARISON
# ===================================================================

print("\n=== Creating Visualizations ===")

fig, axes = plt.subplots(1, 3, figsize=(21, 6))

# Ground truth
axes[0].scatter(gt_coords[:, 0], gt_coords[:, 1], s=5, alpha=0.6, c='blue')
axes[0].set_title('Ground Truth Coordinates', fontsize=14, fontweight='bold')
axes[0].set_xlabel('X', fontsize=12)
axes[0].set_ylabel('Y', fontsize=12)
axes[0].set_aspect('equal', adjustable='box')
axes[0].grid(True, alpha=0.3)

# Original prediction
axes[1].scatter(coords_pred[:, 0], coords_pred[:, 1], s=5, alpha=0.6, c='red')
axes[1].set_title(f'Original Prediction\nŒª1/Œª2 = {r_cur:.2f}', fontsize=14, fontweight='bold')
axes[1].set_xlabel('X', fontsize=12)
axes[1].set_ylabel('Y', fontsize=12)
axes[1].set_aspect('equal', adjustable='box')
axes[1].grid(True, alpha=0.3)

# Corrected prediction
axes[2].scatter(coords_balanced[:, 0], coords_balanced[:, 1], s=5, alpha=0.6, c='green')
axes[2].set_title(f'Anisotropy-Corrected\nŒª1/Œª2 = {r_balanced:.2f} (target: {r_tgt:.2f})', 
                 fontsize=14, fontweight='bold')
axes[2].set_xlabel('X', fontsize=12)
axes[2].set_ylabel('Y', fontsize=12)
axes[2].set_aspect('equal', adjustable='box')
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# ===================================================================
# VISUALIZATION 2: DISTANCE SCATTER COMPARISON
# ===================================================================

sample_size = 50000
sample_idx = np.random.choice(len(gt_distances), min(sample_size, len(gt_distances)), replace=False)

fig, axes = plt.subplots(1, 2, figsize=(16, 7))

# Original
axes[0].scatter(gt_distances[sample_idx], pred_distances_original_scaled[sample_idx], 
               alpha=0.2, s=5, c='red')
axes[0].set_title(f'Original Prediction\nSpearman œÅ = {spearman_original:.4f}', 
                 fontsize=14, fontweight='bold')
axes[0].set_xlabel('Ground Truth Distance', fontsize=12)
axes[0].set_ylabel('Predicted Distance (scaled)', fontsize=12)
axes[0].grid(True, linestyle='--', alpha=0.5)

lims_0 = [min(axes[0].get_xlim()[0], axes[0].get_ylim()[0]), 
          max(axes[0].get_xlim()[1], axes[0].get_ylim()[1])]
axes[0].plot(lims_0, lims_0, 'k--', alpha=0.75, linewidth=2, label='Ideal')
axes[0].set_aspect('equal', adjustable='box')
axes[0].legend(fontsize=11)

# Corrected
axes[1].scatter(gt_distances[sample_idx], pred_distances_balanced_scaled[sample_idx], 
               alpha=0.2, s=5, c='green')
axes[1].set_title(f'Anisotropy-Corrected\nSpearman œÅ = {spearman_balanced:.4f} (Œî{delta_spearman:+.4f})', 
                 fontsize=14, fontweight='bold')
axes[1].set_xlabel('Ground Truth Distance', fontsize=12)
axes[1].set_ylabel('Predicted Distance (scaled)', fontsize=12)
axes[1].grid(True, linestyle='--', alpha=0.5)

lims_1 = [min(axes[1].get_xlim()[0], axes[1].get_ylim()[0]), 
          max(axes[1].get_xlim()[1], axes[1].get_ylim()[1])]
axes[1].plot(lims_1, lims_1, 'k--', alpha=0.75, linewidth=2, label='Ideal')
axes[1].set_aspect('equal', adjustable='box')
axes[1].legend(fontsize=11)

plt.tight_layout()
plt.show()

# ===================================================================
# VISUALIZATION 3: DISTANCE DISTRIBUTIONS
# ===================================================================

fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Original
axes[0].hist(gt_distances, bins=100, alpha=0.5, color='blue', 
            label='Ground Truth', density=True, edgecolor='black', linewidth=0.5)
axes[0].hist(pred_distances_original_scaled, bins=100, alpha=0.5, color='red', 
            label='Original Prediction', density=True, edgecolor='black', linewidth=0.5)
axes[0].set_title('Original Prediction', fontsize=14, fontweight='bold')
axes[0].set_xlabel('Distance', fontsize=12)
axes[0].set_ylabel('Density', fontsize=12)
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3, axis='y')

# Corrected
axes[1].hist(gt_distances, bins=100, alpha=0.5, color='blue', 
            label='Ground Truth', density=True, edgecolor='black', linewidth=0.5)
axes[1].hist(pred_distances_balanced_scaled, bins=100, alpha=0.5, color='green', 
            label='Anisotropy-Corrected', density=True, edgecolor='black', linewidth=0.5)
axes[1].set_title('Anisotropy-Corrected', fontsize=14, fontweight='bold')
axes[1].set_xlabel('Distance', fontsize=12)
axes[1].set_ylabel('Density', fontsize=12)
axes[1].legend(fontsize=11)
axes[1].grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

# ===================================================================
# VISUALIZATION 4: EIGENVALUE COMPARISON
# ===================================================================

fig, ax = plt.subplots(1, 1, figsize=(10, 8))

# Ground truth
gt_coords_centered = gt_coords - gt_coords.mean(axis=0, keepdims=True)
cov_gt = gt_coords_centered.T @ gt_coords_centered / (gt_coords_centered.shape[0] - 1)
eigvals_gt, _ = np.linalg.eigh(cov_gt)
eigvals_gt = eigvals_gt[::-1]
lam1_gt, lam2_gt = eigvals_gt
r_gt = lam1_gt / (lam2_gt + 1e-12)

# Plot
methods = ['Ground Truth', 'Original\nPrediction', 'Anisotropy\nCorrected']
ratios = [r_gt, r_cur, r_balanced]
colors = ['blue', 'red', 'green']

bars = ax.bar(methods, ratios, color=colors, alpha=0.7, edgecolor='black', linewidth=2)

ax.axhline(5, color='darkgreen', linestyle='--', linewidth=2, alpha=0.7, label='2D threshold (5)')
ax.axhline(20, color='orange', linestyle='--', linewidth=2, alpha=0.7, label='1D threshold (20)')
ax.axhline(r_tgt, color='purple', linestyle=':', linewidth=2, alpha=0.7, label=f'ST target ({r_tgt:.2f})')

ax.set_ylabel('Œª1/Œª2 (Anisotropy Ratio)', fontsize=13, fontweight='bold')
ax.set_title('Anisotropy Comparison', fontsize=16, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3, axis='y')

# Add value labels on bars
for bar, ratio in zip(bars, ratios):
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2., height,
            f'{ratio:.2f}',
            ha='center', va='bottom', fontsize=12, fontweight='bold')

plt.tight_layout()
plt.show()

# ===================================================================
# SUMMARY
# ===================================================================

print("\n" + "="*70)
print("DIAGNOSTIC SUMMARY")
print("="*70)

print(f"\nAnisotropy Ratios (Œª1/Œª2):")
print(f"  Ground Truth:           {r_gt:.2f}")
print(f"  Original Prediction:    {r_cur:.2f}")
print(f"  Corrected Prediction:   {r_balanced:.2f}")
print(f"  ST Miniset Target:      {r_tgt:.2f}")

print(f"\nSpearman Correlation:")
print(f"  Original:   {spearman_original:.4f}")
print(f"  Corrected:  {spearman_balanced:.4f}")
print(f"  Change:     {delta_spearman:+.4f}")

print(f"\n{'='*70}")
print("CONCLUSION")
print(f"{'='*70}")

if abs(delta_spearman) < 0.02:
    print("\n‚úì Spearman correlation barely changed after correction")
    print("‚Üí The main issue is NOT just global anisotropy")
    print("‚Üí You likely need to retrain with anisotropy regularization")
    print("‚Üí But the visual improvement suggests anisotropy IS part of the problem")
elif delta_spearman > 0.02:
    print("\n‚úì Spearman correlation IMPROVED significantly")
    print("‚Üí Anisotropy correction helps!")
    print("‚Üí Training with anisotropy regularization should improve results")
else:
    print("\n‚ö† Spearman correlation DECREASED")
    print("‚Üí Anisotropy correction made things worse (unexpected)")
    print("‚Üí The elongation might be capturing real structure")

if r_cur > 10:
    print(f"\n‚ö† Original prediction is very elongated (Œª1/Œª2 = {r_cur:.2f})")
    print("‚Üí This is much more anisotropic than ST minisets")
    print("‚Üí STRONGLY RECOMMEND adding anisotropy regularization to training")
elif r_cur > 5:
    print(f"\n‚ö† Original prediction is somewhat elongated (Œª1/Œª2 = {r_cur:.2f})")
    print("‚Üí Moderately more anisotropic than ST minisets")
    print("‚Üí Consider adding anisotropy regularization")
else:
    print(f"\n‚úì Original prediction is already well-balanced (Œª1/Œª2 = {r_cur:.2f})")

print("\n" + "="*70)
print("DIAGNOSTIC COMPLETE")
print("="*70)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch

# ============================================================================
# EIGENVALUE ANISOTROPY ANALYSIS FOR ST MINI-SUBSETS
# ============================================================================

print("="*70)
print("ST MINI-SUBSETS: 2D GEOMETRY VERIFICATION")
print("="*70)

# ============================================================================
# 1. ANALYZE THE 5 MINI-SUBSETS WE ALREADY GENERATED
# ============================================================================

print("\n=== Individual Mini-Subset Analysis ===\n")

ratios_individual = []
eigenvalues_list = []

for i, data in enumerate(miniset_data):
    D = data['D_edm']
    
    # Reconstruct coordinates from EDM using classical MDS
    n = D.shape[0]
    Jn = np.eye(n) - np.ones((n, n)) / n
    B = -0.5 * (Jn @ (D ** 2) @ Jn)
    
    eigvals_full, eigvecs_full = np.linalg.eigh(B)
    eigvals_full = eigvals_full[::-1]
    eigvecs_full = eigvecs_full[:, ::-1]
    
    # Take top 2 eigenvalues for 2D reconstruction
    coords_patch = eigvecs_full[:, :2] @ np.diag(np.sqrt(np.maximum(eigvals_full[:2], 0)))
    
    # Now analyze 2D variance of these coordinates
    X = coords_patch.astype(float)
    Xc = X - X.mean(axis=0, keepdims=True)
    
    cov = Xc.T @ Xc / (Xc.shape[0] - 1)
    eigvals_2d, eigvecs_2d = np.linalg.eigh(cov)
    eigvals_2d = eigvals_2d[::-1]
    
    lam1, lam2 = eigvals_2d
    ratio = lam1 / (lam2 + 1e-12)
    
    ratios_individual.append(ratio)
    eigenvalues_list.append((lam1, lam2))
    
    print(f"Miniset {i+1} ({data['n_points']} points):")
    print(f"  Œª1 = {lam1:.4f},  Œª2 = {lam2:.4f}")
    print(f"  Œª1/Œª2 = {ratio:.2f}")
    
    if ratio < 5:
        interpretation = "‚Üí GENUINELY 2D ‚úì"
    elif ratio < 20:
        interpretation = "‚Üí Anisotropic but still 2D-ish"
    else:
        interpretation = "‚Üí EFFECTIVELY 1D (very elongated) ‚úó"
    
    print(f"  {interpretation}\n")

# ============================================================================
# 2. GENERATE MANY MORE ST MINI-SUBSETS FOR STATISTICAL ANALYSIS
# ============================================================================

print("\n=== Statistical Analysis Over 200 Random ST Mini-Subsets ===\n")

np.random.seed(42)
torch.manual_seed(42)

num_samples = 200
ratios_stats = []
eigenvalues_stats = []

for i in range(num_samples):
    # Sample random subset
    n_total = st_coords.shape[0]
    n = np.random.randint(n_min, min(n_max + 1, n_total))
    
    indices = torch.randperm(n_total, device=device)[:n]
    miniset_coords = st_coords[indices].cpu().numpy()
    
    # Analyze 2D variance directly from coordinates
    X = miniset_coords.astype(float)
    Xc = X - X.mean(axis=0, keepdims=True)
    
    cov = Xc.T @ Xc / (Xc.shape[0] - 1)
    eigvals, _ = np.linalg.eigh(cov)
    eigvals = eigvals[::-1]
    
    lam1, lam2 = eigvals
    ratio = lam1 / (lam2 + 1e-12)
    
    ratios_stats.append(ratio)
    eigenvalues_stats.append((lam1, lam2))

ratios_stats = np.array(ratios_stats)
eigenvalues_stats = np.array(eigenvalues_stats)

print(f"Number of patches analyzed: {len(ratios_stats)}")
print(f"\nŒª1/Œª2 Anisotropy Ratio Statistics:")
print(f"  Min:        {ratios_stats.min():.2f}")
print(f"  25th %ile:  {np.percentile(ratios_stats, 25):.2f}")
print(f"  Median:     {np.median(ratios_stats):.2f}")
print(f"  75th %ile:  {np.percentile(ratios_stats, 75):.2f}")
print(f"  95th %ile:  {np.percentile(ratios_stats, 95):.2f}")
print(f"  Max:        {ratios_stats.max():.2f}")
print(f"  Mean:       {ratios_stats.mean():.2f}")
print(f"  Std:        {ratios_stats.std():.2f}")

print(f"\nEigenvalue Statistics:")
print(f"  Œª1 - Mean: {eigenvalues_stats[:, 0].mean():.4f}, Std: {eigenvalues_stats[:, 0].std():.4f}")
print(f"  Œª2 - Mean: {eigenvalues_stats[:, 1].mean():.4f}, Std: {eigenvalues_stats[:, 1].std():.4f}")

# Interpretation
median_ratio = np.median(ratios_stats)
print(f"\n{'='*70}")
print("INTERPRETATION:")
print(f"{'='*70}")

if median_ratio < 5:
    print("‚úì ST minisets are GENUINELY 2D")
    print("  ‚Üí Good for training 2D spatial reconstruction")
elif median_ratio < 20:
    print("‚ö† ST minisets are ANISOTROPIC but still 2D-ish")
    print("  ‚Üí Some elongation present, but dimensionality is 2D")
else:
    print("‚úó ST minisets are EFFECTIVELY 1D (very elongated)")
    print("  ‚Üí Warning: training on curved 1D strips, not 2D patches")

# Count how many are genuinely 2D
genuinely_2d = (ratios_stats < 5).sum()
anisotropic_2d = ((ratios_stats >= 5) & (ratios_stats < 20)).sum()
effectively_1d = (ratios_stats >= 20).sum()

print(f"\nDistribution:")
print(f"  Genuinely 2D (Œª1/Œª2 < 5):      {genuinely_2d}/{num_samples} ({100*genuinely_2d/num_samples:.1f}%)")
print(f"  Anisotropic 2D (5 ‚â§ Œª1/Œª2 < 20): {anisotropic_2d}/{num_samples} ({100*anisotropic_2d/num_samples:.1f}%)")
print(f"  Effectively 1D (Œª1/Œª2 ‚â• 20):    {effectively_1d}/{num_samples} ({100*effectively_1d/num_samples:.1f}%)")

# ============================================================================
# 3. VISUALIZATIONS
# ============================================================================

print(f"\n=== Creating Visualizations ===\n")

fig, axes = plt.subplots(2, 2, figsize=(14, 12))

# Plot 1: Histogram of Œª1/Œª2 ratios
ax = axes[0, 0]
ax.hist(ratios_stats, bins=50, alpha=0.7, edgecolor='black', color='steelblue')
ax.axvline(np.median(ratios_stats), color='r', linestyle='--', linewidth=2, 
           label=f'Median: {np.median(ratios_stats):.2f}')
ax.axvline(5, color='g', linestyle='--', linewidth=2, alpha=0.5, label='2D threshold (5)')
ax.axvline(20, color='orange', linestyle='--', linewidth=2, alpha=0.5, label='1D threshold (20)')
ax.set_xlabel('Œª1/Œª2 (Anisotropy Ratio)', fontsize=12, fontweight='bold')
ax.set_ylabel('Count', fontsize=12, fontweight='bold')
ax.set_title('Distribution of Eigenvalue Ratios\n(ST Mini-Subsets)', fontsize=14, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3, axis='y')

# Plot 2: Log-scale histogram
ax = axes[0, 1]
ax.hist(np.log10(ratios_stats), bins=50, alpha=0.7, edgecolor='black', color='coral')
ax.axvline(np.log10(np.median(ratios_stats)), color='r', linestyle='--', linewidth=2, 
           label=f'Median: {np.median(ratios_stats):.2f}')
ax.axvline(np.log10(5), color='g', linestyle='--', linewidth=2, alpha=0.5, label='log‚ÇÅ‚ÇÄ(5)')
ax.axvline(np.log10(20), color='orange', linestyle='--', linewidth=2, alpha=0.5, label='log‚ÇÅ‚ÇÄ(20)')
ax.set_xlabel('log‚ÇÅ‚ÇÄ(Œª1/Œª2)', fontsize=12, fontweight='bold')
ax.set_ylabel('Count', fontsize=12, fontweight='bold')
ax.set_title('Log-Scale Distribution\n(ST Mini-Subsets)', fontsize=14, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3, axis='y')

# Plot 3: Scatter plot of Œª1 vs Œª2
ax = axes[1, 0]
scatter = ax.scatter(eigenvalues_stats[:, 1], eigenvalues_stats[:, 0], 
                    c=ratios_stats, cmap='viridis', alpha=0.6, s=30)
ax.plot([eigenvalues_stats[:, 1].min(), eigenvalues_stats[:, 1].max()],
        [eigenvalues_stats[:, 1].min(), eigenvalues_stats[:, 1].max()],
        'r--', linewidth=2, label='Œª1 = Œª2 (isotropic)')
ax.set_xlabel('Œª2 (Smaller Eigenvalue)', fontsize=12, fontweight='bold')
ax.set_ylabel('Œª1 (Larger Eigenvalue)', fontsize=12, fontweight='bold')
ax.set_title('Eigenvalue Scatter Plot\n(Color = Œª1/Œª2)', fontsize=14, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
cbar = plt.colorbar(scatter, ax=ax)
cbar.set_label('Œª1/Œª2', fontsize=10)

# Plot 4: CDF of ratios
ax = axes[1, 1]
sorted_ratios = np.sort(ratios_stats)
cdf = np.arange(1, len(sorted_ratios) + 1) / len(sorted_ratios)
ax.plot(sorted_ratios, cdf, linewidth=2, color='steelblue')
ax.axvline(5, color='g', linestyle='--', linewidth=2, alpha=0.7, 
           label=f'2D threshold (5): {(ratios_stats < 5).sum()/num_samples*100:.1f}%')
ax.axvline(20, color='orange', linestyle='--', linewidth=2, alpha=0.7, 
           label=f'1D threshold (20): {(ratios_stats >= 20).sum()/num_samples*100:.1f}%')
ax.axhline(0.5, color='r', linestyle='--', linewidth=1, alpha=0.5, label='Median')
ax.set_xlabel('Œª1/Œª2 (Anisotropy Ratio)', fontsize=12, fontweight='bold')
ax.set_ylabel('Cumulative Probability', fontsize=12, fontweight='bold')
ax.set_title('Cumulative Distribution Function\n(ST Mini-Subsets)', fontsize=14, fontweight='bold')
ax.legend(fontsize=9)
ax.grid(True, alpha=0.3)
ax.set_xlim(left=0)

plt.tight_layout()

anisotropy_plot_path = os.path.join(output_dir, f'st_minisets_anisotropy_analysis_{timestamp}.png')
# plt.savefig(anisotropy_plot_path, dpi=300, bbox_inches='tight')
print(f"‚úì Saved anisotropy analysis plot: {anisotropy_plot_path}")
plt.show()

# ============================================================================
# 4. VISUALIZE EXAMPLE MINI-SUBSETS WITH DIFFERENT ANISOTROPIES
# ============================================================================

print("\n=== Visualizing Example Mini-Subsets by Anisotropy ===\n")

# Find examples of different anisotropy levels
low_aniso_idx = np.where(ratios_stats < 3)[0]
med_aniso_idx = np.where((ratios_stats >= 5) & (ratios_stats < 10))[0]
high_aniso_idx = np.where(ratios_stats >= 20)[0]

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

examples = [
    (low_aniso_idx, "Low Anisotropy (Œª1/Œª2 < 3)", 0),
    (med_aniso_idx, "Medium Anisotropy (5 ‚â§ Œª1/Œª2 < 10)", 1),
    (high_aniso_idx, "High Anisotropy (Œª1/Œª2 ‚â• 20)", 2)
]

for idx_array, title, ax_idx in examples:
    if len(idx_array) > 0:
        example_idx = idx_array[np.random.randint(len(idx_array))]
        
        # Generate this specific miniset again
        np.random.seed(42 + example_idx)
        torch.manual_seed(42 + example_idx)
        
        n = np.random.randint(n_min, min(n_max + 1, st_coords.shape[0]))
        indices = torch.randperm(st_coords.shape[0], device=device)[:n]
        miniset_coords = st_coords[indices].cpu().numpy()
        
        ratio = ratios_stats[example_idx]
        
        ax = axes[ax_idx]
        ax.scatter(miniset_coords[:, 0], miniset_coords[:, 1], 
                  alpha=0.6, s=20, c='steelblue', edgecolors='black', linewidth=0.5)
        ax.set_xlabel('Coordinate 1', fontsize=11)
        ax.set_ylabel('Coordinate 2', fontsize=11)
        ax.set_title(f'{title}\nŒª1/Œª2 = {ratio:.2f}', fontsize=12, fontweight='bold')
        ax.grid(True, alpha=0.3)
        ax.set_aspect('equal')
    else:
        ax = axes[ax_idx]
        ax.text(0.5, 0.5, f'No examples found\nfor {title}', 
               ha='center', va='center', fontsize=12)
        ax.set_title(title, fontsize=12, fontweight='bold')

plt.tight_layout()
examples_plot_path = os.path.join(output_dir, f'st_minisets_anisotropy_examples_{timestamp}.png')
# plt.savefig(examples_plot_path, dpi=300, bbox_inches='tight')
print(f"‚úì Saved example minisets plot: {examples_plot_path}")
plt.show()

print("\n" + "="*70)
print("ANISOTROPY ANALYSIS COMPLETE")
print("="*70)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch

# ============================================================================
# EFFECT OF CANONICALIZATION ON ST MINI-SUBSET GEOMETRY
# ============================================================================

print("="*70)
print("CANONICALIZATION IMPACT ON 2D GEOMETRY")
print("="*70)

# ============================================================================
# HELPER FUNCTIONS FOR DIFFERENT CANONICALIZATION METHODS
# ============================================================================

def center_only(coords):
    """Center coordinates to zero mean."""
    coords_centered = coords - coords.mean(axis=0, keepdims=True)
    return coords_centered

def center_and_scale_rms(coords):
    """Center and scale to unit RMS distance from origin."""
    coords_centered = coords - coords.mean(axis=0, keepdims=True)
    rms = np.sqrt((coords_centered ** 2).sum() / coords_centered.shape[0])
    coords_scaled = coords_centered / (rms + 1e-12)
    return coords_scaled

def center_and_scale_bbox(coords):
    """Center and scale to unit bounding box."""
    coords_centered = coords - coords.mean(axis=0, keepdims=True)
    bbox_size = coords_centered.max(axis=0) - coords_centered.min(axis=0)
    max_dim = bbox_size.max()
    coords_scaled = coords_centered / (max_dim + 1e-12)
    return coords_scaled

def full_canonicalize(coords):
    """Full canonicalization: center, scale, rotate to principal axes."""
    # Center
    coords_centered = coords - coords.mean(axis=0, keepdims=True)
    
    # Scale to unit RMS
    rms = np.sqrt((coords_centered ** 2).sum() / coords_centered.shape[0])
    coords_scaled = coords_centered / (rms + 1e-12)
    
    # Rotate to principal axes
    cov = coords_scaled.T @ coords_scaled / (coords_scaled.shape[0] - 1)
    eigvals, eigvecs = np.linalg.eigh(cov)
    eigvecs = eigvecs[:, ::-1]  # Descending order
    coords_rotated = coords_scaled @ eigvecs
    
    # Flip to positive quadrant
    for d in range(coords_rotated.shape[1]):
        if coords_rotated[:, d].sum() < 0:
            coords_rotated[:, d] *= -1
    
    return coords_rotated

def compute_anisotropy(coords):
    """Compute eigenvalue ratio Œª1/Œª2."""
    X = coords.astype(float)
    Xc = X - X.mean(axis=0, keepdims=True)
    cov = Xc.T @ Xc / (Xc.shape[0] - 1)
    eigvals, _ = np.linalg.eigh(cov)
    eigvals = eigvals[::-1]
    lam1, lam2 = eigvals
    ratio = lam1 / (lam2 + 1e-12)
    return ratio, lam1, lam2

# ============================================================================
# GENERATE TEST DATASET
# ============================================================================

print("\n=== Generating Test Dataset ===\n")

np.random.seed(42)
torch.manual_seed(42)

num_samples = 200

# Store results for each canonicalization method
results = {
    'raw': {'ratios': [], 'eigenvalues': []},
    'center_only': {'ratios': [], 'eigenvalues': []},
    'center_rms': {'ratios': [], 'eigenvalues': []},
    'center_bbox': {'ratios': [], 'eigenvalues': []},
    'full_canon': {'ratios': [], 'eigenvalues': []}
}

for i in range(num_samples):
    # Sample random subset
    n_total = st_coords.shape[0]
    n = np.random.randint(n_min, min(n_max + 1, n_total))
    
    indices = torch.randperm(n_total, device=device)[:n]
    coords_raw = st_coords[indices].cpu().numpy()
    
    # Test each canonicalization method
    methods = {
        'raw': coords_raw,
        'center_only': center_only(coords_raw),
        'center_rms': center_and_scale_rms(coords_raw),
        'center_bbox': center_and_scale_bbox(coords_raw),
        'full_canon': full_canonicalize(coords_raw)
    }
    
    for method_name, coords_transformed in methods.items():
        ratio, lam1, lam2 = compute_anisotropy(coords_transformed)
        results[method_name]['ratios'].append(ratio)
        results[method_name]['eigenvalues'].append((lam1, lam2))

# Convert to arrays
for method_name in results:
    results[method_name]['ratios'] = np.array(results[method_name]['ratios'])
    results[method_name]['eigenvalues'] = np.array(results[method_name]['eigenvalues'])

# ============================================================================
# STATISTICAL COMPARISON
# ============================================================================

print(f"Number of patches analyzed: {num_samples}\n")
print("="*70)
print("ANISOTROPY STATISTICS (Œª1/Œª2) BY CANONICALIZATION METHOD")
print("="*70)

method_labels = {
    'raw': 'Raw Coordinates',
    'center_only': 'Center Only',
    'center_rms': 'Center + Unit RMS',
    'center_bbox': 'Center + Unit BBox',
    'full_canon': 'Full Canonicalization'
}

for method_name, label in method_labels.items():
    ratios = results[method_name]['ratios']
    eigenvalues = results[method_name]['eigenvalues']
    
    print(f"\n{label}:")
    print(f"  Œª1/Œª2 Ratio:")
    print(f"    Min:       {ratios.min():.2f}")
    print(f"    Median:    {np.median(ratios):.2f}")
    print(f"    Mean:      {ratios.mean():.2f}")
    print(f"    Max:       {ratios.max():.2f}")
    print(f"    Std:       {ratios.std():.2f}")
    
    genuinely_2d = (ratios < 5).sum()
    anisotropic_2d = ((ratios >= 5) & (ratios < 20)).sum()
    effectively_1d = (ratios >= 20).sum()
    
    print(f"  Distribution:")
    print(f"    Genuinely 2D (<5):      {genuinely_2d}/{num_samples} ({100*genuinely_2d/num_samples:.1f}%)")
    print(f"    Anisotropic (5-20):     {anisotropic_2d}/{num_samples} ({100*anisotropic_2d/num_samples:.1f}%)")
    print(f"    Effectively 1D (‚â•20):   {effectively_1d}/{num_samples} ({100*effectively_1d/num_samples:.1f}%)")
    
    print(f"  Eigenvalues:")
    print(f"    Œª1 Mean: {eigenvalues[:, 0].mean():.4f}, Std: {eigenvalues[:, 0].std():.4f}")
    print(f"    Œª2 Mean: {eigenvalues[:, 1].mean():.4f}, Std: {eigenvalues[:, 1].std():.4f}")

# ============================================================================
# VISUALIZATION: COMPARISON OF METHODS
# ============================================================================

print(f"\n{'='*70}")
print("CREATING VISUALIZATIONS")
print(f"{'='*70}\n")

# Plot 1: Histogram comparison
fig, axes = plt.subplots(2, 3, figsize=(18, 12))
axes = axes.flatten()

colors = ['steelblue', 'coral', 'seagreen', 'purple', 'crimson']

for idx, (method_name, label) in enumerate(method_labels.items()):
    ax = axes[idx]
    ratios = results[method_name]['ratios']
    
    ax.hist(ratios, bins=50, alpha=0.7, edgecolor='black', color=colors[idx])
    ax.axvline(np.median(ratios), color='r', linestyle='--', linewidth=2, 
               label=f'Median: {np.median(ratios):.2f}')
    ax.axvline(5, color='g', linestyle='--', linewidth=2, alpha=0.5, label='2D threshold')
    ax.axvline(20, color='orange', linestyle='--', linewidth=2, alpha=0.5, label='1D threshold')
    
    ax.set_xlabel('Œª1/Œª2', fontsize=11, fontweight='bold')
    ax.set_ylabel('Count', fontsize=11)
    ax.set_title(label, fontsize=12, fontweight='bold')
    ax.legend(fontsize=9)
    ax.grid(True, alpha=0.3, axis='y')
    ax.set_xlim(0, min(100, ratios.max() + 5))

# Hide the 6th subplot
axes[5].axis('off')

plt.tight_layout()
comparison_hist_path = os.path.join(output_dir, f'canonicalization_comparison_histograms_{timestamp}.png')
# plt.savefig(comparison_hist_path, dpi=300, bbox_inches='tight')
print(f"‚úì Saved comparison histograms: {comparison_hist_path}")
plt.show()

# Plot 2: Box plot comparison
fig, ax = plt.subplots(1, 1, figsize=(12, 8))

data_for_boxplot = [results[method]['ratios'] for method in method_labels.keys()]
labels_short = ['Raw', 'Center\nOnly', 'Center +\nUnit RMS', 'Center +\nUnit BBox', 'Full\nCanon']

bp = ax.boxplot(data_for_boxplot, labels=labels_short, patch_artist=True,
                showmeans=True, meanline=True)

for patch, color in zip(bp['boxes'], colors):
    patch.set_facecolor(color)
    patch.set_alpha(0.7)

ax.axhline(5, color='g', linestyle='--', linewidth=2, alpha=0.5, label='2D threshold (5)')
ax.axhline(20, color='orange', linestyle='--', linewidth=2, alpha=0.5, label='1D threshold (20)')

ax.set_ylabel('Œª1/Œª2 (Anisotropy Ratio)', fontsize=14, fontweight='bold')
ax.set_xlabel('Canonicalization Method', fontsize=14, fontweight='bold')
ax.set_title('Effect of Canonicalization on Anisotropy\n(200 ST Mini-Subsets)', 
             fontsize=16, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
boxplot_path = os.path.join(output_dir, f'canonicalization_boxplot_{timestamp}.png')
# plt.savefig(boxplot_path, dpi=300, bbox_inches='tight')
print(f"‚úì Saved boxplot comparison: {boxplot_path}")
plt.show()

# Plot 3: CDF comparison
fig, ax = plt.subplots(1, 1, figsize=(12, 8))

for method_name, label, color in zip(method_labels.keys(), method_labels.values(), colors):
    ratios = results[method_name]['ratios']
    sorted_ratios = np.sort(ratios)
    cdf = np.arange(1, len(sorted_ratios) + 1) / len(sorted_ratios)
    ax.plot(sorted_ratios, cdf, linewidth=2.5, label=label, color=color, alpha=0.8)

ax.axvline(5, color='g', linestyle='--', linewidth=2, alpha=0.5, label='2D threshold (5)')
ax.axvline(20, color='orange', linestyle='--', linewidth=2, alpha=0.5, label='1D threshold (20)')

ax.set_xlabel('Œª1/Œª2 (Anisotropy Ratio)', fontsize=14, fontweight='bold')
ax.set_ylabel('Cumulative Probability', fontsize=14, fontweight='bold')
ax.set_title('Cumulative Distribution Comparison\n(Effect of Canonicalization)', 
             fontsize=16, fontweight='bold')
ax.legend(fontsize=11, loc='lower right')
ax.grid(True, alpha=0.3)
ax.set_xlim(0, 50)

plt.tight_layout()
cdf_path = os.path.join(output_dir, f'canonicalization_cdf_{timestamp}.png')
# plt.savefig(cdf_path, dpi=300, bbox_inches='tight')
print(f"‚úì Saved CDF comparison: {cdf_path}")
plt.show()

# Plot 4: Eigenvalue scatter comparison (2x3 grid)
fig, axes = plt.subplots(2, 3, figsize=(18, 12))
axes = axes.flatten()

for idx, (method_name, label) in enumerate(method_labels.items()):
    ax = axes[idx]
    eigenvalues = results[method_name]['eigenvalues']
    ratios = results[method_name]['ratios']
    
    scatter = ax.scatter(eigenvalues[:, 1], eigenvalues[:, 0], 
                        c=ratios, cmap='viridis', alpha=0.6, s=30, vmin=0, vmax=20)
    
    # Add diagonal line (isotropic reference)
    min_val = min(eigenvalues[:, 1].min(), eigenvalues[:, 0].min())
    max_val = max(eigenvalues[:, 1].max(), eigenvalues[:, 0].max())
    ax.plot([min_val, max_val], [min_val, max_val], 'r--', 
            linewidth=2, label='Œª1 = Œª2', alpha=0.7)
    
    ax.set_xlabel('Œª2 (Smaller)', fontsize=11, fontweight='bold')
    ax.set_ylabel('Œª1 (Larger)', fontsize=11, fontweight='bold')
    ax.set_title(label, fontsize=12, fontweight='bold')
    ax.legend(fontsize=9)
    ax.grid(True, alpha=0.3)
    
    if idx == 1:  # Add colorbar to middle plot
        cbar = plt.colorbar(scatter, ax=ax)
        cbar.set_label('Œª1/Œª2', fontsize=10)

axes[5].axis('off')

plt.tight_layout()
scatter_path = os.path.join(output_dir, f'canonicalization_eigenvalue_scatter_{timestamp}.png')
# plt.savefig(scatter_path, dpi=300, bbox_inches='tight')
print(f"‚úì Saved eigenvalue scatter plots: {scatter_path}")
plt.show()

# ============================================================================
# VISUAL EXAMPLE: SAME MINISET WITH DIFFERENT CANONICALIZATIONS
# ============================================================================

print(f"\n=== Creating Visual Example ===\n")

# Generate one example miniset
np.random.seed(123)
torch.manual_seed(123)

n = np.random.randint(n_min, min(n_max + 1, st_coords.shape[0]))
indices = torch.randperm(st_coords.shape[0], device=device)[:n]
coords_example_raw = st_coords[indices].cpu().numpy()

example_coords = {
    'Raw': coords_example_raw,
    'Center Only': center_only(coords_example_raw),
    'Center + RMS': center_and_scale_rms(coords_example_raw),
    'Center + BBox': center_and_scale_bbox(coords_example_raw),
    'Full Canon': full_canonicalize(coords_example_raw)
}

fig, axes = plt.subplots(2, 3, figsize=(18, 12))
axes = axes.flatten()

for idx, (label, coords) in enumerate(example_coords.items()):
    ax = axes[idx]
    ratio, lam1, lam2 = compute_anisotropy(coords)
    
    ax.scatter(coords[:, 0], coords[:, 1], alpha=0.6, s=30, 
              c='steelblue', edgecolors='black', linewidth=0.5)
    
    ax.set_xlabel('Dimension 1', fontsize=11, fontweight='bold')
    ax.set_ylabel('Dimension 2', fontsize=11, fontweight='bold')
    ax.set_title(f'{label}\nŒª1/Œª2 = {ratio:.2f}, Œª1={lam1:.3f}, Œª2={lam2:.3f}', 
                fontsize=12, fontweight='bold')
    ax.grid(True, alpha=0.3)
    ax.set_aspect('equal', adjustable='box')

axes[5].axis('off')

plt.tight_layout()
example_path = os.path.join(output_dir, f'canonicalization_visual_example_{timestamp}.png')
# plt.savefig(example_path, dpi=300, bbox_inches='tight')
print(f"‚úì Saved visual example: {example_path}")
plt.show()

# ============================================================================
# SUMMARY
# ============================================================================

print("\n" + "="*70)
print("CANONICALIZATION IMPACT SUMMARY")
print("="*70)

for method_name, label in method_labels.items():
    ratios = results[method_name]['ratios']
    median_ratio = np.median(ratios)
    pct_2d = (ratios < 5).sum() / num_samples * 100
    
    print(f"\n{label}:")
    print(f"  Median Œª1/Œª2: {median_ratio:.2f}")
    print(f"  % Genuinely 2D: {pct_2d:.1f}%")
    
    if median_ratio < 5:
        print(f"  ‚Üí ‚úì Remains GENUINELY 2D")
    elif median_ratio < 20:
        print(f"  ‚Üí ‚ö† Becomes more anisotropic but still 2D-ish")
    else:
        print(f"  ‚Üí ‚úó Becomes EFFECTIVELY 1D")

print("\n" + "="*70)
print("ANALYSIS COMPLETE")
print("="*70)

In [None]:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import utils_et as uet

# ============================================================================
# LOAD RAW ST DATA
# ============================================================================

print("="*70)
print("CANONICALIZATION IMPACT ON RAW ST DATA")
print("="*70)

st_counts = '/home/ehtesamul/sc_st/data/mousedata_2020/E1z2/simu_st1_counts_et.csv'
st_meta   = '/home/ehtesamul/sc_st/data/mousedata_2020/E1z2/simu_st1_metadata_et.csv'

print("\nLoading ST1 data...")
st_meta_df = pd.read_csv(st_meta, index_col=0)

# Get raw coordinates
st_coords_raw = torch.tensor(st_meta_df[['coord_x', 'coord_y']].values, 
                             dtype=torch.float32)

print(f"Loaded {st_coords_raw.shape[0]} spots")

# ============================================================================
# APPLY CANONICALIZATION (EXACTLY AS IN run_mouse_brain_2.py)
# ============================================================================

slide_ids = torch.zeros(st_coords_raw.shape[0], dtype=torch.long)

st_coords_canon, st_mu, st_scale = uet.canonicalize_st_coords_per_slide(
    st_coords_raw, slide_ids
)

print(f"\nCanonicalization parameters:")
print(f"  Mean (Œº):        {st_mu[0].numpy()}")
print(f"  Scale factor:    {st_scale[0].item():.4f}")

# ============================================================================
# COMPUTE STATISTICS
# ============================================================================

def compute_stats(coords, label):
    """Compute comprehensive statistics for coordinates."""
    coords_np = coords.numpy() if torch.is_tensor(coords) else coords
    
    # Basic stats
    mean = coords_np.mean(axis=0)
    std = coords_np.std(axis=0)
    min_vals = coords_np.min(axis=0)
    max_vals = coords_np.max(axis=0)
    
    # RMS radius (after centering)
    centered = coords_np - mean
    rms = np.sqrt((centered ** 2).sum(axis=1).mean())
    
    # Pairwise distances
    from scipy.spatial.distance import pdist
    distances = pdist(coords_np)
    
    # Eigenvalue analysis
    cov = np.cov(coords_np.T)
    eigvals, eigvecs = np.linalg.eigh(cov)
    eigvals = eigvals[::-1]
    
    print(f"\n{label}:")
    print(f"  Shape: {coords_np.shape}")
    print(f"  Mean: [{mean[0]:.4f}, {mean[1]:.4f}]")
    print(f"  Std:  [{std[0]:.4f}, {std[1]:.4f}]")
    print(f"  Min:  [{min_vals[0]:.4f}, {min_vals[1]:.4f}]")
    print(f"  Max:  [{max_vals[0]:.4f}, {max_vals[1]:.4f}]")
    print(f"  RMS radius (from mean): {rms:.4f}")
    print(f"  Pairwise distances:")
    print(f"    Mean:   {distances.mean():.4f}")
    print(f"    Median: {np.median(distances):.4f}")
    print(f"    p90:    {np.percentile(distances, 90):.4f}")
    print(f"  Eigenvalues: Œª1={eigvals[0]:.4f}, Œª2={eigvals[1]:.4f}")
    print(f"  Anisotropy (Œª1/Œª2): {eigvals[0]/(eigvals[1]+1e-12):.2f}")
    
    return {
        'mean': mean,
        'std': std,
        'min': min_vals,
        'max': max_vals,
        'rms': rms,
        'distances': distances,
        'eigvals': eigvals,
        'eigvecs': eigvecs
    }

stats_raw = compute_stats(st_coords_raw, "RAW COORDINATES")
stats_canon = compute_stats(st_coords_canon, "CANONICALIZED COORDINATES")

# ============================================================================
# VISUALIZATION
# ============================================================================

print(f"\n{'='*70}")
print("CREATING VISUALIZATIONS")
print(f"{'='*70}\n")

fig = plt.figure(figsize=(20, 12))
gs = fig.add_gridspec(3, 4, hspace=0.3, wspace=0.3)

# ============================================================================
# ROW 1: SCATTER PLOTS
# ============================================================================

# Plot 1: Raw coordinates
ax1 = fig.add_subplot(gs[0, 0])
ax1.scatter(st_coords_raw[:, 0], st_coords_raw[:, 1], 
           alpha=0.5, s=10, c='steelblue', edgecolors='none')
ax1.axhline(stats_raw['mean'][1], color='r', linestyle='--', linewidth=2, alpha=0.7, label='Mean Y')
ax1.axvline(stats_raw['mean'][0], color='r', linestyle='--', linewidth=2, alpha=0.7, label='Mean X')
ax1.set_xlabel('X (pixels)', fontsize=11, fontweight='bold')
ax1.set_ylabel('Y (pixels)', fontsize=11, fontweight='bold')
ax1.set_title(f'Raw Coordinates\nMean: [{stats_raw["mean"][0]:.1f}, {stats_raw["mean"][1]:.1f}]', 
             fontsize=12, fontweight='bold')
ax1.legend(fontsize=9)
ax1.grid(True, alpha=0.3)
ax1.set_aspect('equal', adjustable='box')

# Plot 2: Canonicalized coordinates
ax2 = fig.add_subplot(gs[0, 1])
ax2.scatter(st_coords_canon[:, 0], st_coords_canon[:, 1], 
           alpha=0.5, s=10, c='coral', edgecolors='none')
ax2.axhline(0, color='r', linestyle='--', linewidth=2, alpha=0.7, label='Mean Y = 0')
ax2.axvline(0, color='r', linestyle='--', linewidth=2, alpha=0.7, label='Mean X = 0')

# Add RMS circle
circle = plt.Circle((0, 0), 1.0, color='g', fill=False, 
                    linewidth=2, linestyle='--', alpha=0.7, label=f'RMS = 1.0')
ax2.add_patch(circle)

ax2.set_xlabel('X (canonicalized)', fontsize=11, fontweight='bold')
ax2.set_ylabel('Y (canonicalized)', fontsize=11, fontweight='bold')
ax2.set_title(f'Canonicalized Coordinates\nRMS: {stats_canon["rms"]:.4f}', 
             fontsize=12, fontweight='bold')
ax2.legend(fontsize=9)
ax2.grid(True, alpha=0.3)
ax2.set_aspect('equal', adjustable='box')

# Plot 3: Overlay comparison
ax3 = fig.add_subplot(gs[0, 2])
# Normalize raw coords for visual comparison (same scale as canonical)
raw_centered = st_coords_raw - st_coords_raw.mean(dim=0)
raw_normalized = raw_centered / stats_raw['rms']
ax3.scatter(raw_normalized[:, 0], raw_normalized[:, 1], 
           alpha=0.3, s=8, c='steelblue', edgecolors='none', label='Raw (normalized)')
ax3.scatter(st_coords_canon[:, 0], st_coords_canon[:, 1], 
           alpha=0.3, s=8, c='coral', edgecolors='none', label='Canonical')
ax3.set_xlabel('X (normalized scale)', fontsize=11, fontweight='bold')
ax3.set_ylabel('Y (normalized scale)', fontsize=11, fontweight='bold')
ax3.set_title('Overlay: Raw vs Canonical\n(both normalized to same scale)', 
             fontsize=12, fontweight='bold')
ax3.legend(fontsize=9)
ax3.grid(True, alpha=0.3)
ax3.set_aspect('equal', adjustable='box')

# Plot 4: Scale comparison bar chart
ax4 = fig.add_subplot(gs[0, 3])
metrics = ['Mean X', 'Mean Y', 'RMS Radius', 'Max Distance']
raw_vals = [abs(stats_raw['mean'][0]), abs(stats_raw['mean'][1]), 
           stats_raw['rms'], stats_raw['distances'].max()]
canon_vals = [abs(stats_canon['mean'][0]), abs(stats_canon['mean'][1]), 
             stats_canon['rms'], stats_canon['distances'].max()]

x = np.arange(len(metrics))
width = 0.35

bars1 = ax4.bar(x - width/2, raw_vals, width, label='Raw', color='steelblue', alpha=0.8)
bars2 = ax4.bar(x + width/2, canon_vals, width, label='Canonical', color='coral', alpha=0.8)

ax4.set_ylabel('Value', fontsize=11, fontweight='bold')
ax4.set_title('Metric Comparison', fontsize=12, fontweight='bold')
ax4.set_xticks(x)
ax4.set_xticklabels(metrics, rotation=45, ha='right', fontsize=9)
ax4.legend(fontsize=10)
ax4.grid(True, alpha=0.3, axis='y')

# Add value labels on bars
for bars in [bars1, bars2]:
    for bar in bars:
        height = bar.get_height()
        ax4.text(bar.get_x() + bar.get_width()/2., height,
                f'{height:.2f}', ha='center', va='bottom', fontsize=8)

# ============================================================================
# ROW 2: DISTANCE DISTRIBUTIONS
# ============================================================================

# Plot 5: Raw distance distribution
ax5 = fig.add_subplot(gs[1, 0])
ax5.hist(stats_raw['distances'], bins=100, alpha=0.7, 
        edgecolor='black', color='steelblue')
ax5.axvline(stats_raw['distances'].mean(), color='r', linestyle='--', 
           linewidth=2, label=f'Mean: {stats_raw["distances"].mean():.1f}')
ax5.axvline(np.median(stats_raw['distances']), color='g', linestyle='--', 
           linewidth=2, label=f'Median: {np.median(stats_raw["distances"]):.1f}')
ax5.set_xlabel('Pairwise Distance', fontsize=11, fontweight='bold')
ax5.set_ylabel('Count', fontsize=11, fontweight='bold')
ax5.set_title('Raw: Distance Distribution', fontsize=12, fontweight='bold')
ax5.legend(fontsize=9)
ax5.grid(True, alpha=0.3, axis='y')

# Plot 6: Canonical distance distribution
ax6 = fig.add_subplot(gs[1, 1])
ax6.hist(stats_canon['distances'], bins=100, alpha=0.7, 
        edgecolor='black', color='coral')
ax6.axvline(stats_canon['distances'].mean(), color='r', linestyle='--', 
           linewidth=2, label=f'Mean: {stats_canon["distances"].mean():.4f}')
ax6.axvline(np.median(stats_canon['distances']), color='g', linestyle='--', 
           linewidth=2, label=f'Median: {np.median(stats_canon["distances"]):.4f}')
ax6.set_xlabel('Pairwise Distance', fontsize=11, fontweight='bold')
ax6.set_ylabel('Count', fontsize=11, fontweight='bold')
ax6.set_title('Canonical: Distance Distribution', fontsize=12, fontweight='bold')
ax6.legend(fontsize=9)
ax6.grid(True, alpha=0.3, axis='y')

# Plot 7: Distance ratio histogram
ax7 = fig.add_subplot(gs[1, 2])
# Compute distance scaling factor
scale_factor = stats_raw['rms']
expected_ratio = scale_factor
actual_ratio = stats_raw['distances'].mean() / stats_canon['distances'].mean()

ax7.bar(['Scale Factor\n(RMS)', 'Distance Ratio\n(Mean)'], 
       [scale_factor, actual_ratio], 
       color=['steelblue', 'coral'], alpha=0.8, edgecolor='black')
ax7.set_ylabel('Value', fontsize=11, fontweight='bold')
ax7.set_title('Scaling Consistency Check', fontsize=12, fontweight='bold')
ax7.grid(True, alpha=0.3, axis='y')

for i, (label, val) in enumerate(zip(['Scale Factor', 'Distance Ratio'], 
                                     [scale_factor, actual_ratio])):
    ax7.text(i, val, f'{val:.4f}', ha='center', va='bottom', 
            fontsize=10, fontweight='bold')

# Plot 8: CDF comparison
ax8 = fig.add_subplot(gs[1, 3])
raw_sorted = np.sort(stats_raw['distances'])
canon_sorted = np.sort(stats_canon['distances'])
raw_cdf = np.arange(1, len(raw_sorted) + 1) / len(raw_sorted)
canon_cdf = np.arange(1, len(canon_sorted) + 1) / len(canon_sorted)

ax8.plot(raw_sorted, raw_cdf, linewidth=2, color='steelblue', label='Raw', alpha=0.8)
ax8.plot(canon_sorted, canon_cdf, linewidth=2, color='coral', label='Canonical', alpha=0.8)
ax8.set_xlabel('Pairwise Distance', fontsize=11, fontweight='bold')
ax8.set_ylabel('Cumulative Probability', fontsize=11, fontweight='bold')
ax8.set_title('Distance CDF Comparison', fontsize=12, fontweight='bold')
ax8.legend(fontsize=10)
ax8.grid(True, alpha=0.3)

# ============================================================================
# ROW 3: EIGENVALUE ANALYSIS
# ============================================================================

# Plot 9: Eigenvalue scatter (raw)
ax9 = fig.add_subplot(gs[2, 0])
raw_centered_np = (st_coords_raw - st_coords_raw.mean(dim=0)).numpy()
ax9.scatter(raw_centered_np[:, 0], raw_centered_np[:, 1], 
           alpha=0.4, s=10, c='steelblue', edgecolors='none')

# Draw principal axes
eigvecs_raw = stats_raw['eigvecs']
origin = np.array([0, 0])
for i in range(2):
    direction = eigvecs_raw[:, -(i+1)] * np.sqrt(stats_raw['eigvals'][i]) * 3
    ax9.arrow(origin[0], origin[1], direction[0], direction[1], 
             head_width=20, head_length=30, fc='red', ec='red', linewidth=2, alpha=0.7)
    ax9.text(direction[0], direction[1], f'Œª{i+1}={stats_raw["eigvals"][i]:.1f}', 
            fontsize=10, fontweight='bold', color='red')

ax9.set_xlabel('X (centered)', fontsize=11, fontweight='bold')
ax9.set_ylabel('Y (centered)', fontsize=11, fontweight='bold')
ax9.set_title(f'Raw: Principal Axes\nŒª1/Œª2 = {stats_raw["eigvals"][0]/stats_raw["eigvals"][1]:.2f}', 
             fontsize=12, fontweight='bold')
ax9.grid(True, alpha=0.3)
ax9.set_aspect('equal', adjustable='box')

# Plot 10: Eigenvalue scatter (canonical)
ax10 = fig.add_subplot(gs[2, 1])
canon_np = st_coords_canon.numpy()
ax10.scatter(canon_np[:, 0], canon_np[:, 1], 
            alpha=0.4, s=10, c='coral', edgecolors='none')

# Draw principal axes
eigvecs_canon = stats_canon['eigvecs']
for i in range(2):
    direction = eigvecs_canon[:, -(i+1)] * np.sqrt(stats_canon['eigvals'][i]) * 3
    ax10.arrow(0, 0, direction[0], direction[1], 
              head_width=0.1, head_length=0.15, fc='red', ec='red', linewidth=2, alpha=0.7)
    ax10.text(direction[0], direction[1], f'Œª{i+1}={stats_canon["eigvals"][i]:.3f}', 
             fontsize=10, fontweight='bold', color='red')

ax10.set_xlabel('X (canonicalized)', fontsize=11, fontweight='bold')
ax10.set_ylabel('Y (canonicalized)', fontsize=11, fontweight='bold')
ax10.set_title(f'Canonical: Principal Axes\nŒª1/Œª2 = {stats_canon["eigvals"][0]/stats_canon["eigvals"][1]:.2f}', 
              fontsize=12, fontweight='bold')
ax10.grid(True, alpha=0.3)
ax10.set_aspect('equal', adjustable='box')

# Plot 11: Eigenvalue comparison
ax11 = fig.add_subplot(gs[2, 2])
x = np.arange(2)
width = 0.35

bars1 = ax11.bar(x - width/2, stats_raw['eigvals'], width, 
                label='Raw', color='steelblue', alpha=0.8)
bars2 = ax11.bar(x + width/2, stats_canon['eigvals'], width, 
                label='Canonical', color='coral', alpha=0.8)

ax11.set_ylabel('Eigenvalue', fontsize=11, fontweight='bold')
ax11.set_title('Eigenvalue Comparison', fontsize=12, fontweight='bold')
ax11.set_xticks(x)
ax11.set_xticklabels(['Œª1', 'Œª2'], fontsize=11)
ax11.legend(fontsize=10)
ax11.grid(True, alpha=0.3, axis='y')

for bars in [bars1, bars2]:
    for bar in bars:
        height = bar.get_height()
        ax11.text(bar.get_x() + bar.get_width()/2., height,
                 f'{height:.2f}', ha='center', va='bottom', fontsize=9)

# Plot 12: Summary text
ax12 = fig.add_subplot(gs[2, 3])
ax12.axis('off')

summary_text = f"""
CANONICALIZATION SUMMARY

Transformation Applied:
  1. Center: X - Œº
  2. Scale: (X - Œº) / RMS

Parameters:
  Œº = [{st_mu[0, 0].item():.2f}, {st_mu[0, 1].item():.2f}]
  RMS = {st_scale[0].item():.4f}

Impact:
  ‚Ä¢ Mean: {stats_raw['mean']} ‚Üí {stats_canon['mean']}
  ‚Ä¢ RMS: {stats_raw['rms']:.4f} ‚Üí {stats_canon['rms']:.4f}
  ‚Ä¢ Distances scaled by: 1/{st_scale[0].item():.4f}
  
Geometry Preserved:
  ‚Ä¢ Anisotropy ratio unchanged:
    Raw: {stats_raw['eigvals'][0]/stats_raw['eigvals'][1]:.2f}
    Canon: {stats_canon['eigvals'][0]/stats_canon['eigvals'][1]:.2f}
  
  ‚Ä¢ Shape identical (similarity transform)
  ‚Ä¢ Only scale changed
"""

ax12.text(0.1, 0.5, summary_text, fontsize=11, verticalalignment='center',
         fontfamily='monospace', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.3))

plt.suptitle('Impact of canonicalize_st_coords_per_slide() on Raw ST Data', 
            fontsize=16, fontweight='bold', y=0.995)

output_path = os.path.join(output_dir, f'canonicalization_impact_{timestamp}.png')
# plt.savefig(output_path, dpi=300, bbox_inches='tight')
print(f"‚úì Saved visualization: {output_path}")
plt.show()

print("\n" + "="*70)
print("CANONICALIZATION ANALYSIS COMPLETE")
print("="*70)

In [None]:
# ===================================================================
# COMPLETE NOTEBOOK: SINGLE PATCH INFERENCE + FULL EVALUATION
# ===================================================================
import torch
import numpy as np
import scanpy as sc
from scipy.spatial.distance import cdist, pdist, squareform
from scipy.stats import pearsonr, spearmanr
import matplotlib.pyplot as plt
import seaborn as sns
import os
from datetime import datetime

# Setup
output_dir = "/home/ehtesamul/sc_st/model/gems_mousebrain_output"
timestamp = "20251128_100055"
checkpoint_path = f"{output_dir}/phase2_sc_finetuned_checkpoint.pt"

print("="*70)
print("SINGLE PATCH INFERENCE (Diagnostic Mode)")
print("="*70)

# ===================================================================
# STEP 1: LOAD TEST DATA
# ===================================================================
print("\n--- Loading Test Data ---")

from run_mouse_brain_2 import load_mouse_data
scadata, stadata = load_mouse_data()

# Extract SC gene expression
common = sorted(list(set(scadata.var_names) & set(stadata.var_names)))
X_sc = scadata[:, common].X
if hasattr(X_sc, "toarray"):
    X_sc = X_sc.toarray()
sc_expr = torch.tensor(X_sc, dtype=torch.float32)

n_cells = sc_expr.shape[0]
n_genes = sc_expr.shape[1]

print(f"Loaded SC data: {n_cells} cells, {n_genes} genes")
print(f"Ground truth coords shape: {scadata.obsm['spatial_gt'].shape}")

# ===================================================================
# STEP 2: LOAD MODEL AND CHECKPOINT
# ===================================================================
print("\n--- Loading Model and Checkpoint ---")

from core_models_et_p3 import GEMSModel

model = GEMSModel(
    n_genes=n_genes,
    n_embedding=[512, 256, 128],
    D_latent=32,
    c_dim=256,
    n_heads=4,
    isab_m=64,
    device='cuda',
    use_canonicalize=True,
    use_dist_bias=True,
    dist_bins=24,
    dist_head_shared=True,
    use_angle_features=True,
    angle_bins=8,
    knn_k=12,
    self_conditioning=True,
    sc_feat_mode='concat',
    landmarks_L=16
)

checkpoint = torch.load(checkpoint_path, map_location='cuda')
model.encoder.load_state_dict(checkpoint['encoder'])
model.context_encoder.load_state_dict(checkpoint['context_encoder'])
model.generator.load_state_dict(checkpoint['generator'])
model.score_net.load_state_dict(checkpoint['score_net'])

print(f"‚úì Loaded checkpoint from: {checkpoint_path}")
print(f"  Epochs trained: {checkpoint.get('epochs_finetune', 'N/A')}")

# ===================================================================
# STEP 3: SINGLE PATCH INFERENCE (DIAGNOSTIC MODE)
# ===================================================================
print("\n--- Running Single Patch Inference ---")
print(f"Config: patch_size={n_cells}, coverage_per_cell=1.0")
print("This runs ONE patch with ALL cells (no stitching)")
print("-"*70)

results = model.infer_sc_patchwise(
    sc_gene_expr=sc_expr,
    n_timesteps_sample=600,
    sigma_min=0.01,
    sigma_max=7.0,
    patch_size=n_cells,          # SINGLE PATCH MODE
    coverage_per_cell=1.0,       # NO OVERLAP
    n_align_iters=1,             # IRRELEVANT (only 1 patch)
    eta=0.0,
    guidance_scale=5.0,
    return_coords=True,
    debug_flag=True,
    debug_every=10,
)

print("\n‚úì Inference complete")

# ===================================================================
# STEP 4: EXTRACT RAW EDM (NO PROJECTION, NO RESCALING)
# ===================================================================
print("\n--- Computing Raw EDM (No Post-Processing) ---")

# Extract canonicalized coordinates
coords_canon = results['coords_canon'].cpu().numpy()

# Compute RAW EDM directly from coordinates (NO edm_project, NO rescaling)
gems_edm = cdist(coords_canon, coords_canon, metric='euclidean')

print(f"Raw EDM shape: {gems_edm.shape}")
print(f"Raw EDM stats:")
print(f"  Min: {gems_edm[gems_edm > 0].min():.4f}")
print(f"  Median: {np.median(gems_edm[gems_edm > 0]):.4f}")
print(f"  Max: {gems_edm.max():.4f}")
print(f"  Mean: {gems_edm[gems_edm > 0].mean():.4f}")

# ===================================================================
# STEP 5: COMPUTE GROUND TRUTH EDM
# ===================================================================
print("\n--- Calculating Ground Truth EDM ---")

gt_coords = scadata.obsm['spatial_gt']
gt_edm = squareform(pdist(gt_coords, 'euclidean'))

print(f"Ground Truth EDM shape: {gt_edm.shape}")
print(f"Ground Truth EDM stats:")
print(f"  Min: {gt_edm[gt_edm > 0].min():.4f}")
print(f"  Median: {np.median(gt_edm[gt_edm > 0]):.4f}")
print(f"  Max: {gt_edm.max():.4f}")
print(f"  Mean: {gt_edm[gt_edm > 0].mean():.4f}")

# ===================================================================
# STEP 6: NORMALIZE FOR COMPARISON
# ===================================================================
def normalize_matrix(matrix):
    min_val = matrix.min()
    max_val = matrix.max()
    return (matrix - min_val) / (max_val - min_val)

gems_edm_norm = normalize_matrix(gems_edm)
gt_edm_norm = normalize_matrix(gt_edm)

# ===================================================================
# STEP 7: QUANTITATIVE COMPARISON
# ===================================================================
print("\n" + "="*70)
print("QUANTITATIVE COMPARISON")
print("="*70)

# Extract upper triangle (excluding diagonal)
triu_indices = np.triu_indices(n_cells, k=1)
gt_distances_flat = gt_edm[triu_indices]
gems_distances_flat = gems_edm[triu_indices]

# Scale alignment (median matching)
scale = np.median(gt_distances_flat) / np.median(gems_distances_flat)
gems_distances_flat_scaled = gems_distances_flat * scale

print(f"\nScale factor (median matching): {scale:.4f}")

# Calculate correlations
pearson_corr, _ = pearsonr(gt_distances_flat, gems_distances_flat_scaled)
spearman_corr, _ = spearmanr(gt_distances_flat, gems_distances_flat_scaled)

print(f"\nPearson Correlation: {pearson_corr:.4f}")
print(f"Spearman Correlation: {spearman_corr:.4f}")
print("-"*70)

# ===================================================================
# STEP 8: VISUALIZATIONS
# ===================================================================
print("\n--- Generating Visualizations ---")

# --- PLOT 1: Side-by-Side Heatmaps ---
fig, axes = plt.subplots(1, 2, figsize=(16, 7))
fig.suptitle('EDM Comparison: Ground Truth vs. GEMS (Single Patch, Raw EDM)', 
             fontsize=18, fontweight='bold')

sample_size = min(838, n_cells)
sample_indices = np.random.choice(n_cells, sample_size, replace=False)
sample_indices = np.sort(sample_indices)

im1 = axes[0].imshow(gt_edm_norm[np.ix_(sample_indices, sample_indices)], cmap='viridis')
axes[0].set_title('Ground Truth EDM (Normalized)', fontsize=14)
axes[0].set_xlabel('Cell Index (Sampled)')
axes[0].set_ylabel('Cell Index (Sampled)')
fig.colorbar(im1, ax=axes[0], fraction=0.046, pad=0.04)

im2 = axes[1].imshow(gems_edm_norm[np.ix_(sample_indices, sample_indices)], cmap='viridis')
axes[1].set_title('GEMS Predicted EDM (Normalized)', fontsize=14)
axes[1].set_xlabel('Cell Index (Sampled)')
fig.colorbar(im2, ax=axes[1], fraction=0.046, pad=0.04)

plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()

# --- PLOT 2: Distribution of Distances ---
fig, ax = plt.subplots(figsize=(10, 6))
sns.histplot(gt_distances_flat, color="blue", label='Ground Truth Distances', 
             ax=ax, stat='density', bins=100, alpha=0.6)
sns.histplot(gems_distances_flat_scaled, color="red", label='GEMS Distances (Scaled)', 
             ax=ax, stat='density', bins=100, alpha=0.6)
ax.set_title('Distribution of Pairwise Distances (Single Patch Mode)', fontsize=16, fontweight='bold')
ax.set_xlabel('Distance', fontsize=12)
ax.set_ylabel('Density', fontsize=12)
ax.legend(fontsize=12)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# --- PLOT 3: Scatter Plot of Distances ---
sample_size_scatter = min(50000, len(gt_distances_flat))
sample_indices_scatter = np.random.choice(len(gt_distances_flat), sample_size_scatter, replace=False)

fig, ax = plt.subplots(figsize=(8, 8))
ax.scatter(
    gt_distances_flat[sample_indices_scatter],
    gems_distances_flat_scaled[sample_indices_scatter],
    alpha=0.2, s=5, color='steelblue'
)
ax.set_title(f'GEMS vs. Ground Truth Distances (Single Patch)\nSpearman œÅ = {spearman_corr:.4f}', 
             fontsize=16, fontweight='bold')
ax.set_xlabel('Ground Truth Pairwise Distance', fontsize=12)
ax.set_ylabel('GEMS Pairwise Distance (Scaled)', fontsize=12)
ax.grid(True, linestyle='--', alpha=0.5)

lims = [
    min(ax.get_xlim()[0], ax.get_ylim()[0]),
    max(ax.get_xlim()[1], ax.get_ylim()[1]),
]
ax.plot(lims, lims, 'r--', alpha=0.75, linewidth=2, zorder=0, label='Ideal Correlation')
ax.set_aspect('equal', adjustable='box')
ax.legend(fontsize=12)
plt.tight_layout()
plt.show()

# --- PLOT 4: Coordinate Comparison ---
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
fig.suptitle('Spatial Coordinates: Ground Truth vs. GEMS (Single Patch)', 
             fontsize=16, fontweight='bold')

axes[0].scatter(gt_coords[:, 0], gt_coords[:, 1], s=5, alpha=0.6, color='blue')
axes[0].set_title('Ground Truth Coordinates', fontsize=14)
axes[0].set_xlabel('X', fontsize=12)
axes[0].set_ylabel('Y', fontsize=12)
axes[0].set_aspect('equal')
axes[0].grid(True, alpha=0.3)

axes[1].scatter(coords_canon[:, 0], coords_canon[:, 1], s=5, alpha=0.6, color='red')
axes[1].set_title('GEMS Predicted Coordinates', fontsize=14)
axes[1].set_xlabel('X', fontsize=12)
axes[1].set_ylabel('Y', fontsize=12)
axes[1].set_aspect('equal')
axes[1].grid(True, alpha=0.3)

plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()

# --- PLOT 5: Distance Error Distribution ---
distance_errors = np.abs(gt_distances_flat - gems_distances_flat_scaled)
fig, ax = plt.subplots(figsize=(10, 6))
sns.histplot(distance_errors, bins=100, kde=True, ax=ax, color='purple')
ax.set_title('Distance Prediction Error Distribution', fontsize=16, fontweight='bold')
ax.set_xlabel('Absolute Error |GT - GEMS|', fontsize=12)
ax.set_ylabel('Count', fontsize=12)
ax.axvline(np.median(distance_errors), color='r', linestyle='--', linewidth=2, 
           label=f'Median Error: {np.median(distance_errors):.4f}')
ax.axvline(np.mean(distance_errors), color='g', linestyle='--', linewidth=2, 
           label=f'Mean Error: {np.mean(distance_errors):.4f}')
ax.legend(fontsize=12)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# ===================================================================
# STEP 9: SAVE RESULTS
# ===================================================================
print("\n--- Saving Results ---")

new_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_suffix = f"single_patch_{new_timestamp}"

results_processed = {
    'D_edm': gems_edm,  # RAW EDM (no projection, no rescaling)
    'coords': results['coords'].cpu().numpy(),
    'coords_canon': coords_canon,
    'n_cells': n_cells,
    'timestamp': new_timestamp,
    'mode': 'single_patch_no_projection',
    'scale_factor': scale,
    'pearson_corr': pearson_corr,
    'spearman_corr': spearman_corr,
    'model_config': {
        'n_genes': n_genes,
        'D_latent': 32,
        'c_dim': 256,
    }
}

processed_path = os.path.join(output_dir, f"sc_inference_processed_{output_suffix}.pt")
# torch.save(results_processed, processed_path)
# print(f"‚úì Saved: {processed_path}")

scadata.obsm['X_gems'] = coords_canon
adata_path = os.path.join(output_dir, f"scadata_with_gems_{output_suffix}.h5ad")
scadata.write_h5ad(adata_path)
print(f"‚úì Saved: {adata_path}")

print("\n" + "="*70)
print("SINGLE PATCH DIAGNOSTIC COMPLETE")
print("="*70)
print(f"\nResults Summary:")
print(f"  Mode: Single patch (patch_size={n_cells})")
print(f"  EDM: Raw (no projection, no rescaling)")
print(f"  Pearson: {pearson_corr:.4f}")
print(f"  Spearman: {spearman_corr:.4f}")
print(f"  Scale factor: {scale:.4f}")
print(f"  Output timestamp: {output_suffix}")

In [None]:
import torch
import numpy as np
import pandas as pd
import anndata as ad
import matplotlib.pyplot as plt
import seaborn as sns
import utils_et as uet  # Ensure this is in your python path

# 1. Load the Raw ST Data (Exact paths from your code)
print("Loading ST Data...")
st_counts = '/home/ehtesamul/sc_st/data/mousedata_2020/E1z2/simu_st1_counts_et.csv'
st_meta   = '/home/ehtesamul/sc_st/data/mousedata_2020/E1z2/simu_st1_metadata_et.csv'

# Load coords
st_meta_df = pd.read_csv(st_meta, index_col=0)
raw_coords = st_meta_df[['coord_x', 'coord_y']].values
st_coords_tensor = torch.tensor(raw_coords, dtype=torch.float32)

# 2. Apply the EXACT normalization used in run_mouse_brain_2.py
print("Applying Global RMS Normalization...")
# Dummy slide IDs (all 0) since you have single slide logic in the snippets
slide_ids = torch.zeros(st_coords_tensor.shape[0], dtype=torch.long)

# This is the function called in line 165 of run_mouse_brain_2.py
norm_coords, mu, scale = uet.canonicalize_st_coords_per_slide(
    st_coords_tensor, slide_ids
)

norm_coords = norm_coords.numpy()
print(f"Normalization Scale Factor used: {scale[0].item():.4f}")

# 3. Calculate Statistics
radii = np.sqrt(np.sum(norm_coords**2, axis=1))
points_outside = np.sum(radii > 1.0)
pct_outside = (points_outside / len(radii)) * 100

print("-" * 40)
print(f"Total Points: {len(radii)}")
print(f"Points outside Unit Circle (Radius > 1.0): {points_outside}")
print(f"Percentage outside: {pct_outside:.2f}%")
print(f"Max Radius: {radii.max():.4f}")
print("-" * 40)

# 4. Visualization
fig, axes = plt.subplots(1, 2, figsize=(16, 7))

# Plot A: The Normalized Geometry
axes[0].scatter(norm_coords[:, 0], norm_coords[:, 1], s=5, alpha=0.6, c='steelblue', label='ST Cells')
# Draw the Unit Circle
circle = plt.Circle((0, 0), 1.0, color='red', fill=False, linestyle='--', linewidth=2, label='Unit RMS Circle')
axes[0].add_patch(circle)
axes[0].set_title(f"Normalized ST Data\n({pct_outside:.1f}% points outside red circle)", fontsize=14)
axes[0].set_aspect('equal')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Plot B: Histogram of Radii
sns.histplot(radii, bins=50, ax=axes[1], kde=True, color='purple')
axes[1].axvline(1.0, color='red', linestyle='--', linewidth=2, label='Radius = 1.0')
axes[1].set_title("Distribution of Radii from Center", fontsize=14)
axes[1].set_xlabel("Distance from Center")
axes[1].legend()

plt.tight_layout()
plt.show()

In [None]:
# ===================================================================
# COMPLETE NOTEBOOK: ST-ONLY MODEL (PHASE 1) - SINGLE PATCH INFERENCE
# ===================================================================
import torch
import numpy as np
import scanpy as sc
from scipy.spatial.distance import cdist, pdist, squareform
from scipy.stats import pearsonr, spearmanr
import matplotlib.pyplot as plt
import seaborn as sns
import os
from datetime import datetime

# Setup
output_dir = "/home/ehtesamul/sc_st/model/gems_mousebrain_output"
timestamp = "20251128_100055"

# USE PHASE 1 CHECKPOINT (ST-ONLY, BEFORE SC FINE-TUNING)
checkpoint_path = f"{output_dir}/phase1_st_checkpoint.pt"

print("="*70)
print("ST-ONLY MODEL INFERENCE (Phase 1, Single Patch)")
print("="*70)

# ===================================================================
# STEP 1: LOAD TEST DATA
# ===================================================================
print("\n--- Loading Test Data ---")

from run_mouse_brain_2 import load_mouse_data
scadata, stadata = load_mouse_data()

# Extract SC gene expression
common = sorted(list(set(scadata.var_names) & set(stadata.var_names)))
X_sc = scadata[:, common].X
if hasattr(X_sc, "toarray"):
    X_sc = X_sc.toarray()
sc_expr = torch.tensor(X_sc, dtype=torch.float32)

n_cells = sc_expr.shape[0]
n_genes = sc_expr.shape[1]

print(f"Loaded SC data: {n_cells} cells, {n_genes} genes")
print(f"Ground truth coords shape: {scadata.obsm['spatial_gt'].shape}")

# ===================================================================
# STEP 2: LOAD MODEL AND ST-ONLY CHECKPOINT (PHASE 1)
# ===================================================================
print("\n--- Loading Model and ST-Only Checkpoint (Phase 1) ---")

from core_models_et_p3 import GEMSModel

model = GEMSModel(
    n_genes=n_genes,
    n_embedding=[512, 256, 128],
    D_latent=32,
    c_dim=256,
    n_heads=4,
    isab_m=64,
    device='cuda',
    use_canonicalize=True,
    use_dist_bias=True,
    dist_bins=24,
    dist_head_shared=True,
    use_angle_features=True,
    angle_bins=8,
    knn_k=12,
    self_conditioning=True,
    sc_feat_mode='concat',
    landmarks_L=16
)

checkpoint = torch.load(checkpoint_path, map_location='cuda')
model.encoder.load_state_dict(checkpoint['encoder'])
model.context_encoder.load_state_dict(checkpoint['context_encoder'])
model.generator.load_state_dict(checkpoint['generator'])
model.score_net.load_state_dict(checkpoint['score_net'])

print(f"‚úì Loaded ST-ONLY checkpoint from: {checkpoint_path}")
print(f"  Best ST epoch: {checkpoint.get('E_ST_best', 'N/A')}")
print(f"  This model was trained ONLY on ST data (NO SC fine-tuning)")

# ===================================================================
# STEP 3: SINGLE PATCH INFERENCE (DIAGNOSTIC MODE)
# ===================================================================
print("\n--- Running Single Patch Inference (ST-Only Model) ---")
print(f"Config: patch_size={n_cells}, coverage_per_cell=1.0, n_align_iters=1")
print("This runs ONE patch with ALL cells (no stitching)")
print("-"*70)

results = model.infer_sc_patchwise(
    sc_gene_expr=sc_expr,
    n_timesteps_sample=600,
    sigma_min=0.01,
    sigma_max=7.0,
    patch_size=n_cells,          # SINGLE PATCH MODE
    coverage_per_cell=1.0,       # NO OVERLAP
    n_align_iters=1,             # NO STITCHING (only 1 patch)
    eta=0.0,
    guidance_scale=5.0,
    return_coords=True,
    debug_flag=True,
    debug_every=10,
)

print("\n‚úì Inference complete")

# ===================================================================
# STEP 4: EXTRACT RAW EDM (NO PROJECTION, NO RESCALING)
# ===================================================================
print("\n--- Computing Raw EDM (No Post-Processing) ---")

# Extract canonicalized coordinates
coords_canon = results['coords_canon'].cpu().numpy()

# Compute RAW EDM directly from coordinates (NO edm_project, NO rescaling)
gems_edm = cdist(coords_canon, coords_canon, metric='euclidean')

print(f"Raw EDM shape: {gems_edm.shape}")
print(f"Raw EDM stats:")
print(f"  Min: {gems_edm[gems_edm > 0].min():.4f}")
print(f"  Median: {np.median(gems_edm[gems_edm > 0]):.4f}")
print(f"  Max: {gems_edm.max():.4f}")
print(f"  Mean: {gems_edm[gems_edm > 0].mean():.4f}")

# ===================================================================
# STEP 5: COMPUTE GROUND TRUTH EDM
# ===================================================================
print("\n--- Calculating Ground Truth EDM ---")

gt_coords = scadata.obsm['spatial_gt']
gt_edm = squareform(pdist(gt_coords, 'euclidean'))

print(f"Ground Truth EDM shape: {gt_edm.shape}")
print(f"Ground Truth EDM stats:")
print(f"  Min: {gt_edm[gt_edm > 0].min():.4f}")
print(f"  Median: {np.median(gt_edm[gt_edm > 0]):.4f}")
print(f"  Max: {gt_edm.max():.4f}")
print(f"  Mean: {gt_edm[gt_edm > 0].mean():.4f}")

# ===================================================================
# STEP 6: NORMALIZE FOR COMPARISON
# ===================================================================
def normalize_matrix(matrix):
    min_val = matrix.min()
    max_val = matrix.max()
    return (matrix - min_val) / (max_val - min_val)

gems_edm_norm = normalize_matrix(gems_edm)
gt_edm_norm = normalize_matrix(gt_edm)

# ===================================================================
# STEP 7: QUANTITATIVE COMPARISON
# ===================================================================
print("\n" + "="*70)
print("QUANTITATIVE COMPARISON (ST-ONLY MODEL)")
print("="*70)

# Extract upper triangle (excluding diagonal)
triu_indices = np.triu_indices(n_cells, k=1)
gt_distances_flat = gt_edm[triu_indices]
gems_distances_flat = gems_edm[triu_indices]

# Scale alignment (median matching)
scale = np.median(gt_distances_flat) / np.median(gems_distances_flat)
gems_distances_flat_scaled = gems_distances_flat * scale

print(f"\nScale factor (median matching): {scale:.4f}")

# Calculate correlations
pearson_corr, _ = pearsonr(gt_distances_flat, gems_distances_flat_scaled)
spearman_corr, _ = spearmanr(gt_distances_flat, gems_distances_flat_scaled)

print(f"\nPearson Correlation: {pearson_corr:.4f}")
print(f"Spearman Correlation: {spearman_corr:.4f}")
print("-"*70)

# ===================================================================
# STEP 8: VISUALIZATIONS
# ===================================================================
print("\n--- Generating Visualizations ---")

# --- PLOT 1: Side-by-Side Heatmaps ---
fig, axes = plt.subplots(1, 2, figsize=(16, 7))
fig.suptitle('EDM Comparison: Ground Truth vs. GEMS (ST-Only Model, Single Patch)', 
             fontsize=18, fontweight='bold')

sample_size = min(838, n_cells)
sample_indices = np.random.choice(n_cells, sample_size, replace=False)
sample_indices = np.sort(sample_indices)

im1 = axes[0].imshow(gt_edm_norm[np.ix_(sample_indices, sample_indices)], cmap='viridis')
axes[0].set_title('Ground Truth EDM (Normalized)', fontsize=14)
axes[0].set_xlabel('Cell Index (Sampled)')
axes[0].set_ylabel('Cell Index (Sampled)')
fig.colorbar(im1, ax=axes[0], fraction=0.046, pad=0.04)

im2 = axes[1].imshow(gems_edm_norm[np.ix_(sample_indices, sample_indices)], cmap='viridis')
axes[1].set_title('GEMS Predicted EDM (ST-Only, Normalized)', fontsize=14)
axes[1].set_xlabel('Cell Index (Sampled)')
fig.colorbar(im2, ax=axes[1], fraction=0.046, pad=0.04)

plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()

# --- PLOT 2: Distribution of Distances ---
fig, ax = plt.subplots(figsize=(10, 6))
sns.histplot(gt_distances_flat, color="blue", label='Ground Truth Distances', 
             ax=ax, stat='density', bins=100, alpha=0.6)
sns.histplot(gems_distances_flat_scaled, color="orange", label='GEMS Distances (ST-Only, Scaled)', 
             ax=ax, stat='density', bins=100, alpha=0.6)
ax.set_title('Distribution of Pairwise Distances (ST-Only Model)', fontsize=16, fontweight='bold')
ax.set_xlabel('Distance', fontsize=12)
ax.set_ylabel('Density', fontsize=12)
ax.legend(fontsize=12)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# --- PLOT 3: Scatter Plot of Distances ---
sample_size_scatter = min(50000, len(gt_distances_flat))
sample_indices_scatter = np.random.choice(len(gt_distances_flat), sample_size_scatter, replace=False)

fig, ax = plt.subplots(figsize=(8, 8))
ax.scatter(
    gt_distances_flat[sample_indices_scatter],
    gems_distances_flat_scaled[sample_indices_scatter],
    alpha=0.2, s=5, color='orange'
)
ax.set_title(f'GEMS vs. Ground Truth Distances (ST-Only Model)\nSpearman œÅ = {spearman_corr:.4f}', 
             fontsize=16, fontweight='bold')
ax.set_xlabel('Ground Truth Pairwise Distance', fontsize=12)
ax.set_ylabel('GEMS Pairwise Distance (Scaled)', fontsize=12)
ax.grid(True, linestyle='--', alpha=0.5)

lims = [
    min(ax.get_xlim()[0], ax.get_ylim()[0]),
    max(ax.get_xlim()[1], ax.get_ylim()[1]),
]
ax.plot(lims, lims, 'r--', alpha=0.75, linewidth=2, zorder=0, label='Ideal Correlation')
ax.set_aspect('equal', adjustable='box')
ax.legend(fontsize=12)
plt.tight_layout()
plt.show()

# --- PLOT 4: Coordinate Comparison ---
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
fig.suptitle('Spatial Coordinates: Ground Truth vs. GEMS (ST-Only Model)', 
             fontsize=16, fontweight='bold')

axes[0].scatter(gt_coords[:, 0], gt_coords[:, 1], s=5, alpha=0.6, color='blue')
axes[0].set_title('Ground Truth Coordinates', fontsize=14)
axes[0].set_xlabel('X', fontsize=12)
axes[0].set_ylabel('Y', fontsize=12)
axes[0].set_aspect('equal')
axes[0].grid(True, alpha=0.3)

axes[1].scatter(coords_canon[:, 0], coords_canon[:, 1], s=5, alpha=0.6, color='orange')
axes[1].set_title('GEMS Predicted Coordinates (ST-Only)', fontsize=14)
axes[1].set_xlabel('X', fontsize=12)
axes[1].set_ylabel('Y', fontsize=12)
axes[1].set_aspect('equal')
axes[1].grid(True, alpha=0.3)

plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()

# --- PLOT 5: Distance Error Distribution ---
distance_errors = np.abs(gt_distances_flat - gems_distances_flat_scaled)
fig, ax = plt.subplots(figsize=(10, 6))
sns.histplot(distance_errors, bins=100, kde=True, ax=ax, color='orange')
ax.set_title('Distance Prediction Error Distribution (ST-Only Model)', fontsize=16, fontweight='bold')
ax.set_xlabel('Absolute Error |GT - GEMS|', fontsize=12)
ax.set_ylabel('Count', fontsize=12)
ax.axvline(np.median(distance_errors), color='r', linestyle='--', linewidth=2, 
           label=f'Median Error: {np.median(distance_errors):.4f}')
ax.axvline(np.mean(distance_errors), color='g', linestyle='--', linewidth=2, 
           label=f'Mean Error: {np.mean(distance_errors):.4f}')
ax.legend(fontsize=12)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# ===================================================================
# STEP 9: SAVE RESULTS
# ===================================================================
print("\n--- Saving Results ---")

new_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_suffix = f"st_only_single_patch_{new_timestamp}"

results_processed = {
    'D_edm': gems_edm,  # RAW EDM (no projection, no rescaling)
    'coords': results['coords'].cpu().numpy(),
    'coords_canon': coords_canon,
    'n_cells': n_cells,
    'timestamp': new_timestamp,
    'mode': 'st_only_single_patch_no_projection',
    'scale_factor': scale,
    'pearson_corr': pearson_corr,
    'spearman_corr': spearman_corr,
    'model_config': {
        'n_genes': n_genes,
        'D_latent': 32,
        'c_dim': 256,
        'phase': 'ST-only (Phase 1)',
    }
}

processed_path = os.path.join(output_dir, f"sc_inference_processed_{output_suffix}.pt")
# torch.save(results_processed, processed_path)
# print(f"‚úì Saved: {processed_path}")

scadata.obsm['X_gems_st_only'] = coords_canon
adata_path = os.path.join(output_dir, f"scadata_with_gems_{output_suffix}.h5ad")
scadata.write_h5ad(adata_path)
print(f"‚úì Saved: {adata_path}")

print("\n" + "="*70)
print("ST-ONLY MODEL DIAGNOSTIC COMPLETE")
print("="*70)
print(f"\nResults Summary:")
print(f"  Model: ST-Only (Phase 1, BEFORE SC fine-tuning)")
print(f"  Mode: Single patch (patch_size={n_cells})")
print(f"  EDM: Raw (no projection, no rescaling)")
print(f"  Pearson: {pearson_corr:.4f}")
print(f"  Spearman: {spearman_corr:.4f}")
print(f"  Scale factor: {scale:.4f}")
print(f"  Output timestamp: {output_suffix}")
print("\nThis tells you if ring collapse happens during:")
print("  - ST-only training (Phase 1) ‚Üí if you see ring now")
print("  - SC fine-tuning (Phase 2) ‚Üí if you saw ring only with fine-tuned model")

In [None]:
# ===================================================================
# TIMESTEP-BY-TIMESTEP DIFFUSION VISUALIZATION
# ===================================================================
import torch
import numpy as np
import scanpy as sc
import matplotlib.pyplot as plt
import os
from datetime import datetime

# Setup
output_dir = "/home/ehtesamul/sc_st/model/gems_mousebrain_output"
timestamp = "20251128_100055"
checkpoint_path = f"{output_dir}/phase2_sc_finetuned_checkpoint.pt"

print("="*70)
print("DIFFUSION TIMESTEP VISUALIZATION (Single Patch)")
print("="*70)

# ===================================================================
# STEP 1: LOAD TEST DATA
# ===================================================================
print("\n--- Loading Test Data ---")

from run_mouse_brain_2 import load_mouse_data
scadata, stadata = load_mouse_data()

common = sorted(list(set(scadata.var_names) & set(stadata.var_names)))
X_sc = scadata[:, common].X
if hasattr(X_sc, "toarray"):
    X_sc = X_sc.toarray()
sc_expr = torch.tensor(X_sc, dtype=torch.float32)

n_cells = sc_expr.shape[0]
n_genes = sc_expr.shape[1]

print(f"Loaded SC data: {n_cells} cells, {n_genes} genes")

# ===================================================================
# STEP 2: LOAD MODEL
# ===================================================================
print("\n--- Loading Model ---")

from core_models_et_p3 import GEMSModel
import utils_et as uet

model = GEMSModel(
    n_genes=n_genes,
    n_embedding=[512, 256, 128],
    D_latent=32,
    c_dim=256,
    n_heads=4,
    isab_m=64,
    device='cuda',
    use_canonicalize=True,
    use_dist_bias=True,
    dist_bins=24,
    dist_head_shared=True,
    use_angle_features=True,
    angle_bins=8,
    knn_k=12,
    self_conditioning=True,
    sc_feat_mode='concat',
    landmarks_L=16
)

checkpoint = torch.load(checkpoint_path, map_location='cuda')
model.encoder.load_state_dict(checkpoint['encoder'])
model.context_encoder.load_state_dict(checkpoint['context_encoder'])
model.generator.load_state_dict(checkpoint['generator'])
model.score_net.load_state_dict(checkpoint['score_net'])

print(f"‚úì Loaded checkpoint")

# ===================================================================
# STEP 3: INLINE DIFFUSION SAMPLER WITH TIMESTEP CAPTURE
# ===================================================================
print("\n--- Running Diffusion with Timestep Capture ---")

device = 'cuda'
n_timesteps_sample = 600
sigma_min = 0.01
sigma_max = 7.0
guidance_scale = 2.0
D_latent = 32

model.encoder.eval()
model.context_encoder.eval()
model.score_net.eval()

print(f"Config: n_timesteps={n_timesteps_sample}, guidance_scale={guidance_scale}")
print(f"        sigma_min={sigma_min}, sigma_max={sigma_max}")

# Encode all SC cells
print("\n[1/4] Encoding SC cells...")
with torch.no_grad():
    Z_all = model.encoder(sc_expr.to(device))  # (n_cells, hidden_dim)
    
# Prepare context
print("[2/4] Computing context...")
Z_batch = Z_all.unsqueeze(0)  # (1, n_cells, hidden_dim)
mask = torch.ones(1, n_cells, dtype=torch.bool, device=device)
H = model.context_encoder(Z_batch, mask)  # (1, n_cells, c_dim)

# Sigma schedule
sigmas = torch.exp(torch.linspace(
    torch.log(torch.tensor(sigma_max, device=device)),
    torch.log(torch.tensor(sigma_min, device=device)),
    n_timesteps_sample,
    device=device,
))

# Initialize noise
print("[3/4] Running reverse diffusion...")
V_t = torch.randn(1, n_cells, D_latent, device=device) * sigmas[0]

# Timesteps to save
save_timesteps = [0, 100, 200, 300, 400, 500, 599]
saved_samples = {}

with torch.no_grad():
    for t_idx in range(n_timesteps_sample):
        sigma_t = sigmas[t_idx]
        t_norm = torch.tensor([[t_idx / float(n_timesteps_sample - 1)]], device=device)
        
        # CFG sampling
        H_null = torch.zeros_like(H)
        eps_uncond = model.score_net(V_t, t_norm, H_null, mask)
        eps_cond = model.score_net(V_t, t_norm, H, mask)
        eps = eps_uncond + guidance_scale * (eps_cond - eps_uncond)
        
        # Update
        if t_idx < n_timesteps_sample - 1:
            sigma_next = sigmas[t_idx + 1]
            V_0_pred = V_t - sigma_t * eps
            V_t = V_0_pred + (sigma_next / sigma_t) * (V_t - V_0_pred)
        else:
            V_t = V_t - sigma_t * eps
        
        # Save at specific timesteps
        if t_idx in save_timesteps:
            # Canonicalize the current sample
            V_canon = uet.canonicalize_coords(V_t.squeeze(0))
            saved_samples[t_idx] = V_canon.cpu().numpy()
            print(f"  Saved timestep {t_idx}/{n_timesteps_sample-1}")

# Final sample
V_final = V_t.squeeze(0)  # (n_cells, D_latent)
V_final_canon = uet.canonicalize_coords(V_final)
coords_final = V_final_canon.cpu().numpy()

print("[4/4] Complete!")

# ===================================================================
# STEP 4: CONVERT TO 2D COORDINATES VIA MDS
# ===================================================================
print("\n--- Converting to 2D coordinates ---")

def latent_to_2d(V_latent):
    """Convert D_latent dimensional coordinates to 2D via MDS"""
    n = V_latent.shape[0]
    V_tensor = torch.tensor(V_latent, dtype=torch.float32)
    
    # Compute EDM from latent coordinates
    D = torch.cdist(V_tensor, V_tensor)
    
    # Classical MDS
    Jn = torch.eye(n) - torch.ones(n, n) / n
    B = -0.5 * (Jn @ (D**2) @ Jn)
    
    # Extract 2D coordinates
    coords_2d = uet.classical_mds(B, d_out=2).numpy()
    coords_2d = uet.canonicalize_coords(torch.tensor(coords_2d)).numpy()
    
    return coords_2d

coords_at_timesteps = {}
for t_idx, V in saved_samples.items():
    coords_at_timesteps[t_idx] = latent_to_2d(V)
    print(f"  Converted timestep {t_idx} to 2D")

# ===================================================================
# STEP 5: VISUALIZE DIFFUSION EVOLUTION
# ===================================================================
print("\n--- Generating Visualizations ---")

# Ground truth for reference
gt_coords = scadata.obsm['spatial_gt']

# Plot grid: GT + all saved timesteps
n_plots = len(save_timesteps) + 1
n_cols = 4
n_rows = int(np.ceil(n_plots / n_cols))

fig, axes = plt.subplots(n_rows, n_cols, figsize=(20, 5*n_rows))
axes = axes.flatten()

# Plot ground truth
axes[0].scatter(gt_coords[:, 0], gt_coords[:, 1], s=5, alpha=0.6, c='blue')
axes[0].set_title('Ground Truth', fontsize=14, fontweight='bold')
axes[0].set_aspect('equal')
axes[0].grid(True, alpha=0.3)

# Plot diffusion timesteps
for idx, t_idx in enumerate(save_timesteps):
    ax = axes[idx + 1]
    coords = coords_at_timesteps[t_idx]
    
    ax.scatter(coords[:, 0], coords[:, 1], s=5, alpha=0.6, c='red')
    ax.set_title(f'Timestep {t_idx}/{n_timesteps_sample-1}\n(œÉ={sigmas[t_idx]:.4f})', 
                 fontsize=12, fontweight='bold')
    ax.set_aspect('equal')
    ax.grid(True, alpha=0.3)

# Hide unused subplots
for idx in range(n_plots, len(axes)):
    axes[idx].axis('off')

plt.suptitle(f'Diffusion Evolution (guidance_scale={guidance_scale}, n_timesteps={n_timesteps_sample})', 
             fontsize=18, fontweight='bold', y=0.995)
plt.tight_layout(rect=[0, 0, 1, 0.99])
plt.show()

# ===================================================================
# ADDITIONAL PLOT: SIDE-BY-SIDE EVOLUTION
# ===================================================================
fig, axes = plt.subplots(2, 4, figsize=(20, 10))

# Top row: early timesteps
for idx, t_idx in enumerate([0, 100, 200, 300]):
    coords = coords_at_timesteps[t_idx]
    axes[0, idx].scatter(coords[:, 0], coords[:, 1], s=5, alpha=0.6, c='red')
    axes[0, idx].set_title(f't={t_idx} (œÉ={sigmas[t_idx]:.3f})', fontsize=12, fontweight='bold')
    axes[0, idx].set_aspect('equal')
    axes[0, idx].grid(True, alpha=0.3)

# Bottom row: late timesteps
for idx, t_idx in enumerate([400, 500, 599]):
    coords = coords_at_timesteps[t_idx]
    axes[1, idx].scatter(coords[:, 0], coords[:, 1], s=5, alpha=0.6, c='red')
    axes[1, idx].set_title(f't={t_idx} (œÉ={sigmas[t_idx]:.3f})', fontsize=12, fontweight='bold')
    axes[1, idx].set_aspect('equal')
    axes[1, idx].grid(True, alpha=0.3)

# Ground truth in last position
axes[1, 3].scatter(gt_coords[:, 0], gt_coords[:, 1], s=5, alpha=0.6, c='blue')
axes[1, 3].set_title('Ground Truth', fontsize=12, fontweight='bold')
axes[1, 3].set_aspect('equal')
axes[1, 3].grid(True, alpha=0.3)

plt.suptitle('Diffusion Denoising Trajectory', fontsize=18, fontweight='bold')
plt.tight_layout(rect=[0, 0, 1, 0.97])
plt.show()

# ===================================================================
# STEP 6: QUANTIFY STRUCTURE COLLAPSE
# ===================================================================
print("\n--- Analyzing Structure Collapse ---")

def compute_pca_variance_ratio(coords):
    """Compute variance explained by first 2 PCA components"""
    from sklearn.decomposition import PCA
    pca = PCA(n_components=2)
    pca.fit(coords)
    return pca.explained_variance_ratio_

def compute_circularity(coords):
    """Compute circularity score (higher = more ring-like)"""
    center = coords.mean(axis=0)
    radii = np.linalg.norm(coords - center, axis=1)
    return 1.0 - (radii.std() / radii.mean())

print("\n{:<10} {:<15} {:<15} {:<15}".format("Timestep", "PCA-1 Var", "PCA-2 Var", "Circularity"))
print("-"*60)

for t_idx in save_timesteps:
    coords = coords_at_timesteps[t_idx]
    var_ratios = compute_pca_variance_ratio(coords)
    circ = compute_circularity(coords)
    print(f"{t_idx:<10} {var_ratios[0]:<15.4f} {var_ratios[1]:<15.4f} {circ:<15.4f}")

# Ground truth
gt_var_ratios = compute_pca_variance_ratio(gt_coords)
gt_circ = compute_circularity(gt_coords)
print(f"{'GT':<10} {gt_var_ratios[0]:<15.4f} {gt_var_ratios[1]:<15.4f} {gt_circ:<15.4f}")

print("\n" + "="*70)
print("TIMESTEP ANALYSIS COMPLETE")
print("="*70)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import pandas as pd
import anndata as ad
import matplotlib.pyplot as plt
import os

from core_models_et_p3 import GEMSModel
from core_models_et_p1 import STSetDataset, collate_minisets
import utils_et as uet

# ============================================================================
# SETUP
# ============================================================================

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# ============================================================================
# LOAD DATA (from run_mouse_brain_2.py)
# ============================================================================

st_counts = '/home/ehtesamul/sc_st/data/mousedata_2020/E1z2/simu_st1_counts_et.csv'
st_meta   = '/home/ehtesamul/sc_st/data/mousedata_2020/E1z2/simu_st1_metadata_et.csv'
st_ct     = '/home/ehtesamul/sc_st/data/mousedata_2020/E1z2/simu_st1_celltype_et.csv'

print("Loading ST1 (training ST data)...")
st_expr_df = pd.read_csv(st_counts, index_col=0)
st_meta_df = pd.read_csv(st_meta, index_col=0)
st_ct_df = pd.read_csv(st_ct, index_col=0)

stadata = ad.AnnData(X=st_expr_df.values.T)
stadata.obs_names = st_expr_df.columns
stadata.var_names = st_expr_df.index
stadata.obsm['spatial'] = st_meta_df[['coord_x', 'coord_y']].values
stadata.obs['celltype_mapped_refined'] = st_ct_df.idxmax(axis=1).values
stadata.obsm['celltype_proportions'] = st_ct_df.values

print(f"ST1 loaded: {stadata.shape[0]} spots, {stadata.shape[1]} genes")

# Extract expression and coordinates
X_st = stadata.X
if hasattr(X_st, "toarray"):
    X_st = X_st.toarray()

st_expr = torch.tensor(X_st, dtype=torch.float32, device=device)
st_coords_raw = torch.tensor(stadata.obsm['spatial'], dtype=torch.float32, device=device)

# Apply per-slide canonicalization (same as training)
slide_ids = torch.zeros(st_expr.shape[0], dtype=torch.long, device=device)
st_coords, st_mu, st_scale = uet.canonicalize_st_coords_per_slide(
    st_coords_raw, slide_ids
)

print(f"ST coords canonicalized: scale={st_scale[0].item():.4f}")

# ============================================================================
# LOAD TRAINED ENCODER
# ============================================================================

outdir = '/home/ehtesamul/sc_st/model/gems_mousebrain_output'
checkpoint_path = os.path.join(outdir, 'ab_init.pt')

n_genes = stadata.shape[1]

# Create model with same config as run_mouse_brain_2.py
model = GEMSModel(
    n_genes=n_genes,
    n_embedding=[512, 256, 128],
    D_latent=32,
    c_dim=256,
    n_heads=4,
    isab_m=64,
    device=str(device),
    use_canonicalize=True,
    use_dist_bias=True,
    dist_bins=24,
    dist_head_shared=True,
    use_angle_features=True,
    angle_bins=8,
    knn_k=12,
    self_conditioning=True,
    sc_feat_mode='concat',
    landmarks_L=16,
)

print(f"\nLoading checkpoint from: {checkpoint_path}")
ckpt = torch.load(checkpoint_path, map_location=device)
model.encoder.load_state_dict(ckpt['encoder'])
model.encoder.eval()

print("Encoder loaded and frozen.")

# ============================================================================
# RUN STAGE B (takes ~3 seconds)
# ============================================================================

print("\n=== Running Stage B ===")
slides_dict = {0: (st_coords, st_expr)}
model.train_stageB(
    slides=slides_dict,
    outdir='temp_stageB_cache'
)

print("Stage B complete. targets_dict populated.")

# ============================================================================
# DEFINE SUPERVISED REGRESSION HEAD
# ============================================================================

# class SupervisedEDMHead(nn.Module):
#     """
#     Simple supervised head that predicts EDM from encoder embeddings.
    
#     Architecture:
#     Z (from encoder) -> MLP -> upper triangular EDM prediction
#     """
#     def __init__(self, h_dim: int, hidden_dim: int = 256):
#         super().__init__()
#         self.h_dim = h_dim
        
#         # MLP to predict pairwise distances
#         self.mlp = nn.Sequential(
#             nn.Linear(h_dim * 2, hidden_dim),
#             nn.ReLU(),
#             nn.Linear(hidden_dim, hidden_dim),
#             nn.ReLU(),
#             nn.Linear(hidden_dim, 1)
#         )
    
#     def forward(self, Z: torch.Tensor, mask: torch.Tensor):
#         """
#         Args:
#             Z: (batch, n, h_dim) encoder embeddings
#             mask: (batch, n) validity mask
            
#         Returns:
#             D_pred: (batch, n, n) predicted distance matrix
#         """
#         batch, n, h = Z.shape
        
#         # Create pairwise concatenations
#         Z_i = Z.unsqueeze(2).expand(-1, -1, n, -1)  # (batch, n, n, h)
#         Z_j = Z.unsqueeze(1).expand(-1, n, -1, -1)  # (batch, n, n, h)
#         Z_pairs = torch.cat([Z_i, Z_j], dim=-1)     # (batch, n, n, 2h)
        
#         # Predict distances
#         D_pred = self.mlp(Z_pairs).squeeze(-1)      # (batch, n, n)
#         D_pred = torch.relu(D_pred)                  # Ensure non-negative
        
#         # Symmetrize
#         D_pred = (D_pred + D_pred.transpose(-1, -2)) / 2.0
        
#         # Zero out diagonal
#         diag_mask = torch.eye(n, device=Z.device).unsqueeze(0).bool()
#         D_pred = D_pred.masked_fill(diag_mask, 0.0)
        
#         # Apply validity mask
#         valid_mask = mask.unsqueeze(-1) & mask.unsqueeze(-2)
#         D_pred = D_pred * valid_mask.float()
        
#         return D_pred
    
class SupervisedCoordHead(nn.Module):
    """
    Simple supervised head that predicts 2D coordinates from encoder embeddings.
    """
    def __init__(self, h_dim: int, hidden_dim: int = 256, D_out: int = 2):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(h_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, D_out),
        )

    def forward(self, Z: torch.Tensor, mask: torch.Tensor):
        """
        Args:
            Z: (batch, n, h_dim) encoder embeddings
            mask: (batch, n) validity mask
            
        Returns:
            coords: (batch, n, 2) predicted coordinates
        """
        coords = self.mlp(Z)                    # (batch, n, 2)
        coords = coords * mask.unsqueeze(-1)    # zero out padded entries
        return coords

# ============================================================================
# CREATE DATASET AND DATALOADER
# ============================================================================

# Create ST miniset dataset (same as Stage C training)
st_gene_expr_dict_cpu = {0: st_expr.cpu()}

st_dataset = STSetDataset(
    targets_dict=model.targets_dict,
    encoder=model.encoder,
    st_gene_expr_dict=st_gene_expr_dict_cpu,
    n_min=64,
    n_max=384,
    D_latent=model.D_latent,
    num_samples=4000,  # Same as run_mouse_brain_2.py
    knn_k=12,
    device=device,
    landmarks_L=16
)

st_loader = DataLoader(
    st_dataset,
    batch_size=8,
    shuffle=True,
    num_workers=0,
    collate_fn=collate_minisets
)

print(f"ST dataset created: {len(st_dataset)} samples")

In [None]:
# ============================================================================
# INITIALIZE SUPERVISED HEAD
# ============================================================================

# h_dim = model.encoder.fc_list[-1].out_features  # Get encoder output dim
# Get encoder output dimension by doing a forward pass
with torch.no_grad():
    dummy_input = torch.randn(1, n_genes, device=device)
    h_dim = model.encoder(dummy_input).shape[-1]
# supervised_head = SupervisedEDMHead(h_dim=h_dim, hidden_dim=256).to(device)

supervised_head = SupervisedCoordHead(h_dim=h_dim, hidden_dim=256).to(device)

optimizer = optim.Adam(supervised_head.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)

print(f"\nSupervised head initialized: h_dim={h_dim}")

# ============================================================================
# TRAINING LOOP
# ============================================================================

num_epochs = 50
loss_history = []

print("\n=== Training Supervised Baseline ===\n")

supervised_head.train()

for epoch in range(num_epochs):
    epoch_losses = []
    
    for batch_idx, batch in enumerate(st_loader):
        # Move batch to device
        Z = batch['Z_set'].to(device)              # (batch, n, h)
        mask = batch['mask'].to(device)            # (batch, n)
        D_target = batch['D_target'].to(device)    # (batch, n, n)
        
        # # Forward pass
        # D_pred = supervised_head(Z, mask)
        
        # # Loss: MSE on valid EDM entries
        # valid_mask = mask.unsqueeze(-1) & mask.unsqueeze(-2)
        # loss = ((D_pred - D_target) ** 2 * valid_mask.float()).sum() / valid_mask.float().sum()
        
        # Forward pass
        coords_pred = supervised_head(Z, mask)  # (batch, n, 2)

        # Compute EDM from predicted coords
        D_pred = torch.cdist(coords_pred, coords_pred)  # (batch, n, n)

        # Loss: MSE on valid EDM entries
        valid_mask = mask.unsqueeze(-1) & mask.unsqueeze(-2)
        loss = ((D_pred - D_target) ** 2 * valid_mask.float()).sum() / valid_mask.float().sum()
        
        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_losses.append(loss.item())
    
    scheduler.step()
    
    avg_loss = np.mean(epoch_losses)
    loss_history.append(avg_loss)
    
    if (epoch + 1) % 5 == 0 or epoch == 0:
        print(f"Epoch {epoch+1:3d}/{num_epochs} | Loss: {avg_loss:.6f}")

print("\n=== Training Complete ===\n")

# ============================================================================
# EVALUATE: SAMPLE A FEW MINISETS AND CHECK GEOMETRY
# ============================================================================

In [None]:
supervised_head.eval()

print("=== Evaluating Supervised Baseline ===\n")

num_eval_samples = 5
eval_results = []

with torch.no_grad():
    eval_iter = iter(st_loader)
    
    for i in range(num_eval_samples):
        batch = next(eval_iter)
        
        Z = batch['Z_set'].to(device)
        mask = batch['mask'].to(device)
        D_target = batch['D_target'].to(device)
        
        # Predict coordinates directly
        coords_pred = supervised_head(Z, mask)
        
        # Take first sample in batch
        b = 0
        m = mask[b]
        n_valid = m.sum().item()
        
        coords_pred_sample = coords_pred[b, m].cpu()
        D_target_sample = D_target[b, m][:, m].cpu()
        
        # Compute MDS from target EDM (ground truth)
        n = D_target_sample.shape[0]
        Jn = torch.eye(n) - torch.ones(n, n) / n
        B_target = -0.5 * (Jn @ (D_target_sample ** 2) @ Jn)
        coords_target = uet.classical_mds(B_target, d_out=2)
        
        # Canonicalize both
        coords_pred_canon = uet.canonicalize_coords(coords_pred_sample)
        coords_target_canon = uet.canonicalize_coords(coords_target)
        
        # Compute correlation
        corr_x = np.corrcoef(coords_pred_canon[:, 0].numpy(), coords_target_canon[:, 0].numpy())[0, 1]
        corr_y = np.corrcoef(coords_pred_canon[:, 1].numpy(), coords_target_canon[:, 1].numpy())[0, 1]
        avg_corr = (abs(corr_x) + abs(corr_y)) / 2.0
        
        # EDM correlation
        D_pred_sample = torch.cdist(coords_pred_sample.unsqueeze(0), coords_pred_sample.unsqueeze(0)).squeeze(0)
        edm_corr = np.corrcoef(
            D_pred_sample.flatten().numpy(),
            D_target_sample.flatten().numpy()
        )[0, 1]
        
        eval_results.append({
            'sample': i,
            'n_points': n_valid,
            'corr_x': corr_x,
            'corr_y': corr_y,
            'avg_corr': avg_corr,
            'edm_corr': edm_corr,
            'coords_pred': coords_pred_canon.numpy(),
            'coords_target': coords_target_canon.numpy()
        })
        
        print(f"Sample {i}: n={n_valid:3d} | EDM_corr={edm_corr:.4f} | "
              f"Coord_corr: x={corr_x:.4f}, y={corr_y:.4f}, avg={avg_corr:.4f}")

In [None]:
# ============================================================================
# PLOT GROUND TRUTH VS PREDICTED COORDINATES
# ============================================================================

fig, axes = plt.subplots(2, num_eval_samples, figsize=(4*num_eval_samples, 8))

for i, res in enumerate(eval_results):
    # Predicted coordinates
    axes[0, i].scatter(res['coords_pred'][:, 0], res['coords_pred'][:, 1], 
                      s=10, alpha=0.6, c='blue')
    axes[0, i].set_title(f"Sample {i}: Predicted\ncorr={res['avg_corr']:.3f}")
    axes[0, i].set_aspect('equal')
    axes[0, i].grid(True, alpha=0.3)
    
    # Ground truth coordinates
    axes[1, i].scatter(res['coords_target'][:, 0], res['coords_target'][:, 1],
                      s=10, alpha=0.6, c='red')
    axes[1, i].set_title(f"Ground Truth")
    axes[1, i].set_aspect('equal')
    axes[1, i].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('supervised_baseline_coords_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# ===================================================================
# COMPLETE NOTEBOOK: ST-ONLY MODEL (PHASE 1) - SINGLE PATCH INFERENCE
# ===================================================================
import torch
import numpy as np
import scanpy as sc
from scipy.spatial.distance import cdist, pdist, squareform
from scipy.stats import pearsonr, spearmanr
import matplotlib.pyplot as plt
import seaborn as sns
import os
from datetime import datetime

# Setup
output_dir = "/home/ehtesamul/sc_st/model/gems_mousebrain_output"
# timestamp = "20251125_105556"
timestamp = "20251125_105556"


# USE PHASE 1 CHECKPOINT (ST-ONLY, BEFORE SC FINE-TUNING)
checkpoint_path = f"{output_dir}/phase1_st_checkpoint.pt"

print("="*70)
print("ST-ONLY MODEL INFERENCE (Phase 1, Single Patch)")
print("="*70)

# ===================================================================
# STEP 1: LOAD TEST DATA
# ===================================================================
print("\n--- Loading Test Data ---")

from run_mouse_brain_2 import load_mouse_data
scadata, stadata = load_mouse_data()

# Extract SC gene expression
common = sorted(list(set(scadata.var_names) & set(stadata.var_names)))
X_sc = scadata[:, common].X
if hasattr(X_sc, "toarray"):
    X_sc = X_sc.toarray()
sc_expr = torch.tensor(X_sc, dtype=torch.float32)

n_cells = sc_expr.shape[0]
n_genes = sc_expr.shape[1]

print(f"Loaded SC data: {n_cells} cells, {n_genes} genes")
print(f"Ground truth coords shape: {scadata.obsm['spatial_gt'].shape}")

# ===================================================================
# STEP 2: LOAD MODEL AND ST-ONLY CHECKPOINT (PHASE 1)
# ===================================================================
print("\n--- Loading Model and ST-Only Checkpoint (Phase 1) ---")

from core_models_et_p3 import GEMSModel

model = GEMSModel(
    n_genes=n_genes,
    n_embedding=[512, 256, 128],
    D_latent=32,
    c_dim=256,
    n_heads=4,
    isab_m=64,
    device='cuda',
    use_canonicalize=True,
    use_dist_bias=True,
    dist_bins=24,
    dist_head_shared=True,
    use_angle_features=True,
    angle_bins=8,
    knn_k=12,
    self_conditioning=True,
    sc_feat_mode='concat',
    landmarks_L=16
)

checkpoint = torch.load(checkpoint_path, map_location='cuda')
model.encoder.load_state_dict(checkpoint['encoder'])
model.context_encoder.load_state_dict(checkpoint['context_encoder'])
model.generator.load_state_dict(checkpoint['generator'])
model.score_net.load_state_dict(checkpoint['score_net'])

print(f"‚úì Loaded ST-ONLY checkpoint from: {checkpoint_path}")
print(f"  Best ST epoch: {checkpoint.get('E_ST_best', 'N/A')}")
print(f"  This model was trained ONLY on ST data (NO SC fine-tuning)")

# ===================================================================
# STEP 3: SINGLE PATCH INFERENCE (DIAGNOSTIC MODE)
# ===================================================================
print("\n--- Running Single Patch Inference (ST-Only Model) ---")
print(f"Config: patch_size={n_cells}, coverage_per_cell=1.0, n_align_iters=1")
print("This runs ONE patch with ALL cells (no stitching)")
print("-"*70)

In [None]:
# ============================================================================
# DIFFUSION INFERENCE ON ST MINISETS - COMPLETE CODE
# ============================================================================

import torch
import numpy as np
import pandas as pd
import anndata as ad
import matplotlib.pyplot as plt
import os
from torch.utils.data import DataLoader

from core_models_et_p3 import GEMSModel
from core_models_et_p1 import STSetDataset, collate_minisets
import utils_et as uet

# ============================================================================
# SETUP
# ============================================================================

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
output_dir = '/home/ehtesamul/sc_st/model/gems_mousebrain_output'
checkpoint_path = os.path.join(output_dir, 'phase1_st_checkpoint.pt')

print("="*80)
print("DIFFUSION MODEL INFERENCE ON ST MINISETS")
print("="*80)

# ============================================================================
# LOAD ST DATA
# ============================================================================

st_counts = '/home/ehtesamul/sc_st/data/mousedata_2020/E1z2/simu_st1_counts_et.csv'
st_meta   = '/home/ehtesamul/sc_st/data/mousedata_2020/E1z2/simu_st1_metadata_et.csv'
st_ct     = '/home/ehtesamul/sc_st/data/mousedata_2020/E1z2/simu_st1_celltype_et.csv'

print("\nLoading ST1 data...")
st_expr_df = pd.read_csv(st_counts, index_col=0)
st_meta_df = pd.read_csv(st_meta, index_col=0)
st_ct_df = pd.read_csv(st_ct, index_col=0)

stadata = ad.AnnData(X=st_expr_df.values.T)
stadata.obs_names = st_expr_df.columns
stadata.var_names = st_expr_df.index
stadata.obsm['spatial'] = st_meta_df[['coord_x', 'coord_y']].values
stadata.obs['celltype_mapped_refined'] = st_ct_df.idxmax(axis=1).values

print(f"ST1 loaded: {stadata.shape[0]} spots, {stadata.shape[1]} genes")

# Extract and canonicalize
X_st = stadata.X
if hasattr(X_st, "toarray"):
    X_st = X_st.toarray()

st_expr = torch.tensor(X_st, dtype=torch.float32, device=device)
st_coords_raw = torch.tensor(stadata.obsm['spatial'], dtype=torch.float32, device=device)

slide_ids = torch.zeros(st_expr.shape[0], dtype=torch.long, device=device)
st_coords, st_mu, st_scale = uet.canonicalize_st_coords_per_slide(
    st_coords_raw, slide_ids
)

print(f"ST coords canonicalized: scale={st_scale[0].item():.4f}")

# ============================================================================
# LOAD MODEL AND PHASE 1 CHECKPOINT
# ============================================================================

n_genes = stadata.shape[1]

model = GEMSModel(
    n_genes=n_genes,
    n_embedding=[512, 256, 128],
    D_latent=32,
    c_dim=256,
    n_heads=4,
    isab_m=64,
    device=str(device),
    use_canonicalize=True,
    use_dist_bias=True,
    dist_bins=24,
    dist_head_shared=True,
    use_angle_features=True,
    angle_bins=8,
    knn_k=12,
    self_conditioning=True,
    sc_feat_mode='concat',
    landmarks_L=16,
)

print(f"\nLoading checkpoint: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=device)

model.encoder.load_state_dict(checkpoint['encoder'])
model.context_encoder.load_state_dict(checkpoint['context_encoder'])
model.generator.load_state_dict(checkpoint['generator'])
model.score_net.load_state_dict(checkpoint['score_net'])

print(f"‚úì Loaded Phase 1 ST-only checkpoint")
print(f"  Best ST epoch: {checkpoint.get('E_ST_best', 'N/A')}")

model.encoder.eval()
model.context_encoder.eval()
model.score_net.eval()

# ============================================================================
# RUN STAGE B TO GET TARGETS_DICT
# ============================================================================

print("\n=== Running Stage B ===")
slides_dict = {0: (st_coords, st_expr)}
model.train_stageB(
    slides=slides_dict,
    outdir='temp_stageB_cache'
)
print("Stage B complete.")

# ============================================================================
# CREATE ST MINISET DATASET
# ============================================================================

st_gene_expr_dict_cpu = {0: st_expr.cpu()}

st_dataset = STSetDataset(
    targets_dict=model.targets_dict,
    encoder=model.encoder,
    st_gene_expr_dict=st_gene_expr_dict_cpu,
    n_min=64,
    n_max=384,
    D_latent=model.D_latent,
    num_samples=4000,
    knn_k=12,
    device=device,
    landmarks_L=16
)

st_loader = DataLoader(
    st_dataset,
    batch_size=8,
    shuffle=True,
    num_workers=0,
    collate_fn=collate_minisets
)

print(f"ST dataset created: {len(st_dataset)} samples")

# ============================================================================
# RUN DIFFUSION INFERENCE ON ST MINISETS
# ============================================================================

num_eval_samples = 5
diffusion_results = []

print("\n--- Running diffusion inference on ST minisets ---\n")

with torch.no_grad():
    eval_iter = iter(st_loader)
    
    for i in range(num_eval_samples):
        batch = next(eval_iter)
        
        mask = batch['mask'].to(device)
        D_target = batch['D_target'].to(device)
        
        # Take first sample in batch
        b = 0
        m = mask[b]
        n_valid = m.sum().item()
        
        # Get indices for this miniset
        indices = batch['overlap_info']['indices'][b]
        valid_indices = indices[m].cpu()
        
        # Get gene expression for these specific ST spots
        miniset_expr = st_expr.cpu()[valid_indices]
        
        print(f"Sample {i}: Running diffusion inference on {n_valid} points...")
        
        # Run patchwise inference with single patch (no stitching)
        inf_results = model.infer_sc_patchwise(
            sc_gene_expr=miniset_expr,
            n_timesteps_sample=300,
            sigma_min=0.01,
            sigma_max=7.0,
            patch_size=n_valid,          # Single patch = all points
            coverage_per_cell=1.0,       # No overlap
            n_align_iters=1,             # No stitching
            eta=0.0,
            guidance_scale=6.0,
            return_coords=True,
            debug_flag=False,
            debug_every=10,
        )
        
        # Extract predicted coordinates
        coords_diffusion = inf_results['coords_canon']
        
        # Get ground truth coordinates from target EDM
        D_target_sample = D_target[b, m][:, m].cpu()
        n = D_target_sample.shape[0]
        Jn = torch.eye(n) - torch.ones(n, n) / n
        B_target = -0.5 * (Jn @ (D_target_sample ** 2) @ Jn)
        coords_target = uet.classical_mds(B_target, d_out=2)
        coords_target_canon = uet.canonicalize_coords(coords_target)
        
        # Compute correlations
        corr_x = np.corrcoef(coords_diffusion[:, 0].numpy(), coords_target_canon[:, 0].numpy())[0, 1]
        corr_y = np.corrcoef(coords_diffusion[:, 1].numpy(), coords_target_canon[:, 1].numpy())[0, 1]
        avg_corr = (abs(corr_x) + abs(corr_y)) / 2.0
        
        # EDM correlation
        D_diffusion = torch.cdist(coords_diffusion.unsqueeze(0), coords_diffusion.unsqueeze(0)).squeeze(0)
        edm_corr = np.corrcoef(
            D_diffusion.flatten().numpy(),
            D_target_sample.flatten().numpy()
        )[0, 1]
        
        diffusion_results.append({
            'sample': i,
            'n_points': n_valid,
            'corr_x': corr_x,
            'corr_y': corr_y,
            'avg_corr': avg_corr,
            'edm_corr': edm_corr,
            'coords_diffusion': coords_diffusion.numpy(),
            'coords_target': coords_target_canon.numpy()
        })
        
        print(f"  EDM_corr={edm_corr:.4f} | Coord_corr: x={corr_x:.4f}, y={corr_y:.4f}, avg={avg_corr:.4f}\n")

print("="*80)
print("DIFFUSION INFERENCE COMPLETE")
print("="*80)

# ============================================================================
# PRINT COMPARISON (assuming eval_results from supervised baseline exists)
# ============================================================================

print(f"\nDiffusion Model (Phase 1 ST-only) Results:")
print(f"  Average EDM correlation:   {np.mean([r['edm_corr'] for r in diffusion_results]):.4f}")
print(f"  Average Coord correlation: {np.mean([r['avg_corr'] for r in diffusion_results]):.4f}")

# ============================================================================
# PLOT: DIFFUSION vs GROUND TRUTH
# ============================================================================

fig, axes = plt.subplots(2, num_eval_samples, figsize=(4*num_eval_samples, 8))

for i in range(num_eval_samples):
    # Diffusion prediction
    axes[0, i].scatter(diffusion_results[i]['coords_diffusion'][:, 0],
                      diffusion_results[i]['coords_diffusion'][:, 1],
                      s=10, alpha=0.6, c='green')
    axes[0, i].set_title(f"Sample {i}: Diffusion\ncorr={diffusion_results[i]['avg_corr']:.3f}")
    axes[0, i].set_aspect('equal')
    axes[0, i].grid(True, alpha=0.3)
    
    # Ground truth
    axes[1, i].scatter(diffusion_results[i]['coords_target'][:, 0],
                      diffusion_results[i]['coords_target'][:, 1],
                      s=10, alpha=0.6, c='red')
    axes[1, i].set_title(f"Ground Truth")
    axes[1, i].set_aspect('equal')
    axes[1, i].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('diffusion_vs_groundtruth.png', dpi=150, bbox_inches='tight')
plt.show()

print("\n‚úì Plot saved: diffusion_vs_groundtruth.png")

In [None]:
# ============================================================================
# DIFFUSION INFERENCE ON ST MINISETS - FIXED
# ============================================================================

import torch
import numpy as np
import pandas as pd
import anndata as ad
import matplotlib.pyplot as plt
import os

from core_models_et_p3 import GEMSModel
import utils_et as uet

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
output_dir = '/home/ehtesamul/sc_st/model/gems_mousebrain_output'
# checkpoint_path = os.path.join(output_dir, 'phase2_sc_finetuned_checkpoint.pt')
checkpoint_path = os.path.join(output_dir, 'phase2_sc_finetuned_checkpoint.pt')


print("="*80)
print("DIFFUSION MODEL INFERENCE ON ST MINISETS")
print("="*80)

# ============================================================================
# LOAD ST DATA
# ============================================================================

st_counts = '/home/ehtesamul/sc_st/data/mousedata_2020/E1z2/simu_st1_counts_et.csv'
st_meta   = '/home/ehtesamul/sc_st/data/mousedata_2020/E1z2/simu_st1_metadata_et.csv'

print("\nLoading ST1 data...")
st_expr_df = pd.read_csv(st_counts, index_col=0)
st_meta_df = pd.read_csv(st_meta, index_col=0)

stadata = ad.AnnData(X=st_expr_df.values.T)
stadata.obs_names = st_expr_df.columns
stadata.var_names = st_expr_df.index
stadata.obsm['spatial'] = st_meta_df[['coord_x', 'coord_y']].values

X_st = stadata.X
if hasattr(X_st, "toarray"):
    X_st = X_st.toarray()

st_expr = torch.tensor(X_st, dtype=torch.float32, device=device)
st_coords_raw = torch.tensor(stadata.obsm['spatial'], dtype=torch.float32, device=device)

slide_ids = torch.zeros(st_expr.shape[0], dtype=torch.long, device=device)
st_coords, st_mu, st_scale = uet.canonicalize_st_coords_per_slide(
    st_coords_raw, slide_ids
)

print(f"ST loaded: {stadata.shape[0]} spots, {stadata.shape[1]} genes")

# ============================================================================
# LOAD MODEL
# ============================================================================

n_genes = stadata.shape[1]

model = GEMSModel(
    n_genes=n_genes,
    n_embedding=[512, 256, 128],
    D_latent=32,
    c_dim=256,
    n_heads=4,
    isab_m=64,
    device=str(device),
    use_canonicalize=True,
    use_dist_bias=True,
    dist_bins=24,
    dist_head_shared=True,
    use_angle_features=True,
    angle_bins=8,
    knn_k=12,
    self_conditioning=True,
    sc_feat_mode='concat',
    landmarks_L=16,
)

print(f"\nLoading checkpoint: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=device)

model.encoder.load_state_dict(checkpoint['encoder'])
model.context_encoder.load_state_dict(checkpoint['context_encoder'])
model.generator.load_state_dict(checkpoint['generator'])
model.score_net.load_state_dict(checkpoint['score_net'])

print(f"‚úì Loaded Phase 1 checkpoint (best epoch: {checkpoint.get('E_ST_best', 'N/A')})")

model.encoder.eval()
model.context_encoder.eval()
model.score_net.eval()

# ============================================================================
# SAMPLE ST MINISETS AND RUN DIFFUSION INFERENCE
# ============================================================================

num_eval_samples = 10
diffusion_results = []

print("\n--- Running diffusion inference on ST minisets ---\n")

np.random.seed(42)

for i in range(num_eval_samples):
    # Sample random miniset (same logic as STSetDataset)
    n_min, n_max = 192, 384
    n_total = st_coords.shape[0]
    
    # Random subset size
    n = np.random.randint(n_min, min(n_max + 1, n_total))
    
    # Random indices
    indices = torch.randperm(n_total)[:n]
    
    # Get gene expression and coords for this miniset
    miniset_expr = st_expr[indices].cpu()
    miniset_coords = st_coords[indices].cpu()
    
    # Compute ground truth EDM
    D_target = torch.cdist(miniset_coords, miniset_coords)
    
    print(f"Sample {i}: Running diffusion on {n} points...")
    
    # Run inference with single patch (no stitching)
    with torch.no_grad():
        inf_results = model.infer_sc_patchwise(
            sc_gene_expr=miniset_expr,
            n_timesteps_sample=300,
            sigma_min=0.01,
            sigma_max=7.0,
            patch_size=n,            # Single patch
            coverage_per_cell=1.0,   # No overlap
            n_align_iters=1,         # No alignment
            eta=0.0,
            guidance_scale=6.0,
            return_coords=True,
            debug_flag=False,
        )
    
    coords_diffusion = inf_results['coords_canon']
    
    # Ground truth coords via MDS
    Jn = torch.eye(n) - torch.ones(n, n) / n
    B_target = -0.5 * (Jn @ (D_target**2) @ Jn)
    coords_target = uet.classical_mds(B_target, d_out=2)
    coords_target_canon = uet.canonicalize_coords(coords_target)
    
    # Compute correlations
    corr_x = np.corrcoef(coords_diffusion[:, 0].numpy(), coords_target_canon[:, 0].numpy())[0, 1]
    corr_y = np.corrcoef(coords_diffusion[:, 1].numpy(), coords_target_canon[:, 1].numpy())[0, 1]
    avg_corr = (abs(corr_x) + abs(corr_y)) / 2.0
    
    # EDM correlation
    D_diffusion = torch.cdist(coords_diffusion.unsqueeze(0), coords_diffusion.unsqueeze(0)).squeeze(0)
    edm_corr = np.corrcoef(
        D_diffusion.flatten().numpy(),
        D_target.flatten().numpy()
    )[0, 1]
    
    diffusion_results.append({
        'sample': i,
        'n_points': n,
        'corr_x': corr_x,
        'corr_y': corr_y,
        'avg_corr': avg_corr,
        'edm_corr': edm_corr,
        'coords_diffusion': coords_diffusion.numpy(),
        'coords_target': coords_target_canon.numpy()
    })
    
    print(f"  EDM_corr={edm_corr:.4f} | Coord: x={corr_x:.4f}, y={corr_y:.4f}, avg={avg_corr:.4f}\n")

print("="*80)
print(f"\nDiffusion Results (avg over {num_eval_samples} samples):")
print(f"  EDM correlation:   {np.mean([r['edm_corr'] for r in diffusion_results]):.4f}")
print(f"  Coord correlation: {np.mean([r['avg_corr'] for r in diffusion_results]):.4f}")
print("="*80)

# ============================================================================
# PLOT - 3 COLUMNS MAX PER ROW
# ============================================================================

n_cols = min(3, num_eval_samples)
n_rows = int(np.ceil(num_eval_samples / n_cols)) * 2  # *2 for diffusion + GT rows

fig, axes = plt.subplots(n_rows, n_cols, figsize=(5*n_cols, 4*n_rows))

# Handle single row case
if n_rows == 1:
    axes = axes.reshape(1, -1)
if n_cols == 1:
    axes = axes.reshape(-1, 1)

for i in range(num_eval_samples):
    row_pair = (i // n_cols) * 2  # Which pair of rows (diffusion + GT)
    col = i % n_cols
    
    # Diffusion prediction
    ax_diff = axes[row_pair, col]
    ax_diff.scatter(diffusion_results[i]['coords_diffusion'][:, 0],
                   diffusion_results[i]['coords_diffusion'][:, 1],
                   s=10, alpha=0.6, c='green')
    ax_diff.set_title(f"Sample {i}: Diffusion\n"
                     f"Coord: {diffusion_results[i]['avg_corr']:.3f} | "
                     f"EDM: {diffusion_results[i]['edm_corr']:.3f}",
                     fontsize=10)
    ax_diff.set_aspect('equal')
    ax_diff.grid(True, alpha=0.3)
    
    # Ground truth
    ax_gt = axes[row_pair + 1, col]
    ax_gt.scatter(diffusion_results[i]['coords_target'][:, 0],
                 diffusion_results[i]['coords_target'][:, 1],
                 s=10, alpha=0.6, c='red')
    ax_gt.set_title(f"Ground Truth (n={diffusion_results[i]['n_points']})",
                   fontsize=10)
    ax_gt.set_aspect('equal')
    ax_gt.grid(True, alpha=0.3)

# Hide unused subplots
for i in range(num_eval_samples, n_rows // 2 * n_cols):
    row_pair = (i // n_cols) * 2
    col = i % n_cols
    axes[row_pair, col].axis('off')
    axes[row_pair + 1, col].axis('off')

plt.tight_layout()
plt.show()

In [None]:
# ============================================================================
# SMART OUTLIER REMOVAL - DISTANCE-BASED METHOD
# ============================================================================

import numpy as np
import matplotlib.pyplot as plt
import torch

print("\n" + "="*80)
print("OUTLIER REMOVAL - DISTANCE FROM MEDIAN CENTER")
print("="*80)
print("\nRationale:")
print("- Diffusion occasionally samples points in low-probability tail regions")
print("- Ground truth tissue has consistent density (filled region)")
print("- Outliers are scattered points FAR from main cluster")
print("- Method: Keep only points within 90th percentile distance from median center")
print("- Why median? Robust to outliers (unlike mean)")
print("- Why 90th percentile? Keeps main distribution, removes extreme tail")
print("="*80 + "\n")

def remove_outliers_distance_percentile(coords, coords_target, percentile=90):
    """
    Remove outliers based on distance from median center.
    
    Strategy:
    1. Find median center (robust to outliers)
    2. Compute distance of each point from center
    3. Keep only points within `percentile` of distances
    4. Filter both predicted and target coords to match
    
    Args:
        coords: (n, 2) predicted coordinates
        coords_target: (n, 2) target coordinates
        percentile: keep points within this percentile (90 = remove top 10%)
    
    Returns:
        coords_clean, coords_target_clean, inlier_mask
    """
    # Use MEDIAN center (robust to outliers, unlike mean)
    center = np.median(coords, axis=0)
    
    # Distance from center for each point
    dists = np.linalg.norm(coords - center, axis=1)
    
    # Threshold: keep only points within percentile
    threshold = np.percentile(dists, percentile)
    
    # Inlier mask
    inlier_mask = dists <= threshold
    
    # Filter both predicted and target
    coords_clean = coords[inlier_mask]
    coords_target_clean = coords_target[inlier_mask]
    
    return coords_clean, coords_target_clean, inlier_mask, threshold

# ============================================================================
# CLEAN EACH SAMPLE
# ============================================================================

diffusion_results_clean = []

for i, res in enumerate(diffusion_results):
    coords_pred = res['coords_diffusion']
    coords_gt = res['coords_target']
    n_orig = len(coords_pred)
    
    # Remove outliers
    coords_clean, coords_gt_clean, mask, thresh = remove_outliers_distance_percentile(
        coords_pred, coords_gt, percentile=90
    )
    
    n_kept = len(coords_clean)
    n_removed = n_orig - n_kept
    pct_removed = 100 * n_removed / n_orig
    
    # Recanonialize after filtering
    coords_clean_t = torch.from_numpy(coords_clean).float()
    coords_gt_t = torch.from_numpy(coords_gt_clean).float()
    
    coords_clean_canon = uet.canonicalize_coords(coords_clean_t).numpy()
    coords_gt_canon = uet.canonicalize_coords(coords_gt_t).numpy()
    
    # Recompute correlations
    corr_x_before = res['corr_x']
    corr_y_before = res['corr_y']
    avg_corr_before = res['avg_corr']
    edm_corr_before = res['edm_corr']
    
    corr_x = np.corrcoef(coords_clean_canon[:, 0], coords_gt_canon[:, 0])[0, 1]
    corr_y = np.corrcoef(coords_clean_canon[:, 1], coords_gt_canon[:, 1])[0, 1]
    avg_corr = (abs(corr_x) + abs(corr_y)) / 2.0
    
    # EDM correlation
    D_clean = torch.cdist(
        torch.from_numpy(coords_clean_canon).unsqueeze(0).float(),
        torch.from_numpy(coords_clean_canon).unsqueeze(0).float()
    ).squeeze(0)
    D_gt = torch.cdist(
        torch.from_numpy(coords_gt_canon).unsqueeze(0).float(),
        torch.from_numpy(coords_gt_canon).unsqueeze(0).float()
    ).squeeze(0)
    
    edm_corr = np.corrcoef(D_clean.flatten().numpy(), D_gt.flatten().numpy())[0, 1]
    
    # Store
    diffusion_results_clean.append({
        'sample': i,
        'n_points': n_kept,
        'n_removed': n_removed,
        'pct_removed': pct_removed,
        'corr_x': corr_x,
        'corr_y': corr_y,
        'avg_corr': avg_corr,
        'edm_corr': edm_corr,
        'coords_diffusion': coords_clean_canon,
        'coords_target': coords_gt_canon,
        'threshold': thresh
    })
    
    # Print results
    print(f"Sample {i}: removed {n_removed}/{n_orig} outliers ({pct_removed:.1f}%), "
          f"threshold={thresh:.3f}")
    print(f"  Before: Coord={avg_corr_before:.3f}, EDM={edm_corr_before:.3f}")
    print(f"  After:  Coord={avg_corr:.3f} (Œî={avg_corr-avg_corr_before:+.3f}), "
          f"EDM={edm_corr:.3f} (Œî={edm_corr-edm_corr_before:+.3f})\n")

# ============================================================================
# SUMMARY STATISTICS
# ============================================================================

print("="*80)
print("SUMMARY: BEFORE vs AFTER OUTLIER REMOVAL")
print("="*80 + "\n")

avg_coord_before = np.mean([r['avg_corr'] for r in diffusion_results])
avg_edm_before = np.mean([r['edm_corr'] for r in diffusion_results])

avg_coord_after = np.mean([r['avg_corr'] for r in diffusion_results_clean])
avg_edm_after = np.mean([r['edm_corr'] for r in diffusion_results_clean])

avg_pct_removed = np.mean([r['pct_removed'] for r in diffusion_results_clean])

print(f"Average Coordinate Correlation:")
print(f"  Before: {avg_coord_before:.4f}")
print(f"  After:  {avg_coord_after:.4f} (Œî={avg_coord_after-avg_coord_before:+.4f})")

print(f"\nAverage EDM Correlation:")
print(f"  Before: {avg_edm_before:.4f}")
print(f"  After:  {avg_edm_after:.4f} (Œî={avg_edm_after-avg_edm_before:+.4f})")

print(f"\nAverage outliers removed: {avg_pct_removed:.1f}%")

print("\n" + "="*80)
print("INTERPRETATION:")
print("="*80)
if avg_edm_after - avg_edm_before > 0.1:
    print("‚úì EDM correlation IMPROVED significantly after outlier removal")
    print("  ‚Üí Confirms outliers were corrupting distance metrics")
    print("  ‚Üí Main cluster has better geometric structure than raw output")
elif avg_edm_after - avg_edm_before > 0:
    print("‚úì EDM correlation improved slightly")
    print("  ‚Üí Outliers had some negative effect on distances")
else:
    print("‚ö† EDM correlation unchanged or decreased")
    print("  ‚Üí Problem is not just outliers, geometry of main cluster needs work")
print("="*80 + "\n")

# ============================================================================
# SIMPLE PLOT: PREDICTED vs GROUND TRUTH (AFTER OUTLIER REMOVAL)
# ============================================================================

n_cols = 3
n_rows = int(np.ceil(num_eval_samples / n_cols)) * 2

fig, axes = plt.subplots(n_rows, n_cols, figsize=(5*n_cols, 4*n_rows))

if n_rows == 1:
    axes = axes.reshape(1, -1)
if n_cols == 1:
    axes = axes.reshape(-1, 1)

for i in range(num_eval_samples):
    row_pair = (i // n_cols) * 2
    col = i % n_cols
    
    # Predicted (cleaned)
    ax_pred = axes[row_pair, col]
    ax_pred.scatter(diffusion_results_clean[i]['coords_diffusion'][:, 0],
                   diffusion_results_clean[i]['coords_diffusion'][:, 1],
                   s=10, alpha=0.7, c='#2ecc71', edgecolors='none')
    ax_pred.set_title(f"Sample {i}: Predicted\n"
                     f"Coord: {diffusion_results_clean[i]['avg_corr']:.3f} | "
                     f"EDM: {diffusion_results_clean[i]['edm_corr']:.3f}",
                     fontsize=10)
    ax_pred.set_aspect('equal')
    ax_pred.grid(True, alpha=0.2)
    
    # Ground Truth
    ax_gt = axes[row_pair + 1, col]
    ax_gt.scatter(diffusion_results_clean[i]['coords_target'][:, 0],
                 diffusion_results_clean[i]['coords_target'][:, 1],
                 s=10, alpha=0.7, c='#e74c3c', edgecolors='none')
    ax_gt.set_title(f"Ground Truth (n={diffusion_results_clean[i]['n_points']})",
                   fontsize=10)
    ax_gt.set_aspect('equal')
    ax_gt.grid(True, alpha=0.2)

# Hide unused subplots
for i in range(num_eval_samples, n_rows // 2 * n_cols):
    row_pair = (i // n_cols) * 2
    col = i % n_cols
    axes[row_pair, col].axis('off')
    axes[row_pair + 1, col].axis('off')

plt.tight_layout()
# plt.savefig('cleaned_results.png', dpi=200, bbox_inches='tight')
plt.show()

# print("‚úì Saved plot: outlier_removal_comparison.png")