In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"


from threadpoolctl import threadpool_limits
threadpool_limits(limits=1, user_api='blas')

import scanpy as sc
import pandas as pd
import torch
import numpy as np
import scipy
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

# Import new GEMS model components
from core_models_et_p1 import SharedEncoder, STStageBPrecomputer, STSetDataset
from core_models_et_p2 import SetEncoderContext, DiffusionScoreNet
from core_models_et_p3 import GEMSModel

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load SC data
print("Loading SC data...")
scdata = pd.read_csv('/home/ehtesamul/sc_st/data/mousedata_2020/E1z2/simu_sc_counts.csv', index_col=0)
scdata = scdata.T
scmetadata = pd.read_csv('/home/ehtesamul/sc_st/data/mousedata_2020/E1z2/metadata.csv', index_col=0)
# Load ST data  
print("Loading ST data...")
stdata = pd.read_csv('/home/ehtesamul/sc_st/data/mousedata_2020/E1z2/simu_st_counts.csv', index_col=0)

stdata = stdata.T
spcoor = pd.read_csv('/home/ehtesamul/sc_st/data/mousedata_2020/E1z2/simu_st_metadata.csv', index_col=0)
stgtcelltype = pd.read_csv('/home/ehtesamul/sc_st/data/mousedata_2020/E1z2/simu_st_celltype.csv', index_col=0)

print(f"SC data shape: {scdata.shape}")
print(f"ST data shape: {stdata.shape}")
print(f"ST coords shape: {spcoor.shape}")
print(f"ST celltype shape: {stgtcelltype.shape}")

Loading SC data...
Loading ST data...
SC data shape: (10150, 351)
ST data shape: (581, 351)
ST coords shape: (581, 2)
ST celltype shape: (581, 23)


In [3]:
# Create SC AnnData
scadata = sc.AnnData(scdata, obs=scmetadata)
sc.pp.normalize_total(scadata)
sc.pp.log1p(scadata)

# Add spatial coordinates from metadata
scadata.obsm['spatial'] = scmetadata[['x_global', 'y_global']].values

print(f"SC AnnData: {scadata}")

# Create ST AnnData
stadata = sc.AnnData(stdata)
sc.pp.normalize_total(stadata)
sc.pp.log1p(stadata)

# Add spatial coordinates
stadata.obsm['spatial'] = spcoor[['coord_x', 'coord_y']].values

# Process cell type information
cell_type_columns = stgtcelltype.columns
dominant_celltypes = []

for i in range(stgtcelltype.shape[0]):
    cell_types_present = [col for col, val in zip(cell_type_columns, stgtcelltype.iloc[i]) if val > 0]
    dominant_celltype = cell_types_present[0] if cell_types_present else 'Unknown'
    dominant_celltypes.append(dominant_celltype)

stadata.obs['celltype'] = dominant_celltypes

print(f"ST AnnData: {stadata}")
print(f"ST cell types: {stadata.obs['celltype'].value_counts()}")

SC AnnData: AnnData object with n_obs × n_vars = 10150 × 351
    obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'uniqueID', 'embryo', 'pos', 'z', 'x_global', 'y_global', 'x_global_affine', 'y_global_affine', 'embryo_pos', 'embryo_pos_z', 'Area', 'UMAP1', 'UMAP2', 'celltype_mapped_refined'
    uns: 'log1p'
    obsm: 'spatial'
ST AnnData: AnnData object with n_obs × n_vars = 581 × 351
    obs: 'celltype'
    uns: 'log1p'
    obsm: 'spatial'
ST cell types: celltype
Forebrain/Midbrain/Hindbrain      99
Endothelium                       65
Gut tube                          49
Spinal cord                       45
Cranial mesoderm                  41
Cardiomyocytes                    37
Dermomyotome                      30
Intermediate mesoderm             30
Neural crest                      29
Haematoendothelial progenitors    22
Definitive endoderm               22
Lateral plate mesoderm            22
Splanchnic mesoderm               15
Presomitic mesoderm               14
Anterior somi

