In [None]:
import torch
import scanpy as sc
import pandas as pd
import numpy as np
from tqdm import tqdm
from sklearn.neighbors import kneighbors_graph
from scipy.sparse.csgraph import dijkstra
from scipy.sparse import csr_matrix, issparse
from sklearn.preprocessing import normalize
import ot 
from sklearn.neighbors import NearestNeighbors

In [None]:
def construct_graph_torch(X, k, mode='connectivity', metric = 'minkowski', p=2, device='cuda'):
    '''construct knn graph with torch and gpu
    args:
        X: input data containing features (torch tensor)
        k: number of neighbors for each data point
        mode: 'connectivity' or 'distance'
        metric: distance metric (now euclidean supported for gpu knn)
        p: param for minkowski (not used if metric is euclidean)
    
    Returns:
        knn graph as a pytorch sparse tensor (coo format) or dense tensor depending on mode     
    '''

    assert mode in ['connectivity', 'distance'], "mode must be 'connectivity' or 'distance'."
    assert metric == 'euclidean', "for gpu knn, only 'euclidean' metric is currently supported in this implementation"

    if mode == 'connectivity':
        include_self = True
        mode_knn = 'connectivity'
    else:
        include_self = False
        mode_knn = 'distance'

    n_samples = X.shape[0]
    knn = NearestNeighbors(n_neighbors=k, metric=metric, algorithm='auto')

    if device == 'cuda' and torch.cuda.is_available():
        X_cpu = X.cpu().numpy()
    else:
        X_cpu = X.numpy()

    knn.fit(X_cpu)
    knn_graph_cpu = kneighbors_graph(knn, k, mode=mode_knn, include_self=include_self, metric=metric) #scipy sparse matrix on cpu
    knn_graph_coo = knn_graph_cpu.tocoo()

    if mode == 'connectivity':
        knn_graph = torch.sparse_coo_tensor(torch.LongTensor([knn_graph_coo.row, knn_graph_coo.col]),
                                            torch.FloatTensor(knn_graph_coo.data),
                                            size = knn_graph_coo.shape).to(device)
    elif mode == 'distance':
        knn_graph_dense = torch.tensor(knn_graph_cpu.toarray(), dtype=torch.float32, device=device) #move to gpu as dense tensor
        knn_graph = knn_graph_dense
    
    return knn_graph
    
def distances_cal_torch(graph, type_aware=None, aware_power =2, device='cuda'):
    '''
    calculate distance matrix from graph using dijkstra's algo
    args:
        graph: knn graph (pytorch sparse or dense tensor)
        type_aware: not implemented in this torch version for simplicity
        aware_power: same ^^
        device (str): 'cpu' or 'cuda' device to use
    Returns:
        distance matrix as a torch tensor
    '''

    if isinstance(graph, torch.Tensor) and graph.is_sparse:
        graph_cpu_csr = csr_matrix(graph.cpu().to_dense().numpy())
    elif isinstance(graph, torch.Tensor) and not graph.is_sparse:
        graph_cpu_csr = csr_matrix(graph.cpu().numpy())
    else:
        graph_cpu_csr = csr_matrix(graph) #assume scipy sparse matrix if not torch tensor

    shortestPath_cpu = dijkstra(csgraph = graph_cpu_csr, directed=False, return_predecessors=False) #dijkstra on cpu
    shortestPath = torch.tensor(shortestPath_cpu, dtype=torch.float32, device=device)

    # the_max = torch.nanmax(shortestPath[shortestPath != float('inf')])
    # shortestPath[shortestPath > the_max] = the_max

    #mask out infinite distances
    mask = shortestPath != float('inf')
    if mask.any():
        the_max = torch.max(shortestPath[mask])
        shortestPath[~mask] = the_max #replace inf with max value
    else:
        the_max = 1.0 #fallback if all are inf (should not happen in connected graphs)

    C_dis = shortestPath / the_max
    C_dis -= torch.mean(C_dis)
    return C_dis

def calculate_D_sc_torch(X_sc, k_neighbors=10, graph_mode='connectivity', device='cpu'):
    '''calculate distance matrix from graph using dijkstra's algo
    args:
        graph: knn graph (torch sparse or dense tensor)
        type_aware: not implemented
        aware_power: same ^^
        
    returns:
        distanced matrix as torch tensor'''
    
    if not isinstance(X_sc, torch.Tensor):
        raise TypeError('Input X_sc must be a pytorch tensor')
    
    if device == 'cuda' and torch.cuda.is_available():
        X_sc = X_sc.cuda(device=device)
    else:
        X_sc = X_sc.cpu()
        device= 'cpu'

    print(f'using device: {device}')
    print(f'constructing knn graph...')
    # X_normalized = normalize(X_sc.cpu().numpy(), norm='l2') #normalize on cpu for sklearn knn
    X_normalized = X_sc
    X_normalized_torch = torch.tensor(X_normalized, dtype=torch.float32, device=device)

    Xgraph = construct_graph_torch(X_normalized_torch, k=k_neighbors, mode=graph_mode, metric='euclidean', device=device)

    print('calculating distances from graph....')
    D_sc = distances_cal_torch(Xgraph, device=device)

    print('D_sc calculation complete')
    
    return D_sc


In [None]:
from sklearn.neighbors import kneighbors_graph, NearestNeighbors
from scipy.sparse.csgraph import dijkstra
from scipy.sparse import csr_matrix, issparse
from sklearn.preprocessing import normalize
import ot

def construct_graph_spatial(location_array, k, mode='distance', metric='euclidean', p=2):
    '''construct KNN graph based on spatial coordinates
    args:
        location_array: spatial coordinates of spots (n-spots * 2)
        k: number of neighbors for each spot
        mode: 'connectivity' or 'distance'
        metric: distance metric for knn (p=2 is euclidean)
        p: param for minkowski if connectivity
        
    returns:
        scipy.sparse.csr_matrix: knn graph in csr format
    '''

    assert mode in ['connectivity', 'distance'], "mode must be 'connectivity' or 'distance'"
    if mode == 'connectivity':
        include_self = True
    else:
        include_self = False
    
    c_graph = kneighbors_graph(location_array, k, mode=mode, metric=metric, include_self=include_self, p=p)
    return c_graph

