In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

# Cell 2: force single‐threaded BLAS
os.environ["OMP_NUM_THREADS"]       = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"

In [None]:
# Cell 3: actually cap BLAS to 1 thread
from threadpoolctl import threadpool_limits

# 'blas' covers OpenBLAS, MKL, etc.
threadpool_limits(limits=1, user_api='blas')

# now import as usual, no more warning
import numpy as np
import scipy
# … any other packages that use OpenBLAS …


In [None]:
# import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import torch
import scanpy as sc
import pandas as pd
import numpy as np
from tqdm import tqdm
from sklearn.neighbors import kneighbors_graph
from scipy.sparse.csgraph import dijkstra
from scipy.sparse import csr_matrix, issparse
from sklearn.preprocessing import normalize
import ot 
from sklearn.neighbors import NearestNeighbors
import matplotlib.pyplot as plt

In [None]:
def construct_graph_torch(X, k, mode='connectivity', metric = 'minkowski', p=2, device='cuda'):
    '''construct knn graph with torch and gpu
    args:
        X: input data containing features (torch tensor)
        k: number of neighbors for each data point
        mode: 'connectivity' or 'distance'
        metric: distance metric (now euclidean supported for gpu knn)
        p: param for minkowski (not used if metric is euclidean)
    
    Returns:
        knn graph as a pytorch sparse tensor (coo format) or dense tensor depending on mode     
    '''

    assert mode in ['connectivity', 'distance'], "mode must be 'connectivity' or 'distance'."
    assert metric == 'euclidean', "for gpu knn, only 'euclidean' metric is currently supported in this implementation"

    if mode == 'connectivity':
        include_self = True
        mode_knn = 'connectivity'
    else:
        include_self = False
        mode_knn = 'distance'

    n_samples = X.shape[0]
    knn = NearestNeighbors(n_neighbors=k, metric=metric, algorithm='auto')

    if device == 'cuda' and torch.cuda.is_available():
        X_cpu = X.cpu().numpy()
    else:
        X_cpu = X.numpy()

    knn.fit(X_cpu)
    knn_graph_cpu = kneighbors_graph(knn, k, mode=mode_knn, include_self=include_self, metric=metric) #scipy sparse matrix on cpu
    knn_graph_coo = knn_graph_cpu.tocoo()

    if mode == 'connectivity':
        knn_graph = torch.sparse_coo_tensor(torch.LongTensor([knn_graph_coo.row, knn_graph_coo.col]),
                                            torch.FloatTensor(knn_graph_coo.data),
                                            size = knn_graph_coo.shape).to(device)
    elif mode == 'distance':
        knn_graph_dense = torch.tensor(knn_graph_cpu.toarray(), dtype=torch.float32, device=device) #move to gpu as dense tensor
        knn_graph = knn_graph_dense
    
    return knn_graph
    
def distances_cal_torch(graph, type_aware=None, aware_power =2, device='cuda'):
    '''
    calculate distance matrix from graph using dijkstra's algo
    args:
        graph: knn graph (pytorch sparse or dense tensor)
        type_aware: not implemented in this torch version for simplicity
        aware_power: same ^^
        device (str): 'cpu' or 'cuda' device to use
    Returns:
        distance matrix as a torch tensor
    '''

    if isinstance(graph, torch.Tensor) and graph.is_sparse:
        graph_cpu_csr = csr_matrix(graph.cpu().to_dense().numpy())
    elif isinstance(graph, torch.Tensor) and not graph.is_sparse:
        graph_cpu_csr = csr_matrix(graph.cpu().numpy())
    else:
        graph_cpu_csr = csr_matrix(graph) #assume scipy sparse matrix if not torch tensor

    shortestPath_cpu = dijkstra(csgraph = graph_cpu_csr, directed=False, return_predecessors=False) #dijkstra on cpu
    shortestPath = torch.tensor(shortestPath_cpu, dtype=torch.float32, device=device)

    # the_max = torch.nanmax(shortestPath[shortestPath != float('inf')])
    # shortestPath[shortestPath > the_max] = the_max

    #mask out infinite distances
    mask = shortestPath != float('inf')
    if mask.any():
        the_max = torch.max(shortestPath[mask])
        shortestPath[~mask] = the_max #replace inf with max value
    else:
        the_max = 1.0 #fallback if all are inf (should not happen in connected graphs)

    original_max_distance = the_max.item()
    C_dis = shortestPath / the_max
    # C_dis = shortestPath
    # C_dis -= torch.mean(C_dis)
    return C_dis, original_max_distance

def calculate_D_sc_torch(X_sc, k_neighbors=10, graph_mode='connectivity', device='cpu'):
    '''calculate distance matrix from graph using dijkstra's algo
    args:
        graph: knn graph (torch sparse or dense tensor)
        type_aware: not implemented
        aware_power: same ^^
        
    returns:
        distanced matrix as torch tensor'''
    
    if not isinstance(X_sc, torch.Tensor):
        raise TypeError('Input X_sc must be a pytorch tensor')
    
    if device == 'cuda' and torch.cuda.is_available():
        X_sc = X_sc.cuda(device=device)
    else:
        X_sc = X_sc.cpu()
        device= 'cpu'

    print(f'using device: {device}')
    print(f'constructing knn graph...')
    # X_normalized = normalize(X_sc.cpu().numpy(), norm='l2') #normalize on cpu for sklearn knn
    X_normalized = X_sc
    X_normalized_torch = torch.tensor(X_normalized, dtype=torch.float32, device=device)

    Xgraph = construct_graph_torch(X_normalized_torch, k=k_neighbors, mode=graph_mode, metric='euclidean', device=device)

    print('calculating distances from graph....')
    D_sc, sc_max_distance = distances_cal_torch(Xgraph, device=device)

    print('D_sc calculation complete')
    
    return D_sc, sc_max_distance


In [None]:
from sklearn.neighbors import kneighbors_graph, NearestNeighbors
from scipy.sparse.csgraph import dijkstra
from scipy.sparse import csr_matrix, issparse
from sklearn.preprocessing import normalize
import ot

def construct_graph_spatial(location_array, k, mode='distance', metric='euclidean', p=2):
    '''construct KNN graph based on spatial coordinates
    args:
        location_array: spatial coordinates of spots (n-spots * 2)
        k: number of neighbors for each spot
        mode: 'connectivity' or 'distance'
        metric: distance metric for knn (p=2 is euclidean)
        p: param for minkowski if connectivity
        
    returns:
        scipy.sparse.csr_matrix: knn graph in csr format
    '''

    assert mode in ['connectivity', 'distance'], "mode must be 'connectivity' or 'distance'"
    if mode == 'connectivity':
        include_self = True
    else:
        include_self = False
    
    c_graph = kneighbors_graph(location_array, k, mode=mode, metric=metric, include_self=include_self, p=p)
    return c_graph

def distances_cal_spatial(graph, spot_ids=None, spot_types=None, aware_power=2):
    '''calculate spatial distance matrix from knn graph
    args:
        graph (scipy.sparse.csr_matrix): knn graph
        spot_ids (list, optional): list of spot ids corresponding to the rows/cols of the graph. required if type_aware is used
        spot_types (pd.Series, optinal): pandas series of spot types for type aware distance adjustment. required if type_aware is used
        aware_power (int): power for type-aware distance adjustment
        
    returns:
        sptial distance matrix'''
    shortestPath = dijkstra(csgraph = csr_matrix(graph), directed=False, return_predecessors=False)
    shortestPath = np.nan_to_num(shortestPath, nan=np.inf) #handle potential inf valyes after dijkstra

    if spot_types is not None and spot_ids is not None:
        shortestPath_df = pd.DataFrame(shortestPath, index=spot_ids, columns=spot_ids)
        shortestPath_df['id1'] = shortestPath_df.index
        shortestPath_melted = shortestPath_df.melt(id_vars=['id1'], var_name='id2', value_name='value')

        type_aware_df = pd.DataFrame({'spot': spot_ids, 'spot_type': spot_types}, index=spot_ids)
        meta1 = type_aware_df.copy()
        meta1.columns = ['id1', 'type1']
        meta2 = type_aware_df.copy()
        meta2.columns = ['id2', 'type2']

        shortestPath_melted = pd.merge(shortestPath_melted, meta1, on='id1', how='left')
        shortestPath_melted = pd.merge(shortestPath_melted, meta2, on='id2', how='left')

        shortestPath_melted['same_type'] = shortestPath_melted['type1'] == shortestPath_melted['type2']
        shortestPath_melted.loc[(~shortestPath_melted.smae_type), 'value'] = shortestPath_melted.loc[(~shortestPath_melted.same_type),
                                                                                                     'value'] * aware_power
        shortestPath_melted.drop(['type1', 'type2', 'same_type'], axis=1, inplace=True)
        shortestPath_pivot = shortestPath_melted.pivot(index='id1', columns='id2', values='value')

        order = spot_ids
        shortestPath = shortestPath_pivot[order].loc[order].values
    else:
        shortestPath = np.asarray(shortestPath) #ensure it's a numpy array

    #mask out infinite distances
    mask = shortestPath != float('inf')
    if mask.any():
        the_max = np.max(shortestPath[mask])
        shortestPath[~mask] = the_max #replace inf with max value
    else:
        the_max = 1.0 #fallback if all are inf (should not happen in connected graphs)

    #store original max distance for scale reference
    original_max_distance = the_max
    C_dis = shortestPath / the_max
    # C_dis = shortestPath
    # C_dis -= np.mean(C_dis)

    return C_dis, original_max_distance

def calculate_D_st_from_coords(spatial_coords, X_st=None, k_neighbors=10, graph_mode='distance', aware_st=False, 
                               spot_types=None, aware_power_st=2, spot_ids=None):
    '''calculates the spatial distance matrix D_st for spatial transcriptomics data directly from coordinates and optional spot types
    args:
        spatial_coords: spatial coordinates of spots (n_spots * 2)
        X_st: St gene expression data (not used for D_st calculation itself)
        k_neighbors: number of neighbors for knn graph
        graph_mode: 'connectivity or 'distance' for knn graph
        aware_st: whether to use type-aware distance adjustment
        spot_types: pandas series of spot types for type-aware adjustment
        aware_power_st: power for type-aware distance adjustment
        spot_ids: list or index of spot ids, required if spot_ids is provided
        
    returns:
        np.ndarray: spatial disance matrix D_st'''
    
    if isinstance(spatial_coords, pd.DataFrame):
        location_array = spatial_coords.values
        if spot_ids is None:
            spot_ids = spatial_coords.index.tolist() #use index of dataframe if available
    elif isinstance(spatial_coords, np.ndarray):
        location_array = spatial_coords
        if spot_ids is None:
            spot_ids = list(range(location_array.shape[0])) #generate default ids if not provided

    else:
        raise TypeError('spatial_coords must be a pandas dataframe or a numpy array')
    
    print(f'constructing {graph_mode} graph for ST data with k={k_neighbors}.....')
    Xgraph_st = construct_graph_spatial(location_array, k=k_neighbors, mode=graph_mode)
    
    if aware_st:
        if spot_types is None or spot_ids is None:
            raise ValueError('spot_types and spot_ids must be provided when aware_st=True')
        if not isinstance(spot_types, pd.Series):
            spot_types = pd.Series(spot_types, idnex=spot_ids) 
        print('applying type aware distance adjustment for ST data')
        print(f'aware power for ST: {aware_power_st}')
    else:
        spot_types = None 

    print(f'calculating spatial distances.....')
    D_st, st_max_distance = distances_cal_spatial(Xgraph_st, spot_ids=spot_ids, spot_types=spot_types, aware_power=aware_power_st)

    print('D_st calculation complete')
    return D_st, st_max_distance


def calculate_D_st_euclidean(spatial_coords):
    """
    Calculate Euclidean distance matrix for ST spots.
    
    Args:
        spatial_coords: (m_spots, 2) spatial coordinates
        
    Returns:
        D_st_euclid: (m_spots, m_spots) normalized Euclidean distance matrix
    """
    from scipy.spatial.distance import pdist, squareform
    
    if isinstance(spatial_coords, pd.DataFrame):
        coords_array = spatial_coords.values
    elif isinstance(spatial_coords, np.ndarray):
        coords_array = spatial_coords
    else:
        coords_array = np.array(spatial_coords)
    
    # Compute pairwise Euclidean distances
    D_euclid = squareform(pdist(coords_array, metric='euclidean'))
    
    # Normalize to [0,1]
    max_dist = D_euclid.max()
    if max_dist > 0:
        D_euclid = D_euclid / max_dist
    
    return D_euclid.astype(np.float32)

In [None]:
import numpy as np
import torch
import ot
from tqdm import tqdm
from scipy.sparse.csgraph import dijkstra
from scipy.sparse import csr_matrix
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances


def compute_feature_cost_matrix(X_sc, X_st, metric='euclidean'):
    """
    Compute feature dissimilarity matrix M between SC cells and ST spots.
    
    Args:
        X_sc: SC gene expression (n_cells, n_genes)
        X_st: ST gene expression (m_spots, n_genes) 
        metric: Distance metric ('euclidean', 'cosine', 'correlation')
    
    Returns:
        M: Feature cost matrix (n_cells, m_spots)
    """
    if metric == 'euclidean':
        M = euclidean_distances(X_sc, X_st)
    elif metric == 'cosine':
        # Convert cosine similarity to distance
        cos_sim = cosine_similarity(X_sc, X_st)
        M = 1.0 - cos_sim
    elif metric == 'correlation':
        # Pearson correlation distance
        M = np.zeros((X_sc.shape[0], X_st.shape[0]))
        for i in range(X_sc.shape[0]):
            for j in range(X_st.shape[0]):
                corr = np.corrcoef(X_sc[i], X_st[j])[0,1]
                M[i,j] = 1.0 - corr if not np.isnan(corr) else 1.0
    
    return M.astype(np.float32)

def compute_marginal_weights(M, method='exponential'):
    """
    Compute marginal weights for unbalanced OT from feature cost matrix.
    
    Args:
        M: Feature cost matrix (n_cells, m_spots)
        method: How to compute weights ('uniform', 'exponential', 'softmax')
    
    Returns:
        w_a: Cell weights (n_cells,)
        w_b: Spot weights (m_spots,)
    """
    if method == 'uniform':
        w_a = np.ones(M.shape[0]) / M.shape[0]
        w_b = np.ones(M.shape[1]) / M.shape[1]
    elif method == 'exponential':
        # Use exp(-M) as similarity, then normalize
        weight_matrix = np.exp(-M)
        w_a = np.sum(weight_matrix, axis=1)
        w_b = np.sum(weight_matrix, axis=0)
        w_a = w_a / np.sum(w_a)
        w_b = w_b / np.sum(w_b)
    elif method == 'softmax':
        # Softmax over each row/column
        w_a = np.sum(np.exp(-M), axis=1)
        w_b = np.sum(np.exp(-M), axis=0)
        w_a = w_a / np.sum(w_a)
        w_b = w_b / np.sum(w_b)
    
    return w_a.astype(np.float32), w_b.astype(np.float32)

