In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

# Cell 2: force single‐threaded BLAS
os.environ["OMP_NUM_THREADS"]       = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"

In [None]:
# Cell 3: actually cap BLAS to 1 thread
from threadpoolctl import threadpool_limits

# 'blas' covers OpenBLAS, MKL, etc.
threadpool_limits(limits=1, user_api='blas')

# now import as usual, no more warning
import numpy as np
import scipy
# … any other packages that use OpenBLAS …


In [None]:
# import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import torch
import scanpy as sc
import pandas as pd
import numpy as np
from tqdm import tqdm
from sklearn.neighbors import kneighbors_graph
from scipy.sparse.csgraph import dijkstra
from scipy.sparse import csr_matrix, issparse
from sklearn.preprocessing import normalize
import ot 
from sklearn.neighbors import NearestNeighbors
import matplotlib.pyplot as plt

In [None]:
def construct_graph_torch(X, k, mode='connectivity', metric = 'minkowski', p=2, device='cuda'):
    '''construct knn graph with torch and gpu
    args:
        X: input data containing features (torch tensor)
        k: number of neighbors for each data point
        mode: 'connectivity' or 'distance'
        metric: distance metric (now euclidean supported for gpu knn)
        p: param for minkowski (not used if metric is euclidean)
    
    Returns:
        knn graph as a pytorch sparse tensor (coo format) or dense tensor depending on mode     
    '''

    assert mode in ['connectivity', 'distance'], "mode must be 'connectivity' or 'distance'."
    assert metric == 'euclidean', "for gpu knn, only 'euclidean' metric is currently supported in this implementation"

    if mode == 'connectivity':
        include_self = True
        mode_knn = 'connectivity'
    else:
        include_self = False
        mode_knn = 'distance'

    n_samples = X.shape[0]
    knn = NearestNeighbors(n_neighbors=k, metric=metric, algorithm='auto')

    if device == 'cuda' and torch.cuda.is_available():
        X_cpu = X.cpu().numpy()
    else:
        X_cpu = X.numpy()

    knn.fit(X_cpu)
    knn_graph_cpu = kneighbors_graph(knn, k, mode=mode_knn, include_self=include_self, metric=metric) #scipy sparse matrix on cpu
    knn_graph_coo = knn_graph_cpu.tocoo()

    if mode == 'connectivity':
        knn_graph = torch.sparse_coo_tensor(torch.LongTensor([knn_graph_coo.row, knn_graph_coo.col]),
                                            torch.FloatTensor(knn_graph_coo.data),
                                            size = knn_graph_coo.shape).to(device)
    elif mode == 'distance':
        knn_graph_dense = torch.tensor(knn_graph_cpu.toarray(), dtype=torch.float32, device=device) #move to gpu as dense tensor
        knn_graph = knn_graph_dense
    
    return knn_graph
    
def distances_cal_torch(graph, type_aware=None, aware_power =2, device='cuda'):
    '''
    calculate distance matrix from graph using dijkstra's algo
    args:
        graph: knn graph (pytorch sparse or dense tensor)
        type_aware: not implemented in this torch version for simplicity
        aware_power: same ^^
        device (str): 'cpu' or 'cuda' device to use
    Returns:
        distance matrix as a torch tensor
    '''

    if isinstance(graph, torch.Tensor) and graph.is_sparse:
        graph_cpu_csr = csr_matrix(graph.cpu().to_dense().numpy())
    elif isinstance(graph, torch.Tensor) and not graph.is_sparse:
        graph_cpu_csr = csr_matrix(graph.cpu().numpy())
    else:
        graph_cpu_csr = csr_matrix(graph) #assume scipy sparse matrix if not torch tensor

    shortestPath_cpu = dijkstra(csgraph = graph_cpu_csr, directed=False, return_predecessors=False) #dijkstra on cpu
    shortestPath = torch.tensor(shortestPath_cpu, dtype=torch.float32, device=device)

    # the_max = torch.nanmax(shortestPath[shortestPath != float('inf')])
    # shortestPath[shortestPath > the_max] = the_max

    #mask out infinite distances
    mask = shortestPath != float('inf')
    if mask.any():
        the_max = torch.max(shortestPath[mask])
        shortestPath[~mask] = the_max #replace inf with max value
    else:
        the_max = 1.0 #fallback if all are inf (should not happen in connected graphs)

    original_max_distance = the_max.item()
    C_dis = shortestPath / the_max
    # C_dis = shortestPath
    # C_dis -= torch.mean(C_dis)
    return C_dis, original_max_distance

def calculate_D_sc_torch(X_sc, k_neighbors=10, graph_mode='connectivity', device='cpu'):
    '''calculate distance matrix from graph using dijkstra's algo
    args:
        graph: knn graph (torch sparse or dense tensor)
        type_aware: not implemented
        aware_power: same ^^
        
    returns:
        distanced matrix as torch tensor'''
    
    if not isinstance(X_sc, torch.Tensor):
        raise TypeError('Input X_sc must be a pytorch tensor')
    
    if device == 'cuda' and torch.cuda.is_available():
        X_sc = X_sc.cuda(device=device)
    else:
        X_sc = X_sc.cpu()
        device= 'cpu'

    print(f'using device: {device}')
    print(f'constructing knn graph...')
    # X_normalized = normalize(X_sc.cpu().numpy(), norm='l2') #normalize on cpu for sklearn knn
    X_normalized = X_sc
    X_normalized_torch = torch.tensor(X_normalized, dtype=torch.float32, device=device)

    Xgraph = construct_graph_torch(X_normalized_torch, k=k_neighbors, mode=graph_mode, metric='euclidean', device=device)

    print('calculating distances from graph....')
    D_sc, sc_max_distance = distances_cal_torch(Xgraph, device=device)

    print('D_sc calculation complete')
    
    return D_sc, sc_max_distance


In [None]:
from sklearn.neighbors import kneighbors_graph, NearestNeighbors
from scipy.sparse.csgraph import dijkstra
from scipy.sparse import csr_matrix, issparse
from sklearn.preprocessing import normalize
import ot

def construct_graph_spatial(location_array, k, mode='distance', metric='euclidean', p=2):
    '''construct KNN graph based on spatial coordinates
    args:
        location_array: spatial coordinates of spots (n-spots * 2)
        k: number of neighbors for each spot
        mode: 'connectivity' or 'distance'
        metric: distance metric for knn (p=2 is euclidean)
        p: param for minkowski if connectivity
        
    returns:
        scipy.sparse.csr_matrix: knn graph in csr format
    '''

    assert mode in ['connectivity', 'distance'], "mode must be 'connectivity' or 'distance'"
    if mode == 'connectivity':
        include_self = True
    else:
        include_self = False
    
    c_graph = kneighbors_graph(location_array, k, mode=mode, metric=metric, include_self=include_self, p=p)
    return c_graph

def distances_cal_spatial(graph, spot_ids=None, spot_types=None, aware_power=2):
    '''calculate spatial distance matrix from knn graph
    args:
        graph (scipy.sparse.csr_matrix): knn graph
        spot_ids (list, optional): list of spot ids corresponding to the rows/cols of the graph. required if type_aware is used
        spot_types (pd.Series, optinal): pandas series of spot types for type aware distance adjustment. required if type_aware is used
        aware_power (int): power for type-aware distance adjustment
        
    returns:
        sptial distance matrix'''
    shortestPath = dijkstra(csgraph = csr_matrix(graph), directed=False, return_predecessors=False)
    shortestPath = np.nan_to_num(shortestPath, nan=np.inf) #handle potential inf valyes after dijkstra

    if spot_types is not None and spot_ids is not None:
        shortestPath_df = pd.DataFrame(shortestPath, index=spot_ids, columns=spot_ids)
        shortestPath_df['id1'] = shortestPath_df.index
        shortestPath_melted = shortestPath_df.melt(id_vars=['id1'], var_name='id2', value_name='value')

        type_aware_df = pd.DataFrame({'spot': spot_ids, 'spot_type': spot_types}, index=spot_ids)
        meta1 = type_aware_df.copy()
        meta1.columns = ['id1', 'type1']
        meta2 = type_aware_df.copy()
        meta2.columns = ['id2', 'type2']

        shortestPath_melted = pd.merge(shortestPath_melted, meta1, on='id1', how='left')
        shortestPath_melted = pd.merge(shortestPath_melted, meta2, on='id2', how='left')

        shortestPath_melted['same_type'] = shortestPath_melted['type1'] == shortestPath_melted['type2']
        shortestPath_melted.loc[(~shortestPath_melted.smae_type), 'value'] = shortestPath_melted.loc[(~shortestPath_melted.same_type),
                                                                                                     'value'] * aware_power
        shortestPath_melted.drop(['type1', 'type2', 'same_type'], axis=1, inplace=True)
        shortestPath_pivot = shortestPath_melted.pivot(index='id1', columns='id2', values='value')

        order = spot_ids
        shortestPath = shortestPath_pivot[order].loc[order].values
    else:
        shortestPath = np.asarray(shortestPath) #ensure it's a numpy array

    #mask out infinite distances
    mask = shortestPath != float('inf')
    if mask.any():
        the_max = np.max(shortestPath[mask])
        shortestPath[~mask] = the_max #replace inf with max value
    else:
        the_max = 1.0 #fallback if all are inf (should not happen in connected graphs)

    #store original max distance for scale reference
    original_max_distance = the_max
    C_dis = shortestPath / the_max
    # C_dis = shortestPath
    # C_dis -= np.mean(C_dis)

    return C_dis, original_max_distance

def calculate_D_st_from_coords(spatial_coords, X_st=None, k_neighbors=10, graph_mode='distance', aware_st=False, 
                               spot_types=None, aware_power_st=2, spot_ids=None):
    '''calculates the spatial distance matrix D_st for spatial transcriptomics data directly from coordinates and optional spot types
    args:
        spatial_coords: spatial coordinates of spots (n_spots * 2)
        X_st: St gene expression data (not used for D_st calculation itself)
        k_neighbors: number of neighbors for knn graph
        graph_mode: 'connectivity or 'distance' for knn graph
        aware_st: whether to use type-aware distance adjustment
        spot_types: pandas series of spot types for type-aware adjustment
        aware_power_st: power for type-aware distance adjustment
        spot_ids: list or index of spot ids, required if spot_ids is provided
        
    returns:
        np.ndarray: spatial disance matrix D_st'''
    
    if isinstance(spatial_coords, pd.DataFrame):
        location_array = spatial_coords.values
        if spot_ids is None:
            spot_ids = spatial_coords.index.tolist() #use index of dataframe if available
    elif isinstance(spatial_coords, np.ndarray):
        location_array = spatial_coords
        if spot_ids is None:
            spot_ids = list(range(location_array.shape[0])) #generate default ids if not provided

    else:
        raise TypeError('spatial_coords must be a pandas dataframe or a numpy array')
    
    print(f'constructing {graph_mode} graph for ST data with k={k_neighbors}.....')
    Xgraph_st = construct_graph_spatial(location_array, k=k_neighbors, mode=graph_mode)
    
    if aware_st:
        if spot_types is None or spot_ids is None:
            raise ValueError('spot_types and spot_ids must be provided when aware_st=True')
        if not isinstance(spot_types, pd.Series):
            spot_types = pd.Series(spot_types, idnex=spot_ids) 
        print('applying type aware distance adjustment for ST data')
        print(f'aware power for ST: {aware_power_st}')
    else:
        spot_types = None 

    print(f'calculating spatial distances.....')
    D_st, st_max_distance = distances_cal_spatial(Xgraph_st, spot_ids=spot_ids, spot_types=spot_types, aware_power=aware_power_st)

    print('D_st calculation complete')
    return D_st, st_max_distance


def calculate_D_st_euclidean(spatial_coords):
    """
    Calculate Euclidean distance matrix for ST spots.
    
    Args:
        spatial_coords: (m_spots, 2) spatial coordinates
        
    Returns:
        D_st_euclid: (m_spots, m_spots) normalized Euclidean distance matrix
    """
    from scipy.spatial.distance import pdist, squareform
    
    if isinstance(spatial_coords, pd.DataFrame):
        coords_array = spatial_coords.values
    elif isinstance(spatial_coords, np.ndarray):
        coords_array = spatial_coords
    else:
        coords_array = np.array(spatial_coords)
    
    # Compute pairwise Euclidean distances
    D_euclid = squareform(pdist(coords_array, metric='euclidean'))
    
    # Normalize to [0,1]
    max_dist = D_euclid.max()
    if max_dist > 0:
        D_euclid = D_euclid / max_dist
    
    return D_euclid.astype(np.float32)

In [None]:
def compute_D_induced_proper_scaling(T, D_st, n_sc, n_st):
    '''compute D_induced with proper scaling'''
    #reweight transport matrix
    T_reweight = T * n_sc
    D_induced_raw = T_reweight @ D_st @ T_reweight.t()

    #normalize to [0,1] range
    D_induced_max = torch.max(D_induced_raw[D_induced_raw > 0])
    if D_induced_max > 1e-10: #avoid dvision by very small numbers
        D_induced = D_induced_raw / D_induced_max
    else:
        D_induced = D_induced_raw

    return D_induced

def fused_gw_torch(X_sc, X_st, Y_st, alpha, k_sc=100, k_st=30, G0=None, max_iter = 100, tol=1e-9, epsilon=0.1, device='cuda', n_iter = 1, D_st_precomputed=None):
    n = X_sc.shape[0]
    m = X_st.shape[0]

    X_sc = X_sc.to(device)
    X_st = X_st.to(device)

    if not torch.is_tensor(Y_st):
        Y_st_tensor = torch.tensor(Y_st, dtype=torch.float32, device=device)
    else:
        Y_st_tensor = Y_st.to(device, dtype=torch.float32)

    #calculate distance matrices
    print('calculating SC distances with knn-dijkstra.....')
    D_sc, sc_max_distance = calculate_D_sc_torch(X_sc, graph_mode='distance', k_neighbors=k_sc, device=device)

    if D_st_precomputed is not None:
        print("Using precomputed block-diagonal D_st...")
        D_st = D_st_precomputed.to(device)
        st_max_distance = 1.0 #assume already normalized
    else:
        print('Calculating ST distances.....')
        D_st, st_max_distance = calculate_D_st_from_coords(spatial_coords=Y_st, k_neighbors=k_st, graph_mode="distance")
        D_st = torch.tensor(D_st, dtype=torch.float32, device=device)

    #get expression distance matrix
    C_exp = torch.cdist(X_sc, X_st, p=2) #euclidean distance
    C_exp = C_exp / (torch.max(C_exp) + 1e-16) #normalize

    #ensure distance matries are C-contiguouse numpy arrays for POT
    D_sc_np = D_sc.cpu().numpy()
    D_st_np = D_st.cpu().numpy()
    C_exp_np = C_exp.cpu().numpy()
    D_sc_np = np.ascontiguousarray(D_sc_np)
    D_st_np = np.ascontiguousarray(D_st_np)
    C_exp_np = np.ascontiguousarray(C_exp_np)

    #uniform distributions
    p = ot.unif(n)
    q = ot.unif(m)

    #anneal the reg param over several steps
    T_np = None
    for i in range(n_iter):
        T_np, log = ot.gromov.entropic_fused_gromov_wasserstein(
            M=C_exp_np, 
            C1=D_sc_np, 
            C2=D_st_np,
            p=p, 
            q=q, 
            loss_fun='square_loss',
            epsilon=epsilon,
            alpha=alpha,
            G0=T_np if T_np is not None else (G0.cpu().numpy() if G0 is not None else None),
            log=True,
            verbose=True,
            max_iter=max_iter,
            tol=tol
        )

    fgw_dist = log['fgw_dist']

    print(f'fgw distance: {fgw_dist}')

    T = torch.tensor(T_np, dtype=torch.float32, device=device)

    n_sc = X_sc.shape[0]
    n_st = X_st.shape[0]

    D_induced = compute_D_induced_proper_scaling(T, D_st, n_sc, n_st)

    return T, D_sc, D_st, D_induced, fgw_dist, sc_max_distance, st_max_distance

# patient 2 data load

In [None]:
def load_and_process_cscc_data():
    """
    Load and process the cSCC dataset with multiple ST replicates.
    """
    print("Loading cSCC data...")
    
    # Load SC data
    scadata = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/scP2.h5ad')
    
    # Load all 3 ST datasets
    stadata1 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP2.h5ad')
    stadata2 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP2rep2.h5ad')
    stadata3 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP2rep3.h5ad')
    
    # Normalize and log transform
    for adata in [scadata, stadata1, stadata2, stadata3]:
        sc.pp.normalize_total(adata)
        sc.pp.log1p(adata)
    
    # Create rough cell types for SC data
    scadata.obs['rough_celltype'] = scadata.obs['level1_celltype'].astype(str)
    scadata.obs.loc[scadata.obs['level1_celltype']=='CLEC9A','rough_celltype'] = 'DC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='CD1C','rough_celltype'] = 'DC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='ASDC','rough_celltype'] = 'DC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='PDC','rough_celltype'] = 'PDC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='MDSC','rough_celltype'] = 'DC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='LC','rough_celltype'] = 'DC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='Mac','rough_celltype'] = 'Myeloid cell'
    scadata.obs.loc[scadata.obs['level1_celltype']=='Tcell','rough_celltype'] = 'T cell'
    scadata.obs.loc[scadata.obs['level2_celltype']=='TSK','rough_celltype'] = 'TSK'
    scadata.obs.loc[scadata.obs['level2_celltype'].isin(['Tumor_KC_Basal', 'Tumor_KC_Diff','Tumor_KC_Cyc']),'rough_celltype'] = 'NonTSK'
    
    return scadata, stadata1, stadata2, stadata3