def distances_cal_spatial(graph, spot_ids=None, spot_types=None, aware_power=2):
    '''calculate spatial distance matrix from knn graph
    args:
        graph (scipy.sparse.csr_matrix): knn graph
        spot_ids (list, optional): list of spot ids corresponding to the rows/cols of the graph. required if type_aware is used
        spot_types (pd.Series, optinal): pandas series of spot types for type aware distance adjustment. required if type_aware is used
        aware_power (int): power for type-aware distance adjustment
        
    returns:
        sptial distance matrix'''
    shortestPath = dijkstra(csgraph = csr_matrix(graph), directed=False, return_predecessors=False)
    shortestPath = np.nan_to_num(shortestPath, nan=np.inf) #handle potential inf valyes after dijkstra

    if spot_types is not None and spot_ids is not None:
        shortestPath_df = pd.DataFrame(shortestPath, index=spot_ids, columns=spot_ids)
        shortestPath_df['id1'] = shortestPath_df.index
        shortestPath_melted = shortestPath_df.melt(id_vars=['id1'], var_name='id2', value_name='value')

        type_aware_df = pd.DataFrame({'spot': spot_ids, 'spot_type': spot_types}, index=spot_ids)
        meta1 = type_aware_df.copy()
        meta1.columns = ['id1', 'type1']
        meta2 = type_aware_df.copy()
        meta2.columns = ['id2', 'type2']

        shortestPath_melted = pd.merge(shortestPath_melted, meta1, on='id1', how='left')
        shortestPath_melted = pd.merge(shortestPath_melted, meta2, on='id2', how='left')

        shortestPath_melted['same_type'] = shortestPath_melted['type1'] == shortestPath_melted['type2']
        shortestPath_melted.loc[(~shortestPath_melted.smae_type), 'value'] = shortestPath_melted.loc[(~shortestPath_melted.same_type),
                                                                                                     'value'] * aware_power
        shortestPath_melted.drop(['type1', 'type2', 'same_type'], axis=1, inplace=True)
        shortestPath_pivot = shortestPath_melted.pivot(index='id1', columns='id2', values='value')

        order = spot_ids
        shortestPath = shortestPath_pivot[order].loc[order].values
    else:
        shortestPath = np.asarray(shortestPath) #ensure it's a numpy array

    #mask out infinite distances
    mask = shortestPath != float('inf')
    if mask.any():
        the_max = np.max(shortestPath[mask])
        shortestPath[~mask] = the_max #replace inf with max value
    else:
        the_max = 1.0 #fallback if all are inf (should not happen in connected graphs)

    C_dis = shortestPath / the_max
    C_dis -= np.mean(C_dis)

    return C_dis

def calculate_D_st_from_coords(spatial_coords, X_st=None, k_neighbors=10, graph_mode='distance', aware_st=False, 
                               spot_types=None, aware_power_st=2, spot_ids=None):
    '''calculates the spatial distance matrix D_st for spatial transcriptomics data directly from coordinates and optional spot types
    args:
        spatial_coords: spatial coordinates of spots (n_spots * 2)
        X_st: St gene expression data (not used for D_st calculation itself)
        k_neighbors: number of neighbors for knn graph
        graph_mode: 'connectivity or 'distance' for knn graph
        aware_st: whether to use type-aware distance adjustment
        spot_types: pandas series of spot types for type-aware adjustment
        aware_power_st: power for type-aware distance adjustment
        spot_ids: list or index of spot ids, required if spot_ids is provided
        
    returns:
        np.ndarray: spatial disance matrix D_st'''
    
    if isinstance(spatial_coords, pd.DataFrame):
        location_array = spatial_coords.values
        if spot_ids is None:
            spot_ids = spatial_coords.index.tolist() #use index of dataframe if available
    elif isinstance(spatial_coords, np.ndarray):
        location_array = spatial_coords
        if spot_ids is None:
            spot_ids = list(range(location_array.shape[0])) #generate default ids if not provided

    else:
        raise TypeError('spatial_coords must be a pandas dataframe or a numpy array')
    
    print(f'constructing {graph_mode} graph for ST data with k={k_neighbors}.....')
    Xgraph_st = construct_graph_spatial(location_array, k=k_neighbors, mode=graph_mode)
    
    if aware_st:
        if spot_types is None or spot_ids is None:
            raise ValueError('spot_types and spot_ids must be provided when aware_st=True')
        if not isinstance(spot_types, pd.Series):
            spot_types = pd.Series(spot_types, idnex=spot_ids) 
        print('applying type aware distance adjustment for ST data')
        print(f'aware power for ST: {aware_power_st}')
    else:
        spot_types = None 

    print(f'calculating spatial distances.....')
    D_st = distances_cal_spatial(Xgraph_st, spot_ids=spot_ids, spot_types=spot_types, aware_power=aware_power_st)

    print('D_st calculation complete')
    return D_st


In [None]:
def fused_gw_torch(X_sc, X_st, Y_st, alpha, k=100, G0=None, max_iter = 100, tol=1e-9, device='cuda', n_iter = 1):
    n = X_sc.shape[0]
    m = X_st.shape[0]

    X_sc = X_sc.to(device)
    X_st = X_st.to(device)

    if not torch.is_tensor(Y_st):
        Y_st_tensor = torch.tensor(Y_st, dtype=torch.float32, device=device)
    else:
        Y_st_tensor = Y_st.to(device, dtype=torch.float32)

    #calculate distance matrices
    print('calculating SC distances with knn-dijkstra.....')
    D_sc = calculate_D_sc_torch(X_sc, k_neighbors=k, device=device)

    print('Calculating ST distances.....')
    D_st = calculate_D_st_from_coords(spatial_coords=Y_st, k_neighbors=15, graph_mode="distance") # Using calculate_D_st_from_coords
    D_st = torch.tensor(D_st, dtype=torch.float32, device=device) # Convert D_st to tensor and move to device

    #get expression distance matrix
    C_exp = torch.cdist(X_sc, X_st, p=2) #euclidean distance
    C_exp = C_exp / (torch.max(C_exp) + 1e-16) #normalize

    #ensure distance matries are C-contiguouse numpy arrays for POT
    D_sc_np = D_sc.cpu().numpy()
    D_st_np = D_st.cpu().numpy()
    C_exp_np = C_exp.cpu().numpy()
    D_sc_np = np.ascontiguousarray(D_sc_np)
    D_st_np = np.ascontiguousarray(D_st_np)
    C_exp_np = np.ascontiguousarray(C_exp_np)

    #uniform distributions
    p = ot.unif(n)
    q = ot.unif(m)

    #anneal the reg param over several steps
    T_np = None
    for i in range(n_iter):
        #run fused gw with POT
        T_np, log = ot.gromov.fused_gromov_wasserstein(
            M=C_exp_np, C1=D_sc_np, C2=D_st_np,
            p=p, q=q, loss_fun='square_loss',
            alpha=alpha,
            G0=T_np if T_np is not None else (G0.cpu().numpy() if G0 is not None else None),
            log=True,
            verbose=True,
            max_iter = max_iter,
            tol_abs=tol
        )

    fgw_dist = log['fgw_dist']

    print(f'fgw distance: {fgw_dist}')

    T = torch.tensor(T_np, dtype=torch.float32, device=device)

    return T, D_sc, D_st, fgw_dist

In [None]:
scdata = pd.read_csv('./data/mousedata_2020/E1z2/simu_sc_counts.csv',index_col=0)
scdata = scdata.T
stdata = pd.read_csv('data/mousedata_2020/E1z2/simu_st_counts.csv',index_col=0)
stdata = stdata.T
stgtcelltype = pd.read_csv('./data/mousedata_2020/E1z2/simu_st_celltype.csv',index_col=0)
spcoor = pd.read_csv('./data/mousedata_2020/E1z2/simu_st_metadata.csv',index_col=0)
scmetadata = pd.read_csv('./data/mousedata_2020/E1z2/metadata.csv',index_col=0)