def unbalanced_optimal_transport(w_a, w_b, cost_matrix, epsilon=0.1, rho=100.0, max_iter=1000, tol=1e-7):
    """
    Solve unbalanced optimal transport using Sinkhorn-like iterations.
    
    Args:
        w_a: Source marginals (n,)
        w_b: Target marginals (m,)
        cost_matrix: Transport cost (n, m)
        epsilon: Entropic regularization
        rho: KL penalty weight for unbalanced transport
        max_iter: Maximum iterations
        tol: Convergence tolerance
    
    Returns:
        gamma: Transport plan (n, m)
    """
    lmbda = rho / (rho + epsilon) if not np.isinf(rho) else 1.0
    
    w_a = w_a.reshape(-1, 1)
    w_b = w_b.reshape(-1, 1)
    
    n, m = cost_matrix.shape
    u = np.zeros((n, 1))
    v = np.zeros((m, 1))
    
    for i in range(max_iter):
        u_old = u.copy()
        
        # Update u
        K = np.exp((-cost_matrix + u @ np.ones((1, m)) + np.ones((n, 1)) @ v.T) / epsilon)
        u = lmbda * epsilon * np.log(w_a) - lmbda * epsilon * np.log(np.sum(K, axis=1, keepdims=True)) + lmbda * u
        
        # Update v  
        K = np.exp((-cost_matrix + u @ np.ones((1, m)) + np.ones((n, 1)) @ v.T) / epsilon)
        v = lmbda * epsilon * np.log(w_b) - lmbda * epsilon * np.log(np.sum(K, axis=0, keepdims=True).T) + lmbda * v
        
        # Check convergence
        if np.linalg.norm(u - u_old) < tol:
            break
    
    # Final transport plan
    gamma = np.exp((-cost_matrix + u @ np.ones((1, m)) + np.ones((n, 1)) @ v.T) / epsilon)
    
    return gamma

def structured_optimal_transport(w_a, w_b, M, D_sc, D_st, alpha=0.1, epsilon=0.1, rho=100.0, max_iter=50):
    """
    Solve structured optimal transport that aligns both features and internal geometries.
    This is the core of SpaOTsc Stage 1.
    
    Args:
        w_a: Cell marginals (n_cells,)
        w_b: Spot marginals (m_spots,) 
        M: Feature cost matrix (n_cells, m_spots)
        D_sc: SC distance matrix (n_cells, n_cells)
        D_st: ST distance matrix (m_spots, m_spots)  
        alpha: Weight for structured term (0=pure feature, 1=pure structure)
        epsilon: Entropic regularization
        rho: Unbalanced transport penalty
        max_iter: Maximum iterations
    
    Returns:
        gamma: Transport plan (n_cells, m_spots)
    """
    # Normalize distance matrices
    # D_sc_norm = D_sc / np.max(D_sc) if np.max(D_sc) > 0 else D_sc
    # D_st_norm = D_st / np.max(D_st) if np.max(D_st) > 0 else D_st
    D_sc_norm = D_sc
    D_st_norm = D_st
    
    # Initialize with uniform coupling
    w_a = w_a.reshape(-1, 1) 
    w_b = w_b.reshape(-1, 1)
    gamma = w_a @ w_b.T
    
    n, m = M.shape
    
    for iteration in range(max_iter):
        gamma_old = gamma.copy()
        
        # === STRUCTURED COST COMPUTATION ===
        # This is the key innovation: align internal geometries
        
        # Feature cost component (1-α) * M  
        cost_feature = (1.0 - alpha) * M
        
        # Structured cost component α * GW_cost
        if alpha > 0:
            # Gromov-Wasserstein structured term
            # For squared loss: L(a,b) = 0.5 * (a-b)^2
            
            # Precompute constant terms
            constC1 = 0.5 * (D_sc_norm**2) @ w_a @ np.ones((1, m))
            constC2 = np.ones((n, 1)) @ w_b.T @ (0.5 * (D_st_norm**2)).T
            constC = constC1 + constC2
            
            # Variable term: -D_sc @ gamma @ D_st  
            variable_term = D_sc_norm @ gamma @ D_st_norm.T
            
            # Full structured cost
            cost_structured = alpha * 2.0 * (constC - variable_term)
        else:
            cost_structured = 0
        
        # Total cost matrix
        total_cost = cost_feature + cost_structured
        
        # === OPTIMAL TRANSPORT STEP ===
        # Solve unbalanced OT with combined cost
        if np.isinf(rho):
            # Balanced case: use standard Sinkhorn
            gamma_new = ot.sinkhorn(w_a.flatten(), w_b.flatten(), total_cost, epsilon)
        else:
            # Unbalanced case
            gamma_new = unbalanced_optimal_transport(w_a.flatten(), w_b.flatten(), total_cost, epsilon, rho)
        
        # === LINE SEARCH UPDATE ===
        # Optimal step size for convergence
        if alpha > 0:
            CxC_diff = D_sc_norm @ (gamma_new - gamma) @ D_st_norm.T
            a = -alpha * np.sum(CxC_diff * (gamma_new - gamma))
            b = np.sum((cost_feature + alpha * constC - 2.0 * alpha * (D_sc_norm @ gamma @ D_st_norm.T)) * (gamma_new - gamma))
            
            if a > 0:
                tau = min(1.0, max(0.0, -0.5 * b / a))
            elif a + b < 0:
                tau = 1.0
            else:
                tau = 0.0
        else:
            tau = 1.0
        
        # Update with line search
        gamma = (1.0 - tau) * gamma + tau * gamma_new
        
        # Check convergence
        if np.linalg.norm(gamma - gamma_old) < 1e-6:
            print(f"Structured OT converged at iteration {iteration}")
            break
    
    return gamma

import faiss
from scipy.sparse import coo_matrix
from scipy.sparse.csgraph import shortest_path
import ot
from tqdm import tqdm
from scipy.spatial.distance import pdist

def transport_aware_geodesic_distance(gamma, Y_st_coords, sigma=3.0, k_neighbors=30):
    '''
    compute sc-sc distances using transport-weighted approach
    '''
    print(f'computing transport-aware geodasic distances (sigma={sigma}, k={k_neighbors})')

    n_cells, m_spots = gamma.shape

    #1. compute euclidean distance matrix for ST spots
    print("Computing Euclidean D_st for transport weighting...")
    D_st_euclid = calculate_D_st_euclidean(Y_st_coords)

    # 2) Compute spatial affinity kernel using Euclidean distances
    print("Computing spatial affinity kernel...")
    K = np.exp(-(D_st_euclid**2) / (2 * sigma**2))  # (m_spots, m_spots)

    # 3) Compute transport-weighted similarity matrix W
    print("Computing transport-weighted similarity matrix...")
    W = gamma.dot(K).dot(gamma.T)  # (n_cells, n_cells)

    # Make symmetric for numerical stability
    W = (W + W.T) / 2
    
    print(f"Similarity matrix W: shape={W.shape}, range=[{W.min():.6f}, {W.max():.6f}]")
    
    # 4) Build sparse k-NN graph from similarity matrix W
    print(f"Building k-NN graph with k={k_neighbors}...")

    # For each cell, find k_neighbors with highest similarity
    neighbors = np.argpartition(-W, k_neighbors, axis=1)[:, :k_neighbors]  # (n_cells, k)
    
    # Create sparse graph edges
    rows = np.repeat(np.arange(n_cells), k_neighbors)
    cols = neighbors.flatten()
    similarities = W[rows, cols]
    
    # Convert similarities to costs using log distance (more stable than 1/similarity)
    epsilon = 1e-10
    costs = -np.log(similarities + epsilon)
    
    # Make graph symmetric (undirected)
    rows_sym = np.concatenate([rows, cols])
    cols_sym = np.concatenate([cols, rows]) 
    costs_sym = np.concatenate([costs, costs])
    
    # Create sparse CSR matrix
    graph_csr = coo_matrix((costs_sym, (rows_sym, cols_sym)), 
                          shape=(n_cells, n_cells)).tocsr()
    
    print(f"Sparse graph: {len(costs_sym)} edges, density={len(costs_sym)/(n_cells**2):.4f}")
    
    # 5) Compute all-pairs shortest path distances
    print("Computing shortest path distances...")
    
    D_geodesic = shortest_path(csgraph=graph_csr, directed=False, method='D')
    
    # 6) Handle infinite distances (disconnected components)
    inf_mask = ~np.isfinite(D_geodesic)
    if inf_mask.any():
        max_finite = D_geodesic[~inf_mask].max()
        D_geodesic[inf_mask] = max_finite * 2.0
        print(f"Warning: {inf_mask.sum()} infinite distances replaced with {max_finite*2.0:.4f}")
    
    # 7) Normalize to [0,1] range
    d_min, d_max = D_geodesic.min(), D_geodesic.max()
    if d_max > d_min:
        D_geodesic = (D_geodesic - d_min) / (d_max - d_min)
    
    print(f"Final geodesic distances: range=[{D_geodesic.min():.6f}, {D_geodesic.max():.6f}]")
    
    return D_geodesic.astype(np.float32)

def knn_sparse_ot_FAISS(gamma, D_st, Y_st_coords, k_neighbors=300, batch_size=1024, epsilon=0.01, device='cuda'):
    """
    Option 1: k-NN Graph + Sparse OT + Shortest-Path
    Complete implementation following ChatGPT's specifications exactly
    """    
    n_cells, m_spots = gamma.shape
    print(f"Processing {n_cells} cells, {m_spots} spots with k={k_neighbors}")
    
    # =================== STEP 1: FAISS k-NN SEARCH ===================
    print("Building FAISS-GPU index...")
    
    # Convert to float32 for FAISS
    gamma_faiss = gamma.astype(np.float32)
    
    # Build FAISS-GPU index
    res = faiss.StandardGpuResources()
    index_flat = faiss.IndexFlatL2(m_spots)  # L2 distance in gamma space
    gpu_index = faiss.index_cpu_to_gpu(res, 0, index_flat)
    
    # Add vectors to index and search for neighbors
    gpu_index.add(gamma_faiss)
    _, indices_knn = gpu_index.search(gamma_faiss, k_neighbors + 1)  # +1 for self
    
    # =============== STEP 2: VECTORIZED NEIGHBOR EXTRACTION ===============
    # No Python loops - vectorized
    neighbors = indices_knn[:, 1:]  # Drop self-match (column 0)
    rows = np.repeat(np.arange(n_cells), k_neighbors)
    cols = neighbors.flatten()
    total_pairs = len(rows)
    
    print(f"Found {total_pairs} k-NN pairs instead of {n_cells*(n_cells-1)//2} total pairs")
    
    # ================= STEP 3: MOVE DATA TO GPU ONCE =================
    gamma_cuda = torch.from_numpy(gamma_faiss).to(device)
    spot_coords = torch.from_numpy(Y_st_coords.astype(np.float32)).to(device)  # [m, d]
    
    print(f"Spot coordinates shape: {spot_coords.shape}")
    
    # ============= STEP 4: INITIALIZE GEOMLOSS ONCE OUTSIDE LOOP =============
    try:
        from geomloss import SamplesLoss
        sinkhorn = SamplesLoss("sinkhorn", p=2, blur=epsilon)
        print("Using GeomLoss batch mode - no Python loops!")
    except ImportError:
        raise ImportError("GeomLoss required. Install with: pip install geomloss")
    
    # ========== STEP 5: BATCH PROCESSING - SINGLE GEOMLOSS CALLS ==========
    distances_list = []
    
    for batch_start in tqdm(range(0, total_pairs, batch_size), desc="Batch OT"):
        batch_end = min(batch_start + batch_size, total_pairs)
        B = batch_end - batch_start  # Actual batch size
        
        # Get batch probability distributions
        batch_rows = rows[batch_start:batch_end] 
        batch_cols = cols[batch_start:batch_end]

        P = gamma_cuda[batch_rows]  # [B, m] - source distributions
        Q = gamma_cuda[batch_cols]  # [B, m] - target distributions


        # Force consistent shapes
        P = P.reshape(B, -1)  # Force [B, m] shape
        Q = Q.reshape(B, -1)  # Force [B, m] shape

        # Ensure they have exactly the same shape
        assert P.shape == Q.shape, f"Shape mismatch: P {P.shape} vs Q {Q.shape}"

        # Normalization
        P = P / (P.sum(dim=1, keepdim=True) + 1e-8)
        Q = Q / (Q.sum(dim=1, keepdim=True) + 1e-8)

        # Clamp to avoid zeros
        P = torch.clamp(P, min=1e-8)
        Q = torch.clamp(Q, min=1e-8)

        # Re-normalize
        P = P / P.sum(dim=1, keepdim=True)
        Q = Q / Q.sum(dim=1, keepdim=True)

        # Final shape check
        # print(f"Final shapes - P: {P.shape}, Q: {Q.shape}, dims: P={P.dim()}, Q={Q.dim()}")
        
        # ONE SINGLE GEOMLOSS CALL - processes entire batch in parallel
        # GeomLoss expects: (x_i, x_j, a_i, a_j) where x are points, a are weights
        spot_coords_batch_i = spot_coords.unsqueeze(0).repeat(B, 1, 1)  # [B, m, d]
        spot_coords_batch_j = spot_coords.unsqueeze(0).repeat(B, 1, 1)  # [B, m, d]

        distances_cuda = sinkhorn(
            P,                      # [B, m]  weights for source
            spot_coords_batch_i,    # [B, m, d] points for source
            Q,                      # [B, m]  weights for target
            spot_coords_batch_j     # [B, m, d] points for target
        )
        
        # Convert to Python list and extend
        batch_distances = distances_cuda.cpu().tolist()
        distances_list.extend(batch_distances)
    
    # ============== STEP 6: BUILD SYMMETRIC SPARSE MATRIX ==============
    print("Building sparse matrix...")
    
    # Make symmetric by adding (i,j) and (j,i) entries
    all_rows = np.concatenate([rows, cols])
    all_cols = np.concatenate([cols, rows]) 
    all_distances = np.array(distances_list + distances_list)
    
    # Create sparse COO matrix
    sparse_matrix = coo_matrix((all_distances, (all_rows, all_cols)), 
                              shape=(n_cells, n_cells))
    
    # =============== STEP 7: SHORTEST PATH COMPLETION ===============
    print("Computing shortest paths to fill distance matrix...")
    
    # Convert to CSR for efficient shortest path computation
    D_spatial = shortest_path(sparse_matrix.tocsr(), directed=False, method='D')
    
    # Handle infinite distances (disconnected components)
    finite_mask = D_spatial != np.inf
    if finite_mask.any():
        max_finite_dist = np.max(D_spatial[finite_mask])
        D_spatial[~finite_mask] = max_finite_dist * 2  # Set to 2x max for disconnected pairs
    else:
        print("Warning: All distances are infinite - check k_neighbors value")
        D_spatial[D_spatial == np.inf] = 1.0  # Fallback
    
    print(f"Final distance matrix: {D_spatial.shape}, range [{D_spatial.min():.6f}, {D_spatial.max():.6f}]")
    D_min = D_spatial.min()
    D_max = D_spatial.max()

    if D_max > D_min:
        D_spatial_normalized = (D_spatial - D_min) / (D_max - D_min)
    else:
        D_spatial_normalized = D_spatial  # All same value, keep as is
    
    return D_spatial_normalized.astype(np.float32)



