In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

# Cell 2: force single‐threaded BLAS
os.environ["OMP_NUM_THREADS"]       = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"

In [None]:
# Cell 3: actually cap BLAS to 1 thread
from threadpoolctl import threadpool_limits

# 'blas' covers OpenBLAS, MKL, etc.
threadpool_limits(limits=1, user_api='blas')

# now import as usual, no more warning
import numpy as np
import scipy
# … any other packages that use OpenBLAS …


In [None]:
# import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import torch
import scanpy as sc
import pandas as pd
import numpy as np
from tqdm import tqdm
from sklearn.neighbors import kneighbors_graph
from scipy.sparse.csgraph import dijkstra
from scipy.sparse import csr_matrix, issparse
from sklearn.preprocessing import normalize
import ot 
from sklearn.neighbors import NearestNeighbors
import matplotlib.pyplot as plt

In [None]:
def construct_graph_torch(X, k, mode='connectivity', metric = 'minkowski', p=2, device='cuda'):
    '''construct knn graph with torch and gpu
    args:
        X: input data containing features (torch tensor)
        k: number of neighbors for each data point
        mode: 'connectivity' or 'distance'
        metric: distance metric (now euclidean supported for gpu knn)
        p: param for minkowski (not used if metric is euclidean)
    
    Returns:
        knn graph as a pytorch sparse tensor (coo format) or dense tensor depending on mode     
    '''

    assert mode in ['connectivity', 'distance'], "mode must be 'connectivity' or 'distance'."
    assert metric == 'euclidean', "for gpu knn, only 'euclidean' metric is currently supported in this implementation"

    if mode == 'connectivity':
        include_self = True
        mode_knn = 'connectivity'
    else:
        include_self = False
        mode_knn = 'distance'

    n_samples = X.shape[0]
    knn = NearestNeighbors(n_neighbors=k, metric=metric, algorithm='auto')

    if device == 'cuda' and torch.cuda.is_available():
        X_cpu = X.cpu().numpy()
    else:
        X_cpu = X.numpy()

    knn.fit(X_cpu)
    knn_graph_cpu = kneighbors_graph(knn, k, mode=mode_knn, include_self=include_self, metric=metric) #scipy sparse matrix on cpu
    knn_graph_coo = knn_graph_cpu.tocoo()

    if mode == 'connectivity':
        knn_graph = torch.sparse_coo_tensor(torch.LongTensor([knn_graph_coo.row, knn_graph_coo.col]),
                                            torch.FloatTensor(knn_graph_coo.data),
                                            size = knn_graph_coo.shape).to(device)
    elif mode == 'distance':
        knn_graph_dense = torch.tensor(knn_graph_cpu.toarray(), dtype=torch.float32, device=device) #move to gpu as dense tensor
        knn_graph = knn_graph_dense
    
    return knn_graph
    
def distances_cal_torch(graph, type_aware=None, aware_power =2, device='cuda'):
    '''
    calculate distance matrix from graph using dijkstra's algo
    args:
        graph: knn graph (pytorch sparse or dense tensor)
        type_aware: not implemented in this torch version for simplicity
        aware_power: same ^^
        device (str): 'cpu' or 'cuda' device to use
    Returns:
        distance matrix as a torch tensor
    '''

    if isinstance(graph, torch.Tensor) and graph.is_sparse:
        graph_cpu_csr = csr_matrix(graph.cpu().to_dense().numpy())
    elif isinstance(graph, torch.Tensor) and not graph.is_sparse:
        graph_cpu_csr = csr_matrix(graph.cpu().numpy())
    else:
        graph_cpu_csr = csr_matrix(graph) #assume scipy sparse matrix if not torch tensor

    shortestPath_cpu = dijkstra(csgraph = graph_cpu_csr, directed=False, return_predecessors=False) #dijkstra on cpu
    shortestPath = torch.tensor(shortestPath_cpu, dtype=torch.float32, device=device)

    # the_max = torch.nanmax(shortestPath[shortestPath != float('inf')])
    # shortestPath[shortestPath > the_max] = the_max

    #mask out infinite distances
    mask = shortestPath != float('inf')
    if mask.any():
        the_max = torch.max(shortestPath[mask])
        shortestPath[~mask] = the_max #replace inf with max value
    else:
        the_max = 1.0 #fallback if all are inf (should not happen in connected graphs)

    original_max_distance = the_max.item()
    C_dis = shortestPath / the_max
    # C_dis = shortestPath
    # C_dis -= torch.mean(C_dis)
    return C_dis, original_max_distance

def calculate_D_sc_torch(X_sc, k_neighbors=10, graph_mode='connectivity', device='cpu'):
    '''calculate distance matrix from graph using dijkstra's algo
    args:
        graph: knn graph (torch sparse or dense tensor)
        type_aware: not implemented
        aware_power: same ^^
        
    returns:
        distanced matrix as torch tensor'''
    
    if not isinstance(X_sc, torch.Tensor):
        raise TypeError('Input X_sc must be a pytorch tensor')
    
    if device == 'cuda' and torch.cuda.is_available():
        X_sc = X_sc.cuda(device=device)
    else:
        X_sc = X_sc.cpu()
        device= 'cpu'

    print(f'using device: {device}')
    print(f'constructing knn graph...')
    # X_normalized = normalize(X_sc.cpu().numpy(), norm='l2') #normalize on cpu for sklearn knn
    X_normalized = X_sc
    X_normalized_torch = torch.tensor(X_normalized, dtype=torch.float32, device=device)

    Xgraph = construct_graph_torch(X_normalized_torch, k=k_neighbors, mode=graph_mode, metric='euclidean', device=device)

    print('calculating distances from graph....')
    D_sc, sc_max_distance = distances_cal_torch(Xgraph, device=device)

    print('D_sc calculation complete')
    
    return D_sc, sc_max_distance


In [None]:
from sklearn.neighbors import kneighbors_graph, NearestNeighbors
from scipy.sparse.csgraph import dijkstra
from scipy.sparse import csr_matrix, issparse
from sklearn.preprocessing import normalize
import ot

def construct_graph_spatial(location_array, k, mode='distance', metric='euclidean', p=2):
    '''construct KNN graph based on spatial coordinates
    args:
        location_array: spatial coordinates of spots (n-spots * 2)
        k: number of neighbors for each spot
        mode: 'connectivity' or 'distance'
        metric: distance metric for knn (p=2 is euclidean)
        p: param for minkowski if connectivity
        
    returns:
        scipy.sparse.csr_matrix: knn graph in csr format
    '''

    assert mode in ['connectivity', 'distance'], "mode must be 'connectivity' or 'distance'"
    if mode == 'connectivity':
        include_self = True
    else:
        include_self = False
    
    c_graph = kneighbors_graph(location_array, k, mode=mode, metric=metric, include_self=include_self, p=p)
    return c_graph

def distances_cal_spatial(graph, spot_ids=None, spot_types=None, aware_power=2):
    '''calculate spatial distance matrix from knn graph
    args:
        graph (scipy.sparse.csr_matrix): knn graph
        spot_ids (list, optional): list of spot ids corresponding to the rows/cols of the graph. required if type_aware is used
        spot_types (pd.Series, optinal): pandas series of spot types for type aware distance adjustment. required if type_aware is used
        aware_power (int): power for type-aware distance adjustment
        
    returns:
        sptial distance matrix'''
    shortestPath = dijkstra(csgraph = csr_matrix(graph), directed=False, return_predecessors=False)
    shortestPath = np.nan_to_num(shortestPath, nan=np.inf) #handle potential inf valyes after dijkstra

    if spot_types is not None and spot_ids is not None:
        shortestPath_df = pd.DataFrame(shortestPath, index=spot_ids, columns=spot_ids)
        shortestPath_df['id1'] = shortestPath_df.index
        shortestPath_melted = shortestPath_df.melt(id_vars=['id1'], var_name='id2', value_name='value')

        type_aware_df = pd.DataFrame({'spot': spot_ids, 'spot_type': spot_types}, index=spot_ids)
        meta1 = type_aware_df.copy()
        meta1.columns = ['id1', 'type1']
        meta2 = type_aware_df.copy()
        meta2.columns = ['id2', 'type2']

        shortestPath_melted = pd.merge(shortestPath_melted, meta1, on='id1', how='left')
        shortestPath_melted = pd.merge(shortestPath_melted, meta2, on='id2', how='left')

        shortestPath_melted['same_type'] = shortestPath_melted['type1'] == shortestPath_melted['type2']
        shortestPath_melted.loc[(~shortestPath_melted.smae_type), 'value'] = shortestPath_melted.loc[(~shortestPath_melted.same_type),
                                                                                                     'value'] * aware_power
        shortestPath_melted.drop(['type1', 'type2', 'same_type'], axis=1, inplace=True)
        shortestPath_pivot = shortestPath_melted.pivot(index='id1', columns='id2', values='value')

        order = spot_ids
        shortestPath = shortestPath_pivot[order].loc[order].values
    else:
        shortestPath = np.asarray(shortestPath) #ensure it's a numpy array

    #mask out infinite distances
    mask = shortestPath != float('inf')
    if mask.any():
        the_max = np.max(shortestPath[mask])
        shortestPath[~mask] = the_max #replace inf with max value
    else:
        the_max = 1.0 #fallback if all are inf (should not happen in connected graphs)

    #store original max distance for scale reference
    original_max_distance = the_max
    C_dis = shortestPath / the_max
    # C_dis = shortestPath
    # C_dis -= np.mean(C_dis)

    return C_dis, original_max_distance

def calculate_D_st_from_coords(spatial_coords, X_st=None, k_neighbors=10, graph_mode='distance', aware_st=False, 
                               spot_types=None, aware_power_st=2, spot_ids=None):
    '''calculates the spatial distance matrix D_st for spatial transcriptomics data directly from coordinates and optional spot types
    args:
        spatial_coords: spatial coordinates of spots (n_spots * 2)
        X_st: St gene expression data (not used for D_st calculation itself)
        k_neighbors: number of neighbors for knn graph
        graph_mode: 'connectivity or 'distance' for knn graph
        aware_st: whether to use type-aware distance adjustment
        spot_types: pandas series of spot types for type-aware adjustment
        aware_power_st: power for type-aware distance adjustment
        spot_ids: list or index of spot ids, required if spot_ids is provided
        
    returns:
        np.ndarray: spatial disance matrix D_st'''
    
    if isinstance(spatial_coords, pd.DataFrame):
        location_array = spatial_coords.values
        if spot_ids is None:
            spot_ids = spatial_coords.index.tolist() #use index of dataframe if available
    elif isinstance(spatial_coords, np.ndarray):
        location_array = spatial_coords
        if spot_ids is None:
            spot_ids = list(range(location_array.shape[0])) #generate default ids if not provided

    else:
        raise TypeError('spatial_coords must be a pandas dataframe or a numpy array')
    
    print(f'constructing {graph_mode} graph for ST data with k={k_neighbors}.....')
    Xgraph_st = construct_graph_spatial(location_array, k=k_neighbors, mode=graph_mode)
    
    if aware_st:
        if spot_types is None or spot_ids is None:
            raise ValueError('spot_types and spot_ids must be provided when aware_st=True')
        if not isinstance(spot_types, pd.Series):
            spot_types = pd.Series(spot_types, idnex=spot_ids) 
        print('applying type aware distance adjustment for ST data')
        print(f'aware power for ST: {aware_power_st}')
    else:
        spot_types = None 

    print(f'calculating spatial distances.....')
    D_st, st_max_distance = distances_cal_spatial(Xgraph_st, spot_ids=spot_ids, spot_types=spot_types, aware_power=aware_power_st)

    print('D_st calculation complete')
    return D_st, st_max_distance


def calculate_D_st_euclidean(spatial_coords):
    """
    Calculate Euclidean distance matrix for ST spots.
    
    Args:
        spatial_coords: (m_spots, 2) spatial coordinates
        
    Returns:
        D_st_euclid: (m_spots, m_spots) normalized Euclidean distance matrix
    """
    from scipy.spatial.distance import pdist, squareform
    
    if isinstance(spatial_coords, pd.DataFrame):
        coords_array = spatial_coords.values
    elif isinstance(spatial_coords, np.ndarray):
        coords_array = spatial_coords
    else:
        coords_array = np.array(spatial_coords)
    
    # Compute pairwise Euclidean distances
    D_euclid = squareform(pdist(coords_array, metric='euclidean'))
    
    # Normalize to [0,1]
    max_dist = D_euclid.max()
    if max_dist > 0:
        D_euclid = D_euclid / max_dist
    
    return D_euclid.astype(np.float32)

# patient 2 data load

In [None]:
def load_and_process_cscc_data():
    """
    Load and process the cSCC dataset with multiple ST replicates.
    """
    print("Loading cSCC data...")
    
    # Load SC data
    scadata = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/scP2.h5ad')
    
    # Load all 3 ST datasets
    stadata1 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP2.h5ad')
    stadata2 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP2rep2.h5ad')
    stadata3 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP2rep3.h5ad')
    
    # Normalize and log transform
    for adata in [scadata, stadata1, stadata2, stadata3]:
        sc.pp.normalize_total(adata)
        sc.pp.log1p(adata)
    
    # Create rough cell types for SC data
    scadata.obs['rough_celltype'] = scadata.obs['level1_celltype'].astype(str)
    scadata.obs.loc[scadata.obs['level1_celltype']=='CLEC9A','rough_celltype'] = 'DC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='CD1C','rough_celltype'] = 'DC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='ASDC','rough_celltype'] = 'DC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='PDC','rough_celltype'] = 'PDC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='MDSC','rough_celltype'] = 'DC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='LC','rough_celltype'] = 'DC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='Mac','rough_celltype'] = 'Myeloid cell'
    scadata.obs.loc[scadata.obs['level1_celltype']=='Tcell','rough_celltype'] = 'T cell'
    scadata.obs.loc[scadata.obs['level2_celltype']=='TSK','rough_celltype'] = 'TSK'
    scadata.obs.loc[scadata.obs['level2_celltype'].isin(['Tumor_KC_Basal', 'Tumor_KC_Diff','Tumor_KC_Cyc']),'rough_celltype'] = 'NonTSK'
    
    return scadata, stadata1, stadata2, stadata3

def prepare_combined_st_for_diffusion(stadata1, stadata2, stadata3, scadata):
    """
    Combine all ST datasets for diffusion training while maintaining gene alignment.
    Key innovation: Use ALL ST data points for better training.
    """
    print("Preparing combined ST data for diffusion training...")
    
    # Get common genes between SC and all ST datasets
    sc_genes = set(scadata.var_names)
    st1_genes = set(stadata1.var_names)
    st2_genes = set(stadata2.var_names)
    st3_genes = set(stadata3.var_names)
    
    common_genes = sorted(list(sc_genes & st1_genes & st2_genes & st3_genes))
    print(f"Common genes across all datasets: {len(common_genes)}")
    
    # Extract aligned expression data
    sc_expr = scadata[:, common_genes].X
    st1_expr = stadata1[:, common_genes].X
    st2_expr = stadata2[:, common_genes].X
    st3_expr = stadata3[:, common_genes].X
    
    # Convert to dense if sparse
    if hasattr(sc_expr, 'toarray'):
        sc_expr = sc_expr.toarray()
    if hasattr(st1_expr, 'toarray'):
        st1_expr = st1_expr.toarray()
    if hasattr(st2_expr, 'toarray'):
        st2_expr = st2_expr.toarray()
    if hasattr(st3_expr, 'toarray'):
        st3_expr = st3_expr.toarray()
    
    # Get spatial coordinates
    st1_coords = stadata1.obsm['spatial']
    st2_coords = stadata2.obsm['spatial']
    st3_coords = stadata3.obsm['spatial']

    # Store separate coordinate lists for block-diagonal graph
    st_coords_list = [st1_coords, st2_coords, st3_coords]
    
    # Combine all ST data
    st_expr_combined = np.vstack([st1_expr, st2_expr, st3_expr])

    from sklearn.preprocessing import StandardScaler
    scaler = StandardScaler()
    st_expr_combined = scaler.fit_transform(st_expr_combined)

    st_coords_combined = np.vstack([st1_coords, st2_coords, st3_coords])

    sc_expr = scaler.fit_transform(sc_expr)

    
    # Create dataset labels for tracking
    dataset_labels = (['dataset1'] * len(st1_expr) + 
                     ['dataset2'] * len(st2_expr) + 
                     ['dataset3'] * len(st3_expr))
    
    print(f"Combined ST data shape: {st_expr_combined.shape}")
    print(f"Combined ST coords shape: {st_coords_combined.shape}")
    print(f"SC data shape: {sc_expr.shape}")
    
    # Convert to tensors
    X_sc = torch.tensor(sc_expr, dtype=torch.float32)
    X_st_combined = torch.tensor(st_expr_combined, dtype=torch.float32)
    Y_st_combined = st_coords_combined.astype(np.float32)
    
    return X_sc, X_st_combined, Y_st_combined, dataset_labels, common_genes, st_coords_list

# Load and process data
scadata, stadata1, stadata2, stadata3 = load_and_process_cscc_data()

# Prepare combined data for diffusion
X_sc, X_st_combined, Y_st_combined, dataset_labels, common_genes, st_coords_list = prepare_combined_st_for_diffusion(
    stadata1, stadata2, stadata3, scadata
)

print(f"Data preparation complete!")
print(f"SC cells: {X_sc.shape[0]}")
print(f"Combined ST spots: {X_st_combined.shape[0]}")
print(f"Common genes: {len(common_genes)}")



# diffusion model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim.lr_scheduler as lr_scheduler
import numpy as np
from tqdm import tqdm
import os
import time
import scipy
from sklearn.neighbors import NearestNeighbors
from scipy.spatial import cKDTree
from typing import Optional, Dict, Tuple, List
import torch_geometric
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.data import Data

# =====================================================
# PART X: Graph-VAE Components
# =====================================================

class GraphVAEEncoder(nn.Module):
    """
    Graph encoder that learns latent representations from ST spot graphs.
    ⚠️ Do **not** touch `train_encoder`; its aligned embeddings are the sole conditioning signal throughout.
    """
    def __init__(self, input_dim, hidden_dim=128, latent_dim=32):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        
        # Two GraphConv layers as specified
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        
        # MLP to output μ and log σ² FOR EACH NODE (not graph-level)
        self.mu_head = nn.Linear(hidden_dim, latent_dim)
        self.logvar_head = nn.Linear(hidden_dim, latent_dim)
        
    def forward(self, x, edge_index, edge_weight=None, batch=None):
        """
        x: node features (aligned embeddings E(X_st)) - shape (n_nodes, input_dim)
        edge_index: graph edges from K-NN adjacency
        edge_weight: optional edge weights
        batch: not used since we want node-level representations
        
        Returns:
        mu: (n_nodes, latent_dim)
        logvar: (n_nodes, latent_dim)
        """
        # Two GraphConv layers
        h = torch.relu(self.conv1(x, edge_index, edge_weight))
        h = torch.relu(self.conv2(h, edge_index, edge_weight))
        
        # NO GLOBAL POOLING - we want node-level representations
        # Output μ and log σ² for each node
        mu = self.mu_head(h)        # Shape: (n_nodes, latent_dim)
        logvar = self.logvar_head(h)  # Shape: (n_nodes, latent_dim)
        
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        """Reparameterization trick - works element-wise"""
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std


class GraphVAEDecoder(nn.Module):
    """
    Graph decoder that outputs 2D coordinates from latent z ONLY.
    Features are NOT passed to force geometry into z.
    """
    def __init__(self, latent_dim=32, hidden_dim=128):
        super().__init__()
        self.latent_dim = latent_dim
        
        # Decoder takes ONLY latent z (no conditioning)
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2)  # Output 2D coordinates
        )
        
    def forward(self, z):
        """
        z: latent vectors (batch_size, latent_dim) ONLY
        """
        coords = self.decoder(z)
        return coords

def precompute_knn_edges(coords, k=30, device='cuda'):
    """
    Helper function to precompute K-NN edges for torch-geometric style layers.
    Uses existing graph construction utilities where possible.
    """
    if isinstance(coords, torch.Tensor):
        coords_np = coords.cpu().numpy()
    else:
        coords_np = coords
        
    # Use existing construct_graph_spatial function
    from sklearn.neighbors import kneighbors_graph
    
    # Build KNN graph
    knn_graph = kneighbors_graph(
        coords_np, 
        n_neighbors=k, 
        mode='connectivity', 
        include_self=False
    )
    
    # Convert to torch-geometric format
    from torch_geometric.utils import from_scipy_sparse_matrix
    edge_index, edge_weight = from_scipy_sparse_matrix(knn_graph)
    
    # CRITICAL FIX: Ensure correct dtypes
    edge_index = edge_index.long().to(device)      # Edge indices should be long
    edge_weight = edge_weight.float().to(device)   # Edge weights should be float32
    
    return edge_index, edge_weight

class LatentDenoiser(nn.Module):
    """
    Latent-space denoiser identical to current MLP/U-Net stack but for latent dim=32.
    ⚠️ Do **not** touch `train_encoder`; its aligned embeddings are the sole conditioning signal throughout.
    """
    def __init__(self, latent_dim=32, condition_dim=128, hidden_dim=256, n_blocks=6):
        super().__init__()
        self.latent_dim = latent_dim
        self.condition_dim = condition_dim
        
        # Time embedding (reuse existing SinusoidalEmbedding)
        self.time_embed = nn.Sequential(
            SinusoidalEmbedding(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Latent encoder
        self.latent_encoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Condition encoder (for aligned embeddings)
        self.condition_encoder = nn.Sequential(
            nn.Linear(condition_dim, hidden_dim),
            nn.ReLU(), 
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Denoising blocks (similar to existing hierarchical blocks)
        self.denoising_blocks = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim * 2),
                nn.ReLU(),
                nn.Dropout(0.1),
                nn.Linear(hidden_dim * 2, hidden_dim),
                nn.LayerNorm(hidden_dim)
            ) for _ in range(n_blocks)
        ])
        
        # Output head
        self.output_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, latent_dim)
        )
        
    
    def forward(self, z_noisy, t, condition):
        """
        z_noisy: noisy latent vectors (batch_size, latent_dim)
        t: timestep (batch_size,) - NOW 1D instead of 2D
        condition: aligned embeddings E(X) (batch_size, condition_dim)
        """
        batch_size = z_noisy.size(0)
        # ENSURE inputs are 2D
        if z_noisy.dim() > 2:
            z_noisy = z_noisy.squeeze()
        if condition.dim() > 2:
            condition = condition.squeeze()
            
        # Handle 1D timestep input
        if t.dim() == 1:
            t = t.unsqueeze(1)  # Make it (batch_size, 1)

        t = t.view(batch_size, 1)
        
        # Encode inputs
        z_enc = self.latent_encoder(z_noisy)
        t_enc = self.time_embed(t)
        c_enc = self.condition_encoder(condition)
        
        # Combine features
        h = z_enc + t_enc + c_enc
        
        # Apply denoising blocks
        for block in self.denoising_blocks:
            h = h + block(h)  # Residual connections
            
        # Output predicted noise
        noise_pred = self.output_head(h)
        return noise_pred

# =====================================================
# PART 1: Advanced Network Components
# =====================================================

class FeatureNet(nn.Module):
    def __init__(self, n_genes, n_embedding=[512, 256, 128], dp=0):
        super(FeatureNet, self).__init__()

        self.fc1 = nn.Linear(n_genes, n_embedding[0])
        self.bn1 = nn.LayerNorm(n_embedding[0])
        self.fc2 = nn.Linear(n_embedding[0], n_embedding[1])
        self.bn2 = nn.LayerNorm(n_embedding[1])
        self.fc3 = nn.Linear(n_embedding[1], n_embedding[2])
        
        self.dp = nn.Dropout(dp)
        
    def forward(self, x, isdp=False):
        if isdp:
            x = self.dp(x)
        x = F.relu(self.bn1(self.fc1(x)))
        x = F.relu(self.bn2(self.fc2(x)))
        x = self.fc3(x)
        return x

class SinusoidalEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        """
        x: (batch_size, 1) or (batch_size,)
        Returns: (batch_size, dim)
        """
        if x.dim() == 1:
            x = x.unsqueeze(1)  # Make (batch_size, 1)
        
        device = x.device
        half_dim = self.dim // 2
        emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x * emb.unsqueeze(0)  # (batch_size, half_dim)
        emb = torch.cat([emb.sin(), emb.cos()], dim=-1)  # (batch_size, dim)
        
        if emb.size(1) != self.dim:
            # Handle odd dimensions
            emb = emb[:, :self.dim]
            
        return emb

import torch.optim as optim   
from geomloss import SamplesLoss

# OT refinement function
def refine_with_ot(sc_coords, st_coords, n_steps=50, lr=1e-2):
    """
    Refines SC coordinates by minimizing entropic OT divergence to ST coords.
    sc_coords: Tensor (N,2) initial SC coordinates
    st_coords: Tensor (M,2) ST spot coordinates
    """
    sinkhorn = SamplesLoss("sinkhorn", p=2, blur=0.05, scaling=0.9)
    coords = sc_coords.clone().detach().requires_grad_(True)
    optimizer = optim.Adam([coords], lr=lr)

    for _ in range(n_steps):
        optimizer.zero_grad()
        loss_ot = sinkhorn(coords.unsqueeze(0), st_coords.unsqueeze(0))
        loss_ot.backward()
        optimizer.step()

    return coords.detach()
    
class OTGuidedSampler:
    def __init__(self,
                 T_opt: torch.Tensor,
                 st_coords_norm: torch.Tensor,
                 n_timesteps: int):
        self.T_opt       = T_opt           # (n_sc, n_st)
        self.st_coords   = st_coords_norm  # (n_st, 2)
        self.n_timesteps = n_timesteps

    def get_ot_guidance(self, sc_indices: List[int], t: int) -> torch.Tensor:
        """
        Returns the expected OT‐based target for each sc index,
        plus a small decaying jitter.
        """
        if self.T_opt is None:
            return None

        guidance = []
        for sc_idx in sc_indices:
            st_w = self.T_opt[sc_idx]                 # (n_st,)
            total = st_w.sum()
            if total <= 0:
                guidance.append(torch.zeros(2, device=self.st_coords.device))
                continue
            # expected spot location (no argmax sampling)
            w_norm = st_w / total                    # normalize
            target_mean = (w_norm.unsqueeze(1) * self.st_coords).sum(dim=0)
            # decaying noise: maximal when t≈T, zero at t=0
            noise_scale = 0.02 * (t / self.n_timesteps)
            jitter = torch.randn_like(target_mean) * noise_scale
            guidance.append(target_mean + jitter)

        return torch.stack(guidance)  # (batch_size, 2)
    
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.neighbors import NearestNeighbors
import numpy as np

