In [None]:
# ===================================================================
# COMPLETE NOTEBOOK: SINGLE PATCH INFERENCE + FULL EVALUATION
# ===================================================================
import torch
import numpy as np
import scanpy as sc
from scipy.spatial.distance import cdist, pdist, squareform
from scipy.stats import pearsonr, spearmanr
import matplotlib.pyplot as plt
import seaborn as sns
import os
from datetime import datetime

# Setup
output_dir = "/home/ehtesamul/sc_st/model/gems_mousebrain_output"
timestamp = "20251128_100055"
checkpoint_path = f"{output_dir}/phase2_sc_finetuned_checkpoint.pt"

print("="*70)
print("SINGLE PATCH INFERENCE (Diagnostic Mode)")
print("="*70)

# ===================================================================
# STEP 1: LOAD TEST DATA
# ===================================================================
print("\n--- Loading Test Data ---")

from run_mouse_brain_2 import load_mouse_data
scadata, stadata = load_mouse_data()

# Extract SC gene expression
common = sorted(list(set(scadata.var_names) & set(stadata.var_names)))
X_sc = scadata[:, common].X
if hasattr(X_sc, "toarray"):
    X_sc = X_sc.toarray()
sc_expr = torch.tensor(X_sc, dtype=torch.float32)

n_cells = sc_expr.shape[0]
n_genes = sc_expr.shape[1]

print(f"Loaded SC data: {n_cells} cells, {n_genes} genes")
print(f"Ground truth coords shape: {scadata.obsm['spatial_gt'].shape}")

# ===================================================================
# STEP 2: LOAD MODEL AND CHECKPOINT
# ===================================================================
print("\n--- Loading Model and Checkpoint ---")

from core_models_et_p3 import GEMSModel

model = GEMSModel(
    n_genes=n_genes,
    n_embedding=[512, 256, 128],
    D_latent=32,
    c_dim=256,
    n_heads=4,
    isab_m=64,
    device='cuda',
    use_canonicalize=True,
    use_dist_bias=True,
    dist_bins=24,
    dist_head_shared=True,
    use_angle_features=True,
    angle_bins=8,
    knn_k=12,
    self_conditioning=True,
    sc_feat_mode='concat',
    landmarks_L=16
)

checkpoint = torch.load(checkpoint_path, map_location='cuda')
model.encoder.load_state_dict(checkpoint['encoder'])
model.context_encoder.load_state_dict(checkpoint['context_encoder'])
model.generator.load_state_dict(checkpoint['generator'])
model.score_net.load_state_dict(checkpoint['score_net'])

print(f"✓ Loaded checkpoint from: {checkpoint_path}")
print(f"  Epochs trained: {checkpoint.get('epochs_finetune', 'N/A')}")

# ===================================================================
# STEP 3: SINGLE PATCH INFERENCE (DIAGNOSTIC MODE)
# ===================================================================
print("\n--- Running Single Patch Inference ---")
print(f"Config: patch_size={n_cells}, coverage_per_cell=1.0")
print("This runs ONE patch with ALL cells (no stitching)")
print("-"*70)

results = model.infer_sc_patchwise(
    sc_gene_expr=sc_expr,
    n_timesteps_sample=600,
    sigma_min=0.01,
    sigma_max=7.0,
    patch_size=n_cells,          # SINGLE PATCH MODE
    coverage_per_cell=1.0,       # NO OVERLAP
    n_align_iters=1,             # IRRELEVANT (only 1 patch)
    eta=0.0,
    guidance_scale=5.0,
    return_coords=True,
    debug_flag=True,
    debug_every=10,
)

print("\n✓ Inference complete")

# ===================================================================
# STEP 4: EXTRACT RAW EDM (NO PROJECTION, NO RESCALING)
# ===================================================================
print("\n--- Computing Raw EDM (No Post-Processing) ---")

# Extract canonicalized coordinates
coords_canon = results['coords_canon'].cpu().numpy()

# Compute RAW EDM directly from coordinates (NO edm_project, NO rescaling)
gems_edm = cdist(coords_canon, coords_canon, metric='euclidean')

print(f"Raw EDM shape: {gems_edm.shape}")
print(f"Raw EDM stats:")
print(f"  Min: {gems_edm[gems_edm > 0].min():.4f}")
print(f"  Median: {np.median(gems_edm[gems_edm > 0]):.4f}")
print(f"  Max: {gems_edm.max():.4f}")
print(f"  Mean: {gems_edm[gems_edm > 0].mean():.4f}")

# ===================================================================
# STEP 5: COMPUTE GROUND TRUTH EDM
# ===================================================================
print("\n--- Calculating Ground Truth EDM ---")

gt_coords = scadata.obsm['spatial_gt']
gt_edm = squareform(pdist(gt_coords, 'euclidean'))

print(f"Ground Truth EDM shape: {gt_edm.shape}")
print(f"Ground Truth EDM stats:")
print(f"  Min: {gt_edm[gt_edm > 0].min():.4f}")
print(f"  Median: {np.median(gt_edm[gt_edm > 0]):.4f}")
print(f"  Max: {gt_edm.max():.4f}")
print(f"  Mean: {gt_edm[gt_edm > 0].mean():.4f}")

# ===================================================================
# STEP 6: NORMALIZE FOR COMPARISON
# ===================================================================
def normalize_matrix(matrix):
    min_val = matrix.min()
    max_val = matrix.max()
    return (matrix - min_val) / (max_val - min_val)

gems_edm_norm = normalize_matrix(gems_edm)
gt_edm_norm = normalize_matrix(gt_edm)

# ===================================================================
# STEP 7: QUANTITATIVE COMPARISON
# ===================================================================
print("\n" + "="*70)
print("QUANTITATIVE COMPARISON")
print("="*70)

# Extract upper triangle (excluding diagonal)
triu_indices = np.triu_indices(n_cells, k=1)
gt_distances_flat = gt_edm[triu_indices]
gems_distances_flat = gems_edm[triu_indices]

# Scale alignment (median matching)
scale = np.median(gt_distances_flat) / np.median(gems_distances_flat)
gems_distances_flat_scaled = gems_distances_flat * scale

print(f"\nScale factor (median matching): {scale:.4f}")

# Calculate correlations
pearson_corr, _ = pearsonr(gt_distances_flat, gems_distances_flat_scaled)
spearman_corr, _ = spearmanr(gt_distances_flat, gems_distances_flat_scaled)

print(f"\nPearson Correlation: {pearson_corr:.4f}")
print(f"Spearman Correlation: {spearman_corr:.4f}")
print("-"*70)

# ===================================================================
# STEP 8: VISUALIZATIONS
# ===================================================================
print("\n--- Generating Visualizations ---")

# --- PLOT 1: Side-by-Side Heatmaps ---
fig, axes = plt.subplots(1, 2, figsize=(16, 7))
fig.suptitle('EDM Comparison: Ground Truth vs. GEMS (Single Patch, Raw EDM)', 
             fontsize=18, fontweight='bold')

sample_size = min(838, n_cells)
sample_indices = np.random.choice(n_cells, sample_size, replace=False)
sample_indices = np.sort(sample_indices)

im1 = axes[0].imshow(gt_edm_norm[np.ix_(sample_indices, sample_indices)], cmap='viridis')
axes[0].set_title('Ground Truth EDM (Normalized)', fontsize=14)
axes[0].set_xlabel('Cell Index (Sampled)')
axes[0].set_ylabel('Cell Index (Sampled)')
fig.colorbar(im1, ax=axes[0], fraction=0.046, pad=0.04)

im2 = axes[1].imshow(gems_edm_norm[np.ix_(sample_indices, sample_indices)], cmap='viridis')
axes[1].set_title('GEMS Predicted EDM (Normalized)', fontsize=14)
axes[1].set_xlabel('Cell Index (Sampled)')
fig.colorbar(im2, ax=axes[1], fraction=0.046, pad=0.04)

plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()

# --- PLOT 2: Distribution of Distances ---
fig, ax = plt.subplots(figsize=(10, 6))
sns.histplot(gt_distances_flat, color="blue", label='Ground Truth Distances', 
             ax=ax, stat='density', bins=100, alpha=0.6)
sns.histplot(gems_distances_flat_scaled, color="red", label='GEMS Distances (Scaled)', 
             ax=ax, stat='density', bins=100, alpha=0.6)
ax.set_title('Distribution of Pairwise Distances (Single Patch Mode)', fontsize=16, fontweight='bold')
ax.set_xlabel('Distance', fontsize=12)
ax.set_ylabel('Density', fontsize=12)
ax.legend(fontsize=12)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# --- PLOT 3: Scatter Plot of Distances ---
sample_size_scatter = min(50000, len(gt_distances_flat))
sample_indices_scatter = np.random.choice(len(gt_distances_flat), sample_size_scatter, replace=False)

fig, ax = plt.subplots(figsize=(8, 8))
ax.scatter(
    gt_distances_flat[sample_indices_scatter],
    gems_distances_flat_scaled[sample_indices_scatter],
    alpha=0.2, s=5, color='steelblue'
)
ax.set_title(f'GEMS vs. Ground Truth Distances (Single Patch)\nSpearman ρ = {spearman_corr:.4f}', 
             fontsize=16, fontweight='bold')
ax.set_xlabel('Ground Truth Pairwise Distance', fontsize=12)
ax.set_ylabel('GEMS Pairwise Distance (Scaled)', fontsize=12)
ax.grid(True, linestyle='--', alpha=0.5)

lims = [
    min(ax.get_xlim()[0], ax.get_ylim()[0]),
    max(ax.get_xlim()[1], ax.get_ylim()[1]),
]
ax.plot(lims, lims, 'r--', alpha=0.75, linewidth=2, zorder=0, label='Ideal Correlation')
ax.set_aspect('equal', adjustable='box')
ax.legend(fontsize=12)
plt.tight_layout()
plt.show()

# --- PLOT 4: Coordinate Comparison ---
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
fig.suptitle('Spatial Coordinates: Ground Truth vs. GEMS (Single Patch)', 
             fontsize=16, fontweight='bold')

axes[0].scatter(gt_coords[:, 0], gt_coords[:, 1], s=5, alpha=0.6, color='blue')
axes[0].set_title('Ground Truth Coordinates', fontsize=14)
axes[0].set_xlabel('X', fontsize=12)
axes[0].set_ylabel('Y', fontsize=12)
axes[0].set_aspect('equal')
axes[0].grid(True, alpha=0.3)

axes[1].scatter(coords_canon[:, 0], coords_canon[:, 1], s=5, alpha=0.6, color='red')
axes[1].set_title('GEMS Predicted Coordinates', fontsize=14)
axes[1].set_xlabel('X', fontsize=12)
axes[1].set_ylabel('Y', fontsize=12)
axes[1].set_aspect('equal')
axes[1].grid(True, alpha=0.3)

plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()

# --- PLOT 5: Distance Error Distribution ---
distance_errors = np.abs(gt_distances_flat - gems_distances_flat_scaled)
fig, ax = plt.subplots(figsize=(10, 6))
sns.histplot(distance_errors, bins=100, kde=True, ax=ax, color='purple')
ax.set_title('Distance Prediction Error Distribution', fontsize=16, fontweight='bold')
ax.set_xlabel('Absolute Error |GT - GEMS|', fontsize=12)
ax.set_ylabel('Count', fontsize=12)
ax.axvline(np.median(distance_errors), color='r', linestyle='--', linewidth=2, 
           label=f'Median Error: {np.median(distance_errors):.4f}')
ax.axvline(np.mean(distance_errors), color='g', linestyle='--', linewidth=2, 
           label=f'Mean Error: {np.mean(distance_errors):.4f}')
ax.legend(fontsize=12)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# ===================================================================
# STEP 9: SAVE RESULTS
# ===================================================================
print("\n--- Saving Results ---")

new_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_suffix = f"single_patch_{new_timestamp}"

results_processed = {
    'D_edm': gems_edm,  # RAW EDM (no projection, no rescaling)
    'coords': results['coords'].cpu().numpy(),
    'coords_canon': coords_canon,
    'n_cells': n_cells,
    'timestamp': new_timestamp,
    'mode': 'single_patch_no_projection',
    'scale_factor': scale,
    'pearson_corr': pearson_corr,
    'spearman_corr': spearman_corr,
    'model_config': {
        'n_genes': n_genes,
        'D_latent': 32,
        'c_dim': 256,
    }
}

processed_path = os.path.join(output_dir, f"sc_inference_processed_{output_suffix}.pt")
# torch.save(results_processed, processed_path)
# print(f"✓ Saved: {processed_path}")

scadata.obsm['X_gems'] = coords_canon
adata_path = os.path.join(output_dir, f"scadata_with_gems_{output_suffix}.h5ad")
scadata.write_h5ad(adata_path)
print(f"✓ Saved: {adata_path}")

print("\n" + "="*70)
print("SINGLE PATCH DIAGNOSTIC COMPLETE")
print("="*70)
print(f"\nResults Summary:")
print(f"  Mode: Single patch (patch_size={n_cells})")
print(f"  EDM: Raw (no projection, no rescaling)")
print(f"  Pearson: {pearson_corr:.4f}")
print(f"  Spearman: {spearman_corr:.4f}")
print(f"  Scale factor: {scale:.4f}")
print(f"  Output timestamp: {output_suffix}")

In [None]:
import torch
import numpy as np
import pandas as pd
import anndata as ad
import matplotlib.pyplot as plt
import seaborn as sns
import utils_et as uet  # Ensure this is in your python path

# 1. Load the Raw ST Data (Exact paths from your code)
print("Loading ST Data...")
st_counts = '/home/ehtesamul/sc_st/data/mousedata_2020/E1z2/simu_st1_counts_et.csv'
st_meta   = '/home/ehtesamul/sc_st/data/mousedata_2020/E1z2/simu_st1_metadata_et.csv'

# Load coords
st_meta_df = pd.read_csv(st_meta, index_col=0)
raw_coords = st_meta_df[['coord_x', 'coord_y']].values
st_coords_tensor = torch.tensor(raw_coords, dtype=torch.float32)

# 2. Apply the EXACT normalization used in run_mouse_brain_2.py
print("Applying Global RMS Normalization...")
# Dummy slide IDs (all 0) since you have single slide logic in the snippets
slide_ids = torch.zeros(st_coords_tensor.shape[0], dtype=torch.long)

# This is the function called in line 165 of run_mouse_brain_2.py
norm_coords, mu, scale = uet.canonicalize_st_coords_per_slide(
    st_coords_tensor, slide_ids
)

norm_coords = norm_coords.numpy()
print(f"Normalization Scale Factor used: {scale[0].item():.4f}")

# 3. Calculate Statistics
radii = np.sqrt(np.sum(norm_coords**2, axis=1))
points_outside = np.sum(radii > 1.0)
pct_outside = (points_outside / len(radii)) * 100

print("-" * 40)
print(f"Total Points: {len(radii)}")
print(f"Points outside Unit Circle (Radius > 1.0): {points_outside}")
print(f"Percentage outside: {pct_outside:.2f}%")
print(f"Max Radius: {radii.max():.4f}")
print("-" * 40)

# 4. Visualization
fig, axes = plt.subplots(1, 2, figsize=(16, 7))

# Plot A: The Normalized Geometry
axes[0].scatter(norm_coords[:, 0], norm_coords[:, 1], s=5, alpha=0.6, c='steelblue', label='ST Cells')
# Draw the Unit Circle
circle = plt.Circle((0, 0), 1.0, color='red', fill=False, linestyle='--', linewidth=2, label='Unit RMS Circle')
axes[0].add_patch(circle)
axes[0].set_title(f"Normalized ST Data\n({pct_outside:.1f}% points outside red circle)", fontsize=14)
axes[0].set_aspect('equal')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Plot B: Histogram of Radii
sns.histplot(radii, bins=50, ax=axes[1], kde=True, color='purple')
axes[1].axvline(1.0, color='red', linestyle='--', linewidth=2, label='Radius = 1.0')
axes[1].set_title("Distribution of Radii from Center", fontsize=14)
axes[1].set_xlabel("Distance from Center")
axes[1].legend()

plt.tight_layout()
plt.show()

In [None]:
# ===================================================================
# COMPLETE NOTEBOOK: ST-ONLY MODEL (PHASE 1) - SINGLE PATCH INFERENCE
# ===================================================================
import torch
import numpy as np
import scanpy as sc
from scipy.spatial.distance import cdist, pdist, squareform
from scipy.stats import pearsonr, spearmanr
import matplotlib.pyplot as plt
import seaborn as sns
import os
from datetime import datetime

# Setup
output_dir = "/home/ehtesamul/sc_st/model/gems_mousebrain_output"
timestamp = "20251128_100055"

# USE PHASE 1 CHECKPOINT (ST-ONLY, BEFORE SC FINE-TUNING)
checkpoint_path = f"{output_dir}/phase1_st_checkpoint.pt"

print("="*70)
print("ST-ONLY MODEL INFERENCE (Phase 1, Single Patch)")
print("="*70)

# ===================================================================
# STEP 1: LOAD TEST DATA
# ===================================================================
print("\n--- Loading Test Data ---")

from run_mouse_brain_2 import load_mouse_data
scadata, stadata = load_mouse_data()

# Extract SC gene expression
common = sorted(list(set(scadata.var_names) & set(stadata.var_names)))
X_sc = scadata[:, common].X
if hasattr(X_sc, "toarray"):
    X_sc = X_sc.toarray()
sc_expr = torch.tensor(X_sc, dtype=torch.float32)

n_cells = sc_expr.shape[0]
n_genes = sc_expr.shape[1]

print(f"Loaded SC data: {n_cells} cells, {n_genes} genes")
print(f"Ground truth coords shape: {scadata.obsm['spatial_gt'].shape}")

# ===================================================================
# STEP 2: LOAD MODEL AND ST-ONLY CHECKPOINT (PHASE 1)
# ===================================================================
print("\n--- Loading Model and ST-Only Checkpoint (Phase 1) ---")

from core_models_et_p3 import GEMSModel

model = GEMSModel(
    n_genes=n_genes,
    n_embedding=[512, 256, 128],
    D_latent=32,
    c_dim=256,
    n_heads=4,
    isab_m=64,
    device='cuda',
    use_canonicalize=True,
    use_dist_bias=True,
    dist_bins=24,
    dist_head_shared=True,
    use_angle_features=True,
    angle_bins=8,
    knn_k=12,
    self_conditioning=True,
    sc_feat_mode='concat',
    landmarks_L=16
)

checkpoint = torch.load(checkpoint_path, map_location='cuda')
model.encoder.load_state_dict(checkpoint['encoder'])
model.context_encoder.load_state_dict(checkpoint['context_encoder'])
model.generator.load_state_dict(checkpoint['generator'])
model.score_net.load_state_dict(checkpoint['score_net'])

print(f"✓ Loaded ST-ONLY checkpoint from: {checkpoint_path}")
print(f"  Best ST epoch: {checkpoint.get('E_ST_best', 'N/A')}")
print(f"  This model was trained ONLY on ST data (NO SC fine-tuning)")

# ===================================================================
# STEP 3: SINGLE PATCH INFERENCE (DIAGNOSTIC MODE)
# ===================================================================
print("\n--- Running Single Patch Inference (ST-Only Model) ---")
print(f"Config: patch_size={n_cells}, coverage_per_cell=1.0, n_align_iters=1")
print("This runs ONE patch with ALL cells (no stitching)")
print("-"*70)

results = model.infer_sc_patchwise(
    sc_gene_expr=sc_expr,
    n_timesteps_sample=600,
    sigma_min=0.01,
    sigma_max=7.0,
    patch_size=n_cells,          # SINGLE PATCH MODE
    coverage_per_cell=1.0,       # NO OVERLAP
    n_align_iters=1,             # NO STITCHING (only 1 patch)
    eta=0.0,
    guidance_scale=5.0,
    return_coords=True,
    debug_flag=True,
    debug_every=10,
)

print("\n✓ Inference complete")

# ===================================================================
# STEP 4: EXTRACT RAW EDM (NO PROJECTION, NO RESCALING)
# ===================================================================
print("\n--- Computing Raw EDM (No Post-Processing) ---")

# Extract canonicalized coordinates
coords_canon = results['coords_canon'].cpu().numpy()

# Compute RAW EDM directly from coordinates (NO edm_project, NO rescaling)
gems_edm = cdist(coords_canon, coords_canon, metric='euclidean')

print(f"Raw EDM shape: {gems_edm.shape}")
print(f"Raw EDM stats:")
print(f"  Min: {gems_edm[gems_edm > 0].min():.4f}")
print(f"  Median: {np.median(gems_edm[gems_edm > 0]):.4f}")
print(f"  Max: {gems_edm.max():.4f}")
print(f"  Mean: {gems_edm[gems_edm > 0].mean():.4f}")

# ===================================================================
# STEP 5: COMPUTE GROUND TRUTH EDM
# ===================================================================
print("\n--- Calculating Ground Truth EDM ---")

gt_coords = scadata.obsm['spatial_gt']
gt_edm = squareform(pdist(gt_coords, 'euclidean'))