from sklearn.cluster import KMeans
import time

def proper_spaotsc_landmarks(gamma_transport, D_st, Y_st_coords, n_landmarks=300):
    '''
    simplified proper spaotsc stage 2 with transport-aware landmarks
    '''

    print(f'proper spaotsc with {n_landmarks} landmarks')

    #step 1: select transport aware landmarks
    spot_importance = np.sum(gamma_transport, axis=0)
    spot_importance = spot_importance / np.sum(spot_importance)

    #weight coordinates by transport importance
    weighted_coords = []
    weighted_indices = []
    for i, importance in enumerate(spot_importance):
        n_copies = max(1, int(importance * n_landmarks * 3))
        weighted_coords.extend([Y_st_coords[i]] * n_copies)
        weighted_indices.extend([i] * n_copies)

    #cluster and select landmarks
    kmeans = KMeans(n_clusters = n_landmarks, random_state=42, n_init=10)
    cluster_labels = kmeans.fit_predict(weighted_coords)

    landmark_indices = []
    for cluster_id in range(n_landmarks):
        cluster_mask = cluster_labels == cluster_id
        if not cluster_mask.any():
            continue

        cluster_indices = np.array(weighted_indices)[cluster_mask]
        unique_indices = np.unique(cluster_indices)

        if len(unique_indices) == 0:
            continue

        centroid = kmeans.cluster_centers_[cluster_id]
        distances = [np.linalg.norm(Y_st_coords[idx] - centroid) for idx in unique_indices]
        best_spot = unique_indices[np.argmin(distances)]
        landmark_indices.append(best_spot)

    landmark_indices = np.array(landmark_indices)
    print(f'selected {len(landmark_indices)} landmarks')

    #step 2: reduce to landmark space
    gamma_landmarks = gamma_transport[:, landmark_indices]
    D_st_landmarks = D_st[np.ix_(landmark_indices, landmark_indices)]

    #renormalize transport plan
    for i in range(gamma_landmarks.shape[0]):
        row_sum = np.sum(gamma_landmarks[i, :])
        if row_sum > 0:
            gamma_landmarks[i, :] = gamma_landmarks[i, :] / row_sum
        else:
            gamma_landmarks[i, :] = 1.0 / gamma_landmarks.shape[1]

    #step 3: compute pairwise wasserstein distances
    n_cells = gamma_landmarks.shape[0]
    D_spatial = np.zeros((n_cells, n_cells), dtype=np.float32)

    #normalize for numerical stability
    D_st_max = np.max(D_st_landmarks)
    if D_st_max > 0:
        D_st_norm = D_st_landmarks / D_st_max
    else:
        D_st_norm = D_st_landmarks

    # Ensure arrays are C-contiguous for POT library
    D_st_norm = np.ascontiguousarray(D_st_norm)
    gamma_landmarks = np.ascontiguousarray(gamma_landmarks)


    # FAST BATCH PROCESSING - NO INNER PYTHON LOOPS
    print("Computing pairwise Wasserstein distances with TRUE GPU batching...")

    import torch

    # Move to GPU once
    gamma_gpu = torch.tensor(gamma_landmarks, device="cuda", dtype=torch.float32)  
    D_st_gpu = torch.tensor(D_st_norm, device="cuda", dtype=torch.float32)

    # Generate all pairs at once
    n_cells = gamma_landmarks.shape[0]
    idx_i, idx_j = torch.triu_indices(n_cells, n_cells, offset=1, device='cuda')
    total_pairs = idx_i.numel()

    print(f"Processing {total_pairs} pairs in batches...")

    # Use larger batches since we're eliminating the inner loop
    batch_size = 50000  # Much larger batches now possible
    D_spatial = torch.zeros((n_cells, n_cells), device='cuda', dtype=torch.float32)

    for start in tqdm(range(0, total_pairs, batch_size), desc="True batch processing"):
        end = min(start + batch_size, total_pairs)
        
        # Get batch indices  
        batch_i = idx_i[start:end]
        batch_j = idx_j[start:end]
        
        # Get batch data - this is the key: (batch_size, n_landmarks)
        a_batch = gamma_gpu[batch_i]  # Shape: (batch_size, n_landmarks)
        b_batch = gamma_gpu[batch_j]  # Shape: (batch_size, n_landmarks)
        
        # SINGLE GPU KERNEL CALL FOR ENTIRE BATCH - NO PYTHON LOOP!
        try:
            dists = ot.sinkhorn2(
                a_batch,      # (batch_size, n_landmarks) 
                b_batch,      # (batch_size, n_landmarks)
                D_st_gpu,     # (n_landmarks, n_landmarks)
                reg=0.01, 
                numItermax=100
            )
            
            # Handle POT return format
            if isinstance(dists, (list, tuple)):
                dists = dists[0]
            
            # Restore scale
            dists = dists * D_st_max
            
            # Fill symmetric matrix directly - no loops!
            D_spatial[batch_i, batch_j] = dists
            D_spatial[batch_j, batch_i] = dists
            
        except Exception as e:
            print(f"Batch failed, falling back to smaller chunks: {e}")
            # Fallback: split this batch into smaller pieces
            mini_batch_size = 1000
            for mini_start in range(start, end, mini_batch_size):
                mini_end = min(mini_start + mini_batch_size, end)
                mini_i = idx_i[mini_start:mini_end] 
                mini_j = idx_j[mini_start:mini_end]
                
                a_mini = gamma_gpu[mini_i]
                b_mini = gamma_gpu[mini_j]
                
                mini_dists = ot.sinkhorn2(a_mini, b_mini, D_st_gpu, reg=0.01, numItermax=100)
                if isinstance(mini_dists, (list, tuple)):
                    mini_dists = mini_dists[0]
                
                mini_dists = mini_dists * D_st_max
                D_spatial[mini_i, mini_j] = mini_dists
                D_spatial[mini_j, mini_i] = mini_dists

    # Convert back to numpy
    D_spatial = D_spatial.cpu().numpy().astype(np.float32)

    print(f"GPU batch processing complete, Range: [{D_spatial.min():.6f}, {D_spatial.max():.6f}]")