class CellTypeEmbedding(nn.Module):
    """Learned embeddings for cell types"""
    def __init__(self, num_cell_types, embedding_dim):
        super().__init__()
        self.embedding = nn.Embedding(num_cell_types, embedding_dim)
        
    def forward(self, cell_type_indices):
        return self.embedding(cell_type_indices)

class UncertaintyHead(nn.Module):
    """Predicts coordinate uncertainty"""
    def __init__(self, input_dim, hidden_dim=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2)  # Uncertainty for x and y
        )
        
    def forward(self, x):
        return F.softplus(self.net(x)) + 0.01  # Ensure positive uncertainty

class PhysicsInformedLayer(nn.Module):
    """Incorporates cell non-overlap constraints"""
    def __init__(self, feature_dim):
        super().__init__()
        self.radius_predictor = nn.Sequential(
            nn.Linear(feature_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Softplus()
        )
        self.repulsion_strength = nn.Parameter(torch.tensor(0.1))
        
    def compute_repulsion_gradient(self, coords, radii, cell_types=None):
        """Compute repulsion forces between cells"""
        batch_size = coords.shape[0]
        
        # Compute pairwise distances
        distances = torch.cdist(coords, coords, p=2)
        
        # Compute sum of radii for each pair
        radii_sum = radii + radii.T
        
        # Compute overlap (positive when cells overlap)
        overlap = F.relu(radii_sum - distances + 1e-6)
        
        # Mask out self-interactions
        mask = (1 - torch.eye(batch_size, device=coords.device))
        overlap = overlap * mask
        
        # Compute repulsion forces
        coord_diff = coords.unsqueeze(1) - coords.unsqueeze(0)  # (B, B, 2)
        distances_safe = distances + 1e-6  # Avoid division by zero
        
        # Normalize direction vectors
        directions = coord_diff / distances_safe.unsqueeze(-1)
        
        # Apply stronger repulsion for same cell types (optional)
        if cell_types is not None:
            same_type_mask = (cell_types.unsqueeze(1) == cell_types.unsqueeze(0)).float()
            repulsion_weight = 1.0 + 0.5 * same_type_mask  # 50% stronger for same type
        else:
            # repulsion_weight = 1.0
            batch_size = coords.shape[0]
            repulsion_weight = torch.ones(batch_size, batch_size, device=coords.device)
            
        # Compute repulsion magnitude
        repulsion_magnitude = overlap.unsqueeze(-1) * repulsion_weight.unsqueeze(-1)
        
        # Sum repulsion forces from all other cells
        repulsion_forces = (repulsion_magnitude * directions * mask.unsqueeze(-1)).sum(dim=1)
        
        return repulsion_forces
        
    def forward(self, coords, features, cell_types=None):
        # Predict cell radii based on features
        radii = self.radius_predictor(features).squeeze(-1) * 0.01  # Scale to reasonable size
        
        # Compute repulsion gradient
        repulsion_grad = self.compute_repulsion_gradient(coords, radii, cell_types)
        
        return repulsion_grad * self.repulsion_strength, radii
    
class SpatialBatchSampler:
    """Sample spatially contiguous batches for geometric attention"""
    
    def __init__(self, coordinates, batch_size, k_neighbors=None):
        """
        coordinates: (N, 2) array of spatial coordinates
        batch_size: size of each batch
        k_neighbors: number of neighbors to precompute (default: batch_size)
        """
        self.coordinates = coordinates
        self.batch_size = batch_size
        self.k_neighbors = k_neighbors or min(batch_size, len(coordinates))
        
        # Precompute nearest neighbors
        self.nbrs = NearestNeighbors(
            n_neighbors=self.k_neighbors, 
            algorithm='kd_tree'
        ).fit(coordinates)
        
    def sample_spatial_batch(self):
        """Sample a spatially contiguous batch"""
        # Pick random center point
        center_idx = np.random.randint(len(self.coordinates))
        
        # Get k nearest neighbors
        distances, indices = self.nbrs.kneighbors(
            self.coordinates[center_idx:center_idx+1], 
            return_distance=True
        )
        
        # Return indices as torch tensor
        batch_indices = torch.tensor(indices.flatten()[:self.batch_size], dtype=torch.long)
        return batch_indices

# =====================================================
# PART 2: Hierarchical Diffusion Architecture
# =====================================================

class HierarchicalDiffusionBlock(nn.Module):
    """Multi-scale diffusion block for coarse-to-fine generation"""
    def __init__(self, dim, num_scales=3):
        super().__init__()
        self.num_scales = num_scales
        
        # Coarse-level predictor (for clusters/regions)
        self.coarse_net = nn.Sequential(
            nn.Linear(dim, dim * 2),
            nn.ReLU(),
            nn.Linear(dim * 2, dim)
        )
        
        # Fine-level predictor (for individual cells)
        self.fine_net = nn.Sequential(
            nn.Linear(dim * 2, dim * 2),  # Takes both coarse and fine features
            nn.ReLU(),
            nn.Linear(dim * 2, dim)
        )
        
        # Scale mixing weights
        self.scale_mixer = nn.Sequential(
            nn.Linear(1, 64),  # Takes timestep
            nn.ReLU(),
            nn.Linear(64, num_scales),
            nn.Softmax(dim=-1)
        )
        
    def forward(self, x, t, coarse_context=None):
        # Determine scale weights based on timestep
        scale_weights = self.scale_mixer(t.unsqueeze(-1))
        
        # Coarse prediction
        coarse_pred = self.coarse_net(x)
        
        # Fine prediction (conditioned on coarse if available)
        if coarse_context is not None:
            fine_input = torch.cat([x, coarse_context], dim=-1)
        else:
            fine_input = torch.cat([x, coarse_pred], dim=-1)
        fine_pred = self.fine_net(fine_input)
        
        # Mix scales based on timestep
        output = scale_weights[:, 0:1] * coarse_pred + scale_weights[:, 1:2] * fine_pred
        
        return output  


# =====================================================
# PART 3: Main Advanced Diffusion Model
# =====================================================

class AdvancedHierarchicalDiffusion(nn.Module):
    def __init__(
        self,
        st_gene_expr,
        st_coords,
        sc_gene_expr,
        cell_types_sc=None,  # Cell type labels for SC data
        transport_plan=None,  # Optimal transport plan from domain alignment
        D_st=None,
        D_induced=None,
        n_genes=None,
        # n_embedding=128,
        n_embedding=[512, 256, 128],
        coord_space_diameter=200,
        st_max_distance=None,
        sc_max_distance=None,
        sigma=3.0,
        alpha=0.9,
        mmdbatch=0.1,
        batch_size=64,
        device='cuda',
        lr_e=0.0001,
        lr_d=0.0002,
        n_timesteps=1000,
        n_denoising_blocks=6,
        hidden_dim=512,
        num_heads=8,
        num_hierarchical_scales=3,
        dp=0.1,
        outf='output'
    ):
        super().__init__()

        self.diffusion_losses = {
            'total': [],
            'diffusion': [],
            'struct': [],
            'physics': [],
            'uncertainty': [],
            'epochs': []
        }

        # Loss tracking for Graph-VAE training
        self.vae_losses = {
            'total': [],
            'reconstruction': [],
            'kl': [],
            'epochs': []
        }
        
        # Loss tracking for Latent Diffusion training  
        self.latent_diffusion_losses = {
            'total': [],
            'diffusion': [],
            'struct': [],
            'epochs': []
        }
        
        # Keep encoder losses separate (if you want to track them)
        self.encoder_losses = {
            'total': [],
            'pred': [],
            'circle': [],
            'mmd': [],
            'epochs': []
        }
        
        self.device = device
        self.batch_size = batch_size
        self.n_timesteps = n_timesteps
        self.sigma = sigma
        self.alpha = alpha
        self.mmdbatch = mmdbatch
        self.n_embedding = n_embedding
        
        # Create output directory
        self.outf = outf
        if not os.path.exists(outf):
            os.makedirs(outf)
        
        # Store data
        self.st_gene_expr = torch.tensor(st_gene_expr, dtype=torch.float32).to(device)
        self.st_coords = torch.tensor(st_coords, dtype=torch.float32).to(device)
        self.sc_gene_expr = torch.tensor(sc_gene_expr, dtype=torch.float32).to(device)

        
        # Temperature regularization for geometric attention
        self.temp_weight_decay = 1e-4
        
        # Store transport plan if provided
        self.transport_plan = torch.tensor(transport_plan, dtype=torch.float32).to(device) if transport_plan is not None else None
        
        # Process cell types
        if cell_types_sc is not None:
            # Convert cell type strings to indices
            unique_cell_types = np.unique(cell_types_sc)
            self.cell_type_to_idx = {ct: i for i, ct in enumerate(unique_cell_types)}
            self.num_cell_types = len(unique_cell_types)
            cell_type_indices = [self.cell_type_to_idx[ct] for ct in cell_types_sc]
            self.sc_cell_types = torch.tensor(cell_type_indices, dtype=torch.long).to(device)
        else:
            self.sc_cell_types = None
            self.num_cell_types = 0
            
        # Store distance matrices
        self.D_st = torch.tensor(D_st, dtype=torch.float32).to(device) if D_st is not None else None
        self.D_induced = torch.tensor(D_induced, dtype=torch.float32).to(device) if D_induced is not None else None

        # If D_st is not provided, calculate it from spatial coordinates
        if self.D_st is None:
            print("D_st not provided, calculating from spatial coordinates...")
            if isinstance(st_coords, torch.Tensor):
                st_coords_np = st_coords.cpu().numpy()
            else:
                st_coords_np = st_coords
            
            D_st_np, st_max_distance = calculate_D_st_from_coords(
                spatial_coords=st_coords_np, 
                k_neighbors=50, 
                graph_mode="distance"
            )
            self.D_st = torch.tensor(D_st_np, dtype=torch.float32).to(device)
            self.st_max_distance = st_max_distance
            print(f"D_st calculated, shape: {self.D_st.shape}")


        print(f"Final matrices - D_st: {self.D_st.shape if self.D_st is not None else None}, "
            f"D_induced: {self.D_induced.shape if self.D_induced is not None else None}")
        
        # Normalize coordinates
        self.st_coords_norm, self.coords_center, self.coords_radius = self.normalize_coordinates_isotropic(self.st_coords)        
        # Model parameters
        self.n_genes = n_genes or st_gene_expr.shape[1]
        
        # ========== FEATURE ENCODER ==========
        self.netE = self.build_feature_encoder(self.n_genes, n_embedding, dp)

        self.train_log = os.path.join(outf, 'train.log')

        
        # ========== CELL TYPE EMBEDDING ==========

        use_cell_types = (cell_types_sc is not None)  # Check if SC data has cell types
        self.use_cell_types = use_cell_types

        if self.num_cell_types > 0:
            self.cell_type_embedding = CellTypeEmbedding(self.num_cell_types, n_embedding[-1] // 2)
            total_feature_dim = n_embedding[-1] + n_embedding[-1] // 2
        else:
            self.cell_type_embedding = None
            total_feature_dim = n_embedding[-1]
            
        # ========== HIERARCHICAL DIFFUSION COMPONENTS ==========
        # Time embedding
        self.time_embed = nn.Sequential(
            SinusoidalEmbedding(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Coordinate encoder
        self.coord_encoder = nn.Sequential(
            nn.Linear(2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Feature projection (includes cell type if available)
        self.feat_proj = nn.Sequential(
            nn.Linear(total_feature_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        # ========== GRAPH-VAE COMPONENTS (REPLACING HIERARCHICAL DIFFUSION) ==========        
        # Graph-VAE parameters
        self.latent_dim = 32  # As specified in instructions
        
        # Graph-VAE Encoder (learns latent representations from ST graphs)
        self.graph_vae_encoder = GraphVAEEncoder(
            input_dim=n_embedding[-1],  # Aligned embedding dimension
            hidden_dim=128,             # GraphConv hidden dimension  
            latent_dim=self.latent_dim
        ).to(device)

        self.graph_vae_decoder = GraphVAEDecoder(
            latent_dim=self.latent_dim,
            hidden_dim=128  # Remove condition_dim
        ).to(device)
        
        # Latent Denoiser (replaces hierarchical_blocks)
        self.latent_denoiser = LatentDenoiser(
            latent_dim=self.latent_dim,
            condition_dim=n_embedding[-1],
            hidden_dim=hidden_dim,
            n_blocks=n_denoising_blocks
        ).to(device)
        
        # ========== HIERARCHICAL DENOISING BLOCKS ==========
        self.hierarchical_blocks = nn.ModuleList([
            HierarchicalDiffusionBlock(hidden_dim, num_hierarchical_scales)
            for _ in range(n_denoising_blocks)
        ])    

        # ========== PHYSICS-INFORMED COMPONENTS ==========
        self.physics_layer = PhysicsInformedLayer(hidden_dim)
        
        # ========== UNCERTAINTY QUANTIFICATION ==========
        self.uncertainty_head = UncertaintyHead(hidden_dim)
        
        # ========== OPTIMAL TRANSPORT GUIDANCE ==========
        if self.transport_plan is not None:
            self.ot_guidance_strength = nn.Parameter(torch.tensor(0.1))
            
        # ========== OUTPUT LAYERS ==========
        self.noise_predictor = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2)
        )
        
        # Create noise schedule
        # self.noise_schedule = self.create_noise_schedule()
        self.noise_schedule = self.build_noise_schedule(self.n_timesteps)

        self.guidance_scale = 2.0
        
        # Optimizers
        self.setup_optimizers(lr_e, lr_d)
        
        # MMD Loss for domain alignment
        self.mmd_loss = MMDLoss()

        # Move entire model to device
        self.to(self.device)

    def build_noise_schedule(self, T, beta_start=1e-4, beta_end=2e-2):
        """Rebuild noise schedule when T changes"""
        betas = torch.linspace(beta_start, beta_end, T, device=self.device)
        alphas = 1.0 - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        alphas_cumprod_prev = torch.cat([torch.tensor([1.0], device=self.device), alphas_cumprod[:-1]], dim=0)
        posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
        
        return {
            'betas': betas,
            'alphas': alphas,
            'alphas_cumprod': alphas_cumprod,
            'alphas_cumprod_prev': alphas_cumprod_prev,
            'posterior_variance': posterior_variance
        }

    def update_noise_schedule(self):
        """Update noise schedule when n_timesteps changes"""
        self.noise_schedule = self.build_noise_schedule(self.n_timesteps)
        print(f"Updated noise schedule for T={self.n_timesteps}")
        print(f"alpha_bar[0]={self.noise_schedule['alphas_cumprod'][0]:.6f}, alpha_bar[-1]={self.noise_schedule['alphas_cumprod'][-1]:.6f}")

    def setup_spatial_sampling(self):
        if hasattr(self, 'st_coords_norm'):
            self.spatial_sampler = SpatialBatchSampler(
                coordinates=self.st_coords_norm.cpu().numpy(),
                batch_size=self.batch_size
            )
        else:
            self.spatial_sampler = None

    def get_spatial_batch(self):
        """Get spatially contiguous batch for training"""
        if self.spatial_sampler is not None:
            return self.spatial_sampler.sample_spatial_batch()
        else:
            # Fallback to random sampling
            return torch.randperm(len(self.st_coords_norm))[:self.batch_size]
        
    def _evaluate_sigma_quality(self, st_embeddings, k=10):
        """Evaluate how well encoder embeddings preserve spatial k-NN structure"""
        with torch.no_grad():
            # Get k-NN from encoder similarity
            netpred = st_embeddings.mm(st_embeddings.t())
            pred_knn = self._get_knn_from_similarity(netpred, k=k)
            
            # Get k-NN from physical coordinates  
            phys_knn = self._get_knn_from_coords(self.st_coords_norm, k=k)
            
            # Compute overlap
            overlap = (pred_knn == phys_knn).float().mean().item()
            return overlap

    def _get_knn_from_similarity(self, similarity_matrix, k=10):
        """Extract top-k neighbors from similarity matrix"""
        # Get top-k indices for each node
        _, topk_indices = torch.topk(similarity_matrix, k=k+1, dim=1)  # +1 to exclude self
        topk_indices = topk_indices[:, 1:]  # Remove self-connections
        return topk_indices

    def _get_knn_from_coords(self, coords, k=10):
        """Extract top-k spatial neighbors from coordinates"""
        # Compute pairwise distances
        distances = torch.cdist(coords, coords)
        # Get top-k closest (smallest distances)
        _, topk_indices = torch.topk(distances, k=k+1, dim=1, largest=False)  # +1 for self
        topk_indices = topk_indices[:, 1:]  # Remove self-connections  
        return topk_indices
        
    def normalize_coordinates_isotropic(self, coords):
        """Normalize coordinates isotropically to [-1, 1]"""
        center = coords.mean(dim=0)
        centered_coords = coords - center
        max_dist = torch.max(torch.norm(centered_coords, dim=1))
        normalized_coords = centered_coords / (max_dist + 1e-8)
        return normalized_coords, center, max_dist
        

    def build_feature_encoder(self, n_genes, n_embedding, dp):
        """Build the feature encoder network"""
        return FeatureNet(n_genes, n_embedding=n_embedding, dp=dp).to(self.device)
        
    def create_noise_schedule(self):
        """Create the noise schedule for diffusion"""
        betas = torch.linspace(0.0001, 0.02, self.n_timesteps, device=self.device)
        alphas = 1 - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        
        return {
            'betas': betas,
            'alphas': alphas,
            'alphas_cumprod': alphas_cumprod,
            'sqrt_alphas_cumprod': torch.sqrt(alphas_cumprod),
            'sqrt_one_minus_alphas_cumprod': torch.sqrt(1 - alphas_cumprod)
        }
        
    def setup_optimizers(self, lr_e, lr_d):
        """Setup optimizers and schedulers"""
        # Encoder optimizer
        self.optimizer_E = torch.optim.AdamW(self.netE.parameters(), lr=0.002)               
        self.scheduler_E = lr_scheduler.StepLR(self.optimizer_E, step_size=200, gamma=0.5) 

        # MMD Loss
        self.mmd_fn = MMDLoss()   
        
        # Diffusion model optimizer
        diff_params = []
        diff_params.extend(self.time_embed.parameters())
        diff_params.extend(self.coord_encoder.parameters())
        diff_params.extend(self.feat_proj.parameters())
        diff_params.extend(self.hierarchical_blocks.parameters())
        # diff_params.extend(self.geometric_attention_blocks.parameters())
        diff_params.extend(self.physics_layer.parameters())
        diff_params.extend(self.uncertainty_head.parameters())
        diff_params.extend(self.noise_predictor.parameters())
        
        if self.cell_type_embedding is not None:
            diff_params.extend(self.cell_type_embedding.parameters())
            
        if self.transport_plan is not None:
            diff_params.append(self.ot_guidance_strength)
            
        self.optimizer_diff = torch.optim.Adam(diff_params, lr=lr_d, betas=(0.9, 0.999))
        self.scheduler_diff = lr_scheduler.CosineAnnealingWarmRestarts(self.optimizer_diff, T_0=500)
        
    def add_noise(self, coords, t, noise_schedule):
        """Add noise to coordinates according to the diffusion schedule"""
        noise = torch.randn_like(coords)
        sqrt_alphas_cumprod_t = noise_schedule['sqrt_alphas_cumprod'][t].view(-1, 1)
        sqrt_one_minus_alphas_cumprod_t = noise_schedule['sqrt_one_minus_alphas_cumprod'][t].view(-1, 1)
        
        noisy_coords = sqrt_alphas_cumprod_t * coords + sqrt_one_minus_alphas_cumprod_t * noise
        return noisy_coords, noise
        
        
    def forward_diffusion(self, noisy_coords, t, features, cell_types=None):
        """Forward pass through the advanced diffusion model"""
        batch_size = noisy_coords.shape[0]
        
        # Encode inputs
        time_emb = self.time_embed(t)
        coord_emb = self.coord_encoder(noisy_coords)
        
        # Process features with optional cell type
        if cell_types is not None and self.cell_type_embedding is not None:
            cell_type_emb = self.cell_type_embedding(cell_types)
            combined_features = torch.cat([features, cell_type_emb], dim=-1)
        else:
            #when no cell types, pad with zeros to match expected input size
            if self.cell_type_embedding is not None:
                #create zero padding for cell type embedding
                cell_type_dim = self.n_embedding[-1] // 2
                zero_padding = torch.zeros(batch_size, cell_type_dim, device=features.device)
                combined_features = torch.cat([features, zero_padding], dim=-1)
            else:
                combined_features = features
            # combined_features = features
            
        feat_emb = self.feat_proj(combined_features)
        
        # Combine embeddings
        h = coord_emb + time_emb + feat_emb
        
        # Process through hierarchical blocks with geometric attention
        for i, block in enumerate(self.hierarchical_blocks):
            h = block(h, t)
                
        # Predict noise
        noise_pred = self.noise_predictor(h)
        
        # Compute physics-informed correction
        physics_correction, cell_radii = self.physics_layer(noisy_coords, h, cell_types)
        
        # Compute uncertainty
        uncertainty = self.uncertainty_head(h)
        
        # Apply corrections based on timestep (less physics at high noise)
        t_factor = (1 - t).unsqueeze(-1) #shape: (natch_size, 1)
        noise_pred = noise_pred + t_factor * physics_correction * 0.1
        
        return noise_pred, uncertainty, cell_radii
        
    def train_encoder(self, n_epochs=1000, ratio_start=0, ratio_end=1.0):
        """Train the STEM encoder to align ST and SC data"""
        print("Training STEM encoder...")
        
        # Log training start
        with open(self.train_log, 'a') as f:
            localtime = time.asctime(time.localtime(time.time()))
            f.write(f"{localtime} - Starting STEM encoder training\n")
            f.write(f"n_epochs={n_epochs}, ratio_start={ratio_start}, ratio_end={ratio_end}\n")
        
        # Calculate spatial adjacency matrix
        if self.sigma == 0:
            nettrue = torch.eye(self.st_coords.shape[0], device=self.device)
        else:
            nettrue = torch.tensor(scipy.spatial.distance.cdist(
                self.st_coords.cpu().numpy(), 
                self.st_coords.cpu().numpy()
            ), device=self.device).to(torch.float32)
            
            sigma = self.sigma
            nettrue = torch.exp(-nettrue**2/(2*sigma**2))/(np.sqrt(2*np.pi)*sigma)
            nettrue = F.normalize(nettrue, p=1, dim=1)
        
        # Training loop
        for epoch in range(n_epochs):
            # Schedule for circle loss weight
            ratio = ratio_start + (ratio_end - ratio_start) * min(epoch / (n_epochs * 0.8), 1.0)
            
            # Forward pass ST data
            e_seq_st = self.netE(self.st_gene_expr, True)
            
            # Sample from SC data due to large size
            sc_idx = torch.randint(0, self.sc_gene_expr.shape[0], (min(self.batch_size, self.mmdbatch),), device=self.device)
            sc_batch = self.sc_gene_expr[sc_idx]
            e_seq_sc = self.netE(sc_batch, False)
            
            # Calculate losses
            self.optimizer_E.zero_grad()
            
            # Prediction loss (equivalent to netpred in STEM)
            netpred = e_seq_st.mm(e_seq_st.t())
            loss_E_pred = F.cross_entropy(netpred, nettrue, reduction='mean')
            
            # Mapping matrices
            st2sc = F.softmax(e_seq_st.mm(e_seq_sc.t()), dim=1)
            sc2st = F.softmax(e_seq_sc.mm(e_seq_st.t()), dim=1)
            
            # Circle loss
            st2st = torch.log(st2sc.mm(sc2st) + 1e-7)
            loss_E_circle = F.kl_div(st2st, nettrue, reduction='none').sum(1).mean()
            
            # MMD loss
            ranidx = torch.randint(0, e_seq_sc.shape[0], (min(self.mmdbatch, e_seq_sc.shape[0]),), device=self.device)
            loss_E_mmd = self.mmd_fn(e_seq_st, e_seq_sc[ranidx])
            
            # Total loss
            loss_E = loss_E_pred + self.alpha * loss_E_mmd + ratio * loss_E_circle
            
            # Backward and optimize
            loss_E.backward()
            self.optimizer_E.step()
            self.scheduler_E.step()
            
            # Log progress
            if epoch % 200 == 0:
                log_msg = (f"Encoder epoch {epoch}/{n_epochs}, "
                          f"Loss_E: {loss_E.item():.6f}, "
                          f"Loss_E_pred: {loss_E_pred.item():.6f}, "
                          f"Loss_E_circle: {loss_E_circle.item():.6f}, "
                          f"Loss_E_mmd: {loss_E_mmd.item():.6f}, "
                          f"Ratio: {ratio:.4f}")
                
                print(log_msg)
                with open(self.train_log, 'a') as f:
                    f.write(log_msg + '\n')
                
                # Save checkpoint
                if epoch % 500 == 0:
                    torch.save({
                        'epoch': epoch,
                        'netE_state_dict': self.netE.state_dict(),
                        'optimizer_state_dict': self.optimizer_E.state_dict(),
                        'scheduler_state_dict': self.scheduler_E.state_dict(),
                    }, os.path.join(self.outf, f'encoder_checkpoint_epoch_{epoch}.pt'))
    

        print("\n" + "="*50)
        print("EVALUATING SIGMA QUALITY")
        print("="*50)
        
        # Evaluate current sigma
        with torch.no_grad():
            self.netE.eval()
            st_embeddings = self.netE(self.st_gene_expr, True)  # Get final ST embeddings
            current_overlap = self._evaluate_sigma_quality(st_embeddings, k=10)
            print(f"Current sigma ({self.sigma:.4f}) -> kNN overlap = {current_overlap:.3f}")
        
        # Test different sigma values to find optimal
        print("\nTesting different sigma values...")
        sigma_candidates = [
            self.sigma * 0.5,   # Half current
            self.sigma * 0.75,  # 3/4 current  
            self.sigma,         # Current (baseline)
            self.sigma * 1.25,  # 5/4 current
            self.sigma * 1.5,   # 1.5x current
            self.sigma * 2.0,    # Double current
            self.sigma * 2.5,   # Double current
            self.sigma * 3.0,    # Double current
            self.sigma * 4.0    # Double current

        ]
        
        overlaps = []
        for test_sigma in sigma_candidates:
            # Recompute adjacency with test sigma
            if test_sigma == 0:
                test_nettrue = torch.eye(self.st_coords.shape[0], device=self.device)
            else:
                distances = torch.tensor(scipy.spatial.distance.cdist(
                    self.st_coords.cpu().numpy(), 
                    self.st_coords.cpu().numpy()
                ), device=self.device).to(torch.float32)
                
                test_nettrue = torch.exp(-distances**2/(2*test_sigma**2))/(np.sqrt(2*np.pi)*test_sigma)
                test_nettrue = F.normalize(test_nettrue, p=1, dim=1)
            
            # Quick test: how well does current encoder match this adjacency?
            with torch.no_grad():
                netpred = st_embeddings.mm(st_embeddings.t())
                pred_knn = self._get_knn_from_similarity(netpred, k=15)
                true_knn = self._get_knn_from_similarity(test_nettrue, k=15)
                overlap = (pred_knn == true_knn).float().mean().item()
                overlaps.append(overlap)
                
            print(f"  sigma = {test_sigma:.4f} -> overlap = {overlap:.5f}")
        
        # Find best sigma
        best_idx = np.argmax(overlaps)
        best_sigma = sigma_candidates[best_idx]
        best_overlap = overlaps[best_idx]

        # print(overlaps)
        
        print(f"\nBest sigma: {best_sigma:.4f} (overlap = {best_overlap:.5f})")
        if best_sigma != self.sigma:
            print(f"⚠️  Consider using sigma = {best_sigma:.4f} instead of {self.sigma:.4f}")
            print(f"   Improvement: {best_overlap:.3f} vs {current_overlap:.3f} (+{(best_overlap-current_overlap)*100:.1f}%)")
        else:
            print("✅ Current sigma is optimal!")
        
        print("="*50)
        # ===================================
        
        # Save final encoder
        torch.save({
            'netE_state_dict': self.netE.state_dict(),
        }, os.path.join(self.outf, 'final_encoder.pt'))
        
        print("Encoder training complete!")

    def train_graph_vae(self, epochs=800, lr=1e-3, warmup_epochs=320,  # 40% of epochs
                    lambda_cov_max=0.3, angle_loss_weight=0.2, 
                    radius_loss_weight=1.0, angle_warmup_epochs=0, 
                    beta_final=1e-3):  # Small β_final to prevent blow-up
        """
        Train the Graph-VAE with:
        1) normalized KL (no free-bits, mean over batch and dims)
        2) covariance warm-up (fix axes)  
        3) geometry-anchored angle/radius loss (fix gauge deterministically)
        4) NO reconstruction loss, NO gradient loss (as requested)
        """
        print("Training Graph-VAE with normalized KL and polar losses...")
        
        # Freeze encoder
        self.netE.eval()
        for p in self.netE.parameters():
            p.requires_grad = False

        # Build ST graph
        adj_idx, adj_w = precompute_knn_edges(self.st_coords_norm, k=30, device=self.device)

        # Precompute aligned features (ST only for training)
        with torch.no_grad():
            st_features_aligned = self.netE(self.st_gene_expr).float()

        # Precompute canonical frame and angle/radius targets for ST
        print('Computing canonical angular frame for geometry anchoring....')
        c, a_theta, R, theta_true, r_true = self._compute_canonical_frame(self.st_coords_norm)

        # Precompute true covariance of ST spots
        with torch.no_grad():
            centered = self.st_coords_norm - self.st_coords_norm.mean(0, keepdim=True)
            self.cov_true = (centered.T @ centered) / (centered.shape[0] - 1)

        # Optimizer + scheduler
        vae_params = list(self.graph_vae_encoder.parameters()) + list(self.graph_vae_decoder.parameters())
        optimizer = torch.optim.Adam(vae_params, lr=lr, weight_decay=1e-5)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

        # Initialize loss logs
        for key in ('total','reconstruction','kl','kl_raw','cov','lambda_cov','angle','radius','beta','epochs'):
            self.vae_losses.setdefault(key, [])

        # Training loop
        for epoch in range(epochs):
            optimizer.zero_grad()

            # Compute loss weights with warmups
            lambda_cov = lambda_cov_max * min(epoch+1, warmup_epochs) / warmup_epochs
            w_theta = 0.0 if epoch < angle_warmup_epochs else angle_loss_weight
            w_radius = 0.0 if epoch < angle_warmup_epochs else radius_loss_weight
            
            # KL warm-up: β from 0 to β_final (small value)
            beta = beta_final * min(epoch+1, warmup_epochs) / warmup_epochs

            # Forward pass for ST spots only
            mu_st, logvar_st = self.graph_vae_encoder(st_features_aligned, adj_idx, adj_w)
            z_st = self.graph_vae_encoder.reparameterize(mu_st, logvar_st)
            
            # Decoder takes ONLY z (no features)
            coords_pred_st = self.graph_vae_decoder(z_st)
            
            # 1) NO Reconstruction loss (as requested)
            L_recon = torch.tensor(0.0, device=self.device)

            # 2) Normalized KL divergence (Option A - no free-bits)
            # KL per dim (positive values)
            kl_per_dim = -0.5 * (1 + logvar_st - mu_st.pow(2) - logvar_st.exp())  # [N, latent_dim]
            
            # Mean over batch AND latent dims (this prevents blow-up)
            KL_raw = kl_per_dim.mean()  # O(1) scale, not O(latent_dim)
            L_KL = beta * KL_raw

            # 3) Covariance alignment (ST spots only, warm-up)
            coords_pred_centered = coords_pred_st - coords_pred_st.mean(0, keepdim=True)
            cov_pred = (coords_pred_centered.T @ coords_pred_centered) / (coords_pred_centered.shape[0] - 1)
            L_cov = F.mse_loss(cov_pred, self.cov_true)

            # 4) NO Gradient anchoring (as requested)
            # L_grad = 0.0

            # 5) Geometry-anchored angle/radius losses
            # Compute predicted angles/radii using the SAME canonical frame (c, a_theta, R)
            v_hat = coords_pred_st - c 
            cross = a_theta[0] * v_hat[:, 1] - a_theta[1] * v_hat[:, 0]
            dot = a_theta[0] * v_hat[:, 0] + a_theta[1] * v_hat[:, 1]
            theta_hat = torch.atan2(cross, dot)
            r_hat = v_hat.norm(dim=1) / (R + 1e-8)

            # Circular angle loss: 1 - cos(delta)
            delta = theta_hat - theta_true
            L_angle_raw = (1.0 - torch.cos(delta)).mean()
            L_angle = w_theta * L_angle_raw

            # Radius loss
            L_radius_raw = F.mse_loss(r_hat, r_true)
            L_radius = w_radius * L_radius_raw

            # 6) Total loss
            total_loss = (
                L_recon +           # 0.0 as requested
                L_KL +              # Small, normalized
                lambda_cov * L_cov +
                L_angle + 
                L_radius
            )
            
            total_loss.backward()
            optimizer.step()
            scheduler.step()

            # Log losses
            self.vae_losses['total'].append(total_loss.item())
            self.vae_losses['reconstruction'].append(L_recon.item())
            self.vae_losses['kl'].append(L_KL.item())
            self.vae_losses['kl_raw'].append(KL_raw.item())
            self.vae_losses['cov'].append(L_cov.item())
            self.vae_losses['lambda_cov'].append(lambda_cov)
            self.vae_losses['angle'].append(L_angle.item())
            self.vae_losses['radius'].append(L_radius.item())
            self.vae_losses['beta'].append(beta)
            self.vae_losses['epochs'].append(epoch)

            if epoch % 100 == 0 or epoch == epochs-1:
                # Compute interpretable angle error in degrees
                with torch.no_grad():
                    mean_angle_err_deg = torch.mean(torch.abs(torch.rad2deg(torch.atan2(torch.sin(delta), torch.cos(delta))))).item()
                
                print(f"Epoch {epoch+1}/{epochs}  "
                    f"Loss={total_loss:.4f}  "
                    f"L_recon={L_recon:.4f}  "
                    f"L_KL={L_KL:.6f}(β={beta:.6f})  "
                    f"KL_raw={KL_raw:.4f}  "
                    f"L_cov={L_cov:.4f}  "
                    f"L_angle={L_angle:.4f}  "
                    f"L_radius={L_radius:.4f}  "
                    f"AngleErr={mean_angle_err_deg:.2f}°")

        print("Graph-VAE training complete.")
        
        # Diagnostic: Check if decoder uses z meaningfully
        print("\n=== DECODER DEPENDENCY DIAGNOSTIC ===")
        with torch.no_grad():
            # Fix features, vary z
            mu_fixed = mu_st[:5]
            logvar_fixed = logvar_st[:5]
            
            coords_samples = []
            for _ in range(5):
                z_sample = self.graph_vae_encoder.reparameterize(mu_fixed, logvar_fixed)
                coords_sample = self.graph_vae_decoder(z_sample)
                coords_samples.append(coords_sample)
            
            coords_stack = torch.stack(coords_samples)  # (5, 5, 2)
            coord_variance = coords_stack.var(dim=0).mean().item()  # Average variance across samples
            
            print(f"Coordinate variance from z sampling: {coord_variance:.6f}")
            if coord_variance < 1e-4:
                print("⚠️  WARNING: Low variance suggests potential posterior collapse!")
            else:
                print("✅ Good: Decoder shows dependency on z")
        print("=" * 50)


    def _compute_canonical_frame(self, X_st):
        '''compute canonical angular frame from ST coordinates for geometry anchoring
        
        returns:
            c: centroid (2, )
            a_theta: reference direction vector (2, )
            R: max radius (scalar)
            theta_true: true_angles (N, )
            r_true: true normalized radii (N,)
        '''
        #centroid
        c = X_st.mean(dim=0)

        #find farthest point to define reference direction
        d = torch.linalg.norm(X_st - c, dim=1)
        A = torch.argmax(d).item()
        a_theta = (X_st[A] - c)
        R = d.max().clamp_min(1e-8)

        #compute true angles and radii for all points
        v = X_st - c
        cross = a_theta[0] * v[:, 1] - a_theta[1] * v[:, 0]
        dot = a_theta[0] * v[:, 0] + a_theta[1] * v[:, 1]
        theta_true = torch.atan2(cross, dot)
        r_true = (v.norm(dim=1) / R)

        print(f"Canonical frame: center=({c[0]:.3f}, {c[1]:.3f}), "
            f"ref_dir=({a_theta[0]:.3f}, {a_theta[1]:.3f}), max_radius={R:.3f}")
        
        return c, a_theta, R, theta_true, r_true
    
    def _compute_spatial_pc1(self):
        """
        Compute first spatial principal component from continuous SVGs for anchoring
        """
        import numpy as np
        from scipy.ndimage import gaussian_filter
        from sklearn.decomposition import TruncatedSVD
        
        # Step 1: Filter to continuous genes (5%-95% expression)
        st_expr = self.st_gene_expr.cpu().numpy()
        nonzero_frac = (st_expr > 0).mean(0)
        mask = (nonzero_frac >= 0.05) & (nonzero_frac <= 0.95)
        expr_cont = st_expr[:, mask]
        
        print(f"Filtered to {expr_cont.shape[1]} continuous genes from {st_expr.shape[1]} total")
        
        if expr_cont.shape[1] < 10:
            print("Warning: Too few continuous genes, using all genes")
            expr_cont = st_expr
        
        # Step 2: Smooth expression and compute PCA
        expr_smooth = gaussian_filter(expr_cont, sigma=(1, 0))  # smooth over spots only
        
        # Compute first PC
        svd = TruncatedSVD(n_components=1, random_state=42)
        pc1 = svd.fit_transform(expr_smooth).flatten()
        
        print(f"PC-1 explains {svd.explained_variance_ratio_[0]:.3f} of spatial variance")
        
        return torch.tensor(pc1, device=self.device, dtype=torch.float32)

    def train_diffusion_latent(self, n_epochs=400, lambda_struct=10.0, p_drop=0.05, posterior_temp_floor=0.6):
        """
        Train latent-space conditional DDPM as a proper conditional prior p(z|h).
        """
        print("Training latent-space diffusion model...")
        
        # Freeze encoder and Graph-VAE encoder
        self.netE.eval()
        self.graph_vae_encoder.eval()
        for param in self.netE.parameters():
            param.requires_grad = False
        for param in self.graph_vae_encoder.parameters():
            param.requires_grad = False
        
        # Precompute fixed ST latents as specified
        print("Computing fixed ST latents...")
        st_adj_idx, st_adj_w = precompute_knn_edges(self.st_coords_norm, k=30, device=self.device)
        
        with torch.no_grad():
            st_features_aligned = self.netE(self.st_gene_expr).float()
            st_mu, st_logvar = self.graph_vae_encoder(st_features_aligned, st_adj_idx, st_adj_w)
            # z_st = self.graph_vae_encoder.reparameterize(st_mu, st_logvar)
        
        # Setup optimizer for latent denoiser
        optimizer_latent = torch.optim.AdamW(
            self.latent_denoiser.parameters(), 
            lr=self.optimizer_diff.param_groups[0]['lr'], 
            weight_decay=1e-5
        )
        scheduler_latent = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer_latent, T_max=n_epochs, eta_min=1e-6
        )
        
        # Training loop - identical to old train_diffusion but in latent space
        best_loss = float('inf')
        
        # Training loop
        for epoch in range(n_epochs):
            # Sample batch indices
            idx = torch.randperm(len(st_mu))[:self.batch_size]
            batch_mu = st_mu[idx]
            batch_logvar = st_logvar[idx]
            batch_h = st_features_aligned[idx]
            
            # Sample FRESH z0 from posterior (with temperature floor)
            eps0 = torch.randn_like(batch_mu)
            posterior_std = torch.sqrt(torch.exp(batch_logvar) + posterior_temp_floor**2)
            z0 = batch_mu + posterior_std * eps0
            
            # Sample random timesteps
            t = torch.randint(0, self.n_timesteps, (len(z0),), device=self.device)
            
            # Forward diffusion: add noise to z0
            eps = torch.randn_like(z0)
            alpha_bar_t = self.noise_schedule['alphas_cumprod'][t].view(-1, 1)
            z_t = torch.sqrt(alpha_bar_t) * z0 + torch.sqrt(1 - alpha_bar_t) * eps
            
            # Classifier-free guidance: randomly drop conditioning
            cond = batch_h.clone()
            drop_mask = (torch.rand(len(cond), 1, device=self.device) < p_drop).float()
            cond = cond * (1 - drop_mask)  # Zero out dropped rows
            
            # Predict noise
            t_norm = (t.float().unsqueeze(1) / max(self.n_timesteps - 1, 1)).clamp(0, 1)
            # t_norm = torch.full((B, 1), t / max(self.n_timesteps - 1, 1), device=self.device)

            eps_pred = self.latent_denoiser(z_t, t_norm, cond)

            # v_tgt = torch.sqrt(alpha_bar_t) * eps - torch.sqrt(1 - alpha_bar_t) * z0
            # v_pred = self.latent_denoiser(z_t, t_norm, cond)
            
            # Diffusion loss
            loss_diffusion = F.mse_loss(eps_pred, eps)
            # loss_diffusion = F.mse_loss(v_pred, v_tgt)

            
            # Structure loss on x0 (denoised latent)
            loss_struct = torch.tensor(0.0, device=self.device)
            if lambda_struct > 0:
                # Predict x0 from eps_pred
                x0_pred = (z_t - torch.sqrt(1 - alpha_bar_t) * eps_pred) / torch.sqrt(alpha_bar_t + 1e-8)
                
                # Preserve latent geometry
                D_target = torch.cdist(z0, z0, p=2)
                D_pred = torch.cdist(x0_pred, x0_pred, p=2)
                loss_struct = F.mse_loss(D_pred, D_target)
            
            # Total loss
            # total_loss = loss_diffusion + lambda_struct * loss_struct
            total_loss = loss_diffusion

            # Record losses for plotting
            self.latent_diffusion_losses['total'].append(total_loss.item())
            self.latent_diffusion_losses['diffusion'].append(loss_diffusion.item())
            self.latent_diffusion_losses['struct'].append(loss_struct.item() if isinstance(loss_struct, torch.Tensor) else loss_struct)
            self.latent_diffusion_losses['epochs'].append(epoch)
            
            # Backward pass
            optimizer_latent.zero_grad()
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.latent_denoiser.parameters(), 1.0)
            optimizer_latent.step()
            scheduler_latent.step()
            
            # Logging
            if epoch % 500 == 0:
                ls = float(loss_struct) if isinstance(loss_struct, torch.Tensor) else loss_struct
                log_msg = (
                    f"Latent Diffusion epoch {epoch}/{n_epochs}, "
                    f"Total: {total_loss.item():.6f}, "
                    f"Diffusion: {loss_diffusion.item():.6f}, "
                    f"Struct: {ls:.6f}"
                )
                print(log_msg)
                with open(self.train_log, 'a') as f:
                    f.write(log_msg + '\n')

            
            # Save checkpoint
            if epoch % 500 == 0:
                torch.save({
                    'epoch': epoch,
                    'latent_denoiser_state_dict': self.latent_denoiser.state_dict(),
                    'optimizer_state_dict': optimizer_latent.state_dict(),
                }, os.path.join(self.outf, f'latent_diffusion_checkpoint_epoch_{epoch}.pt'))
        
        # Save final model
        torch.save({
            'latent_denoiser_state_dict': self.latent_denoiser.state_dict(),
        }, os.path.join(self.outf, 'final_latent_diffusion.pt'))
        
        print("Latent diffusion training complete!")
                        

    def train(self, encoder_epochs=1000, vae_epochs=800, diffusion_epochs=400, **kwargs):
        """
        Combined training pipeline: encoder → graph_vae → diffusion_latent
        ⚠️ Do **not** touch `train_encoder`; its aligned embeddings are the sole conditioning signal throughout.
        """
        print("Starting Graph-VAE + Latent Diffusion training pipeline...")
        
        # Stage 1: Train encoder (DO NOT MODIFY - keep existing train_encoder)
        print("Stage 1: Training domain alignment encoder...")
        self.train_encoder(n_epochs=encoder_epochs)
        
        # Stage 2: Train Graph-VAE
        print("Stage 2: Training Graph-VAE...")
        self.train_graph_vae(epochs=vae_epochs)
        
        # Stage 3: Train latent diffusion
        print("Stage 3: Training latent diffusion...")
        self.train_diffusion_latent(n_epochs=diffusion_epochs, **kwargs)
        
        print("Complete training pipeline finished!")

    def refine_coordinates_scale_free(self, coords, tau=0.9):
        '''scale free coordinate refinement using relative min seperation
        
        tau: factor for minimum seperation (0.8-1.0) where 1.0 is tight'''

        device = coords.device
        n = coords.shape[0]

        def compute_nearest_neighbor_distances(X):
            '''compute distance to nearest enighbor for each point'''
            distances = torch.cdist(X, X)
            distances.fill_diagonal_(float('inf'))
            nn_distances, _ = distances.min(dim=1)
            return nn_distances
        
        def pairwise_distances_squared(X):
            '''compute pairwise squared distances'''
            s = (X * X).sum(1, keepdim=True)
            S = s + s.T - 2 * X @ X.T
            S = S.clamp_min_(0.0)
            S.fill_diagonal_(0.0)
            return S
        
        def gram_matrix(S):
            '''convert gram matrix back to distance matrix'''
            r = S.mean(1, keepdim=True)
            c = S.mean(0, keepdim=True)
            m = S.mean()
            B = -0.5 * ( S - r - c + m)
            return B
        
        def distance_from_gram(B):
            """Convert Gram matrix back to distance matrix"""
            d = torch.diag(B)
            S = d[:, None] + d[None, :] - 2 * B
            S = S.clamp_min_(0.0)
            S.fill_diagonal_(0.0)
            return S
        
        def project_to_euclidean(S):
            """Project distance matrix to Euclidean space"""
            B = gram_matrix(S)
            vals, vecs = torch.linalg.eigh(B)
            vals = vals.clamp_min(0.0)
            
            # Keep top 2 eigenvalues for 2D
            idx = torch.argsort(vals, descending=True)[:2]
            vals = vals[idx]
            vecs = vecs[:, idx]
            
            B_proj = (vecs * vals.sqrt()) @ (vecs * vals.sqrt()).T
            return distance_from_gram(B_proj)
        
        def enforce_min_separation(S, min_dist):
            """Enforce minimum distance constraints"""
            S = S.clone()
            min_dist_sq = min_dist ** 2
            mask = ~torch.eye(n, dtype=torch.bool, device=device)
            S[mask] = torch.maximum(S[mask], torch.full_like(S[mask], min_dist_sq))
            S.fill_diagonal_(0.0)
            return 0.5 * (S + S.T)
        
        def coords_from_distances(S):
            """Extract 2D coordinates from distance matrix"""
            B = gram_matrix(S)
            vals, vecs = torch.linalg.eigh(B)
            idx = torch.argsort(vals, descending=True)[:2]
            L = torch.diag(vals[idx].clamp_min(0.0).sqrt())
            U = vecs[:, idx] @ L
            return U
        
        # def align_to_original(U, X_orig):
        #     """Align refined coordinates to original using Procrustes"""
        #     U_center = U.mean(0, keepdim=True)
        #     X_center = X_orig.mean(0, keepdim=True)
        #     U_centered = U - U_center
        #     X_centered = X_orig - X_center
            
        #     M = U_centered.T @ X_centered
        #     U_svd, _, Vt = torch.linalg.svd(M, full_matrices=False)
        #     R = U_svd @ Vt
            
        #     return (U @ R.T) + X_center
        
        # def align_to_original(U, X_orig):
        #     Um = U.mean(0)
        #     Xm = X_orig.mean(0)
        #     Uc = U - Um
        #     Xc = X_orig - Xm
        #     M = Uc.T @ Xc
        #     U_svd, _, Vt = torch.linalg.svd(M, full_matrices=False)
        #     R = U_svd @ Vt
        #     # avoid reflection
        #     if torch.linalg.det(R) < 0:
        #         U_svd[:, -1] *= -1
        #         R = U_svd @ Vt
        #     t = Xm - (Um @ R)
        #     return U @ R + t
        
        def align_to_original(U, X_orig):
            Um = U.mean(0)
            Xm = X_orig.mean(0)
            Uc = U - Um
            Xc = X_orig - Xm  # Fixed: was X_orig - Xc
            M = Uc.T @ Xc
            U_svd, _, Vt = torch.linalg.svd(M, full_matrices=False)
            R = U_svd @ Vt
            # avoid reflection
            if torch.linalg.det(R) < 0:
                U_svd[:, -1] *= -1
                R = U_svd @ Vt
            t = Xm - (Um @ R)
            return U @ R + t


        
        # Step 1: Compute scale-free minimum separation
        nn_distances = compute_nearest_neighbor_distances(coords)
        median_nn_dist = torch.median(nn_distances)
        min_separation = tau * median_nn_dist
        
        print(f"Nearest neighbor distances: min={nn_distances.min():.4f}, "
            f"median={median_nn_dist:.4f}, max={nn_distances.max():.4f}")
        print(f"Using minimum separation: {min_separation:.4f} (tau={tau})")
        
        # Step 2: Refinement loop
        S = pairwise_distances_squared(coords)
        
        for iteration in range(100):
            S_prev = S.clone()
            
            # Project to Euclidean distance matrix
            S = project_to_euclidean(S)
            
            # Enforce minimum separation
            S = enforce_min_separation(S, min_separation)
            
            # Check convergence
            change = (S - S_prev).norm() / (S_prev.norm() + 1e-12)
            if change < 1e-6:
                print(f"Converged after {iteration + 1} iterations")
                break
        
        # Step 3: Extract refined coordinates
        coords_refined = coords_from_distances(S)
        coords_final = align_to_original(coords_refined, coords)
        # coords_final = lock_orientation(coords_refined, coords)
        
        # Compute final statistics
        final_nn_distances = compute_nearest_neighbor_distances(coords_final)
        print(f"After refinement - min distance: {final_nn_distances.min():.4f}, "
            f"median: {torch.median(final_nn_distances):.4f}")
        
        return coords_final



    def sample_sc_coordinates_batched(self, batch_size=512, guidance_scale=1.0, refine_coords=True, refinement_tau=0.9):
        """
        Sample SC coordinates using pure noise → guided denoising → decode.
        No Graph-VAE latents used during sampling.
        """
        n_total = len(self.sc_gene_expr)
        print(f"Sampling {n_total} SC coordinates using pure noise → guided diffusion → decode...")
        print(f"Using guidance scale: {guidance_scale}")
        
        # Set models to eval mode
        self.netE.eval()
        self.graph_vae_decoder.eval()
        self.latent_denoiser.eval()
        
        # Update noise schedule for current n_timesteps
        self.update_noise_schedule()
        
        all_coords = []
        n_batches = (n_total + batch_size - 1) // batch_size
        
        with torch.no_grad():
            for batch_idx in range(n_batches):
                start_idx = batch_idx * batch_size
                end_idx = min(start_idx + batch_size, n_total)
                batch_sc_expr = self.sc_gene_expr[start_idx:end_idx]
                
                print(f"Processing batch {batch_idx + 1}/{n_batches} ({len(batch_sc_expr)} cells)...")
                
                # Get aligned SC embeddings
                h_sc = self.netE(batch_sc_expr).float()
                B = h_sc.size(0)
                
                # Start from PURE NOISE in latent space
                z_t = torch.randn(B, self.latent_dim, device=self.device)
                
                # Reverse diffusion with classifier-free guidance
                for t in reversed(range(self.n_timesteps)):
                    # t_norm = torch.full((B, 1), t / (self.n_timesteps - 1), device=self.device)
                    t_norm = torch.full((B, 1), t / max(self.n_timesteps - 1, 1), device=self.device)
                    
                    # Classifier-free guidance: conditional and unconditional predictions
                    eps_c = self.latent_denoiser(z_t, t_norm, h_sc)
                    eps_u = self.latent_denoiser(z_t, t_norm, torch.zeros_like(h_sc))
                    eps_pred = (1 + guidance_scale) * eps_c - guidance_scale * eps_u

                    # eps_pred = eps_pred.view(B, self.latent_dim)
                    
                    # DDPM reverse step
                    alpha_t = self.noise_schedule['alphas'][t]
                    alpha_bar_t = self.noise_schedule['alphas_cumprod'][t]
                    beta_t = self.noise_schedule['betas'][t]
                    
                    # Compute mean
                    mu = (1 / torch.sqrt(alpha_t)) * (
                        z_t - ((1 - alpha_t) / torch.sqrt(1 - alpha_bar_t)) * eps_pred
                    )

                    # mu = mu.view(B, self.latent_dim)
                    
                    if t > 0:
                        # Add noise (use posterior variance for better sampling)
                        sigma_t = torch.sqrt(self.noise_schedule['posterior_variance'][t])
                        noise = torch.randn_like(z_t)
                        z_t = mu + sigma_t * noise
                    else:
                        z_t = mu
                
                # Decode final latent to coordinates (z ONLY, no features)
                batch_coords = self.graph_vae_decoder(z_t)

                
                # Move to CPU and store
                all_coords.append(batch_coords.cpu()) 
                
                # Clear GPU cache
                torch.cuda.empty_cache()
        
        # Combine all batches
        final_coords = torch.cat(all_coords, dim=0)

        # Scale-free coordinate refinement on ALL coordinates
        if refine_coords:
            print("Applying scale-free refinement to all coordinates...")
            # Convert back to tensor on device for refinement
            coords_tensor = final_coords.to(self.device)
            refined_coords = self.refine_coordinates_scale_free(coords_tensor, tau=refinement_tau)
            final_coords = refined_coords.cpu()
        
        print("Pure noise → guided diffusion → decode sampling complete!")
        return final_coords.cpu().numpy()


    def plot_training_losses(self):
        """Plot training losses for Graph-VAE + Latent Diffusion pipeline"""
        import matplotlib.pyplot as plt
        import numpy as np
        
        # Determine how many subplots we need
        n_plots = 0
        if len(self.vae_losses['epochs']) > 0:
            n_plots += 2  # VAE losses and VAE smoothed
        if len(self.latent_diffusion_losses['epochs']) > 0:
            n_plots += 2  # Latent diffusion losses and smoothed
        
        if n_plots == 0:
            print("No training losses to plot.")
            return
        
        # Create figure with appropriate number of subplots
        if n_plots == 2:
            fig, axes = plt.subplots(1, 2, figsize=(15, 5))
            axes = [axes] if n_plots == 2 else axes
        elif n_plots == 4:
            fig, axes = plt.subplots(2, 2, figsize=(15, 10))
            axes = axes.flatten()
        else:
            fig, axes = plt.subplots(1, n_plots, figsize=(7*n_plots, 5))
            if n_plots == 1:
                axes = [axes]
        
        plot_idx = 0
        
        # Plot 1: Graph-VAE losses
        if len(self.vae_losses['epochs']) > 0:
            epochs_vae = np.array(self.vae_losses['epochs'])
            ax = axes[plot_idx]
            
            ax.plot(epochs_vae, self.vae_losses['total'], 'b-', label='Total VAE Loss', linewidth=2)
            ax.plot(epochs_vae, self.vae_losses['reconstruction'], 'g-', label='Reconstruction Loss', linewidth=2)
            ax.plot(epochs_vae, np.array(self.vae_losses['kl']) * 0.01, 'r--', label='KL Loss (×0.01)', alpha=0.8)
            
            ax.set_xlabel('Epoch')
            ax.set_ylabel('Loss')
            ax.set_title('Graph-VAE Training Losses')
            ax.legend()
            ax.grid(True, alpha=0.3)
            ax.set_yscale('log')
            plot_idx += 1
            
            # Plot 2: Graph-VAE smoothed
            if len(self.vae_losses['total']) > 1:
                ax = axes[plot_idx]
                window = min(50, len(self.vae_losses['total']) // 10)
                if window > 1:
                    smoothed = np.convolve(self.vae_losses['total'], 
                                        np.ones(window)/window, mode='valid')
                    smooth_epochs = epochs_vae[window-1:]
                    ax.plot(epochs_vae, self.vae_losses['total'], 'lightblue', alpha=0.5, label='Raw')
                    ax.plot(smooth_epochs, smoothed, 'blue', linewidth=2, label=f'Smoothed (window={window})')
                else:
                    ax.plot(epochs_vae, self.vae_losses['total'], 'blue', linewidth=2)
                
                ax.set_xlabel('Epoch')
                ax.set_ylabel('Loss')
                ax.set_title('Graph-VAE Total Loss (Smoothed)')
                if window > 1:
                    ax.legend()
                ax.grid(True, alpha=0.3)
                ax.set_yscale('log')
                plot_idx += 1
        
        # Plot 3: Latent Diffusion losses
        if len(self.latent_diffusion_losses['epochs']) > 0:
            epochs_diff = np.array(self.latent_diffusion_losses['epochs'])
            ax = axes[plot_idx]
            
            ax.plot(epochs_diff, self.latent_diffusion_losses['total'], 'b-', label='Total Loss', linewidth=2)
            ax.plot(epochs_diff, self.latent_diffusion_losses['diffusion'], 'k-', label='Diffusion Loss', linewidth=2)
            
            # Only plot struct loss if it's non-zero
            struct_losses = np.array(self.latent_diffusion_losses['struct'])
            if np.any(struct_losses > 0):
                ax.plot(epochs_diff, struct_losses, 'r--', label='Structure Loss', alpha=0.8)
            
            ax.set_xlabel('Epoch')
            ax.set_ylabel('Loss')
            ax.set_title('Latent Diffusion Training Losses')
            ax.legend()
            ax.grid(True, alpha=0.3)
            ax.set_yscale('log')
            plot_idx += 1
            
            # Plot 4: Latent Diffusion smoothed
            if len(self.latent_diffusion_losses['total']) > 1:
                ax = axes[plot_idx]
                window = min(50, len(self.latent_diffusion_losses['total']) // 10)
                if window > 1:
                    smoothed = np.convolve(self.latent_diffusion_losses['total'],
                                        np.ones(window)/window, mode='valid')
                    smooth_epochs = epochs_diff[window-1:]
                    ax.plot(epochs_diff, self.latent_diffusion_losses['total'], 'lightcoral', alpha=0.5, label='Raw')
                    ax.plot(smooth_epochs, smoothed, 'red', linewidth=2, label=f'Smoothed (window={window})')
                else:
                    ax.plot(epochs_diff, self.latent_diffusion_losses['total'], 'red', linewidth=2)
                
                ax.set_xlabel('Epoch')
                ax.set_ylabel('Loss')
                ax.set_title('Latent Diffusion Total Loss (Smoothed)')
                if window > 1:
                    ax.legend()
                ax.grid(True, alpha=0.3)
                ax.set_yscale('log')
                plot_idx += 1
        
        plt.tight_layout()
        plt.show()
        
        # Print final loss values
        print("\n=== Training Loss Summary ===")
        
        if len(self.vae_losses['total']) > 0:
            print(f"Graph-VAE - Initial Loss: {self.vae_losses['total'][0]:.6f}")
            print(f"Graph-VAE - Final Loss: {self.vae_losses['total'][-1]:.6f}")
            print(f"Graph-VAE - Loss Reduction: {(1 - self.vae_losses['total'][-1]/self.vae_losses['total'][0])*100:.2f}%")
            print(f"Graph-VAE - Final Reconstruction Loss: {self.vae_losses['reconstruction'][-1]:.6f}")
            print(f"Graph-VAE - Final KL Loss: {self.vae_losses['kl'][-1]:.6f}")
        
        if len(self.latent_diffusion_losses['total']) > 0:
            print(f"Latent Diffusion - Initial Loss: {self.latent_diffusion_losses['total'][0]:.6f}")
            print(f"Latent Diffusion - Final Loss: {self.latent_diffusion_losses['total'][-1]:.6f}")
            print(f"Latent Diffusion - Loss Reduction: {(1 - self.latent_diffusion_losses['total'][-1]/self.latent_diffusion_losses['total'][0])*100:.2f}%")
            print(f"Latent Diffusion - Final Diffusion Loss: {self.latent_diffusion_losses['diffusion'][-1]:.6f}")
            if np.any(np.array(self.latent_diffusion_losses['struct']) > 0):
                print(f"Latent Diffusion - Final Structure Loss: {self.latent_diffusion_losses['struct'][-1]:.6f}")


# =====================================================
# PART 4: MMD Loss Implementation
# =====================================================

class MMDLoss(nn.Module):
    def __init__(self, kernel_type='rbf', kernel_mul=2.0, kernel_num=5, fix_sigma=None):
        super(MMDLoss, self).__init__()
        self.kernel_num = kernel_num
        self.kernel_mul = kernel_mul
        self.fix_sigma = fix_sigma
        self.kernel_type = kernel_type

    def guassian_kernel(self, source, target, kernel_mul, kernel_num, fix_sigma):
        n_samples = int(source.size()[0]) + int(target.size()[0])
        total = torch.cat([source, target], dim=0)
        total0 = total.unsqueeze(0).expand(
            int(total.size(0)), int(total.size(0)), int(total.size(1)))
        total1 = total.unsqueeze(1).expand(
            int(total.size(0)), int(total.size(0)), int(total.size(1)))
        L2_distance = ((total0-total1)**2).sum(2)
        if fix_sigma:
            bandwidth = fix_sigma
        else:
            bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)
        bandwidth /= kernel_mul ** (kernel_num // 2)
        bandwidth_list = [bandwidth * (kernel_mul**i)
                          for i in range(kernel_num)]
        kernel_val = [torch.exp(-L2_distance / bandwidth_temp)
                      for bandwidth_temp in bandwidth_list]
        tmp = 0
        for x in kernel_val:
            tmp += x
        return tmp

    def linear_mmd2(self, f_of_X, f_of_Y):
        loss = 0.0
        delta = f_of_X.float().mean(0) - f_of_Y.float().mean(0)
        loss = delta.dot(delta.T)
        return loss

    def forward(self, source, target):
        if self.kernel_type == 'linear':
            return self.linear_mmd2(source, target)
        elif self.kernel_type == 'rbf':
            batch_size = int(source.size()[0])
            kernels = self.guassian_kernel(
                source, target, kernel_mul=self.kernel_mul, kernel_num=self.kernel_num, fix_sigma=self.fix_sigma)
            XX = torch.mean(kernels[:batch_size, :batch_size])
            YY = torch.mean(kernels[batch_size:, batch_size:])
            XY = torch.mean(kernels[:batch_size, batch_size:])
            YX = torch.mean(kernels[batch_size:, :batch_size])
            loss = torch.mean(XX + YY - XY - YX)
            return loss

def analyze_sc_st_patterns(model):
    """Run the complete analysis"""
    
    print("Analyzing SC vs ST expression patterns...")
    
    # Main comparison plot
    common_genes = model.compare_sc_st_expression_patterns(n_genes=20)
    
    # Detailed gene-by-gene analysis
    print(f"\nDetailed analysis for top {len(common_genes)} variable genes...")
    model.plot_detailed_gene_comparison(common_genes, n_genes=10)
    
    # Print some statistics
    print("\nExpression Statistics:")
    print(f"SC data shape: {model.sc_gene_expr.shape}")
    print(f"ST data shape: {model.st_gene_expr.shape}")
    
    with torch.no_grad():
        sc_mean = model.sc_gene_expr.mean(0)
        st_mean = model.st_gene_expr.mean(0)
        
        print(f"SC mean expression: {sc_mean.mean():.3f} ± {sc_mean.std():.3f}")
        print(f"ST mean expression: {st_mean.mean():.3f} ± {st_mean.std():.3f}")

In [None]:
def load_and_process_cscc_data_individual_norm():
    """
    Load and process cSCC data with individual normalization per ST dataset.
    """
    print("Loading cSCC data with individual normalization...")
    
    # Load SC data
    scadata = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/scP2.h5ad')
    
    # Load all 3 ST datasets
    stadata1 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP2.h5ad')
    stadata2 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP2rep2.h5ad')
    stadata3 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP2rep3.h5ad')
    
    # Normalize expression data (same for all)
    for adata in [scadata, stadata1, stadata2, stadata3]:
        sc.pp.normalize_total(adata)
        sc.pp.log1p(adata)
    
    # Create rough cell types for SC data
    scadata.obs['rough_celltype'] = scadata.obs['level1_celltype'].astype(str)
    scadata.obs.loc[scadata.obs['level1_celltype']=='CLEC9A','rough_celltype'] = 'DC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='CD1C','rough_celltype'] = 'DC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='ASDC','rough_celltype'] = 'DC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='PDC','rough_celltype'] = 'PDC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='MDSC','rough_celltype'] = 'DC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='LC','rough_celltype'] = 'DC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='Mac','rough_celltype'] = 'Myeloid cell'
    scadata.obs.loc[scadata.obs['level1_celltype']=='Tcell','rough_celltype'] = 'T cell'
    scadata.obs.loc[scadata.obs['level2_celltype']=='TSK','rough_celltype'] = 'TSK'
    scadata.obs.loc[scadata.obs['level2_celltype'].isin(['Tumor_KC_Basal', 'Tumor_KC_Diff','Tumor_KC_Cyc']),'rough_celltype'] = 'NonTSK'


    
    return scadata, stadata1, stadata2, stadata3

def normalize_coordinates_individually(coords):
    """
    Normalize coordinates to [-1, 1] range individually.
    """
    coords_min = coords.min(axis=0)
    coords_max = coords.max(axis=0)
    coords_range = coords_max - coords_min
    
    # Avoid division by zero
    coords_range[coords_range == 0] = 1.0
    
    # Normalize to [-1, 1]
    coords_normalized = 2 * (coords - coords_min) / coords_range - 1
    
    return coords_normalized, coords_min, coords_max, coords_range

def prepare_individually_normalized_st_data(stadata1, stadata2, stadata3, scadata):
    """
    Normalize each ST dataset individually, then combine.
    """
    print("Preparing individually normalized ST data...")
    
    # Get common genes
    sc_genes = set(scadata.var_names)
    st1_genes = set(stadata1.var_names)
    st2_genes = set(stadata2.var_names)
    st3_genes = set(stadata3.var_names)
    
    common_genes = sorted(list(sc_genes & st1_genes & st2_genes & st3_genes))
    print(f"Common genes across all datasets: {len(common_genes)}")
    
    # Extract aligned expression data
    sc_expr = scadata[:, common_genes].X
    st1_expr = stadata1[:, common_genes].X
    st2_expr = stadata2[:, common_genes].X
    st3_expr = stadata3[:, common_genes].X
    
    # Convert to dense if sparse
    if hasattr(sc_expr, 'toarray'):
        sc_expr = sc_expr.toarray()
    if hasattr(st1_expr, 'toarray'):
        st1_expr = st1_expr.toarray()
    if hasattr(st2_expr, 'toarray'):
        st2_expr = st2_expr.toarray()
    if hasattr(st3_expr, 'toarray'):
        st3_expr = st3_expr.toarray()
    
    # Get spatial coordinates and normalize individually
    st1_coords = stadata1.obsm['spatial']
    st2_coords = stadata2.obsm['spatial']
    st3_coords = stadata3.obsm['spatial']
    
    print("Normalizing coordinates individually...")
    st1_coords_norm, st1_min, st1_max, st1_range = normalize_coordinates_individually(st1_coords)
    st2_coords_norm, st2_min, st2_max, st2_range = normalize_coordinates_individually(st2_coords)
    st3_coords_norm, st3_min, st3_max, st3_range = normalize_coordinates_individually(st3_coords)
    
    print(f"ST1 coord range: [{st1_coords_norm.min():.3f}, {st1_coords_norm.max():.3f}]")
    print(f"ST2 coord range: [{st2_coords_norm.min():.3f}, {st2_coords_norm.max():.3f}]")
    print(f"ST3 coord range: [{st3_coords_norm.min():.3f}, {st3_coords_norm.max():.3f}]")
    
    # Combine all ST data
    st_expr_combined = np.vstack([st1_expr, st2_expr, st3_expr])
    st_coords_combined = np.vstack([st1_coords_norm, st2_coords_norm, st3_coords_norm])
    
    # Create dataset metadata
    dataset_info = {
        'labels': (['dataset1'] * len(st1_expr) + 
                  ['dataset2'] * len(st2_expr) + 
                  ['dataset3'] * len(st3_expr)),
        'normalization_params': {
            'dataset1': {'min': st1_min, 'max': st1_max, 'range': st1_range},
            'dataset2': {'min': st2_min, 'max': st2_max, 'range': st2_range},
            'dataset3': {'min': st3_min, 'max': st3_max, 'range': st3_range}
        }
    }
    
    print(f"Combined ST data shape: {st_expr_combined.shape}")
    print(f"Combined ST coords shape: {st_coords_combined.shape}")
    print(f"SC data shape: {sc_expr.shape}")
    
    # Convert to tensors
    X_sc = torch.tensor(sc_expr, dtype=torch.float32)
    X_st_combined = torch.tensor(st_expr_combined, dtype=torch.float32)
    Y_st_combined = st_coords_combined.astype(np.float32)
    
    return X_sc, X_st_combined, Y_st_combined, dataset_info, common_genes

In [None]:
scadata, stadata1, stadata2, stadata3 = load_and_process_cscc_data_individual_norm()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
scadata

In [None]:
def analyze_cell_interactions_advanced(scadata, coords_key='advanced_diffusion_coords_avg'):
    """Analyze cell-cell interactions using the advanced diffusion coordinates"""
    
    # Get coordinates from scadata
    coords_mean = scadata.obsm[coords_key]
    
    # Compute pairwise distances
    from scipy.spatial.distance import pdist, squareform
    distances = squareform(pdist(coords_mean))
    
    # Get cell types
    cell_types = scadata.obs['rough_celltype'].values
    unique_types = np.unique(cell_types)
    
    # Analyze minimum distances between cell types
    min_distances = {}
    for i, type1 in enumerate(unique_types):
        for j, type2 in enumerate(unique_types):
            if i <= j:  # Include self-interactions
                mask1 = cell_types == type1
                mask2 = cell_types == type2
                
                if i == j:
                    # Same cell type - exclude self
                    sub_dist = distances[np.ix_(mask1, mask2)]
                    np.fill_diagonal(sub_dist, np.inf)
                    if sub_dist.size > 0:
                        min_dist = np.min(sub_dist[sub_dist < np.inf])
                    else:
                        min_dist = np.nan
                else:
                    # Different cell types
                    sub_dist = distances[np.ix_(mask1, mask2)]
                    min_dist = np.min(sub_dist) if sub_dist.size > 0 else np.nan
                
                min_distances[(type1, type2)] = min_dist
    
    # Create interaction matrix visualization
    interaction_matrix = np.full((len(unique_types), len(unique_types)), np.nan)
    for i, type1 in enumerate(unique_types):
        for j, type2 in enumerate(unique_types):
            key = (type1, type2) if i <= j else (type2, type1)
            if key in min_distances:
                interaction_matrix[i, j] = min_distances[key]
                interaction_matrix[j, i] = min_distances[key]
    
    # Plot
    plt.figure(figsize=(10, 8))
    mask = ~np.isnan(interaction_matrix)
    sns.heatmap(interaction_matrix, 
                annot=True, fmt='.3f', 
                xticklabels=unique_types,
                yticklabels=unique_types,
                cmap='coolwarm_r',
                mask=~mask,
                cbar_kws={'label': 'Minimum Distance'})
    plt.title(f'Minimum Distances Between Cell Types ({coords_key})')
    plt.tight_layout()
    plt.show()
    
    return min_distances, interaction_matrix

def visualize_advanced_results_multi_model(scadata):
    """Visualize results from multiple models with uncertainty analysis"""
    
    # Get coordinates from all models
    coords_avg = scadata.obsm['advanced_diffusion_coords_avg']
    coords_rep1 = scadata.obsm['advanced_diffusion_coords_rep1'] 
    coords_rep2 = scadata.obsm['advanced_diffusion_coords_rep2']
    coords_rep3 = scadata.obsm['advanced_diffusion_coords_rep3']
    
    # Calculate uncertainty metrics across models
    all_coords = np.stack([coords_rep1, coords_rep2, coords_rep3], axis=0)  # (3, n_cells, 2)
    coords_std = np.std(all_coords, axis=0)  # Standard deviation across models
    coords_var = np.var(all_coords, axis=0)  # Variance across models
    
    # Total variability (combining x and y dimensions)
    total_std = np.sqrt(coords_std[:, 0]**2 + coords_std[:, 1]**2)
    total_var = coords_var[:, 0] + coords_var[:, 1]
    
    # Create confidence scores (inverse of variability)
    confidence = 1 / (1 + total_std)
    scadata.obs['spatial_confidence'] = confidence
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    
    # 1. Spatial coordinates colored by cell type
    ax = axes[0, 0]
    cell_types = scadata.obs['rough_celltype']
    unique_types = cell_types.unique()
    colors = sns.color_palette('tab20', n_colors=len(unique_types))
    
    for i, ct in enumerate(unique_types):
        mask = cell_types == ct
        ax.scatter(coords_avg[mask, 0], coords_avg[mask, 1], 
                  c=[colors[i]], label=ct, s=30, alpha=0.7)
    ax.set_title('Averaged Spatial Coordinates by Cell Type', fontsize=14)
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
    
    # 2. Model variability (standard deviation across 3 models)
    ax = axes[0, 1]
    scatter = ax.scatter(coords_avg[:, 0], coords_avg[:, 1], 
                        c=total_std, cmap='viridis_r', 
                        s=30, alpha=0.7)
    plt.colorbar(scatter, ax=ax, label='Model Std Dev')
    ax.set_title('Model Prediction Variability', fontsize=14)
    
    # 3. X vs Y coordinate uncertainty
    ax = axes[0, 2]
    scatter = ax.scatter(coords_std[:, 0], coords_std[:, 1], 
                        c=total_std, cmap='plasma', s=30, alpha=0.7)
    plt.colorbar(scatter, ax=ax, label='Total Std')
    ax.set_xlabel('X Coordinate Std')
    ax.set_ylabel('Y Coordinate Std')
    ax.set_title('Coordinate Uncertainty (X vs Y)', fontsize=14)
    
    # 4. Cell density heatmap
    ax = axes[1, 0]
    from scipy.stats import gaussian_kde
    xy = coords_avg.T
    z = gaussian_kde(xy)(xy)
    scatter = ax.scatter(coords_avg[:, 0], coords_avg[:, 1], 
                        c=z, cmap='hot', s=30, alpha=0.7)
    plt.colorbar(scatter, ax=ax, label='Density')
    ax.set_title('Cell Density (Averaged Coordinates)', fontsize=14)
    
    # 5. Confidence scores
    ax = axes[1, 1]
    scatter = ax.scatter(coords_avg[:, 0], coords_avg[:, 1], 
                        c=confidence, cmap='RdYlGn', s=30, alpha=0.7)
    plt.colorbar(scatter, ax=ax, label='Confidence')
    ax.set_title('Prediction Confidence Across Models', fontsize=14)
    
    # 6. Model agreement visualization
    ax = axes[1, 2]
    # Show cells where models agree vs disagree
    high_agreement = total_std < np.percentile(total_std, 25)  # Bottom 25%
    low_agreement = total_std > np.percentile(total_std, 75)   # Top 25%
    
    ax.scatter(coords_avg[high_agreement, 0], coords_avg[high_agreement, 1], 
              c='green', s=20, alpha=0.7, label='High Agreement')
    ax.scatter(coords_avg[low_agreement, 0], coords_avg[low_agreement, 1], 
              c='red', s=20, alpha=0.7, label='Low Agreement')
    ax.scatter(coords_avg[~(high_agreement | low_agreement), 0], 
              coords_avg[~(high_agreement | low_agreement), 1], 
              c='gray', s=10, alpha=0.5, label='Medium Agreement')
    ax.set_title('Model Agreement', fontsize=14)
    ax.legend()
    
    plt.tight_layout()
    plt.show()
    
    return fig, total_std, confidence

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial.distance import cdist
from sklearn.preprocessing import StandardScaler

def procrustes_alignment_2d(coords_to_align, reference_coords):
    """
    Align coords_to_align to reference_coords using 2D Procrustes analysis
    (rotation + translation only, no scaling)
    
    Args:
        coords_to_align: (n_points, 2) coordinates to be aligned
        reference_coords: (n_points, 2) reference coordinates
    
    Returns:
        aligned_coords: (n_points, 2) aligned coordinates
        rotation_matrix: (2, 2) rotation matrix used
        translation: (2,) translation vector used
    """
    assert coords_to_align.shape == reference_coords.shape, "Coordinate shapes must match"
    assert coords_to_align.shape[1] == 2, "Only 2D coordinates supported"
    
    # Center both coordinate sets
    coords_centered = coords_to_align - np.mean(coords_to_align, axis=0)
    ref_centered = reference_coords - np.mean(reference_coords, axis=0)
    
    # Compute cross-covariance matrix
    H = coords_centered.T @ ref_centered
    
    # SVD to find optimal rotation
    U, S, Vt = np.linalg.svd(H)
    
    # Compute rotation matrix
    R = Vt.T @ U.T
    
    # Ensure proper rotation (det(R) = 1)
    if np.linalg.det(R) < 0:
        Vt[-1, :] *= -1
        R = Vt.T @ U.T
    
    # Apply rotation to centered coordinates
    coords_rotated = coords_centered @ R.T
    
    # Compute translation to align centroids
    translation = np.mean(reference_coords, axis=0) - np.mean(coords_rotated, axis=0)
    
    # Apply translation
    aligned_coords = coords_rotated + translation
    
    return aligned_coords, R, translation

def plot_alignment_comparison(coords_list, cell_types, labels, title_prefix=""):
    """
    Plot coordinates before and after alignment for comparison
    
    Args:
        coords_list: List of coordinate arrays [(n_points, 2), ...]
        cell_types: Array of cell type labels
        labels: List of labels for each coordinate set
        title_prefix: Prefix for plot titles
    """
    n_models = len(coords_list)
    fig, axes = plt.subplots(2, n_models, figsize=(5*n_models, 10))
    
    if n_models == 1:
        axes = axes.reshape(2, 1)
    
    # Get unique cell types and colors
    unique_types = np.unique(cell_types)
    colors = plt.cm.tab20(np.linspace(0, 1, len(unique_types)))
    color_map = dict(zip(unique_types, colors))
    
    for i, (coords, label) in enumerate(zip(coords_list, labels)):
        # Plot individual models
        for j, cell_type in enumerate(unique_types):
            mask = cell_types == cell_type
            axes[0, i].scatter(coords[mask, 0], coords[mask, 1], 
                             c=[color_map[cell_type]], label=cell_type, 
                             alpha=0.6, s=20)
        
        axes[0, i].set_title(f'{title_prefix}{label}')
        axes[0, i].set_xlabel('X coordinate')
        axes[0, i].set_ylabel('Y coordinate')
        axes[0, i].legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        axes[0, i].grid(True, alpha=0.3)
    
    # Plot overlay of all models
    for i, (coords, label) in enumerate(zip(coords_list, labels)):
        axes[1, 0].scatter(coords[:, 0], coords[:, 1], 
                          label=label, alpha=0.4, s=15)
    
    axes[1, 0].set_title(f'{title_prefix}All Models Overlay')
    axes[1, 0].set_xlabel('X coordinate')
    axes[1, 0].set_ylabel('Y coordinate')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    # Hide unused subplots
    for i in range(1, n_models):
        axes[1, i].set_visible(False)
    
    plt.tight_layout()
    plt.show()

def train_individual_advanced_diffusion_models_with_alignment(scadata, stadata1, stadata2, stadata3):
    """
    Train separate AdvancedHierarchicalDiffusion models for each ST dataset, 
    apply Procrustes alignment, and average the results.
    
    Returns:
        scadata: Updated with averaged coordinates in obsm['advanced_diffusion_coords_avg']
        models_all: All trained models for further analysis
        alignment_info: Dictionary containing alignment details
    """
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Store results from each model
    sc_coords_results = []
    models_all = []
    
    # List of ST datasets for iteration
    st_datasets = [
        (stadata1, "dataset1"),
        (stadata2, "dataset2"), 
        (stadata3, "dataset3")
    ]
    
    for i, (stadata, dataset_name) in enumerate(st_datasets):
        print(f"\n{'='*50}")
        print(f"Training AdvancedHierarchicalDiffusion model {i+1}/3 for {dataset_name}")
        print(f"{'='*50}")
        
        # Get common genes between SC and current ST dataset
        sc_genes = set(scadata.var_names)
        st_genes = set(stadata.var_names)
        common_genes = sorted(list(sc_genes & st_genes))
        
        print(f"Common genes for {dataset_name}: {len(common_genes)}")
        
        # Extract expression data
        sc_expr = scadata[:, common_genes].X
        st_expr = stadata[:, common_genes].X
        
        # Convert to dense if sparse
        if hasattr(sc_expr, 'toarray'):
            sc_expr = sc_expr.toarray()
        if hasattr(st_expr, 'toarray'):
            st_expr = st_expr.toarray()
            
        # Get spatial coordinates
        st_coords = stadata.obsm['spatial']
        
        print(f"SC data shape: {sc_expr.shape}")
        print(f"ST data shape: {st_expr.shape}")
        print(f"ST coords shape: {st_coords.shape}")
        
        # Initialize AdvancedHierarchicalDiffusion model
        model = AdvancedHierarchicalDiffusion(
            st_gene_expr=st_expr,
            st_coords=st_coords,
            sc_gene_expr=sc_expr,
            cell_types_sc=scadata.obs['rough_celltype'].values,
            transport_plan=None,
            D_st=None,
            D_induced=None,
            n_genes=len(common_genes),
            n_embedding=[512, 256, 128],
            coord_space_diameter=2.00,
            sigma=3.0,
            alpha=0.8,
            mmdbatch=1000,
            batch_size=256,
            device=device,
            lr_e=0.0001,
            lr_d=0.0002,
            n_timesteps=600,
            n_denoising_blocks=4,
            hidden_dim=256,
            num_heads=8,
            num_hierarchical_scales=3,
            dp=0.2,
            outf=f'advanced_diffusion_{dataset_name}'
        )
        
        print(f"Training model for {dataset_name}...")
        
        # Train using new Graph-VAE + Latent Diffusion pipeline
        model.train(
            encoder_epochs=1000,
            vae_epochs=1000,
            diffusion_epochs=2500,
            lambda_struct=2.0
        )
        
        model.plot_training_losses()
        
        print(f"Generating SC coordinates using model {i+1}...")
        sc_coords = model.sample_sc_coordinates_batched(
            batch_size=512
        )
        
        # Store results
        sc_coords_results.append(sc_coords)
        models_all.append(model)
        
        print(f"Model {i+1} complete! Generated coordinates shape: {sc_coords.shape}")
        
        # Clean up GPU memory
        del model
        torch.cuda.empty_cache()
    
    # ========================
    # PROCRUSTES ALIGNMENT
    # ========================
    print(f"\n{'='*50}")
    print("APPLYING PROCRUSTES ALIGNMENT")
    print(f"{'='*50}")
    
    # Use second model (index 1) as reference
    reference_idx = 1
    reference_coords = sc_coords_results[reference_idx].copy()
    
    print(f"Using model {reference_idx + 1} (dataset{reference_idx + 1}) as reference")
    
    # Plot BEFORE alignment
    print("Plotting coordinates BEFORE alignment...")
    plot_alignment_comparison(
        coords_list=sc_coords_results.copy(),
        cell_types=scadata.obs['rough_celltype'].values,
        labels=[f"Model {i+1}" for i in range(len(sc_coords_results))],
        title_prefix="BEFORE Alignment - "
    )
    
    # Apply Procrustes alignment
    aligned_coords_results = []
    alignment_info = {
        'reference_model': reference_idx + 1,
        'rotations': [],
        'translations': [],
        'rmse_before': [],
        'rmse_after': []
    }
    
    for i, coords in enumerate(sc_coords_results):
        if i == reference_idx:
            # Reference model - no alignment needed
            aligned_coords = coords.copy()
            alignment_info['rotations'].append(np.eye(2))
            alignment_info['translations'].append(np.zeros(2))
            rmse_before = 0.0
            rmse_after = 0.0
        else:
            # Calculate RMSE before alignment
            rmse_before = np.sqrt(np.mean((coords - reference_coords)**2))
            
            # Apply Procrustes alignment
            aligned_coords, rotation_matrix, translation = procrustes_alignment_2d(
                coords, reference_coords
            )
            
            # Calculate RMSE after alignment
            rmse_after = np.sqrt(np.mean((aligned_coords - reference_coords)**2))
            
            # Store alignment info
            alignment_info['rotations'].append(rotation_matrix)
            alignment_info['translations'].append(translation)
            
            print(f"Model {i+1} -> Reference alignment:")
            print(f"  RMSE before: {rmse_before:.6f}")
            print(f"  RMSE after:  {rmse_after:.6f}")
            print(f"  Improvement: {rmse_before - rmse_after:.6f}")
        
        alignment_info['rmse_before'].append(rmse_before)
        alignment_info['rmse_after'].append(rmse_after)
        aligned_coords_results.append(aligned_coords)
    
    # Plot AFTER alignment
    print("Plotting coordinates AFTER alignment...")
    plot_alignment_comparison(
        coords_list=aligned_coords_results.copy(),
        cell_types=scadata.obs['rough_celltype'].values,
        labels=[f"Model {i+1}" for i in range(len(aligned_coords_results))],
        title_prefix="AFTER Alignment - "
    )
    
    # ========================
    # AVERAGING
    # ========================
    print(f"\n{'='*50}")
    print("AVERAGING ALIGNED RESULTS")
    print(f"{'='*50}")
    
    # Average the aligned results
    sc_coords_avg = np.mean(aligned_coords_results, axis=0)
    
    # Verify shapes match
    shapes = [coords.shape for coords in aligned_coords_results]
    assert all(shape == shapes[0] for shape in shapes), f"Shape mismatch: {shapes}"
    
    print(f"Final averaged coordinates shape: {sc_coords_avg.shape}")
    
    # Add to AnnData
    scadata.obsm['advanced_diffusion_coords_avg'] = sc_coords_avg
    scadata.obsm['advanced_diffusion_coords_avg_aligned'] = sc_coords_avg  # For clarity
    
    # Save individual results (both original and aligned)
    for i, (orig_coords, aligned_coords) in enumerate(zip(sc_coords_results, aligned_coords_results)):
        scadata.obsm[f'advanced_diffusion_coords_rep{i+1}_original'] = orig_coords
        scadata.obsm[f'advanced_diffusion_coords_rep{i+1}_aligned'] = aligned_coords
    
    # Plot final averaged result
    print("Plotting final averaged coordinates...")
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    # Get unique cell types and colors
    unique_types = np.unique(scadata.obs['rough_celltype'].values)
    colors = plt.cm.tab20(np.linspace(0, 1, len(unique_types)))
    color_map = dict(zip(unique_types, colors))
    
    cell_types = scadata.obs['rough_celltype'].values
    
    # Plot before averaging (overlay of aligned models)
    for i, coords in enumerate(aligned_coords_results):
        ax1.scatter(coords[:, 0], coords[:, 1], 
                   label=f"Model {i+1}", alpha=0.4, s=15)
    ax1.set_title('Aligned Models (Before Averaging)')
    ax1.set_xlabel('X coordinate')
    ax1.set_ylabel('Y coordinate')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Plot final averaged result by cell type
    for cell_type in unique_types:
        mask = cell_types == cell_type
        ax2.scatter(sc_coords_avg[mask, 0], sc_coords_avg[mask, 1], 
                   c=[color_map[cell_type]], label=cell_type, 
                   alpha=0.6, s=20)
    ax2.set_title('Final Averaged Coordinates (by Cell Type)')
    ax2.set_xlabel('X coordinate')
    ax2.set_ylabel('Y coordinate')
    ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print(f"\nAdvanced diffusion training with alignment complete!")
    print(f"Results saved in scadata.obsm['advanced_diffusion_coords_avg']")
    print(f"Individual results saved with '_original' and '_aligned' suffixes")
    
    return scadata, models_all, alignment_info

In [None]:
scadata, models_all, alignment_info = train_individual_advanced_diffusion_models_with_alignment(
    scadata, stadata1, stadata2, stadata3
)

In [None]:
def train_individual_advanced_diffusion_models(scadata, stadata1, stadata2, stadata3):
    """
    Train separate AdvancedHierarchicalDiffusion models for each ST dataset and average the results.
    
    Returns:
        scadata: Updated with averaged coordinates in obsm['advanced_diffusion_coords_avg']
        models_all: All trained models for further analysis
    """
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Store results from each model
    sc_coords_results = []
    models_all = []
    
    # List of ST datasets for iteration
    st_datasets = [
        (stadata1, "dataset1"),
        (stadata2, "dataset2"), 
        (stadata3, "dataset3")
    ]
    
    for i, (stadata, dataset_name) in enumerate(st_datasets):
        print(f"\n{'='*50}")
        print(f"Training AdvancedHierarchicalDiffusion model {i+1}/3 for {dataset_name}")
        print(f"{'='*50}")
        
        # Get common genes between SC and current ST dataset
        sc_genes = set(scadata.var_names)
        st_genes = set(stadata.var_names)
        common_genes = sorted(list(sc_genes & st_genes))
        
        print(f"Common genes for {dataset_name}: {len(common_genes)}")
        
        # Extract expression data
        sc_expr = scadata[:, common_genes].X
        st_expr = stadata[:, common_genes].X
        
        # Convert to dense if sparse
        if hasattr(sc_expr, 'toarray'):
            sc_expr = sc_expr.toarray()
        if hasattr(st_expr, 'toarray'):
            st_expr = st_expr.toarray()
            
        # Get spatial coordinates
        st_coords = stadata.obsm['spatial']
        
        print(f"SC data shape: {sc_expr.shape}")
        print(f"ST data shape: {st_expr.shape}")
        print(f"ST coords shape: {st_coords.shape}")
        
        # Initialize AdvancedHierarchicalDiffusion model
        model = AdvancedHierarchicalDiffusion(
            st_gene_expr=st_expr,
            st_coords=st_coords,
            sc_gene_expr=sc_expr,
            cell_types_sc=scadata.obs['rough_celltype'].values,  # No cell type labels
            transport_plan=None,  # No OT transport plan
            D_st=None,           # No distance matrices
            D_induced=None,
            n_genes=len(common_genes),
            n_embedding=[512, 256, 128],  # Same as STEMDiffusion
            coord_space_diameter=2.00,
            sigma=3.0,
            alpha=0.8,
            mmdbatch=1000,
            batch_size=256,
            device=device,
            lr_e=0.0001,
            lr_d=0.0002,
            n_timesteps=800,     # Same as STEMDiffusion
            n_denoising_blocks=4,
            hidden_dim=256,      # Same as STEMDiffusion
            num_heads=8,
            num_hierarchical_scales=3,
            dp=0.2,
            outf=f'advanced_diffusion_{dataset_name}'
        )
        
        print(f"Training model for {dataset_name}...")
        
        # Train using new Graph-VAE + Latent Diffusion pipeline
        model.train(
            encoder_epochs=800,  # Stage 1: Domain alignment encoder
            vae_epochs=1500,       # Stage 2: Graph-VAE training
            diffusion_epochs=2500, # Stage 3: Latent diffusion
            lambda_struct=2.0     # Structure loss weight
        )
        
        # model.plot_training_losses()

        # # Add the analysis
        # print(f"\nAnalyzing SC vs ST patterns for {dataset_name}...")
        # analyze_sc_st_patterns(model)
        
        print(f"Generating SC coordinates using model {i+1}...")
        # Sample SC coordinates using new Graph-VAE + Latent Diffusion pipeline
        # sc_coords = model.sample_sc_coordinates(
        #     n_samples=None     # Use all SC cells
        # )
        sc_coords = model.sample_sc_coordinates_batched(
            batch_size=512  # Even smaller batches
        )
        
        # Store results
        sc_coords_results.append(sc_coords)
        models_all.append(model)
        
        print(f"Model {i+1} complete! Generated coordinates shape: {sc_coords.shape}")
        
        # Clean up GPU memory
        del model
        torch.cuda.empty_cache()
    
    # Average the results from all 3 models
    print(f"\nAveraging results from {len(sc_coords_results)} models...")
    sc_coords_avg = np.mean(sc_coords_results, axis=0)
    
    # Verify shapes match
    shapes = [coords.shape for coords in sc_coords_results]
    assert all(shape == shapes[0] for shape in shapes), f"Shape mismatch: {shapes}"
    
    print(f"Final averaged coordinates shape: {sc_coords_avg.shape}")
    
    # Add to AnnData
    scadata.obsm['advanced_diffusion_coords_avg'] = sc_coords_avg
    
    # Optionally, save individual results too
    for i, coords in enumerate(sc_coords_results):
        scadata.obsm[f'advanced_diffusion_coords_rep{i+1}'] = coords
    
    print(f"\nAdvanced diffusion training complete!")
    print(f"Results saved in scadata.obsm['advanced_diffusion_coords_avg']")
    
    return scadata, models_all

# Load and process data
scadata, stadata1, stadata2, stadata3 = load_and_process_cscc_data()

# Train individual AdvancedHierarchicalDiffusion models and get averaged results
scadata, advanced_models = train_individual_advanced_diffusion_models(
    scadata, stadata1, stadata2, stadata3
)

print("Advanced diffusion training complete! Results saved in scadata.obsm['advanced_diffusion_coords_avg']")

# Visualize results
import matplotlib.pyplot as plt
import seaborn as sns

my_tab20 = sns.color_palette("tab20", n_colors=20).as_hex()

# Plot 1: Averaged coordinates
plt.figure(figsize=(8, 6))
sc.pl.embedding(scadata, basis='advanced_diffusion_coords_avg', color='rough_celltype',
               size=85, title='SC Advanced Diffusion Coords (Averaged)',
               palette=my_tab20, legend_loc='right margin', legend_fontsize=10)
plt.show()

# Plot 2: Individual model results
for i in range(3):
    plt.figure(figsize=(6, 5))
    sc.pl.embedding(scadata, basis=f'advanced_diffusion_coords_rep{i+1}', color='rough_celltype',
                   size=85, title=f'SC Coordinates (Advanced Model {i+1})',
                   palette=my_tab20, legend_loc='right margin', legend_fontsize=10)
    plt.show()

In [None]:
scadata.obsm['advanced_diffusion_coords_avg']

In [None]:
scadata.obsm['advanced_diffusion_coords_avg']

In [None]:
# Visualize results with separate plots
import matplotlib.pyplot as plt
import seaborn as sns


my_tab20 = sns.color_palette("tab20", n_colors=20).as_hex()


import matplotlib as mpl
mpl.rcParams['figure.figsize'] = (6, 6)
# import scanpy as sc
# sc.settings.set_figure_params(figsize=(4,4), dpi=100)

# Plot 1: Averaged coordinates
plt.figure(figsize=(4, 4))
sc.pl.embedding(scadata, basis='advanced_diffusion_coords_avg', color='rough_celltype',
               size=85, title='SC Spatial Coordinates (Averaged from 3 Models)',
               palette=my_tab20, legend_loc='right margin', legend_fontsize=10)
plt.show()

# Plot 2: Model 1 results
plt.figure(figsize=(4, 4))
sc.pl.embedding(scadata, basis='advanced_diffusion_coords_rep1_original', color='rough_celltype',
               size=85, title='SC Coordinates (Model 1)',
               palette=my_tab20, legend_loc='right margin', legend_fontsize=10)
plt.show()

plt.figure(figsize=(4, 4))
sc.pl.embedding(scadata, basis='advanced_diffusion_coords_rep1_aligned', color='rough_celltype',
               size=85, title='SC Coordinates (Model 1)',
               palette=my_tab20, legend_loc='right margin', legend_fontsize=10)
plt.show()

# Plot 3: Model 2 results
plt.figure(figsize=(4, 4))
sc.pl.embedding(scadata, basis='advanced_diffusion_coords_rep2_original', color='rough_celltype',
               size=85, title='SC Coordinates (Model 2)',
               palette=my_tab20, legend_loc='right margin', legend_fontsize=10)
plt.show()

plt.figure(figsize=(4, 4))
sc.pl.embedding(scadata, basis='advanced_diffusion_coords_rep2_aligned', color='rough_celltype',
               size=85, title='SC Coordinates (Model 1)',
               palette=my_tab20, legend_loc='right margin', legend_fontsize=10)
plt.show()

import seaborn as sns
my_tab20 = sns.color_palette("tab20", n_colors=12).as_hex()

# Plot 4: Model 3 results
plt.figure(figsize=(4, 4))
sc.pl.embedding(scadata, basis='advanced_diffusion_coords_rep3_original', color='rough_celltype',
               size=85, title='SC Coordinates (Model 3)',
             palette=my_tab20, legend_loc='right margin', legend_fontsize=10)
plt.show()

plt.figure(figsize=(4, 4))
sc.pl.embedding(scadata, basis='advanced_diffusion_coords_rep3_aligned', color='rough_celltype',
               size=85, title='SC Coordinates (Model 1)',
               palette=my_tab20, legend_loc='right margin', legend_fontsize=10)
plt.show()

In [None]:
# After you have: scadata.obsm['advanced_diffusion_coords_avg'], etc.

print("\n=== Advanced Analysis and Visualization ===")

mpl.rcParams['figure.figsize'] = (8, 6)


# 1. Visualize advanced results with uncertainty analysis
print("Creating advanced visualization plots...")
fig, model_uncertainty, confidence_scores = visualize_advanced_results_multi_model(scadata)

# 2. Analyze cell interactions for averaged coordinates
print("Analyzing cell interactions (averaged coordinates)...")
min_distances_avg, interaction_matrix_avg = analyze_cell_interactions_advanced(
    scadata, coords_key='advanced_diffusion_coords_avg'
)

# 3. Optional: Analyze interactions for individual models too
print("Analyzing cell interactions (individual models)...")
for i in range(1, 4):
    print(f"\nModel {i} interactions:")
    min_distances, interaction_matrix = analyze_cell_interactions_advanced(
        scadata, coords_key=f'advanced_diffusion_coords_rep{i}'
    )

# 4. Print summary statistics
print("\n=== Advanced Model Statistics ===")
print(f"Total cells mapped: {len(scadata.obsm['advanced_diffusion_coords_avg'])}")
print(f"Average model uncertainty: {model_uncertainty.mean():.4f}")
print(f"Model uncertainty range: [{model_uncertainty.min():.4f}, {model_uncertainty.max():.4f}]")
print(f"Average confidence: {confidence_scores.mean():.4f}")

print("\n=== Cell Type Confidence ===")
for ct in scadata.obs['rough_celltype'].unique():
    mask = scadata.obs['rough_celltype'] == ct
    print(f"{ct}: {mask.sum()} cells, "
          f"avg confidence: {confidence_scores[mask].mean():.3f}, "
          f"avg uncertainty: {model_uncertainty[mask].mean():.4f}")

print("\n=== Physics Constraints (Averaged Coordinates) ===")
all_distances = []
for key, dist in min_distances_avg.items():
    if not np.isnan(dist):
        all_distances.append(dist)
        print(f"Min distance {key[0]} - {key[1]}: {dist:.4f}")

if all_distances:
    print(f"\nOverall minimum cell-cell distance: {np.min(all_distances):.4f}")
    print(f"Cells with potential overlaps (< 0.01): {np.sum(np.array(all_distances) < 0.01)}")

In [None]:
# Visualize results
import matplotlib.pyplot as plt
import seaborn as sns
mpl.rcParams['figure.figsize'] = (6, 6)

my_tab20 = sns.color_palette("tab20", n_colors=20).as_hex()

# Plot 1: Averaged coordinates
plt.figure(figsize=(8, 6))
sc.pl.embedding(scadata, basis='advanced_diffusion_coords_avg', color='rough_celltype',
               size=85, title='SC Advanced Diffusion Coords (Averaged)',
               palette=my_tab20, legend_loc='right margin', legend_fontsize=10)
plt.show()

# Plot 2: Individual model results
for i in range(3):
    plt.figure(figsize=(6, 5)) 
    sc.pl.embedding(scadata, basis=f'advanced_diffusion_coords_rep{i+1}', color='rough_celltype',
                   size=85, title=f'SC Coordinates (Advanced Model {i+1})',
                   palette=my_tab20, legend_loc='right margin', legend_fontsize=10)
    plt.show()

In [None]:
def normalize_coordinates_isotropic(coords):
    '''normalize coordinates to unit circle preserving aspect ratio'''
    if torch.is_tensor(coords):
        center = coords.mean(dim=0)
        centered = coords - center
        max_radius = torch.max(torch.norm(centered, dim=1))
        coords_norm = centered / max_radius
        return coords_norm, center, max_radius
    else:
        center = coords.mean(axis=0)
        centered = coords - center
        max_radius = np.max(np.linalg.norm(centered, axis=1))
        coords_norm = centered / max_radius
        return coords_norm, center, max_radius

# Load and prepare data for validation
scadata_val, stadata1_val, stadata2_val, stadata3_val = load_and_process_cscc_data_individual_norm()

# Get normalized ground truth coordinates for ST3
st3_coords_gt = stadata3_val.obsm['spatial']
st3_coords_gt_norm, _, _ = normalize_coordinates_isotropic(st3_coords_gt)

print("=== VALIDATION EXPERIMENT ===")
print("Training diffusion models on ST1+ST2, testing on ST3...")

# Prepare datasets for training (only first 2)
st_datasets_train = [
    (stadata1_val, "dataset1"),
    (stadata2_val, "dataset2")
]


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Storage for results
T_all = []
D_induced_all = []
D_st_all = []
D_sc_all = []
trained_models = []

# Train on first two datasets
for i, (stadata, dataset_name) in enumerate(st_datasets_train):
    print(f"\n{'='*50}")
    print(f"Training Advanced Diffusion model {i+1}/2 for {dataset_name} using SpaOTsc")
    print(f"{'='*50}")
    
    # Get common genes between SC and current ST dataset
    sc_genes = set(scadata_val.var_names)
    st_genes = set(stadata.var_names)
    common_genes = sorted(list(sc_genes & st_genes))
    
    print(f"Common genes for {dataset_name}: {len(common_genes)}")
    
    # Extract expression data
    sc_expr = scadata_val[:, common_genes].X
    st_expr = stadata[:, common_genes].X
    
    # Convert to dense if sparse
    if hasattr(sc_expr, 'toarray'):
        sc_expr = sc_expr.toarray()
    if hasattr(st_expr, 'toarray'):
        st_expr = st_expr.toarray()
        
    # Get coordinates
    st_coords = stadata.obsm['spatial']
    
    # Convert to tensors
    X_sc = torch.tensor(sc_expr, dtype=torch.float32).to(device)
    X_st = torch.tensor(st_expr, dtype=torch.float32).to(device)
    Y_st = torch.tensor(st_coords, dtype=torch.float32).to(device)
    
    print(f"SC data shape: {X_sc.shape}")
    print(f"ST data shape: {X_st.shape}")
    print(f"ST coords shape: {Y_st.shape}")
    
    # === REPLACE FUSED_GW_TORCH WITH SPAOTSC ===
    print(f"Running optimal transport for {dataset_name}...")

    # === PREPARE CELL TYPE INFORMATION ===
    # Extract cell types from scadata
    if 'rough_celltype' in scadata_val.obs.columns:
        cell_types_sc = scadata_val.obs['rough_celltype'].values
        unique_cell_types = np.unique(cell_types_sc)
        print(f"Found {len(unique_cell_types)} unique cell types: {unique_cell_types}")
    else:
        cell_types_sc = None
        print("No cell type information found")

    # === INITIALIZE ADVANCED HIERARCHICAL DIFFUSION MODEL ===
    print(f"Initializing AdvancedHierarchicalDiffusion for {dataset_name}...")
    
    output_dir = f'./cscc_advanced_diffusion_{dataset_name}_validation'
    
    model = AdvancedHierarchicalDiffusion(
        st_gene_expr=st_expr,
        st_coords=st_coords,
        sc_gene_expr=sc_expr,
        cell_types_sc=scadata.obs['rough_celltype'].values,  # No cell type labels
        transport_plan=None,  # No OT transport plan
        D_st=None,           # No distance matrices
        D_induced=None,
        n_genes=len(common_genes),
        n_embedding=[512, 256, 128],  # Same as STEMDiffusion
        coord_space_diameter=2.00,
        sigma=3.0,
        alpha=0.8,
        mmdbatch=1000,
        batch_size=256,
        device=device,
        lr_e=0.0001,
        lr_d=0.0002,
        n_timesteps=800,     # Same as STEMDiffusion
        n_denoising_blocks=6,
        hidden_dim=256,      # Same as STEMDiffusion
        num_heads=8,
        num_hierarchical_scales=3,
        dp=0.2,
        outf=f'advanced_diffusion_{dataset_name}'
    )
    
    # Train the model
    print(f"Training model for {dataset_name}...")
    # model.train()

    model.train(
        encoder_epochs=1000,  # Stage 1: Domain alignment encoder
        vae_epochs=1000,       # Stage 2: Graph-VAE training
        diffusion_epochs=2500, # Stage 3: Latent diffusion
        lambda_struct=5.0     # Structure loss weight
    )
    
    # Store the trained model
    trained_models.append(model)
    print(f"Model {i+1} training completed!")

print(f"\nTraining completed! {len(trained_models)} models trained.")

# Test on the third dataset by creating new model instances
print(f"\n{'='*50}")
print("Testing on ST3 dataset...")
print(f"{'='*50}")

# Get common genes for testing
sc_genes = set(scadata_val.var_names)
st1_genes = set(stadata1_val.var_names)
st2_genes = set(stadata2_val.var_names)
st3_genes = set(stadata3_val.var_names)
common_genes_test = sorted(list(sc_genes & st1_genes & st2_genes & st3_genes))

print(f"Common genes for testing: {len(common_genes_test)}")

# Extract ST3 expression data
st3_expr = stadata3_val[:, common_genes_test].X
if hasattr(st3_expr, 'toarray'):
    st3_expr = st3_expr.toarray()

# We'll create dummy models just to get predictions
# Use ST1 as reference for coordinates (since we need some spatial reference)
st1_coords = stadata1_val.obsm['spatial']
st1_expr_ref = stadata1_val[:, common_genes_test].X
if hasattr(st1_expr_ref, 'toarray'):
    st1_expr_ref = st1_expr_ref.toarray()

# Convert to tensors
X_st1_ref = torch.tensor(st1_expr_ref, dtype=torch.float32).to(device)
X_st3_test = torch.tensor(st3_expr, dtype=torch.float32).to(device)
Y_st1_ref = torch.tensor(st1_coords, dtype=torch.float32).to(device)

# Get cell types
cell_types_sc = scadata_val.obs['rough_celltype'].values if 'rough_celltype' in scadata_val.obs.columns else None

all_predictions = []

# Get the number of cell types from the original training
original_cell_types = scadata_val.obs['rough_celltype'].values
unique_cell_types = np.unique(original_cell_types)
num_original_cell_types = len(unique_cell_types)

print(f"Original model has {num_original_cell_types} cell types")

# Create dummy cell types for ST3 data to match the original number
dummy_cell_types = np.random.choice(unique_cell_types, size=X_st3_test.shape[0])

# For each trained model, create a test version
for i, trained_model in enumerate(trained_models):
    print(f"Creating test model based on trained model {i+1}...")
    
    # Create a minimal model instance for testing (no training)
    test_model = AdvancedHierarchicalDiffusion(
        st_gene_expr=X_st1_ref.cpu().numpy(),  # Use ST1 as reference
        st_coords=Y_st1_ref.cpu().numpy(),     # Use ST1 coords as reference
        sc_gene_expr=X_st3_test.cpu().numpy(), # ST3 data as "SC" data
        cell_types_sc=dummy_cell_types,                    # No cell types for ST3
        transport_plan=None,               # Use transport plan from training
        D_st=None,                      # Use distance matrices from training
        D_induced=None,
        n_genes=len(common_genes_test),
        n_embedding=[512, 256, 128],
        coord_space_diameter=2.00,
        sigma=3.0,
        alpha=0.8,
        mmdbatch=1000,
        batch_size=256,
        device=device,
        lr_e=0.0001,
        lr_d=0.0002,
        n_timesteps=800,
        n_denoising_blocks=6,
        hidden_dim=256,
        num_heads=8,
        num_hierarchical_scales=3,
        dp=0.2,
        outf=f'./temp_test_model_{i}'
    )
    
    # Copy trained parameters (this is a hack, but should work)
    state_dict = trained_model.state_dict()
    # if 'ot_guidance_strength' in state_dict:
    #     del state_dict['ot_guidance_strength']
    # test_model.load_state_dict(state_dict)
    test_model.load_state_dict(trained_model.state_dict(), strict=False)
    
    # Now sample coordinates (this should work since ST3 data is the "SC" data)
    print(f"Generating predictions from test model {i+1}...")

    predicted_coords = test_model.sample_sc_coordinates_batched(
            batch_size=512  # Even smaller batches
    )
    all_predictions.append(predicted_coords)

# Continue with the rest of your evaluation code...

# Average predictions from both models
predicted_coords_avg = np.mean(all_predictions, axis=0)
print(f"Predicted coordinates shape: {predicted_coords_avg.shape}")
print(f"Ground truth coordinates shape: {st3_coords_gt_norm.shape}")

# Calculate evaluation metrics
from sklearn.metrics import mean_squared_error, mean_absolute_error
from scipy.stats import pearsonr

# MSE and MAE
mse = mean_squared_error(st3_coords_gt_norm, predicted_coords_avg)
mae = mean_absolute_error(st3_coords_gt_norm, predicted_coords_avg)

# Correlation for each dimension
corr_x, p_x = pearsonr(st3_coords_gt_norm[:, 0], predicted_coords_avg[:, 0])
corr_y, p_y = pearsonr(st3_coords_gt_norm[:, 1], predicted_coords_avg[:, 1])

print("=== VALIDATION RESULTS ===")
print(f"Mean Squared Error: {mse:.6f}")
print(f"Mean Absolute Error: {mae:.6f}")
print(f"Correlation X-dimension: {corr_x:.4f} (p={p_x:.6f})")
print(f"Correlation Y-dimension: {corr_y:.4f} (p={p_y:.6f})")

# Visualization
import matplotlib.pyplot as plt

plt.figure(figsize=(15, 5))

# Plot 1: Ground truth coordinates
plt.subplot(1, 3, 1)
plt.scatter(st3_coords_gt_norm[:, 0], st3_coords_gt_norm[:, 1], 
           c=range(len(st3_coords_gt_norm)), cmap='viridis', alpha=0.6, s=20)
plt.title('Ground Truth ST3 Coordinates\n(Isotropic Normalized)')
plt.xlabel('X coordinate')
plt.ylabel('Y coordinate')
plt.colorbar(label='Spot index')

# Plot 2: Predicted coordinates
plt.subplot(1, 3, 2)
plt.scatter(predicted_coords_avg[:, 0], predicted_coords_avg[:, 1], 
           c=range(len(predicted_coords_avg)), cmap='viridis', alpha=0.6, s=20)
plt.title('Predicted ST3 Coordinates\n(Averaged from 2 Models)')
plt.xlabel('X coordinate')
plt.ylabel('Y coordinate')
plt.colorbar(label='Spot index')

# Plot 3: Correlation plot
plt.subplot(1, 3, 3)
plt.scatter(st3_coords_gt_norm[:, 0], predicted_coords_avg[:, 0], 
           alpha=0.5, label=f'X-coord (r={corr_x:.3f})', s=15)
plt.scatter(st3_coords_gt_norm[:, 1], predicted_coords_avg[:, 1], 
           alpha=0.5, label=f'Y-coord (r={corr_y:.3f})', s=15)
plt.plot([st3_coords_gt_norm.min(), st3_coords_gt_norm.max()], 
         [st3_coords_gt_norm.min(), st3_coords_gt_norm.max()], 
         'r--', alpha=0.8, label='Perfect correlation')
plt.xlabel('Ground Truth Coordinates')
plt.ylabel('Predicted Coordinates')
plt.title('Prediction vs Ground Truth')
plt.legend()

plt.tight_layout()
plt.show()

# Additional distance-based evaluation
euclidean_distances = np.sqrt(np.sum((st3_coords_gt_norm - predicted_coords_avg)**2, axis=1))
median_distance = np.median(euclidean_distances)
mean_distance = np.mean(euclidean_distances)

print(f"\nDistance-based metrics:")
print(f"Mean Euclidean distance: {mean_distance:.6f}")
print(f"Median Euclidean distance: {median_distance:.6f}")
print(f"Max Euclidean distance: {np.max(euclidean_distances):.6f}")
print(f"Min Euclidean distance: {np.min(euclidean_distances):.6f}")

print("=== VALIDATION EXPERIMENT COMPLETED ===")

In [None]:
# Now sample coordinates (this should work since ST3 data is the "SC" data)
print(f"Generating predictions from test model {i+1}...")
all_predictions = []
for i, trained_model in enumerate(trained_models):
    print(f"Creating test model based on trained model {i+1}...")
    
    # Create a minimal model instance for testing (no training)
    test_model = AdvancedHierarchicalDiffusion(
        st_gene_expr=X_st1_ref.cpu().numpy(),  # Use ST1 as reference
        st_coords=Y_st1_ref.cpu().numpy(),     # Use ST1 coords as reference
        sc_gene_expr=X_st3_test.cpu().numpy(), # ST3 data as "SC" data
        cell_types_sc=dummy_cell_types,                    # No cell types for ST3
        transport_plan=None,               # Use transport plan from training
        D_st=None,                      # Use distance matrices from training
        D_induced=None,
        n_genes=len(common_genes_test),
        n_embedding=[512, 256, 128],
        coord_space_diameter=2.00,
        sigma=3.0,
        alpha=0.8,
        mmdbatch=1000,
        batch_size=256,
        device=device,
        lr_e=0.0001,
        lr_d=0.0002,
        n_timesteps=800,
        n_denoising_blocks=6,
        hidden_dim=256,
        num_heads=8,
        num_hierarchical_scales=3,
        dp=0.2,
        outf=f'./temp_test_model_{i}'
    )
    
    # Copy trained parameters (this is a hack, but should work)
    state_dict = trained_model.state_dict()
    # if 'ot_guidance_strength' in state_dict:
    #     del state_dict['ot_guidance_strength']
    # test_model.load_state_dict(state_dict)
    test_model.load_state_dict(trained_model.state_dict(), strict=False)
    
    # Now sample coordinates (this should work since ST3 data is the "SC" data)
    print(f"Generating predictions from test model {i+1}...")
    predicted_coords = test_model.sample_sc_coordinates_batched(
            batch_size=512  # Even smaller batches
    )
    all_predictions.append(predicted_coords)

# Continue with the rest of your evaluation code...

# Average predictions from both models
predicted_coords_avg = np.mean(all_predictions, axis=0)
print(f"Predicted coordinates shape: {predicted_coords_avg.shape}")
print(f"Ground truth coordinates shape: {st3_coords_gt_norm.shape}")

# Calculate evaluation metrics
from sklearn.metrics import mean_squared_error, mean_absolute_error
from scipy.stats import pearsonr

# MSE and MAE
mse = mean_squared_error(st3_coords_gt_norm, predicted_coords_avg)
mae = mean_absolute_error(st3_coords_gt_norm, predicted_coords_avg)

# Correlation for each dimension
corr_x, p_x = pearsonr(st3_coords_gt_norm[:, 0], predicted_coords_avg[:, 0])
corr_y, p_y = pearsonr(st3_coords_gt_norm[:, 1], predicted_coords_avg[:, 1])

print("=== VALIDATION RESULTS ===")
print(f"Mean Squared Error: {mse:.6f}")
print(f"Mean Absolute Error: {mae:.6f}")
print(f"Correlation X-dimension: {corr_x:.4f} (p={p_x:.6f})")
print(f"Correlation Y-dimension: {corr_y:.4f} (p={p_y:.6f})")

# Visualization
import matplotlib.pyplot as plt

plt.figure(figsize=(15, 5))

# Plot 1: Ground truth coordinates
plt.subplot(1, 3, 1)
plt.scatter(st3_coords_gt_norm[:, 0], st3_coords_gt_norm[:, 1], 
           c=range(len(st3_coords_gt_norm)), cmap='viridis', alpha=0.6, s=20)
plt.title('Ground Truth ST3 Coordinates\n(Isotropic Normalized)')
plt.xlabel('X coordinate')
plt.ylabel('Y coordinate')
plt.colorbar(label='Spot index')

# Plot 2: Predicted coordinates
plt.subplot(1, 3, 2)
plt.scatter(predicted_coords_avg[:, 0], predicted_coords_avg[:, 1], 
           c=range(len(predicted_coords_avg)), cmap='viridis', alpha=0.6, s=20)
plt.title('Predicted ST3 Coordinates\n(Averaged from 2 Models)')
plt.xlabel('X coordinate')
plt.ylabel('Y coordinate')
plt.colorbar(label='Spot index')

# Plot 3: Correlation plot
plt.subplot(1, 3, 3)
plt.scatter(st3_coords_gt_norm[:, 0], predicted_coords_avg[:, 0], 
           alpha=0.5, label=f'X-coord (r={corr_x:.3f})', s=15)
plt.scatter(st3_coords_gt_norm[:, 1], predicted_coords_avg[:, 1], 
           alpha=0.5, label=f'Y-coord (r={corr_y:.3f})', s=15)
plt.plot([st3_coords_gt_norm.min(), st3_coords_gt_norm.max()], 
         [st3_coords_gt_norm.min(), st3_coords_gt_norm.max()], 
         'r--', alpha=0.8, label='Perfect correlation')
plt.xlabel('Ground Truth Coordinates')
plt.ylabel('Predicted Coordinates')
plt.title('Prediction vs Ground Truth')
plt.legend()

plt.tight_layout()
plt.show()

# Additional distance-based evaluation
euclidean_distances = np.sqrt(np.sum((st3_coords_gt_norm - predicted_coords_avg)**2, axis=1))
median_distance = np.median(euclidean_distances)
mean_distance = np.mean(euclidean_distances)

print(f"\nDistance-based metrics:")
print(f"Mean Euclidean distance: {mean_distance:.6f}")
print(f"Median Euclidean distance: {median_distance:.6f}")
print(f"Max Euclidean distance: {np.max(euclidean_distances):.6f}")
print(f"Min Euclidean distance: {np.min(euclidean_distances):.6f}")

print("=== VALIDATION EXPERIMENT COMPLETED ===")

In [None]:
# Evaluate each model separately AND the average
predicted_coords_avg = np.mean(all_predictions, axis=0)
print(f"Predicted coordinates shape: {predicted_coords_avg.shape}")
print(f"Ground truth coordinates shape: {st3_coords_gt_norm.shape}")

from sklearn.metrics import mean_squared_error, mean_absolute_error
from scipy.stats import pearsonr

# Evaluate individual models
for i, pred in enumerate(all_predictions):
    mse_i = mean_squared_error(st3_coords_gt_norm, pred)
    mae_i = mean_absolute_error(st3_coords_gt_norm, pred)
    corr_x_i, p_x_i = pearsonr(st3_coords_gt_norm[:, 0], pred[:, 0])
    corr_y_i, p_y_i = pearsonr(st3_coords_gt_norm[:, 1], pred[:, 1])
    
    print(f"\n=== MODEL {i+1} RESULTS ===")
    print(f"MSE: {mse_i:.6f}, MAE: {mae_i:.6f}")
    print(f"Corr X: {corr_x_i:.4f}, Corr Y: {corr_y_i:.4f}")

# Evaluate averaged results
mse = mean_squared_error(st3_coords_gt_norm, predicted_coords_avg)
mae = mean_absolute_error(st3_coords_gt_norm, predicted_coords_avg)
corr_x, p_x = pearsonr(st3_coords_gt_norm[:, 0], predicted_coords_avg[:, 0])
corr_y, p_y = pearsonr(st3_coords_gt_norm[:, 1], predicted_coords_avg[:, 1])

print(f"\n=== AVERAGED RESULTS ===")
print(f"MSE: {mse:.6f}, MAE: {mae:.6f}")
print(f"Corr X: {corr_x:.4f}, Corr Y: {corr_y:.4f}")

In [None]:
len(all_predictions)

In [None]:
# Visualization - individual models + average + ground truth
import matplotlib.pyplot as plt

plt.figure(figsize=(20, 5))

# Plot 1: Ground truth coordinates
plt.subplot(1, 4, 1)
plt.scatter(st3_coords_gt_norm[:, 0], st3_coords_gt_norm[:, 1], 
           c=range(len(st3_coords_gt_norm)), cmap='viridis', alpha=0.6, s=20)
plt.title('Ground Truth ST3 Coordinates\n(Isotropic Normalized)')
plt.xlabel('X coordinate')
plt.ylabel('Y coordinate')
plt.colorbar(label='Spot index')

# Plot 2: Model 1 predictions
plt.subplot(1, 4, 2)
plt.scatter(all_predictions[0][:, 0], all_predictions[0][:, 1], 
           c=range(len(all_predictions[0])), cmap='viridis', alpha=0.6, s=20)
plt.title('Model 1 Predictions')
plt.xlabel('X coordinate')
plt.ylabel('Y coordinate')
plt.colorbar(label='Spot index')

# Plot 3: Model 2 predictions
plt.subplot(1, 4, 3)
plt.scatter(all_predictions[1][:, 0], all_predictions[1][:, 1], 
           c=range(len(all_predictions[1])), cmap='viridis', alpha=0.6, s=20)
plt.title('Model 2 Predictions')
plt.xlabel('X coordinate')
plt.ylabel('Y coordinate')
plt.colorbar(label='Spot index')

# Plot 4: Averaged predictions
plt.subplot(1, 4, 4)
plt.scatter(predicted_coords_avg[:, 0], predicted_coords_avg[:, 1], 
           c=range(len(predicted_coords_avg)), cmap='viridis', alpha=0.6, s=20)
plt.title('Averaged Predictions')
plt.xlabel('X coordinate')
plt.ylabel('Y coordinate')
plt.colorbar(label='Spot index')

plt.tight_layout()
plt.show()

# Correlation plots for each model
plt.figure(figsize=(15, 5))

for i, pred in enumerate(all_predictions):
    corr_x_i, _ = pearsonr(st3_coords_gt_norm[:, 0], pred[:, 0])
    corr_y_i, _ = pearsonr(st3_coords_gt_norm[:, 1], pred[:, 1])
    
    plt.subplot(1, 3, i+1)
    plt.scatter(st3_coords_gt_norm[:, 0], pred[:, 0], 
               alpha=0.5, label=f'X-coord (r={corr_x_i:.3f})', s=15)
    plt.scatter(st3_coords_gt_norm[:, 1], pred[:, 1], 
               alpha=0.5, label=f'Y-coord (r={corr_y_i:.3f})', s=15)
    plt.plot([st3_coords_gt_norm.min(), st3_coords_gt_norm.max()], 
             [st3_coords_gt_norm.min(), st3_coords_gt_norm.max()], 
             'r--', alpha=0.8, label='Perfect correlation')
    plt.xlabel('Ground Truth Coordinates')
    plt.ylabel('Predicted Coordinates')
    plt.title(f'Model {i+1} vs Ground Truth')
    plt.legend()

# Averaged correlation plot
plt.subplot(1, 3, 3)
plt.scatter(st3_coords_gt_norm[:, 0], predicted_coords_avg[:, 0], 
           alpha=0.5, label=f'X-coord (r={corr_x:.3f})', s=15)
plt.scatter(st3_coords_gt_norm[:, 1], predicted_coords_avg[:, 1], 
           alpha=0.5, label=f'Y-coord (r={corr_y:.3f})', s=15)
plt.plot([st3_coords_gt_norm.min(), st3_coords_gt_norm.max()], 
         [st3_coords_gt_norm.min(), st3_coords_gt_norm.max()], 
         'r--', alpha=0.8, label='Perfect correlation')
plt.xlabel('Ground Truth Coordinates')
plt.ylabel('Predicted Coordinates')
plt.title('Averaged Predictions vs Ground Truth')
plt.legend()

plt.tight_layout()
plt.show()

# Distance error plots for each model
euclidean_distances_all = []
for i, pred in enumerate(all_predictions):
    distances = np.sqrt(np.sum((st3_coords_gt_norm - pred)**2, axis=1))
    euclidean_distances_all.append(distances)

euclidean_distances_avg = np.sqrt(np.sum((st3_coords_gt_norm - predicted_coords_avg)**2, axis=1))

plt.figure(figsize=(15, 4))

# Distance histograms
for i, distances in enumerate(euclidean_distances_all):
    plt.subplot(1, 3, i+1)
    plt.hist(distances, bins=50, alpha=0.7, edgecolor='black')
    plt.xlabel('Euclidean Distance')
    plt.ylabel('Frequency')
    plt.title(f'Model {i+1} Error Distribution\nMean: {np.mean(distances):.4f}')

plt.subplot(1, 3, 3)
plt.hist(euclidean_distances_avg, bins=50, alpha=0.7, edgecolor='black')
plt.xlabel('Euclidean Distance')
plt.ylabel('Frequency')
plt.title(f'Averaged Model Error Distribution\nMean: {np.mean(euclidean_distances_avg):.4f}')

plt.tight_layout()
plt.show()

# Print distance metrics for all models
for i, distances in enumerate(euclidean_distances_all):
    print(f"\nModel {i+1} distance metrics:")
    print(f"Mean: {np.mean(distances):.6f}, Median: {np.median(distances):.6f}")
    print(f"Max: {np.max(distances):.6f}, Min: {np.min(distances):.6f}")

print(f"\nAveraged model distance metrics:")
print(f"Mean: {np.mean(euclidean_distances_avg):.6f}, Median: {np.median(euclidean_distances_avg):.6f}")
print(f"Max: {np.max(euclidean_distances_avg):.6f}, Min: {np.min(euclidean_distances_avg):.6f}")

In [None]:
mpl.rcParams['figure.figsize'] = (4, 4)

sc.pl.spatial(scadata,color="rough_celltype",spot_size=0.06, show=True,basis='advanced_diffusion_coords_avg',title='reconstructed')


In [None]:
mpl.rcParams['figure.figsize'] = (4, 4)
sc.pl.spatial(scadata,color="level2_celltype",groups=["PDC"],spot_size=0.06, show=True,basis='advanced_diffusion_coords_avg',title='reconstructed',na_in_legend=False)

In [None]:
# rcParams['pdf.fonttype'] = 42
# rcParams['ps.fonttype'] = 42
# figsize(4,4)
mpl.rcParams['figure.figsize'] = (4, 4)
sc.pl.spatial(scadata,color="level3_celltype",groups=["TSK"],spot_size=0.06, show=True,basis='advanced_diffusion_coords_avg',title='reconstructed',na_in_legend=False)
#save='TSK',

In [None]:
# figsize(4,4)
mpl.rcParams['figure.figsize'] = (4, 4)
sc.pl.spatial(scadata,color="level2_celltype",groups=["Tumor_KC_Cyc"],spot_size=0.06, show=True,basis='advanced_diffusion_coords_avg',title='reconstructed',na_in_legend=False,save='P2cyc')
sc.pl.spatial(scadata,color="level2_celltype",groups=["Tumor_KC_Basal"],spot_size=0.06, show=True,basis='advanced_diffusion_coords_avg',title='reconstructed',na_in_legend=False,save='P2bas')
sc.pl.spatial(scadata,color="level2_celltype",groups=["Tumor_KC_Diff"],spot_size=0.06, show=True,basis='advanced_diffusion_coords_avg',title='reconstructed',na_in_legend=False,save='P2diff')
#save='nonTSK',

In [None]:
import squidpy as sq
sq.gr.spatial_neighbors(scadata,spatial_key='advanced_diffusion_coords_avg')
sq.gr.nhood_enrichment(scadata,cluster_key='rough_celltype')
sq.gr.interaction_matrix(scadata,cluster_key='rough_celltype')
kscadata = scadata[ scadata.obs.level2_celltype.isin(['Tumor_KC_Cyc','Tumor_KC_Basal','Tumor_KC_Diff','TSK'])].copy()
sq.gr.spatial_neighbors(kscadata,spatial_key='advanced_diffusion_coords_avg')
sq.gr.nhood_enrichment(kscadata,cluster_key='level2_celltype')
# sq.pl.nhood_enrichment(kscadata, cluster_key="level2_celltype",cmap='coolwarm',save='TSKKC_new_good.png',figsize=(3,5))
sq.pl.nhood_enrichment(kscadata, cluster_key="level2_celltype",cmap='coolwarm',figsize=(3,5))


# patient 10 stuff

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.spatial.distance import cdist
from sklearn.metrics.pairwise import cosine_similarity
import pandas as pd

# Load all 3 ST datasets
stadata1_p10 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP10rep1.h5ad')
stadata2_p10 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP10rep2.h5ad')
stadata3_p10 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP10rep3.h5ad')

datasets = [stadata1_p10, stadata2_p10, stadata3_p10]
names = ['ST_P10_Rep1', 'ST_P10_Rep2', 'ST_P10_Rep3']

# Basic info
print("Dataset Basic Info:")
for i, (data, name) in enumerate(zip(datasets, names)):
    print(f"{name}: {data.shape[0]} spots, {data.shape[1]} genes")
    print(f"  Spatial coords range: X[{data.obsm['spatial'][:,0].min():.2f}, {data.obsm['spatial'][:,0].max():.2f}], Y[{data.obsm['spatial'][:,1].min():.2f}, {data.obsm['spatial'][:,1].max():.2f}]")

# Plot spatial coordinates
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Individual plots
for i, (data, name) in enumerate(zip(datasets, names)):
    coords = data.obsm['spatial']
    row = 0 if i < 2 else 1
    col = i if i < 2 else 0
    axes[row, col].scatter(coords[:, 0], coords[:, 1], alpha=0.6, s=20)
    axes[row, col].set_title(f'{name}\n{data.shape[0]} spots')
    axes[row, col].set_xlabel('X coordinate')
    axes[row, col].set_ylabel('Y coordinate')

# Overlay plot
colors = ['red', 'blue', 'green']
for i, (data, name, color) in enumerate(zip(datasets, names, colors)):
    coords = data.obsm['spatial']
    axes[1, 1].scatter(coords[:, 0], coords[:, 1], alpha=0.5, s=15, c=color, label=name)
axes[1, 1].set_title('All Datasets Overlay')
axes[1, 1].set_xlabel('X coordinate')
axes[1, 1].set_ylabel('Y coordinate')
axes[1, 1].legend()

plt.tight_layout()
plt.show()

# Find common genes
all_genes = [set(data.var_names) for data in datasets]
common_genes = sorted(list(all_genes[0] & all_genes[1] & all_genes[2]))
print(f"\nCommon genes across all datasets: {len(common_genes)}")

# Coordinate overlap analysis
print("\nCoordinate Overlap Analysis:")
tolerance = 1.0  # Distance tolerance for "overlap"

for i in range(len(datasets)):
    for j in range(i+1, len(datasets)):
        coords_i = datasets[i].obsm['spatial']
        coords_j = datasets[j].obsm['spatial']
        
        # Calculate pairwise distances
        distances = cdist(coords_i, coords_j)
        min_distances = np.min(distances, axis=1)
        
        # Count overlaps within tolerance
        overlaps = np.sum(min_distances < tolerance)
        
        print(f"{names[i]} vs {names[j]}:")
        print(f"  Spots within {tolerance} units: {overlaps}/{len(coords_i)} ({overlaps/len(coords_i)*100:.1f}%)")
        print(f"  Mean min distance: {np.mean(min_distances):.2f}")

# Gene expression similarity for closest spots
print("\nGene Expression Similarity (for closest coordinate pairs):")

for i in range(len(datasets)):
    for j in range(i+1, len(datasets)):
        # Get common genes data
        expr_i = datasets[i][:, common_genes].X
        expr_j = datasets[j][:, common_genes].X
        
        if hasattr(expr_i, 'toarray'):
            expr_i = expr_i.toarray()
        if hasattr(expr_j, 'toarray'):
            expr_j = expr_j.toarray()
        
        coords_i = datasets[i].obsm['spatial']
        coords_j = datasets[j].obsm['spatial']
        
        # Find closest pairs
        distances = cdist(coords_i, coords_j)
        closest_j_indices = np.argmin(distances, axis=1)
        
        # Calculate correlations for closest pairs
        correlations = []
        for spot_i in range(len(expr_i)):
            closest_j = closest_j_indices[spot_i]
            corr = np.corrcoef(expr_i[spot_i], expr_j[closest_j])[0, 1]
            if not np.isnan(corr):
                correlations.append(corr)
        
        print(f"{names[i]} vs {names[j]}:")
        print(f"  Mean gene expression correlation: {np.mean(correlations):.4f}")
        print(f"  Median correlation: {np.median(correlations):.4f}")
        print(f"  Correlations > 0.5: {np.sum(np.array(correlations) > 0.5)}/{len(correlations)} ({np.sum(np.array(correlations) > 0.5)/len(correlations)*100:.1f}%)")

# Distance distribution plot
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
pair_idx = 0

for i in range(len(datasets)):
    for j in range(i+1, len(datasets)):
        coords_i = datasets[i].obsm['spatial']
        coords_j = datasets[j].obsm['spatial']
        
        distances = cdist(coords_i, coords_j)
        min_distances = np.min(distances, axis=1)
        
        axes[pair_idx].hist(min_distances, bins=50, alpha=0.7)
        axes[pair_idx].set_title(f'{names[i]} vs {names[j]}\nMin Distance Distribution')
        axes[pair_idx].set_xlabel('Distance to closest spot')
        axes[pair_idx].set_ylabel('Frequency')
        axes[pair_idx].axvline(tolerance, color='red', linestyle='--', label=f'Tolerance={tolerance}')
        axes[pair_idx].legend()
        
        pair_idx += 1

plt.tight_layout()
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.spatial.distance import cdist
from scipy.stats import pearsonr
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler
import umap
from scipy.spatial import cKDTree
from scipy.interpolate import griddata
import pandas as pd

# Get common genes and prepare data
all_genes = [set(data.var_names) for data in datasets]
common_genes = sorted(list(all_genes[0] & all_genes[1] & all_genes[2]))

print(f"Running comprehensive analysis on {len(common_genes)} common genes...")

# Prepare expression matrices
expr_matrices = []
coord_matrices = []
for data in datasets:
    expr = data[:, common_genes].X
    if hasattr(expr, 'toarray'):
        expr = expr.toarray()
    expr_matrices.append(expr)
    coord_matrices.append(data.obsm['spatial'])

# 1. FIXED SPATIAL GENE EXPRESSION GRADIENTS
print("\n1. Calculating spatial gradients...")

def calculate_spatial_gradients_fixed(expr, coords, top_n_genes=20):
    """Calculate spatial gradients using local neighborhood differences"""
    tree = cKDTree(coords)
    gradients = {}
    
    for gene_idx in range(min(top_n_genes, expr.shape[1])):
        gene_expr = expr[:, gene_idx]
        
        grad_magnitudes = []
        
        for i in range(len(coords)):
            # Find 5 nearest neighbors
            distances, indices = tree.query(coords[i], k=6)  # Include self
            neighbors = indices[1:]  # Exclude self
            
            if len(neighbors) > 0:
                # Calculate gradient as max difference with neighbors
                expr_diffs = []
                for neighbor in neighbors:
                    if neighbor < len(gene_expr):
                        coord_diff = coords[neighbor] - coords[i]
                        expr_diff = gene_expr[neighbor] - gene_expr[i]
                        if np.linalg.norm(coord_diff) > 0:
                            # Directional derivative
                            grad_component = expr_diff / np.linalg.norm(coord_diff)
                            expr_diffs.append(abs(grad_component))
                
                if expr_diffs:
                    grad_magnitudes.append(max(expr_diffs))
                else:
                    grad_magnitudes.append(0)
            else:
                grad_magnitudes.append(0)
        
        grad_magnitudes = np.array(grad_magnitudes)
        gradients[common_genes[gene_idx]] = {
            'magnitude': grad_magnitudes,
            'mean_magnitude': np.mean(grad_magnitudes)
        }
    
    return gradients

# Calculate gradients for each dataset
gradient_results = []
for i, (expr, coords, name) in enumerate(zip(expr_matrices, coord_matrices, names)):
    print(f"  Processing {name}...")
    gradients = calculate_spatial_gradients_fixed(expr, coords, top_n_genes=20)
    gradient_results.append(gradients)

# Plot top gradient genes
fig, axes = plt.subplots(3, 3, figsize=(15, 12))
for i, (gradients, name) in enumerate(zip(gradient_results, names)):
    # Get top 3 genes with highest gradient magnitude
    top_genes = sorted(gradients.items(), key=lambda x: x[1]['mean_magnitude'], reverse=True)[:3]
    
    for j, (gene, grad_data) in enumerate(top_genes):
        coords = coord_matrices[i]
        scatter = axes[i, j].scatter(coords[:, 0], coords[:, 1], c=grad_data['magnitude'], 
                          cmap='viridis', s=20, alpha=0.7)
        axes[i, j].set_title(f'{name}\n{gene} (grad: {grad_data["mean_magnitude"]:.3f})')
        axes[i, j].set_xlabel('X')
        axes[i, j].set_ylabel('Y')
        plt.colorbar(scatter, ax=axes[i, j])

plt.tight_layout()
plt.show()

# 2. PCA/UMAP ANALYSIS
print("\n2. Performing PCA and UMAP analysis...")

# Combine all datasets for joint analysis
all_expr = np.vstack(expr_matrices)
all_coords = np.vstack(coord_matrices)
dataset_labels = np.concatenate([np.full(len(expr), i) for i, expr in enumerate(expr_matrices)])

# Standardize expression data
scaler = StandardScaler()
all_expr_scaled = scaler.fit_transform(all_expr)

# PCA
pca = PCA(n_components=50)
pca_result = pca.fit_transform(all_expr_scaled)

# UMAP
umap_reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=2, random_state=42)
umap_result = umap_reducer.fit_transform(all_expr_scaled)

# Plot PCA and UMAP
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# PCA by dataset
colors = ['red', 'blue', 'green']
for i, (name, color) in enumerate(zip(names, colors)):
    mask = dataset_labels == i
    axes[0, 0].scatter(pca_result[mask, 0], pca_result[mask, 1], 
                      c=color, label=name, alpha=0.6, s=15)
axes[0, 0].set_title('PCA by Dataset')
axes[0, 0].set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.2%})')
axes[0, 0].set_ylabel(f'PC2 ({pca.explained_variance_ratio_[1]:.2%})')
axes[0, 0].legend()