In [4]:
def train_gems_mousebrain(scadata, stadata, output_dir='gems_mousebrain_output', device='cuda'):
    """
    Train GEMS model for mouse brain data with mixed ST/SC training.
    """
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    
    print("\n" + "="*70)
    print("GEMS MOUSE BRAIN TRAINING (MIXED ST/SC)")
    print("="*70)
    print(f"Device: {device}")
    
    # Get common genes
    sc_genes = set(scadata.var_names)
    st_genes = set(stadata.var_names)
    common_genes = sorted(list(sc_genes & st_genes))
    n_genes = len(common_genes)
    
    print(f"Common genes: {n_genes}")
    
    # Extract expression data
    sc_expr = scadata[:, common_genes].X
    st_expr = stadata[:, common_genes].X
    
    # Convert to dense if sparse
    if hasattr(sc_expr, 'toarray'):
        sc_expr = sc_expr.toarray()
    if hasattr(st_expr, 'toarray'):
        st_expr = st_expr.toarray()
    
    # Get spatial coordinates
    st_coords = stadata.obsm['spatial']
    
    print(f"SC expression shape: {sc_expr.shape}")
    print(f"ST expression shape: {st_expr.shape}")
    print(f"ST coords shape: {st_coords.shape}")
    
    # Convert to tensors
    sc_expr_tensor = torch.tensor(sc_expr, dtype=torch.float32)
    st_expr_tensor = torch.tensor(st_expr, dtype=torch.float32)
    st_coords_tensor = torch.tensor(st_coords, dtype=torch.float32)
    
    # Single slide - all zeros for slide IDs
    slide_ids = torch.zeros(st_expr.shape[0], dtype=torch.long)
    
    # Prepare slide dictionary for Stage B
    slides_dict = {
        0: (st_coords_tensor, st_expr_tensor)
    }
    
    # Prepare gene expression dictionary for Stage C
    st_gene_expr_dict = {
        0: st_expr_tensor
    }
    
    # Initialize GEMS model
    model = GEMSModel(
        n_genes=n_genes,
        n_embedding=[512, 256, 128],
        D_latent=16,
        c_dim=256,
        n_heads=4,
        isab_m=64,
        device=device
    )
    
    # ========================================================================
    # STAGE A: Train Shared Encoder
    # ========================================================================
    print("\n" + "="*70)
    print("STAGE A: Training Shared Encoder")
    print("="*70)
    
    model.train_stageA(
        st_gene_expr=st_expr_tensor,
        st_coords=st_coords_tensor,
        sc_gene_expr=sc_expr_tensor,
        slide_ids=slide_ids,
        n_epochs=1000,
        batch_size=256,
        lr=0.0001,
        sigma=None,
        alpha=0.8,
        ratio_start=0.0,
        ratio_end=1.0,
        mmdbatch=1.0,
        outf=output_dir
    )
    
    # ========================================================================
    # STAGE B: Precompute Geometric Targets (with whitening + geodesic)
    # ========================================================================
    print("\n" + "="*70)
    print("STAGE B: Precomputing Geometric Targets")
    print("="*70)
    
    slides_dict_device = {
        sid: (coords.to(device), expr.to(device))
        for sid, (coords, expr) in slides_dict.items()
    }
    
    model.train_stageB(
        slides=slides_dict_device,
        outdir=str(Path(output_dir) / 'stage_b_cache')
    )
    
    # ========================================================================
    # STAGE C: Train Diffusion Generator (Mixed ST/SC with new losses)
    # ========================================================================
    print("\n" + "="*70)
    print("STAGE C: Training Diffusion Generator (Mixed ST/SC)")
    print("="*70)
    
    st_gene_expr_dict_device = {
        sid: expr.to(device)
        for sid, expr in st_gene_expr_dict.items()
    }
    
    model.train_stageC(
        st_gene_expr_dict=st_gene_expr_dict_device,
        sc_gene_expr=sc_expr_tensor,  # NEW: SC data for mixed training
        n_min=64,
        n_max=192,
        num_st_samples=600,  # NEW: separate ST sample count
        num_sc_samples=9,   # NEW: SC sample count
        n_epochs=20,         # Increased epochs
        batch_size=4,          # Smaller batch for mixed training
        lr=1e-4,
        n_timesteps=600,
        sigma_min=0.01,
        sigma_max=5.0,         # CHANGED: reduced from 50.0
        outf=output_dir
    )
    
    # Save model
    model.save(str(Path(output_dir) / 'gems_model_mousebrain.pt'))
    
    return model, common_genes