def spaotsc_spatial_distance_matrix(X_sc, X_st, D_sc, D_st, Y_st,
                                  alpha=0.1, epsilon_stage1=0.1, epsilon_stage2=0.01,
                                  rho=100.0, feature_metric='euclidean',
                                  marginal_method='exponential', max_iter_stage1=50, k_neighbors=50,
                                  use_landmarks=False, n_landmarks=500,
                                  verbose=True):
    """
    Main function implementing SpaOTsc's two-stage approach for spatial distance matrix.
    
    Stage 1: Structured & Unbalanced Optimal Transport
    Stage 2: Pairwise Wasserstein Distance Computation
    
    Args:
        X_sc: SC gene expression (n_cells, n_genes) 
        X_st: ST gene expression (m_spots, n_genes)
        D_sc: SC distance matrix from k-NN graph (n_cells, n_cells)
        D_st: ST spatial distance matrix (m_spots, m_spots)
        alpha: Structure vs feature weight (0=pure feature, 1=pure structure)
        epsilon_stage1: Entropic regularization for Stage 1
        epsilon_stage2: Entropic regularization for Stage 2
        rho: Unbalanced transport penalty (np.inf for balanced)
        feature_metric: Distance metric for gene expression ('euclidean', 'cosine', 'correlation')
        marginal_method: How to compute marginal weights ('uniform', 'exponential', 'softmax')
        max_iter_stage1: Maximum iterations for structured OT
        use_landmarks: Use landmark approximation in Stage 2
        n_landmarks: Number of landmarks for approximation
        verbose: Print progress information
    
    Returns:
        D_induced_spaotsc: Spatial distance matrix for SC cells (n_cells, n_cells)
        gamma_transport: Transport plan from Stage 1 (n_cells, m_spots)
    """
    
    if verbose:
        print("=== SpaOTsc Spatial Distance Matrix Computation ===")
        print(f"SC data: {X_sc.shape}, ST data: {X_st.shape}")
        print(f"Alpha (structure weight): {alpha}")
        print(f"Feature metric: {feature_metric}")
        print(f"Marginal method: {marginal_method}")
    
    # === STAGE 1: STRUCTURED & UNBALANCED OPTIMAL TRANSPORT ===
    if verbose:
        print("\n--- Stage 1: Structured Optimal Transport ---")
    
    # Step 1.1: Compute feature cost matrix M
    if verbose:
        print("Computing feature cost matrix...")
    M = compute_feature_cost_matrix(X_sc, X_st, metric=feature_metric)
    
    # Step 1.2: Compute marginal weights
    if verbose:
        print("Computing marginal weights...")
    w_a, w_b = compute_marginal_weights(M, method=marginal_method)
    
    # Step 1.3: Solve structured optimal transport
    if verbose:
        print("Solving structured optimal transport...")
    gamma_transport = structured_optimal_transport(
        w_a, w_b, M, D_sc, D_st, 
        alpha=alpha, epsilon=epsilon_stage1, rho=rho, max_iter=max_iter_stage1
    )
    
    if verbose:
        print(f"Transport plan shape: {gamma_transport.shape}")
        print(f"Transport plan mass: {np.sum(gamma_transport):.6f}")

    cell_idx = 0  # Pick any cell
    spot_weights = gamma_transport[cell_idx, :]
    top_spots = np.argsort(spot_weights)[-10:]  # Top 10 spots

    # 2. Are these spots spatially close?
    # Get their coordinates
    top_spot_coords = Y_st[top_spots]
    # Compute pairwise distances
    spatial_spread = np.std(pdist(top_spot_coords))
    print(f"Spatial spread of top spots: {spatial_spread}")

    # 3. Compare to random
    random_spots = np.random.choice(len(Y_st), 10)
    random_spread = np.std(pdist(Y_st[random_spots]))
    print(f"Random spread: {random_spread}")
        
    # === STAGE 2: PAIRWISE WASSERSTEIN DISTANCES ===
    if verbose:
        print("\n--- Stage 2: Pairwise Wasserstein Distances ---")
    
    # Step 2.1: FAST k-NN + Sparse OT using FAISS
    # D_induced_spaotsc = knn_sparse_ot_faiss(
    #     gamma_transport, D_st,
    #     k_neighbors=50,
    #     batch_size=1024,
    #     epsilon=epsilon_stage2,
    #     device='cuda'
    # )

    if isinstance(Y_st, torch.Tensor):
        Y_st_np = Y_st.cpu().numpy()
    else:
        Y_st_np = Y_st


    D_induced_spaotsc = proper_spaotsc_landmarks(
        gamma_transport, 
        D_st if isinstance(D_st, np.ndarray) else D_st.cpu().numpy(),
        Y_st_np, 
        n_landmarks=min(300, X_st.shape[0] // 3) 
    )
    
    if verbose:
        print(f"Final spatial distance matrix shape: {D_induced_spaotsc.shape}")
        print(f"Distance range: [{np.min(D_induced_spaotsc):.6f}, {np.max(D_induced_spaotsc):.6f}]")
        print(f"Mean distance: {np.mean(D_induced_spaotsc):.6f}")
        print("=== SpaOTsc computation complete! ===")
    
    return D_induced_spaotsc, gamma_transport

# Example usage function to replace your current approach
def replace_fused_gw_with_spaotsc(X_sc, X_st, Y_st, k_neighbors=10, alpha=0.9, device='cuda'):
    """
    Drop-in replacement for your current fused_gw_torch function using SpaOTsc approach.
    
    Args:
        X_sc: SC gene expression tensor
        X_st: ST gene expression tensor  
        Y_st: ST spatial coordinates
        k_neighbors: Number of neighbors for k-NN graph
        alpha: Structure vs feature weight for SpaOTsc
        device: Device for computation
    
    Returns:
        gamma_transport: Transport plan (n_cells, m_spots)
        D_sc: SC distance matrix
        D_st: ST spatial distance matrix
        D_induced_spaotsc: SpaOTsc spatial distance matrix
        spaotsc_quality: Quality metric (transport plan mass)
    """
    
    # Convert tensors to numpy
    if isinstance(X_sc, torch.Tensor):
        X_sc_np = X_sc.cpu().numpy()
    else:
        X_sc_np = X_sc
        
    if isinstance(X_st, torch.Tensor):
        X_st_np = X_st.cpu().numpy()  
    else:
        X_st_np = X_st
        
    if isinstance(Y_st, torch.Tensor):
        Y_st_np = Y_st.cpu().numpy()
    else:
        Y_st_np = Y_st
    
    # Compute distance matrices (keeping your existing functions)
    print('Calculating SC distances with k-NN Dijkstra...')
    
    # D_sc_tensor = calculate_D_sc_torch(torch.tensor(X_sc_np), k_neighbors=k_neighbors, device=device)
    # D_sc = D_sc_tensor.cpu().numpy()

    D_sc_tensor, sc_max_distance = calculate_D_sc_torch(torch.tensor(X_sc_np), k_neighbors=k_neighbors, device=device)
    D_sc = D_sc_tensor.cpu().numpy()
    
    print('Calculating ST distances...')
    D_st, st_max_distance = calculate_D_st_from_coords(
        spatial_coords=Y_st_np, k_neighbors=50, graph_mode="distance"
    )
    
    # Apply SpaOTsc method
    print('Applying SpaOTsc spatial distance computation...')
    D_induced_spaotsc, gamma_transport = spaotsc_spatial_distance_matrix(
        X_sc_np, X_st_np, D_sc, D_st, Y_st=Y_st, k_neighbors=k_neighbors,
        alpha=alpha,  # Higher alpha = more emphasis on spatial structure
        epsilon_stage1=0.1,
        epsilon_stage2=0.1, 
        rho=50.0,
        feature_metric='euclidean',
        use_landmarks=True,  # Speed up computation
        n_landmarks=min(500, X_st_np.shape[0]),
        verbose=True
    )
    
    # Convert back to tensors if needed
    if device == 'cuda':
        D_sc = torch.tensor(D_sc, dtype=torch.float32, device=device)
        D_st = torch.tensor(D_st, dtype=torch.float32, device=device)
        D_induced_spaotsc = torch.tensor(D_induced_spaotsc, dtype=torch.float32, device=device)
        gamma_transport = torch.tensor(gamma_transport, dtype=torch.float32, device=device)
    
    # Quality metric (transport plan mass conservation)
    spaotsc_quality = np.sum(gamma_transport) if isinstance(gamma_transport, np.ndarray) else torch.sum(gamma_transport).item()
    
    print(f'SpaOTsc quality (transport mass): {spaotsc_quality:.6f}')
    
    return gamma_transport, D_sc, D_st, D_induced_spaotsc, spaotsc_quality


In [None]:
def quick_validation(D_spatial):
    """Quick validation of distance matrix properties"""
    print("🧪 Quick Validation:")
    
    # Symmetry
    symmetry_error = np.max(np.abs(D_spatial - D_spatial.T))
    print(f"   Symmetry error: {symmetry_error:.2e} (should be ~0)")
    
    # Triangle inequality (sample)
    violations = 0
    n_check = min(1000, D_spatial.shape[0])
    for _ in range(n_check):
        i, j, k = np.random.choice(D_spatial.shape[0], 3, replace=False)
        if D_spatial[i,k] > D_spatial[i,j] + D_spatial[j,k] + 1e-6:
            violations += 1
    
    print(f"   Triangle violations: {violations}/{n_check} (should be 0)")
    print(f"   Range: [{D_spatial.min():.6f}, {D_spatial.max():.6f}]")
    print(f"   Mean: {D_spatial.mean():.6f}")
    
    is_valid = symmetry_error < 1e-6 and violations == 0
    print(f"   Overall: {'✅ VALID' if is_valid else '❌ ISSUES DETECTED'}")
    
    return is_valid

def compare_methods(gamma_transport, D_st, Y_st_coords, subset_size=1000):
    """Compare your geodesic vs proper SpaOTsc on subset"""
    
    # Test on subset
    indices = np.random.choice(gamma_transport.shape[0], subset_size, replace=False)
    gamma_subset = gamma_transport[indices]
    
    print(f"🔬 Comparing methods on {subset_size} cells...")
    
    # Your current method
    print("\n1️⃣ Your geodesic method:")
    start = time.time()
    D_geodesic = transport_aware_geodesic_distance(gamma_subset, Y_st_coords, sigma=3.0, k_neighbors=30)
    geodesic_time = time.time() - start
    print(f"   Time: {geodesic_time:.2f}s")
    
    # Proper SpaOTsc  
    print("\n2️⃣ Proper SpaOTsc:")
    start = time.time()
    D_proper = proper_spaotsc_landmarks(gamma_subset, D_st, Y_st_coords, n_landmarks=200)
    proper_time = time.time() - start
    print(f"   Time: {proper_time:.2f}s")
    
    # Correlation
    correlation = np.corrcoef(D_geodesic.flatten(), D_proper.flatten())[0,1]
    print(f"\n📊 Correlation: {correlation:.4f}")
    print(f"   Speed ratio: {proper_time/geodesic_time:.2f}x")
    
    # Validation
    print(f"\n🧪 Validation:")
    geodesic_symmetry = np.max(np.abs(D_geodesic - D_geodesic.T))
    proper_symmetry = np.max(np.abs(D_proper - D_proper.T))
    print(f"   Symmetry - Geodesic: {geodesic_symmetry:.2e}, SpaOTsc: {proper_symmetry:.2e}")
    
    return correlation, proper_time/geodesic_time


In [None]:
def compute_D_induced_proper_scaling(T, D_st, n_sc, n_st):
    '''compute D_induced with proper scaling'''
    #reweight transport matrix
    T_reweight = T * n_sc
    D_induced_raw = T_reweight @ D_st @ T_reweight.t()

    #normalize to [0,1] range
    D_induced_max = torch.max(D_induced_raw[D_induced_raw > 0])
    if D_induced_max > 1e-10: #avoid dvision by very small numbers
        D_induced = D_induced_raw / D_induced_max
    else:
        D_induced = D_induced_raw

    return D_induced

def fused_gw_torch(X_sc, X_st, Y_st, alpha, k=100, G0=None, max_iter = 100, tol=1e-9, device='cuda', n_iter = 1, D_st_precomputed=None):
    n = X_sc.shape[0]
    m = X_st.shape[0]

    X_sc = X_sc.to(device)
    X_st = X_st.to(device)

    if not torch.is_tensor(Y_st):
        Y_st_tensor = torch.tensor(Y_st, dtype=torch.float32, device=device)
    else:
        Y_st_tensor = Y_st.to(device, dtype=torch.float32)

    #calculate distance matrices
    print('calculating SC distances with knn-dijkstra.....')
    D_sc, sc_max_distance = calculate_D_sc_torch(X_sc, graph_mode='distance', k_neighbors=k, device=device)

    if D_st_precomputed is not None:
        print("Using precomputed block-diagonal D_st...")
        D_st = D_st_precomputed.to(device)
        st_max_distance = 1.0 #assume already normalized
    else:
        print('Calculating ST distances.....')
        D_st, st_max_distance = calculate_D_st_from_coords(spatial_coords=Y_st, k_neighbors=50, graph_mode="distance")
        D_st = torch.tensor(D_st, dtype=torch.float32, device=device)

    #get expression distance matrix
    C_exp = torch.cdist(X_sc, X_st, p=2) #euclidean distance
    C_exp = C_exp / (torch.max(C_exp) + 1e-16) #normalize

    #ensure distance matries are C-contiguouse numpy arrays for POT
    D_sc_np = D_sc.cpu().numpy()
    D_st_np = D_st.cpu().numpy()
    C_exp_np = C_exp.cpu().numpy()
    D_sc_np = np.ascontiguousarray(D_sc_np)
    D_st_np = np.ascontiguousarray(D_st_np)
    C_exp_np = np.ascontiguousarray(C_exp_np)

    #uniform distributions
    p = ot.unif(n)
    q = ot.unif(m)

    #anneal the reg param over several steps
    T_np = None
    for i in range(n_iter):
        #run fused gw with POT
        T_np, log = ot.gromov.fused_gromov_wasserstein(
            M=C_exp_np, C1=D_sc_np, C2=D_st_np,
            p=p, q=q, loss_fun='square_loss',
            alpha=alpha,
            G0=T_np if T_np is not None else (G0.cpu().numpy() if G0 is not None else None),
            log=True,
            verbose=True,
            max_iter = max_iter,
            tol_abs=tol
        )

    fgw_dist = log['fgw_dist']

    print(f'fgw distance: {fgw_dist}')

    T = torch.tensor(T_np, dtype=torch.float32, device=device)

    n_sc = X_sc.shape[0]
    n_st = X_st.shape[0]

    D_induced = compute_D_induced_proper_scaling(T, D_st, n_sc, n_st)

    return T, D_sc, D_st, D_induced, fgw_dist, sc_max_distance, st_max_distance

# patient 2 data load

In [None]:
def load_and_process_cscc_data():
    """
    Load and process the cSCC dataset with multiple ST replicates.
    """
    print("Loading cSCC data...")
    
    # Load SC data
    scadata = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/scP2.h5ad')
    
    # Load all 3 ST datasets
    stadata1 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP2.h5ad')
    stadata2 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP2rep2.h5ad')
    stadata3 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP2rep3.h5ad')
    
    # Normalize and log transform
    for adata in [scadata, stadata1, stadata2, stadata3]:
        sc.pp.normalize_total(adata)
        sc.pp.log1p(adata)
    
    # Create rough cell types for SC data
    scadata.obs['rough_celltype'] = scadata.obs['level1_celltype'].astype(str)
    scadata.obs.loc[scadata.obs['level1_celltype']=='CLEC9A','rough_celltype'] = 'DC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='CD1C','rough_celltype'] = 'DC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='ASDC','rough_celltype'] = 'DC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='PDC','rough_celltype'] = 'PDC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='MDSC','rough_celltype'] = 'DC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='LC','rough_celltype'] = 'DC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='Mac','rough_celltype'] = 'Myeloid cell'
    scadata.obs.loc[scadata.obs['level1_celltype']=='Tcell','rough_celltype'] = 'T cell'
    scadata.obs.loc[scadata.obs['level2_celltype']=='TSK','rough_celltype'] = 'TSK'
    scadata.obs.loc[scadata.obs['level2_celltype'].isin(['Tumor_KC_Basal', 'Tumor_KC_Diff','Tumor_KC_Cyc']),'rough_celltype'] = 'NonTSK'
    
    return scadata, stadata1, stadata2, stadata3

def prepare_combined_st_for_diffusion(stadata1, stadata2, stadata3, scadata):
    """
    Combine all ST datasets for diffusion training while maintaining gene alignment.
    Key innovation: Use ALL ST data points for better training.
    """
    print("Preparing combined ST data for diffusion training...")
    
    # Get common genes between SC and all ST datasets
    sc_genes = set(scadata.var_names)
    st1_genes = set(stadata1.var_names)
    st2_genes = set(stadata2.var_names)
    st3_genes = set(stadata3.var_names)
    
    common_genes = sorted(list(sc_genes & st1_genes & st2_genes & st3_genes))
    print(f"Common genes across all datasets: {len(common_genes)}")
    
    # Extract aligned expression data
    sc_expr = scadata[:, common_genes].X
    st1_expr = stadata1[:, common_genes].X
    st2_expr = stadata2[:, common_genes].X
    st3_expr = stadata3[:, common_genes].X
    
    # Convert to dense if sparse
    if hasattr(sc_expr, 'toarray'):
        sc_expr = sc_expr.toarray()
    if hasattr(st1_expr, 'toarray'):
        st1_expr = st1_expr.toarray()
    if hasattr(st2_expr, 'toarray'):
        st2_expr = st2_expr.toarray()
    if hasattr(st3_expr, 'toarray'):
        st3_expr = st3_expr.toarray()
    
    # Get spatial coordinates
    st1_coords = stadata1.obsm['spatial']
    st2_coords = stadata2.obsm['spatial']
    st3_coords = stadata3.obsm['spatial']

    # Store separate coordinate lists for block-diagonal graph
    st_coords_list = [st1_coords, st2_coords, st3_coords]
    
    # Combine all ST data
    st_expr_combined = np.vstack([st1_expr, st2_expr, st3_expr])

    from sklearn.preprocessing import StandardScaler
    scaler = StandardScaler()
    st_expr_combined = scaler.fit_transform(st_expr_combined)

    st_coords_combined = np.vstack([st1_coords, st2_coords, st3_coords])

    sc_expr = scaler.fit_transform(sc_expr)

    
    # Create dataset labels for tracking
    dataset_labels = (['dataset1'] * len(st1_expr) + 
                     ['dataset2'] * len(st2_expr) + 
                     ['dataset3'] * len(st3_expr))
    
    print(f"Combined ST data shape: {st_expr_combined.shape}")
    print(f"Combined ST coords shape: {st_coords_combined.shape}")
    print(f"SC data shape: {sc_expr.shape}")
    
    # Convert to tensors
    X_sc = torch.tensor(sc_expr, dtype=torch.float32)
    X_st_combined = torch.tensor(st_expr_combined, dtype=torch.float32)
    Y_st_combined = st_coords_combined.astype(np.float32)
    
    return X_sc, X_st_combined, Y_st_combined, dataset_labels, common_genes, st_coords_list

# Load and process data
scadata, stadata1, stadata2, stadata3 = load_and_process_cscc_data()

# Prepare combined data for diffusion
X_sc, X_st_combined, Y_st_combined, dataset_labels, common_genes, st_coords_list = prepare_combined_st_for_diffusion(
    stadata1, stadata2, stadata3, scadata
)

print(f"Data preparation complete!")
print(f"SC cells: {X_sc.shape[0]}")
print(f"Combined ST spots: {X_st_combined.shape[0]}")
print(f"Common genes: {len(common_genes)}")



# diffusion model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim.lr_scheduler as lr_scheduler
import numpy as np
from tqdm import tqdm
import os
import time
import scipy
from sklearn.neighbors import NearestNeighbors
from scipy.spatial import cKDTree
from typing import Optional, Dict, Tuple, List

# =====================================================
# PART 1: Advanced Network Components
# =====================================================

class FeatureNet(nn.Module):
    def __init__(self, n_genes, n_embedding=[512, 256, 128], dp=0):
        super(FeatureNet, self).__init__()

        self.fc1 = nn.Linear(n_genes, n_embedding[0])
        self.bn1 = nn.LayerNorm(n_embedding[0])
        self.fc2 = nn.Linear(n_embedding[0], n_embedding[1])
        self.bn2 = nn.LayerNorm(n_embedding[1])
        self.fc3 = nn.Linear(n_embedding[1], n_embedding[2])
        
        self.dp = nn.Dropout(dp)
        
    def forward(self, x, isdp=False):
        if isdp:
            x = self.dp(x)
        x = F.relu(self.bn1(self.fc1(x)))
        x = F.relu(self.bn2(self.fc2(x)))
        x = self.fc3(x)
        return x

class SinusoidalEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        half_dim = self.dim // 2
        emb = np.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=x.device) * -emb)
        emb = x.unsqueeze(-1) * emb.unsqueeze(0)
        emb = torch.cat([emb.sin(), emb.cos()], dim=-1)
        return emb

# class GeometricAttentionBlock(nn.Module):
#     """Attention mechanism that respects spatial relationships"""
#     def __init__(self, dim, num_heads=8, temperature=1.0):
#         super().__init__()
#         self.attention = nn.MultiheadAttention(dim, num_heads, batch_first=True)
#         self.norm = nn.LayerNorm(dim)
#         self.temperature = temperature
        
#     def forward(self, x, coords):
#         # x: (batch, dim)
#         # coords: (batch, 2)

#         batch_size = x.shape[0]
        
#         # Compute pairwise distances for geometric bias
#         distances = torch.cdist(coords, coords, p=2)
#         geometric_bias = -distances / self.temperature
        
#         # Apply attention with geometric bias

#         #reshape for attention: treat each cell as a sequence element
#         x_norm = self.norm(x).unsqueeze(0) #(1, batch_dim)

#         #expand mask for all heads: (num_heads, batch, batch)
#         geometric_bias = geometric_bias.unsqueeze(0).expand(
#             self.attention.num_heads, -1, -1
#         ).contiguous().view(self.attention.num_heads * batch_size, batch_size)