# UMAP by dataset
for i, (name, color) in enumerate(zip(names, colors)):
    mask = dataset_labels == i
    axes[0, 1].scatter(umap_result[mask, 0], umap_result[mask, 1], 
                      c=color, label=name, alpha=0.6, s=15)
axes[0, 1].set_title('UMAP by Dataset')
axes[0, 1].set_xlabel('UMAP1')
axes[0, 1].set_ylabel('UMAP2')
axes[0, 1].legend()

# PCA variance explained
axes[1, 0].plot(range(1, 21), pca.explained_variance_ratio_[:20], 'bo-')
axes[1, 0].set_title('PCA Variance Explained')
axes[1, 0].set_xlabel('Principal Component')
axes[1, 0].set_ylabel('Variance Explained')

# Spatial coordinates colored by first PC
scatter = axes[1, 1].scatter(all_coords[:, 0], all_coords[:, 1], c=pca_result[:, 0], 
                  cmap='viridis', s=15, alpha=0.7)
axes[1, 1].set_title('Spatial Distribution (colored by PC1)')
axes[1, 1].set_xlabel('X coordinate')
axes[1, 1].set_ylabel('Y coordinate')
plt.colorbar(scatter, ax=axes[1, 1])

plt.tight_layout()
plt.show()

# 3. HIGHLY VARIABLE GENES
print("\n3. Finding highly variable genes...")

