In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import scanpy as sc
import pandas as pd
import torch
import scipy
import time
from sklearn.neighbors import RadiusNeighborsTransformer
from model.core_models_v2 import AdvancedHierarchicalDiffusion

In [None]:
# Load SC data
print("Loading SC data...")
scdata = pd.read_csv('./data/mousedata_2020/E1z2/simu_sc_counts.csv', index_col=0)
scdata = scdata.T
scmetadata = pd.read_csv('./data/mousedata_2020/E1z2/metadata.csv', index_col=0)

# Load ST data  
print("Loading ST data...")
stdata = pd.read_csv('data/mousedata_2020/E1z2/simu_st_counts_et.csv', index_col=0)
stdata = stdata.T
spcoor = pd.read_csv('./data/mousedata_2020/E1z2/simu_st_metadata_et.csv', index_col=0)
stgtcelltype = pd.read_csv('./data/mousedata_2020/E1z2/simu_st_celltype_et.csv', index_col=0)

print(f"SC data shape: {scdata.shape}")
print(f"ST data shape: {stdata.shape}")
print(f"ST coords shape: {spcoor.shape}")
print(f"ST celltype shape: {stgtcelltype.shape}")

In [None]:
# Create SC AnnData
scadata = sc.AnnData(scdata, obs=scmetadata)
sc.pp.normalize_total(scadata)
sc.pp.log1p(scadata)

# Add spatial coordinates from metadata
scadata.obsm['spatial'] = scmetadata[['x_global', 'y_global']].values

print(f"SC AnnData: {scadata}")
print(f"SC cell types: {scadata.obs.columns.tolist()}")

# ===================================================================
# Process ST Data
# ===================================================================

# Create ST AnnData
stadata = sc.AnnData(stdata)
sc.pp.normalize_total(stadata)
sc.pp.log1p(stadata)

# Add spatial coordinates
stadata.obsm['spatial'] = spcoor[['coord_x', 'coord_y']].values

# Process cell type information for ST data
cell_type_columns = stgtcelltype.columns
dominant_celltypes = []

for i in range(stgtcelltype.shape[0]):
    # Get cell types present in this spot
    cell_types_present = [col for col, val in zip(cell_type_columns, stgtcelltype.iloc[i]) if val > 0]
    # Take first one if multiple (or modify as needed)
    dominant_celltype = cell_types_present[0] if cell_types_present else 'Unknown'
    dominant_celltypes.append(dominant_celltype)

stadata.obs['celltype'] = dominant_celltypes

print(f"ST AnnData: {stadata}")
print(f"ST cell types: {stadata.obs['celltype'].value_counts()}")

In [None]:
# # ============================================================================
# # COMPLETE TRAINING AND INFERENCE FOR MOUSE BRAIN DATA
# # ============================================================================

# import torch
# import numpy as np
# from tqdm import tqdm
# import matplotlib.pyplot as plt
# import seaborn as sns
# from model.core_models_v2 import AdvancedHierarchicalDiffusion
# # from model.utils import canonicalize_coordinates

# device = 'cuda' if torch.cuda.is_available() else 'cpu'
# print(f"Using device: {device}")

# # ===========================
# # STEP 1: PREPARE DATA
# # ===========================
# print("\n" + "="*70)
# print("STEP 1: Preparing data")
# print("="*70)

# # Get common genes
# sc_genes = set(scadata.var_names)
# st_genes = set(stadata.var_names)
# common_genes = sorted(list(sc_genes & st_genes))
# print(f"Common genes: {len(common_genes)}")

# # Extract and convert to dense
# def to_dense(X):
#     return X.toarray() if hasattr(X, 'toarray') else X

# sc_expr = to_dense(scadata[:, common_genes].X)
# st_expr = to_dense(stadata[:, common_genes].X)
# st_coords = stadata.obsm['spatial']

# print(f"SC expression shape: {sc_expr.shape}")
# print(f"ST expression shape: {st_expr.shape}")
# print(f"ST coords shape: {st_coords.shape}")

# # Calculate dropout rate
# sc.pp.calculate_qc_metrics(scadata, percent_top=None, log1p=False, inplace=True)
# sc.pp.calculate_qc_metrics(stadata, percent_top=None, log1p=False, inplace=True)
# dp = 1 - scadata.obs['n_genes_by_counts'].median() / stadata.obs['n_genes_by_counts'].median()
# print(f"Calculated dropout rate (dp): {dp:.4f}")