#         #apply attention with geometric bias
#         attended, _ = self.attention(
#             x_norm, #(1, batch, dim)
#             x_norm, 
#             x_norm,
#             attn_mask=geometric_bias #(num_heads * batch, batch)
#         )
        
#         return x + attended.squeeze(1)
    
class GeometricAttentionBlock(nn.Module):
    """Attention mechanism that respects spatial relationships"""
    def __init__(self, dim, num_heads=8, temperature=1.0):
        super().__init__()
        self.attention = nn.MultiheadAttention(dim, num_heads, batch_first=True)
        self.norm = nn.LayerNorm(dim)
        self.temperature = temperature
        
    def forward(self, x, coords):
        # x: (batch_size, dim)
        # coords: (batch_size, 2)
        
        batch_size = x.shape[0]
        
        # Compute pairwise distances for geometric bias
        distances = torch.cdist(coords, coords, p=2)
        geometric_bias = -distances / self.temperature
        
        # Treat the batch dimension as sequence dimension
        # Reshape: (batch_size, dim) -> (1, batch_size, dim)
        x_norm = self.norm(x).unsqueeze(0)  # (1, batch_size, dim)
        
        # Apply attention where each "cell" attends to all other "cells"
        # geometric_bias should be (batch_size, batch_size) which matches (seq_len, seq_len)
        attended, _ = self.attention(
            x_norm,  # (1, batch_size, dim) - query
            x_norm,  # (1, batch_size, dim) - key  
            x_norm,  # (1, batch_size, dim) - value
            attn_mask=geometric_bias  # (batch_size, batch_size)
        )
        
        return x + attended.squeeze(0)  # Remove the sequence dimension

class CellTypeEmbedding(nn.Module):
    """Learned embeddings for cell types"""
    def __init__(self, num_cell_types, embedding_dim):
        super().__init__()
        self.embedding = nn.Embedding(num_cell_types, embedding_dim)
        
    def forward(self, cell_type_indices):
        return self.embedding(cell_type_indices)

class UncertaintyHead(nn.Module):
    """Predicts coordinate uncertainty"""
    def __init__(self, input_dim, hidden_dim=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2)  # Uncertainty for x and y
        )
        
    def forward(self, x):
        return F.softplus(self.net(x)) + 0.01  # Ensure positive uncertainty