def find_highly_variable_genes(expr, gene_names, top_n=20):
    """Find genes with highest coefficient of variation"""
    cv_scores = []
    for i in range(expr.shape[1]):
        mean_expr = np.mean(expr[:, i])
        std_expr = np.std(expr[:, i])
        cv = std_expr / (mean_expr + 1e-8)
        cv_scores.append(cv)
    
    top_indices = np.argsort(cv_scores)[-top_n:][::-1]
    return [(gene_names[i], cv_scores[i]) for i in top_indices]

hvg_results = []
for i, (expr, name) in enumerate(zip(expr_matrices, names)):
    hvgs = find_highly_variable_genes(expr, common_genes)
    hvg_results.append(hvgs)
    print(f"\nTop 10 HVGs in {name}:")
    for gene, cv in hvgs[:10]:
        print(f"  {gene}: CV = {cv:.3f}")

# 4. SPATIAL AUTOCORRELATION
print("\n4. Calculating spatial autocorrelation...")

def moran_i(expr, coords, k=8):
    """Calculate Moran's I for spatial autocorrelation"""
    tree = cKDTree(coords)
    n = len(expr)
    
    distances, indices = tree.query(coords, k=min(k+1, n))
    
    mean_expr = np.mean(expr)
    numerator = 0
    denominator = 0
    w_sum = 0
    
    for i in range(n):
        neighbors = indices[i, 1:min(k+1, len(indices[i]))]
        for j in neighbors:
            if j < n:
                numerator += (expr[i] - mean_expr) * (expr[j] - mean_expr)
                w_sum += 1
        denominator += (expr[i] - mean_expr) ** 2
    
    if w_sum > 0 and denominator > 0:
        moran_i = (n / w_sum) * (numerator / denominator)
    else:
        moran_i = 0
    
    return moran_i