print(f"Ground Truth EDM shape: {gt_edm.shape}")
print(f"Ground Truth EDM stats:")
print(f"  Min: {gt_edm[gt_edm > 0].min():.4f}")
print(f"  Median: {np.median(gt_edm[gt_edm > 0]):.4f}")
print(f"  Max: {gt_edm.max():.4f}")
print(f"  Mean: {gt_edm[gt_edm > 0].mean():.4f}")

# ===================================================================
# STEP 6: NORMALIZE FOR COMPARISON
# ===================================================================
def normalize_matrix(matrix):
    min_val = matrix.min()
    max_val = matrix.max()
    return (matrix - min_val) / (max_val - min_val)

gems_edm_norm = normalize_matrix(gems_edm)
gt_edm_norm = normalize_matrix(gt_edm)

# ===================================================================
# STEP 7: QUANTITATIVE COMPARISON
# ===================================================================
print("\n" + "="*70)
print("QUANTITATIVE COMPARISON (ST-ONLY MODEL)")
print("="*70)

# Extract upper triangle (excluding diagonal)
triu_indices = np.triu_indices(n_cells, k=1)
gt_distances_flat = gt_edm[triu_indices]
gems_distances_flat = gems_edm[triu_indices]

# Scale alignment (median matching)
scale = np.median(gt_distances_flat) / np.median(gems_distances_flat)
gems_distances_flat_scaled = gems_distances_flat * scale

print(f"\nScale factor (median matching): {scale:.4f}")

# Calculate correlations
pearson_corr, _ = pearsonr(gt_distances_flat, gems_distances_flat_scaled)
spearman_corr, _ = spearmanr(gt_distances_flat, gems_distances_flat_scaled)

print(f"\nPearson Correlation: {pearson_corr:.4f}")
print(f"Spearman Correlation: {spearman_corr:.4f}")
print("-"*70)

# ===================================================================
# STEP 8: VISUALIZATIONS
# ===================================================================
print("\n--- Generating Visualizations ---")

# --- PLOT 1: Side-by-Side Heatmaps ---
fig, axes = plt.subplots(1, 2, figsize=(16, 7))
fig.suptitle('EDM Comparison: Ground Truth vs. GEMS (ST-Only Model, Single Patch)', 
             fontsize=18, fontweight='bold')

sample_size = min(838, n_cells)
sample_indices = np.random.choice(n_cells, sample_size, replace=False)
sample_indices = np.sort(sample_indices)

im1 = axes[0].imshow(gt_edm_norm[np.ix_(sample_indices, sample_indices)], cmap='viridis')
axes[0].set_title('Ground Truth EDM (Normalized)', fontsize=14)
axes[0].set_xlabel('Cell Index (Sampled)')
axes[0].set_ylabel('Cell Index (Sampled)')
fig.colorbar(im1, ax=axes[0], fraction=0.046, pad=0.04)

im2 = axes[1].imshow(gems_edm_norm[np.ix_(sample_indices, sample_indices)], cmap='viridis')
axes[1].set_title('GEMS Predicted EDM (ST-Only, Normalized)', fontsize=14)
axes[1].set_xlabel('Cell Index (Sampled)')
fig.colorbar(im2, ax=axes[1], fraction=0.046, pad=0.04)

plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()

# --- PLOT 2: Distribution of Distances ---
fig, ax = plt.subplots(figsize=(10, 6))
sns.histplot(gt_distances_flat, color="blue", label='Ground Truth Distances', 
             ax=ax, stat='density', bins=100, alpha=0.6)
sns.histplot(gems_distances_flat_scaled, color="orange", label='GEMS Distances (ST-Only, Scaled)', 
             ax=ax, stat='density', bins=100, alpha=0.6)
ax.set_title('Distribution of Pairwise Distances (ST-Only Model)', fontsize=16, fontweight='bold')
ax.set_xlabel('Distance', fontsize=12)
ax.set_ylabel('Density', fontsize=12)
ax.legend(fontsize=12)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# --- PLOT 3: Scatter Plot of Distances ---
sample_size_scatter = min(50000, len(gt_distances_flat))
sample_indices_scatter = np.random.choice(len(gt_distances_flat), sample_size_scatter, replace=False)

fig, ax = plt.subplots(figsize=(8, 8))
ax.scatter(
    gt_distances_flat[sample_indices_scatter],
    gems_distances_flat_scaled[sample_indices_scatter],
    alpha=0.2, s=5, color='orange'
)
ax.set_title(f'GEMS vs. Ground Truth Distances (ST-Only Model)\nSpearman ρ = {spearman_corr:.4f}', 
             fontsize=16, fontweight='bold')
ax.set_xlabel('Ground Truth Pairwise Distance', fontsize=12)
ax.set_ylabel('GEMS Pairwise Distance (Scaled)', fontsize=12)
ax.grid(True, linestyle='--', alpha=0.5)

lims = [
    min(ax.get_xlim()[0], ax.get_ylim()[0]),
    max(ax.get_xlim()[1], ax.get_ylim()[1]),
]
ax.plot(lims, lims, 'r--', alpha=0.75, linewidth=2, zorder=0, label='Ideal Correlation')
ax.set_aspect('equal', adjustable='box')
ax.legend(fontsize=12)
plt.tight_layout()
plt.show()

# --- PLOT 4: Coordinate Comparison ---
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
fig.suptitle('Spatial Coordinates: Ground Truth vs. GEMS (ST-Only Model)', 
             fontsize=16, fontweight='bold')

axes[0].scatter(gt_coords[:, 0], gt_coords[:, 1], s=5, alpha=0.6, color='blue')
axes[0].set_title('Ground Truth Coordinates', fontsize=14)
axes[0].set_xlabel('X', fontsize=12)
axes[0].set_ylabel('Y', fontsize=12)
axes[0].set_aspect('equal')
axes[0].grid(True, alpha=0.3)

axes[1].scatter(coords_canon[:, 0], coords_canon[:, 1], s=5, alpha=0.6, color='orange')
axes[1].set_title('GEMS Predicted Coordinates (ST-Only)', fontsize=14)
axes[1].set_xlabel('X', fontsize=12)
axes[1].set_ylabel('Y', fontsize=12)
axes[1].set_aspect('equal')
axes[1].grid(True, alpha=0.3)

plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()

# --- PLOT 5: Distance Error Distribution ---
distance_errors = np.abs(gt_distances_flat - gems_distances_flat_scaled)
fig, ax = plt.subplots(figsize=(10, 6))
sns.histplot(distance_errors, bins=100, kde=True, ax=ax, color='orange')
ax.set_title('Distance Prediction Error Distribution (ST-Only Model)', fontsize=16, fontweight='bold')
ax.set_xlabel('Absolute Error |GT - GEMS|', fontsize=12)
ax.set_ylabel('Count', fontsize=12)
ax.axvline(np.median(distance_errors), color='r', linestyle='--', linewidth=2, 
           label=f'Median Error: {np.median(distance_errors):.4f}')
ax.axvline(np.mean(distance_errors), color='g', linestyle='--', linewidth=2, 
           label=f'Mean Error: {np.mean(distance_errors):.4f}')
ax.legend(fontsize=12)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# ===================================================================
# STEP 9: SAVE RESULTS
# ===================================================================
print("\n--- Saving Results ---")

new_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_suffix = f"st_only_single_patch_{new_timestamp}"

results_processed = {
    'D_edm': gems_edm,  # RAW EDM (no projection, no rescaling)
    'coords': results['coords'].cpu().numpy(),
    'coords_canon': coords_canon,
    'n_cells': n_cells,
    'timestamp': new_timestamp,
    'mode': 'st_only_single_patch_no_projection',
    'scale_factor': scale,
    'pearson_corr': pearson_corr,
    'spearman_corr': spearman_corr,
    'model_config': {
        'n_genes': n_genes,
        'D_latent': 32,
        'c_dim': 256,
        'phase': 'ST-only (Phase 1)',
    }
}

processed_path = os.path.join(output_dir, f"sc_inference_processed_{output_suffix}.pt")
# torch.save(results_processed, processed_path)
# print(f"✓ Saved: {processed_path}")

scadata.obsm['X_gems_st_only'] = coords_canon
adata_path = os.path.join(output_dir, f"scadata_with_gems_{output_suffix}.h5ad")
scadata.write_h5ad(adata_path)
print(f"✓ Saved: {adata_path}")

print("\n" + "="*70)
print("ST-ONLY MODEL DIAGNOSTIC COMPLETE")
print("="*70)
print(f"\nResults Summary:")
print(f"  Model: ST-Only (Phase 1, BEFORE SC fine-tuning)")
print(f"  Mode: Single patch (patch_size={n_cells})")
print(f"  EDM: Raw (no projection, no rescaling)")
print(f"  Pearson: {pearson_corr:.4f}")
print(f"  Spearman: {spearman_corr:.4f}")
print(f"  Scale factor: {scale:.4f}")
print(f"  Output timestamp: {output_suffix}")
print("\nThis tells you if ring collapse happens during:")
print("  - ST-only training (Phase 1) → if you see ring now")
print("  - SC fine-tuning (Phase 2) → if you saw ring only with fine-tuned model")

In [None]:
# ===================================================================
# TIMESTEP-BY-TIMESTEP DIFFUSION VISUALIZATION
# ===================================================================
import torch
import numpy as np
import scanpy as sc
import matplotlib.pyplot as plt
import os
from datetime import datetime

# Setup
output_dir = "/home/ehtesamul/sc_st/model/gems_mousebrain_output"
timestamp = "20251128_100055"
checkpoint_path = f"{output_dir}/phase2_sc_finetuned_checkpoint.pt"

print("="*70)
print("DIFFUSION TIMESTEP VISUALIZATION (Single Patch)")
print("="*70)

# ===================================================================
# STEP 1: LOAD TEST DATA
# ===================================================================
print("\n--- Loading Test Data ---")

from run_mouse_brain_2 import load_mouse_data
scadata, stadata = load_mouse_data()

common = sorted(list(set(scadata.var_names) & set(stadata.var_names)))
X_sc = scadata[:, common].X
if hasattr(X_sc, "toarray"):
    X_sc = X_sc.toarray()
sc_expr = torch.tensor(X_sc, dtype=torch.float32)

n_cells = sc_expr.shape[0]
n_genes = sc_expr.shape[1]

print(f"Loaded SC data: {n_cells} cells, {n_genes} genes")

# ===================================================================
# STEP 2: LOAD MODEL
# ===================================================================
print("\n--- Loading Model ---")

from core_models_et_p3 import GEMSModel
import utils_et as uet