class PhysicsInformedLayer(nn.Module):
    """Incorporates cell non-overlap constraints"""
    def __init__(self, feature_dim):
        super().__init__()
        self.radius_predictor = nn.Sequential(
            nn.Linear(feature_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Softplus()
        )
        self.repulsion_strength = nn.Parameter(torch.tensor(0.1))
        
    def compute_repulsion_gradient(self, coords, radii, cell_types=None):
        """Compute repulsion forces between cells"""
        batch_size = coords.shape[0]
        
        # Compute pairwise distances
        distances = torch.cdist(coords, coords, p=2)
        
        # Compute sum of radii for each pair
        radii_sum = radii + radii.T
        
        # Compute overlap (positive when cells overlap)
        overlap = F.relu(radii_sum - distances + 1e-6)
        
        # Mask out self-interactions
        mask = (1 - torch.eye(batch_size, device=coords.device))
        overlap = overlap * mask
        
        # Compute repulsion forces
        coord_diff = coords.unsqueeze(1) - coords.unsqueeze(0)  # (B, B, 2)
        distances_safe = distances + 1e-6  # Avoid division by zero
        
        # Normalize direction vectors
        directions = coord_diff / distances_safe.unsqueeze(-1)
        
        # Apply stronger repulsion for same cell types (optional)
        if cell_types is not None:
            same_type_mask = (cell_types.unsqueeze(1) == cell_types.unsqueeze(0)).float()
            repulsion_weight = 1.0 + 0.5 * same_type_mask  # 50% stronger for same type
        else:
            # repulsion_weight = 1.0
            batch_size = coords.shape[0]
            repulsion_weight = torch.ones(batch_size, batch_size, device=coords.device)
            
        # Compute repulsion magnitude
        repulsion_magnitude = overlap.unsqueeze(-1) * repulsion_weight.unsqueeze(-1)
        
        # Sum repulsion forces from all other cells
        repulsion_forces = (repulsion_magnitude * directions * mask.unsqueeze(-1)).sum(dim=1)
        
        return repulsion_forces
        
    def forward(self, coords, features, cell_types=None):
        # Predict cell radii based on features
        radii = self.radius_predictor(features).squeeze(-1) * 0.01  # Scale to reasonable size
        
        # Compute repulsion gradient
        repulsion_grad = self.compute_repulsion_gradient(coords, radii, cell_types)
        
        return repulsion_grad * self.repulsion_strength, radii

# =====================================================
# PART 2: Hierarchical Diffusion Architecture
# =====================================================

class HierarchicalDiffusionBlock(nn.Module):
    """Multi-scale diffusion block for coarse-to-fine generation"""
    def __init__(self, dim, num_scales=3):
        super().__init__()
        self.num_scales = num_scales
        
        # Coarse-level predictor (for clusters/regions)
        self.coarse_net = nn.Sequential(
            nn.Linear(dim, dim * 2),
            nn.ReLU(),
            nn.Linear(dim * 2, dim)
        )
        
        # Fine-level predictor (for individual cells)
        self.fine_net = nn.Sequential(
            nn.Linear(dim * 2, dim * 2),  # Takes both coarse and fine features
            nn.ReLU(),
            nn.Linear(dim * 2, dim)
        )
        
        # Scale mixing weights
        self.scale_mixer = nn.Sequential(
            nn.Linear(1, 64),  # Takes timestep
            nn.ReLU(),
            nn.Linear(64, num_scales),
            nn.Softmax(dim=-1)
        )
        
    def forward(self, x, t, coarse_context=None):
        # Determine scale weights based on timestep
        scale_weights = self.scale_mixer(t.unsqueeze(-1))
        
        # Coarse prediction
        coarse_pred = self.coarse_net(x)
        
        # Fine prediction (conditioned on coarse if available)
        if coarse_context is not None:
            fine_input = torch.cat([x, coarse_context], dim=-1)
        else:
            fine_input = torch.cat([x, coarse_pred], dim=-1)
        fine_pred = self.fine_net(fine_input)
        
        # Mix scales based on timestep
        output = scale_weights[:, 0:1] * coarse_pred + scale_weights[:, 1:2] * fine_pred
        
        return output

# =====================================================
# PART 3: Main Advanced Diffusion Model
# =====================================================

class AdvancedHierarchicalDiffusion(nn.Module):
    def __init__(
        self,
        st_gene_expr,
        st_coords,
        sc_gene_expr,
        cell_types_sc=None,  # Cell type labels for SC data
        transport_plan=None,  # Optimal transport plan from domain alignment
        D_st=None,
        D_induced=None,
        n_genes=None,
        # n_embedding=128,
        n_embedding=[512, 256, 128],
        coord_space_diameter=200,
        st_max_distance=None,
        sc_max_distance=None,
        sigma=3.0,
        alpha=0.9,
        mmdbatch=0.1,
        batch_size=64,
        device='cuda',
        lr_e=0.0001,
        lr_d=0.0002,
        n_timesteps=1000,
        n_denoising_blocks=6,
        hidden_dim=512,
        num_heads=8,
        num_hierarchical_scales=3,
        dp=0.1,
        outf='output'
    ):
        super().__init__()
        
        self.device = device
        self.batch_size = batch_size
        self.n_timesteps = n_timesteps
        self.sigma = sigma
        self.alpha = alpha
        self.mmdbatch = mmdbatch
        self.n_embedding = n_embedding
        
        # Create output directory
        self.outf = outf
        if not os.path.exists(outf):
            os.makedirs(outf)
        
        # Store data
        self.st_gene_expr = torch.tensor(st_gene_expr, dtype=torch.float32).to(device)
        self.st_coords = torch.tensor(st_coords, dtype=torch.float32).to(device)
        self.sc_gene_expr = torch.tensor(sc_gene_expr, dtype=torch.float32).to(device)
        
        # Store transport plan if provided
        self.transport_plan = torch.tensor(transport_plan, dtype=torch.float32).to(device) if transport_plan is not None else None
        
        # Process cell types
        if cell_types_sc is not None:
            # Convert cell type strings to indices
            unique_cell_types = np.unique(cell_types_sc)
            self.cell_type_to_idx = {ct: i for i, ct in enumerate(unique_cell_types)}
            self.num_cell_types = len(unique_cell_types)
            cell_type_indices = [self.cell_type_to_idx[ct] for ct in cell_types_sc]
            self.sc_cell_types = torch.tensor(cell_type_indices, dtype=torch.long).to(device)
        else:
            self.sc_cell_types = None
            self.num_cell_types = 0
            
        # Store distance matrices
        self.D_st = torch.tensor(D_st, dtype=torch.float32).to(device) if D_st is not None else None
        self.D_induced = torch.tensor(D_induced, dtype=torch.float32).to(device) if D_induced is not None else None

        # If D_st is not provided, calculate it from spatial coordinates
        if self.D_st is None:
            print("D_st not provided, calculating from spatial coordinates...")
            if isinstance(st_coords, torch.Tensor):
                st_coords_np = st_coords.cpu().numpy()
            else:
                st_coords_np = st_coords
            
            D_st_np, st_max_distance = calculate_D_st_from_coords(
                spatial_coords=st_coords_np, 
                k_neighbors=50, 
                graph_mode="distance"
            )
            self.D_st = torch.tensor(D_st_np, dtype=torch.float32).to(device)
            self.st_max_distance = st_max_distance
            print(f"D_st calculated, shape: {self.D_st.shape}")

        # If D_induced is not provided, calculate it using fused Gromov-Wasserstein
        if self.D_induced is None and transport_plan is None:
            print("D_induced not provided, calculating using Fused Gromov-Wasserstein...")
            try:
                # Calculate using fused GW if available
                T_opt, D_sc, D_st_calc, D_induced_calc, _, sc_max_dist, st_max_dist = fused_gw_torch(
                    X_sc=self.sc_gene_expr,
                    X_st=self.st_gene_expr, 
                    Y_st=self.st_coords,
                    alpha=0.9,
                    k=100,
                    device=device
                )
                self.D_induced = D_induced_calc
                self.transport_plan = T_opt
                if self.D_st is None:  # Use the calculated D_st if we don't have one
                    self.D_st = D_st_calc
                print(f"D_induced calculated using FGW, shape: {self.D_induced.shape}")
            except Exception as e:
                print(f"FGW calculation failed: {e}")
                print("Computing simple D_induced approximation...")
                # Simple fallback: use identity matrix scaled by D_st
                n_sc = self.sc_gene_expr.shape[0]
                self.D_induced = torch.eye(n_sc, device=device) * self.D_st.mean()

        print(f"Final matrices - D_st: {self.D_st.shape if self.D_st is not None else None}, "
            f"D_induced: {self.D_induced.shape if self.D_induced is not None else None}")
        
        # Normalize coordinates
        self.st_coords_norm, self.coords_center, self.coords_radius = self.normalize_coordinates_isotropic(self.st_coords)
        
        # Model parameters
        self.n_genes = n_genes or st_gene_expr.shape[1]
        
        # ========== FEATURE ENCODER ==========
        self.netE = self.build_feature_encoder(self.n_genes, n_embedding, dp)

        self.train_log = os.path.join(outf, 'train.log')

        
        # ========== CELL TYPE EMBEDDING ==========

        use_cell_types = (cell_types_sc is not None)  # Check if SC data has cell types
        self.use_cell_types = use_cell_types

        if self.num_cell_types > 0:
            self.cell_type_embedding = CellTypeEmbedding(self.num_cell_types, n_embedding[-1] // 2)
            total_feature_dim = n_embedding[-1] + n_embedding[-1] // 2
        else:
            self.cell_type_embedding = None
            total_feature_dim = n_embedding[-1]
            
        # ========== HIERARCHICAL DIFFUSION COMPONENTS ==========
        # Time embedding
        self.time_embed = nn.Sequential(
            SinusoidalEmbedding(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Coordinate encoder
        self.coord_encoder = nn.Sequential(
            nn.Linear(2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Feature projection (includes cell type if available)
        self.feat_proj = nn.Sequential(
            nn.Linear(total_feature_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # ========== HIERARCHICAL DENOISING BLOCKS ==========
        self.hierarchical_blocks = nn.ModuleList([
            HierarchicalDiffusionBlock(hidden_dim, num_hierarchical_scales)
            for _ in range(n_denoising_blocks)
        ])
        
        # ========== GEOMETRIC ATTENTION ==========
        self.geometric_attention_blocks = nn.ModuleList([
            GeometricAttentionBlock(hidden_dim, num_heads)
            for _ in range(n_denoising_blocks // 2)
        ])
        
        # ========== PHYSICS-INFORMED COMPONENTS ==========
        self.physics_layer = PhysicsInformedLayer(hidden_dim)
        
        # ========== UNCERTAINTY QUANTIFICATION ==========
        self.uncertainty_head = UncertaintyHead(hidden_dim)
        
        # ========== OPTIMAL TRANSPORT GUIDANCE ==========
        if self.transport_plan is not None:
            self.ot_guidance_strength = nn.Parameter(torch.tensor(0.1))
            
        # ========== OUTPUT LAYERS ==========
        self.noise_predictor = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2)
        )
        
        # Create noise schedule
        self.noise_schedule = self.create_noise_schedule()
        
        # Optimizers
        self.setup_optimizers(lr_e, lr_d)
        
        # MMD Loss for domain alignment
        self.mmd_loss = MMDLoss()

        # Move entire model to device
        self.to(self.device)
        
    def normalize_coordinates_isotropic(self, coords):
        """Normalize coordinates isotropically to [-1, 1]"""
        center = coords.mean(dim=0)
        centered_coords = coords - center
        max_dist = torch.max(torch.norm(centered_coords, dim=1))
        normalized_coords = centered_coords / (max_dist + 1e-8)
        return normalized_coords, center, max_dist
        

    def build_feature_encoder(self, n_genes, n_embedding, dp):
        """Build the feature encoder network"""
        return FeatureNet(n_genes, n_embedding=n_embedding, dp=dp).to(self.device)
        
    def create_noise_schedule(self):
        """Create the noise schedule for diffusion"""
        betas = torch.linspace(0.0001, 0.02, self.n_timesteps, device=self.device)
        alphas = 1 - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        
        return {
            'betas': betas,
            'alphas': alphas,
            'alphas_cumprod': alphas_cumprod,
            'sqrt_alphas_cumprod': torch.sqrt(alphas_cumprod),
            'sqrt_one_minus_alphas_cumprod': torch.sqrt(1 - alphas_cumprod)
        }
        
    def setup_optimizers(self, lr_e, lr_d):
        """Setup optimizers and schedulers"""
        # Encoder optimizer
        self.optimizer_E = torch.optim.AdamW(self.netE.parameters(), lr=0.002)               
        self.scheduler_E = lr_scheduler.StepLR(self.optimizer_E, step_size=200, gamma=0.5) 

        # MMD Loss
        self.mmd_fn = MMDLoss()   
        
        # Diffusion model optimizer
        diff_params = []
        diff_params.extend(self.time_embed.parameters())
        diff_params.extend(self.coord_encoder.parameters())
        diff_params.extend(self.feat_proj.parameters())
        diff_params.extend(self.hierarchical_blocks.parameters())
        diff_params.extend(self.geometric_attention_blocks.parameters())
        diff_params.extend(self.physics_layer.parameters())
        diff_params.extend(self.uncertainty_head.parameters())
        diff_params.extend(self.noise_predictor.parameters())
        
        if self.cell_type_embedding is not None:
            diff_params.extend(self.cell_type_embedding.parameters())
            
        if self.transport_plan is not None:
            diff_params.append(self.ot_guidance_strength)
            
        self.optimizer_diff = torch.optim.Adam(diff_params, lr=lr_d, betas=(0.9, 0.999))
        self.scheduler_diff = lr_scheduler.CosineAnnealingWarmRestarts(self.optimizer_diff, T_0=500)
        
    def add_noise(self, coords, t, noise_schedule):
        """Add noise to coordinates according to the diffusion schedule"""
        noise = torch.randn_like(coords)
        sqrt_alphas_cumprod_t = noise_schedule['sqrt_alphas_cumprod'][t].view(-1, 1)
        sqrt_one_minus_alphas_cumprod_t = noise_schedule['sqrt_one_minus_alphas_cumprod'][t].view(-1, 1)
        
        noisy_coords = sqrt_alphas_cumprod_t * coords + sqrt_one_minus_alphas_cumprod_t * noise
        return noisy_coords, noise
        
    def compute_ot_guidance(self, coords_sc, features_sc):
        """Compute guidance from optimal transport plan"""
        if self.transport_plan is None:
            return torch.zeros_like(coords_sc)
            
        # Compute target positions based on transport plan
        # T_star: (n_sc, n_st), st_coords_norm: (n_st, 2)
        target_positions = torch.matmul(self.transport_plan, self.st_coords_norm)
        
        # Compute attraction force towards target positions
        attraction = target_positions - coords_sc
        
        return attraction * self.ot_guidance_strength
        
    def forward_diffusion(self, noisy_coords, t, features, cell_types=None):
        """Forward pass through the advanced diffusion model"""
        batch_size = noisy_coords.shape[0]
        
        # Encode inputs
        time_emb = self.time_embed(t)
        coord_emb = self.coord_encoder(noisy_coords)
        
        # Process features with optional cell type
        if cell_types is not None and self.cell_type_embedding is not None:
            cell_type_emb = self.cell_type_embedding(cell_types)
            combined_features = torch.cat([features, cell_type_emb], dim=-1)
        else:
            #when no cell types, pad with zeros to match expected input size
            if self.cell_type_embedding is not None:
                #create zero padding for cell type embedding
                cell_type_dim = self.n_embedding[-1] // 2
                zero_padding = torch.zeros(batch_size, cell_type_dim, device=features.device)
                combined_features = torch.cat([features, zero_padding], dim=-1)
            else:
                combined_features = features
            # combined_features = features
            
        feat_emb = self.feat_proj(combined_features)
        
        # Combine embeddings
        h = coord_emb + time_emb + feat_emb
        
        # Process through hierarchical blocks with geometric attention
        for i, block in enumerate(self.hierarchical_blocks):
            h = block(h, t)
            
            # Apply geometric attention at certain layers
            if i % 2 == 0 and i // 2 < len(self.geometric_attention_blocks):
                h = self.geometric_attention_blocks[i // 2](h, noisy_coords)
                
        # Predict noise
        noise_pred = self.noise_predictor(h)
        
        # Compute physics-informed correction
        physics_correction, cell_radii = self.physics_layer(noisy_coords, h, cell_types)
        
        # Compute uncertainty
        uncertainty = self.uncertainty_head(h)
        
        # Apply corrections based on timestep (less physics at high noise)
        # t_factor = 1 - t / self.n_timesteps  # 0 at start, 1 at end
        # noise_pred = noise_pred + t_factor * physics_correction * 0.1
        t_factor = (1 - t).unsqueeze(-1) #shape: (natch_size, 1)
        noise_pred = noise_pred + t_factor * physics_correction * 0.1
        
        return noise_pred, uncertainty, cell_radii
        
    def train_encoder(self, n_epochs=1000, ratio_start=0, ratio_end=1.0):
        """Train the STEM encoder to align ST and SC data"""
        print("Training STEM encoder...")
        
        # Log training start
        with open(self.train_log, 'a') as f:
            localtime = time.asctime(time.localtime(time.time()))
            f.write(f"{localtime} - Starting STEM encoder training\n")
            f.write(f"n_epochs={n_epochs}, ratio_start={ratio_start}, ratio_end={ratio_end}\n")
        
        # Calculate spatial adjacency matrix
        if self.sigma == 0:
            nettrue = torch.eye(self.st_coords.shape[0], device=self.device)
        else:
            nettrue = torch.tensor(scipy.spatial.distance.cdist(
                self.st_coords.cpu().numpy(), 
                self.st_coords.cpu().numpy()
            ), device=self.device).to(torch.float32)
            
            sigma = self.sigma
            nettrue = torch.exp(-nettrue**2/(2*sigma**2))/(np.sqrt(2*np.pi)*sigma)
            nettrue = F.normalize(nettrue, p=1, dim=1)
        
        # Training loop
        for epoch in range(n_epochs):
            # Schedule for circle loss weight
            ratio = ratio_start + (ratio_end - ratio_start) * min(epoch / (n_epochs * 0.8), 1.0)
            
            # Forward pass ST data
            e_seq_st = self.netE(self.st_gene_expr, True)
            
            # Sample from SC data due to large size
            sc_idx = torch.randint(0, self.sc_gene_expr.shape[0], (min(self.batch_size, self.mmdbatch),), device=self.device)
            sc_batch = self.sc_gene_expr[sc_idx]
            e_seq_sc = self.netE(sc_batch, False)
            
            # Calculate losses
            self.optimizer_E.zero_grad()
            
            # Prediction loss (equivalent to netpred in STEM)
            netpred = e_seq_st.mm(e_seq_st.t())
            loss_E_pred = F.cross_entropy(netpred, nettrue, reduction='mean')
            
            # Mapping matrices
            st2sc = F.softmax(e_seq_st.mm(e_seq_sc.t()), dim=1)
            sc2st = F.softmax(e_seq_sc.mm(e_seq_st.t()), dim=1)
            
            # Circle loss
            st2st = torch.log(st2sc.mm(sc2st) + 1e-7)
            loss_E_circle = F.kl_div(st2st, nettrue, reduction='none').sum(1).mean()
            
            # MMD loss
            ranidx = torch.randint(0, e_seq_sc.shape[0], (min(self.mmdbatch, e_seq_sc.shape[0]),), device=self.device)
            loss_E_mmd = self.mmd_fn(e_seq_st, e_seq_sc[ranidx])
            
            # Total loss
            loss_E = loss_E_pred + self.alpha * loss_E_mmd + ratio * loss_E_circle
            
            # Backward and optimize
            loss_E.backward()
            self.optimizer_E.step()
            self.scheduler_E.step()
            
            # Log progress
            if epoch % 100 == 0:
                log_msg = (f"Encoder epoch {epoch}/{n_epochs}, "
                          f"Loss_E: {loss_E.item():.6f}, "
                          f"Loss_E_pred: {loss_E_pred.item():.6f}, "
                          f"Loss_E_circle: {loss_E_circle.item():.6f}, "
                          f"Loss_E_mmd: {loss_E_mmd.item():.6f}, "
                          f"Ratio: {ratio:.4f}")
                
                print(log_msg)
                with open(self.train_log, 'a') as f:
                    f.write(log_msg + '\n')
                
                # Save checkpoint
                if epoch % 500 == 0:
                    torch.save({
                        'epoch': epoch,
                        'netE_state_dict': self.netE.state_dict(),
                        'optimizer_state_dict': self.optimizer_E.state_dict(),
                        'scheduler_state_dict': self.scheduler_E.state_dict(),
                    }, os.path.join(self.outf, f'encoder_checkpoint_epoch_{epoch}.pt'))
        
        # Save final encoder
        torch.save({
            'netE_state_dict': self.netE.state_dict(),
        }, os.path.join(self.outf, 'final_encoder.pt'))
        
        print("Encoder training complete!")
                
    def train_diffusion(self, n_epochs=2000, lambda_struct=10.0, lambda_physics=1.0, lambda_uncertainty=0.1):
        """Train the advanced diffusion model"""
        print("Training advanced hierarchical diffusion model...")
        
        # Freeze encoder
        for param in self.netE.parameters():
            param.requires_grad = False
            
        # Precompute adjacency matrix
        def compute_adjacency_matrix(distances, sigma=3.0):
            weights = torch.exp(-(distances ** 2) / (2 * sigma * sigma))
            weights = weights * (1 - torch.eye(weights.shape[0], device=self.device))
            row_sums = weights.sum(dim=1, keepdim=True)
            row_sums = torch.clamp(row_sums, min=1e-10)
            adjacency = weights / (row_sums + 1e-8)
            return adjacency
            
        st_adj = compute_adjacency_matrix(self.D_st, sigma=self.sigma)
        
        best_loss = float('inf')
        best_state = None
        
        for epoch in range(n_epochs):
            # Sample batch
            idx = torch.randperm(len(self.st_coords_norm))[:self.batch_size]
            coords = self.st_coords_norm[idx]
            features = self.st_gene_expr[idx]
            sub_adj = st_adj[idx][:, idx]
            sub_adj = sub_adj / (sub_adj.sum(dim=1, keepdim=True) + 1e-8)
            
            # Sample timesteps with curriculum
            if epoch < n_epochs // 3:
                # Early training: focus on high noise
                t = torch.randint(int(0.7 * self.n_timesteps), self.n_timesteps, (self.batch_size,), device=self.device)
            elif epoch < 2 * n_epochs // 3:
                # Mid training: balanced
                t = torch.randint(0, self.n_timesteps, (self.batch_size,), device=self.device)
            else:
                # Late training: focus on low noise (refinement)
                t = torch.randint(0, int(0.3 * self.n_timesteps), (self.batch_size,), device=self.device)
                
            # Add noise
            noisy_coords, target_noise = self.add_noise(coords, t, self.noise_schedule)
            
            # Get encoded features
            with torch.no_grad():
                encoded_features = self.netE(features)
                
            # Forward pass
            pred_noise, uncertainty, cell_radii = self.forward_diffusion(
                noisy_coords, 
                t.float() / self.n_timesteps, 
                encoded_features,
                cell_types=None  # ST data doesn't have cell types
            )
            
            # Compute losses
            # 1. Diffusion loss
            diffusion_loss = F.mse_loss(pred_noise, target_noise)
            
            # 2. Structure loss
            sqrt_alphas_cumprod_t = self.noise_schedule['sqrt_alphas_cumprod'][t].view(-1, 1)
            sqrt_one_minus_alphas_cumprod_t = self.noise_schedule['sqrt_one_minus_alphas_cumprod'][t].view(-1, 1)
            pred_coords = (noisy_coords - sqrt_one_minus_alphas_cumprod_t * pred_noise) / sqrt_alphas_cumprod_t
            
            pred_distances = torch.cdist(pred_coords, pred_coords, p=2)
            pred_adj = compute_adjacency_matrix(pred_distances, sigma=self.sigma)
            
            struct_loss = F.kl_div(
                torch.log(pred_adj + 1e-10),
                sub_adj,
                reduction='batchmean'
            )
            
            # 3. Physics loss (encourage minimum separation)
            min_distances = pred_distances + torch.eye(self.batch_size, device=self.device) * 1e6
            min_dist = min_distances.min(dim=1)[0]
            physics_loss = F.relu(0.01 - min_dist).mean()  # Penalize if cells closer than 0.01
            
            # 4. Uncertainty regularization
            uncertainty_loss = uncertainty.mean()  # Encourage lower uncertainty
            
            # Total loss
            total_loss = (
                diffusion_loss + 
                lambda_struct * struct_loss + 
                lambda_physics * physics_loss +
                lambda_uncertainty * uncertainty_loss
            )
            
            # Optimize
            self.optimizer_diff.zero_grad()
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(
                [p for p in self.parameters() if p.requires_grad], 
                1.0
            )
            self.optimizer_diff.step()
            self.scheduler_diff.step()
            
            # Save best model
            if total_loss.item() < best_loss:
                best_loss = total_loss.item()
                best_state = self.state_dict()
                
            if epoch % 100 == 0:
                print(f"Epoch {epoch}: Total loss = {total_loss.item():.6f}, "
                      f"Diff = {diffusion_loss.item():.6f}, "
                      f"Struct = {struct_loss.item():.6f}, "
                      f"Physics = {physics_loss.item():.6f}, "
                      f"Uncertainty = {uncertainty_loss.item():.6f}")
                
        # Load best model
        if best_state is not None:
            self.load_state_dict(best_state)
            
        print("Advanced diffusion training complete!")
        
    def sample_sc_coordinates(self, num_samples=5, return_uncertainty=True):
        """Sample SC coordinates with uncertainty quantification"""
        self.eval()
        
        all_samples = []
        all_uncertainties = []
        
        with torch.no_grad():
            # Get SC features and cell types
            sc_features = self.netE(self.sc_gene_expr)
            
            for sample_idx in range(num_samples):
                # Start from noise
                coords = torch.randn(len(self.sc_gene_expr), 2, device=self.device)
                
                # Reverse diffusion process
                for t in reversed(range(self.n_timesteps)):
                    t_batch = torch.full((len(coords),), t, device=self.device)
                    
                    # Predict noise and uncertainty
                    noise_pred, uncertainty, _ = self.forward_diffusion(
                        coords,
                        t_batch.float() / self.n_timesteps,
                        sc_features,
                        self.sc_cell_types
                    )
                    
                    # Apply OT guidance if available
                    if self.transport_plan is not None:
                        ot_guidance = self.compute_ot_guidance(coords, sc_features)
                        noise_pred = noise_pred - ot_guidance * 0.1
                        
                    # Denoise step
                    alpha = self.noise_schedule['alphas'][t]
                    alpha_cumprod = self.noise_schedule['alphas_cumprod'][t]
                    beta = self.noise_schedule['betas'][t]
                    
                    if t > 0:
                        noise = torch.randn_like(coords)
                        sigma = torch.sqrt(beta)
                    else:
                        noise = 0
                        sigma = 0
                        
                    coords = (1 / torch.sqrt(alpha)) * (
                        coords - (beta / torch.sqrt(1 - alpha_cumprod)) * noise_pred
                    ) + sigma * noise
                    
                all_samples.append(coords.cpu())
                all_uncertainties.append(uncertainty.cpu())
                
        # Aggregate results
        coords_mean = torch.stack(all_samples).mean(0)
        coords_std = torch.stack(all_samples).std(0)
        uncertainty_mean = torch.stack(all_uncertainties).mean(0)
        
        # Denormalize
        # coords_mean = coords_mean * self.coords_radius + self.coords_center
        
        if return_uncertainty:
            return coords_mean.numpy(), coords_std.numpy(), uncertainty_mean.numpy()
        else:
            return coords_mean.numpy()
            
    def train(self, encoder_epochs=1000, diffusion_epochs=2000, **kwargs):
        """Combined training pipeline"""
        # Train encoder
        self.train_encoder(n_epochs=encoder_epochs)
        
        # Train diffusion
        self.train_diffusion(n_epochs=diffusion_epochs, **kwargs)

# =====================================================
# PART 4: MMD Loss Implementation
# =====================================================

class MMDLoss(nn.Module):
    def __init__(self, kernel_type='rbf', kernel_mul=2.0, kernel_num=5, fix_sigma=None):
        super(MMDLoss, self).__init__()
        self.kernel_num = kernel_num
        self.kernel_mul = kernel_mul
        self.fix_sigma = fix_sigma
        self.kernel_type = kernel_type

    def guassian_kernel(self, source, target, kernel_mul, kernel_num, fix_sigma):
        n_samples = int(source.size()[0]) + int(target.size()[0])
        total = torch.cat([source, target], dim=0)
        total0 = total.unsqueeze(0).expand(
            int(total.size(0)), int(total.size(0)), int(total.size(1)))
        total1 = total.unsqueeze(1).expand(
            int(total.size(0)), int(total.size(0)), int(total.size(1)))
        L2_distance = ((total0-total1)**2).sum(2)
        if fix_sigma:
            bandwidth = fix_sigma
        else:
            bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)
        bandwidth /= kernel_mul ** (kernel_num // 2)
        bandwidth_list = [bandwidth * (kernel_mul**i)
                          for i in range(kernel_num)]
        kernel_val = [torch.exp(-L2_distance / bandwidth_temp)
                      for bandwidth_temp in bandwidth_list]
        tmp = 0
        for x in kernel_val:
            tmp += x
        return tmp

    def linear_mmd2(self, f_of_X, f_of_Y):
        loss = 0.0
        delta = f_of_X.float().mean(0) - f_of_Y.float().mean(0)
        loss = delta.dot(delta.T)
        return loss

    def forward(self, source, target):
        if self.kernel_type == 'linear':
            return self.linear_mmd2(source, target)
        elif self.kernel_type == 'rbf':
            batch_size = int(source.size()[0])
            kernels = self.guassian_kernel(
                source, target, kernel_mul=self.kernel_mul, kernel_num=self.kernel_num, fix_sigma=self.fix_sigma)
            XX = torch.mean(kernels[:batch_size, :batch_size])
            YY = torch.mean(kernels[batch_size:, batch_size:])
            XY = torch.mean(kernels[:batch_size, batch_size:])
            YX = torch.mean(kernels[batch_size:, :batch_size])
            loss = torch.mean(XX + YY - XY - YX)
            return loss

In [None]:
def load_and_process_cscc_data_individual_norm():
    """
    Load and process cSCC data with individual normalization per ST dataset.
    """
    print("Loading cSCC data with individual normalization...")
    
    # Load SC data
    scadata = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/scP2.h5ad')
    
    # Load all 3 ST datasets
    stadata1 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP2.h5ad')
    stadata2 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP2rep2.h5ad')
    stadata3 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP2rep3.h5ad')
    
    # Normalize expression data (same for all)
    for adata in [scadata, stadata1, stadata2, stadata3]:
        sc.pp.normalize_total(adata)
        sc.pp.log1p(adata)
    
    # Create rough cell types for SC data
    scadata.obs['rough_celltype'] = scadata.obs['level1_celltype'].astype(str)
    scadata.obs.loc[scadata.obs['level1_celltype']=='CLEC9A','rough_celltype'] = 'DC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='CD1C','rough_celltype'] = 'DC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='ASDC','rough_celltype'] = 'DC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='PDC','rough_celltype'] = 'PDC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='MDSC','rough_celltype'] = 'DC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='LC','rough_celltype'] = 'DC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='Mac','rough_celltype'] = 'Myeloid cell'
    scadata.obs.loc[scadata.obs['level1_celltype']=='Tcell','rough_celltype'] = 'T cell'
    scadata.obs.loc[scadata.obs['level2_celltype']=='TSK','rough_celltype'] = 'TSK'
    scadata.obs.loc[scadata.obs['level2_celltype'].isin(['Tumor_KC_Basal', 'Tumor_KC_Diff','Tumor_KC_Cyc']),'rough_celltype'] = 'NonTSK'
    
    return scadata, stadata1, stadata2, stadata3

def normalize_coordinates_individually(coords):
    """
    Normalize coordinates to [-1, 1] range individually.
    """
    coords_min = coords.min(axis=0)
    coords_max = coords.max(axis=0)
    coords_range = coords_max - coords_min
    
    # Avoid division by zero
    coords_range[coords_range == 0] = 1.0
    
    # Normalize to [-1, 1]
    coords_normalized = 2 * (coords - coords_min) / coords_range - 1
    
    return coords_normalized, coords_min, coords_max, coords_range

def prepare_individually_normalized_st_data(stadata1, stadata2, stadata3, scadata):
    """
    Normalize each ST dataset individually, then combine.
    """
    print("Preparing individually normalized ST data...")
    
    # Get common genes
    sc_genes = set(scadata.var_names)
    st1_genes = set(stadata1.var_names)
    st2_genes = set(stadata2.var_names)
    st3_genes = set(stadata3.var_names)
    
    common_genes = sorted(list(sc_genes & st1_genes & st2_genes & st3_genes))
    print(f"Common genes across all datasets: {len(common_genes)}")
    
    # Extract aligned expression data
    sc_expr = scadata[:, common_genes].X
    st1_expr = stadata1[:, common_genes].X
    st2_expr = stadata2[:, common_genes].X
    st3_expr = stadata3[:, common_genes].X
    
    # Convert to dense if sparse
    if hasattr(sc_expr, 'toarray'):
        sc_expr = sc_expr.toarray()
    if hasattr(st1_expr, 'toarray'):
        st1_expr = st1_expr.toarray()
    if hasattr(st2_expr, 'toarray'):
        st2_expr = st2_expr.toarray()
    if hasattr(st3_expr, 'toarray'):
        st3_expr = st3_expr.toarray()
    
    # Get spatial coordinates and normalize individually
    st1_coords = stadata1.obsm['spatial']
    st2_coords = stadata2.obsm['spatial']
    st3_coords = stadata3.obsm['spatial']
    
    print("Normalizing coordinates individually...")
    st1_coords_norm, st1_min, st1_max, st1_range = normalize_coordinates_individually(st1_coords)
    st2_coords_norm, st2_min, st2_max, st2_range = normalize_coordinates_individually(st2_coords)
    st3_coords_norm, st3_min, st3_max, st3_range = normalize_coordinates_individually(st3_coords)
    
    print(f"ST1 coord range: [{st1_coords_norm.min():.3f}, {st1_coords_norm.max():.3f}]")
    print(f"ST2 coord range: [{st2_coords_norm.min():.3f}, {st2_coords_norm.max():.3f}]")
    print(f"ST3 coord range: [{st3_coords_norm.min():.3f}, {st3_coords_norm.max():.3f}]")
    
    # Combine all ST data
    st_expr_combined = np.vstack([st1_expr, st2_expr, st3_expr])
    st_coords_combined = np.vstack([st1_coords_norm, st2_coords_norm, st3_coords_norm])
    
    # Create dataset metadata
    dataset_info = {
        'labels': (['dataset1'] * len(st1_expr) + 
                  ['dataset2'] * len(st2_expr) + 
                  ['dataset3'] * len(st3_expr)),
        'normalization_params': {
            'dataset1': {'min': st1_min, 'max': st1_max, 'range': st1_range},
            'dataset2': {'min': st2_min, 'max': st2_max, 'range': st2_range},
            'dataset3': {'min': st3_min, 'max': st3_max, 'range': st3_range}
        }
    }
    
    print(f"Combined ST data shape: {st_expr_combined.shape}")
    print(f"Combined ST coords shape: {st_coords_combined.shape}")
    print(f"SC data shape: {sc_expr.shape}")
    
    # Convert to tensors
    X_sc = torch.tensor(sc_expr, dtype=torch.float32)
    X_st_combined = torch.tensor(st_expr_combined, dtype=torch.float32)
    Y_st_combined = st_coords_combined.astype(np.float32)
    
    return X_sc, X_st_combined, Y_st_combined, dataset_info, common_genes

In [None]:
scadata, stadata1, stadata2, stadata3 = load_and_process_cscc_data_individual_norm()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# =====================================================
# Usage Example: Advanced Hierarchical Diffusion Model
# =====================================================

import numpy as np
import torch
import scanpy as sc
import matplotlib.pyplot as plt
import seaborn as sns

def prepare_data_with_cell_types(scadata, stadata):
    """Prepare data including cell type information"""
    # Get gene expression data
    X_sc = scadata.X.toarray() if hasattr(scadata.X, 'toarray') else scadata.X
    X_st = stadata.X.toarray() if hasattr(stadata.X, 'toarray') else stadata.X
    
    # Get spatial coordinates
    st_coords = stadata.obsm['spatial']
    
    # Get cell types for SC data
    if 'rough_celltype' in scadata.obs.columns:
        cell_types_sc = scadata.obs['rough_celltype'].values
    else:
        cell_types_sc = None
    
    return X_sc, X_st, st_coords, cell_types_sc

def train_advanced_model(scadata, stadata, T_opt=None, D_st=None, D_induced=None):
    """Train the advanced hierarchical diffusion model"""
    
    # Prepare data
    X_sc, X_st, st_coords, cell_types_sc = prepare_data_with_cell_types(scadata, stadata)
    
    # Find common genes
    common_genes = list(set(scadata.var_names) & set(stadata.var_names))
    sc_gene_idx = [scadata.var_names.get_loc(g) for g in common_genes]
    st_gene_idx = [stadata.var_names.get_loc(g) for g in common_genes]
    
    X_sc = X_sc[:, sc_gene_idx]
    X_st = X_st[:, st_gene_idx]
    
    # Initialize the advanced model
    model = AdvancedHierarchicalDiffusion(
        st_gene_expr=X_st,
        st_coords=st_coords,
        sc_gene_expr=X_sc,
        cell_types_sc=cell_types_sc,  # Pass cell type information
        transport_plan=T_opt,          # Pass optimal transport plan
        D_st=D_st,
        D_induced=D_induced,
        n_genes=len(common_genes),
        n_embedding=[512, 256, 128],               # Larger embedding for richer features
        mmdbatch=1000,
        batch_size=64,
        lr_e=0.0002,
        lr_d=0.0001,
        n_timesteps=1000,
        n_denoising_blocks=8,          # More blocks for complex modeling
        hidden_dim=512,
        num_heads=8,                   # Multi-head attention
        num_hierarchical_scales=3,     # Multi-scale generation
        dp=0.15,
        device='cuda' if torch.cuda.is_available() else 'cpu',
        outf='advanced_diffusion_output'
    )
    
    # Train the model with all advanced features
    model.train(
        encoder_epochs=1500,           # More epochs for better alignment
        diffusion_epochs=3000,         # More epochs for complex model
        lambda_struct=5.0,             # Structure preservation
        lambda_physics=2.0,            # Cell non-overlap constraints
        lambda_uncertainty=0.5         # Uncertainty regularization
    )
    
    return model

def sample_and_analyze_results(model, scadata, num_samples=10):
    """Sample coordinates with uncertainty quantification"""
    
    print(f"Sampling {num_samples} coordinate predictions...")
    
    # Sample with uncertainty
    coords_mean, coords_std, uncertainty = model.sample_sc_coordinates(
        num_samples=num_samples,
        return_uncertainty=True
    )
    
    # Add results to AnnData
    scadata.obsm['spatial_advanced'] = coords_mean
    scadata.obsm['spatial_std'] = coords_std
    scadata.obsm['spatial_uncertainty'] = uncertainty
    
    # Compute confidence scores (inverse of uncertainty)
    confidence = 1 / (1 + uncertainty.mean(axis=1))
    scadata.obs['spatial_confidence'] = confidence
    
    return scadata, coords_mean, coords_std, uncertainty

def visualize_advanced_results(scadata, coords_mean, coords_std, uncertainty):
    """Create comprehensive visualizations"""
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    
    # 1. Spatial coordinates colored by cell type
    ax = axes[0, 0]
    cell_types = scadata.obs['rough_celltype']
    unique_types = cell_types.unique()
    colors = sns.color_palette('tab20', n_colors=len(unique_types))
    
    for i, ct in enumerate(unique_types):
        mask = cell_types == ct
        ax.scatter(coords_mean[mask, 0], coords_mean[mask, 1], 
                  c=[colors[i]], label=ct, s=30, alpha=0.7)
    ax.set_title('Spatial Coordinates by Cell Type', fontsize=14)
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
    
    # 2. Coordinate uncertainty (X and Y)
    ax = axes[0, 1]
    scatter = ax.scatter(coords_mean[:, 0], coords_mean[:, 1], 
                        c=uncertainty.mean(axis=1), cmap='viridis_r', 
                        s=30, alpha=0.7)
    plt.colorbar(scatter, ax=ax, label='Uncertainty')
    ax.set_title('Spatial Uncertainty', fontsize=14)
    
    # 3. Coordinate standard deviation across samples
    ax = axes[0, 2]
    total_std = np.sqrt(coords_std[:, 0]**2 + coords_std[:, 1]**2)
    scatter = ax.scatter(coords_mean[:, 0], coords_mean[:, 1], 
                        c=total_std, cmap='plasma', s=30, alpha=0.7)
    plt.colorbar(scatter, ax=ax, label='Std Dev')
    ax.set_title('Prediction Variability', fontsize=14)
    
    # 4. Cell density heatmap
    ax = axes[1, 0]
    from scipy.stats import gaussian_kde
    xy = coords_mean.T
    z = gaussian_kde(xy)(xy)
    scatter = ax.scatter(coords_mean[:, 0], coords_mean[:, 1], 
                        c=z, cmap='hot', s=30, alpha=0.7)
    plt.colorbar(scatter, ax=ax, label='Density')
    ax.set_title('Cell Density', fontsize=14)
    
    # 5. Confidence scores
    ax = axes[1, 1]
    confidence = scadata.obs['spatial_confidence'].values
    scatter = ax.scatter(coords_mean[:, 0], coords_mean[:, 1], 
                        c=confidence, cmap='RdYlGn', s=30, alpha=0.7)
    plt.colorbar(scatter, ax=ax, label='Confidence')
    ax.set_title('Prediction Confidence', fontsize=14)
    
    # 6. Cell type composition in spatial regions
    ax = axes[1, 2]
    # Divide space into grid and show composition
    x_bins = np.linspace(coords_mean[:, 0].min(), coords_mean[:, 0].max(), 10)
    y_bins = np.linspace(coords_mean[:, 1].min(), coords_mean[:, 1].max(), 10)
    
    # Create a simple visualization of dominant cell type per region
    from matplotlib.patches import Rectangle
    for i in range(len(x_bins)-1):
        for j in range(len(y_bins)-1):
            mask = ((coords_mean[:, 0] >= x_bins[i]) & 
                   (coords_mean[:, 0] < x_bins[i+1]) & 
                   (coords_mean[:, 1] >= y_bins[j]) & 
                   (coords_mean[:, 1] < y_bins[j+1]))
            
            if mask.sum() > 0:
                # Find dominant cell type in this region
                region_types = cell_types[mask]
                if len(region_types) > 0:
                    dominant = region_types.mode()[0]
                    color_idx = list(unique_types).index(dominant)
                    rect = Rectangle((x_bins[i], y_bins[j]), 
                                   x_bins[i+1]-x_bins[i], 
                                   y_bins[j+1]-y_bins[j],
                                   facecolor=colors[color_idx], 
                                   alpha=0.3)
                    ax.add_patch(rect)
    
    # Overlay actual cells
    for i, ct in enumerate(unique_types):
        mask = cell_types == ct
        ax.scatter(coords_mean[mask, 0], coords_mean[mask, 1], 
                  c=[colors[i]], s=10, alpha=1.0)
    ax.set_title('Spatial Cell Type Regions', fontsize=14)
    ax.set_xlim(coords_mean[:, 0].min(), coords_mean[:, 0].max())
    ax.set_ylim(coords_mean[:, 1].min(), coords_mean[:, 1].max())
    
    plt.tight_layout()
    return fig

def analyze_cell_interactions(scadata, coords_mean, model):
    """Analyze cell-cell interactions and physics constraints"""
    
    # Compute pairwise distances
    from scipy.spatial.distance import pdist, squareform
    distances = squareform(pdist(coords_mean))
    
    # Get cell types
    cell_types = scadata.obs['rough_celltype'].values
    unique_types = np.unique(cell_types)
    
    # Analyze minimum distances between cell types
    min_distances = {}
    for i, type1 in enumerate(unique_types):
        for j, type2 in enumerate(unique_types):
            if i <= j:  # Include self-interactions
                mask1 = cell_types == type1
                mask2 = cell_types == type2
                
                if i == j:
                    # Same cell type - exclude self
                    sub_dist = distances[np.ix_(mask1, mask2)]
                    np.fill_diagonal(sub_dist, np.inf)
                    if sub_dist.size > 0:
                        min_dist = np.min(sub_dist[sub_dist < np.inf])
                    else:
                        min_dist = np.nan
                else:
                    # Different cell types
                    sub_dist = distances[np.ix_(mask1, mask2)]
                    min_dist = np.min(sub_dist) if sub_dist.size > 0 else np.nan
                
                min_distances[(type1, type2)] = min_dist
    
    # Create interaction matrix visualization
    interaction_matrix = np.full((len(unique_types), len(unique_types)), np.nan)
    for i, type1 in enumerate(unique_types):
        for j, type2 in enumerate(unique_types):
            key = (type1, type2) if i <= j else (type2, type1)
            if key in min_distances:
                interaction_matrix[i, j] = min_distances[key]
                interaction_matrix[j, i] = min_distances[key]
    
    # Plot
    plt.figure(figsize=(10, 8))
    mask = ~np.isnan(interaction_matrix)
    sns.heatmap(interaction_matrix, 
                annot=True, fmt='.3f', 
                xticklabels=unique_types,
                yticklabels=unique_types,
                cmap='coolwarm_r',
                mask=~mask,
                cbar_kws={'label': 'Minimum Distance'})
    plt.title('Minimum Distances Between Cell Types')
    plt.tight_layout()
    
    return min_distances, interaction_matrix

# =====================================================
# Main execution example
# =====================================================

if __name__ == "__main__":
    # Load your data
    scadata, stadata1, stadata2, stadata3 = load_and_process_cscc_data()
    
    # Combine ST data or use one dataset
    # For this example, let's use stadata1
    
    # Run optimal transport (if you have it)
    # T_opt, D_sc, D_st, D_induced = run_optimal_transport(scadata, stadata1)
    
    # Train the advanced model
    model = train_advanced_model(
        scadata, 
        stadata1,
        T_opt=None,  # Pass your OT result if available
        D_st=None,   # Pass your distance matrix
        D_induced=None
    )
    
    # Sample results with uncertainty
    scadata, coords_mean, coords_std, uncertainty = sample_and_analyze_results(
        model, scadata, num_samples=10
    )
    
    # Create visualizations
    fig = visualize_advanced_results(scadata, coords_mean, coords_std, uncertainty)
    plt.savefig('advanced_diffusion_results.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    # Analyze cell interactions
    min_distances, interaction_matrix = analyze_cell_interactions(
        scadata, coords_mean, model
    )
    plt.savefig('cell_interactions.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    # Print summary statistics
    print("\n=== Advanced Diffusion Model Results ===")
    print(f"Total cells mapped: {len(coords_mean)}")
    print(f"Average coordinate uncertainty: {uncertainty.mean():.4f}")
    print(f"Average prediction std dev: {coords_std.mean():.4f}")
    print(f"Confidence range: [{scadata.obs['spatial_confidence'].min():.3f}, "
          f"{scadata.obs['spatial_confidence'].max():.3f}]")
    
    print("\n=== Cell Type Statistics ===")
    for ct in scadata.obs['rough_celltype'].unique():
        mask = scadata.obs['rough_celltype'] == ct
        print(f"{ct}: {mask.sum()} cells, "
              f"avg confidence: {scadata.obs['spatial_confidence'][mask].mean():.3f}")
    
    print("\n=== Physics Constraints ===")
    all_distances = []
    for key, dist in min_distances.items():
        if not np.isnan(dist):
            all_distances.append(dist)
            print(f"Min distance {key[0]} - {key[1]}: {dist:.4f}")
    
    print(f"\nOverall minimum cell-cell distance: {np.min(all_distances):.4f}")
    print(f"Cells with overlaps (< 0.01): {np.sum(np.array(all_distances) < 0.01)}")

In [None]:
def train_individual_advanced_diffusion_models(scadata, stadata1, stadata2, stadata3):
    """
    Train separate AdvancedHierarchicalDiffusion models for each ST dataset and average the results.
    
    Returns:
        scadata: Updated with averaged coordinates in obsm['advanced_diffusion_coords_avg']
        models_all: All trained models for further analysis
    """
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Store results from each model
    sc_coords_results = []
    models_all = []
    
    # List of ST datasets for iteration
    st_datasets = [
        (stadata1, "dataset1"),
        (stadata2, "dataset2"), 
        (stadata3, "dataset3")
    ]
    
    for i, (stadata, dataset_name) in enumerate(st_datasets):
        print(f"\n{'='*50}")
        print(f"Training AdvancedHierarchicalDiffusion model {i+1}/3 for {dataset_name}")
        print(f"{'='*50}")
        
        # Get common genes between SC and current ST dataset
        sc_genes = set(scadata.var_names)
        st_genes = set(stadata.var_names)
        common_genes = sorted(list(sc_genes & st_genes))
        
        print(f"Common genes for {dataset_name}: {len(common_genes)}")
        
        # Extract expression data
        sc_expr = scadata[:, common_genes].X
        st_expr = stadata[:, common_genes].X
        
        # Convert to dense if sparse
        if hasattr(sc_expr, 'toarray'):
            sc_expr = sc_expr.toarray()
        if hasattr(st_expr, 'toarray'):
            st_expr = st_expr.toarray()
            
        # Get spatial coordinates
        st_coords = stadata.obsm['spatial']
        
        print(f"SC data shape: {sc_expr.shape}")
        print(f"ST data shape: {st_expr.shape}")
        print(f"ST coords shape: {st_coords.shape}")
        
        # Initialize AdvancedHierarchicalDiffusion model
        model = AdvancedHierarchicalDiffusion(
            st_gene_expr=st_expr,
            st_coords=st_coords,
            sc_gene_expr=sc_expr,
            cell_types_sc=None,  # No cell type labels
            transport_plan=None,  # No OT transport plan
            D_st=None,           # No distance matrices
            D_induced=None,
            n_genes=len(common_genes),
            n_embedding=[512, 256, 128],  # Same as STEMDiffusion
            coord_space_diameter=200,
            sigma=3.0,
            alpha=0.8,
            mmdbatch=1000,
            batch_size=256,
            device=device,
            lr_e=0.0001,
            lr_d=0.0002,
            n_timesteps=800,     # Same as STEMDiffusion
            n_denoising_blocks=6,
            hidden_dim=256,      # Same as STEMDiffusion
            num_heads=8,
            num_hierarchical_scales=3,
            dp=0.1,
            outf=f'advanced_diffusion_{dataset_name}'
        )
        
        print(f"Training model for {dataset_name}...")
        
        # Train the model with reduced epochs for speed
        model.train(
            encoder_epochs=800,      # Reduced from 1500
            diffusion_epochs=1500,   # Reduced from 3000
            lambda_struct=5.0,
            lambda_physics=2.0,
            lambda_uncertainty=0.5
        )
        
        print(f"Generating SC coordinates using model {i+1}...")
        
        # Sample SC coordinates with fast sampling (fewer steps)
        sc_coords = model.sample_sc_coordinates(
            num_samples=1,          # Single sample for averaging
            return_uncertainty=False      # Fast sampling with fewer steps
        )
        
        # Store results
        sc_coords_results.append(sc_coords)
        models_all.append(model)
        
        print(f"Model {i+1} complete! Generated coordinates shape: {sc_coords.shape}")
        
        # Clean up GPU memory
        del model
        torch.cuda.empty_cache()
    
    # Average the results from all 3 models
    print(f"\nAveraging results from {len(sc_coords_results)} models...")
    sc_coords_avg = np.mean(sc_coords_results, axis=0)
    
    # Verify shapes match
    shapes = [coords.shape for coords in sc_coords_results]
    assert all(shape == shapes[0] for shape in shapes), f"Shape mismatch: {shapes}"
    
    print(f"Final averaged coordinates shape: {sc_coords_avg.shape}")
    
    # Add to AnnData
    scadata.obsm['advanced_diffusion_coords_avg'] = sc_coords_avg
    
    # Optionally, save individual results too
    for i, coords in enumerate(sc_coords_results):
        scadata.obsm[f'advanced_diffusion_coords_rep{i+1}'] = coords
    
    print(f"\nAdvanced diffusion training complete!")
    print(f"Results saved in scadata.obsm['advanced_diffusion_coords_avg']")
    
    return scadata, models_all

# Load and process data
scadata, stadata1, stadata2, stadata3 = load_and_process_cscc_data()

# Train individual AdvancedHierarchicalDiffusion models and get averaged results
scadata, advanced_models = train_individual_advanced_diffusion_models(
    scadata, stadata1, stadata2, stadata3
)

print("Advanced diffusion training complete! Results saved in scadata.obsm['advanced_diffusion_coords_avg']")

# Visualize results
import matplotlib.pyplot as plt
import seaborn as sns

my_tab20 = sns.color_palette("tab20", n_colors=20).as_hex()

# Plot 1: Averaged coordinates
plt.figure(figsize=(8, 6))
sc.pl.embedding(scadata, basis='advanced_diffusion_coords_avg', color='rough_celltype',
               size=85, title='SC Advanced Diffusion Coords (Averaged)',
               palette=my_tab20, legend_loc='right margin', legend_fontsize=10)
plt.show()

# Plot 2: Individual model results
for i in range(3):
    plt.figure(figsize=(6, 5))
    sc.pl.embedding(scadata, basis=f'advanced_diffusion_coords_rep{i+1}', color='rough_celltype',
                   size=85, title=f'SC Coordinates (Advanced Model {i+1})',
                   palette=my_tab20, legend_loc='right margin', legend_fontsize=10)
    plt.show()

In [None]:
%matplotlib inline
import scanpy as sc
sc.settings.set_figure_params(dpi=100, facecolor='white')

# Visualize results with separate plots
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (6,6)


import seaborn as sns
my_tab20 = sns.color_palette("tab20", n_colors=20).as_hex()

# Plot 1: Averaged coordinates
plt.figure(figsize=(4, 4))
sc.pl.embedding(scadata, basis='sb_coords', color='rough_celltype',
               size=85, title='SC Spatial Coordinates (Averaged from 3 Models)',
               palette=my_tab20, legend_loc='right margin', legend_fontsize=10)
plt.show()

# Plot 2: Model 1 results
plt.figure(figsize=(4, 4))
sc.pl.embedding(scadata, basis='diffusion_coords_rep1', color='rough_celltype',
               size=85, title='SC Coordinates (Model 1)',
               palette=my_tab20, legend_loc='right margin', legend_fontsize=10)
plt.show()

# Plot 3: Model 2 results
plt.figure(figsize=(4, 4))
sc.pl.embedding(scadata, basis='diffusion_coords_rep2', color='rough_celltype',
               size=85, title='SC Coordinates (Model 2)',
               palette=my_tab20, legend_loc='right margin', legend_fontsize=10)
plt.show()



# Plot 4: Model 3 results
plt.figure(figsize=(4, 4))
sc.pl.embedding(scadata, basis='diffusion_coords_rep3', color='rough_celltype',
               size=85, title='SC Coordinates (Model 3)',
             palette=my_tab20, legend_loc='right margin', legend_fontsize=10)
plt.show()

In [None]:
scadata.obsm['sb_coords']