autocorr_results = []
for i, (expr, coords, name) in enumerate(zip(expr_matrices, coord_matrices, names)):
    print(f"  Processing {name}...")
    autocorr_genes = []
    
    # Test top 50 most variable genes
    hvgs = find_highly_variable_genes(expr, common_genes, top_n=50)
    
    for gene, _ in hvgs:
        gene_idx = common_genes.index(gene)
        moran = moran_i(expr[:, gene_idx], coords)
        autocorr_genes.append((gene, moran))
    
    autocorr_genes.sort(key=lambda x: abs(x[1]), reverse=True)
    autocorr_results.append(autocorr_genes)
    
    print(f"Top 5 spatially autocorrelated genes in {name}:")
    for gene, moran in autocorr_genes[:5]:
        print(f"  {gene}: Moran's I = {moran:.3f}")

# 5. FIXED CROSS-DATASET GENE CORRELATION
print("\n5. Cross-dataset gene correlation analysis...")

def calculate_cross_dataset_correlation_fixed(expr1, expr2, coords1, coords2, gene_names):
    """Calculate gene-wise correlation using spatially matched spots"""
    correlations = []
    
    # Find spatially closest spots between datasets
    tree = cKDTree(coords2)
    distances, indices = tree.query(coords1)
    
    # Only use matches within reasonable distance
    good_matches = distances < 5.0  # Adjust threshold as needed
    
    if np.sum(good_matches) < 10:
        print(f"    Warning: Only {np.sum(good_matches)} good spatial matches found")
        return correlations
    
    matched_expr1 = expr1[good_matches]
    matched_expr2 = expr2[indices[good_matches]]
    
    for i in range(len(gene_names)):
        if len(matched_expr1) > 1:  # Need at least 2 points for correlation
            corr, _ = pearsonr(matched_expr1[:, i], matched_expr2[:, i])
            if not np.isnan(corr):
                correlations.append((gene_names[i], corr))
    
    return correlations