adata = sc.AnnData(scdata,obs=scmetadata)
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)
scdata = pd.DataFrame(adata.X,index=adata.obs_names,columns=adata.var_names)
stadata = sc.AnnData(stdata)
sc.pp.normalize_total(stadata)
sc.pp.log1p(stadata)
stdata = pd.DataFrame(stadata.X,index=stadata.obs_names,columns=stadata.var_names)

adata.obsm['spatial'] = scmetadata[['x_global','y_global']].values
stadata.obsm['spatial'] = spcoor

# Preprocess data (normalize, log transform)
adata = sc.AnnData(scdata, obs=scmetadata)
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)
scdata_processed = pd.DataFrame(adata.X, index=adata.obs_names, columns=adata.var_names)
X_sc = torch.tensor(scdata_processed.values, dtype=torch.float32)

stadata = sc.AnnData(stdata)
sc.pp.normalize_total(stadata)
sc.pp.log1p(stadata)
stdata_processed = pd.DataFrame(stadata.X, index=stadata.obs_names, columns=stadata.var_names)
X_st = torch.tensor(stdata_processed.values, dtype=torch.float32)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
Y_st = spcoor.values
# --- Run FGW using POT ---
T, D_sc, D_st, fgw_dist = fused_gw_torch(
    X_sc=X_sc, X_st=X_st, Y_st=Y_st,
    alpha=0.3, # Example: balance expression and structure equally
    k=300,      # k for SC graph
    max_iter=200,
    device=device
)

In [None]:
D_st = D_st.to(device)
D_induced = T @ D_st @ T.t()
D_induced

In [None]:
# Initialize the diffusion model with the simplified parameters
if isinstance(Y_st, torch.Tensor):
    Y_st = Y_st.cpu().numpy()
if isinstance(X_st, torch.Tensor):
    X_st = X_st.cpu().numpy()
if isinstance(X_sc, torch.Tensor):
    X_sc = X_sc.cpu().numpy()
if isinstance(D_induced, torch.Tensor):
    D_induced = D_induced.cpu().numpy()


diffusion = CoordinateDiffusion(
    st_gene_expr=X_st,
    st_coords=Y_st,
    sc_gene_expr=X_sc,
    D_induced=D_induced,
    device="cuda",
    n_timesteps=800,  # Increased timesteps for better quality
    beta_start=1e-4,   # Standard DDPM values
    beta_end=0.02
)

print("Training diffusion model...")

# 4. Training with adjusted parameters
# First phase - coordinate denoising
diffusion.pretrain_denoising(
    n_epochs=2000,    # More epochs for better convergence
    batch_size=192     # Larger batch size for stability
)

# Second phase - conditional training
diffusion.train_conditional(
    n_epochs=3000,    # More epochs for conditional training
    batch_size=192   # Keep batch size consistent
)

# 5. Sample coordinates with proper parameters
sc_coords = diffusion.sample_coordinates(
    timesteps=800     # Can use fewer timesteps for sampling
)

In [None]:
Y_st.shape

In [None]:
gt_sc_coords = np.column_stack([
    adata.obs['x_global'].values,
    adata.obs['y_global'].values
])

gt_sc_coords.shape

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import os
from torch_geometric.nn import SAGEConv
from torch_geometric.utils import to_undirected

class SinusoidalEmbedding(nn.Module):
    """Sinusoidal embeddings for diffusion timesteps"""
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
    
    def forward(self, t):
        half_dim = self.dim // 2
        emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
        emb = t * emb[None, :]
        emb = torch.cat((torch.sin(emb), torch.cos(emb)), dim=-1)
        if self.dim % 2 == 1:
            emb = F.pad(emb, (0, 1, 0, 0))
        return emb

class ResidualBlock(nn.Module):
    """Residual block with skip connections"""
    def __init__(self, in_dim, hidden_dim=None):
        super().__init__()
        hidden_dim = hidden_dim or in_dim
        self.block1 = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.SiLU()
        )
        self.block2 = nn.Sequential(
            nn.Linear(hidden_dim, in_dim),
            nn.LayerNorm(in_dim),
        )
        self.activation = nn.SiLU()
    
    def forward(self, x):
        h = self.block1(x)
        h = self.block2(h)
        return self.activation(x + h)
    
class GraphSageModel(nn.Module):
    """GraphSAGE model that scales better to large graphs"""
    def __init__(self, in_dim, hidden_dim=256, out_dim=128, n_layers=2):
        super().__init__()
        self.n_layers = n_layers
        
        # First layer
        self.conv_layers = nn.ModuleList([SAGEConv(in_dim, hidden_dim)])
        
        # Hidden layers
        for _ in range(n_layers - 2):
            self.conv_layers.append(SAGEConv(hidden_dim, hidden_dim))
            
        # Output layer
        self.conv_layers.append(SAGEConv(hidden_dim, out_dim))
        
        self.norm_layers = nn.ModuleList([
            nn.LayerNorm(hidden_dim if i < n_layers - 1 else out_dim)
            for i in range(n_layers)
        ])
        
    def forward(self, x, edge_index):
        for i, conv in enumerate(self.conv_layers):
            x = conv(x, edge_index)
            x = self.norm_layers[i](x)
            if i < self.n_layers - 1:
                x = F.silu(x)
        return x
    