model = GEMSModel(
    n_genes=n_genes,
    n_embedding=[512, 256, 128],
    D_latent=32,
    c_dim=256,
    n_heads=4,
    isab_m=64,
    device='cuda',
    use_canonicalize=True,
    use_dist_bias=True,
    dist_bins=24,
    dist_head_shared=True,
    use_angle_features=True,
    angle_bins=8,
    knn_k=12,
    self_conditioning=True,
    sc_feat_mode='concat',
    landmarks_L=16
)

checkpoint = torch.load(checkpoint_path, map_location='cuda')
model.encoder.load_state_dict(checkpoint['encoder'])
model.context_encoder.load_state_dict(checkpoint['context_encoder'])
model.generator.load_state_dict(checkpoint['generator'])
model.score_net.load_state_dict(checkpoint['score_net'])

print(f"✓ Loaded checkpoint")

# ===================================================================
# STEP 3: INLINE DIFFUSION SAMPLER WITH TIMESTEP CAPTURE
# ===================================================================
print("\n--- Running Diffusion with Timestep Capture ---")

device = 'cuda'
n_timesteps_sample = 600
sigma_min = 0.01
sigma_max = 7.0
guidance_scale = 2.0
D_latent = 32

model.encoder.eval()
model.context_encoder.eval()
model.score_net.eval()

print(f"Config: n_timesteps={n_timesteps_sample}, guidance_scale={guidance_scale}")
print(f"        sigma_min={sigma_min}, sigma_max={sigma_max}")

# Encode all SC cells
print("\n[1/4] Encoding SC cells...")
with torch.no_grad():
    Z_all = model.encoder(sc_expr.to(device))  # (n_cells, hidden_dim)
    
# Prepare context
print("[2/4] Computing context...")
Z_batch = Z_all.unsqueeze(0)  # (1, n_cells, hidden_dim)
mask = torch.ones(1, n_cells, dtype=torch.bool, device=device)
H = model.context_encoder(Z_batch, mask)  # (1, n_cells, c_dim)

# Sigma schedule
sigmas = torch.exp(torch.linspace(
    torch.log(torch.tensor(sigma_max, device=device)),
    torch.log(torch.tensor(sigma_min, device=device)),
    n_timesteps_sample,
    device=device,
))

# Initialize noise
print("[3/4] Running reverse diffusion...")
V_t = torch.randn(1, n_cells, D_latent, device=device) * sigmas[0]

# Timesteps to save
save_timesteps = [0, 100, 200, 300, 400, 500, 599]
saved_samples = {}

with torch.no_grad():
    for t_idx in range(n_timesteps_sample):
        sigma_t = sigmas[t_idx]
        t_norm = torch.tensor([[t_idx / float(n_timesteps_sample - 1)]], device=device)
        
        # CFG sampling
        H_null = torch.zeros_like(H)
        eps_uncond = model.score_net(V_t, t_norm, H_null, mask)
        eps_cond = model.score_net(V_t, t_norm, H, mask)
        eps = eps_uncond + guidance_scale * (eps_cond - eps_uncond)
        
        # Update
        if t_idx < n_timesteps_sample - 1:
            sigma_next = sigmas[t_idx + 1]
            V_0_pred = V_t - sigma_t * eps
            V_t = V_0_pred + (sigma_next / sigma_t) * (V_t - V_0_pred)
        else:
            V_t = V_t - sigma_t * eps
        
        # Save at specific timesteps
        if t_idx in save_timesteps:
            # Canonicalize the current sample
            V_canon = uet.canonicalize_coords(V_t.squeeze(0))
            saved_samples[t_idx] = V_canon.cpu().numpy()
            print(f"  Saved timestep {t_idx}/{n_timesteps_sample-1}")

# Final sample
V_final = V_t.squeeze(0)  # (n_cells, D_latent)
V_final_canon = uet.canonicalize_coords(V_final)
coords_final = V_final_canon.cpu().numpy()

print("[4/4] Complete!")

# ===================================================================
# STEP 4: CONVERT TO 2D COORDINATES VIA MDS
# ===================================================================
print("\n--- Converting to 2D coordinates ---")

def latent_to_2d(V_latent):
    """Convert D_latent dimensional coordinates to 2D via MDS"""
    n = V_latent.shape[0]
    V_tensor = torch.tensor(V_latent, dtype=torch.float32)
    
    # Compute EDM from latent coordinates
    D = torch.cdist(V_tensor, V_tensor)
    
    # Classical MDS
    Jn = torch.eye(n) - torch.ones(n, n) / n
    B = -0.5 * (Jn @ (D**2) @ Jn)
    
    # Extract 2D coordinates
    coords_2d = uet.classical_mds(B, d_out=2).numpy()
    coords_2d = uet.canonicalize_coords(torch.tensor(coords_2d)).numpy()
    
    return coords_2d

coords_at_timesteps = {}
for t_idx, V in saved_samples.items():
    coords_at_timesteps[t_idx] = latent_to_2d(V)
    print(f"  Converted timestep {t_idx} to 2D")

# ===================================================================
# STEP 5: VISUALIZE DIFFUSION EVOLUTION
# ===================================================================
print("\n--- Generating Visualizations ---")

# Ground truth for reference
gt_coords = scadata.obsm['spatial_gt']

# Plot grid: GT + all saved timesteps
n_plots = len(save_timesteps) + 1
n_cols = 4
n_rows = int(np.ceil(n_plots / n_cols))

fig, axes = plt.subplots(n_rows, n_cols, figsize=(20, 5*n_rows))
axes = axes.flatten()

# Plot ground truth
axes[0].scatter(gt_coords[:, 0], gt_coords[:, 1], s=5, alpha=0.6, c='blue')
axes[0].set_title('Ground Truth', fontsize=14, fontweight='bold')
axes[0].set_aspect('equal')
axes[0].grid(True, alpha=0.3)

# Plot diffusion timesteps
for idx, t_idx in enumerate(save_timesteps):
    ax = axes[idx + 1]
    coords = coords_at_timesteps[t_idx]
    
    ax.scatter(coords[:, 0], coords[:, 1], s=5, alpha=0.6, c='red')
    ax.set_title(f'Timestep {t_idx}/{n_timesteps_sample-1}\n(σ={sigmas[t_idx]:.4f})', 
                 fontsize=12, fontweight='bold')
    ax.set_aspect('equal')
    ax.grid(True, alpha=0.3)

# Hide unused subplots
for idx in range(n_plots, len(axes)):
    axes[idx].axis('off')

plt.suptitle(f'Diffusion Evolution (guidance_scale={guidance_scale}, n_timesteps={n_timesteps_sample})', 
             fontsize=18, fontweight='bold', y=0.995)
plt.tight_layout(rect=[0, 0, 1, 0.99])
plt.show()

# ===================================================================
# ADDITIONAL PLOT: SIDE-BY-SIDE EVOLUTION
# ===================================================================
fig, axes = plt.subplots(2, 4, figsize=(20, 10))

# Top row: early timesteps
for idx, t_idx in enumerate([0, 100, 200, 300]):
    coords = coords_at_timesteps[t_idx]
    axes[0, idx].scatter(coords[:, 0], coords[:, 1], s=5, alpha=0.6, c='red')
    axes[0, idx].set_title(f't={t_idx} (σ={sigmas[t_idx]:.3f})', fontsize=12, fontweight='bold')
    axes[0, idx].set_aspect('equal')
    axes[0, idx].grid(True, alpha=0.3)

# Bottom row: late timesteps
for idx, t_idx in enumerate([400, 500, 599]):
    coords = coords_at_timesteps[t_idx]
    axes[1, idx].scatter(coords[:, 0], coords[:, 1], s=5, alpha=0.6, c='red')
    axes[1, idx].set_title(f't={t_idx} (σ={sigmas[t_idx]:.3f})', fontsize=12, fontweight='bold')
    axes[1, idx].set_aspect('equal')
    axes[1, idx].grid(True, alpha=0.3)

# Ground truth in last position
axes[1, 3].scatter(gt_coords[:, 0], gt_coords[:, 1], s=5, alpha=0.6, c='blue')
axes[1, 3].set_title('Ground Truth', fontsize=12, fontweight='bold')
axes[1, 3].set_aspect('equal')
axes[1, 3].grid(True, alpha=0.3)

plt.suptitle('Diffusion Denoising Trajectory', fontsize=18, fontweight='bold')
plt.tight_layout(rect=[0, 0, 1, 0.97])
plt.show()

# ===================================================================
# STEP 6: QUANTIFY STRUCTURE COLLAPSE
# ===================================================================
print("\n--- Analyzing Structure Collapse ---")

def compute_pca_variance_ratio(coords):
    """Compute variance explained by first 2 PCA components"""
    from sklearn.decomposition import PCA
    pca = PCA(n_components=2)
    pca.fit(coords)
    return pca.explained_variance_ratio_

def compute_circularity(coords):
    """Compute circularity score (higher = more ring-like)"""
    center = coords.mean(axis=0)
    radii = np.linalg.norm(coords - center, axis=1)
    return 1.0 - (radii.std() / radii.mean())

print("\n{:<10} {:<15} {:<15} {:<15}".format("Timestep", "PCA-1 Var", "PCA-2 Var", "Circularity"))
print("-"*60)

for t_idx in save_timesteps:
    coords = coords_at_timesteps[t_idx]
    var_ratios = compute_pca_variance_ratio(coords)
    circ = compute_circularity(coords)
    print(f"{t_idx:<10} {var_ratios[0]:<15.4f} {var_ratios[1]:<15.4f} {circ:<15.4f}")

# Ground truth
gt_var_ratios = compute_pca_variance_ratio(gt_coords)
gt_circ = compute_circularity(gt_coords)
print(f"{'GT':<10} {gt_var_ratios[0]:<15.4f} {gt_var_ratios[1]:<15.4f} {gt_circ:<15.4f}")

print("\n" + "="*70)
print("TIMESTEP ANALYSIS COMPLETE")
print("="*70)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import pandas as pd
import anndata as ad
import matplotlib.pyplot as plt
import os

from core_models_et_p3 import GEMSModel
from core_models_et_p1 import STSetDataset, collate_minisets
import utils_et as uet

# ============================================================================
# SETUP
# ============================================================================

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# ============================================================================
# LOAD DATA (from run_mouse_brain_2.py)
# ============================================================================

st_counts = '/home/ehtesamul/sc_st/data/mousedata_2020/E1z2/simu_st1_counts_et.csv'
st_meta   = '/home/ehtesamul/sc_st/data/mousedata_2020/E1z2/simu_st1_metadata_et.csv'
st_ct     = '/home/ehtesamul/sc_st/data/mousedata_2020/E1z2/simu_st1_celltype_et.csv'

print("Loading ST1 (training ST data)...")
st_expr_df = pd.read_csv(st_counts, index_col=0)
st_meta_df = pd.read_csv(st_meta, index_col=0)
st_ct_df = pd.read_csv(st_ct, index_col=0)