# Calculate pairwise correlations
corr_results = {}
for i in range(len(datasets)):
    for j in range(i+1, len(datasets)):
        pair_name = f"{names[i]}_vs_{names[j]}"
        correlations = calculate_cross_dataset_correlation_fixed(
            expr_matrices[i], expr_matrices[j], 
            coord_matrices[i], coord_matrices[j], 
            common_genes
        )
        correlations.sort(key=lambda x: abs(x[1]), reverse=True)
        corr_results[pair_name] = correlations
        
        print(f"\nTop correlated genes between {names[i]} and {names[j]}:")
        for gene, corr in correlations[:10]:
            print(f"  {gene}: r = {corr:.3f}")

# 6. EXPRESSION CLUSTERING
print("\n6. Performing expression clustering...")

n_clusters = 5
clustering_results = []

for i, (expr, coords, name) in enumerate(zip(expr_matrices, coord_matrices, names)):
    expr_scaled = StandardScaler().fit_transform(expr)
    
    kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
    cluster_labels = kmeans.fit_predict(expr_scaled)
    
    clustering_results.append((cluster_labels, kmeans))
    
    print(f"\nCluster sizes in {name}:")
    unique, counts = np.unique(cluster_labels, return_counts=True)
    for cluster, count in zip(unique, counts):
        print(f"  Cluster {cluster}: {count} spots")