In [5]:
# ============================================================================
# TRAINING
# ============================================================================

print("Starting GEMS training with mixed ST/SC regimen...")
model, common_genes = train_gems_mousebrain(scadata, stadata, device='cuda')
print("\nTraining complete! Model saved.")

Starting GEMS training with mixed ST/SC regimen...

GEMS MOUSE BRAIN TRAINING (MIXED ST/SC)
Device: cuda
Common genes: 351
SC expression shape: (10150, 351)
ST expression shape: (581, 351)
ST coords shape: (581, 2)
GEMS Model initialized:
  Encoder: 351 → [512, 256, 128]
  D_latent: 16
  Context dim: 256
  ISAB inducing points: 64

STAGE A: Training Shared Encoder

STAGE A: Training Shared Encoder
Auto-computed sigma = 0.1761
Training encoder for 1000 epochs...
Epoch 0/1000 | Loss: 59.0723 | Pred: 59.0661 | Circle: 2.9948 | MMD: 0.0078
Epoch 100/1000 | Loss: 4.0427 | Pred: 3.7182 | Circle: 2.5459 | MMD: 0.0078
Epoch 200/1000 | Loss: 2.7264 | Pred: 2.0791 | Circle: 2.5637 | MMD: 0.0080
Epoch 300/1000 | Loss: 2.2325 | Pred: 1.2894 | Circle: 2.4976 | MMD: 0.0081
Epoch 400/1000 | Loss: 2.0397 | Pred: 0.8187 | Circle: 2.4289 | MMD: 0.0082
Epoch 500/1000 | Loss: 2.1279 | Pred: 0.6413 | Circle: 2.3677 | MMD: 0.0084
Epoch 600/1000 | Loss: 2.2469 | Pred: 0.5684 | Circle: 2.2288 | MMD: 0.0086
Ep

Epoch 1/20: 100%|██████████| 153/153 [01:42<00:00,  1.50it/s]
Epoch 2/20: 100%|██████████| 153/153 [01:45<00:00,  1.45it/s]
Epoch 3/20: 100%|██████████| 153/153 [01:41<00:00,  1.51it/s]
Epoch 4/20: 100%|██████████| 153/153 [01:44<00:00,  1.47it/s]
Epoch 5/20: 100%|██████████| 153/153 [01:43<00:00,  1.48it/s]



Epoch 5/20 | Total: 13267.6984 | Score: 13150.3667 | Gram: 231.6238 | Heat: 0.0000 | SW_ST: 0.0022 | SW_SC: 0.0044 | Overlap: 3.4030 | Ord_SC: 1.3331


Epoch 6/20: 100%|██████████| 153/153 [01:42<00:00,  1.50it/s]
Epoch 7/20: 100%|██████████| 153/153 [01:46<00:00,  1.44it/s]
Epoch 8/20: 100%|██████████| 153/153 [01:44<00:00,  1.46it/s]
Epoch 9/20: 100%|██████████| 153/153 [01:41<00:00,  1.50it/s]
Epoch 10/20: 100%|██████████| 153/153 [01:45<00:00,  1.46it/s]



Epoch 10/20 | Total: 12814.2795 | Score: 12719.2944 | Gram: 188.6810 | Heat: 0.0000 | SW_ST: 0.0019 | SW_SC: 0.0038 | Overlap: 1.2917 | Ord_SC: 0.6392


Epoch 11/20: 100%|██████████| 153/153 [01:38<00:00,  1.55it/s]
Epoch 12/20: 100%|██████████| 153/153 [01:46<00:00,  1.44it/s]
Epoch 13/20: 100%|██████████| 153/153 [01:49<00:00,  1.40it/s]
Epoch 14/20: 100%|██████████| 153/153 [01:46<00:00,  1.43it/s]
Epoch 15/20: 100%|██████████| 153/153 [01:38<00:00,  1.56it/s]