class CoordinateDenoiser(nn.Module):
    """Enhanced denoiser model with GNN embeddings"""
    def __init__(
        self,
        coord_dim=2,
        feature_dim=None,
        time_dim=256,
        hidden_dim=256,
        gnn_dim=128,
        n_blocks=4
    ):
        super().__init__()
        # Time embedding
        self.time_embed = nn.Sequential(
            SinusoidalEmbedding(time_dim),
            nn.Linear(time_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Feature embedding network
        self.feature_encoder = None
        if feature_dim is not None:
            self.feature_encoder = nn.Sequential(
                nn.Linear(feature_dim, hidden_dim),
                nn.LayerNorm(hidden_dim),
                nn.SiLU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.LayerNorm(hidden_dim),
                nn.SiLU()
            )
        
        # Coordinate embedding
        self.coord_encoder = nn.Sequential(
            nn.Linear(coord_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.SiLU()
        )

        # GNN embedding projection
        self.gnn_encoder = nn.Sequential(
            nn.Linear(gnn_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.SiLU()
        )
        
        # Main network with residual blocks
        self.blocks = nn.ModuleList([
            ResidualBlock(hidden_dim) for _ in range(n_blocks)
        ])
        
        # Final output
        self.final = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim//2),
            nn.SiLU(),
            nn.Linear(hidden_dim//2, coord_dim)
        )
    
    def forward(self, coords, t, features=None, gnn_emb=None):
        # Time embedding
        t_emb = self.time_embed(t)
        
        # Coordinate embedding
        h = self.coord_encoder(coords)
        
        # Add time embedding
        h = h + t_emb
        
        # Add gene expression features
        if features is not None and self.feature_encoder is not None:
            feat_emb = self.feature_encoder(features)
            h = h + feat_emb
            
        # Add GNN embedding if provided
        if gnn_emb is not None:
            gnn_emb_proj = self.gnn_encoder(gnn_emb)
            h = h + gnn_emb_proj
        
        # Process through blocks
        for block in self.blocks:
            h = block(h)
        
        # Predict noise
        return self.final(h)

class CoordinateDiffusion:
    def __init__(
        self, 
        st_gene_expr,
        st_coords,
        sc_gene_expr,
        D_induced,
        device='cuda',
        n_timesteps = 1000,
        beta_start = 1e-4,
        beta_end = 0.02,
        gnn_dim = 128
    ):
        self.device = torch.device(device)

        # Store data
        self.st_gene_expr = torch.tensor(st_gene_expr, dtype=torch.float32).to(self.device)
        self.st_coords_init = torch.tensor(st_coords, dtype=torch.float32).to(self.device)
        self.sc_gene_expr = torch.tensor(sc_gene_expr, dtype=torch.float32).to(self.device)
        self.D_induced = torch.tensor(D_induced, dtype=torch.float32).to(self.device)

        coords_min = self.st_coords_init.min(dim=0)[0]
        coords_max = self.st_coords_init.max(dim=0)[0]
        coords_range = coords_max - coords_min
        self.st_coords = 2 * (self.st_coords_init - coords_min) / coords_range -1 
        self.coords_min, self.coords_max = coords_min, coords_max
        self.coords_range = coords_range 

        # Model dimensions
        self.n_genes = st_gene_expr.shape[1]
        self.gnn_dim = gnn_dim
        
        # Setup noise schedule (from Code 1 - improved diffusion process)
        self.n_timesteps = n_timesteps
        self.noise_schedule = self.get_noise_schedule(n_timesteps, beta_start, beta_end)
        
        # Initialize GraphSAGE model for gene expression embedding
        self.gnn_model = GraphSageModel(
            in_dim=self.n_genes,
            hidden_dim=256,
            out_dim=self.gnn_dim,
            n_layers=2
        ).to(self.device)
        
        # Initialize denoiser model
        self.model = CoordinateDenoiser(
            coord_dim=2,
            feature_dim=self.n_genes,
            time_dim=256,
            hidden_dim=256,
            gnn_dim=self.gnn_dim,
            n_blocks=4
        ).to(self.device)

        # Optimizer
        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-4, weight_decay=1e-5)
        
        # Calculate ST distances matrix (used for training/evaluation)
        self.st_distances = torch.cdist(self.st_coords, self.st_coords, p=2)
        
        # Precompute GNN embeddings (will be done in training)
        self.st_gnn_embeddings = None

    def get_noise_schedule(self, timesteps=1000, beta1=1e-4, beta2=0.02):
        """Returns DDPM noise schedule parameters (from Code 1)"""
        # Linear schedule
        betas = torch.linspace(beta1, beta2, timesteps, device=self.device)
        alphas = 1.0 - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
        sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
        
        # Fix posterior variance calculation
        posterior_variance = torch.zeros_like(betas)
        posterior_variance[1:] = betas[1:] * (1. - alphas_cumprod[:-1]) / (1. - alphas_cumprod[1:])
        posterior_variance[0] = betas[0]
        
        return {
            'betas': betas,
            'alphas': alphas,
            'alphas_cumprod': alphas_cumprod,
            'sqrt_alphas_cumprod': sqrt_alphas_cumprod,
            'sqrt_one_minus_alphas_cumprod': sqrt_one_minus_alphas_cumprod,
            'posterior_variance': posterior_variance
        }
    
    def add_noise(self, x_0, t, noise_schedule):
        """Add noise to coordinates according to timestep t (from Code 1)"""
        noise = torch.randn_like(x_0)
        sqrt_alphas_cumprod_t = noise_schedule['sqrt_alphas_cumprod'][t].view(-1, 1)
        sqrt_one_minus_alphas_cumprod_t = noise_schedule['sqrt_one_minus_alphas_cumprod'][t].view(-1, 1)
        
        # Add noise according to schedule
        x_t = sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise
        
        return x_t, noise
    
    def pretrain_gnn(self, epochs=100, lr=1e-3):
        """Pretrain the GraphSAGE model in unsupervised manner"""
        print("Pretraining GraphSAGE on ST data...")
        
        # Build a k-NN graph for ST data based on gene expression similarity
        similarity = F.normalize(self.st_gene_expr, p=2, dim=1) @ F.normalize(self.st_gene_expr, p=2, dim=1).t()
        k = min(20, self.st_gene_expr.shape[0]-1)
        _, indices = torch.topk(similarity, k+1, dim=1)
        indices = indices[:, 1:]  # Remove self-loops
        rows = torch.arange(self.st_gene_expr.shape[0], device=self.device).repeat_interleave(k)
        cols = indices.reshape(-1)
        edge_index = torch.stack([rows, cols], dim=0).to(self.device)
        edge_index = to_undirected(edge_index)
        
        optimizer = torch.optim.Adam(self.gnn_model.parameters(), lr=lr)
        for epoch in range(epochs):
            self.gnn_model.train()
            optimizer.zero_grad()
            embeddings = self.gnn_model(self.st_gene_expr, edge_index)
            
            # For each edge, compute cosine similarity
            src = edge_index[0]
            dst = edge_index[1]
            pos_sim = F.cosine_similarity(embeddings[src], embeddings[dst])
            
            # Sample negative nodes randomly for each source
            neg_indices = torch.randint(0, self.st_gene_expr.shape[0], (src.size(0),), device=self.device)
            neg_sim = F.cosine_similarity(embeddings[src], embeddings[neg_indices])

            # Loss: encourage high similarity for positive pairs and lower for negatives
            loss = -torch.mean(pos_sim) + torch.mean(neg_sim)
            loss.backward()
            optimizer.step()
            
            if epoch % 10 == 0:
                print(f'GNN pretraining epoch {epoch}/{epochs}, loss: {loss.item():.4f}')
        
        # After training, compute final embeddings
        self.gnn_model.eval()
        with torch.no_grad():
            self.st_gnn_embeddings = self.gnn_model(self.st_gene_expr, edge_index)
        
        print("GraphSAGE pretraining complete!")
        return self.st_gnn_embeddings
    
    def compute_sc_gnn_embeddings(self, batch_size=1000):
        """Compute GNN embeddings for SC data in batches to prevent OOM"""
        print("Computing SC GNN embeddings...")
        
        # Build a k-NN graph for SC data
        # Process in batches to avoid OOM
        n_sc = self.sc_gene_expr.shape[0]
        all_embeddings = torch.zeros(n_sc, self.gnn_dim, device=self.device)
        
        # For very large datasets, process in chunks
        chunk_size = min(10000, n_sc)  # Maximum chunk size to process at once
        
        for start_idx in range(0, n_sc, chunk_size):
            end_idx = min(start_idx + chunk_size, n_sc)
            chunk = self.sc_gene_expr[start_idx:end_idx]
            
            # Build k-NN graph for this chunk
            similarity = F.normalize(chunk, p=2, dim=1) @ F.normalize(chunk, p=2, dim=1).t()
            k = min(20, chunk.shape[0] - 1)
            _, indices = torch.topk(similarity, k + 1, dim=1)
            indices = indices[:, 1:]  # Remove self-loops
            rows = torch.arange(chunk.shape[0], device=self.device).repeat_interleave(k)
            cols = indices.reshape(-1)
            edge_index = torch.stack([rows, cols], dim=0).to(self.device)
            edge_index = to_undirected(edge_index)
            
            # Get embeddings
            self.gnn_model.eval()
            with torch.no_grad():
                chunk_embeddings = self.gnn_model(chunk, edge_index)
                all_embeddings[start_idx:end_idx] = chunk_embeddings
        
        return all_embeddings
    
    def compute_multi_scale_adjacency(self, coords, sigma=0.1, alpha=0.5):
        """Compute a multi-scale adjacency matrix for structural constraint"""
        # Compute pairwise euclidean distances
        distances = torch.cdist(coords, coords, p=2)
        
        # Local adjacency using Gaussian kernel
        weights = torch.exp(-(distances ** 2) / (2 * sigma ** 2))
        identity_mask = torch.eye(weights.shape[0], device=self.device)
        local_adj = weights * (1 - identity_mask)
        local_adj = local_adj / (local_adj.sum(dim=1, keepdim=True) + 1e-8)
        
        # Global adjacency via random walk (4-hop)
        rw_adj = torch.matrix_power(local_adj, 4)
        rw_adj = rw_adj / (rw_adj.sum(dim=1, keepdim=True) + 1e-8)
        
        # Combine two scales
        multi_scale = alpha * local_adj + (1 - alpha) * rw_adj
        multi_scale = multi_scale / (multi_scale.sum(dim=1, keepdim=True) + 1e-8)
        
        return multi_scale
    
    def _calculate_adjacency_from_distances(self, distances, sigma=3.0):
        '''convert dstances to adj matrix using gaussian kernel'''
        weights = torch.exp(-(distances ** 2) / (2* sigma ** 2))
        #zero out self connections to avoid numerical issues
        identity_mask = torch.eye(weights.shape[0], device=weights.device)
        weights = weights * (1- identity_mask)
        #normalzie rows to sum to 1
        row_sums = weights.sum(dim=1, keepdim=True)
        adjacecny = weights / (row_sums + 1e-8)
        return adjacecny
    
    def pretrain_denoising(self, n_epochs=1000, batch_size=128, lambda_structure=1.0):
        '''first phase: train on st coords denoising (without conditioning)'''
        print('starting first training phase: coordinate denosiing pretraining....')

        #if gnn embeddings are not computed yet, compute
        # if self.st_gnn_embeddings is None:
        #     self.pretrain_gnn(epochs=1000)

        #compute multi scale adj matrix for structural constraints
        st_adj_matrix = self.compute_multi_scale_adjacency(self.st_coords, sigma=3.0, alpha=0.5)

        running_loss = 0
        running_diffusion_loss = 0
        running_structure_loss = 0

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max = n_epochs)

        for epoch in range(n_epochs):
            #random batch of st coords
            idx = torch.randint(0, len(self.st_coords), (batch_size,))
            coords = self.st_coords[idx]

            #random timesteps
            t = torch.randint(0, self.n_timesteps, (batch_size,)).to(self.device)

            #add noise
            noisy_coords, target_noise = self.add_noise(coords, t, self.noise_schedule)

            #predict noise (without conditioning)
            pred_noise = self.model(noisy_coords, t.unsqueeze(1).float() / self.n_timesteps)

            #diffusion loss
            diffusion_loss = F.mse_loss(pred_noise, target_noise)

            #estimate one step denoise coordinates
            sqrt_alphas_cumprod_t = self.noise_schedule['sqrt_alphas_cumprod'][t].view(-1, 1)
            sqrt_one_minus_alphas_cumprod_t = self.noise_schedule['sqrt_one_minus_alphas_cumprod'][t].view(-1, 1)
            pred_coords = (noisy_coords - sqrt_one_minus_alphas_cumprod_t * pred_noise) / sqrt_alphas_cumprod_t

            #compute structure loss based on distances
            pred_distances = torch.cdist(pred_coords, pred_coords, p=2)
            true_distances = torch.cdist(coords, coords, p=2)
            structure_loss = F.mse_loss(pred_distances, true_distances)

            #total loss
            # loss = diffusion_loss + lambda_structure * structure_loss?
            loss = diffusion_loss

            #update 
            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
            self.optimizer.step()
            scheduler.step()

            running_loss += loss.item()
            running_diffusion_loss += diffusion_loss.item()
            running_structure_loss += structure_loss.item()

            if epoch % 100 == 0:
                avg_loss = running_loss / (epoch + 1 if epoch > 0 else 1)
                avg_diff_loss = running_diffusion_loss / (epoch + 1 if epoch > 0 else 1)
                avg_struct_loss = running_structure_loss / (epoch + 1 if epoch > 0 else 1)
                print(f'pretrain epoch {epoch},'
                      f'Loss: {loss.item():.6f},'
                      f'diff loss: {diffusion_loss.item():.6f},'
                      f'struct loss: {structure_loss.item():.6f},'
                      f'LR: {scheduler.get_last_lr()[0]:.6f}')
                
        print('first training phase complete!')

    def train_conditional(self, n_epochs=2000, batch_size=128, lambda_structure=5.0):
        '''second phase: train with gene expression and gnn conditioning'''
        print('starting second training phase: conditional training with GNN.....')

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, n_epochs, eta_min=1e-6)
        running_loss = 0
        running_diffusion_loss = 0
        running_structure_loss = 0
        running_adj_loss = 0

        #compute multi scale adj matrix for structural constraints
        st_adj_matrix = self.compute_multi_scale_adjacency(self.st_coords, sigma=3.0, alpha=0.5)

        for epoch in range(n_epochs):
            #random batch
            idx = torch.randint(0, len(self.st_coords), (batch_size,))
            coords = self.st_coords[idx]
            features = self.st_gene_expr[idx]
            # gnn_emb = self.st_gnn_embeddings[idx] 
            gnn_emb = None

            #get sub-adj matrix for this batch
            sub_adj = st_adj_matrix[idx, :][:, idx]

            #random timesteps
            t = torch.randint(0, self.n_timesteps, (batch_size,)).to(self.device)

            noisy_coords, target_noise = self.add_noise(coords, t, self.noise_schedule)

            #predict noise with conditioning
            pred_noise = self.model(
                noisy_coords,
                t.unsqueeze(1).float() / self.n_timesteps,
                features,
                gnn_emb
            )

            #diffusion loss
            diffusion_loss = F.mse_loss(pred_noise, target_noise)

            #estimate one-step denoised coordinates
            sqrt_alphas_cumprod_t = self.noise_schedule['sqrt_alphas_cumprod'][t].view(-1, 1)
            sqrt_one_minus_alphas_cumprod_t = self.noise_schedule['sqrt_one_minus_alphas_cumprod'][t].view(-1, 1)
            pred_coords = (noisy_coords - sqrt_one_minus_alphas_cumprod_t * pred_noise)/ sqrt_alphas_cumprod_t

            #structure loss based on distances
            pred_distances = torch.cdist(pred_coords, pred_coords, p=2)
            true_distances = torch.cdist(coords, coords, p=2)
            structure_loss = F.mse_loss(pred_distances, true_distances)

            #adj loss
            pred_adj = self.compute_multi_scale_adjacency(pred_coords, sigma=3.0, alpha=0.5)
            adj_loss = F.kl_div(
                torch.log(pred_adj + 1e-8),
                sub_adj,
                reduction='batchmean'
            )

            #total loss
            # loss = 5 * diffusion_loss + lambda_structure * structure_loss + lambda_structure * adj_loss
            loss = 5 * diffusion_loss + lambda_structure * adj_loss


            #update
            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
            self.optimizer.step()
            scheduler.step()

            running_loss += loss.item()
            running_diffusion_loss += diffusion_loss.item()
            running_structure_loss + structure_loss.item()
            running_adj_loss += adj_loss.item()

            if epoch % 100 == 0:
                avg_loss = running_loss / (epoch + 1 if epoch > 0 else 1)
                avg_diff_loss = running_diffusion_loss / (epoch + 1 if epoch > 0 else 1)
                avg_struct_loss = running_structure_loss / (epoch + 1 if epoch > 0 else 1)
                avg_adj_loss = running_adj_loss / (epoch + 1 if epoch > 0 else 1)
                print(f'conditional train epoch {epoch},'
                      f'loss: {loss.item(): .6f},'
                      f'diff loss: {diffusion_loss.item(): .6f},'
                      f'struct loss: {structure_loss.item(): .6f},'
                      f'adj loss: {adj_loss.item(): .6f},'
                      f'LR: {scheduler.get_last_lr()[0]: .6f}')
                
        print('second training phase comeplete!')

    def train(self, n_epochs=2000, batch_size=64, timesteps=500, beta1=1e-4, beta2=0.02, 
            lambda_start=1.0, lambda_end=5.0, save_every=100, checkpoint_dir="./checkpoints"):
        """Train the diffusion model with GNN conditioning - single stage approach"""
        os.makedirs(checkpoint_dir, exist_ok=True)
        
        # Setup noise schedule
        noise_schedule = self.get_noise_schedule(timesteps, beta1, beta2)
        
        # Create a simple dataloader for training
        indices = torch.arange(len(self.st_coords))
        dataloader = DataLoader(indices, batch_size=batch_size, shuffle=True, drop_last=False)
        
        # Training metrics
        diffusion_losses = []
        structure_losses = []
        total_losses = []
        
        # Linear lambda schedule
        lambda_scheduler = lambda epoch: lambda_start + (lambda_end - lambda_start) * min(epoch / (n_epochs * 0.8), 1.0)
        
        # Optimizer and LR scheduler
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-4, weight_decay=1e-5)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs, eta_min=1e-6)
        
        self.model.train()
        for epoch in range(n_epochs):
            epoch_diffusion_loss = 0
            epoch_structure_loss = 0
            epoch_total_loss = 0
            
            # Current lambda value
            lambda_structure = lambda_scheduler(epoch)
            
            for batch_idx in dataloader:
                # Get batch data
                coords = self.st_coords[batch_idx]
                gene_expr = self.st_gene_expr[batch_idx]
                gnn_emb = self.st_gnn_embeddings[batch_idx]
                
                # Sample random timesteps
                t = torch.randint(0, timesteps, (len(batch_idx),), device=self.device)
                
                # Add noise to coordinates
                noisy_coords, noise = self.add_noise(coords, t, noise_schedule)
                
                # Predict noise
                predicted_noise = self.model(noisy_coords, t.unsqueeze(1).float() / timesteps, gene_expr, gnn_emb)
                
                # Diffusion loss (MSE between predicted and actual noise)
                diffusion_loss = F.mse_loss(predicted_noise, noise)
                
                # One-step denoising to get estimated coordinates
                sqrt_alphas_cumprod_t = noise_schedule['sqrt_alphas_cumprod'][t].view(-1, 1)
                sqrt_one_minus_alphas_cumprod_t = noise_schedule['sqrt_one_minus_alphas_cumprod'][t].view(-1, 1)
                
                # Estimate the clean coordinates using the predicted noise
                pred_coords = (noisy_coords - sqrt_one_minus_alphas_cumprod_t * predicted_noise) / sqrt_alphas_cumprod_t
                
                # Calculate structure loss using KL divergence
                pred_distances = torch.cdist(pred_coords, pred_coords, p=2)
                true_distances = torch.cdist(coords, coords, p=2)
                
                # Normalize distances to probabilities for KL divergence
                pred_distances = pred_distances / (pred_distances.sum(dim=1, keepdim=True) + 1e-8)
                true_distances = true_distances / (true_distances.sum(dim=1, keepdim=True) + 1e-8)

                #calc adj matrices
                pred_adj = self._calculate_adjacency_from_distances(pred_distances)
                true_adj = self._calculate_adjacency_from_distances(true_distances)
                
                # KL divergence for structural loss (more stable than MSE)
                # structure_loss = F.kl_div(
                #     torch.log(pred_distances + 1e-8),
                #     true_distances,
                #     reduction='batchmean'
                # )

                structure_loss = F.kl_div(
                    torch.log(pred_adj + 1e-8),
                    true_adj,
                    reduction='batchmean'
                )
                
                # Total loss with scheduled lambda
                total_loss = diffusion_loss + lambda_structure * structure_loss
                
                # Backpropagation
                optimizer.zero_grad()
                total_loss.backward()
                
                # Gradient clipping for stability
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                
                optimizer.step()
                
                # Accumulate losses
                epoch_diffusion_loss += diffusion_loss.item()
                epoch_structure_loss += structure_loss.item()
                epoch_total_loss += total_loss.item()
            
            # Step scheduler
            scheduler.step()
            
            # Average losses over batches
            n_batches = len(dataloader)
            epoch_diffusion_loss /= n_batches
            epoch_structure_loss /= n_batches
            epoch_total_loss /= n_batches
            
            # Store metrics
            diffusion_losses.append(epoch_diffusion_loss)
            structure_losses.append(epoch_structure_loss)
            total_losses.append(epoch_total_loss)
            
            # Print progress
            if (epoch + 1) % (n_epochs // 20) == 0 or epoch == 0:
                print(f"Epoch {epoch+1}/{n_epochs}, "
                    f"Diffusion Loss: {epoch_diffusion_loss:.6f}, "
                    f"Structure Loss: {epoch_structure_loss:.6f}, "
                    f"Total Loss: {epoch_total_loss:.6f}, "
                    f"Lambda: {lambda_structure:.2f}, "
                    f"LR: {scheduler.get_last_lr()[0]:.6f}")
                        
            # Save checkpoint
            if (epoch + 1) % save_every == 0:
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'diffusion_loss': epoch_diffusion_loss,
                    'structure_loss': epoch_structure_loss,
                    'total_loss': epoch_total_loss
                }, f"{checkpoint_dir}/model_epoch_{epoch+1}.pt")
        
        return {
            'diffusion_losses': diffusion_losses,
            'structure_losses': structure_losses,
            'total_losses': total_losses
        }

    def denormalize_coordinates(self, normalized_coords):
        # if isinstance(normalized_coords, torch.Tensor):
        #     normalized_coords = normalized_coords.cpu().numpy()

        # #apply inverse norm
        # og_coords = (normalized_coords + 1) / 2 * self.coords_range + self.coords_min

        if isinstance(normalized_coords, np.ndarray):
            coords_range = self.coords_range.cpu().numpy() if isinstance(self.coords_range, torch.Tensor) else self.coords_range
            coords_min = self.coords_min.cpu().numpy() if isinstance(self.coords_min, torch.Tensor) else self.coords_min
            og_coords = (normalized_coords + 1) / 2 * coords_range + coords_min
        else:
            # If normalized_coords is a tensor
            coords_range = self.coords_range if isinstance(self.coords_range, torch.Tensor) else torch.tensor(self.coords_range, device=normalized_coords.device)
            coords_min = self.coords_min if isinstance(self.coords_min, torch.Tensor) else torch.tensor(self.coords_min, device=normalized_coords.device)
            og_coords = (normalized_coords + 1) / 2 * coords_range + coords_min

        return og_coords
    
    @torch.no_grad()
    def sample_without_D_induced(self, timesteps=None):
        """Generate coordinates for all SC cells at once without using D_induced"""
        print("Sampling coordinates for SC cells without structure guidance...")
        self.model.eval()
        
        # Compute GNN embeddings for SC data
        sc_gnn_embeddings = self.compute_sc_gnn_embeddings()
        
        # If no timesteps specified, use default
        timesteps = timesteps or self.n_timesteps
        
        # Start from random noise for all cells
        x = torch.randn(len(self.sc_gene_expr), 2, device=self.device)
        
        # Gradually denoise
        for t in tqdm(range(timesteps-1, -1, -1), desc="Sampling coordinates"):
            # Create timestep tensor
            time_tensor = torch.ones(len(x), 1, device=self.device) * t / timesteps
            
            # Predict noise using gene expression and GNN embeddings as conditioning
            pred_noise = self.model(x, time_tensor, self.sc_gene_emxpr, sc_gnn_embeddings)
            
            # Get parameters for this timestep
            alpha_t = self.noise_schedule['alphas'][t]
            alpha_cumprod_t = self.noise_schedule['alphas_cumprod'][t]
            beta_t = self.noise_schedule['betas'][t]
            
            # Apply noise (except for last step)
            if t > 0:
                noise = torch.randn_like(x)
            else:
                noise = 0
            
            # Update sample with reverse diffusion step
            x = (1 / torch.sqrt(alpha_t)) * (
                x - ((1 - alpha_t) / torch.sqrt(1 - alpha_cumprod_t)) * pred_noise
            ) + torch.sqrt(beta_t) * noise
        
        return x.cpu().numpy()


    @torch.no_grad()
    def sample_coordinates(self, timesteps=None, batch_size=None):
        '''generate coords for al SC cells at once with structure constraints'''
        print('sampling coords for SC cells....')
        self.model.eval()

        #compute gnn embeddings for sc data
        sc_gnn_embeddings = self.compute_sc_gnn_embeddings()

        timesteps = timesteps or self.n_timesteps

        #proecss all cells at once if batch_size is None and gpu memroy alows
        #otherwise, process in batches
        if batch_size is None:
            return self._sample_all_coordinates(timesteps, sc_gnn_embeddings)
        else:
            return self._sample_batch_coordinates(timesteps, sc_gnn_embeddings, batch_size)
        
    @torch.no_grad()
    def _sample_all_coordinates(self, timesteps, sc_gnn_embeddings):
        '''sa,ple coordinates for all cells at once'''
        #start from random noise
        x = torch.randn(len(self.sc_gene_expr), 2, device=self.device)

        #gradual denoise
        for t in tqdm(range(timesteps-1, -1, -1), desc='sampling coordinates'):
            #create timestep tensor
            time_tensor = torch.ones(len(x), 1, device=self.device) * t / timesteps

            #predict noise using gene expression and gnn embeddings as condition
            # pred_noise = self.model(x, time_tensor, self.sc_gene_expr, sc_gnn_embeddings)
            pred_noise = self.model(x, time_tensor, self.sc_gene_expr)


            #get params for this timestep
            alpha_t = self.noise_schedule['alphas'][t]
            alpha_cumprod_t = self.noise_schedule['alphas_cumprod'][t]
            beta_t = self.noise_schedule['betas'][t]

            #apply noise 
            if t > 0:
                noise = torch.randn_like(x)
            else:
                noise = 0

            #update sample with reverse diffusion step
            x = (1/ torch.sqrt(alpha_t)) * (
                x - ((1-alpha_t) / torch.sqrt(1- alpha_cumprod_t)) * pred_noise
            ) + torch.sqrt(beta_t) * noise

            sc_coords = x.cpu().numpy()

            sc_coords = self.denormalize_coordinates(sc_coords)

            # #appl structure constraint after sometime
            # if t < timesteps * 0.8 and t % 3 == 0:
            #     #adjust coordinates to match D_induced
            #     #do periodically to avoid over correction
            #     x = self._adjust_to_match_distances(x, scale=0.05)

        return sc_coords
    
    @torch.no_grad()
    def _sample_batch_coordinates(self, timesteps, sc_gnn_embeddings, batch_size):
        '''sample coordinates in batches to save memory'''
        n_cells = len(self.sc_gene_expr)
        coordinates = np.zeros((n_cells, 2))

        #process in batches
        for start_idx in range(0, n_cells, batch_size):
            end_idx = min(start_idx+ batch_size, n_cells)
            batch_size_actual = end_idx - start_idx

            #get batch data
            batch_genes = self.sc_gene_expr[start_idx: end_idx]
            batch_gnn = sc_gnn_embeddings[start_idx: end_idx]

            #start from random noise
            x = torch.randn(batch_size_actual, 2, device=self.device)

            #gradually denoise
            for t in tqdm(range(timesteps-1, -1, -1), desc=f'sampling batch {start_idx//batch_size +1} / {(n_cells-1)// batch_size+1}'):
                #create timestep tensor
                time_tensor = torch.ones(batch_size_actual, 1, device=self.device) * t / timesteps

                #predict nosie using gene expression and gnn embeddinsg as conditioning
                # pred_noise = self.model(x, time_tensor, batch_genes, batch_gnn)
                pred_noise = self.model(x, time_tensor, batch_genes)

                #gete params for this timestep
                alpha_t = self.noise_schedule['alphas'][t]
                alpha_cumprod_t = self.noise_schedule['alphas_cumprod'][t]
                beta_t = self.noise_schedule['betas'][t]

                #apply noise 
                if t > 0:
                    noise = torch.randn_like(x)
                else:
                    noise = 0

                #update sample with reverse diffusion step
                x = (1/ torch.sqrt(alpha_t)) * (
                    x - ((1-alpha_t)/ torch.sqrt(1- alpha_cumprod_t)) * pred_noise 
                ) + torch.sqrt(beta_t) * noise

            #store batch results
            coordinates[start_idx: end_idx] = x.cpu().numpy()

        return coordinates
        

    def _adjust_to_match_distances(self, coords, scale=0.05):
        '''adjust coordinates to better match D_induced during sampling'''
        #compute cuurent distances
        current_dist = torch.cdist(coords, coords)

        #process in batches for large matrices
        if self.D_induced.shape[0] > 10000:
            current_dist_norm = current_dist / current_dist.max()
            target_dist_norm = self.D_induced / self.D_induced.max()


            #calculate gradient direction (in batches if needed)
            batch_size = 512
            n_coords = coords.shape[0]
            grad_sum = torch.zeros_like(coords)

            for i in range(0, n_coords, batch_size):
                end_i = min(i + batch_size, n_coords)

                for j in range(0, n_coords, batch_size):
                    end_j = min(j+batch_size, n_coords)

                    #extract submatrices
                    curr_dist_sub = current_dist_norm[i:end_i, j:end_j]
                    # target_dist_sub = self.D_induced[i:end_i, j:end_j]
                    target_dist_sub = target_dist_norm[i:end_i, j:end_j]


                    #compute difference
                    diff_sub = curr_dist_sub - target_dist_sub

                    #accumulate gradients for coordinates i:end_i
                    #this compues how each coordinate should move to better match the target
                    for idx in range(i, end_i):
                        rel_idx = idx - i
                        #calculate direction vectors from this point to all others
                        if idx < coords.shape[0]:
                            directions = coords[idx].unsqueeze(0) - coords
                            #normalize directions
                            norms = torch.norm(directions, dim=1, keepdim=True)
                            normalized_dirs = directions / (norms + 1e-8)
                            #weight by distance differences
                            row_diff = diff_sub[rel_idx] if rel_idx < diff_sub.shape[0] else torch.zeros(diff_sub.shape[1], device=self.device)
                            weighted_dirs = normalized_dirs * row_diff.unsqueeze(1)
                            #sum contributions
                            grad_sum[idx] += weighted_dirs.sum(dim=0)

            else:
                #for smaller datasets we can do this in one go
                #normalize distances
                current_dist_norm = current_dist / current_dist.max()
                target_dist_norm = self.D_induced / self.D_induced.max()

                #compute difference
                diff = current_dist_norm - target_dist_norm

                #rehsape for broadcasting
                diff_expanded = diff.unsqueeze(-1) #[n, n, 1]

                #compute direction vectors between all pairs
                coords_expnaded_i = coords.unsqueeze(1) #[n, 1, 2]
                coords_expnaded_j = coords.unsqueeze(0) #[1, n, 2]
                directions = coords_expnaded_i - coords_expnaded_j #[n, n, 2]

                #normalize directions
                norms = torch.norm(directions, dim=2, keepdim=True)
                normalized_dirs = directions / (norms + 1e-8)

                #weigh by distance difference
                weighted_dirs = normalized_dirs * diff_expanded

                #sum along j dimension to get gradient for each point
                grad_sum = weighted_dirs.sum(dim=1)

            #apply gradients with scaling
            coords_adj = coords - scale * grad_sum

            return coords_adj