# Plot clustering results
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
for i, ((cluster_labels, _), coords, name) in enumerate(zip(clustering_results, coord_matrices, names)):
    scatter = axes[i].scatter(coords[:, 0], coords[:, 1], c=cluster_labels, 
                             cmap='tab10', s=20, alpha=0.7)
    axes[i].set_title(f'{name}\nExpression Clusters')
    axes[i].set_xlabel('X coordinate')
    axes[i].set_ylabel('Y coordinate')
    plt.colorbar(scatter, ax=axes[i])

plt.tight_layout()
plt.show()

# 7. DIFFERENTIAL EXPRESSION BY SPATIAL REGIONS
print("\n7. Differential expression by spatial regions...")

def spatial_differential_expression(expr, coords, gene_names, n_regions=4):
    """Find differentially expressed genes across spatial regions"""
    x_median = np.median(coords[:, 0])
    y_median = np.median(coords[:, 1])
    
    regions = []
    regions.append((coords[:, 0] <= x_median) & (coords[:, 1] <= y_median))
    regions.append((coords[:, 0] > x_median) & (coords[:, 1] <= y_median))
    regions.append((coords[:, 0] <= x_median) & (coords[:, 1] > y_median))
    regions.append((coords[:, 0] > x_median) & (coords[:, 1] > y_median))
    
    de_results = []
    
    for gene_idx in range(min(100, expr.shape[1])):
        gene_expr = expr[:, gene_idx]
        region_means = [np.mean(gene_expr[region]) for region in regions]
        
        cv_regions = np.std(region_means) / (np.mean(region_means) + 1e-8)
        de_results.append((gene_names[gene_idx], cv_regions, region_means))
    
    de_results.sort(key=lambda x: x[1], reverse=True)
    return de_results

de_results = []
for i, (expr, coords, name) in enumerate(zip(expr_matrices, coord_matrices, names)):
    de_genes = spatial_differential_expression(expr, coords, common_genes)
    de_results.append(de_genes)
    
    print(f"\nTop spatially DE genes in {name}:")
    for gene, cv, means in de_genes[:10]:
        print(f"  {gene}: CV = {cv:.3f}")

print("\n=== COMPREHENSIVE ANALYSIS COMPLETE ===")

In [None]:
# Let's examine the expression patterns of these suspicious HVGs
print("INVESTIGATING SUSPICIOUS HVG PATTERNS")
print("="*50)

# Get the top HVG genes for each dataset
top_hvg_genes = []
for i, (expr, name) in enumerate(zip(expr_matrices, names)):
    hvgs = find_highly_variable_genes(expr, common_genes, top_n=10)
    top_hvg_genes.append([gene for gene, cv in hvgs])

# Check expression patterns for the first few genes in each dataset
for dataset_idx, (expr, name, top_genes) in enumerate(zip(expr_matrices, names, top_hvg_genes)):
    print(f"\n{name} - Examining top 3 HVG genes:")
    print("-" * 40)
    
    for gene_idx in range(3):  # Check first 3 genes
        gene_name = top_genes[gene_idx]
        gene_position = common_genes.index(gene_name)
        gene_expr = expr[:, gene_position]
        
        print(f"\nGene: {gene_name}")
        print(f"Expression values (first 20 spots): {gene_expr[:20]}")
        print(f"Unique values: {np.unique(gene_expr)}")
        print(f"Number of unique values: {len(np.unique(gene_expr))}")
        print(f"Min: {np.min(gene_expr):.4f}, Max: {np.max(gene_expr):.4f}")
        print(f"Mean: {np.mean(gene_expr):.4f}, Std: {np.std(gene_expr):.4f}")
        print(f"CV: {np.std(gene_expr) / (np.mean(gene_expr) + 1e-8):.4f}")
        
        # Count how many zeros vs non-zeros
        zeros = np.sum(gene_expr == 0)
        non_zeros = np.sum(gene_expr != 0)
        print(f"Zeros: {zeros}, Non-zeros: {non_zeros}")
        
        # Show value distribution
        unique_vals, counts = np.unique(gene_expr, return_counts=True)
        print(f"Value distribution:")
        for val, count in zip(unique_vals[:10], counts[:10]):  # Show first 10 most common
            print(f"  {val:.4f}: {count} spots")
        if len(unique_vals) > 10:
            print(f"  ... and {len(unique_vals)-10} more unique values")

# Let's also check if this is a pattern across ALL genes, not just HVGs
print(f"\n\nCHECKING OVERALL GENE EXPRESSION PATTERNS")
print("="*50)

for dataset_idx, (expr, name) in enumerate(zip(expr_matrices, names)):
    print(f"\n{name} - Overall statistics:")
    
    # Check how many genes have very few unique values
    genes_with_few_values = 0
    genes_with_binary = 0
    genes_with_identical_cv = 0
    
    cv_values = []
    
    for gene_idx in range(min(100, expr.shape[1])):  # Check first 100 genes
        gene_expr = expr[:, gene_idx]
        unique_vals = np.unique(gene_expr)
        cv = np.std(gene_expr) / (np.mean(gene_expr) + 1e-8)
        cv_values.append(cv)
        
        if len(unique_vals) <= 3:
            genes_with_few_values += 1
        if len(unique_vals) == 2:
            genes_with_binary += 1
    
    # Check for identical CV values
    unique_cvs = np.unique(cv_values)
    print(f"  Genes with ≤3 unique values: {genes_with_few_values}/100")
    print(f"  Genes with binary expression: {genes_with_binary}/100") 
    print(f"  Number of unique CV values: {len(unique_cvs)}")
    print(f"  Most common CV values: {unique_cvs[:10]}")
    
    # Show distribution of CV values
    cv_values = np.array(cv_values)
    print(f"  CV range: {np.min(cv_values):.4f} to {np.max(cv_values):.4f}")
    print(f"  CV mean: {np.mean(cv_values):.4f}")

# Create a histogram of CV values to visualize the pattern
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

for dataset_idx, (expr, name) in enumerate(zip(expr_matrices, names)):
    cv_values = []
    for gene_idx in range(min(200, expr.shape[1])):  # Check first 200 genes
        gene_expr = expr[:, gene_idx]
        cv = np.std(gene_expr) / (np.mean(gene_expr) + 1e-8)
        cv_values.append(cv)
    
    axes[dataset_idx].hist(cv_values, bins=50, alpha=0.7)
    axes[dataset_idx].set_title(f'{name}\nCV Distribution')
    axes[dataset_idx].set_xlabel('Coefficient of Variation')
    axes[dataset_idx].set_ylabel('Number of Genes')
    axes[dataset_idx].set_yscale('log')  # Log scale to see patterns better

plt.tight_layout()
plt.show()

In [None]:
def load_and_process_cscc_data_p10():
    """
    Load and process the cSCC dataset with multiple ST replicates.
    """
    print("Loading cSCC data...")
    
    # Load SC data
    scadata_p10 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/scP10.h5ad')
    
    # Load all 3 ST datasets
    stadata1_p10 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP10rep1.h5ad')
    stadata2_p10 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP10rep2.h5ad')
    stadata3_p10 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP10rep3.h5ad')
    
    # Normalize and log transform
    for adata in [scadata_p10, stadata1_p10, stadata2_p10, stadata3_p10]:
        sc.pp.normalize_total(adata)
        sc.pp.log1p(adata)
    
    # Create rough cell types for SC data
    scadata_p10.obs['rough_celltype'] = scadata_p10.obs['level1_celltype'].astype(str)
    scadata_p10.obs.loc[scadata_p10.obs['level1_celltype']=='CLEC9A','rough_celltype'] = 'DC'
    scadata_p10.obs.loc[scadata_p10.obs['level1_celltype']=='CD1C','rough_celltype'] = 'DC'
    scadata_p10.obs.loc[scadata_p10.obs['level1_celltype']=='ASDC','rough_celltype'] = 'DC'
    scadata_p10.obs.loc[scadata_p10.obs['level1_celltype']=='PDC','rough_celltype'] = 'PDC'
    scadata_p10.obs.loc[scadata_p10.obs['level1_celltype']=='MDSC','rough_celltype'] = 'DC'
    scadata_p10.obs.loc[scadata_p10.obs['level1_celltype']=='LC','rough_celltype'] = 'DC'
    scadata_p10.obs.loc[scadata_p10.obs['level1_celltype']=='Mac','rough_celltype'] = 'Myeloid cell'
    scadata_p10.obs.loc[scadata_p10.obs['level1_celltype']=='Tcell','rough_celltype'] = 'T cell'
    scadata_p10.obs.loc[scadata_p10.obs['level2_celltype']=='TSK','rough_celltype'] = 'TSK'
    scadata_p10.obs.loc[scadata_p10.obs['level2_celltype'].isin(['Tumor_KC_Basal', 'Tumor_KC_Diff','Tumor_KC_Cyc']),'rough_celltype'] = 'NonTSK'
    
    return scadata_p10, stadata1_p10, stadata2_p10, stadata3_p10

def prepare_combined_st_for_diffusion(stadata1, stadata2, stadata3, scadata):
    """
    Combine all ST datasets for diffusion training while maintaining gene alignment.
    Key innovation: Use ALL ST data points for better training.
    """
    print("Preparing combined ST data for diffusion training...")
    
    # Get common genes between SC and all ST datasets
    sc_genes = set(scadata.var_names)
    st1_genes = set(stadata1.var_names)
    st2_genes = set(stadata2.var_names)
    st3_genes = set(stadata3.var_names)
    
    common_genes = sorted(list(sc_genes & st1_genes & st2_genes & st3_genes))
    print(f"Common genes across all datasets: {len(common_genes)}")
    
    # Extract aligned expression data
    sc_expr = scadata[:, common_genes].X
    st1_expr = stadata1[:, common_genes].X
    st2_expr = stadata2[:, common_genes].X
    st3_expr = stadata3[:, common_genes].X
    
    # Convert to dense if sparse
    if hasattr(sc_expr, 'toarray'):
        sc_expr = sc_expr.toarray()
    if hasattr(st1_expr, 'toarray'):
        st1_expr = st1_expr.toarray()
    if hasattr(st2_expr, 'toarray'):
        st2_expr = st2_expr.toarray()
    if hasattr(st3_expr, 'toarray'):
        st3_expr = st3_expr.toarray()
    
    # Get spatial coordinates
    st1_coords = stadata1.obsm['spatial']
    st2_coords = stadata2.obsm['spatial']
    st3_coords = stadata3.obsm['spatial']

    # Store separate coordinate lists for block-diagonal graph
    st_coords_list = [st1_coords, st2_coords, st3_coords]
    
    # Combine all ST data
    st_expr_combined = np.vstack([st1_expr, st2_expr, st3_expr])

    from sklearn.preprocessing import StandardScaler
    scaler = StandardScaler()
    st_expr_combined = scaler.fit_transform(st_expr_combined)

    st_coords_combined = np.vstack([st1_coords, st2_coords, st3_coords])

    sc_expr = scaler.fit_transform(sc_expr)


    
    # Create dataset labels for tracking
    dataset_labels = (['dataset1'] * len(st1_expr) + 
                     ['dataset2'] * len(st2_expr) + 
                     ['dataset3'] * len(st3_expr))
    
    print(f"Combined ST data shape: {st_expr_combined.shape}")
    print(f"Combined ST coords shape: {st_coords_combined.shape}")
    print(f"SC data shape: {sc_expr.shape}")
    
    # Convert to tensors
    X_sc = torch.tensor(sc_expr, dtype=torch.float32)
    X_st_combined = torch.tensor(st_expr_combined, dtype=torch.float32)
    Y_st_combined = st_coords_combined.astype(np.float32)
    
    return X_sc, X_st_combined, Y_st_combined, dataset_labels, common_genes, st_coords_list

# Load and process data
scadata_p10, stadata1_p10, stadata2_p10, stadata3_p10 = load_and_process_cscc_data_p10()

# Prepare combined data for diffusion
X_sc, X_st_combined, Y_st_combined, dataset_labels, common_genes, st_coords_list = prepare_combined_st_for_diffusion(
    stadata1_p10, stadata2_p10, stadata3_p10, scadata_p10
)

print(f"Data preparation complete!")
print(f"SC cells: {X_sc.shape[0]}")
print(f"Combined ST spots: {X_st_combined.shape[0]}")
print(f"Common genes: {len(common_genes)}")



In [None]:
# Add this code block to your notebook BEFORE the training loop

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
from scipy import stats
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import torch

def analyze_sc_st_patterns(model, n_genes=20, figsize=(20, 15)):
    """
    Comprehensive comparison of gene expression patterns between SC and ST data
    """
    
    # Get data
    sc_expr = model.sc_gene_expr
    st_expr = model.st_gene_expr
    
    # Convert to numpy if needed
    if torch.is_tensor(sc_expr):
        sc_expr = sc_expr.cpu().numpy()
    if torch.is_tensor(st_expr):
        st_expr = st_expr.cpu().numpy()
    
    # Find common highly variable genes for comparison
    common_genes = find_common_variable_genes(sc_expr, st_expr, n_genes)
    
    fig = plt.figure(figsize=figsize)
    
    # 1. Expression distribution comparison (violin plots)
    plot_expression_distributions(sc_expr, st_expr, common_genes, fig, 1)
    
    # 2. Correlation heatmap between SC and ST
    plot_sc_st_correlation(sc_expr, st_expr, common_genes, fig, 2)
    
    # 3. Joint embedding (t-SNE) of SC and ST
    plot_joint_embedding(sc_expr, st_expr, fig, 3)
    
    # 4. Spatial expression patterns for ST
    plot_spatial_expression_patterns(model, common_genes, fig, 4)
    
    # 5. Domain alignment quality
    plot_domain_alignment(model, fig, 5)
    
    plt.tight_layout()
    plt.show()
    
    return common_genes

def find_common_variable_genes(sc_expr, st_expr, n_genes=20):
    """Find genes that are variable in both SC and ST"""
    
    # Calculate CV for SC data
    sc_mean = np.mean(sc_expr, axis=0) + 1e-8
    sc_std = np.std(sc_expr, axis=0)
    sc_cv = sc_std / sc_mean
    
    # Calculate CV for ST data  
    st_mean = np.mean(st_expr, axis=0) + 1e-8
    st_std = np.std(st_expr, axis=0)
    st_cv = st_std / st_mean
    
    # Find genes variable in both (avoid extreme single-spot genes)
    sc_nonzero_frac = (sc_expr > 0).mean(0)
    st_nonzero_frac = (st_expr > 0).mean(0)
    
    # Filter for genes expressed in reasonable fraction of cells/spots
    valid_mask = (sc_nonzero_frac >= 0.1) & (sc_nonzero_frac <= 0.9) & \
                 (st_nonzero_frac >= 0.1) & (st_nonzero_frac <= 0.9)
    
    # Rank by combined variability
    combined_cv = sc_cv + st_cv
    combined_cv[~valid_mask] = 0
    
    top_gene_indices = np.argsort(combined_cv)[-n_genes:]
    
    print(f"Selected {len(top_gene_indices)} variable genes for analysis")
    
    return top_gene_indices