Epoch 15/20 | Total: 12372.9133 | Score: 12264.6566 | Gram: 215.2868 | Heat: 0.0000 | SW_ST: 0.0019 | SW_SC: 0.0034 | Overlap: 1.3145 | Ord_SC: 0.5654


Epoch 16/20: 100%|██████████| 153/153 [01:41<00:00,  1.51it/s]
Epoch 17/20: 100%|██████████| 153/153 [01:42<00:00,  1.50it/s]
Epoch 18/20: 100%|██████████| 153/153 [01:43<00:00,  1.48it/s]
Epoch 19/20: 100%|██████████| 153/153 [01:40<00:00,  1.52it/s]
Epoch 20/20: 100%|██████████| 153/153 [01:46<00:00,  1.44it/s]


Epoch 20/20 | Total: 12560.2731 | Score: 12476.3698 | Gram: 167.0511 | Heat: 0.0000 | SW_ST: 0.0019 | SW_SC: 0.0032 | Overlap: 0.6764 | Ord_SC: 0.4138
Training complete!
Stage C complete.
Model saved to gems_mousebrain_output/gems_model_mousebrain.pt

Training complete! Model saved.





In [None]:
# ============================================================================
# INFERENCE
# ============================================================================

print("\n" + "="*70)
print("SC COORDINATE INFERENCE (ANCHOR-CONDITIONED)")
print("="*70)

# Prepare SC expression tensor
sc_expr = scadata[:, common_genes].X
if hasattr(sc_expr, 'toarray'):
    sc_expr = sc_expr.toarray()
sc_expr_tensor = torch.tensor(sc_expr, dtype=torch.float32)

print(f"SC data shape: {sc_expr_tensor.shape}")

# Clear cache before inference
if torch.cuda.is_available():
    torch.cuda.empty_cache()
import gc
gc.collect()

# Run anchor-conditioned inference
results = model.infer_sc_anchored(
    sc_gene_expr=sc_expr_tensor,
    n_timesteps_sample=160,    # Recommended: 120-200 for 20GB GPU
    return_coords=True,
    anchor_size=384,           # Recommended: 256-384
    batch_size=512,            # Recommended: 384-512
    eta=0.0                    # Deterministic (safer)
)

print(f"\nInference complete:")
print(f"  D_edm shape: {results['D_edm'].shape}")
if 'coords_canon' in results:
    print(f"  Coordinates shape: {results['coords_canon'].shape}")

# Add coordinates to scadata
coords_canon = results['coords_canon'].numpy()
scadata.obsm['gems_coords'] = coords_canon

print(f"\nGenerated coordinates added to scadata.obsm['gems_coords']")
print(f"Shape: {scadata.obsm['gems_coords'].shape}")

In [None]:
# ============================================================================
# VISUALIZATION
# ============================================================================

import matplotlib.pyplot as plt
import seaborn as sns

fig, axes = plt.subplots(1, 2, figsize=(16, 7))

# Plot 1: GEMS coordinates colored by cell type
if 'cell_type' in scadata.obs.columns:
    cell_types = scadata.obs['cell_type']
    unique_types = cell_types.unique()
    
    for ct in unique_types:
        mask = cell_types == ct
        axes[0].scatter(
            coords_canon[mask, 0], 
            coords_canon[mask, 1],
            s=1, 
            alpha=0.6, 
            label=ct
        )
    
    axes[0].set_title('GEMS Coordinates (by cell type)', fontsize=14)
    axes[0].set_xlabel('GEMS Dim 1')
    axes[0].set_ylabel('GEMS Dim 2')
    axes[0].legend(markerscale=5, fontsize=8, loc='best')
else:
    axes[0].scatter(coords_canon[:, 0], coords_canon[:, 1], s=1, alpha=0.6)
    axes[0].set_title('GEMS Coordinates', fontsize=14)
    axes[0].set_xlabel('GEMS Dim 1')
    axes[0].set_ylabel('GEMS Dim 2')

axes[0].axis('equal')

# Plot 2: Distance distribution sanity check
D_edm = results['D_edm'].numpy()
upper_tri_idx = np.triu_indices_from(D_edm, k=1)
distances = D_edm[upper_tri_idx]