In [None]:
D_induced

In [None]:
# Initialize the diffusion model with the required parameters
if isinstance(Y_st, torch.Tensor):
    Y_st = Y_st.cpu().numpy()
if isinstance(X_st, torch.Tensor):
    X_st = X_st.cpu().numpy()
if isinstance(X_sc, torch.Tensor):
    X_sc = X_sc.cpu().numpy()
if isinstance(D_induced, torch.Tensor):
    D_induced = D_induced.cpu().numpy()

# Initialize the diffusion model
diffusion = CoordinateDiffusion(
    st_gene_expr=X_st,
    st_coords=Y_st,
    sc_gene_expr=X_sc,
    D_induced=D_induced,
    device="cuda",
    n_timesteps=800,  # Increased timesteps for better quality
    beta_start=1e-4,  # Standard DDPM values
    beta_end=0.02,
    gnn_dim=128  # Dimension for GNN embeddings
)

print("Pretraining GNN model...")
# Pretrain the GNN model to get better embeddings
diffusion.pretrain_gnn(epochs=500)

print("Training diffusion model...")

# First phase - coordinate denoising with structure preservation
# diffusion.pretrain_denoising(
#     n_epochs=2000,     # More epochs for better convergence
#     batch_size=192,    # Larger batch size for stability
#     lambda_structure=1.0  # Weight for structure preservation loss
# )