# # Create slide IDs (single slide, all zeros)
# slide_ids = np.zeros(len(st_expr), dtype=np.int64)

# # ===========================
# # STEP 2: INITIALIZE MODEL
# # ===========================
# print("\n" + "="*70)
# print("STEP 2: Initializing model")
# print("="*70)

# torch.manual_seed(42)
# np.random.seed(42)

# # Get cell type info (handle different possible column names)
# cell_type_col = None
# for col in ['celltype', 'cell_type', 'CellType', 'cluster', 'leiden']:
#     if col in scadata.obs.columns:
#         cell_type_col = col
#         break

# if cell_type_col:
#     cell_types_sc = scadata.obs[cell_type_col].values
#     print(f"Using cell type column: {cell_type_col}")
#     print(f"Unique cell types: {len(np.unique(cell_types_sc))}")
# else:
#     cell_types_sc = None
#     print("No cell type information found")

# model = AdvancedHierarchicalDiffusion(
#     st_gene_expr=st_expr,
#     st_coords=st_coords,
#     sc_gene_expr=sc_expr,
#     cell_types_sc=cell_types_sc,
#     transport_plan=None,
#     D_st=None,
#     D_induced=None,
#     n_genes=len(common_genes),
#     n_embedding=[512, 256, 128],
#     coord_space_diameter=2.00,
#     sigma=3.0,
#     alpha=0.8,
#     mmdbatch=1000,
#     batch_size=512,
#     device=device,
#     lr_e=0.002,
#     lr_d=0.0002,
#     n_timesteps=400,
#     n_denoising_blocks=4,
#     hidden_dim=256,
#     num_heads=6,
#     num_hierarchical_scales=3,
#     dp=dp,
#     outf='advanced_diffusion_mousebrain'
# )

# # CRITICAL: Set slide_ids (all zeros for single slide)
# # model.slide_ids = torch.tensor(slide_ids, dtype=torch.long, device=device)
# # print(f"Model initialized with D_GEOM={model.D_GEOM}")

# # ===========================
# # STEP 3: TRAIN ENCODER
# # ===========================
# print("\n" + "="*70)
# print("STEP 3: Training encoder (cross-modal alignment)")
# print("="*70)

# model.train_encoder(n_epochs=1201)

# # ===========================
# # STEP 4: TRAIN GRAPH-VAE (COSINE-NCA)
# # ===========================
# print("\n" + "="*70)
# print("STEP 4: Training Graph-VAE with cosine-NCA")
# print("="*70)

# model.train_graph_vae(
#     epochs=2501,
#     lr=1e-3,
#     warmup_epochs=1000,
#     beta_final=1e-3,
#     num_anchors_per_epoch=10000,
#     max_pos_per_anchor=10,
#     temperature=0.5,        # Cosine temperature
#     neg_pool_size=512       # Sampled negatives
# )

# # ===========================
# # STEP 5: TRAIN DIFFUSION
# # ===========================
# print("\n" + "="*70)
# print("STEP 5: Training latent diffusion")
# print("="*70)

# model.train_diffusion_latent(n_epochs=5001, p_drop_max=0.2)

# # ===========================
# # STEP 6: GENERATE SC COORDINATES
# # ===========================
# print("\n" + "="*70)
# print("STEP 6: Generating canonicalized SC coordinates")
# print("="*70)

# sc_coords_canonical = model.sample_sc_coordinates(
#     batch_size=512,
#     guidance_scale=10.0,
#     temp_sigma=1.0,
#     return_normalized=True
# )

# # Convert to numpy
# if torch.is_tensor(sc_coords_canonical):
#     sc_coords_np = sc_coords_canonical.cpu().numpy()
# else:
#     sc_coords_np = sc_coords_canonical

# # Store results
# scadata.obsm['spatial_canonical'] = sc_coords_np

# print(f"\nGenerated coordinates shape: {sc_coords_np.shape}")
# print(f"X range: [{sc_coords_np[:, 0].min():.3f}, {sc_coords_np[:, 0].max():.3f}]")
# print(f"Y range: [{sc_coords_np[:, 1].min():.3f}, {sc_coords_np[:, 1].max():.3f}]")

# # ===========================
# # STEP 7: VISUALIZE RESULTS
# # ===========================
# print("\n" + "="*70)
# print("STEP 7: Visualizing results")
# print("="*70)

# # Plot SC predictions colored by cell type
# fig, axes = plt.subplots(1, 2, figsize=(20, 8))

# # Get cell type column for coloring
# color_col = cell_type_col if cell_type_col else None