def prepare_combined_st_for_diffusion(stadata1, stadata2, stadata3, scadata):
    """
    Combine all ST datasets for diffusion training while maintaining gene alignment.
    Key innovation: Use ALL ST data points for better training.
    """
    print("Preparing combined ST data for diffusion training...")
    
    # Get common genes between SC and all ST datasets
    sc_genes = set(scadata.var_names)
    st1_genes = set(stadata1.var_names)
    st2_genes = set(stadata2.var_names)
    st3_genes = set(stadata3.var_names)
    
    common_genes = sorted(list(sc_genes & st1_genes & st2_genes & st3_genes))
    print(f"Common genes across all datasets: {len(common_genes)}")
    
    # Extract aligned expression data
    sc_expr = scadata[:, common_genes].X
    st1_expr = stadata1[:, common_genes].X
    st2_expr = stadata2[:, common_genes].X
    st3_expr = stadata3[:, common_genes].X
    
    # Convert to dense if sparse
    if hasattr(sc_expr, 'toarray'):
        sc_expr = sc_expr.toarray()
    if hasattr(st1_expr, 'toarray'):
        st1_expr = st1_expr.toarray()
    if hasattr(st2_expr, 'toarray'):
        st2_expr = st2_expr.toarray()
    if hasattr(st3_expr, 'toarray'):
        st3_expr = st3_expr.toarray()
    
    # Get spatial coordinates
    st1_coords = stadata1.obsm['spatial']
    st2_coords = stadata2.obsm['spatial']
    st3_coords = stadata3.obsm['spatial']

    # Store separate coordinate lists for block-diagonal graph
    st_coords_list = [st1_coords, st2_coords, st3_coords]
    
    # Combine all ST data
    st_expr_combined = np.vstack([st1_expr, st2_expr, st3_expr])

    from sklearn.preprocessing import StandardScaler
    scaler = StandardScaler()
    st_expr_combined = scaler.fit_transform(st_expr_combined)

    st_coords_combined = np.vstack([st1_coords, st2_coords, st3_coords])

    sc_expr = scaler.fit_transform(sc_expr)

    
    # Create dataset labels for tracking
    dataset_labels = (['dataset1'] * len(st1_expr) + 
                     ['dataset2'] * len(st2_expr) + 
                     ['dataset3'] * len(st3_expr))
    
    print(f"Combined ST data shape: {st_expr_combined.shape}")
    print(f"Combined ST coords shape: {st_coords_combined.shape}")
    print(f"SC data shape: {sc_expr.shape}")
    
    # Convert to tensors
    X_sc = torch.tensor(sc_expr, dtype=torch.float32)
    X_st_combined = torch.tensor(st_expr_combined, dtype=torch.float32)
    Y_st_combined = st_coords_combined.astype(np.float32)
    
    return X_sc, X_st_combined, Y_st_combined, dataset_labels, common_genes, st_coords_list

# Load and process data
scadata, stadata1, stadata2, stadata3 = load_and_process_cscc_data()

# Prepare combined data for diffusion
X_sc, X_st_combined, Y_st_combined, dataset_labels, common_genes, st_coords_list = prepare_combined_st_for_diffusion(
    stadata1, stadata2, stadata3, scadata
)

print(f"Data preparation complete!")
print(f"SC cells: {X_sc.shape[0]}")
print(f"Combined ST spots: {X_st_combined.shape[0]}")
print(f"Common genes: {len(common_genes)}")



# diffusion model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim.lr_scheduler as lr_scheduler
import numpy as np
from tqdm import tqdm
import os
import time
import scipy
from sklearn.neighbors import NearestNeighbors
from scipy.spatial import cKDTree
from typing import Optional, Dict, Tuple, List
import torch_geometric
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.data import Data

# =====================================================
# PART X: Graph-VAE Components
# =====================================================

class GraphVAEEncoder(nn.Module):
    """
    Graph encoder that learns latent representations from ST spot graphs.
    ⚠️ Do **not** touch `train_encoder`; its aligned embeddings are the sole conditioning signal throughout.
    """
    def __init__(self, input_dim, hidden_dim=128, latent_dim=32):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        
        # Two GraphConv layers as specified
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        
        # MLP to output μ and log σ² FOR EACH NODE (not graph-level)
        self.mu_head = nn.Linear(hidden_dim, latent_dim)
        self.logvar_head = nn.Linear(hidden_dim, latent_dim)
        
    def forward(self, x, edge_index, edge_weight=None, batch=None):
        """
        x: node features (aligned embeddings E(X_st)) - shape (n_nodes, input_dim)
        edge_index: graph edges from K-NN adjacency
        edge_weight: optional edge weights
        batch: not used since we want node-level representations
        
        Returns:
        mu: (n_nodes, latent_dim)
        logvar: (n_nodes, latent_dim)
        """
        # Two GraphConv layers
        h = torch.relu(self.conv1(x, edge_index, edge_weight))
        h = torch.relu(self.conv2(h, edge_index, edge_weight))
        
        # NO GLOBAL POOLING - we want node-level representations
        # Output μ and log σ² for each node
        mu = self.mu_head(h)        # Shape: (n_nodes, latent_dim)
        logvar = self.logvar_head(h)  # Shape: (n_nodes, latent_dim)
        
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        """Reparameterization trick - works element-wise"""
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

class GraphVAEDecoder(nn.Module):
    """
    Graph decoder that outputs 2D coordinates from latent z + aligned embedding.
    """
    def __init__(self, latent_dim=32, condition_dim=128, hidden_dim=128):
        super().__init__()
        self.latent_dim = latent_dim
        self.condition_dim = condition_dim
        
        # Concatenate latent z with aligned embedding for conditioning
        input_dim = latent_dim + condition_dim
        
        self.decoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2)  # Output 2D coordinates
        )
        
    def forward(self, z, condition):
        """
        z: latent vectors (batch_size, latent_dim)
        condition: aligned embeddings E(X) (batch_size, condition_dim)
        """
        # Concatenate latent with conditioning
        combined = torch.cat([z, condition], dim=-1)
        coords = self.decoder(combined)
        return coords

def precompute_knn_edges(coords, k=30, device='cuda'):
    """
    Helper function to precompute K-NN edges for torch-geometric style layers.
    Uses existing graph construction utilities where possible.
    """
    if isinstance(coords, torch.Tensor):
        coords_np = coords.cpu().numpy()
    else:
        coords_np = coords
        
    # Use existing construct_graph_spatial function
    from sklearn.neighbors import kneighbors_graph
    
    # Build KNN graph
    knn_graph = kneighbors_graph(
        coords_np, 
        n_neighbors=k, 
        mode='connectivity', 
        include_self=False
    )
    
    # Convert to torch-geometric format
    from torch_geometric.utils import from_scipy_sparse_matrix
    edge_index, edge_weight = from_scipy_sparse_matrix(knn_graph)
    
    # CRITICAL FIX: Ensure correct dtypes
    edge_index = edge_index.long().to(device)      # Edge indices should be long
    edge_weight = edge_weight.float().to(device)   # Edge weights should be float32
    
    return edge_index, edge_weight

class LatentDenoiser(nn.Module):
    """
    Latent-space denoiser identical to current MLP/U-Net stack but for latent dim=32.
    ⚠️ Do **not** touch `train_encoder`; its aligned embeddings are the sole conditioning signal throughout.
    """
    def __init__(self, latent_dim=32, condition_dim=128, hidden_dim=256, n_blocks=6):
        super().__init__()
        self.latent_dim = latent_dim
        self.condition_dim = condition_dim
        
        # Time embedding (reuse existing SinusoidalEmbedding)
        self.time_embed = nn.Sequential(
            SinusoidalEmbedding(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Latent encoder
        self.latent_encoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Condition encoder (for aligned embeddings)
        self.condition_encoder = nn.Sequential(
            nn.Linear(condition_dim, hidden_dim),
            nn.ReLU(), 
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Denoising blocks (similar to existing hierarchical blocks)
        self.denoising_blocks = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim * 2),
                nn.ReLU(),
                nn.Dropout(0.1),
                nn.Linear(hidden_dim * 2, hidden_dim),
                nn.LayerNorm(hidden_dim)
            ) for _ in range(n_blocks)
        ])
        
        # Output head
        self.output_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, latent_dim)
        )
        
    
    def forward(self, z_noisy, t, condition):
        """
        z_noisy: noisy latent vectors (batch_size, latent_dim)
        t: timestep (batch_size,) - NOW 1D instead of 2D
        condition: aligned embeddings E(X) (batch_size, condition_dim)
        """
        # ENSURE inputs are 2D
        if z_noisy.dim() > 2:
            z_noisy = z_noisy.squeeze()
        if condition.dim() > 2:
            condition = condition.squeeze()
            
        # Handle 1D timestep input
        if t.dim() == 1:
            t = t.unsqueeze(1)  # Make it (batch_size, 1)
        
        # Encode inputs
        z_enc = self.latent_encoder(z_noisy)
        t_enc = self.time_embed(t)
        c_enc = self.condition_encoder(condition)
        
        # Combine features
        h = z_enc + t_enc + c_enc
        
        # Apply denoising blocks
        for block in self.denoising_blocks:
            h = h + block(h)  # Residual connections
            
        # Output predicted noise
        noise_pred = self.output_head(h)
        return noise_pred

# =====================================================
# PART 1: Advanced Network Components
# =====================================================

class FeatureNet(nn.Module):
    def __init__(self, n_genes, n_embedding=[512, 256, 128], dp=0):
        super(FeatureNet, self).__init__()

        self.fc1 = nn.Linear(n_genes, n_embedding[0])
        self.bn1 = nn.LayerNorm(n_embedding[0])
        self.fc2 = nn.Linear(n_embedding[0], n_embedding[1])
        self.bn2 = nn.LayerNorm(n_embedding[1])
        self.fc3 = nn.Linear(n_embedding[1], n_embedding[2])
        
        self.dp = nn.Dropout(dp)
        
    def forward(self, x, isdp=False):
        if isdp:
            x = self.dp(x)
        x = F.relu(self.bn1(self.fc1(x)))
        x = F.relu(self.bn2(self.fc2(x)))
        x = self.fc3(x)
        return x

class SinusoidalEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        half_dim = self.dim // 2
        emb = np.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=x.device) * -emb)
        emb = x.unsqueeze(-1) * emb.unsqueeze(0)
        emb = torch.cat([emb.sin(), emb.cos()], dim=-1)
        return emb

import torch.optim as optim   
from geomloss import SamplesLoss

# OT refinement function
def refine_with_ot(sc_coords, st_coords, n_steps=50, lr=1e-2):
    """
    Refines SC coordinates by minimizing entropic OT divergence to ST coords.
    sc_coords: Tensor (N,2) initial SC coordinates
    st_coords: Tensor (M,2) ST spot coordinates
    """
    sinkhorn = SamplesLoss("sinkhorn", p=2, blur=0.05, scaling=0.9)
    coords = sc_coords.clone().detach().requires_grad_(True)
    optimizer = optim.Adam([coords], lr=lr)

    for _ in range(n_steps):
        optimizer.zero_grad()
        loss_ot = sinkhorn(coords.unsqueeze(0), st_coords.unsqueeze(0))
        loss_ot.backward()
        optimizer.step()

    return coords.detach()
    
class OTGuidedSampler:
    def __init__(self,
                 T_opt: torch.Tensor,
                 st_coords_norm: torch.Tensor,
                 n_timesteps: int):
        self.T_opt       = T_opt           # (n_sc, n_st)
        self.st_coords   = st_coords_norm  # (n_st, 2)
        self.n_timesteps = n_timesteps

    def get_ot_guidance(self, sc_indices: List[int], t: int) -> torch.Tensor:
        """
        Returns the expected OT‐based target for each sc index,
        plus a small decaying jitter.
        """
        if self.T_opt is None:
            return None

        guidance = []
        for sc_idx in sc_indices:
            st_w = self.T_opt[sc_idx]                 # (n_st,)
            total = st_w.sum()
            if total <= 0:
                guidance.append(torch.zeros(2, device=self.st_coords.device))
                continue
            # expected spot location (no argmax sampling)
            w_norm = st_w / total                    # normalize
            target_mean = (w_norm.unsqueeze(1) * self.st_coords).sum(dim=0)
            # decaying noise: maximal when t≈T, zero at t=0
            noise_scale = 0.02 * (t / self.n_timesteps)
            jitter = torch.randn_like(target_mean) * noise_scale
            guidance.append(target_mean + jitter)

        return torch.stack(guidance)  # (batch_size, 2)
    
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.neighbors import NearestNeighbors
import numpy as np

class CellTypeEmbedding(nn.Module):
    """Learned embeddings for cell types"""
    def __init__(self, num_cell_types, embedding_dim):
        super().__init__()
        self.embedding = nn.Embedding(num_cell_types, embedding_dim)
        
    def forward(self, cell_type_indices):
        return self.embedding(cell_type_indices)

class UncertaintyHead(nn.Module):
    """Predicts coordinate uncertainty"""
    def __init__(self, input_dim, hidden_dim=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2)  # Uncertainty for x and y
        )
        
    def forward(self, x):
        return F.softplus(self.net(x)) + 0.01  # Ensure positive uncertainty