axes[1].hist(distances, bins=100, alpha=0.7, edgecolor='black')
axes[1].set_title('Distance Distribution (EDM)', fontsize=14)
axes[1].set_xlabel('Distance')
axes[1].set_ylabel('Count')
axes[1].axvline(distances.mean(), color='r', linestyle='--', label=f'Mean: {distances.mean():.2f}')
axes[1].axvline(np.median(distances), color='g', linestyle='--', label=f'Median: {np.median(distances):.2f}')
axes[1].legend()

plt.tight_layout()
plt.savefig(str(Path(output_dir) / 'gems_inference_results.png'), dpi=300, bbox_inches='tight')
plt.show()

print(f"\nVisualization saved to {output_dir}/gems_inference_results.png")

# Save results
results_save = {
    'coords': coords_canon,
    'D_edm': D_edm,
    'common_genes': common_genes
}
torch.save(results_save, str(Path(output_dir) / 'gems_inference_results.pt'))
print(f"Results saved to {output_dir}/gems_inference_results.pt")

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

# Get the distance matrix from results
D = results['D_edm'].cpu().numpy()  # Shape: (n_cells, n_cells)
n = D.shape[0]

# Build centering matrix H = I - (1/n) * ones
H = np.eye(n) - np.ones((n, n)) / n

# Double-center: G = -0.5 * H * D^2 * H
D_squared = D ** 2
G = -0.5 * H @ D_squared @ H

# Compute eigenvalues (only top 20 to save time)
eigenvalues = np.linalg.eigvalsh(G)
eigenvalues = np.sort(eigenvalues)[::-1]  # Sort descending
top_20 = eigenvalues[:20]

print("Top 20 eigenvalues:")
print(top_20)
print(f"\nTop 10 eigenvalues: {top_20[:10]}")
print(f"Ratio λ1/λ10: {top_20[0] / top_20[9]:.2f}")

# Plot
plt.figure(figsize=(10, 4))

plt.subplot(1, 2, 1)
plt.plot(range(1, 21), top_20, 'o-')
plt.xlabel('Eigenvalue rank')
plt.ylabel('Eigenvalue')
plt.title('Top 20 Eigenvalues')
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(range(1, 11), top_20[:10], 'o-', color='red')
plt.xlabel('Eigenvalue rank')
plt.ylabel('Eigenvalue')
plt.title('Top 10 Eigenvalues (Blob Check)')
plt.grid(True)

plt.tight_layout()
plt.show()

# Blob indicator check
if top_20[0] / top_20[9] < 3.0:
    print("\n⚠️ BLOB PATTERN: Top 10 eigenvalues are similar → coordinates may collapse to a blob")
else:
    print("\n✓ Good eigenvalue spread → coordinates should be well-separated")

In [None]:
# Set up plotting
plt.rcParams['figure.figsize'] = (6,5)

# Check if celltype exists in scadata
if 'celltype' not in scadata.obs.columns:
    # Use the cell type from metadata if available
    if 'celltype_mapped_refined' in scmetadata.columns:
        scadata.obs['celltype'] = scmetadata['celltype_mapped_refined'].values
    else:
        print("No celltype column found. Creating dummy column.")
        scadata.obs['celltype'] = 'Unknown'

# Get number of unique cell types and create palette
n_celltypes = scadata.obs['celltype'].nunique()
my_tab20 = sns.color_palette("tab20", n_colors=n_celltypes).as_hex()

# Plot 1: Original SC coordinates (if they exist)
if 'x_global' in scadata.obs.columns:
    fig = plt.figure(figsize=(6, 3))
    sc.pl.embedding(
        scadata, 
        basis='spatial', 
        color='celltype',
        title='Original SC Coordinates',
        size=60,
        palette=my_tab20,
        legend_loc='right margin',
        show=True
    )

# Plot 2: Generated GEMS coordinates
fig = plt.figure(figsize=(6, 3))
sc.pl.embedding(
    scadata,
    basis='gems_coords_avg',
    color='celltype',
    title='Generated GEMS Coordinates - Mouse Brain',
    size=60,
    palette=my_tab20,
    legend_loc='right margin',
    show=True
)