# if color_col:
#     # SC predictions
#     sc.pl.embedding(
#         scadata, 
#         basis='spatial_canonical', 
#         color=color_col,
#         size=50, 
#         title='SC Predicted Coordinates (Canonical)',
#         palette='tab20',
#         legend_loc='right margin',
#         ax=axes[0],
#         show=False
#     )
    
#     # ST ground truth
#     sc.pl.embedding(
#         stadata,
#         basis='spatial',
#         color='celltype',
#         size=100,
#         title='ST Reference (Ground Truth)',
#         palette='tab20',
#         legend_loc='right margin',
#         ax=axes[1],
#         show=False
#     )
# else:
#     # No cell type info - just plot coordinates
#     axes[0].scatter(sc_coords_np[:, 0], sc_coords_np[:, 1], s=10, alpha=0.5)
#     axes[0].set_title('SC Predicted Coordinates')
#     axes[0].set_xlabel('X')
#     axes[0].set_ylabel('Y')
#     axes[0].axis('equal')
    
#     axes[1].scatter(st_coords[:, 0], st_coords[:, 1], s=50, alpha=0.5)
#     axes[1].set_title('ST Reference')
#     axes[1].set_xlabel('X')
#     axes[1].set_ylabel('Y')
#     axes[1].axis('equal')

# plt.tight_layout()
# plt.savefig('mousebrain_sc_predictions.png', dpi=300, bbox_inches='tight')
# plt.show()

# print("\n" + "="*70)
# print("TRAINING AND INFERENCE COMPLETE!")
# print("Results saved in:")
# print("  - scadata.obsm['spatial_canonical']")
# print("  - mousebrain_sc_predictions.png")
# print("="*70)

# # ===========================
# # STEP 8: SAVE RESULTS
# # ===========================
# # Save model checkpoint
# torch.save({
#     'model_state': model.state_dict(),
#     'transform_info': model.last_canonicalization_transform,
#     'common_genes': common_genes
# }, 'model_checkpoint_mousebrain.pt')

# # Save SC predictions
# # scadata.write_h5ad('scadata_with_predictions.h5ad')

# print("\nCheckpoint and predictions saved!")
# print("  - model_checkpoint_mousebrain.pt")
# print("  - scadata_with_predictions.h5ad")

In [None]:
import numpy as np
def train_advanced_diffusion_mousebrain(scadata, stadata):
    """
    Train Advanced Hierarchical Diffusion model for mouse brain data
    """
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}")
    
    # Get common genes
    sc_genes = set(scadata.var_names)
    st_genes = set(stadata.var_names)
    common_genes = sorted(list(sc_genes & st_genes))
    
    print(f"Common genes: {len(common_genes)}")
    
    # Extract expression data
    sc_expr = scadata[:, common_genes].X
    st_expr = stadata[:, common_genes].X
    
    # Convert to dense if sparse
    if hasattr(sc_expr, 'toarray'):
        sc_expr = sc_expr.toarray()
    if hasattr(st_expr, 'toarray'):
        st_expr = st_expr.toarray()
        
    # Get spatial coordinates
    st_coords = stadata.obsm['spatial']
    
    print(f"SC expression shape: {sc_expr.shape}")
    print(f"ST expression shape: {st_expr.shape}")
    print(f"ST coords shape: {st_coords.shape}")
    
    # Calculate dp parameter
    sc.pp.calculate_qc_metrics(scadata, percent_top=None, log1p=False, inplace=True)
    sc.pp.calculate_qc_metrics(stadata, percent_top=None, log1p=False, inplace=True)
    dp = 1 - scadata.obs['n_genes_by_counts'].median() / stadata.obs['n_genes_by_counts'].median()
    
    print(f"Calculated dropout rate (dp): {dp:.4f}")
    
    # Initialize model
    
    torch.manual_seed(42)
    np.random.seed(42)
    
    model = AdvancedHierarchicalDiffusion(
        st_gene_expr=st_expr,
        st_coords=st_coords,
        sc_gene_expr=sc_expr,
        cell_types_sc=scadata.obs.get('celltype', None), 
        transport_plan=None,
        D_st=None,
        D_induced=None,
        n_genes=len(common_genes),
        n_embedding=[512, 256, 128],
        coord_space_diameter=2.00,
        sigma=3.0, 
        alpha=0.8,
        mmdbatch=1000,
        batch_size=512,
        device=device,
        lr_e=0.002,
        lr_d=0.0002,
        n_timesteps=400,
        n_denoising_blocks=4,
        hidden_dim=256,
        num_heads=6,
        num_hierarchical_scales=3,
        dp=dp,
        outf='advanced_diffusion_mousebrain'
    )
    
    # Train the model
    print("Training Advanced Hierarchical Diffusion model...")
    model.train(
        encoder_epochs=1201,
        vae_epochs=2501,
        diffusion_epochs=5001, p_drop_max=.2
    )

    # model.fine_tune_decoder_boundary(
    #     epochs=10,           # 8–15 is typical
    #     batch_size=1024,
    #     lambda_hull=5.0,     # 3–8 works well; increase only if leaks remain
    #     outlier_sigma=0.8    # amount of latent perturbation for hull shaping
    # )
    
    # Generate SC coordinates
    print("Generating SC coordinates...")

    sc_coords = model.sample_sc_coordinates(
        batch_size=512,
        guidance_scale=10.0,
        return_normalized=True
    )

    # sc_coords = model.sample_sc_coordinates_batched(
    #     batch_size=512, return_normalized=False
    # )
    
    # Store results
    scadata.obsm['advanced_diffusion_coords'] = sc_coords.cpu().numpy() if hasattr(sc_coords, 'cpu') else sc_coords
    
    print(f"Generated coordinates shape: {scadata.obsm['advanced_diffusion_coords'].shape}")
    
    return scadata, model