# # Second phase - conditional training with gene expression and GNN embeddings
# diffusion.train_conditional(
#     n_epochs=3000,     # More epochs for conditional training
#     batch_size=192,    # Keep batch size consistent
#     lambda_structure=5.0  # Increased weight for structure in second phase
# )
# Single-stage training with structural constraints
diffusion.train(
    n_epochs=3000,
    batch_size=256,
    timesteps=800,
    beta1=1e-4,
    beta2=0.02,
    lambda_start=10,  # Start with low structure weight
    lambda_end=50,    # Gradually increase to moderate weight
    save_every=500,
    checkpoint_dir="./checkpoints"
)



In [None]:
sc_coords = diffusion.sample_without_D_induced(timesteps=800)

In [None]:
# Sample coordinates with structure enforcement
# If you have a very large SC dataset, you can use batching
if X_sc.shape[0] > 11000:
    sc_coords = diffusion.sample_coordinates(
        timesteps=800,    # Can use fewer timesteps for sampling
        batch_size=11000   # Process in batches to save memory
    )
else:
    # Process all cells at once for better global structure
    sc_coords = diffusion.sample_coordinates(
        timesteps=800     # Can use fewer timesteps for sampling
    )

In [None]:
# sc_coords = diffusion.sample_sc_coordinates(timesteps=1000, lambda_structure=2.0)

adata.obsm['diffusion_coords'] = sc_coords
import scanpy as sc
import matplotlib.pyplot as plt

sc.settings.set_figure_params(figsize=(10, 10))


#plot using scanypy's spatial plotting function with cell types as colors
sc.pl.embedding(adata, basis='diffusion_coords', color = 'celltype_mapped_refined',
                size=75, title='SC spatial coordinates (Diffusion Model)',
                palette='tab20', legend_loc='right margin', legend_fontsize=10)

In [None]:
# After loading your data
if isinstance(Y_st, torch.Tensor):
    Y_st = Y_st.cpu().numpy()
if isinstance(X_st, torch.Tensor):
    X_st = X_st.cpu().numpy()
if isinstance(X_sc, torch.Tensor):
    X_sc = X_sc.cpu().numpy()
if isinstance(D_induced, torch.Tensor):
    D_induced = D_induced.cpu().numpy()

# Run the analysis
results = run_gnn_analysis(X_sc, D_induced)