class PhysicsInformedLayer(nn.Module):
    """Incorporates cell non-overlap constraints"""
    def __init__(self, feature_dim):
        super().__init__()
        self.radius_predictor = nn.Sequential(
            nn.Linear(feature_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Softplus()
        )
        self.repulsion_strength = nn.Parameter(torch.tensor(0.1))
        
    def compute_repulsion_gradient(self, coords, radii, cell_types=None):
        """Compute repulsion forces between cells"""
        batch_size = coords.shape[0]
        
        # Compute pairwise distances
        distances = torch.cdist(coords, coords, p=2)
        
        # Compute sum of radii for each pair
        radii_sum = radii + radii.T
        
        # Compute overlap (positive when cells overlap)
        overlap = F.relu(radii_sum - distances + 1e-6)
        
        # Mask out self-interactions
        mask = (1 - torch.eye(batch_size, device=coords.device))
        overlap = overlap * mask
        
        # Compute repulsion forces
        coord_diff = coords.unsqueeze(1) - coords.unsqueeze(0)  # (B, B, 2)
        distances_safe = distances + 1e-6  # Avoid division by zero
        
        # Normalize direction vectors
        directions = coord_diff / distances_safe.unsqueeze(-1)
        
        # Apply stronger repulsion for same cell types (optional)
        if cell_types is not None:
            same_type_mask = (cell_types.unsqueeze(1) == cell_types.unsqueeze(0)).float()
            repulsion_weight = 1.0 + 0.5 * same_type_mask  # 50% stronger for same type
        else:
            # repulsion_weight = 1.0
            batch_size = coords.shape[0]
            repulsion_weight = torch.ones(batch_size, batch_size, device=coords.device)
            
        # Compute repulsion magnitude
        repulsion_magnitude = overlap.unsqueeze(-1) * repulsion_weight.unsqueeze(-1)
        
        # Sum repulsion forces from all other cells
        repulsion_forces = (repulsion_magnitude * directions * mask.unsqueeze(-1)).sum(dim=1)
        
        return repulsion_forces
        
    def forward(self, coords, features, cell_types=None):
        # Predict cell radii based on features
        radii = self.radius_predictor(features).squeeze(-1) * 0.01  # Scale to reasonable size
        
        # Compute repulsion gradient
        repulsion_grad = self.compute_repulsion_gradient(coords, radii, cell_types)
        
        return repulsion_grad * self.repulsion_strength, radii
    
class SpatialBatchSampler:
    """Sample spatially contiguous batches for geometric attention"""
    
    def __init__(self, coordinates, batch_size, k_neighbors=None):
        """
        coordinates: (N, 2) array of spatial coordinates
        batch_size: size of each batch
        k_neighbors: number of neighbors to precompute (default: batch_size)
        """
        self.coordinates = coordinates
        self.batch_size = batch_size
        self.k_neighbors = k_neighbors or min(batch_size, len(coordinates))
        
        # Precompute nearest neighbors
        self.nbrs = NearestNeighbors(
            n_neighbors=self.k_neighbors, 
            algorithm='kd_tree'
        ).fit(coordinates)
        
    def sample_spatial_batch(self):
        """Sample a spatially contiguous batch"""
        # Pick random center point
        center_idx = np.random.randint(len(self.coordinates))
        
        # Get k nearest neighbors
        distances, indices = self.nbrs.kneighbors(
            self.coordinates[center_idx:center_idx+1], 
            return_distance=True
        )
        
        # Return indices as torch tensor
        batch_indices = torch.tensor(indices.flatten()[:self.batch_size], dtype=torch.long)
        return batch_indices

# =====================================================
# PART 2: Hierarchical Diffusion Architecture
# =====================================================

class HierarchicalDiffusionBlock(nn.Module):
    """Multi-scale diffusion block for coarse-to-fine generation"""
    def __init__(self, dim, num_scales=3):
        super().__init__()
        self.num_scales = num_scales
        
        # Coarse-level predictor (for clusters/regions)
        self.coarse_net = nn.Sequential(
            nn.Linear(dim, dim * 2),
            nn.ReLU(),
            nn.Linear(dim * 2, dim)
        )
        
        # Fine-level predictor (for individual cells)
        self.fine_net = nn.Sequential(
            nn.Linear(dim * 2, dim * 2),  # Takes both coarse and fine features
            nn.ReLU(),
            nn.Linear(dim * 2, dim)
        )
        
        # Scale mixing weights
        self.scale_mixer = nn.Sequential(
            nn.Linear(1, 64),  # Takes timestep
            nn.ReLU(),
            nn.Linear(64, num_scales),
            nn.Softmax(dim=-1)
        )
        
    def forward(self, x, t, coarse_context=None):
        # Determine scale weights based on timestep
        scale_weights = self.scale_mixer(t.unsqueeze(-1))
        
        # Coarse prediction
        coarse_pred = self.coarse_net(x)
        
        # Fine prediction (conditioned on coarse if available)
        if coarse_context is not None:
            fine_input = torch.cat([x, coarse_context], dim=-1)
        else:
            fine_input = torch.cat([x, coarse_pred], dim=-1)
        fine_pred = self.fine_net(fine_input)
        
        # Mix scales based on timestep
        output = scale_weights[:, 0:1] * coarse_pred + scale_weights[:, 1:2] * fine_pred
        
        return output  


# =====================================================
# PART 3: Main Advanced Diffusion Model
# =====================================================

class AdvancedHierarchicalDiffusion(nn.Module):
    def __init__(
        self,
        st_gene_expr,
        st_coords,
        sc_gene_expr,
        cell_types_sc=None,  # Cell type labels for SC data
        transport_plan=None,  # Optimal transport plan from domain alignment
        D_st=None,
        D_induced=None,
        n_genes=None,
        # n_embedding=128,
        n_embedding=[512, 256, 128],
        coord_space_diameter=200,
        st_max_distance=None,
        sc_max_distance=None,
        sigma=3.0,
        alpha=0.9,
        mmdbatch=0.1,
        batch_size=64,
        device='cuda',
        lr_e=0.0001,
        lr_d=0.0002,
        n_timesteps=1000,
        n_denoising_blocks=6,
        hidden_dim=512,
        num_heads=8,
        num_hierarchical_scales=3,
        dp=0.1,
        outf='output'
    ):
        super().__init__()

        self.diffusion_losses = {
            'total': [],
            'diffusion': [],
            'struct': [],
            'physics': [],
            'uncertainty': [],
            'epochs': []
        }

        # Loss tracking for Graph-VAE training
        self.vae_losses = {
            'total': [],
            'reconstruction': [],
            'kl': [],
            'epochs': []
        }
        
        # Loss tracking for Latent Diffusion training  
        self.latent_diffusion_losses = {
            'total': [],
            'diffusion': [],
            'struct': [],
            'epochs': []
        }
        
        # Keep encoder losses separate (if you want to track them)
        self.encoder_losses = {
            'total': [],
            'pred': [],
            'circle': [],
            'mmd': [],
            'epochs': []
        }
        
        self.device = device
        self.batch_size = batch_size
        self.n_timesteps = n_timesteps
        self.sigma = sigma
        self.alpha = alpha
        self.mmdbatch = mmdbatch
        self.n_embedding = n_embedding
        
        # Create output directory
        self.outf = outf
        if not os.path.exists(outf):
            os.makedirs(outf)
        
        # Store data
        self.st_gene_expr = torch.tensor(st_gene_expr, dtype=torch.float32).to(device)
        self.st_coords = torch.tensor(st_coords, dtype=torch.float32).to(device)
        self.sc_gene_expr = torch.tensor(sc_gene_expr, dtype=torch.float32).to(device)

        
        # Temperature regularization for geometric attention
        self.temp_weight_decay = 1e-4
        
        # Store transport plan if provided
        self.transport_plan = torch.tensor(transport_plan, dtype=torch.float32).to(device) if transport_plan is not None else None
        
        # Process cell types
        if cell_types_sc is not None:
            # Convert cell type strings to indices
            unique_cell_types = np.unique(cell_types_sc)
            self.cell_type_to_idx = {ct: i for i, ct in enumerate(unique_cell_types)}
            self.num_cell_types = len(unique_cell_types)
            cell_type_indices = [self.cell_type_to_idx[ct] for ct in cell_types_sc]
            self.sc_cell_types = torch.tensor(cell_type_indices, dtype=torch.long).to(device)
        else:
            self.sc_cell_types = None
            self.num_cell_types = 0
            
        # Store distance matrices
        self.D_st = torch.tensor(D_st, dtype=torch.float32).to(device) if D_st is not None else None
        self.D_induced = torch.tensor(D_induced, dtype=torch.float32).to(device) if D_induced is not None else None

        # If D_st is not provided, calculate it from spatial coordinates
        if self.D_st is None:
            print("D_st not provided, calculating from spatial coordinates...")
            if isinstance(st_coords, torch.Tensor):
                st_coords_np = st_coords.cpu().numpy()
            else:
                st_coords_np = st_coords
            
            D_st_np, st_max_distance = calculate_D_st_from_coords(
                spatial_coords=st_coords_np, 
                k_neighbors=50, 
                graph_mode="distance"
            )
            self.D_st = torch.tensor(D_st_np, dtype=torch.float32).to(device)
            self.st_max_distance = st_max_distance
            print(f"D_st calculated, shape: {self.D_st.shape}")


        print(f"Final matrices - D_st: {self.D_st.shape if self.D_st is not None else None}, "
            f"D_induced: {self.D_induced.shape if self.D_induced is not None else None}")
        
        # Normalize coordinates
        self.st_coords_norm, self.coords_center, self.coords_radius = self.normalize_coordinates_isotropic(self.st_coords)        
        # Model parameters
        self.n_genes = n_genes or st_gene_expr.shape[1]
        
        # ========== FEATURE ENCODER ==========
        self.netE = self.build_feature_encoder(self.n_genes, n_embedding, dp)

        self.train_log = os.path.join(outf, 'train.log')

        
        # ========== CELL TYPE EMBEDDING ==========

        use_cell_types = (cell_types_sc is not None)  # Check if SC data has cell types
        self.use_cell_types = use_cell_types

        if self.num_cell_types > 0:
            self.cell_type_embedding = CellTypeEmbedding(self.num_cell_types, n_embedding[-1] // 2)
            total_feature_dim = n_embedding[-1] + n_embedding[-1] // 2
        else:
            self.cell_type_embedding = None
            total_feature_dim = n_embedding[-1]
            
        # ========== HIERARCHICAL DIFFUSION COMPONENTS ==========
        # Time embedding
        self.time_embed = nn.Sequential(
            SinusoidalEmbedding(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Coordinate encoder
        self.coord_encoder = nn.Sequential(
            nn.Linear(2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Feature projection (includes cell type if available)
        self.feat_proj = nn.Sequential(
            nn.Linear(total_feature_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        # ========== GRAPH-VAE COMPONENTS (REPLACING HIERARCHICAL DIFFUSION) ==========        
        # Graph-VAE parameters
        self.latent_dim = 32  # As specified in instructions
        
        # Graph-VAE Encoder (learns latent representations from ST graphs)
        self.graph_vae_encoder = GraphVAEEncoder(
            input_dim=n_embedding[-1],  # Aligned embedding dimension
            hidden_dim=128,             # GraphConv hidden dimension  
            latent_dim=self.latent_dim
        ).to(device)
        
        # Graph-VAE Decoder (outputs coordinates from latent + conditioning)
        self.graph_vae_decoder = GraphVAEDecoder(
            latent_dim=self.latent_dim,
            condition_dim=n_embedding[-1],  # Same as aligned embedding
            hidden_dim=128
        ).to(device)
        
        # Latent Denoiser (replaces hierarchical_blocks)
        self.latent_denoiser = LatentDenoiser(
            latent_dim=self.latent_dim,
            condition_dim=n_embedding[-1],
            hidden_dim=hidden_dim,
            n_blocks=n_denoising_blocks
        ).to(device)
        
        # ========== HIERARCHICAL DENOISING BLOCKS ==========
        self.hierarchical_blocks = nn.ModuleList([
            HierarchicalDiffusionBlock(hidden_dim, num_hierarchical_scales)
            for _ in range(n_denoising_blocks)
        ])    

        # ========== PHYSICS-INFORMED COMPONENTS ==========
        self.physics_layer = PhysicsInformedLayer(hidden_dim)
        
        # ========== UNCERTAINTY QUANTIFICATION ==========
        self.uncertainty_head = UncertaintyHead(hidden_dim)
        
        # ========== OPTIMAL TRANSPORT GUIDANCE ==========
        if self.transport_plan is not None:
            self.ot_guidance_strength = nn.Parameter(torch.tensor(0.1))
            
        # ========== OUTPUT LAYERS ==========
        self.noise_predictor = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2)
        )
        
        # Create noise schedule
        self.noise_schedule = self.create_noise_schedule()
        
        # Optimizers
        self.setup_optimizers(lr_e, lr_d)
        
        # MMD Loss for domain alignment
        self.mmd_loss = MMDLoss()

        # Move entire model to device
        self.to(self.device)

    def setup_spatial_sampling(self):
        if hasattr(self, 'st_coords_norm'):
            self.spatial_sampler = SpatialBatchSampler(
                coordinates=self.st_coords_norm.cpu().numpy(),
                batch_size=self.batch_size
            )
        else:
            self.spatial_sampler = None

    def get_spatial_batch(self):
        """Get spatially contiguous batch for training"""
        if self.spatial_sampler is not None:
            return self.spatial_sampler.sample_spatial_batch()
        else:
            # Fallback to random sampling
            return torch.randperm(len(self.st_coords_norm))[:self.batch_size]
        
    def normalize_coordinates_isotropic(self, coords):
        """Normalize coordinates isotropically to [-1, 1]"""
        center = coords.mean(dim=0)
        centered_coords = coords - center
        max_dist = torch.max(torch.norm(centered_coords, dim=1))
        normalized_coords = centered_coords / (max_dist + 1e-8)
        return normalized_coords, center, max_dist
        

    def build_feature_encoder(self, n_genes, n_embedding, dp):
        """Build the feature encoder network"""
        return FeatureNet(n_genes, n_embedding=n_embedding, dp=dp).to(self.device)
        
    def create_noise_schedule(self):
        """Create the noise schedule for diffusion"""
        betas = torch.linspace(0.0001, 0.02, self.n_timesteps, device=self.device)
        alphas = 1 - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        
        return {
            'betas': betas,
            'alphas': alphas,
            'alphas_cumprod': alphas_cumprod,
            'sqrt_alphas_cumprod': torch.sqrt(alphas_cumprod),
            'sqrt_one_minus_alphas_cumprod': torch.sqrt(1 - alphas_cumprod)
        }
        
    def setup_optimizers(self, lr_e, lr_d):
        """Setup optimizers and schedulers"""
        # Encoder optimizer
        self.optimizer_E = torch.optim.AdamW(self.netE.parameters(), lr=0.002)               
        self.scheduler_E = lr_scheduler.StepLR(self.optimizer_E, step_size=200, gamma=0.5) 

        # MMD Loss
        self.mmd_fn = MMDLoss()   
        
        # Diffusion model optimizer
        diff_params = []
        diff_params.extend(self.time_embed.parameters())
        diff_params.extend(self.coord_encoder.parameters())
        diff_params.extend(self.feat_proj.parameters())
        diff_params.extend(self.hierarchical_blocks.parameters())
        # diff_params.extend(self.geometric_attention_blocks.parameters())
        diff_params.extend(self.physics_layer.parameters())
        diff_params.extend(self.uncertainty_head.parameters())
        diff_params.extend(self.noise_predictor.parameters())
        
        if self.cell_type_embedding is not None:
            diff_params.extend(self.cell_type_embedding.parameters())
            
        if self.transport_plan is not None:
            diff_params.append(self.ot_guidance_strength)
            
        self.optimizer_diff = torch.optim.Adam(diff_params, lr=lr_d, betas=(0.9, 0.999))
        self.scheduler_diff = lr_scheduler.CosineAnnealingWarmRestarts(self.optimizer_diff, T_0=500)
        
    def add_noise(self, coords, t, noise_schedule):
        """Add noise to coordinates according to the diffusion schedule"""
        noise = torch.randn_like(coords)
        sqrt_alphas_cumprod_t = noise_schedule['sqrt_alphas_cumprod'][t].view(-1, 1)
        sqrt_one_minus_alphas_cumprod_t = noise_schedule['sqrt_one_minus_alphas_cumprod'][t].view(-1, 1)
        
        noisy_coords = sqrt_alphas_cumprod_t * coords + sqrt_one_minus_alphas_cumprod_t * noise
        return noisy_coords, noise
        
        
    def forward_diffusion(self, noisy_coords, t, features, cell_types=None):
        """Forward pass through the advanced diffusion model"""
        batch_size = noisy_coords.shape[0]
        
        # Encode inputs
        time_emb = self.time_embed(t)
        coord_emb = self.coord_encoder(noisy_coords)
        
        # Process features with optional cell type
        if cell_types is not None and self.cell_type_embedding is not None:
            cell_type_emb = self.cell_type_embedding(cell_types)
            combined_features = torch.cat([features, cell_type_emb], dim=-1)
        else:
            #when no cell types, pad with zeros to match expected input size
            if self.cell_type_embedding is not None:
                #create zero padding for cell type embedding
                cell_type_dim = self.n_embedding[-1] // 2
                zero_padding = torch.zeros(batch_size, cell_type_dim, device=features.device)
                combined_features = torch.cat([features, zero_padding], dim=-1)
            else:
                combined_features = features
            # combined_features = features
            
        feat_emb = self.feat_proj(combined_features)
        
        # Combine embeddings
        h = coord_emb + time_emb + feat_emb
        
        # Process through hierarchical blocks with geometric attention
        for i, block in enumerate(self.hierarchical_blocks):
            h = block(h, t)
                
        # Predict noise
        noise_pred = self.noise_predictor(h)
        
        # Compute physics-informed correction
        physics_correction, cell_radii = self.physics_layer(noisy_coords, h, cell_types)
        
        # Compute uncertainty
        uncertainty = self.uncertainty_head(h)
        
        # Apply corrections based on timestep (less physics at high noise)
        # t_factor = 1 - t / self.n_timesteps  # 0 at start, 1 at end
        # noise_pred = noise_pred + t_factor * physics_correction * 0.1
        t_factor = (1 - t).unsqueeze(-1) #shape: (natch_size, 1)
        noise_pred = noise_pred + t_factor * physics_correction * 0.1
        
        return noise_pred, uncertainty, cell_radii
        
    def train_encoder(self, n_epochs=1000, ratio_start=0, ratio_end=1.0):
        """Train the STEM encoder to align ST and SC data"""
        print("Training STEM encoder...")
        
        # Log training start
        with open(self.train_log, 'a') as f:
            localtime = time.asctime(time.localtime(time.time()))
            f.write(f"{localtime} - Starting STEM encoder training\n")
            f.write(f"n_epochs={n_epochs}, ratio_start={ratio_start}, ratio_end={ratio_end}\n")
        
        # Calculate spatial adjacency matrix
        if self.sigma == 0:
            nettrue = torch.eye(self.st_coords.shape[0], device=self.device)
        else:
            nettrue = torch.tensor(scipy.spatial.distance.cdist(
                self.st_coords.cpu().numpy(), 
                self.st_coords.cpu().numpy()
            ), device=self.device).to(torch.float32)
            
            sigma = self.sigma
            nettrue = torch.exp(-nettrue**2/(2*sigma**2))/(np.sqrt(2*np.pi)*sigma)
            nettrue = F.normalize(nettrue, p=1, dim=1)
        
        # Training loop
        for epoch in range(n_epochs):
            # Schedule for circle loss weight
            ratio = ratio_start + (ratio_end - ratio_start) * min(epoch / (n_epochs * 0.8), 1.0)
            
            # Forward pass ST data
            e_seq_st = self.netE(self.st_gene_expr, True)
            
            # Sample from SC data due to large size
            sc_idx = torch.randint(0, self.sc_gene_expr.shape[0], (min(self.batch_size, self.mmdbatch),), device=self.device)
            sc_batch = self.sc_gene_expr[sc_idx]
            e_seq_sc = self.netE(sc_batch, False)
            
            # Calculate losses
            self.optimizer_E.zero_grad()
            
            # Prediction loss (equivalent to netpred in STEM)
            netpred = e_seq_st.mm(e_seq_st.t())
            loss_E_pred = F.cross_entropy(netpred, nettrue, reduction='mean')
            
            # Mapping matrices
            st2sc = F.softmax(e_seq_st.mm(e_seq_sc.t()), dim=1)
            sc2st = F.softmax(e_seq_sc.mm(e_seq_st.t()), dim=1)
            
            # Circle loss
            st2st = torch.log(st2sc.mm(sc2st) + 1e-7)
            loss_E_circle = F.kl_div(st2st, nettrue, reduction='none').sum(1).mean()
            
            # MMD loss
            ranidx = torch.randint(0, e_seq_sc.shape[0], (min(self.mmdbatch, e_seq_sc.shape[0]),), device=self.device)
            loss_E_mmd = self.mmd_fn(e_seq_st, e_seq_sc[ranidx])
            
            # Total loss
            loss_E = loss_E_pred + self.alpha * loss_E_mmd + ratio * loss_E_circle
            
            # Backward and optimize
            loss_E.backward()
            self.optimizer_E.step()
            self.scheduler_E.step()
            
            # Log progress
            if epoch % 200 == 0:
                log_msg = (f"Encoder epoch {epoch}/{n_epochs}, "
                          f"Loss_E: {loss_E.item():.6f}, "
                          f"Loss_E_pred: {loss_E_pred.item():.6f}, "
                          f"Loss_E_circle: {loss_E_circle.item():.6f}, "
                          f"Loss_E_mmd: {loss_E_mmd.item():.6f}, "
                          f"Ratio: {ratio:.4f}")
                
                print(log_msg)
                with open(self.train_log, 'a') as f:
                    f.write(log_msg + '\n')
                
                # Save checkpoint
                if epoch % 500 == 0:
                    torch.save({
                        'epoch': epoch,
                        'netE_state_dict': self.netE.state_dict(),
                        'optimizer_state_dict': self.optimizer_E.state_dict(),
                        'scheduler_state_dict': self.scheduler_E.state_dict(),
                    }, os.path.join(self.outf, f'encoder_checkpoint_epoch_{epoch}.pt'))
        
        # Save final encoder
        torch.save({
            'netE_state_dict': self.netE.state_dict(),
        }, os.path.join(self.outf, 'final_encoder.pt'))
        
        print("Encoder training complete!")

    def train_graph_vae(self, epochs=800, lr=1e-3):
        """
        Train the Graph-VAE that learns the outline.
        ⚠️ Do **not** touch `train_encoder`; its aligned embeddings are the sole conditioning signal throughout.
        """
        print("Training Graph-VAE...")
        
        # Freeze encoder as specified
        self.netE.eval()
        for param in self.netE.parameters():
            param.requires_grad = False
        
        # Build graph from ST spatial coordinates
        print("Building spatial graph for ST data...")
        adj_idx, adj_w = precompute_knn_edges(self.st_coords_norm, k=30, device=self.device)

        self._current_adj_idx = adj_idx
        
        # Get aligned embeddings E(X_st) using trained encoder
        with torch.no_grad():
            st_features_aligned = self.netE(self.st_gene_expr)
            # ENSURE FLOAT32 DTYPE
            st_features_aligned = st_features_aligned.float()
        
        # Optimizers for Graph-VAE components
        vae_params = list(self.graph_vae_encoder.parameters()) + list(self.graph_vae_decoder.parameters())
        optimizer_vae = torch.optim.Adam(vae_params, lr=lr, weight_decay=1e-5)
        scheduler_vae = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_vae, T_max=epochs)
        
        # Training loop
        best_loss = float('inf')
        
        for epoch in range(epochs):
            optimizer_vae.zero_grad()
            
            # Forward pass through VAE encoder (now returns node-level representations)
            mu, logvar = self.graph_vae_encoder(st_features_aligned, adj_idx, adj_w)
            
            # Reparameterization trick (now works element-wise)
            z = self.graph_vae_encoder.reparameterize(mu, logvar)
            
            # Decode to coordinates (now z and st_features_aligned have same batch size)
            coords_pred = self.graph_vae_decoder(z, st_features_aligned)
            
            # Compute losses as specified
            # L_recon = MSE between decoded coords and true st_coords_norm
            L_recon = torch.nn.functional.mse_loss(coords_pred, self.st_coords_norm)
            
            # L_KL = D_KL(q(z|x) || N(0, I)) - sum over all nodes and latent dims
            L_KL = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / mu.shape[0]
            
            # Combined loss with specified weights (λ_recon : λ_KL = 1.0 : 0.01)
            # total_loss = 1.0 * L_recon + 0.1 * L_KL
            lambda_recon = 1.0
            lambda_KL    = 0.01   # <<<< drop this from 1.0 down to 1e-3 or even 0

            total_loss = lambda_recon * L_recon +  lambda_KL * L_KL

            # Record losses for plotting
            self.vae_losses['total'].append(total_loss.item())
            self.vae_losses['reconstruction'].append(L_recon.item())
            self.vae_losses['kl'].append(L_KL.item())
            self.vae_losses['epochs'].append(epoch)
            
            # Backward pass
            total_loss.backward()
            optimizer_vae.step()
            scheduler_vae.step()
            
            # Logging
            if epoch % 200 == 0:
                log_msg = (f"Graph-VAE epoch {epoch}/{epochs}, "
                        f"Total Loss: {total_loss.item():.6f}, "
                        f"L_recon: {L_recon.item():.6f}, "
                        f"L_KL: {L_KL.item():.6f}")
                print(log_msg)
                with open(self.train_log, 'a') as f:
                    f.write(log_msg + '\n')
            
            # Save best model
            if total_loss.item() < best_loss:
                best_loss = total_loss.item()
                torch.save({
                    'encoder_state_dict': self.graph_vae_encoder.state_dict(),
                    'decoder_state_dict': self.graph_vae_decoder.state_dict(),
                    'epoch': epoch,
                    'loss': best_loss
                }, os.path.join(self.outf, 'best_graph_vae.pt'))
        
        # Save final models
        torch.save({
            'encoder_state_dict': self.graph_vae_encoder.state_dict(),
            'decoder_state_dict': self.graph_vae_decoder.state_dict(),
        }, os.path.join(self.outf, 'final_graph_vae.pt'))
        
        print("Graph-VAE training complete!")

    def train_diffusion_latent(self, n_epochs=400, lambda_struct=10.0):
        """
        Train latent-space conditional DDPM.
        ⚠️ Do **not** touch `train_encoder`; its aligned embeddings are the sole conditioning signal throughout.
        """
        print("Training latent-space diffusion model...")
        
        # Freeze encoder and Graph-VAE encoder
        self.netE.eval()
        self.graph_vae_encoder.eval()
        for param in self.netE.parameters():
            param.requires_grad = False
        for param in self.graph_vae_encoder.parameters():
            param.requires_grad = False
        
        # Precompute fixed ST latents as specified
        print("Computing fixed ST latents...")
        st_adj_idx, st_adj_w = precompute_knn_edges(self.st_coords_norm, k=30, device=self.device)
        
        with torch.no_grad():
            st_features_aligned = self.netE(self.st_gene_expr)
            # ENSURE FLOAT32 DTYPE
            st_features_aligned = st_features_aligned.float()
            st_mu, st_logvar = self.graph_vae_encoder(st_features_aligned, st_adj_idx, st_adj_w)
            z_st = self.graph_vae_encoder.reparameterize(st_mu, st_logvar)
        
        # Setup optimizer for latent denoiser
        optimizer_latent = torch.optim.AdamW(
            self.latent_denoiser.parameters(), 
            lr=self.optimizer_diff.param_groups[0]['lr'], 
            weight_decay=1e-5
        )
        scheduler_latent = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer_latent, T_max=n_epochs, eta_min=1e-6
        )
        
        # Training loop - identical to old train_diffusion but in latent space
        best_loss = float('inf')
        
        for epoch in range(n_epochs):
            # Sample batch
            idx = torch.randperm(len(z_st))[:self.batch_size]
            batch_z_st = z_st[idx]
            batch_st_features = st_features_aligned[idx]
            
            # Sample random timesteps
            t = torch.randint(0, self.n_timesteps, (len(batch_z_st),), device=self.device)
            
            # Add noise to latent vectors (targets = random noise in latent space)
            noise = torch.randn_like(batch_z_st)
            
            # Get noise schedule parameters
            alpha_t = self.noise_schedule['alphas_cumprod'][t].view(-1, 1)
            
            # Forward diffusion: add noise to clean latents
            z_noisy = torch.sqrt(alpha_t) * batch_z_st + torch.sqrt(1 - alpha_t) * noise
            
            # Predict noise using latent denoiser
            t_normalized = t.float().unsqueeze(1) / self.n_timesteps
            noise_pred = self.latent_denoiser(z_noisy, t_normalized, batch_st_features)
            
            # Compute diffusion loss
            loss_diffusion = torch.nn.functional.mse_loss(noise_pred, noise)
            
            # Structure loss in latent space (optional, can be simplified)
            # loss_struct = 0.0
            if lambda_struct > 0:
                # Simple latent space structure loss
                latent_distances = torch.cdist(batch_z_st, batch_z_st, p=2)
                pred_distances = torch.cdist(noise_pred, noise_pred, p=2)
                loss_struct = torch.nn.functional.mse_loss(pred_distances, latent_distances)
            
            # Combined loss
            total_loss = loss_diffusion + lambda_struct * loss_struct

            # Record losses for plotting
            self.latent_diffusion_losses['total'].append(total_loss.item())
            self.latent_diffusion_losses['diffusion'].append(loss_diffusion.item())
            self.latent_diffusion_losses['struct'].append(loss_struct.item() if isinstance(loss_struct, torch.Tensor) else loss_struct)
            self.latent_diffusion_losses['epochs'].append(epoch)
            
            # Backward pass
            optimizer_latent.zero_grad()
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.latent_denoiser.parameters(), 1.0)
            optimizer_latent.step()
            scheduler_latent.step()
            
            # Logging
            if epoch % 500 == 0:
                log_msg = (f"Latent Diffusion epoch {epoch}/{n_epochs}, "
                        f"Total Loss: {total_loss.item():.6f}, "
                        f"Diffusion Loss: {loss_diffusion.item():.6f}, "
                        f"Struct Loss: {loss_struct:.6f}" if isinstance(loss_struct, float) 
                        else f"Struct Loss: {loss_struct.item():.6f}")
                print(log_msg)
                with open(self.train_log, 'a') as f:
                    f.write(log_msg + '\n')
            
            # Save checkpoint
            if epoch % 500 == 0:
                torch.save({
                    'epoch': epoch,
                    'latent_denoiser_state_dict': self.latent_denoiser.state_dict(),
                    'optimizer_state_dict': optimizer_latent.state_dict(),
                }, os.path.join(self.outf, f'latent_diffusion_checkpoint_epoch_{epoch}.pt'))
        
        # Save final model
        torch.save({
            'latent_denoiser_state_dict': self.latent_denoiser.state_dict(),
        }, os.path.join(self.outf, 'final_latent_diffusion.pt'))
        
        print("Latent diffusion training complete!")
                        

    def train(self, encoder_epochs=1000, vae_epochs=800, diffusion_epochs=400, **kwargs):
        """
        Combined training pipeline: encoder → graph_vae → diffusion_latent
        ⚠️ Do **not** touch `train_encoder`; its aligned embeddings are the sole conditioning signal throughout.
        """
        print("Starting Graph-VAE + Latent Diffusion training pipeline...")
        
        # Stage 1: Train encoder (DO NOT MODIFY - keep existing train_encoder)
        print("Stage 1: Training domain alignment encoder...")
        self.train_encoder(n_epochs=encoder_epochs)
        
        # Stage 2: Train Graph-VAE
        print("Stage 2: Training Graph-VAE...")
        self.train_graph_vae(epochs=vae_epochs)
        
        # Stage 3: Train latent diffusion
        print("Stage 3: Training latent diffusion...")
        self.train_diffusion_latent(n_epochs=diffusion_epochs, **kwargs)
        
        print("Complete training pipeline finished!")

    def sample_sc_coordinates_batched(self, batch_size=512):
        """
        Batched version of sample_sc_coordinates to handle memory constraints.
        ⚠️ Do **not** touch `train_encoder`; its aligned embeddings are the sole conditioning signal throughout.
        """
        n_total = len(self.sc_gene_expr)
        print(f"Sampling {n_total} SC coordinates using Graph-VAE + Latent Diffusion (batched)...")
        
        # Set models to eval mode
        self.netE.eval()
        self.graph_vae_encoder.eval()
        self.graph_vae_decoder.eval()
        self.latent_denoiser.eval()
        
        all_coords = []
        n_batches = (n_total + batch_size - 1) // batch_size
        
        with torch.no_grad():
            for batch_idx in range(n_batches):
                start_idx = batch_idx * batch_size
                end_idx = min(start_idx + batch_size, n_total)
                batch_sc_expr = self.sc_gene_expr[start_idx:end_idx]
                
                print(f"Processing batch {batch_idx + 1}/{n_batches} ({len(batch_sc_expr)} cells)...")
                
                # Get aligned SC embeddings for this batch
                sc_features_aligned = self.netE(batch_sc_expr).float()
                
                # Build SC graph for this batch in expression space
                sc_adj_idx, sc_adj_w = precompute_knn_edges(batch_sc_expr, k=30, device=self.device)
                
                # Get SC latents using Graph-VAE encoder
                sc_mu, sc_logvar = self.graph_vae_encoder(sc_features_aligned, sc_adj_idx, sc_adj_w)
                z_sc = self.graph_vae_encoder.reparameterize(sc_mu, sc_logvar)
                
                # ENSURE z_sc is 2D
                if z_sc.dim() > 2:
                    z_sc = z_sc.squeeze()
                
                # Initialize random noise in latent space
                # z_t = torch.randn_like(z_sc)
                # Instead of starting from pure random noise, start from noised Graph-VAE latents
                eps = torch.randn_like(z_sc)
                alpha_bar_T = self.noise_schedule['alphas_cumprod'][self.n_timesteps - 1]  # ᾱ_T (final timestep)
                z_t = (alpha_bar_T.sqrt() * z_sc) + ((1 - alpha_bar_T).sqrt() * eps)
                
                # Reverse diffusion process in latent space
                for t in reversed(range(self.n_timesteps)):
                    # FIX: Use 1D timestep tensor instead of 2D
                    t_tensor = torch.full((len(z_sc),), t / self.n_timesteps, device=self.device)
                    
                    # Predict noise in latent space
                    noise_pred = self.latent_denoiser(z_t, t_tensor, sc_features_aligned)
                    
                    # ENSURE noise_pred is 2D
                    if noise_pred.dim() > 2:
                        noise_pred = noise_pred.squeeze()
                    
                    # Update latent representation (standard DDPM reverse step)
                    alpha_t = self.noise_schedule['alphas'][t]
                    alpha_cumprod_t = self.noise_schedule['alphas_cumprod'][t]
                    beta_t = self.noise_schedule['betas'][t]
                    
                    if t > 0:
                        noise = torch.randn_like(z_t)
                    else:
                        noise = 0
                        
                    z_t = (1 / torch.sqrt(alpha_t)) * (
                        z_t - ((1 - alpha_t) / torch.sqrt(1 - alpha_cumprod_t)) * noise_pred
                    ) + torch.sqrt(beta_t) * noise
                    
                    # ENSURE z_t stays 2D
                    if z_t.dim() > 2:
                        z_t = z_t.squeeze()
                
                    # ENSURE both inputs to decoder are 2D - FORCE CORRECT SHAPES
                    if z_t.dim() > 2:
                        # If z_t is [256, 256, 32], we want [256, 32] - take diagonal or first slice
                        if z_t.shape[0] == z_t.shape[1]:  # Square matrix case
                            # Take diagonal elements to get [256, 32]
                            z_t = torch.diagonal(z_t, dim1=0, dim2=1).T
                        else:
                            z_t = z_t.squeeze()

                    if sc_features_aligned.dim() > 2:
                        sc_features_aligned = sc_features_aligned.squeeze()

                    # print(f"DEBUG AFTER FIX: z_t.shape = {z_t.shape}, sc_features_aligned.shape = {sc_features_aligned.shape}")

                # Decode final latent to 2D coordinates using Graph-VAE decoder
                batch_coords = self.graph_vae_decoder(z_t, sc_features_aligned)
                
                # Move to CPU and store
                all_coords.append(batch_coords.cpu())
                
                # Clear GPU cache between batches
                torch.cuda.empty_cache()
        
        # Combine all batches
        final_coords = torch.cat(all_coords, dim=0)
                
        print("Batched sampling complete!")
        return final_coords.cpu().numpy()


    def plot_training_losses(self):
        """Plot training losses for Graph-VAE + Latent Diffusion pipeline"""
        import matplotlib.pyplot as plt
        import numpy as np
        
        # Determine how many subplots we need
        n_plots = 0
        if len(self.vae_losses['epochs']) > 0:
            n_plots += 2  # VAE losses and VAE smoothed
        if len(self.latent_diffusion_losses['epochs']) > 0:
            n_plots += 2  # Latent diffusion losses and smoothed
        
        if n_plots == 0:
            print("No training losses to plot.")
            return
        
        # Create figure with appropriate number of subplots
        if n_plots == 2:
            fig, axes = plt.subplots(1, 2, figsize=(15, 5))
            axes = [axes] if n_plots == 2 else axes
        elif n_plots == 4:
            fig, axes = plt.subplots(2, 2, figsize=(15, 10))
            axes = axes.flatten()
        else:
            fig, axes = plt.subplots(1, n_plots, figsize=(7*n_plots, 5))
            if n_plots == 1:
                axes = [axes]
        
        plot_idx = 0
        
        # Plot 1: Graph-VAE losses
        if len(self.vae_losses['epochs']) > 0:
            epochs_vae = np.array(self.vae_losses['epochs'])
            ax = axes[plot_idx]
            
            ax.plot(epochs_vae, self.vae_losses['total'], 'b-', label='Total VAE Loss', linewidth=2)
            ax.plot(epochs_vae, self.vae_losses['reconstruction'], 'g-', label='Reconstruction Loss', linewidth=2)
            ax.plot(epochs_vae, np.array(self.vae_losses['kl']) * 0.01, 'r--', label='KL Loss (×0.01)', alpha=0.8)
            
            ax.set_xlabel('Epoch')
            ax.set_ylabel('Loss')
            ax.set_title('Graph-VAE Training Losses')
            ax.legend()
            ax.grid(True, alpha=0.3)
            ax.set_yscale('log')
            plot_idx += 1
            
            # Plot 2: Graph-VAE smoothed
            if len(self.vae_losses['total']) > 1:
                ax = axes[plot_idx]
                window = min(50, len(self.vae_losses['total']) // 10)
                if window > 1:
                    smoothed = np.convolve(self.vae_losses['total'], 
                                        np.ones(window)/window, mode='valid')
                    smooth_epochs = epochs_vae[window-1:]
                    ax.plot(epochs_vae, self.vae_losses['total'], 'lightblue', alpha=0.5, label='Raw')
                    ax.plot(smooth_epochs, smoothed, 'blue', linewidth=2, label=f'Smoothed (window={window})')
                else:
                    ax.plot(epochs_vae, self.vae_losses['total'], 'blue', linewidth=2)
                
                ax.set_xlabel('Epoch')
                ax.set_ylabel('Loss')
                ax.set_title('Graph-VAE Total Loss (Smoothed)')
                if window > 1:
                    ax.legend()
                ax.grid(True, alpha=0.3)
                ax.set_yscale('log')
                plot_idx += 1
        
        # Plot 3: Latent Diffusion losses
        if len(self.latent_diffusion_losses['epochs']) > 0:
            epochs_diff = np.array(self.latent_diffusion_losses['epochs'])
            ax = axes[plot_idx]
            
            ax.plot(epochs_diff, self.latent_diffusion_losses['total'], 'b-', label='Total Loss', linewidth=2)
            ax.plot(epochs_diff, self.latent_diffusion_losses['diffusion'], 'k-', label='Diffusion Loss', linewidth=2)
            
            # Only plot struct loss if it's non-zero
            struct_losses = np.array(self.latent_diffusion_losses['struct'])
            if np.any(struct_losses > 0):
                ax.plot(epochs_diff, struct_losses, 'r--', label='Structure Loss', alpha=0.8)
            
            ax.set_xlabel('Epoch')
            ax.set_ylabel('Loss')
            ax.set_title('Latent Diffusion Training Losses')
            ax.legend()
            ax.grid(True, alpha=0.3)
            ax.set_yscale('log')
            plot_idx += 1
            
            # Plot 4: Latent Diffusion smoothed
            if len(self.latent_diffusion_losses['total']) > 1:
                ax = axes[plot_idx]
                window = min(50, len(self.latent_diffusion_losses['total']) // 10)
                if window > 1:
                    smoothed = np.convolve(self.latent_diffusion_losses['total'],
                                        np.ones(window)/window, mode='valid')
                    smooth_epochs = epochs_diff[window-1:]
                    ax.plot(epochs_diff, self.latent_diffusion_losses['total'], 'lightcoral', alpha=0.5, label='Raw')
                    ax.plot(smooth_epochs, smoothed, 'red', linewidth=2, label=f'Smoothed (window={window})')
                else:
                    ax.plot(epochs_diff, self.latent_diffusion_losses['total'], 'red', linewidth=2)
                
                ax.set_xlabel('Epoch')
                ax.set_ylabel('Loss')
                ax.set_title('Latent Diffusion Total Loss (Smoothed)')
                if window > 1:
                    ax.legend()
                ax.grid(True, alpha=0.3)
                ax.set_yscale('log')
                plot_idx += 1
        
        plt.tight_layout()
        plt.show()
        
        # Print final loss values
        print("\n=== Training Loss Summary ===")
        
        if len(self.vae_losses['total']) > 0:
            print(f"Graph-VAE - Initial Loss: {self.vae_losses['total'][0]:.6f}")
            print(f"Graph-VAE - Final Loss: {self.vae_losses['total'][-1]:.6f}")
            print(f"Graph-VAE - Loss Reduction: {(1 - self.vae_losses['total'][-1]/self.vae_losses['total'][0])*100:.2f}%")
            print(f"Graph-VAE - Final Reconstruction Loss: {self.vae_losses['reconstruction'][-1]:.6f}")
            print(f"Graph-VAE - Final KL Loss: {self.vae_losses['kl'][-1]:.6f}")
        
        if len(self.latent_diffusion_losses['total']) > 0:
            print(f"Latent Diffusion - Initial Loss: {self.latent_diffusion_losses['total'][0]:.6f}")
            print(f"Latent Diffusion - Final Loss: {self.latent_diffusion_losses['total'][-1]:.6f}")
            print(f"Latent Diffusion - Loss Reduction: {(1 - self.latent_diffusion_losses['total'][-1]/self.latent_diffusion_losses['total'][0])*100:.2f}%")
            print(f"Latent Diffusion - Final Diffusion Loss: {self.latent_diffusion_losses['diffusion'][-1]:.6f}")
            if np.any(np.array(self.latent_diffusion_losses['struct']) > 0):
                print(f"Latent Diffusion - Final Structure Loss: {self.latent_diffusion_losses['struct'][-1]:.6f}")


# =====================================================
# PART 4: MMD Loss Implementation
# =====================================================

class MMDLoss(nn.Module):
    def __init__(self, kernel_type='rbf', kernel_mul=2.0, kernel_num=5, fix_sigma=None):
        super(MMDLoss, self).__init__()
        self.kernel_num = kernel_num
        self.kernel_mul = kernel_mul
        self.fix_sigma = fix_sigma
        self.kernel_type = kernel_type

    def guassian_kernel(self, source, target, kernel_mul, kernel_num, fix_sigma):
        n_samples = int(source.size()[0]) + int(target.size()[0])
        total = torch.cat([source, target], dim=0)
        total0 = total.unsqueeze(0).expand(
            int(total.size(0)), int(total.size(0)), int(total.size(1)))
        total1 = total.unsqueeze(1).expand(
            int(total.size(0)), int(total.size(0)), int(total.size(1)))
        L2_distance = ((total0-total1)**2).sum(2)
        if fix_sigma:
            bandwidth = fix_sigma
        else:
            bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)
        bandwidth /= kernel_mul ** (kernel_num // 2)
        bandwidth_list = [bandwidth * (kernel_mul**i)
                          for i in range(kernel_num)]
        kernel_val = [torch.exp(-L2_distance / bandwidth_temp)
                      for bandwidth_temp in bandwidth_list]
        tmp = 0
        for x in kernel_val:
            tmp += x
        return tmp

    def linear_mmd2(self, f_of_X, f_of_Y):
        loss = 0.0
        delta = f_of_X.float().mean(0) - f_of_Y.float().mean(0)
        loss = delta.dot(delta.T)
        return loss

    def forward(self, source, target):
        if self.kernel_type == 'linear':
            return self.linear_mmd2(source, target)
        elif self.kernel_type == 'rbf':
            batch_size = int(source.size()[0])
            kernels = self.guassian_kernel(
                source, target, kernel_mul=self.kernel_mul, kernel_num=self.kernel_num, fix_sigma=self.fix_sigma)
            XX = torch.mean(kernels[:batch_size, :batch_size])
            YY = torch.mean(kernels[batch_size:, batch_size:])
            XY = torch.mean(kernels[:batch_size, batch_size:])
            YX = torch.mean(kernels[batch_size:, :batch_size])
            loss = torch.mean(XX + YY - XY - YX)
            return loss

In [None]:
def load_and_process_cscc_data_individual_norm():
    """
    Load and process cSCC data with individual normalization per ST dataset.
    """
    print("Loading cSCC data with individual normalization...")
    
    # Load SC data
    scadata = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/scP2.h5ad')
    
    # Load all 3 ST datasets
    stadata1 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP2.h5ad')
    stadata2 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP2rep2.h5ad')
    stadata3 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP2rep3.h5ad')
    
    # Normalize expression data (same for all)
    for adata in [scadata, stadata1, stadata2, stadata3]:
        sc.pp.normalize_total(adata)
        sc.pp.log1p(adata)
    
    # Create rough cell types for SC data
    scadata.obs['rough_celltype'] = scadata.obs['level1_celltype'].astype(str)
    scadata.obs.loc[scadata.obs['level1_celltype']=='CLEC9A','rough_celltype'] = 'DC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='CD1C','rough_celltype'] = 'DC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='ASDC','rough_celltype'] = 'DC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='PDC','rough_celltype'] = 'PDC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='MDSC','rough_celltype'] = 'DC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='LC','rough_celltype'] = 'DC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='Mac','rough_celltype'] = 'Myeloid cell'
    scadata.obs.loc[scadata.obs['level1_celltype']=='Tcell','rough_celltype'] = 'T cell'
    scadata.obs.loc[scadata.obs['level2_celltype']=='TSK','rough_celltype'] = 'TSK'
    scadata.obs.loc[scadata.obs['level2_celltype'].isin(['Tumor_KC_Basal', 'Tumor_KC_Diff','Tumor_KC_Cyc']),'rough_celltype'] = 'NonTSK'


    
    return scadata, stadata1, stadata2, stadata3

def normalize_coordinates_individually(coords):
    """
    Normalize coordinates to [-1, 1] range individually.
    """
    coords_min = coords.min(axis=0)
    coords_max = coords.max(axis=0)
    coords_range = coords_max - coords_min
    
    # Avoid division by zero
    coords_range[coords_range == 0] = 1.0
    
    # Normalize to [-1, 1]
    coords_normalized = 2 * (coords - coords_min) / coords_range - 1
    
    return coords_normalized, coords_min, coords_max, coords_range

def prepare_individually_normalized_st_data(stadata1, stadata2, stadata3, scadata):
    """
    Normalize each ST dataset individually, then combine.
    """
    print("Preparing individually normalized ST data...")
    
    # Get common genes
    sc_genes = set(scadata.var_names)
    st1_genes = set(stadata1.var_names)
    st2_genes = set(stadata2.var_names)
    st3_genes = set(stadata3.var_names)
    
    common_genes = sorted(list(sc_genes & st1_genes & st2_genes & st3_genes))
    print(f"Common genes across all datasets: {len(common_genes)}")
    
    # Extract aligned expression data
    sc_expr = scadata[:, common_genes].X
    st1_expr = stadata1[:, common_genes].X
    st2_expr = stadata2[:, common_genes].X
    st3_expr = stadata3[:, common_genes].X
    
    # Convert to dense if sparse
    if hasattr(sc_expr, 'toarray'):
        sc_expr = sc_expr.toarray()
    if hasattr(st1_expr, 'toarray'):
        st1_expr = st1_expr.toarray()
    if hasattr(st2_expr, 'toarray'):
        st2_expr = st2_expr.toarray()
    if hasattr(st3_expr, 'toarray'):
        st3_expr = st3_expr.toarray()
    
    # Get spatial coordinates and normalize individually
    st1_coords = stadata1.obsm['spatial']
    st2_coords = stadata2.obsm['spatial']
    st3_coords = stadata3.obsm['spatial']
    
    print("Normalizing coordinates individually...")
    st1_coords_norm, st1_min, st1_max, st1_range = normalize_coordinates_individually(st1_coords)
    st2_coords_norm, st2_min, st2_max, st2_range = normalize_coordinates_individually(st2_coords)
    st3_coords_norm, st3_min, st3_max, st3_range = normalize_coordinates_individually(st3_coords)
    
    print(f"ST1 coord range: [{st1_coords_norm.min():.3f}, {st1_coords_norm.max():.3f}]")
    print(f"ST2 coord range: [{st2_coords_norm.min():.3f}, {st2_coords_norm.max():.3f}]")
    print(f"ST3 coord range: [{st3_coords_norm.min():.3f}, {st3_coords_norm.max():.3f}]")
    
    # Combine all ST data
    st_expr_combined = np.vstack([st1_expr, st2_expr, st3_expr])
    st_coords_combined = np.vstack([st1_coords_norm, st2_coords_norm, st3_coords_norm])
    
    # Create dataset metadata
    dataset_info = {
        'labels': (['dataset1'] * len(st1_expr) + 
                  ['dataset2'] * len(st2_expr) + 
                  ['dataset3'] * len(st3_expr)),
        'normalization_params': {
            'dataset1': {'min': st1_min, 'max': st1_max, 'range': st1_range},
            'dataset2': {'min': st2_min, 'max': st2_max, 'range': st2_range},
            'dataset3': {'min': st3_min, 'max': st3_max, 'range': st3_range}
        }
    }
    
    print(f"Combined ST data shape: {st_expr_combined.shape}")
    print(f"Combined ST coords shape: {st_coords_combined.shape}")
    print(f"SC data shape: {sc_expr.shape}")
    
    # Convert to tensors
    X_sc = torch.tensor(sc_expr, dtype=torch.float32)
    X_st_combined = torch.tensor(st_expr_combined, dtype=torch.float32)
    Y_st_combined = st_coords_combined.astype(np.float32)
    
    return X_sc, X_st_combined, Y_st_combined, dataset_info, common_genes

In [None]:
scadata, stadata1, stadata2, stadata3 = load_and_process_cscc_data_individual_norm()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def analyze_cell_interactions_advanced(scadata, coords_key='advanced_diffusion_coords_avg'):
    """Analyze cell-cell interactions using the advanced diffusion coordinates"""
    
    # Get coordinates from scadata
    coords_mean = scadata.obsm[coords_key]
    
    # Compute pairwise distances
    from scipy.spatial.distance import pdist, squareform
    distances = squareform(pdist(coords_mean))
    
    # Get cell types
    cell_types = scadata.obs['rough_celltype'].values
    unique_types = np.unique(cell_types)
    
    # Analyze minimum distances between cell types
    min_distances = {}
    for i, type1 in enumerate(unique_types):
        for j, type2 in enumerate(unique_types):
            if i <= j:  # Include self-interactions
                mask1 = cell_types == type1
                mask2 = cell_types == type2
                
                if i == j:
                    # Same cell type - exclude self
                    sub_dist = distances[np.ix_(mask1, mask2)]
                    np.fill_diagonal(sub_dist, np.inf)
                    if sub_dist.size > 0:
                        min_dist = np.min(sub_dist[sub_dist < np.inf])
                    else:
                        min_dist = np.nan
                else:
                    # Different cell types
                    sub_dist = distances[np.ix_(mask1, mask2)]
                    min_dist = np.min(sub_dist) if sub_dist.size > 0 else np.nan
                
                min_distances[(type1, type2)] = min_dist
    
    # Create interaction matrix visualization
    interaction_matrix = np.full((len(unique_types), len(unique_types)), np.nan)
    for i, type1 in enumerate(unique_types):
        for j, type2 in enumerate(unique_types):
            key = (type1, type2) if i <= j else (type2, type1)
            if key in min_distances:
                interaction_matrix[i, j] = min_distances[key]
                interaction_matrix[j, i] = min_distances[key]
    
    # Plot
    plt.figure(figsize=(10, 8))
    mask = ~np.isnan(interaction_matrix)
    sns.heatmap(interaction_matrix, 
                annot=True, fmt='.3f', 
                xticklabels=unique_types,
                yticklabels=unique_types,
                cmap='coolwarm_r',
                mask=~mask,
                cbar_kws={'label': 'Minimum Distance'})
    plt.title(f'Minimum Distances Between Cell Types ({coords_key})')
    plt.tight_layout()
    plt.show()
    
    return min_distances, interaction_matrix

def visualize_advanced_results_multi_model(scadata):
    """Visualize results from multiple models with uncertainty analysis"""
    
    # Get coordinates from all models
    coords_avg = scadata.obsm['advanced_diffusion_coords_avg']
    coords_rep1 = scadata.obsm['advanced_diffusion_coords_rep1'] 
    coords_rep2 = scadata.obsm['advanced_diffusion_coords_rep2']
    coords_rep3 = scadata.obsm['advanced_diffusion_coords_rep3']
    
    # Calculate uncertainty metrics across models
    all_coords = np.stack([coords_rep1, coords_rep2, coords_rep3], axis=0)  # (3, n_cells, 2)
    coords_std = np.std(all_coords, axis=0)  # Standard deviation across models
    coords_var = np.var(all_coords, axis=0)  # Variance across models
    
    # Total variability (combining x and y dimensions)
    total_std = np.sqrt(coords_std[:, 0]**2 + coords_std[:, 1]**2)
    total_var = coords_var[:, 0] + coords_var[:, 1]
    
    # Create confidence scores (inverse of variability)
    confidence = 1 / (1 + total_std)
    scadata.obs['spatial_confidence'] = confidence
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    
    # 1. Spatial coordinates colored by cell type
    ax = axes[0, 0]
    cell_types = scadata.obs['rough_celltype']
    unique_types = cell_types.unique()
    colors = sns.color_palette('tab20', n_colors=len(unique_types))
    
    for i, ct in enumerate(unique_types):
        mask = cell_types == ct
        ax.scatter(coords_avg[mask, 0], coords_avg[mask, 1], 
                  c=[colors[i]], label=ct, s=30, alpha=0.7)
    ax.set_title('Averaged Spatial Coordinates by Cell Type', fontsize=14)
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
    
    # 2. Model variability (standard deviation across 3 models)
    ax = axes[0, 1]
    scatter = ax.scatter(coords_avg[:, 0], coords_avg[:, 1], 
                        c=total_std, cmap='viridis_r', 
                        s=30, alpha=0.7)
    plt.colorbar(scatter, ax=ax, label='Model Std Dev')
    ax.set_title('Model Prediction Variability', fontsize=14)
    
    # 3. X vs Y coordinate uncertainty
    ax = axes[0, 2]
    scatter = ax.scatter(coords_std[:, 0], coords_std[:, 1], 
                        c=total_std, cmap='plasma', s=30, alpha=0.7)
    plt.colorbar(scatter, ax=ax, label='Total Std')
    ax.set_xlabel('X Coordinate Std')
    ax.set_ylabel('Y Coordinate Std')
    ax.set_title('Coordinate Uncertainty (X vs Y)', fontsize=14)
    
    # 4. Cell density heatmap
    ax = axes[1, 0]
    from scipy.stats import gaussian_kde
    xy = coords_avg.T
    z = gaussian_kde(xy)(xy)
    scatter = ax.scatter(coords_avg[:, 0], coords_avg[:, 1], 
                        c=z, cmap='hot', s=30, alpha=0.7)
    plt.colorbar(scatter, ax=ax, label='Density')
    ax.set_title('Cell Density (Averaged Coordinates)', fontsize=14)
    
    # 5. Confidence scores
    ax = axes[1, 1]
    scatter = ax.scatter(coords_avg[:, 0], coords_avg[:, 1], 
                        c=confidence, cmap='RdYlGn', s=30, alpha=0.7)
    plt.colorbar(scatter, ax=ax, label='Confidence')
    ax.set_title('Prediction Confidence Across Models', fontsize=14)
    
    # 6. Model agreement visualization
    ax = axes[1, 2]
    # Show cells where models agree vs disagree
    high_agreement = total_std < np.percentile(total_std, 25)  # Bottom 25%
    low_agreement = total_std > np.percentile(total_std, 75)   # Top 25%
    
    ax.scatter(coords_avg[high_agreement, 0], coords_avg[high_agreement, 1], 
              c='green', s=20, alpha=0.7, label='High Agreement')
    ax.scatter(coords_avg[low_agreement, 0], coords_avg[low_agreement, 1], 
              c='red', s=20, alpha=0.7, label='Low Agreement')
    ax.scatter(coords_avg[~(high_agreement | low_agreement), 0], 
              coords_avg[~(high_agreement | low_agreement), 1], 
              c='gray', s=10, alpha=0.5, label='Medium Agreement')
    ax.set_title('Model Agreement', fontsize=14)
    ax.legend()
    
    plt.tight_layout()
    plt.show()
    
    return fig, total_std, confidence

In [None]:
def train_individual_advanced_diffusion_models(scadata, stadata1, stadata2, stadata3):
    """
    Train separate AdvancedHierarchicalDiffusion models for each ST dataset and average the results.
    
    Returns:
        scadata: Updated with averaged coordinates in obsm['advanced_diffusion_coords_avg']
        models_all: All trained models for further analysis
    """
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Store results from each model
    sc_coords_results = []
    models_all = []
    
    # List of ST datasets for iteration
    st_datasets = [
        (stadata1, "dataset1"),
        (stadata2, "dataset2"), 
        (stadata3, "dataset3")
    ]
    
    for i, (stadata, dataset_name) in enumerate(st_datasets):
        print(f"\n{'='*50}")
        print(f"Training AdvancedHierarchicalDiffusion model {i+1}/3 for {dataset_name}")
        print(f"{'='*50}")
        
        # Get common genes between SC and current ST dataset
        sc_genes = set(scadata.var_names)
        st_genes = set(stadata.var_names)
        common_genes = sorted(list(sc_genes & st_genes))
        
        print(f"Common genes for {dataset_name}: {len(common_genes)}")
        
        # Extract expression data
        sc_expr = scadata[:, common_genes].X
        st_expr = stadata[:, common_genes].X
        
        # Convert to dense if sparse
        if hasattr(sc_expr, 'toarray'):
            sc_expr = sc_expr.toarray()
        if hasattr(st_expr, 'toarray'):
            st_expr = st_expr.toarray()
            
        # Get spatial coordinates
        st_coords = stadata.obsm['spatial']
        
        print(f"SC data shape: {sc_expr.shape}")
        print(f"ST data shape: {st_expr.shape}")
        print(f"ST coords shape: {st_coords.shape}")
        
        # Initialize AdvancedHierarchicalDiffusion model
        model = AdvancedHierarchicalDiffusion(
            st_gene_expr=st_expr,
            st_coords=st_coords,
            sc_gene_expr=sc_expr,
            cell_types_sc=scadata.obs['rough_celltype'].values,  # No cell type labels
            transport_plan=None,  # No OT transport plan
            D_st=None,           # No distance matrices
            D_induced=None,
            n_genes=len(common_genes),
            n_embedding=[512, 256, 128],  # Same as STEMDiffusion
            coord_space_diameter=2.00,
            sigma=3.0,
            alpha=0.8,
            mmdbatch=1000,
            batch_size=256,
            device=device,
            lr_e=0.0001,
            lr_d=0.0002,
            n_timesteps=600,     # Same as STEMDiffusion
            n_denoising_blocks=4,
            hidden_dim=256,      # Same as STEMDiffusion
            num_heads=8,
            num_hierarchical_scales=3,
            dp=0.2,
            outf=f'advanced_diffusion_{dataset_name}'
        )
        
        print(f"Training model for {dataset_name}...")
        
        # Train using new Graph-VAE + Latent Diffusion pipeline
        model.train(
            encoder_epochs=1000,  # Stage 1: Domain alignment encoder
            vae_epochs=1000,       # Stage 2: Graph-VAE training
            diffusion_epochs=2500, # Stage 3: Latent diffusion
            lambda_struct=2.0     # Structure loss weight
        )
        
        model.plot_training_losses()
        
        print(f"Generating SC coordinates using model {i+1}...")
        # Sample SC coordinates using new Graph-VAE + Latent Diffusion pipeline
        # sc_coords = model.sample_sc_coordinates(
        #     n_samples=None     # Use all SC cells
        # )
        sc_coords = model.sample_sc_coordinates_batched(
            batch_size=512  # Even smaller batches
        )
        
        # Store results
        sc_coords_results.append(sc_coords)
        models_all.append(model)
        
        print(f"Model {i+1} complete! Generated coordinates shape: {sc_coords.shape}")
        
        # Clean up GPU memory
        del model
        torch.cuda.empty_cache()
    
    # Average the results from all 3 models
    print(f"\nAveraging results from {len(sc_coords_results)} models...")
    sc_coords_avg = np.mean(sc_coords_results, axis=0)
    
    # Verify shapes match
    shapes = [coords.shape for coords in sc_coords_results]
    assert all(shape == shapes[0] for shape in shapes), f"Shape mismatch: {shapes}"
    
    print(f"Final averaged coordinates shape: {sc_coords_avg.shape}")
    
    # Add to AnnData
    scadata.obsm['advanced_diffusion_coords_avg'] = sc_coords_avg
    
    # Optionally, save individual results too
    for i, coords in enumerate(sc_coords_results):
        scadata.obsm[f'advanced_diffusion_coords_rep{i+1}'] = coords
    
    print(f"\nAdvanced diffusion training complete!")
    print(f"Results saved in scadata.obsm['advanced_diffusion_coords_avg']")
    
    return scadata, models_all

# Load and process data
scadata, stadata1, stadata2, stadata3 = load_and_process_cscc_data()

# Train individual AdvancedHierarchicalDiffusion models and get averaged results
scadata, advanced_models = train_individual_advanced_diffusion_models(
    scadata, stadata1, stadata2, stadata3
)

print("Advanced diffusion training complete! Results saved in scadata.obsm['advanced_diffusion_coords_avg']")

# Visualize results
import matplotlib.pyplot as plt
import seaborn as sns

my_tab20 = sns.color_palette("tab20", n_colors=20).as_hex()

# Plot 1: Averaged coordinates
plt.figure(figsize=(8, 6))
sc.pl.embedding(scadata, basis='advanced_diffusion_coords_avg', color='rough_celltype',
               size=85, title='SC Advanced Diffusion Coords (Averaged)',
               palette=my_tab20, legend_loc='right margin', legend_fontsize=10)
plt.show()

# Plot 2: Individual model results
for i in range(3):
    plt.figure(figsize=(6, 5))
    sc.pl.embedding(scadata, basis=f'advanced_diffusion_coords_rep{i+1}', color='rough_celltype',
                   size=85, title=f'SC Coordinates (Advanced Model {i+1})',
                   palette=my_tab20, legend_loc='right margin', legend_fontsize=10)
    plt.show()

In [None]:
scadata.obsm['advanced_diffusion_coords_avg']

In [None]:
scadata.obsm['advanced_diffusion_coords_avg']

In [None]:
# Visualize results with separate plots
import matplotlib.pyplot as plt

import matplotlib as mpl
mpl.rcParams['figure.figsize'] = (4, 4)
# import scanpy as sc
# sc.settings.set_figure_params(figsize=(4,4), dpi=100)

# Plot 1: Averaged coordinates
plt.figure(figsize=(4, 4))
sc.pl.embedding(scadata, basis='advanced_diffusion_coords_avg', color='rough_celltype',
               size=85, title='SC Spatial Coordinates (Averaged from 3 Models)',
               palette=my_tab20, legend_loc='right margin', legend_fontsize=10)
plt.show()

# Plot 2: Model 1 results
plt.figure(figsize=(4, 4))
sc.pl.embedding(scadata, basis='advanced_diffusion_coords_rep1', color='rough_celltype',
               size=85, title='SC Coordinates (Model 1)',
               palette=my_tab20, legend_loc='right margin', legend_fontsize=10)
plt.show()

# Plot 3: Model 2 results
plt.figure(figsize=(4, 4))
sc.pl.embedding(scadata, basis='advanced_diffusion_coords_rep2', color='rough_celltype',
               size=85, title='SC Coordinates (Model 2)',
               palette=my_tab20, legend_loc='right margin', legend_fontsize=10)
plt.show()

import seaborn as sns
my_tab20 = sns.color_palette("tab20", n_colors=12).as_hex()

# Plot 4: Model 3 results
plt.figure(figsize=(4, 4))
sc.pl.embedding(scadata, basis='advanced_diffusion_coords_rep3', color='rough_celltype',
               size=85, title='SC Coordinates (Model 3)',
             palette=my_tab20, legend_loc='right margin', legend_fontsize=10)
plt.show()

In [None]:
# After you have: scadata.obsm['advanced_diffusion_coords_avg'], etc.

print("\n=== Advanced Analysis and Visualization ===")

mpl.rcParams['figure.figsize'] = (8, 6)


# 1. Visualize advanced results with uncertainty analysis
print("Creating advanced visualization plots...")
fig, model_uncertainty, confidence_scores = visualize_advanced_results_multi_model(scadata)

# 2. Analyze cell interactions for averaged coordinates
print("Analyzing cell interactions (averaged coordinates)...")
min_distances_avg, interaction_matrix_avg = analyze_cell_interactions_advanced(
    scadata, coords_key='advanced_diffusion_coords_avg'
)

# 3. Optional: Analyze interactions for individual models too
print("Analyzing cell interactions (individual models)...")
for i in range(1, 4):
    print(f"\nModel {i} interactions:")
    min_distances, interaction_matrix = analyze_cell_interactions_advanced(
        scadata, coords_key=f'advanced_diffusion_coords_rep{i}'
    )

# 4. Print summary statistics
print("\n=== Advanced Model Statistics ===")
print(f"Total cells mapped: {len(scadata.obsm['advanced_diffusion_coords_avg'])}")
print(f"Average model uncertainty: {model_uncertainty.mean():.4f}")
print(f"Model uncertainty range: [{model_uncertainty.min():.4f}, {model_uncertainty.max():.4f}]")
print(f"Average confidence: {confidence_scores.mean():.4f}")

print("\n=== Cell Type Confidence ===")
for ct in scadata.obs['rough_celltype'].unique():
    mask = scadata.obs['rough_celltype'] == ct
    print(f"{ct}: {mask.sum()} cells, "
          f"avg confidence: {confidence_scores[mask].mean():.3f}, "
          f"avg uncertainty: {model_uncertainty[mask].mean():.4f}")

print("\n=== Physics Constraints (Averaged Coordinates) ===")
all_distances = []
for key, dist in min_distances_avg.items():
    if not np.isnan(dist):
        all_distances.append(dist)
        print(f"Min distance {key[0]} - {key[1]}: {dist:.4f}")

if all_distances:
    print(f"\nOverall minimum cell-cell distance: {np.min(all_distances):.4f}")
    print(f"Cells with potential overlaps (< 0.01): {np.sum(np.array(all_distances) < 0.01)}")

In [None]:
# Visualize results
import matplotlib.pyplot as plt
import seaborn as sns
mpl.rcParams['figure.figsize'] = (8, 6)

my_tab20 = sns.color_palette("tab20", n_colors=20).as_hex()

# Plot 1: Averaged coordinates
plt.figure(figsize=(8, 6))
sc.pl.embedding(scadata, basis='advanced_diffusion_coords_avg', color='rough_celltype',
               size=85, title='SC Advanced Diffusion Coords (Averaged)',
               palette=my_tab20, legend_loc='right margin', legend_fontsize=10)
plt.show()

# Plot 2: Individual model results
for i in range(3):
    plt.figure(figsize=(6, 5)) 
    sc.pl.embedding(scadata, basis=f'advanced_diffusion_coords_rep{i+1}', color='rough_celltype',
                   size=85, title=f'SC Coordinates (Advanced Model {i+1})',
                   palette=my_tab20, legend_loc='right margin', legend_fontsize=10)
    plt.show()

In [None]:
def normalize_coordinates_isotropic(coords):
    '''normalize coordinates to unit circle preserving aspect ratio'''
    if torch.is_tensor(coords):
        center = coords.mean(dim=0)
        centered = coords - center
        max_radius = torch.max(torch.norm(centered, dim=1))
        coords_norm = centered / max_radius
        return coords_norm, center, max_radius
    else:
        center = coords.mean(axis=0)
        centered = coords - center
        max_radius = np.max(np.linalg.norm(centered, axis=1))
        coords_norm = centered / max_radius
        return coords_norm, center, max_radius

# Load and prepare data for validation
scadata_val, stadata1_val, stadata2_val, stadata3_val = load_and_process_cscc_data_individual_norm()

# Get normalized ground truth coordinates for ST3
st3_coords_gt = stadata3_val.obsm['spatial']
st3_coords_gt_norm, _, _ = normalize_coordinates_isotropic(st3_coords_gt)

print("=== VALIDATION EXPERIMENT ===")
print("Training diffusion models on ST1+ST2, testing on ST3...")

# Prepare datasets for training (only first 2)
st_datasets_train = [
    (stadata1_val, "dataset1"),
    (stadata2_val, "dataset2")
]


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Storage for results
T_all = []
D_induced_all = []
D_st_all = []
D_sc_all = []
trained_models = []

# Train on first two datasets
for i, (stadata, dataset_name) in enumerate(st_datasets_train):
    print(f"\n{'='*50}")
    print(f"Training Advanced Diffusion model {i+1}/2 for {dataset_name} using SpaOTsc")
    print(f"{'='*50}")
    
    # Get common genes between SC and current ST dataset
    sc_genes = set(scadata_val.var_names)
    st_genes = set(stadata.var_names)
    common_genes = sorted(list(sc_genes & st_genes))
    
    print(f"Common genes for {dataset_name}: {len(common_genes)}")
    
    # Extract expression data
    sc_expr = scadata_val[:, common_genes].X
    st_expr = stadata[:, common_genes].X
    
    # Convert to dense if sparse
    if hasattr(sc_expr, 'toarray'):
        sc_expr = sc_expr.toarray()
    if hasattr(st_expr, 'toarray'):
        st_expr = st_expr.toarray()
        
    # Get coordinates
    st_coords = stadata.obsm['spatial']
    
    # Convert to tensors
    X_sc = torch.tensor(sc_expr, dtype=torch.float32).to(device)
    X_st = torch.tensor(st_expr, dtype=torch.float32).to(device)
    Y_st = torch.tensor(st_coords, dtype=torch.float32).to(device)
    
    print(f"SC data shape: {X_sc.shape}")
    print(f"ST data shape: {X_st.shape}")
    print(f"ST coords shape: {Y_st.shape}")
    
    # === REPLACE FUSED_GW_TORCH WITH SPAOTSC ===
    print(f"Running optimal transport for {dataset_name}...")

    # === PREPARE CELL TYPE INFORMATION ===
    # Extract cell types from scadata
    if 'rough_celltype' in scadata_val.obs.columns:
        cell_types_sc = scadata_val.obs['rough_celltype'].values
        unique_cell_types = np.unique(cell_types_sc)
        print(f"Found {len(unique_cell_types)} unique cell types: {unique_cell_types}")
    else:
        cell_types_sc = None
        print("No cell type information found")

    # === INITIALIZE ADVANCED HIERARCHICAL DIFFUSION MODEL ===
    print(f"Initializing AdvancedHierarchicalDiffusion for {dataset_name}...")
    
    output_dir = f'./cscc_advanced_diffusion_{dataset_name}_validation'
    
    model = AdvancedHierarchicalDiffusion(
        st_gene_expr=st_expr,
        st_coords=st_coords,
        sc_gene_expr=sc_expr,
        cell_types_sc=scadata.obs['rough_celltype'].values,  # No cell type labels
        transport_plan=None,  # No OT transport plan
        D_st=None,           # No distance matrices
        D_induced=None,
        n_genes=len(common_genes),
        n_embedding=[512, 256, 128],  # Same as STEMDiffusion
        coord_space_diameter=2.00,
        sigma=3.0,
        alpha=0.8,
        mmdbatch=1000,
        batch_size=256,
        device=device,
        lr_e=0.0001,
        lr_d=0.0002,
        n_timesteps=800,     # Same as STEMDiffusion
        n_denoising_blocks=6,
        hidden_dim=256,      # Same as STEMDiffusion
        num_heads=8,
        num_hierarchical_scales=3,
        dp=0.2,
        outf=f'advanced_diffusion_{dataset_name}'
    )
    
    # Train the model
    print(f"Training model for {dataset_name}...")
    # model.train()

    model.train(
        encoder_epochs=1000,  # Stage 1: Domain alignment encoder
        vae_epochs=1000,       # Stage 2: Graph-VAE training
        diffusion_epochs=2500, # Stage 3: Latent diffusion
        lambda_struct=5.0     # Structure loss weight
    )
    
    # Store the trained model
    trained_models.append(model)
    print(f"Model {i+1} training completed!")

print(f"\nTraining completed! {len(trained_models)} models trained.")

# Test on the third dataset by creating new model instances
print(f"\n{'='*50}")
print("Testing on ST3 dataset...")
print(f"{'='*50}")

# Get common genes for testing
sc_genes = set(scadata_val.var_names)
st1_genes = set(stadata1_val.var_names)
st2_genes = set(stadata2_val.var_names)
st3_genes = set(stadata3_val.var_names)
common_genes_test = sorted(list(sc_genes & st1_genes & st2_genes & st3_genes))

print(f"Common genes for testing: {len(common_genes_test)}")

# Extract ST3 expression data
st3_expr = stadata3_val[:, common_genes_test].X
if hasattr(st3_expr, 'toarray'):
    st3_expr = st3_expr.toarray()

# We'll create dummy models just to get predictions
# Use ST1 as reference for coordinates (since we need some spatial reference)
st1_coords = stadata1_val.obsm['spatial']
st1_expr_ref = stadata1_val[:, common_genes_test].X
if hasattr(st1_expr_ref, 'toarray'):
    st1_expr_ref = st1_expr_ref.toarray()

# Convert to tensors
X_st1_ref = torch.tensor(st1_expr_ref, dtype=torch.float32).to(device)
X_st3_test = torch.tensor(st3_expr, dtype=torch.float32).to(device)
Y_st1_ref = torch.tensor(st1_coords, dtype=torch.float32).to(device)

# Get cell types
cell_types_sc = scadata_val.obs['rough_celltype'].values if 'rough_celltype' in scadata_val.obs.columns else None

all_predictions = []

# Get the number of cell types from the original training
original_cell_types = scadata_val.obs['rough_celltype'].values
unique_cell_types = np.unique(original_cell_types)
num_original_cell_types = len(unique_cell_types)

print(f"Original model has {num_original_cell_types} cell types")

# Create dummy cell types for ST3 data to match the original number
dummy_cell_types = np.random.choice(unique_cell_types, size=X_st3_test.shape[0])

# For each trained model, create a test version
for i, trained_model in enumerate(trained_models):
    print(f"Creating test model based on trained model {i+1}...")
    
    # Create a minimal model instance for testing (no training)
    test_model = AdvancedHierarchicalDiffusion(
        st_gene_expr=X_st1_ref.cpu().numpy(),  # Use ST1 as reference
        st_coords=Y_st1_ref.cpu().numpy(),     # Use ST1 coords as reference
        sc_gene_expr=X_st3_test.cpu().numpy(), # ST3 data as "SC" data
        cell_types_sc=dummy_cell_types,                    # No cell types for ST3
        transport_plan=None,               # Use transport plan from training
        D_st=None,                      # Use distance matrices from training
        D_induced=None,
        n_genes=len(common_genes_test),
        n_embedding=[512, 256, 128],
        coord_space_diameter=2.00,
        sigma=3.0,
        alpha=0.8,
        mmdbatch=1000,
        batch_size=256,
        device=device,
        lr_e=0.0001,
        lr_d=0.0002,
        n_timesteps=800,
        n_denoising_blocks=6,
        hidden_dim=256,
        num_heads=8,
        num_hierarchical_scales=3,
        dp=0.2,
        outf=f'./temp_test_model_{i}'
    )
    
    # Copy trained parameters (this is a hack, but should work)
    state_dict = trained_model.state_dict()
    # if 'ot_guidance_strength' in state_dict:
    #     del state_dict['ot_guidance_strength']
    # test_model.load_state_dict(state_dict)
    test_model.load_state_dict(trained_model.state_dict(), strict=False)
    
    # Now sample coordinates (this should work since ST3 data is the "SC" data)
    print(f"Generating predictions from test model {i+1}...")

    predicted_coords = test_model.sample_sc_coordinates_batched(
            batch_size=512  # Even smaller batches
    )
    all_predictions.append(predicted_coords)

# Continue with the rest of your evaluation code...

# Average predictions from both models
predicted_coords_avg = np.mean(all_predictions, axis=0)
print(f"Predicted coordinates shape: {predicted_coords_avg.shape}")
print(f"Ground truth coordinates shape: {st3_coords_gt_norm.shape}")

# Calculate evaluation metrics
from sklearn.metrics import mean_squared_error, mean_absolute_error
from scipy.stats import pearsonr

# MSE and MAE
mse = mean_squared_error(st3_coords_gt_norm, predicted_coords_avg)
mae = mean_absolute_error(st3_coords_gt_norm, predicted_coords_avg)

# Correlation for each dimension
corr_x, p_x = pearsonr(st3_coords_gt_norm[:, 0], predicted_coords_avg[:, 0])
corr_y, p_y = pearsonr(st3_coords_gt_norm[:, 1], predicted_coords_avg[:, 1])

print("=== VALIDATION RESULTS ===")
print(f"Mean Squared Error: {mse:.6f}")
print(f"Mean Absolute Error: {mae:.6f}")
print(f"Correlation X-dimension: {corr_x:.4f} (p={p_x:.6f})")
print(f"Correlation Y-dimension: {corr_y:.4f} (p={p_y:.6f})")

# Visualization
import matplotlib.pyplot as plt

plt.figure(figsize=(15, 5))

# Plot 1: Ground truth coordinates
plt.subplot(1, 3, 1)
plt.scatter(st3_coords_gt_norm[:, 0], st3_coords_gt_norm[:, 1], 
           c=range(len(st3_coords_gt_norm)), cmap='viridis', alpha=0.6, s=20)
plt.title('Ground Truth ST3 Coordinates\n(Isotropic Normalized)')
plt.xlabel('X coordinate')
plt.ylabel('Y coordinate')
plt.colorbar(label='Spot index')

# Plot 2: Predicted coordinates
plt.subplot(1, 3, 2)
plt.scatter(predicted_coords_avg[:, 0], predicted_coords_avg[:, 1], 
           c=range(len(predicted_coords_avg)), cmap='viridis', alpha=0.6, s=20)
plt.title('Predicted ST3 Coordinates\n(Averaged from 2 Models)')
plt.xlabel('X coordinate')
plt.ylabel('Y coordinate')
plt.colorbar(label='Spot index')

# Plot 3: Correlation plot
plt.subplot(1, 3, 3)
plt.scatter(st3_coords_gt_norm[:, 0], predicted_coords_avg[:, 0], 
           alpha=0.5, label=f'X-coord (r={corr_x:.3f})', s=15)
plt.scatter(st3_coords_gt_norm[:, 1], predicted_coords_avg[:, 1], 
           alpha=0.5, label=f'Y-coord (r={corr_y:.3f})', s=15)
plt.plot([st3_coords_gt_norm.min(), st3_coords_gt_norm.max()], 
         [st3_coords_gt_norm.min(), st3_coords_gt_norm.max()], 
         'r--', alpha=0.8, label='Perfect correlation')
plt.xlabel('Ground Truth Coordinates')
plt.ylabel('Predicted Coordinates')
plt.title('Prediction vs Ground Truth')
plt.legend()

plt.tight_layout()
plt.show()

# Additional distance-based evaluation
euclidean_distances = np.sqrt(np.sum((st3_coords_gt_norm - predicted_coords_avg)**2, axis=1))
median_distance = np.median(euclidean_distances)
mean_distance = np.mean(euclidean_distances)

print(f"\nDistance-based metrics:")
print(f"Mean Euclidean distance: {mean_distance:.6f}")
print(f"Median Euclidean distance: {median_distance:.6f}")
print(f"Max Euclidean distance: {np.max(euclidean_distances):.6f}")
print(f"Min Euclidean distance: {np.min(euclidean_distances):.6f}")

print("=== VALIDATION EXPERIMENT COMPLETED ===")

In [None]:
# Now sample coordinates (this should work since ST3 data is the "SC" data)
print(f"Generating predictions from test model {i+1}...")
all_predictions = []
for i, trained_model in enumerate(trained_models):
    print(f"Creating test model based on trained model {i+1}...")
    
    # Create a minimal model instance for testing (no training)
    test_model = AdvancedHierarchicalDiffusion(
        st_gene_expr=X_st1_ref.cpu().numpy(),  # Use ST1 as reference
        st_coords=Y_st1_ref.cpu().numpy(),     # Use ST1 coords as reference
        sc_gene_expr=X_st3_test.cpu().numpy(), # ST3 data as "SC" data
        cell_types_sc=dummy_cell_types,                    # No cell types for ST3
        transport_plan=None,               # Use transport plan from training
        D_st=None,                      # Use distance matrices from training
        D_induced=None,
        n_genes=len(common_genes_test),
        n_embedding=[512, 256, 128],
        coord_space_diameter=2.00,
        sigma=3.0,
        alpha=0.8,
        mmdbatch=1000,
        batch_size=256,
        device=device,
        lr_e=0.0001,
        lr_d=0.0002,
        n_timesteps=800,
        n_denoising_blocks=6,
        hidden_dim=256,
        num_heads=8,
        num_hierarchical_scales=3,
        dp=0.2,
        outf=f'./temp_test_model_{i}'
    )
    
    # Copy trained parameters (this is a hack, but should work)
    state_dict = trained_model.state_dict()
    # if 'ot_guidance_strength' in state_dict:
    #     del state_dict['ot_guidance_strength']
    # test_model.load_state_dict(state_dict)
    test_model.load_state_dict(trained_model.state_dict(), strict=False)
    
    # Now sample coordinates (this should work since ST3 data is the "SC" data)
    print(f"Generating predictions from test model {i+1}...")
    predicted_coords = test_model.sample_sc_coordinates_batched(
            batch_size=512  # Even smaller batches
    )
    all_predictions.append(predicted_coords)

# Continue with the rest of your evaluation code...

# Average predictions from both models
predicted_coords_avg = np.mean(all_predictions, axis=0)
print(f"Predicted coordinates shape: {predicted_coords_avg.shape}")
print(f"Ground truth coordinates shape: {st3_coords_gt_norm.shape}")

# Calculate evaluation metrics
from sklearn.metrics import mean_squared_error, mean_absolute_error
from scipy.stats import pearsonr

# MSE and MAE
mse = mean_squared_error(st3_coords_gt_norm, predicted_coords_avg)
mae = mean_absolute_error(st3_coords_gt_norm, predicted_coords_avg)

# Correlation for each dimension
corr_x, p_x = pearsonr(st3_coords_gt_norm[:, 0], predicted_coords_avg[:, 0])
corr_y, p_y = pearsonr(st3_coords_gt_norm[:, 1], predicted_coords_avg[:, 1])

print("=== VALIDATION RESULTS ===")
print(f"Mean Squared Error: {mse:.6f}")
print(f"Mean Absolute Error: {mae:.6f}")
print(f"Correlation X-dimension: {corr_x:.4f} (p={p_x:.6f})")
print(f"Correlation Y-dimension: {corr_y:.4f} (p={p_y:.6f})")

# Visualization
import matplotlib.pyplot as plt

plt.figure(figsize=(15, 5))

# Plot 1: Ground truth coordinates
plt.subplot(1, 3, 1)
plt.scatter(st3_coords_gt_norm[:, 0], st3_coords_gt_norm[:, 1], 
           c=range(len(st3_coords_gt_norm)), cmap='viridis', alpha=0.6, s=20)
plt.title('Ground Truth ST3 Coordinates\n(Isotropic Normalized)')
plt.xlabel('X coordinate')
plt.ylabel('Y coordinate')
plt.colorbar(label='Spot index')

# Plot 2: Predicted coordinates
plt.subplot(1, 3, 2)
plt.scatter(predicted_coords_avg[:, 0], predicted_coords_avg[:, 1], 
           c=range(len(predicted_coords_avg)), cmap='viridis', alpha=0.6, s=20)
plt.title('Predicted ST3 Coordinates\n(Averaged from 2 Models)')
plt.xlabel('X coordinate')
plt.ylabel('Y coordinate')
plt.colorbar(label='Spot index')

# Plot 3: Correlation plot
plt.subplot(1, 3, 3)
plt.scatter(st3_coords_gt_norm[:, 0], predicted_coords_avg[:, 0], 
           alpha=0.5, label=f'X-coord (r={corr_x:.3f})', s=15)
plt.scatter(st3_coords_gt_norm[:, 1], predicted_coords_avg[:, 1], 
           alpha=0.5, label=f'Y-coord (r={corr_y:.3f})', s=15)
plt.plot([st3_coords_gt_norm.min(), st3_coords_gt_norm.max()], 
         [st3_coords_gt_norm.min(), st3_coords_gt_norm.max()], 
         'r--', alpha=0.8, label='Perfect correlation')
plt.xlabel('Ground Truth Coordinates')
plt.ylabel('Predicted Coordinates')
plt.title('Prediction vs Ground Truth')
plt.legend()

plt.tight_layout()
plt.show()

# Additional distance-based evaluation
euclidean_distances = np.sqrt(np.sum((st3_coords_gt_norm - predicted_coords_avg)**2, axis=1))
median_distance = np.median(euclidean_distances)
mean_distance = np.mean(euclidean_distances)

print(f"\nDistance-based metrics:")
print(f"Mean Euclidean distance: {mean_distance:.6f}")
print(f"Median Euclidean distance: {median_distance:.6f}")
print(f"Max Euclidean distance: {np.max(euclidean_distances):.6f}")
print(f"Min Euclidean distance: {np.min(euclidean_distances):.6f}")

print("=== VALIDATION EXPERIMENT COMPLETED ===")

In [None]:
# Evaluate each model separately AND the average
predicted_coords_avg = np.mean(all_predictions, axis=0)
print(f"Predicted coordinates shape: {predicted_coords_avg.shape}")
print(f"Ground truth coordinates shape: {st3_coords_gt_norm.shape}")

from sklearn.metrics import mean_squared_error, mean_absolute_error
from scipy.stats import pearsonr

# Evaluate individual models
for i, pred in enumerate(all_predictions):
    mse_i = mean_squared_error(st3_coords_gt_norm, pred)
    mae_i = mean_absolute_error(st3_coords_gt_norm, pred)
    corr_x_i, p_x_i = pearsonr(st3_coords_gt_norm[:, 0], pred[:, 0])
    corr_y_i, p_y_i = pearsonr(st3_coords_gt_norm[:, 1], pred[:, 1])
    
    print(f"\n=== MODEL {i+1} RESULTS ===")
    print(f"MSE: {mse_i:.6f}, MAE: {mae_i:.6f}")
    print(f"Corr X: {corr_x_i:.4f}, Corr Y: {corr_y_i:.4f}")

# Evaluate averaged results
mse = mean_squared_error(st3_coords_gt_norm, predicted_coords_avg)
mae = mean_absolute_error(st3_coords_gt_norm, predicted_coords_avg)
corr_x, p_x = pearsonr(st3_coords_gt_norm[:, 0], predicted_coords_avg[:, 0])
corr_y, p_y = pearsonr(st3_coords_gt_norm[:, 1], predicted_coords_avg[:, 1])

print(f"\n=== AVERAGED RESULTS ===")
print(f"MSE: {mse:.6f}, MAE: {mae:.6f}")
print(f"Corr X: {corr_x:.4f}, Corr Y: {corr_y:.4f}")

In [None]:
len(all_predictions)

In [None]:
# Visualization - individual models + average + ground truth
import matplotlib.pyplot as plt

plt.figure(figsize=(20, 5))

# Plot 1: Ground truth coordinates
plt.subplot(1, 4, 1)
plt.scatter(st3_coords_gt_norm[:, 0], st3_coords_gt_norm[:, 1], 
           c=range(len(st3_coords_gt_norm)), cmap='viridis', alpha=0.6, s=20)
plt.title('Ground Truth ST3 Coordinates\n(Isotropic Normalized)')
plt.xlabel('X coordinate')
plt.ylabel('Y coordinate')
plt.colorbar(label='Spot index')

# Plot 2: Model 1 predictions
plt.subplot(1, 4, 2)
plt.scatter(all_predictions[0][:, 0], all_predictions[0][:, 1], 
           c=range(len(all_predictions[0])), cmap='viridis', alpha=0.6, s=20)
plt.title('Model 1 Predictions')
plt.xlabel('X coordinate')
plt.ylabel('Y coordinate')
plt.colorbar(label='Spot index')

# Plot 3: Model 2 predictions
plt.subplot(1, 4, 3)
plt.scatter(all_predictions[1][:, 0], all_predictions[1][:, 1], 
           c=range(len(all_predictions[1])), cmap='viridis', alpha=0.6, s=20)
plt.title('Model 2 Predictions')
plt.xlabel('X coordinate')
plt.ylabel('Y coordinate')
plt.colorbar(label='Spot index')

# Plot 4: Averaged predictions
plt.subplot(1, 4, 4)
plt.scatter(predicted_coords_avg[:, 0], predicted_coords_avg[:, 1], 
           c=range(len(predicted_coords_avg)), cmap='viridis', alpha=0.6, s=20)
plt.title('Averaged Predictions')
plt.xlabel('X coordinate')
plt.ylabel('Y coordinate')
plt.colorbar(label='Spot index')

plt.tight_layout()
plt.show()

# Correlation plots for each model
plt.figure(figsize=(15, 5))

for i, pred in enumerate(all_predictions):
    corr_x_i, _ = pearsonr(st3_coords_gt_norm[:, 0], pred[:, 0])
    corr_y_i, _ = pearsonr(st3_coords_gt_norm[:, 1], pred[:, 1])
    
    plt.subplot(1, 3, i+1)
    plt.scatter(st3_coords_gt_norm[:, 0], pred[:, 0], 
               alpha=0.5, label=f'X-coord (r={corr_x_i:.3f})', s=15)
    plt.scatter(st3_coords_gt_norm[:, 1], pred[:, 1], 
               alpha=0.5, label=f'Y-coord (r={corr_y_i:.3f})', s=15)
    plt.plot([st3_coords_gt_norm.min(), st3_coords_gt_norm.max()], 
             [st3_coords_gt_norm.min(), st3_coords_gt_norm.max()], 
             'r--', alpha=0.8, label='Perfect correlation')
    plt.xlabel('Ground Truth Coordinates')
    plt.ylabel('Predicted Coordinates')
    plt.title(f'Model {i+1} vs Ground Truth')
    plt.legend()

# Averaged correlation plot
plt.subplot(1, 3, 3)
plt.scatter(st3_coords_gt_norm[:, 0], predicted_coords_avg[:, 0], 
           alpha=0.5, label=f'X-coord (r={corr_x:.3f})', s=15)
plt.scatter(st3_coords_gt_norm[:, 1], predicted_coords_avg[:, 1], 
           alpha=0.5, label=f'Y-coord (r={corr_y:.3f})', s=15)
plt.plot([st3_coords_gt_norm.min(), st3_coords_gt_norm.max()], 
         [st3_coords_gt_norm.min(), st3_coords_gt_norm.max()], 
         'r--', alpha=0.8, label='Perfect correlation')
plt.xlabel('Ground Truth Coordinates')
plt.ylabel('Predicted Coordinates')
plt.title('Averaged Predictions vs Ground Truth')
plt.legend()

plt.tight_layout()
plt.show()

# Distance error plots for each model
euclidean_distances_all = []
for i, pred in enumerate(all_predictions):
    distances = np.sqrt(np.sum((st3_coords_gt_norm - pred)**2, axis=1))
    euclidean_distances_all.append(distances)

euclidean_distances_avg = np.sqrt(np.sum((st3_coords_gt_norm - predicted_coords_avg)**2, axis=1))

plt.figure(figsize=(15, 4))

# Distance histograms
for i, distances in enumerate(euclidean_distances_all):
    plt.subplot(1, 3, i+1)
    plt.hist(distances, bins=50, alpha=0.7, edgecolor='black')
    plt.xlabel('Euclidean Distance')
    plt.ylabel('Frequency')
    plt.title(f'Model {i+1} Error Distribution\nMean: {np.mean(distances):.4f}')

plt.subplot(1, 3, 3)
plt.hist(euclidean_distances_avg, bins=50, alpha=0.7, edgecolor='black')
plt.xlabel('Euclidean Distance')
plt.ylabel('Frequency')
plt.title(f'Averaged Model Error Distribution\nMean: {np.mean(euclidean_distances_avg):.4f}')

plt.tight_layout()
plt.show()

# Print distance metrics for all models
for i, distances in enumerate(euclidean_distances_all):
    print(f"\nModel {i+1} distance metrics:")
    print(f"Mean: {np.mean(distances):.6f}, Median: {np.median(distances):.6f}")
    print(f"Max: {np.max(distances):.6f}, Min: {np.min(distances):.6f}")

print(f"\nAveraged model distance metrics:")
print(f"Mean: {np.mean(euclidean_distances_avg):.6f}, Median: {np.median(euclidean_distances_avg):.6f}")
print(f"Max: {np.max(euclidean_distances_avg):.6f}, Min: {np.min(euclidean_distances_avg):.6f}")

In [None]:
mpl.rcParams['figure.figsize'] = (4, 4)

sc.pl.spatial(scadata,color="rough_celltype",spot_size=0.04, show=True,basis='advanced_diffusion_coords_avg',title='reconstructed')


In [None]:
mpl.rcParams['figure.figsize'] = (4, 4)
sc.pl.spatial(scadata,color="level2_celltype",groups=["PDC"],spot_size=0.04, show=True,basis='advanced_diffusion_coords_avg',title='reconstructed',na_in_legend=False)

In [None]:
# rcParams['pdf.fonttype'] = 42
# rcParams['ps.fonttype'] = 42
# figsize(4,4)
mpl.rcParams['figure.figsize'] = (4, 4)
sc.pl.spatial(scadata,color="level3_celltype",groups=["TSK"],spot_size=0.04, show=True,basis='advanced_diffusion_coords_avg',title='reconstructed',na_in_legend=False)
#save='TSK',

In [None]:
# figsize(4,4)
mpl.rcParams['figure.figsize'] = (4, 4)
sc.pl.spatial(scadata,color="level2_celltype",groups=["Tumor_KC_Cyc"],spot_size=0.04, show=True,basis='advanced_diffusion_coords_avg',title='reconstructed',na_in_legend=False,save='P2cyc')
sc.pl.spatial(scadata,color="level2_celltype",groups=["Tumor_KC_Basal"],spot_size=0.04, show=True,basis='advanced_diffusion_coords_avg',title='reconstructed',na_in_legend=False,save='P2bas')
sc.pl.spatial(scadata,color="level2_celltype",groups=["Tumor_KC_Diff"],spot_size=0.04, show=True,basis='advanced_diffusion_coords_avg',title='reconstructed',na_in_legend=False,save='P2diff')
#save='nonTSK',

In [None]:
import squidpy as sq
sq.gr.spatial_neighbors(scadata,spatial_key='advanced_diffusion_coords_avg')
sq.gr.nhood_enrichment(scadata,cluster_key='rough_celltype')
sq.gr.interaction_matrix(scadata,cluster_key='rough_celltype')
kscadata = scadata[ scadata.obs.level2_celltype.isin(['Tumor_KC_Cyc','Tumor_KC_Basal','Tumor_KC_Diff','TSK'])].copy()
sq.gr.spatial_neighbors(kscadata,spatial_key='advanced_diffusion_coords_avg')
sq.gr.nhood_enrichment(kscadata,cluster_key='level2_celltype')
# sq.pl.nhood_enrichment(kscadata, cluster_key="level2_celltype",cmap='coolwarm',save='TSKKC_new_good.png',figsize=(3,5))
sq.pl.nhood_enrichment(kscadata, cluster_key="level2_celltype",cmap='coolwarm',figsize=(3,5))


# patient 10 stuff

In [None]:
def load_and_process_cscc_data_p10():
    """
    Load and process the cSCC dataset with multiple ST replicates.
    """
    print("Loading cSCC data...")
    
    # Load SC data
    scadata_p10 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/scP10.h5ad')
    
    # Load all 3 ST datasets
    stadata1_p10 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP10rep1.h5ad')
    stadata2_p10 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP10rep2.h5ad')
    stadata3_p10 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP10rep3.h5ad')
    
    # Normalize and log transform
    for adata in [scadata_p10, stadata1_p10, stadata2_p10, stadata3_p10]:
        sc.pp.normalize_total(adata)
        sc.pp.log1p(adata)
    
    # Create rough cell types for SC data
    scadata_p10.obs['rough_celltype'] = scadata_p10.obs['level1_celltype'].astype(str)
    scadata_p10.obs.loc[scadata_p10.obs['level1_celltype']=='CLEC9A','rough_celltype'] = 'DC'
    scadata_p10.obs.loc[scadata_p10.obs['level1_celltype']=='CD1C','rough_celltype'] = 'DC'
    scadata_p10.obs.loc[scadata_p10.obs['level1_celltype']=='ASDC','rough_celltype'] = 'DC'
    scadata_p10.obs.loc[scadata_p10.obs['level1_celltype']=='PDC','rough_celltype'] = 'PDC'
    scadata_p10.obs.loc[scadata_p10.obs['level1_celltype']=='MDSC','rough_celltype'] = 'DC'
    scadata_p10.obs.loc[scadata_p10.obs['level1_celltype']=='LC','rough_celltype'] = 'DC'
    scadata_p10.obs.loc[scadata_p10.obs['level1_celltype']=='Mac','rough_celltype'] = 'Myeloid cell'
    scadata_p10.obs.loc[scadata_p10.obs['level1_celltype']=='Tcell','rough_celltype'] = 'T cell'
    scadata_p10.obs.loc[scadata_p10.obs['level2_celltype']=='TSK','rough_celltype'] = 'TSK'
    scadata_p10.obs.loc[scadata_p10.obs['level2_celltype'].isin(['Tumor_KC_Basal', 'Tumor_KC_Diff','Tumor_KC_Cyc']),'rough_celltype'] = 'NonTSK'
    
    return scadata_p10, stadata1_p10, stadata2_p10, stadata3_p10

def prepare_combined_st_for_diffusion(stadata1, stadata2, stadata3, scadata):
    """
    Combine all ST datasets for diffusion training while maintaining gene alignment.
    Key innovation: Use ALL ST data points for better training.
    """
    print("Preparing combined ST data for diffusion training...")
    
    # Get common genes between SC and all ST datasets
    sc_genes = set(scadata.var_names)
    st1_genes = set(stadata1.var_names)
    st2_genes = set(stadata2.var_names)
    st3_genes = set(stadata3.var_names)
    
    common_genes = sorted(list(sc_genes & st1_genes & st2_genes & st3_genes))
    print(f"Common genes across all datasets: {len(common_genes)}")
    
    # Extract aligned expression data
    sc_expr = scadata[:, common_genes].X
    st1_expr = stadata1[:, common_genes].X
    st2_expr = stadata2[:, common_genes].X
    st3_expr = stadata3[:, common_genes].X
    
    # Convert to dense if sparse
    if hasattr(sc_expr, 'toarray'):
        sc_expr = sc_expr.toarray()
    if hasattr(st1_expr, 'toarray'):
        st1_expr = st1_expr.toarray()
    if hasattr(st2_expr, 'toarray'):
        st2_expr = st2_expr.toarray()
    if hasattr(st3_expr, 'toarray'):
        st3_expr = st3_expr.toarray()
    
    # Get spatial coordinates
    st1_coords = stadata1.obsm['spatial']
    st2_coords = stadata2.obsm['spatial']
    st3_coords = stadata3.obsm['spatial']

    # Store separate coordinate lists for block-diagonal graph
    st_coords_list = [st1_coords, st2_coords, st3_coords]
    
    # Combine all ST data
    st_expr_combined = np.vstack([st1_expr, st2_expr, st3_expr])

    from sklearn.preprocessing import StandardScaler
    scaler = StandardScaler()
    st_expr_combined = scaler.fit_transform(st_expr_combined)

    st_coords_combined = np.vstack([st1_coords, st2_coords, st3_coords])

    sc_expr = scaler.fit_transform(sc_expr)


    
    # Create dataset labels for tracking
    dataset_labels = (['dataset1'] * len(st1_expr) + 
                     ['dataset2'] * len(st2_expr) + 
                     ['dataset3'] * len(st3_expr))
    
    print(f"Combined ST data shape: {st_expr_combined.shape}")
    print(f"Combined ST coords shape: {st_coords_combined.shape}")
    print(f"SC data shape: {sc_expr.shape}")
    
    # Convert to tensors
    X_sc = torch.tensor(sc_expr, dtype=torch.float32)
    X_st_combined = torch.tensor(st_expr_combined, dtype=torch.float32)
    Y_st_combined = st_coords_combined.astype(np.float32)
    
    return X_sc, X_st_combined, Y_st_combined, dataset_labels, common_genes, st_coords_list

# Load and process data
scadata_p10, stadata1_p10, stadata2_p10, stadata3_p10 = load_and_process_cscc_data_p10()

# Prepare combined data for diffusion
X_sc, X_st_combined, Y_st_combined, dataset_labels, common_genes, st_coords_list = prepare_combined_st_for_diffusion(
    stadata1_p10, stadata2_p10, stadata3_p10, scadata_p10
)

print(f"Data preparation complete!")
print(f"SC cells: {X_sc.shape[0]}")
print(f"Combined ST spots: {X_st_combined.shape[0]}")
print(f"Common genes: {len(common_genes)}")



In [None]:
def train_individual_advanced_diffusion_models(scadata, stadata1, stadata2, stadata3):
    """
    Train separate AdvancedHierarchicalDiffusion models for each ST dataset and average the results.
    
    Returns:
        scadata: Updated with averaged coordinates in obsm['advanced_diffusion_coords_avg']
        models_all: All trained models for further analysis
    """
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Store results from each model
    sc_coords_results = []
    models_all = []
    
    # List of ST datasets for iteration
    st_datasets = [
        (stadata1, "dataset1"),
        (stadata2, "dataset2"), 
        (stadata3, "dataset3")
    ]
    
    for i, (stadata, dataset_name) in enumerate(st_datasets):
        print(f"\n{'='*50}")
        print(f"Training AdvancedHierarchicalDiffusion model {i+1}/3 for {dataset_name}")
        print(f"{'='*50}")
        
        # Get common genes between SC and current ST dataset
        sc_genes = set(scadata.var_names)
        st_genes = set(stadata.var_names)
        common_genes = sorted(list(sc_genes & st_genes))
        
        print(f"Common genes for {dataset_name}: {len(common_genes)}")
        
        # Extract expression data
        sc_expr = scadata[:, common_genes].X
        st_expr = stadata[:, common_genes].X
        
        # Convert to dense if sparse
        if hasattr(sc_expr, 'toarray'):
            sc_expr = sc_expr.toarray()
        if hasattr(st_expr, 'toarray'):
            st_expr = st_expr.toarray()
            
        # Get spatial coordinates
        st_coords = stadata.obsm['spatial']
        
        print(f"SC data shape: {sc_expr.shape}")
        print(f"ST data shape: {st_expr.shape}")
        print(f"ST coords shape: {st_coords.shape}")
        
        # Initialize AdvancedHierarchicalDiffusion model
        model = AdvancedHierarchicalDiffusion(
            st_gene_expr=st_expr,
            st_coords=st_coords,
            sc_gene_expr=sc_expr,
            cell_types_sc=scadata.obs['rough_celltype'].values,  # No cell type labels
            transport_plan=None,  # No OT transport plan
            D_st=None,           # No distance matrices
            D_induced=None,
            n_genes=len(common_genes),
            n_embedding=[512, 256, 128],  # Same as STEMDiffusion
            coord_space_diameter=2.00,
            sigma=3.0,
            alpha=0.8,
            mmdbatch=1000,
            batch_size=256,
            device=device,
            lr_e=0.0001,
            lr_d=0.0002,
            n_timesteps=600,     # Same as STEMDiffusion
            n_denoising_blocks=4,
            hidden_dim=256,      # Same as STEMDiffusion
            num_heads=8,
            num_hierarchical_scales=3,
            dp=0.2,
            outf=f'advanced_diffusion_{dataset_name}'
        )
        
        print(f"Training model for {dataset_name}...")
        
        # Train using new Graph-VAE + Latent Diffusion pipeline
        model.train(
            encoder_epochs=1000,  # Stage 1: Domain alignment encoder
            vae_epochs=1200,       # Stage 2: Graph-VAE training
            diffusion_epochs=2500, # Stage 3: Latent diffusion
            lambda_struct=3.0     # Structure loss weight
        )
        
        model.plot_training_losses()
        
        print(f"Generating SC coordinates using model {i+1}...")
        # Sample SC coordinates using new Graph-VAE + Latent Diffusion pipeline
        # sc_coords = model.sample_sc_coordinates(
        #     n_samples=None     # Use all SC cells
        # )
        sc_coords = model.sample_sc_coordinates_batched(
            batch_size=512  # Even smaller batches
        )
        
        # Store results
        sc_coords_results.append(sc_coords)
        models_all.append(model)
        
        print(f"Model {i+1} complete! Generated coordinates shape: {sc_coords.shape}")
        
        # Clean up GPU memory
        del model
        torch.cuda.empty_cache()
    
    # Average the results from all 3 models
    print(f"\nAveraging results from {len(sc_coords_results)} models...")
    sc_coords_avg = np.mean(sc_coords_results, axis=0)
    
    # Verify shapes match
    shapes = [coords.shape for coords in sc_coords_results]
    assert all(shape == shapes[0] for shape in shapes), f"Shape mismatch: {shapes}"
    
    print(f"Final averaged coordinates shape: {sc_coords_avg.shape}")
    
    # Add to AnnData
    scadata.obsm['advanced_diffusion_coords_avg'] = sc_coords_avg
    
    # Optionally, save individual results too
    for i, coords in enumerate(sc_coords_results):
        scadata.obsm[f'advanced_diffusion_coords_rep{i+1}'] = coords
    
    print(f"\nAdvanced diffusion training complete!")
    print(f"Results saved in scadata.obsm['advanced_diffusion_coords_avg']")
    
    return scadata, models_all

# Load and process data
scadata_p10, stadata1_p10, stadata2_p10, stadata3_p10 = load_and_process_cscc_data_p10()

# Train individual AdvancedHierarchicalDiffusion models and get averaged results
scadata_p10, advanced_models_p10 = train_individual_advanced_diffusion_models(
    scadata_p10, stadata1_p10, stadata2_p10, stadata3_p10
)

print("Advanced diffusion training complete! Results saved in scadata.obsm['advanced_diffusion_coords_avg']")

# Visualize results
import matplotlib.pyplot as plt
import seaborn as sns

mpl.rcParams['figure.figsize'] = (4, 4)

my_tab20 = sns.color_palette("tab20", n_colors=20).as_hex()

# Plot 1: Averaged coordinates
plt.figure(figsize=(8, 6))
sc.pl.embedding(scadata_p10, basis='advanced_diffusion_coords_avg', color='rough_celltype',
               size=85, title='SC Advanced Diffusion Coords (Averaged)',
               palette=my_tab20, legend_loc='right margin', legend_fontsize=10)
plt.show()

# Plot 2: Individual model results
for i in range(3):
    plt.figure(figsize=(6, 5))
    sc.pl.embedding(scadata_p10, basis=f'advanced_diffusion_coords_rep{i+1}', color='rough_celltype',
                   size=85, title=f'SC Coordinates (Advanced Model {i+1})',
                   palette=my_tab20, legend_loc='right margin', legend_fontsize=10)
    plt.show()

In [None]:
# Visualize results with separate plots
import matplotlib.pyplot as plt

import matplotlib as mpl
mpl.rcParams['figure.figsize'] = (6, 6)
# import scanpy as sc
# sc.settings.set_figure_params(figsize=(4,4), dpi=100)

my_tab20 = sns.color_palette("tab20", n_colors=20).as_hex()


# Plot 1: Averaged coordinates
plt.figure(figsize=(4, 4))
sc.pl.embedding(scadata_p10, basis='advanced_diffusion_coords_avg', color='rough_celltype',
               size=85, title='SC Spatial Coordinates (Averaged from 3 Models)',
               palette=my_tab20, legend_loc='right margin', legend_fontsize=10)
plt.show()

# Plot 2: Model 1 results
plt.figure(figsize=(4, 4))
sc.pl.embedding(scadata_p10, basis='advanced_diffusion_coords_rep1', color='rough_celltype',
               size=85, title='SC Coordinates (Model 1)',
               palette=my_tab20, legend_loc='right margin', legend_fontsize=10)
plt.show()

# Plot 3: Model 2 results
plt.figure(figsize=(4, 4))
sc.pl.embedding(scadata_p10, basis='advanced_diffusion_coords_rep2', color='rough_celltype',
               size=85, title='SC Coordinates (Model 2)',
               palette=my_tab20, legend_loc='right margin', legend_fontsize=10)
plt.show()

import seaborn as sns
my_tab20 = sns.color_palette("tab20", n_colors=12).as_hex()

# Plot 4: Model 3 results
plt.figure(figsize=(4, 4))
sc.pl.embedding(scadata_p10, basis='advanced_diffusion_coords_rep3', color='rough_celltype',
               size=85, title='SC Coordinates (Model 3)',
             palette=my_tab20, legend_loc='right margin', legend_fontsize=10)
plt.show()

In [None]:
scadata_p10.obs['selection'] = (scadata_p10.obs['level2_celltype']=='TSK').astype(int)
scadata_p10.obs['selection2'] = (scadata_p10.obs['level1_celltype']=='Fibroblast').astype(int)
scadata_p10.obs['selection3'] = (scadata_p10.obs['rough_celltype']=='Epithelial').astype(int)

# figsize(6,5)
plt.figure(figsize=(6, 6))

sc.pl.spatial(scadata_p10, color=['selection','selection2','selection3','level3_celltype'], spot_size=0.025,cmap='bwr',basis='advanced_diffusion_coords_avg')

In [None]:
sc.pl.spatial(scadata_p10,color="level3_celltype",groups=["TSK"],spot_size=0.03, show=True,basis='advanced_diffusion_coords_avg',title='reconstructed',na_in_legend=False)


In [None]:
import squidpy as sq
sq.gr.spatial_neighbors(scadata_p10,spatial_key='advanced_diffusion_coords_avg')
sq.gr.nhood_enrichment(scadata_p10,cluster_key='rough_celltype')
sq.gr.interaction_matrix(scadata_p10,cluster_key='rough_celltype')
kscadata_p10 = scadata_p10[ scadata_p10.obs.level2_celltype.isin(['Tumor_KC_Cyc','Tumor_KC_Basal','Tumor_KC_Diff','TSK'])].copy()
sq.gr.spatial_neighbors(kscadata_p10,spatial_key='advanced_diffusion_coords_avg')
sq.gr.nhood_enrichment(kscadata_p10,cluster_key='level2_celltype')
# sq.pl.nhood_enrichment(kscadata, cluster_key="level2_celltype",cmap='coolwarm',save='TSKKC_new_best_p10.svg',figsize=(3,5), title=None)
# sq.pl.nhood_enrichment(kscadata, cluster_key="level2_celltype", cmap='coolwarm', save='TSKKC_new_best_p10.svg', figsize=(3,5), ylabel='')
# sq.pl.nhood_enrichment(kscadata, cluster_key="level2_celltype",cmap='coolwarm',figsize=(3,5))

fig, ax = plt.subplots(figsize=(3,5))
sq.pl.nhood_enrichment(kscadata_p10, cluster_key="level2_celltype", cmap='coolwarm', ax=ax)
ax.set_ylabel('')
# plt.savefig('TSKKC_new_best_p10.svg')
plt.show()


In [None]:
# figsize(4,4)
# sc.settings.file_format_figs = 'svg'

sc.pl.spatial(scadata_p10,color="level2_celltype",groups=["Tumor_KC_Cyc"],spot_size=0.03, show=True,basis='advanced_diffusion_coords_avg',title='reconstructed',na_in_legend=False,save='P10_cyc')
sc.pl.spatial(scadata_p10,color="level2_celltype",groups=["Tumor_KC_Basal"],spot_size=0.03, show=True,basis='advanced_diffusion_coords_avg',title='reconstructed',na_in_legend=False,save='P10_bas')
sc.pl.spatial(scadata_p10,color="level2_celltype",groups=["Tumor_KC_Diff"],spot_size=0.03, show=True,basis='advanced_diffusion_coords_avg',title='reconstructed',na_in_legend=False,save='P10_diff')
#save='nonTSK',

In [None]:
sc.pl.spatial(scadata_p10,color="level2_celltype",groups=["PDC"],spot_size=0.04, show=True,basis='advanced_diffusion_coords_avg',title='reconstructed',na_in_legend=False)

In [None]:
sc.pl.spatial(scadata_p10,color="level2_celltype",groups=['Tumor_KC_Cyc','Tumor_KC_Basal','Tumor_KC_Diff'],spot_size=0.025, 
              show=True,basis='advanced_diffusion_coords_avg',title='reconstructed',na_in_legend=False)