# Run training
scadata, model = train_advanced_diffusion_mousebrain(scadata, stadata)

In [None]:
sc_coords = model.sample_sc_coordinates(
    batch_size=512,
    guidance_scale=10.0,
    return_normalized=True
)

# sc_coords = model.sample_sc_coordinates_batched(
#     batch_size=512, return_normalized=False
# )

# Store results
scadata.obsm['advanced_diffusion_coords'] = sc_coords.cpu().numpy() if hasattr(sc_coords, 'cpu') else sc_coords

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Set up plotting
plt.rcParams['figure.figsize'] = (6, 8)
my_tab20 = sns.color_palette("tab20", n_colors=25).as_hex()

# Plot 1: Original SC coordinates (if they exist)
if 'x_global' in scadata.obs.columns:
    plt.figure(figsize=(8, 6))
    sc.pl.embedding(scadata, basis='spatial', color='celltype_mapped_refined', 
                   title='Original SC Coordinates', size=60,
                   palette=my_tab20, legend_loc='right margin')
    plt.show()

# Plot 2: Generated diffusion coordinates  
plt.figure(figsize=(8, 6))
sc.pl.embedding(scadata, basis='advanced_diffusion_coords', color='celltype_mapped_refined',
               title='Generated Diffusion Coordinates', size=60,
               palette=my_tab20, legend_loc='right margin')
plt.show()

In [None]:
scadata.obsm['advanced_diffusion_coords']

In [None]:
import numpy as np

# names of the embeddings you want to examine
emb_names = ["spatial", "advanced_diffusion_coords"]   # adjust as needed

for name in emb_names:
    # ---- fetch the coordinates ----
    if name in scadata.obsm:                # common place for embeddings
        XY = scadata.obsm[name]
    else:                                   # fall back to obs columns, e.g. x_global / y_global
        XY = scadata.obs[[f"{name}_0", f"{name}_1"]].to_numpy()
    
    # ---- compute basic stats ----
    x_min, x_max = XY[:, 0].min(), XY[:, 0].max()
    y_min, y_max = XY[:, 1].min(), XY[:, 1].max()
    
    print(f"{name}:")
    print(f"    x range = ({x_min:.3f}, {x_max:.3f})")
    print(f"    y range = ({y_min:.3f}, {y_max:.3f})\n")


In [None]:
import numpy as np

# names of the embeddings you want to examine
emb_names = ["spatial", "advanced_diffusion_coords"]   # adjust as needed

for name in emb_names:
    # ---- fetch the coordinates ----
    if name in scadata.obsm:                # common place for embeddings
        XY = scadata.obsm[name]
    else:                                   # fall back to obs columns, e.g. x_global / y_global
        XY = scadata.obs[[f"{name}_0", f"{name}_1"]].to_numpy()
    
    # ---- compute basic stats ----
    x_min, x_max = XY[:, 0].min(), XY[:, 0].max()
    y_min, y_max = XY[:, 1].min(), XY[:, 1].max()
    
    print(f"{name}:")
    print(f"    x range = ({x_min:.3f}, {x_max:.3f})")
    print(f"    y range = ({y_min:.3f}, {y_max:.3f})\n")