stadata = ad.AnnData(X=st_expr_df.values.T)
stadata.obs_names = st_expr_df.columns
stadata.var_names = st_expr_df.index
stadata.obsm['spatial'] = st_meta_df[['coord_x', 'coord_y']].values
stadata.obs['celltype_mapped_refined'] = st_ct_df.idxmax(axis=1).values
stadata.obsm['celltype_proportions'] = st_ct_df.values

print(f"ST1 loaded: {stadata.shape[0]} spots, {stadata.shape[1]} genes")

# Extract expression and coordinates
X_st = stadata.X
if hasattr(X_st, "toarray"):
    X_st = X_st.toarray()

st_expr = torch.tensor(X_st, dtype=torch.float32, device=device)
st_coords_raw = torch.tensor(stadata.obsm['spatial'], dtype=torch.float32, device=device)

# Apply per-slide canonicalization (same as training)
slide_ids = torch.zeros(st_expr.shape[0], dtype=torch.long, device=device)
st_coords, st_mu, st_scale = uet.canonicalize_st_coords_per_slide(
    st_coords_raw, slide_ids
)

print(f"ST coords canonicalized: scale={st_scale[0].item():.4f}")

# ============================================================================
# LOAD TRAINED ENCODER
# ============================================================================

outdir = '/home/ehtesamul/sc_st/model/gems_mousebrain_output'
checkpoint_path = os.path.join(outdir, 'ab_init.pt')

n_genes = stadata.shape[1]

# Create model with same config as run_mouse_brain_2.py
model = GEMSModel(
    n_genes=n_genes,
    n_embedding=[512, 256, 128],
    D_latent=32,
    c_dim=256,
    n_heads=4,
    isab_m=64,
    device=str(device),
    use_canonicalize=True,
    use_dist_bias=True,
    dist_bins=24,
    dist_head_shared=True,
    use_angle_features=True,
    angle_bins=8,
    knn_k=12,
    self_conditioning=True,
    sc_feat_mode='concat',
    landmarks_L=16,
)

print(f"\nLoading checkpoint from: {checkpoint_path}")
ckpt = torch.load(checkpoint_path, map_location=device)
model.encoder.load_state_dict(ckpt['encoder'])
model.encoder.eval()

print("Encoder loaded and frozen.")

# ============================================================================
# RUN STAGE B (takes ~3 seconds)
# ============================================================================

print("\n=== Running Stage B ===")
slides_dict = {0: (st_coords, st_expr)}
model.train_stageB(
    slides=slides_dict,
    outdir='temp_stageB_cache'
)

print("Stage B complete. targets_dict populated.")

# ============================================================================
# DEFINE SUPERVISED REGRESSION HEAD
# ============================================================================

# class SupervisedEDMHead(nn.Module):
#     """
#     Simple supervised head that predicts EDM from encoder embeddings.
    
#     Architecture:
#     Z (from encoder) -> MLP -> upper triangular EDM prediction
#     """
#     def __init__(self, h_dim: int, hidden_dim: int = 256):
#         super().__init__()
#         self.h_dim = h_dim
        
#         # MLP to predict pairwise distances
#         self.mlp = nn.Sequential(
#             nn.Linear(h_dim * 2, hidden_dim),
#             nn.ReLU(),
#             nn.Linear(hidden_dim, hidden_dim),
#             nn.ReLU(),
#             nn.Linear(hidden_dim, 1)
#         )
    
#     def forward(self, Z: torch.Tensor, mask: torch.Tensor):
#         """
#         Args:
#             Z: (batch, n, h_dim) encoder embeddings
#             mask: (batch, n) validity mask
            
#         Returns:
#             D_pred: (batch, n, n) predicted distance matrix
#         """
#         batch, n, h = Z.shape
        
#         # Create pairwise concatenations
#         Z_i = Z.unsqueeze(2).expand(-1, -1, n, -1)  # (batch, n, n, h)
#         Z_j = Z.unsqueeze(1).expand(-1, n, -1, -1)  # (batch, n, n, h)
#         Z_pairs = torch.cat([Z_i, Z_j], dim=-1)     # (batch, n, n, 2h)
        
#         # Predict distances
#         D_pred = self.mlp(Z_pairs).squeeze(-1)      # (batch, n, n)
#         D_pred = torch.relu(D_pred)                  # Ensure non-negative
        
#         # Symmetrize
#         D_pred = (D_pred + D_pred.transpose(-1, -2)) / 2.0
        
#         # Zero out diagonal
#         diag_mask = torch.eye(n, device=Z.device).unsqueeze(0).bool()
#         D_pred = D_pred.masked_fill(diag_mask, 0.0)
        
#         # Apply validity mask
#         valid_mask = mask.unsqueeze(-1) & mask.unsqueeze(-2)
#         D_pred = D_pred * valid_mask.float()
        
#         return D_pred
    
class SupervisedCoordHead(nn.Module):
    """
    Simple supervised head that predicts 2D coordinates from encoder embeddings.
    """
    def __init__(self, h_dim: int, hidden_dim: int = 256, D_out: int = 2):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(h_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, D_out),
        )

    def forward(self, Z: torch.Tensor, mask: torch.Tensor):
        """
        Args:
            Z: (batch, n, h_dim) encoder embeddings
            mask: (batch, n) validity mask
            
        Returns:
            coords: (batch, n, 2) predicted coordinates
        """
        coords = self.mlp(Z)                    # (batch, n, 2)
        coords = coords * mask.unsqueeze(-1)    # zero out padded entries
        return coords

# ============================================================================
# CREATE DATASET AND DATALOADER
# ============================================================================

# Create ST miniset dataset (same as Stage C training)
st_gene_expr_dict_cpu = {0: st_expr.cpu()}

st_dataset = STSetDataset(
    targets_dict=model.targets_dict,
    encoder=model.encoder,
    st_gene_expr_dict=st_gene_expr_dict_cpu,
    n_min=64,
    n_max=384,
    D_latent=model.D_latent,
    num_samples=4000,  # Same as run_mouse_brain_2.py
    knn_k=12,
    device=device,
    landmarks_L=16
)

st_loader = DataLoader(
    st_dataset,
    batch_size=8,
    shuffle=True,
    num_workers=0,
    collate_fn=collate_minisets
)

print(f"ST dataset created: {len(st_dataset)} samples")

In [None]:
# ============================================================================
# INITIALIZE SUPERVISED HEAD
# ============================================================================

# h_dim = model.encoder.fc_list[-1].out_features  # Get encoder output dim
# Get encoder output dimension by doing a forward pass
with torch.no_grad():
    dummy_input = torch.randn(1, n_genes, device=device)
    h_dim = model.encoder(dummy_input).shape[-1]
# supervised_head = SupervisedEDMHead(h_dim=h_dim, hidden_dim=256).to(device)

supervised_head = SupervisedCoordHead(h_dim=h_dim, hidden_dim=256).to(device)

optimizer = optim.Adam(supervised_head.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)

print(f"\nSupervised head initialized: h_dim={h_dim}")

# ============================================================================
# TRAINING LOOP
# ============================================================================

num_epochs = 50
loss_history = []

print("\n=== Training Supervised Baseline ===\n")

supervised_head.train()

for epoch in range(num_epochs):
    epoch_losses = []
    
    for batch_idx, batch in enumerate(st_loader):
        # Move batch to device
        Z = batch['Z_set'].to(device)              # (batch, n, h)
        mask = batch['mask'].to(device)            # (batch, n)
        D_target = batch['D_target'].to(device)    # (batch, n, n)
        
        # # Forward pass
        # D_pred = supervised_head(Z, mask)
        
        # # Loss: MSE on valid EDM entries
        # valid_mask = mask.unsqueeze(-1) & mask.unsqueeze(-2)
        # loss = ((D_pred - D_target) ** 2 * valid_mask.float()).sum() / valid_mask.float().sum()
        
        # Forward pass
        coords_pred = supervised_head(Z, mask)  # (batch, n, 2)

        # Compute EDM from predicted coords
        D_pred = torch.cdist(coords_pred, coords_pred)  # (batch, n, n)

        # Loss: MSE on valid EDM entries
        valid_mask = mask.unsqueeze(-1) & mask.unsqueeze(-2)
        loss = ((D_pred - D_target) ** 2 * valid_mask.float()).sum() / valid_mask.float().sum()
        
        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_losses.append(loss.item())
    
    scheduler.step()
    
    avg_loss = np.mean(epoch_losses)
    loss_history.append(avg_loss)
    
    if (epoch + 1) % 5 == 0 or epoch == 0:
        print(f"Epoch {epoch+1:3d}/{num_epochs} | Loss: {avg_loss:.6f}")

print("\n=== Training Complete ===\n")

# ============================================================================
# EVALUATE: SAMPLE A FEW MINISETS AND CHECK GEOMETRY
# ============================================================================

In [None]:
supervised_head.eval()

print("=== Evaluating Supervised Baseline ===\n")

num_eval_samples = 5
eval_results = []

with torch.no_grad():
    eval_iter = iter(st_loader)
    
    for i in range(num_eval_samples):
        batch = next(eval_iter)
        
        Z = batch['Z_set'].to(device)
        mask = batch['mask'].to(device)
        D_target = batch['D_target'].to(device)
        
        # Predict coordinates directly
        coords_pred = supervised_head(Z, mask)
        
        # Take first sample in batch
        b = 0
        m = mask[b]
        n_valid = m.sum().item()
        
        coords_pred_sample = coords_pred[b, m].cpu()
        D_target_sample = D_target[b, m][:, m].cpu()
        
        # Compute MDS from target EDM (ground truth)
        n = D_target_sample.shape[0]
        Jn = torch.eye(n) - torch.ones(n, n) / n
        B_target = -0.5 * (Jn @ (D_target_sample ** 2) @ Jn)
        coords_target = uet.classical_mds(B_target, d_out=2)
        
        # Canonicalize both
        coords_pred_canon = uet.canonicalize_coords(coords_pred_sample)
        coords_target_canon = uet.canonicalize_coords(coords_target)
        
        # Compute correlation
        corr_x = np.corrcoef(coords_pred_canon[:, 0].numpy(), coords_target_canon[:, 0].numpy())[0, 1]
        corr_y = np.corrcoef(coords_pred_canon[:, 1].numpy(), coords_target_canon[:, 1].numpy())[0, 1]
        avg_corr = (abs(corr_x) + abs(corr_y)) / 2.0
        
        # EDM correlation
        D_pred_sample = torch.cdist(coords_pred_sample.unsqueeze(0), coords_pred_sample.unsqueeze(0)).squeeze(0)
        edm_corr = np.corrcoef(
            D_pred_sample.flatten().numpy(),
            D_target_sample.flatten().numpy()
        )[0, 1]
        
        eval_results.append({
            'sample': i,
            'n_points': n_valid,
            'corr_x': corr_x,
            'corr_y': corr_y,
            'avg_corr': avg_corr,
            'edm_corr': edm_corr,
            'coords_pred': coords_pred_canon.numpy(),
            'coords_target': coords_target_canon.numpy()
        })
        
        print(f"Sample {i}: n={n_valid:3d} | EDM_corr={edm_corr:.4f} | "
              f"Coord_corr: x={corr_x:.4f}, y={corr_y:.4f}, avg={avg_corr:.4f}")