def plot_expression_distributions(sc_expr, st_expr, gene_indices, fig, subplot_num):
    """Plot violin plots comparing expression distributions"""
    
    n_genes = min(len(gene_indices), 10)  # Limit to 10 genes for clarity
    n_cols = 5
    n_rows = 2
    
    for i in range(n_genes):
        gene_idx = gene_indices[i]
        ax = fig.add_subplot(5, n_cols, i + 1)
        
        # Prepare data
        sc_values = sc_expr[:, gene_idx]
        st_values = st_expr[:, gene_idx]
        
        data_df = pd.DataFrame({
            'Expression': np.concatenate([sc_values, st_values]),
            'Data_Type': ['SC'] * len(sc_values) + ['ST'] * len(st_values)
        })
        
        # Plot violin plot
        sns.violinplot(data=data_df, x='Data_Type', y='Expression', ax=ax)
        ax.set_title(f'Gene {gene_idx}', fontsize=8)
        ax.set_xlabel('')
        
        if i % n_cols != 0:
            ax.set_ylabel('')

def plot_sc_st_correlation(sc_expr, st_expr, gene_indices, fig, subplot_num):
    """Plot correlation between SC and ST expression levels"""
    
    ax = fig.add_subplot(5, 2, 3)
    
    # Calculate mean expression for each gene
    sc_means = np.mean(sc_expr[:, gene_indices], axis=0)
    st_means = np.mean(st_expr[:, gene_indices], axis=0)
    
    # Scatter plot
    ax.scatter(sc_means, st_means, alpha=0.7)
    
    # Add correlation coefficient
    corr, p_val = stats.pearsonr(sc_means, st_means)
    ax.set_title(f'SC vs ST Mean Expression\nCorr = {corr:.3f}, p = {p_val:.3e}')
    ax.set_xlabel('SC Mean Expression')
    ax.set_ylabel('ST Mean Expression')
    
    # Add diagonal line
    min_val, max_val = 0, max(ax.get_xlim()[1], ax.get_ylim()[1])
    ax.plot([min_val, max_val], [min_val, max_val], 'r--', alpha=0.5)

def plot_joint_embedding(sc_expr, st_expr, fig, subplot_num):
    """Create joint t-SNE embedding of SC and ST data"""
    
    ax = fig.add_subplot(5, 2, 4)
    
    # Sample data for faster computation
    n_sample = min(2000, sc_expr.shape[0])
    sc_sample_idx = np.random.choice(sc_expr.shape[0], n_sample, replace=False)
    
    n_st_sample = min(500, st_expr.shape[0])
    st_sample_idx = np.random.choice(st_expr.shape[0], n_st_sample, replace=False)
    
    # Combine data
    combined_data = np.vstack([
        sc_expr[sc_sample_idx],
        st_expr[st_sample_idx]
    ])
    
    # t-SNE embedding
    tsne = TSNE(n_components=2, random_state=42, perplexity=30)
    embedding = tsne.fit_transform(combined_data)
    
    # Plot
    colors = ['blue'] * n_sample + ['red'] * n_st_sample
    labels = ['SC'] * n_sample + ['ST'] * n_st_sample
    
    for label, color in [('SC', 'blue'), ('ST', 'red')]:
        mask = np.array(labels) == label
        ax.scatter(embedding[mask, 0], embedding[mask, 1], 
                  c=color, alpha=0.6, s=20, label=label)
    
    ax.set_title('Joint t-SNE: SC vs ST')
    ax.legend()
    ax.set_xlabel('t-SNE 1')
    ax.set_ylabel('t-SNE 2')

def plot_spatial_expression_patterns(model, gene_indices, fig, subplot_num):
    """Plot spatial expression patterns for ST data"""
    
    n_genes_show = min(4, len(gene_indices))
    
    for i in range(n_genes_show):
        ax = fig.add_subplot(5, 4, 8 + i + 1)
        gene_idx = gene_indices[-(i+1)]  # Take top variable genes
        
        # Get expression values
        if torch.is_tensor(model.st_gene_expr):
            expr_values = model.st_gene_expr[:, gene_idx].cpu().numpy()
        else:
            expr_values = model.st_gene_expr[:, gene_idx]
            
        if torch.is_tensor(model.st_coords):
            coords = model.st_coords.cpu().numpy()
        else:
            coords = model.st_coords
        
        # Spatial scatter plot
        scatter = ax.scatter(coords[:, 0], coords[:, 1], 
                           c=expr_values, cmap='viridis', 
                           s=30, alpha=0.8)
        
        ax.set_title(f'ST Spatial: Gene {gene_idx}', fontsize=8)
        ax.set_aspect('equal')
        plt.colorbar(scatter, ax=ax, shrink=0.8)

def plot_domain_alignment(model, fig, subplot_num):
    """Plot domain alignment quality using encoder embeddings"""
    
    ax = fig.add_subplot(5, 1, 5)
    
    with torch.no_grad():
        # Get aligned embeddings
        if torch.is_tensor(model.sc_gene_expr):
            sc_data = model.sc_gene_expr
        else:
            sc_data = torch.tensor(model.sc_gene_expr, device=model.device)
            
        if torch.is_tensor(model.st_gene_expr):
            st_data = model.st_gene_expr
        else:
            st_data = torch.tensor(model.st_gene_expr, device=model.device)
        
        sc_embedding = model.netE(sc_data).cpu().numpy()
        st_embedding = model.netE(st_data).cpu().numpy()
        
        # Sample for visualization
        n_sample = min(1000, sc_embedding.shape[0])
        sc_idx = np.random.choice(sc_embedding.shape[0], n_sample, replace=False)
        st_idx = np.random.choice(st_embedding.shape[0], min(300, st_embedding.shape[0]), replace=False)
        
        # PCA for visualization
        combined_embedding = np.vstack([sc_embedding[sc_idx], st_embedding[st_idx]])
        pca = PCA(n_components=2)
        embedding_2d = pca.fit_transform(combined_embedding)
        
        # Plot
        n_sc = len(sc_idx)
        ax.scatter(embedding_2d[:n_sc, 0], embedding_2d[:n_sc, 1], 
                  c='blue', alpha=0.6, s=20, label='SC')
        ax.scatter(embedding_2d[n_sc:, 0], embedding_2d[n_sc:, 1], 
                  c='red', alpha=0.6, s=20, label='ST')
        
        ax.set_title('Domain Alignment (Encoder Embeddings)')
        ax.legend()
        ax.set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.2f})')
        ax.set_ylabel(f'PC2 ({pca.explained_variance_ratio_[1]:.2f})')

print("Analysis functions loaded successfully!")

In [None]:
def train_individual_advanced_diffusion_models(scadata, stadata1, stadata2, stadata3):
    """
    Train separate AdvancedHierarchicalDiffusion models for each ST dataset and average the results.
    MODIFIED: Run stadata1 three times to test for SC cluster rotation/sliding
    """
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Store results from each model
    sc_coords_results = []
    models_all = []
    
    # STEP 1: Build canonical angular frame from ST slide (ONCE)
    # st_coords_raw = stadata1.obsm['spatial']  # Use raw ST coordinates
    # angular_frame = _build_canonical_angular_frame(st_coords_raw)
    
    # List of ST datasets for iteration - Use stadata1 three times
    st_datasets = [
        (stadata1, "run1"),
        (stadata2, "run2"), 
        (stadata3, "run3")
    ]
    
    for i, (stadata, run_name) in enumerate(st_datasets):
        print(f"\n{'='*50}")
        print(f"Training AdvancedHierarchicalDiffusion model {i+1}/3 for {run_name}")
        print(f"{'='*50}")
        
        # Get common genes between SC and current ST dataset
        sc_genes = set(scadata.var_names)
        st_genes = set(stadata.var_names)
        common_genes = sorted(list(sc_genes & st_genes))
        
        print(f"Common genes for {run_name}: {len(common_genes)}")
        
        # Extract expression data
        sc_expr = scadata[:, common_genes].X
        st_expr = stadata[:, common_genes].X
        
        # Convert to dense if sparse
        if hasattr(sc_expr, 'toarray'):
            sc_expr = sc_expr.toarray()
        if hasattr(st_expr, 'toarray'):
            st_expr = st_expr.toarray()
            
        # Get spatial coordinates
        st_coords = stadata.obsm['spatial']
        
        print(f"SC data shape: {sc_expr.shape}")
        print(f"ST data shape: {st_expr.shape}")
        print(f"ST coords shape: {st_coords.shape}")
        
        # Initialize model with different random seed for each run
        torch.manual_seed(42 + i)
        np.random.seed(42 + i)
        
        model = AdvancedHierarchicalDiffusion(
            st_gene_expr=st_expr,
            st_coords=st_coords,
            sc_gene_expr=sc_expr,
            cell_types_sc=scadata.obs['rough_celltype'].values,
            transport_plan=None,
            D_st=None,
            D_induced=None,
            n_genes=len(common_genes),
            n_embedding=[512, 256, 128],
            coord_space_diameter=2.00,
            sigma=2.0,
            alpha=0.8,
            mmdbatch=1000,
            batch_size=256,
            device=device,
            lr_e=0.0001,
            lr_d=0.0002,
            n_timesteps=800,
            n_denoising_blocks=6,
            hidden_dim=256,
            num_heads=8,
            num_hierarchical_scales=3,
            dp=0.2,
            outf=f'advanced_diffusion_{run_name}'
        )

        # Train the model
        print(f"Training model for {run_name}...")
        model.train(
            encoder_epochs=1000,
            vae_epochs=1500,
            diffusion_epochs=3000,
            lambda_struct=10.0
        )

        st_coords_raw = model.st_coords_norm.cpu().numpy()  # Use normalized coords from model
        angular_frame = _build_canonical_angular_frame(st_coords_raw)
        
        # Generate SC coordinates
        print(f"Generating SC coordinates using {run_name} model...")
        # sc_coords = model.generate_sc_coordinates()
        sc_coords = model.sample_sc_coordinates_batched(
            batch_size=512,  # Even smaller batches
            refine_coords=False
        )
        sc_coords_results.append(sc_coords)
        models_all.append(model)
        
        # STEP 2: Plot SC cells colored by angle (using ST-derived frame)
        _plot_sc_angle_analysis(sc_coords, scadata.obs['rough_celltype'].values, 
                               angular_frame, st_coords_raw, run_name, i+1)
    
    # STEP 3: Comparative analysis across runs
    _plot_comparative_sc_angle_analysis(sc_coords_results, scadata.obs['rough_celltype'].values,
                                       angular_frame, st_coords_raw)
    
    # Compute averaged SC coordinates
    sc_coords_avg = np.mean(sc_coords_results, axis=0)
    sc_coords_std = np.std(sc_coords_results, axis=0)
    
    # Store results in scadata
    scadata.obsm['advanced_diffusion_coords_avg'] = sc_coords_avg
    scadata.obsm['advanced_diffusion_coords_std'] = sc_coords_std
    
    # Store individual results
    for i, coords in enumerate(sc_coords_results):
        scadata.obsm[f'advanced_diffusion_coords_rep{i+1}'] = coords
    
    print(f"\nTraining complete. Results stored in scadata.obsm")
    return scadata, models_all

def _build_canonical_angular_frame(st_coords):
    """Build canonical angular frame from ST coordinates (dataset-specific, run-independent)"""
    import numpy as np
    
    # Compute centroid
    centroid = st_coords.mean(axis=0)
    
    # Find farthest spot from centroid (deterministic 0° direction)
    distances = np.linalg.norm(st_coords - centroid, axis=1)
    farthest_idx = np.argmax(distances)
    a0 = st_coords[farthest_idx] - centroid  # 0° direction vector
    
    def angle_fn(x):
        """Compute angle from canonical frame"""
        if x.ndim == 1:
            x = x.reshape(1, -1)
        
        v = x - centroid
        cross = a0[0] * v[:, 1] - a0[1] * v[:, 0]  # z-component of 2D cross
        dot = a0[0] * v[:, 0] + a0[1] * v[:, 1]
        angles = np.arctan2(cross, dot)
        angles = np.where(angles < 0, angles + 2*np.pi, angles)  # Map to [0, 2π)
        return angles
    
    return {
        'centroid': centroid,
        'zero_direction': a0,
        'farthest_idx': farthest_idx,
        'angle_fn': angle_fn
    }

def _plot_sc_angle_analysis(sc_coords, cell_types, angular_frame, st_coords_bg, run_name, run_num):
    """Plot SC cells colored by angle from ST-derived frame"""
    import matplotlib.pyplot as plt
    import numpy as np
    
    # Compute angles for SC cells using ST-derived frame
    sc_angles = angular_frame['angle_fn'](sc_coords)
    sc_angles_degrees = np.degrees(sc_angles)
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    # Plot 1: SC cells colored by angle (with ST outline in background)
    ax1.scatter(st_coords_bg[:, 0], st_coords_bg[:, 1], 
               c='lightgray', s=10, alpha=0.3, label='ST outline')
    
    scatter = ax1.scatter(sc_coords[:, 0], sc_coords[:, 1], 
                         c=sc_angles_degrees, cmap='hsv', s=30, alpha=0.8)
    
    # Mark centroid and 0° direction
    centroid = angular_frame['centroid']
    zero_dir = angular_frame['zero_direction']
    ax1.scatter(centroid[0], centroid[1], c='black', s=100, marker='x', linewidth=3)
    ax1.arrow(centroid[0], centroid[1], zero_dir[0]*0.3, zero_dir[1]*0.3, 
              head_width=0.05, head_length=0.05, fc='red', ec='red', linewidth=2)
    
    ax1.set_title(f'{run_name}: SC Cells Colored by Angle θ')
    ax1.set_xlabel('X coordinate')
    ax1.set_ylabel('Y coordinate')
    ax1.set_aspect('equal')
    ax1.legend()
    
    # Add colorbar
    cbar = plt.colorbar(scatter, ax=ax1)
    cbar.set_label('Angle (degrees)')
    
    # Plot 2: Per-cell-type angle distribution
    unique_types = np.unique(cell_types)
    colors = plt.cm.tab10(np.linspace(0, 1, len(unique_types)))
    
    for i, cell_type in enumerate(unique_types):
        mask = cell_types == cell_type
        if np.sum(mask) > 0:
            angles_subset = sc_angles_degrees[mask]
            ax2.hist(angles_subset, bins=36, alpha=0.6, label=cell_type, 
                    color=colors[i], density=True)
    
    ax2.set_title(f'{run_name}: Angle Distribution by Cell Type')
    ax2.set_xlabel('Angle (degrees)')
    ax2.set_ylabel('Density')
    ax2.set_xlim(0, 360)
    ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    
    plt.tight_layout()
    plt.savefig(f'sc_angle_analysis_{run_name}.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    # Print circular statistics per cell type
    print(f"\n{run_name} - Circular statistics per cell type:")
    for cell_type in unique_types:
        mask = cell_types == cell_type
        if np.sum(mask) > 5:  # Only if enough cells
            angles_rad = sc_angles[mask]
            # Circular mean
            mean_cos = np.mean(np.cos(angles_rad))
            mean_sin = np.mean(np.sin(angles_rad))
            circular_mean = np.arctan2(mean_sin, mean_cos)
            if circular_mean < 0:
                circular_mean += 2*np.pi
            
            print(f"  {cell_type}: mean={np.degrees(circular_mean):.1f}°, n={np.sum(mask)}")

def _plot_comparative_sc_angle_analysis(sc_coords_list, cell_types, angular_frame, st_coords_bg):
    """Plot comparative SC angle analysis across all runs"""
    import matplotlib.pyplot as plt
    import numpy as np
    
    n_runs = len(sc_coords_list)
    unique_types = np.unique(cell_types)
    
    fig, axes = plt.subplots(2, n_runs, figsize=(5*n_runs, 10))
    if n_runs == 1:
        axes = axes.reshape(-1, 1)
    
    # Top row: SC scatter plots per run
    for i, sc_coords in enumerate(sc_coords_list):
        ax = axes[0, i]
        
        # ST background
        ax.scatter(st_coords_bg[:, 0], st_coords_bg[:, 1], 
                  c='lightgray', s=5, alpha=0.3)
        
        # SC cells colored by angle
        sc_angles = angular_frame['angle_fn'](sc_coords)
        sc_angles_degrees = np.degrees(sc_angles)
        
        scatter = ax.scatter(sc_coords[:, 0], sc_coords[:, 1], 
                           c=sc_angles_degrees, cmap='hsv', s=20, alpha=0.8)
        
        ax.set_title(f'Run {i+1}: SC Cells by Angle')
        ax.set_aspect('equal')
        
        if i == n_runs-1:  # Add colorbar to last plot
            cbar = plt.colorbar(scatter, ax=ax)
            cbar.set_label('Angle (degrees)')
    
    # Bottom row: Cell type angle distributions per run  
    colors = plt.cm.tab10(np.linspace(0, 1, len(unique_types)))
    
    for i, sc_coords in enumerate(sc_coords_list):
        ax = axes[1, i]
        
        sc_angles = angular_frame['angle_fn'](sc_coords)
        sc_angles_degrees = np.degrees(sc_angles)
        
        for j, cell_type in enumerate(unique_types):
            mask = cell_types == cell_type
            if np.sum(mask) > 5:
                angles_subset = sc_angles_degrees[mask]
                ax.hist(angles_subset, bins=36, alpha=0.6, 
                       label=cell_type if i == 0 else "", 
                       color=colors[j], density=True)
        
        ax.set_title(f'Run {i+1}: Cell Type Angles')
        ax.set_xlabel('Angle (degrees)')
        ax.set_ylabel('Density')
        ax.set_xlim(0, 360)
        
        if i == 0:
            ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    
    plt.tight_layout()
    plt.savefig('comparative_sc_angle_analysis.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    # Check for sector sliding
    print(f"\n" + "="*60)
    print("SECTOR SLIDING ANALYSIS")
    print("="*60)
    
    for cell_type in unique_types:
        mask = cell_types == cell_type
        if np.sum(mask) > 10:  # Only analyze cell types with enough cells
            circular_means = []
            
            for i, sc_coords in enumerate(sc_coords_list):
                sc_angles = angular_frame['angle_fn'](sc_coords)
                angles_subset = sc_angles[mask]
                
                # Circular mean
                mean_cos = np.mean(np.cos(angles_subset))
                mean_sin = np.mean(np.sin(angles_subset))
                circular_mean = np.arctan2(mean_sin, mean_cos)
                if circular_mean < 0:
                    circular_mean += 2*np.pi
                
                circular_means.append(np.degrees(circular_mean))
            
            # Check for large differences between runs
            max_diff = max(circular_means) - min(circular_means)
            if max_diff > 180:  # Handle wraparound
                max_diff = 360 - max_diff
            
            print(f"{cell_type}:")
            print(f"  Run means: {[f'{m:.1f}°' for m in circular_means]}")
            print(f"  Max difference: {max_diff:.1f}°")
            
            if max_diff > 30:  # Significant sliding
                print(f"  ⚠️  SECTOR SLIDING DETECTED!")
            else:
                print(f"  ✅ Consistent placement")

In [None]:
# Load and process data
scadata_p10, stadata1_p10, stadata2_p10, stadata3_p10 = load_and_process_cscc_data_p10()

# Train individual AdvancedHierarchicalDiffusion models and get averaged results
scadata_p10, advanced_models_p10 = train_individual_advanced_diffusion_models(
    scadata_p10, stadata1_p10, stadata2_p10, stadata3_p10
)

In [None]:
# Visualize results
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib as mpl

mpl.rcParams['figure.figsize'] = (4, 4)

my_tab20 = sns.color_palette("tab20", n_colors=20).as_hex()

# Plot 1: Averaged coordinates
plt.figure(figsize=(8, 6))
sc.pl.embedding(scadata_p10, basis='advanced_diffusion_coords_avg', color='rough_celltype',
               size=85, title='SC Advanced Diffusion Coords (Averaged)',
               palette=my_tab20, legend_loc='right margin', legend_fontsize=10)
plt.show()

# Plot 2: Individual model results
for i in range(3):
    plt.figure(figsize=(6, 5))
    sc.pl.embedding(scadata_p10, basis=f'advanced_diffusion_coords_rep{i+1}', color='rough_celltype',
                   size=85, title=f'SC Coordinates (Advanced Model {i+1})',
                   palette=my_tab20, legend_loc='right margin', legend_fontsize=10)
    plt.show()

In [None]:
# Visualize results
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib as mpl

mpl.rcParams['figure.figsize'] = (4, 4)

my_tab20 = sns.color_palette("tab20", n_colors=20).as_hex()

# Plot 1: Averaged coordinates
plt.figure(figsize=(8, 6))
sc.pl.embedding(scadata_p10, basis='advanced_diffusion_coords_avg', color='rough_celltype',
               size=85, title='SC Advanced Diffusion Coords (Averaged)',
               palette=my_tab20, legend_loc='right margin', legend_fontsize=10)
plt.show()

# Plot 2: Individual model results
for i in range(3):
    plt.figure(figsize=(6, 5))
    sc.pl.embedding(scadata_p10, basis=f'advanced_diffusion_coords_rep{i+1}', color='rough_celltype',
                   size=85, title=f'SC Coordinates (Advanced Model {i+1})',
                   palette=my_tab20, legend_loc='right margin', legend_fontsize=10)
    plt.show()

In [None]:
scadata_p10.obsm['advanced_diffusion_coords_avg']

In [None]:
# Visualize results with separate plots
import matplotlib.pyplot as plt

import matplotlib as mpl
mpl.rcParams['figure.figsize'] = (4, 4)
# import scanpy as sc
# sc.settings.set_figure_params(figsize=(4,4), dpi=100)

my_tab20 = sns.color_palette("tab20", n_colors=20).as_hex()


# Plot 1: Averaged coordinates
plt.figure(figsize=(4, 4))
sc.pl.embedding(scadata_p10, basis='advanced_diffusion_coords_avg', color='rough_celltype',
               size=85, title='SC Spatial Coordinates (Averaged from 3 Models)',
               palette=my_tab20, legend_loc='right margin', legend_fontsize=10)
plt.show()

# Plot 2: Model 1 results
plt.figure(figsize=(4, 4))
sc.pl.embedding(scadata_p10, basis='advanced_diffusion_coords_rep1', color='rough_celltype',
               size=85, title='SC Coordinates (Model 1)',
               palette=my_tab20, legend_loc='right margin', legend_fontsize=10)
plt.show()

# Plot 3: Model 2 results
plt.figure(figsize=(4, 4))
sc.pl.embedding(scadata_p10, basis='advanced_diffusion_coords_rep2', color='rough_celltype',
               size=85, title='SC Coordinates (Model 2)',
               palette=my_tab20, legend_loc='right margin', legend_fontsize=10)
plt.show()

import seaborn as sns
my_tab20 = sns.color_palette("tab20", n_colors=12).as_hex()

# Plot 4: Model 3 results
plt.figure(figsize=(4, 4))
sc.pl.embedding(scadata_p10, basis='advanced_diffusion_coords_rep3', color='rough_celltype',
               size=85, title='SC Coordinates (Model 3)',
             palette=my_tab20, legend_loc='right margin', legend_fontsize=10)
plt.show()

In [None]:
scadata_p10.obs['selection'] = (scadata_p10.obs['level2_celltype']=='TSK').astype(int)
scadata_p10.obs['selection2'] = (scadata_p10.obs['level1_celltype']=='Fibroblast').astype(int)
scadata_p10.obs['selection3'] = (scadata_p10.obs['rough_celltype']=='Epithelial').astype(int)

# figsize(6,5)
plt.figure(figsize=(6, 6))

sc.pl.spatial(scadata_p10, color=['selection','selection2','selection3','level3_celltype'], spot_size=0.025,cmap='bwr',basis='advanced_diffusion_coords_avg')

In [None]:
sc.pl.spatial(scadata_p10,color="level3_celltype",groups=["TSK"],spot_size=0.03, show=True,basis='advanced_diffusion_coords_avg',title='reconstructed',na_in_legend=False)


In [None]:
import squidpy as sq
sq.gr.spatial_neighbors(scadata_p10,spatial_key='advanced_diffusion_coords_avg')
sq.gr.nhood_enrichment(scadata_p10,cluster_key='rough_celltype')
sq.gr.interaction_matrix(scadata_p10,cluster_key='rough_celltype')
kscadata_p10 = scadata_p10[ scadata_p10.obs.level2_celltype.isin(['Tumor_KC_Cyc','Tumor_KC_Basal','Tumor_KC_Diff','TSK'])].copy()
sq.gr.spatial_neighbors(kscadata_p10,spatial_key='advanced_diffusion_coords_avg')
sq.gr.nhood_enrichment(kscadata_p10,cluster_key='level2_celltype')
# sq.pl.nhood_enrichment(kscadata, cluster_key="level2_celltype",cmap='coolwarm',save='TSKKC_new_best_p10.svg',figsize=(3,5), title=None)
# sq.pl.nhood_enrichment(kscadata, cluster_key="level2_celltype", cmap='coolwarm', save='TSKKC_new_best_p10.svg', figsize=(3,5), ylabel='')
# sq.pl.nhood_enrichment(kscadata, cluster_key="level2_celltype",cmap='coolwarm',figsize=(3,5))

fig, ax = plt.subplots(figsize=(3,5))
sq.pl.nhood_enrichment(kscadata_p10, cluster_key="level2_celltype", cmap='coolwarm', ax=ax)
ax.set_ylabel('')
# plt.savefig('TSKKC_new_best_p10.svg')
plt.show()


In [None]:
# figsize(4,4)
# sc.settings.file_format_figs = 'svg'

sc.pl.spatial(scadata_p10,color="level2_celltype",groups=["Tumor_KC_Cyc"],spot_size=0.03, show=True,basis='advanced_diffusion_coords_avg',title='reconstructed',na_in_legend=False,save='P10_cyc')
sc.pl.spatial(scadata_p10,color="level2_celltype",groups=["Tumor_KC_Basal"],spot_size=0.03, show=True,basis='advanced_diffusion_coords_avg',title='reconstructed',na_in_legend=False,save='P10_bas')
sc.pl.spatial(scadata_p10,color="level2_celltype",groups=["Tumor_KC_Diff"],spot_size=0.03, show=True,basis='advanced_diffusion_coords_avg',title='reconstructed',na_in_legend=False,save='P10_diff')
#save='nonTSK',

In [None]:
sc.pl.spatial(scadata_p10,color="level2_celltype",groups=["PDC"],spot_size=0.04, show=True,basis='advanced_diffusion_coords_avg',title='reconstructed',na_in_legend=False)

In [None]:
sc.pl.spatial(scadata_p10,color="level2_celltype",groups=['Tumor_KC_Cyc','Tumor_KC_Basal','Tumor_KC_Diff'],spot_size=0.025, 
              show=True,basis='advanced_diffusion_coords_avg',title='reconstructed',na_in_legend=False)