In [None]:
# ============================================================================
# PLOT GROUND TRUTH VS PREDICTED COORDINATES
# ============================================================================

fig, axes = plt.subplots(2, num_eval_samples, figsize=(4*num_eval_samples, 8))

for i, res in enumerate(eval_results):
    # Predicted coordinates
    axes[0, i].scatter(res['coords_pred'][:, 0], res['coords_pred'][:, 1], 
                      s=10, alpha=0.6, c='blue')
    axes[0, i].set_title(f"Sample {i}: Predicted\ncorr={res['avg_corr']:.3f}")
    axes[0, i].set_aspect('equal')
    axes[0, i].grid(True, alpha=0.3)
    
    # Ground truth coordinates
    axes[1, i].scatter(res['coords_target'][:, 0], res['coords_target'][:, 1],
                      s=10, alpha=0.6, c='red')
    axes[1, i].set_title(f"Ground Truth")
    axes[1, i].set_aspect('equal')
    axes[1, i].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('supervised_baseline_coords_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# ===================================================================
# COMPLETE NOTEBOOK: ST-ONLY MODEL (PHASE 1) - SINGLE PATCH INFERENCE
# ===================================================================
import torch
import numpy as np
import scanpy as sc
from scipy.spatial.distance import cdist, pdist, squareform
from scipy.stats import pearsonr, spearmanr
import matplotlib.pyplot as plt
import seaborn as sns
import os
from datetime import datetime

# Setup
output_dir = "/home/ehtesamul/sc_st/model/gems_mousebrain_output"
# timestamp = "20251125_105556"
timestamp = "20251125_105556"


# USE PHASE 1 CHECKPOINT (ST-ONLY, BEFORE SC FINE-TUNING)
checkpoint_path = f"{output_dir}/phase1_st_checkpoint.pt"

print("="*70)
print("ST-ONLY MODEL INFERENCE (Phase 1, Single Patch)")
print("="*70)

# ===================================================================
# STEP 1: LOAD TEST DATA
# ===================================================================
print("\n--- Loading Test Data ---")

from run_mouse_brain_2 import load_mouse_data
scadata, stadata = load_mouse_data()

# Extract SC gene expression
common = sorted(list(set(scadata.var_names) & set(stadata.var_names)))
X_sc = scadata[:, common].X
if hasattr(X_sc, "toarray"):
    X_sc = X_sc.toarray()
sc_expr = torch.tensor(X_sc, dtype=torch.float32)

n_cells = sc_expr.shape[0]
n_genes = sc_expr.shape[1]

print(f"Loaded SC data: {n_cells} cells, {n_genes} genes")
print(f"Ground truth coords shape: {scadata.obsm['spatial_gt'].shape}")

# ===================================================================
# STEP 2: LOAD MODEL AND ST-ONLY CHECKPOINT (PHASE 1)
# ===================================================================
print("\n--- Loading Model and ST-Only Checkpoint (Phase 1) ---")

from core_models_et_p3 import GEMSModel

model = GEMSModel(
    n_genes=n_genes,
    n_embedding=[512, 256, 128],
    D_latent=32,
    c_dim=256,
    n_heads=4,
    isab_m=64,
    device='cuda',
    use_canonicalize=True,
    use_dist_bias=True,
    dist_bins=24,
    dist_head_shared=True,
    use_angle_features=True,
    angle_bins=8,
    knn_k=12,
    self_conditioning=True,
    sc_feat_mode='concat',
    landmarks_L=16
)

checkpoint = torch.load(checkpoint_path, map_location='cuda')
model.encoder.load_state_dict(checkpoint['encoder'])
model.context_encoder.load_state_dict(checkpoint['context_encoder'])
model.generator.load_state_dict(checkpoint['generator'])
model.score_net.load_state_dict(checkpoint['score_net'])

print(f"✓ Loaded ST-ONLY checkpoint from: {checkpoint_path}")
print(f"  Best ST epoch: {checkpoint.get('E_ST_best', 'N/A')}")
print(f"  This model was trained ONLY on ST data (NO SC fine-tuning)")

# ===================================================================
# STEP 3: SINGLE PATCH INFERENCE (DIAGNOSTIC MODE)
# ===================================================================
print("\n--- Running Single Patch Inference (ST-Only Model) ---")
print(f"Config: patch_size={n_cells}, coverage_per_cell=1.0, n_align_iters=1")
print("This runs ONE patch with ALL cells (no stitching)")
print("-"*70)

In [None]:
# ============================================================================
# DIFFUSION INFERENCE ON ST MINISETS - COMPLETE CODE
# ============================================================================

import torch
import numpy as np
import pandas as pd
import anndata as ad
import matplotlib.pyplot as plt
import os
from torch.utils.data import DataLoader

from core_models_et_p3 import GEMSModel
from core_models_et_p1 import STSetDataset, collate_minisets
import utils_et as uet

# ============================================================================
# SETUP
# ============================================================================

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
output_dir = '/home/ehtesamul/sc_st/model/gems_mousebrain_output'
checkpoint_path = os.path.join(output_dir, 'phase1_st_checkpoint.pt')

print("="*80)
print("DIFFUSION MODEL INFERENCE ON ST MINISETS")
print("="*80)

# ============================================================================
# LOAD ST DATA
# ============================================================================

st_counts = '/home/ehtesamul/sc_st/data/mousedata_2020/E1z2/simu_st1_counts_et.csv'
st_meta   = '/home/ehtesamul/sc_st/data/mousedata_2020/E1z2/simu_st1_metadata_et.csv'
st_ct     = '/home/ehtesamul/sc_st/data/mousedata_2020/E1z2/simu_st1_celltype_et.csv'

print("\nLoading ST1 data...")
st_expr_df = pd.read_csv(st_counts, index_col=0)
st_meta_df = pd.read_csv(st_meta, index_col=0)
st_ct_df = pd.read_csv(st_ct, index_col=0)

stadata = ad.AnnData(X=st_expr_df.values.T)
stadata.obs_names = st_expr_df.columns
stadata.var_names = st_expr_df.index
stadata.obsm['spatial'] = st_meta_df[['coord_x', 'coord_y']].values
stadata.obs['celltype_mapped_refined'] = st_ct_df.idxmax(axis=1).values

print(f"ST1 loaded: {stadata.shape[0]} spots, {stadata.shape[1]} genes")

# Extract and canonicalize
X_st = stadata.X
if hasattr(X_st, "toarray"):
    X_st = X_st.toarray()

st_expr = torch.tensor(X_st, dtype=torch.float32, device=device)
st_coords_raw = torch.tensor(stadata.obsm['spatial'], dtype=torch.float32, device=device)

slide_ids = torch.zeros(st_expr.shape[0], dtype=torch.long, device=device)
st_coords, st_mu, st_scale = uet.canonicalize_st_coords_per_slide(
    st_coords_raw, slide_ids
)

print(f"ST coords canonicalized: scale={st_scale[0].item():.4f}")

# ============================================================================
# LOAD MODEL AND PHASE 1 CHECKPOINT
# ============================================================================

n_genes = stadata.shape[1]

model = GEMSModel(
    n_genes=n_genes,
    n_embedding=[512, 256, 128],
    D_latent=32,
    c_dim=256,
    n_heads=4,
    isab_m=64,
    device=str(device),
    use_canonicalize=True,
    use_dist_bias=True,
    dist_bins=24,
    dist_head_shared=True,
    use_angle_features=True,
    angle_bins=8,
    knn_k=12,
    self_conditioning=True,
    sc_feat_mode='concat',
    landmarks_L=16,
)

print(f"\nLoading checkpoint: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=device)

model.encoder.load_state_dict(checkpoint['encoder'])
model.context_encoder.load_state_dict(checkpoint['context_encoder'])
model.generator.load_state_dict(checkpoint['generator'])
model.score_net.load_state_dict(checkpoint['score_net'])

print(f"✓ Loaded Phase 1 ST-only checkpoint")
print(f"  Best ST epoch: {checkpoint.get('E_ST_best', 'N/A')}")

model.encoder.eval()
model.context_encoder.eval()
model.score_net.eval()

# ============================================================================
# RUN STAGE B TO GET TARGETS_DICT
# ============================================================================

print("\n=== Running Stage B ===")
slides_dict = {0: (st_coords, st_expr)}
model.train_stageB(
    slides=slides_dict,
    outdir='temp_stageB_cache'
)
print("Stage B complete.")

# ============================================================================
# CREATE ST MINISET DATASET
# ============================================================================

st_gene_expr_dict_cpu = {0: st_expr.cpu()}

st_dataset = STSetDataset(
    targets_dict=model.targets_dict,
    encoder=model.encoder,
    st_gene_expr_dict=st_gene_expr_dict_cpu,
    n_min=64,
    n_max=384,
    D_latent=model.D_latent,
    num_samples=4000,
    knn_k=12,
    device=device,
    landmarks_L=16
)

st_loader = DataLoader(
    st_dataset,
    batch_size=8,
    shuffle=True,
    num_workers=0,
    collate_fn=collate_minisets
)

print(f"ST dataset created: {len(st_dataset)} samples")

# ============================================================================
# RUN DIFFUSION INFERENCE ON ST MINISETS
# ============================================================================

num_eval_samples = 5
diffusion_results = []

print("\n--- Running diffusion inference on ST minisets ---\n")

with torch.no_grad():
    eval_iter = iter(st_loader)
    
    for i in range(num_eval_samples):
        batch = next(eval_iter)
        
        mask = batch['mask'].to(device)
        D_target = batch['D_target'].to(device)
        
        # Take first sample in batch
        b = 0
        m = mask[b]
        n_valid = m.sum().item()
        
        # Get indices for this miniset
        indices = batch['overlap_info']['indices'][b]
        valid_indices = indices[m].cpu()
        
        # Get gene expression for these specific ST spots
        miniset_expr = st_expr.cpu()[valid_indices]
        
        print(f"Sample {i}: Running diffusion inference on {n_valid} points...")
        
        # Run patchwise inference with single patch (no stitching)
        inf_results = model.infer_sc_patchwise(
            sc_gene_expr=miniset_expr,
            n_timesteps_sample=300,
            sigma_min=0.01,
            sigma_max=7.0,
            patch_size=n_valid,          # Single patch = all points
            coverage_per_cell=1.0,       # No overlap
            n_align_iters=1,             # No stitching
            eta=0.0,
            guidance_scale=6.0,
            return_coords=True,
            debug_flag=False,
            debug_every=10,
        )
        
        # Extract predicted coordinates
        coords_diffusion = inf_results['coords_canon']
        
        # Get ground truth coordinates from target EDM
        D_target_sample = D_target[b, m][:, m].cpu()
        n = D_target_sample.shape[0]
        Jn = torch.eye(n) - torch.ones(n, n) / n
        B_target = -0.5 * (Jn @ (D_target_sample ** 2) @ Jn)
        coords_target = uet.classical_mds(B_target, d_out=2)
        coords_target_canon = uet.canonicalize_coords(coords_target)
        
        # Compute correlations
        corr_x = np.corrcoef(coords_diffusion[:, 0].numpy(), coords_target_canon[:, 0].numpy())[0, 1]
        corr_y = np.corrcoef(coords_diffusion[:, 1].numpy(), coords_target_canon[:, 1].numpy())[0, 1]
        avg_corr = (abs(corr_x) + abs(corr_y)) / 2.0
        
        # EDM correlation
        D_diffusion = torch.cdist(coords_diffusion.unsqueeze(0), coords_diffusion.unsqueeze(0)).squeeze(0)
        edm_corr = np.corrcoef(
            D_diffusion.flatten().numpy(),
            D_target_sample.flatten().numpy()
        )[0, 1]
        
        diffusion_results.append({
            'sample': i,
            'n_points': n_valid,
            'corr_x': corr_x,
            'corr_y': corr_y,
            'avg_corr': avg_corr,
            'edm_corr': edm_corr,
            'coords_diffusion': coords_diffusion.numpy(),
            'coords_target': coords_target_canon.numpy()
        })
        
        print(f"  EDM_corr={edm_corr:.4f} | Coord_corr: x={corr_x:.4f}, y={corr_y:.4f}, avg={avg_corr:.4f}\n")

print("="*80)
print("DIFFUSION INFERENCE COMPLETE")
print("="*80)

# ============================================================================
# PRINT COMPARISON (assuming eval_results from supervised baseline exists)
# ============================================================================

print(f"\nDiffusion Model (Phase 1 ST-only) Results:")
print(f"  Average EDM correlation:   {np.mean([r['edm_corr'] for r in diffusion_results]):.4f}")
print(f"  Average Coord correlation: {np.mean([r['avg_corr'] for r in diffusion_results]):.4f}")

# ============================================================================
# PLOT: DIFFUSION vs GROUND TRUTH
# ============================================================================

fig, axes = plt.subplots(2, num_eval_samples, figsize=(4*num_eval_samples, 8))

for i in range(num_eval_samples):
    # Diffusion prediction
    axes[0, i].scatter(diffusion_results[i]['coords_diffusion'][:, 0],
                      diffusion_results[i]['coords_diffusion'][:, 1],
                      s=10, alpha=0.6, c='green')
    axes[0, i].set_title(f"Sample {i}: Diffusion\ncorr={diffusion_results[i]['avg_corr']:.3f}")
    axes[0, i].set_aspect('equal')
    axes[0, i].grid(True, alpha=0.3)
    
    # Ground truth
    axes[1, i].scatter(diffusion_results[i]['coords_target'][:, 0],
                      diffusion_results[i]['coords_target'][:, 1],
                      s=10, alpha=0.6, c='red')
    axes[1, i].set_title(f"Ground Truth")
    axes[1, i].set_aspect('equal')
    axes[1, i].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('diffusion_vs_groundtruth.png', dpi=150, bbox_inches='tight')
plt.show()

print("\n✓ Plot saved: diffusion_vs_groundtruth.png")

In [None]:
# ============================================================================
# DIFFUSION INFERENCE ON ST MINISETS - FIXED
# ============================================================================

import torch
import numpy as np
import pandas as pd
import anndata as ad
import matplotlib.pyplot as plt
import os

from core_models_et_p3 import GEMSModel
import utils_et as uet

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
output_dir = '/home/ehtesamul/sc_st/model/gems_mousebrain_output'
checkpoint_path = os.path.join(output_dir, 'phase2_sc_finetuned_checkpoint.pt')

print("="*80)
print("DIFFUSION MODEL INFERENCE ON ST MINISETS")
print("="*80)

# ============================================================================
# LOAD ST DATA
# ============================================================================

st_counts = '/home/ehtesamul/sc_st/data/mousedata_2020/E1z2/simu_st1_counts_et.csv'
st_meta   = '/home/ehtesamul/sc_st/data/mousedata_2020/E1z2/simu_st1_metadata_et.csv'

print("\nLoading ST1 data...")
st_expr_df = pd.read_csv(st_counts, index_col=0)
st_meta_df = pd.read_csv(st_meta, index_col=0)

stadata = ad.AnnData(X=st_expr_df.values.T)
stadata.obs_names = st_expr_df.columns
stadata.var_names = st_expr_df.index
stadata.obsm['spatial'] = st_meta_df[['coord_x', 'coord_y']].values

X_st = stadata.X
if hasattr(X_st, "toarray"):
    X_st = X_st.toarray()

st_expr = torch.tensor(X_st, dtype=torch.float32, device=device)
st_coords_raw = torch.tensor(stadata.obsm['spatial'], dtype=torch.float32, device=device)

slide_ids = torch.zeros(st_expr.shape[0], dtype=torch.long, device=device)
st_coords, st_mu, st_scale = uet.canonicalize_st_coords_per_slide(
    st_coords_raw, slide_ids
)

print(f"ST loaded: {stadata.shape[0]} spots, {stadata.shape[1]} genes")

# ============================================================================
# LOAD MODEL
# ============================================================================

n_genes = stadata.shape[1]

model = GEMSModel(
    n_genes=n_genes,
    n_embedding=[512, 256, 128],
    D_latent=32,
    c_dim=256,
    n_heads=4,
    isab_m=64,
    device=str(device),
    use_canonicalize=True,
    use_dist_bias=True,
    dist_bins=24,
    dist_head_shared=True,
    use_angle_features=True,
    angle_bins=8,
    knn_k=12,
    self_conditioning=True,
    sc_feat_mode='concat',
    landmarks_L=16,
)

print(f"\nLoading checkpoint: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=device)

model.encoder.load_state_dict(checkpoint['encoder'])
model.context_encoder.load_state_dict(checkpoint['context_encoder'])
model.generator.load_state_dict(checkpoint['generator'])
model.score_net.load_state_dict(checkpoint['score_net'])

print(f"✓ Loaded Phase 1 checkpoint (best epoch: {checkpoint.get('E_ST_best', 'N/A')})")

model.encoder.eval()
model.context_encoder.eval()
model.score_net.eval()

# ============================================================================
# SAMPLE ST MINISETS AND RUN DIFFUSION INFERENCE
# ============================================================================

num_eval_samples = 10
diffusion_results = []

print("\n--- Running diffusion inference on ST minisets ---\n")

np.random.seed(42)

for i in range(num_eval_samples):
    # Sample random miniset (same logic as STSetDataset)
    n_min, n_max = 192, 384
    n_total = st_coords.shape[0]
    
    # Random subset size
    n = np.random.randint(n_min, min(n_max + 1, n_total))
    
    # Random indices
    indices = torch.randperm(n_total)[:n]
    
    # Get gene expression and coords for this miniset
    miniset_expr = st_expr[indices].cpu()
    miniset_coords = st_coords[indices].cpu()
    
    # Compute ground truth EDM
    D_target = torch.cdist(miniset_coords, miniset_coords)
    
    print(f"Sample {i}: Running diffusion on {n} points...")
    
    # Run inference with single patch (no stitching)
    with torch.no_grad():
        inf_results = model.infer_sc_patchwise(
            sc_gene_expr=miniset_expr,
            n_timesteps_sample=300,
            sigma_min=0.01,
            sigma_max=7.0,
            patch_size=n,            # Single patch
            coverage_per_cell=1.0,   # No overlap
            n_align_iters=1,         # No alignment
            eta=0.0,
            guidance_scale=6.0,
            return_coords=True,
            debug_flag=False,
        )
    
    coords_diffusion = inf_results['coords_canon']
    
    # Ground truth coords via MDS
    Jn = torch.eye(n) - torch.ones(n, n) / n
    B_target = -0.5 * (Jn @ (D_target**2) @ Jn)
    coords_target = uet.classical_mds(B_target, d_out=2)
    coords_target_canon = uet.canonicalize_coords(coords_target)
    
    # Compute correlations
    corr_x = np.corrcoef(coords_diffusion[:, 0].numpy(), coords_target_canon[:, 0].numpy())[0, 1]
    corr_y = np.corrcoef(coords_diffusion[:, 1].numpy(), coords_target_canon[:, 1].numpy())[0, 1]
    avg_corr = (abs(corr_x) + abs(corr_y)) / 2.0
    
    # EDM correlation
    D_diffusion = torch.cdist(coords_diffusion.unsqueeze(0), coords_diffusion.unsqueeze(0)).squeeze(0)
    edm_corr = np.corrcoef(
        D_diffusion.flatten().numpy(),
        D_target.flatten().numpy()
    )[0, 1]
    
    diffusion_results.append({
        'sample': i,
        'n_points': n,
        'corr_x': corr_x,
        'corr_y': corr_y,
        'avg_corr': avg_corr,
        'edm_corr': edm_corr,
        'coords_diffusion': coords_diffusion.numpy(),
        'coords_target': coords_target_canon.numpy()
    })
    
    print(f"  EDM_corr={edm_corr:.4f} | Coord: x={corr_x:.4f}, y={corr_y:.4f}, avg={avg_corr:.4f}\n")

print("="*80)
print(f"\nDiffusion Results (avg over {num_eval_samples} samples):")
print(f"  EDM correlation:   {np.mean([r['edm_corr'] for r in diffusion_results]):.4f}")
print(f"  Coord correlation: {np.mean([r['avg_corr'] for r in diffusion_results]):.4f}")
print("="*80)

# ============================================================================
# PLOT - 3 COLUMNS MAX PER ROW
# ============================================================================

n_cols = min(3, num_eval_samples)
n_rows = int(np.ceil(num_eval_samples / n_cols)) * 2  # *2 for diffusion + GT rows

fig, axes = plt.subplots(n_rows, n_cols, figsize=(5*n_cols, 4*n_rows))

# Handle single row case
if n_rows == 1:
    axes = axes.reshape(1, -1)
if n_cols == 1:
    axes = axes.reshape(-1, 1)

for i in range(num_eval_samples):
    row_pair = (i // n_cols) * 2  # Which pair of rows (diffusion + GT)
    col = i % n_cols
    
    # Diffusion prediction
    ax_diff = axes[row_pair, col]
    ax_diff.scatter(diffusion_results[i]['coords_diffusion'][:, 0],
                   diffusion_results[i]['coords_diffusion'][:, 1],
                   s=10, alpha=0.6, c='green')
    ax_diff.set_title(f"Sample {i}: Diffusion\n"
                     f"Coord: {diffusion_results[i]['avg_corr']:.3f} | "
                     f"EDM: {diffusion_results[i]['edm_corr']:.3f}",
                     fontsize=10)
    ax_diff.set_aspect('equal')
    ax_diff.grid(True, alpha=0.3)
    
    # Ground truth
    ax_gt = axes[row_pair + 1, col]
    ax_gt.scatter(diffusion_results[i]['coords_target'][:, 0],
                 diffusion_results[i]['coords_target'][:, 1],
                 s=10, alpha=0.6, c='red')
    ax_gt.set_title(f"Ground Truth (n={diffusion_results[i]['n_points']})",
                   fontsize=10)
    ax_gt.set_aspect('equal')
    ax_gt.grid(True, alpha=0.3)

# Hide unused subplots
for i in range(num_eval_samples, n_rows // 2 * n_cols):
    row_pair = (i // n_cols) * 2
    col = i % n_cols
    axes[row_pair, col].axis('off')
    axes[row_pair + 1, col].axis('off')

plt.tight_layout()
plt.show()

In [None]:
# ============================================================================
# SMART OUTLIER REMOVAL - DISTANCE-BASED METHOD
# ============================================================================

import numpy as np
import matplotlib.pyplot as plt
import torch

print("\n" + "="*80)
print("OUTLIER REMOVAL - DISTANCE FROM MEDIAN CENTER")
print("="*80)
print("\nRationale:")
print("- Diffusion occasionally samples points in low-probability tail regions")
print("- Ground truth tissue has consistent density (filled region)")
print("- Outliers are scattered points FAR from main cluster")
print("- Method: Keep only points within 90th percentile distance from median center")
print("- Why median? Robust to outliers (unlike mean)")
print("- Why 90th percentile? Keeps main distribution, removes extreme tail")
print("="*80 + "\n")

def remove_outliers_distance_percentile(coords, coords_target, percentile=90):
    """
    Remove outliers based on distance from median center.
    
    Strategy:
    1. Find median center (robust to outliers)
    2. Compute distance of each point from center
    3. Keep only points within `percentile` of distances
    4. Filter both predicted and target coords to match
    
    Args:
        coords: (n, 2) predicted coordinates
        coords_target: (n, 2) target coordinates
        percentile: keep points within this percentile (90 = remove top 10%)
    
    Returns:
        coords_clean, coords_target_clean, inlier_mask
    """
    # Use MEDIAN center (robust to outliers, unlike mean)
    center = np.median(coords, axis=0)
    
    # Distance from center for each point
    dists = np.linalg.norm(coords - center, axis=1)
    
    # Threshold: keep only points within percentile
    threshold = np.percentile(dists, percentile)
    
    # Inlier mask
    inlier_mask = dists <= threshold
    
    # Filter both predicted and target
    coords_clean = coords[inlier_mask]
    coords_target_clean = coords_target[inlier_mask]
    
    return coords_clean, coords_target_clean, inlier_mask, threshold

# ============================================================================
# CLEAN EACH SAMPLE
# ============================================================================

diffusion_results_clean = []

for i, res in enumerate(diffusion_results):
    coords_pred = res['coords_diffusion']
    coords_gt = res['coords_target']
    n_orig = len(coords_pred)
    
    # Remove outliers
    coords_clean, coords_gt_clean, mask, thresh = remove_outliers_distance_percentile(
        coords_pred, coords_gt, percentile=90
    )
    
    n_kept = len(coords_clean)
    n_removed = n_orig - n_kept
    pct_removed = 100 * n_removed / n_orig
    
    # Recanonialize after filtering
    coords_clean_t = torch.from_numpy(coords_clean).float()
    coords_gt_t = torch.from_numpy(coords_gt_clean).float()
    
    coords_clean_canon = uet.canonicalize_coords(coords_clean_t).numpy()
    coords_gt_canon = uet.canonicalize_coords(coords_gt_t).numpy()
    
    # Recompute correlations
    corr_x_before = res['corr_x']
    corr_y_before = res['corr_y']
    avg_corr_before = res['avg_corr']
    edm_corr_before = res['edm_corr']
    
    corr_x = np.corrcoef(coords_clean_canon[:, 0], coords_gt_canon[:, 0])[0, 1]
    corr_y = np.corrcoef(coords_clean_canon[:, 1], coords_gt_canon[:, 1])[0, 1]
    avg_corr = (abs(corr_x) + abs(corr_y)) / 2.0
    
    # EDM correlation
    D_clean = torch.cdist(
        torch.from_numpy(coords_clean_canon).unsqueeze(0).float(),
        torch.from_numpy(coords_clean_canon).unsqueeze(0).float()
    ).squeeze(0)
    D_gt = torch.cdist(
        torch.from_numpy(coords_gt_canon).unsqueeze(0).float(),
        torch.from_numpy(coords_gt_canon).unsqueeze(0).float()
    ).squeeze(0)
    
    edm_corr = np.corrcoef(D_clean.flatten().numpy(), D_gt.flatten().numpy())[0, 1]
    
    # Store
    diffusion_results_clean.append({
        'sample': i,
        'n_points': n_kept,
        'n_removed': n_removed,
        'pct_removed': pct_removed,
        'corr_x': corr_x,
        'corr_y': corr_y,
        'avg_corr': avg_corr,
        'edm_corr': edm_corr,
        'coords_diffusion': coords_clean_canon,
        'coords_target': coords_gt_canon,
        'threshold': thresh
    })
    
    # Print results
    print(f"Sample {i}: removed {n_removed}/{n_orig} outliers ({pct_removed:.1f}%), "
          f"threshold={thresh:.3f}")
    print(f"  Before: Coord={avg_corr_before:.3f}, EDM={edm_corr_before:.3f}")
    print(f"  After:  Coord={avg_corr:.3f} (Δ={avg_corr-avg_corr_before:+.3f}), "
          f"EDM={edm_corr:.3f} (Δ={edm_corr-edm_corr_before:+.3f})\n")

# ============================================================================
# SUMMARY STATISTICS
# ============================================================================

print("="*80)
print("SUMMARY: BEFORE vs AFTER OUTLIER REMOVAL")
print("="*80 + "\n")

avg_coord_before = np.mean([r['avg_corr'] for r in diffusion_results])
avg_edm_before = np.mean([r['edm_corr'] for r in diffusion_results])

avg_coord_after = np.mean([r['avg_corr'] for r in diffusion_results_clean])
avg_edm_after = np.mean([r['edm_corr'] for r in diffusion_results_clean])

avg_pct_removed = np.mean([r['pct_removed'] for r in diffusion_results_clean])

print(f"Average Coordinate Correlation:")
print(f"  Before: {avg_coord_before:.4f}")
print(f"  After:  {avg_coord_after:.4f} (Δ={avg_coord_after-avg_coord_before:+.4f})")

print(f"\nAverage EDM Correlation:")
print(f"  Before: {avg_edm_before:.4f}")
print(f"  After:  {avg_edm_after:.4f} (Δ={avg_edm_after-avg_edm_before:+.4f})")

print(f"\nAverage outliers removed: {avg_pct_removed:.1f}%")

print("\n" + "="*80)
print("INTERPRETATION:")
print("="*80)
if avg_edm_after - avg_edm_before > 0.1:
    print("✓ EDM correlation IMPROVED significantly after outlier removal")
    print("  → Confirms outliers were corrupting distance metrics")
    print("  → Main cluster has better geometric structure than raw output")
elif avg_edm_after - avg_edm_before > 0:
    print("✓ EDM correlation improved slightly")
    print("  → Outliers had some negative effect on distances")
else:
    print("⚠ EDM correlation unchanged or decreased")
    print("  → Problem is not just outliers, geometry of main cluster needs work")
print("="*80 + "\n")

# ============================================================================
# SIMPLE PLOT: PREDICTED vs GROUND TRUTH (AFTER OUTLIER REMOVAL)
# ============================================================================

n_cols = 3
n_rows = int(np.ceil(num_eval_samples / n_cols)) * 2

fig, axes = plt.subplots(n_rows, n_cols, figsize=(5*n_cols, 4*n_rows))

if n_rows == 1:
    axes = axes.reshape(1, -1)
if n_cols == 1:
    axes = axes.reshape(-1, 1)

for i in range(num_eval_samples):
    row_pair = (i // n_cols) * 2
    col = i % n_cols
    
    # Predicted (cleaned)
    ax_pred = axes[row_pair, col]
    ax_pred.scatter(diffusion_results_clean[i]['coords_diffusion'][:, 0],
                   diffusion_results_clean[i]['coords_diffusion'][:, 1],
                   s=10, alpha=0.7, c='#2ecc71', edgecolors='none')
    ax_pred.set_title(f"Sample {i}: Predicted\n"
                     f"Coord: {diffusion_results_clean[i]['avg_corr']:.3f} | "
                     f"EDM: {diffusion_results_clean[i]['edm_corr']:.3f}",
                     fontsize=10)
    ax_pred.set_aspect('equal')
    ax_pred.grid(True, alpha=0.2)
    
    # Ground Truth
    ax_gt = axes[row_pair + 1, col]
    ax_gt.scatter(diffusion_results_clean[i]['coords_target'][:, 0],
                 diffusion_results_clean[i]['coords_target'][:, 1],
                 s=10, alpha=0.7, c='#e74c3c', edgecolors='none')
    ax_gt.set_title(f"Ground Truth (n={diffusion_results_clean[i]['n_points']})",
                   fontsize=10)
    ax_gt.set_aspect('equal')
    ax_gt.grid(True, alpha=0.2)

# Hide unused subplots
for i in range(num_eval_samples, n_rows // 2 * n_cols):
    row_pair = (i // n_cols) * 2
    col = i % n_cols
    axes[row_pair, col].axis('off')
    axes[row_pair + 1, col].axis('off')

plt.tight_layout()
# plt.savefig('cleaned_results.png', dpi=200, bbox_inches='tight')
plt.show()

# print("✓ Saved plot: outlier_removal_comparison.png")