In [32]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

# Cell 2: force single‐threaded BLAS
os.environ["OMP_NUM_THREADS"]       = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"

In [33]:
# Cell 3: actually cap BLAS to 1 thread
from threadpoolctl import threadpool_limits

# 'blas' covers OpenBLAS, MKL, etc.
threadpool_limits(limits=1, user_api='blas')

# now import as usual, no more warning
import numpy as np
import scipy
# … any other packages that use OpenBLAS …


In [34]:
# import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import torch
import scanpy as sc
import pandas as pd
import numpy as np
from tqdm import tqdm
from sklearn.neighbors import kneighbors_graph
from scipy.sparse.csgraph import dijkstra
from scipy.sparse import csr_matrix, issparse
from sklearn.preprocessing import normalize
import ot 
from sklearn.neighbors import NearestNeighbors
import matplotlib.pyplot as plt
import pandas as pd

In [35]:
from model import AdvancedHierarchicalDiffusion

# patient 2 data load

In [36]:
def load_and_process_cscc_data():
    """
    Load and process the cSCC dataset with multiple ST replicates.
    """
    print("Loading cSCC data...")
    
    # Load SC data
    scadata = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/scP2.h5ad')
    
    # Load all 3 ST datasets
    stadata1 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP2.h5ad')
    stadata2 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP2rep2.h5ad')
    stadata3 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP2rep3.h5ad')
    
    # Normalize and log transform
    for adata in [scadata, stadata1, stadata2, stadata3]:
        sc.pp.normalize_total(adata)
        sc.pp.log1p(adata)
    
    # Create rough cell types for SC data
    scadata.obs['rough_celltype'] = scadata.obs['level1_celltype'].astype(str)
    scadata.obs.loc[scadata.obs['level1_celltype']=='CLEC9A','rough_celltype'] = 'DC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='CD1C','rough_celltype'] = 'DC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='ASDC','rough_celltype'] = 'DC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='PDC','rough_celltype'] = 'PDC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='MDSC','rough_celltype'] = 'DC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='LC','rough_celltype'] = 'DC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='Mac','rough_celltype'] = 'Myeloid cell'
    scadata.obs.loc[scadata.obs['level1_celltype']=='Tcell','rough_celltype'] = 'T cell'
    scadata.obs.loc[scadata.obs['level2_celltype']=='TSK','rough_celltype'] = 'TSK'
    scadata.obs.loc[scadata.obs['level2_celltype'].isin(['Tumor_KC_Basal', 'Tumor_KC_Diff','Tumor_KC_Cyc']),'rough_celltype'] = 'NonTSK'
    
    return scadata, stadata1, stadata2, stadata3

def prepare_combined_st_for_diffusion(stadata1, stadata2, stadata3, scadata):
    """
    Combine all ST datasets for diffusion training while maintaining gene alignment.
    Key innovation: Use ALL ST data points for better training.
    """
    print("Preparing combined ST data for diffusion training...")
    
    # Get common genes between SC and all ST datasets
    sc_genes = set(scadata.var_names)
    st1_genes = set(stadata1.var_names)
    st2_genes = set(stadata2.var_names)
    st3_genes = set(stadata3.var_names)
    
    common_genes = sorted(list(sc_genes & st1_genes & st2_genes & st3_genes))
    print(f"Common genes across all datasets: {len(common_genes)}")
    
    # Extract aligned expression data
    sc_expr = scadata[:, common_genes].X
    st1_expr = stadata1[:, common_genes].X
    st2_expr = stadata2[:, common_genes].X
    st3_expr = stadata3[:, common_genes].X
    
    # Convert to dense if sparse
    if hasattr(sc_expr, 'toarray'):
        sc_expr = sc_expr.toarray()
    if hasattr(st1_expr, 'toarray'):
        st1_expr = st1_expr.toarray()
    if hasattr(st2_expr, 'toarray'):
        st2_expr = st2_expr.toarray()
    if hasattr(st3_expr, 'toarray'):
        st3_expr = st3_expr.toarray()
    
    # Get spatial coordinates
    st1_coords = stadata1.obsm['spatial']
    st2_coords = stadata2.obsm['spatial']
    st3_coords = stadata3.obsm['spatial']

    # Store separate coordinate lists for block-diagonal graph
    st_coords_list = [st1_coords, st2_coords, st3_coords]
    
    # Combine all ST data
    st_expr_combined = np.vstack([st1_expr, st2_expr, st3_expr])

    from sklearn.preprocessing import StandardScaler
    scaler = StandardScaler()
    st_expr_combined = scaler.fit_transform(st_expr_combined)

    st_coords_combined = np.vstack([st1_coords, st2_coords, st3_coords])

    sc_expr = scaler.fit_transform(sc_expr)

    
    # Create dataset labels for tracking
    dataset_labels = (['dataset1'] * len(st1_expr) + 
                     ['dataset2'] * len(st2_expr) + 
                     ['dataset3'] * len(st3_expr))
    
    print(f"Combined ST data shape: {st_expr_combined.shape}")
    print(f"Combined ST coords shape: {st_coords_combined.shape}")
    print(f"SC data shape: {sc_expr.shape}")
    
    # Convert to tensors
    X_sc = torch.tensor(sc_expr, dtype=torch.float32)
    X_st_combined = torch.tensor(st_expr_combined, dtype=torch.float32)
    Y_st_combined = st_coords_combined.astype(np.float32)
    
    return X_sc, X_st_combined, Y_st_combined, dataset_labels, common_genes, st_coords_list

# Load and process data
scadata, stadata1, stadata2, stadata3 = load_and_process_cscc_data()

# Prepare combined data for diffusion
X_sc, X_st_combined, Y_st_combined, dataset_labels, common_genes, st_coords_list = prepare_combined_st_for_diffusion(
    stadata1, stadata2, stadata3, scadata
)

print(f"Data preparation complete!")
print(f"SC cells: {X_sc.shape[0]}")
print(f"Combined ST spots: {X_st_combined.shape[0]}")
print(f"Common genes: {len(common_genes)}")



Loading cSCC data...
Preparing combined ST data for diffusion training...
Common genes across all datasets: 2000
Combined ST data shape: (1950, 2000)
Combined ST coords shape: (1950, 2)
SC data shape: (2688, 2000)
Data preparation complete!
SC cells: 2688
Combined ST spots: 1950
Common genes: 2000


In [37]:
def load_and_process_cscc_data_individual_norm():
    """
    Load and process cSCC data with individual normalization per ST dataset.
    """
    print("Loading cSCC data with individual normalization...")
    
    # Load SC data
    scadata = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/scP2.h5ad')
    
    # Load all 3 ST datasets
    stadata1 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP2.h5ad')
    stadata2 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP2rep2.h5ad')
    stadata3 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP2rep3.h5ad')
    
    # Normalize expression data (same for all)
    for adata in [scadata, stadata1, stadata2, stadata3]:
        sc.pp.normalize_total(adata)
        sc.pp.log1p(adata)
    
    # Create rough cell types for SC data
    scadata.obs['rough_celltype'] = scadata.obs['level1_celltype'].astype(str)
    scadata.obs.loc[scadata.obs['level1_celltype']=='CLEC9A','rough_celltype'] = 'DC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='CD1C','rough_celltype'] = 'DC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='ASDC','rough_celltype'] = 'DC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='PDC','rough_celltype'] = 'PDC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='MDSC','rough_celltype'] = 'DC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='LC','rough_celltype'] = 'DC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='Mac','rough_celltype'] = 'Myeloid cell'
    scadata.obs.loc[scadata.obs['level1_celltype']=='Tcell','rough_celltype'] = 'T cell'
    scadata.obs.loc[scadata.obs['level2_celltype']=='TSK','rough_celltype'] = 'TSK'
    scadata.obs.loc[scadata.obs['level2_celltype'].isin(['Tumor_KC_Basal', 'Tumor_KC_Diff','Tumor_KC_Cyc']),'rough_celltype'] = 'NonTSK'


    
    return scadata, stadata1, stadata2, stadata3

def normalize_coordinates_individually(coords):
    """
    Normalize coordinates to [-1, 1] range individually.
    """
    coords_min = coords.min(axis=0)
    coords_max = coords.max(axis=0)
    coords_range = coords_max - coords_min
    
    # Avoid division by zero
    coords_range[coords_range == 0] = 1.0
    
    # Normalize to [-1, 1]
    coords_normalized = 2 * (coords - coords_min) / coords_range - 1
    
    return coords_normalized, coords_min, coords_max, coords_range

def prepare_individually_normalized_st_data(stadata1, stadata2, stadata3, scadata):
    """
    Normalize each ST dataset individually, then combine.
    """
    print("Preparing individually normalized ST data...")
    
    # Get common genes
    sc_genes = set(scadata.var_names)
    st1_genes = set(stadata1.var_names)
    st2_genes = set(stadata2.var_names)
    st3_genes = set(stadata3.var_names)
    
    common_genes = sorted(list(sc_genes & st1_genes & st2_genes & st3_genes))
    print(f"Common genes across all datasets: {len(common_genes)}")
    
    # Extract aligned expression data
    sc_expr = scadata[:, common_genes].X
    st1_expr = stadata1[:, common_genes].X
    st2_expr = stadata2[:, common_genes].X
    st3_expr = stadata3[:, common_genes].X
    
    # Convert to dense if sparse
    if hasattr(sc_expr, 'toarray'):
        sc_expr = sc_expr.toarray()
    if hasattr(st1_expr, 'toarray'):
        st1_expr = st1_expr.toarray()
    if hasattr(st2_expr, 'toarray'):
        st2_expr = st2_expr.toarray()
    if hasattr(st3_expr, 'toarray'):
        st3_expr = st3_expr.toarray()
    
    # Get spatial coordinates and normalize individually
    st1_coords = stadata1.obsm['spatial']
    st2_coords = stadata2.obsm['spatial']
    st3_coords = stadata3.obsm['spatial']
    
    print("Normalizing coordinates individually...")
    st1_coords_norm, st1_min, st1_max, st1_range = normalize_coordinates_individually(st1_coords)
    st2_coords_norm, st2_min, st2_max, st2_range = normalize_coordinates_individually(st2_coords)
    st3_coords_norm, st3_min, st3_max, st3_range = normalize_coordinates_individually(st3_coords)
    
    print(f"ST1 coord range: [{st1_coords_norm.min():.3f}, {st1_coords_norm.max():.3f}]")
    print(f"ST2 coord range: [{st2_coords_norm.min():.3f}, {st2_coords_norm.max():.3f}]")
    print(f"ST3 coord range: [{st3_coords_norm.min():.3f}, {st3_coords_norm.max():.3f}]")
    
    # Combine all ST data
    st_expr_combined = np.vstack([st1_expr, st2_expr, st3_expr])
    st_coords_combined = np.vstack([st1_coords_norm, st2_coords_norm, st3_coords_norm])
    
    # Create dataset metadata
    dataset_info = {
        'labels': (['dataset1'] * len(st1_expr) + 
                  ['dataset2'] * len(st2_expr) + 
                  ['dataset3'] * len(st3_expr)),
        'normalization_params': {
            'dataset1': {'min': st1_min, 'max': st1_max, 'range': st1_range},
            'dataset2': {'min': st2_min, 'max': st2_max, 'range': st2_range},
            'dataset3': {'min': st3_min, 'max': st3_max, 'range': st3_range}
        }
    }
    
    print(f"Combined ST data shape: {st_expr_combined.shape}")
    print(f"Combined ST coords shape: {st_coords_combined.shape}")
    print(f"SC data shape: {sc_expr.shape}")
    
    # Convert to tensors
    X_sc = torch.tensor(sc_expr, dtype=torch.float32)
    X_st_combined = torch.tensor(st_expr_combined, dtype=torch.float32)
    Y_st_combined = st_coords_combined.astype(np.float32)
    
    return X_sc, X_st_combined, Y_st_combined, dataset_info, common_genes

In [38]:
scadata, stadata1, stadata2, stadata3 = load_and_process_cscc_data_individual_norm()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Loading cSCC data with individual normalization...


In [39]:
def train_individual_advanced_diffusion_models(scadata, stadata1, stadata2, stadata3):
    """
    Train separate AdvancedHierarchicalDiffusion models for each ST dataset and average the results.
    MODIFIED: Run stadata1 three times to test for SC cluster rotation/sliding
    """
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Store results from each model
    sc_coords_results = []
    models_all = []
    
    # STEP 1: Build canonical angular frame from ST slide (ONCE)
    # st_coords_raw = stadata1.obsm['spatial']  # Use raw ST coordinates
    # angular_frame = _build_canonical_angular_frame(st_coords_raw)
    
    # List of ST datasets for iteration - Use stadata1 three times
    st_datasets = [
        (stadata1, "run1"),
        (stadata2, "run2"), 
        (stadata3, "run3")
    ]
    
    for i, (stadata, run_name) in enumerate(st_datasets):
        print(f"\n{'='*50}")
        print(f"Training AdvancedHierarchicalDiffusion model {i+1}/3 for {run_name}")
        print(f"{'='*50}")
        
        # Get common genes between SC and current ST dataset
        sc_genes = set(scadata.var_names)
        st_genes = set(stadata.var_names)
        common_genes = sorted(list(sc_genes & st_genes))
        
        print(f"Common genes for {run_name}: {len(common_genes)}")
        
        # Extract expression data
        sc_expr = scadata[:, common_genes].X
        st_expr = stadata[:, common_genes].X
        
        # Convert to dense if sparse
        if hasattr(sc_expr, 'toarray'):
            sc_expr = sc_expr.toarray()
        if hasattr(st_expr, 'toarray'):
            st_expr = st_expr.toarray()
            
        # Get spatial coordinates
        st_coords = stadata.obsm['spatial']
        
        print(f"SC data shape: {sc_expr.shape}")
        print(f"ST data shape: {st_expr.shape}")
        print(f"ST coords shape: {st_coords.shape}")
        
        # Initialize model with different random seed for each run
        torch.manual_seed(42 + i)
        np.random.seed(42 + i)
        
        model = AdvancedHierarchicalDiffusion(
            st_gene_expr=st_expr,
            st_coords=st_coords,
            sc_gene_expr=sc_expr,
            cell_types_sc=scadata.obs['rough_celltype'].values,
            transport_plan=None,
            D_st=None,
            D_induced=None,
            n_genes=len(common_genes),
            n_embedding=[512, 256, 128],
            coord_space_diameter=2.00,
            sigma=0.75,
            alpha=0.8,
            mmdbatch=1000,
            batch_size=256,
            device=device,
            lr_e=0.0001,
            lr_d=0.0002,
            n_timesteps=800,
            n_denoising_blocks=6,
            hidden_dim=256,
            num_heads=8,
            num_hierarchical_scales=3,
            dp=0.2,
            outf=f'advanced_diffusion_{run_name}'
        )

        # Train the model
        print(f"Training model for {run_name}...")
        model.train(
            encoder_epochs=1000,
            vae_epochs=1200,
            diffusion_epochs=2500,
            lambda_struct=10.0
        )

        st_coords_raw = model.st_coords_norm.cpu().numpy()  # Use normalized coords from model
        angular_frame = _build_canonical_angular_frame(st_coords_raw)
        
        # Generate SC coordinates
        print(f"Generating SC coordinates using {run_name} model...")
        # sc_coords = model.generate_sc_coordinates()
        # sc_coords = model.sample_sc_coordinates_batched(
        #     batch_size=512,  # Even smaller batches
        #     refine_coords=False
        # )

        # Sample with geometry guidance
        sc_coords = model.sample_sc_coordinates_with_geometry_guidance(
            batch_size=512,
            guidance_scale=1.0,
            geometry_guidance=False,
            k_nn=5,
            lambda_trip=0.05,
            lambda_rep=0.08,
            gamma=0.05
        )

        print(f"\n=== Generated SC Coordinates ({run_name}) ===")
        print(f"  X range: [{sc_coords[:, 0].min():.3f}, {sc_coords[:, 0].max():.3f}]")
        print(f"  Y range: [{sc_coords[:, 1].min():.3f}, {sc_coords[:, 1].max():.3f}]")
        # print(f"  Max radius from center: {np.max(np.linalg.norm(sc_coords, axis=1)):.3f}")
        # print(f"  % points outside unit circle: {(np.linalg.norm(sc_coords, axis=1) > 1).mean()*100:.1f}%")

        # Evaluate geometry preservation
        metrics = model.evaluate_geometry_preservation(sc_coords)
        sc_coords_results.append(sc_coords)
        models_all.append(model)
        
        # STEP 2: Plot SC cells colored by angle (using ST-derived frame)
        _plot_sc_angle_analysis(sc_coords, scadata.obs['rough_celltype'].values, 
                               angular_frame, st_coords_raw, run_name, i+1)
    
    # STEP 3: Comparative analysis across runs
    _plot_comparative_sc_angle_analysis(sc_coords_results, scadata.obs['rough_celltype'].values,
                                       angular_frame, st_coords_raw)
    
    # Compute averaged SC coordinates
    sc_coords_results_np = [coords.cpu().numpy() for coords in sc_coords_results]  # Convert to numpy
    sc_coords_avg = np.mean(sc_coords_results_np, axis=0)
    sc_coords_std = np.std(sc_coords_results_np, axis=0)

    # Store results in scadata
    scadata.obsm['advanced_diffusion_coords_avg'] = sc_coords_avg
    scadata.obsm['advanced_diffusion_coords_std'] = sc_coords_std

    # Store individual results
    for i, coords in enumerate(sc_coords_results_np):  # Use numpy arrays
        scadata.obsm[f'advanced_diffusion_coords_rep{i+1}'] = coords

    print(f"\nTraining complete. Results stored in scadata.obsm")
    return scadata, models_all

def _build_canonical_angular_frame(st_coords):
    """Build canonical angular frame from ST coordinates (dataset-specific, run-independent)"""
    import numpy as np
    
    # Compute centroid
    centroid = st_coords.mean(axis=0)
    
    # Find farthest spot from centroid (deterministic 0° direction)
    distances = np.linalg.norm(st_coords - centroid, axis=1)
    farthest_idx = np.argmax(distances)
    a0 = st_coords[farthest_idx] - centroid  # 0° direction vector
    
    def angle_fn(x):
        """Compute angle from canonical frame"""
        if x.ndim == 1:
            x = x.reshape(1, -1)
        
        v = x - centroid
        cross = a0[0] * v[:, 1] - a0[1] * v[:, 0]  # z-component of 2D cross
        dot = a0[0] * v[:, 0] + a0[1] * v[:, 1]
        angles = np.arctan2(cross, dot)
        angles = np.where(angles < 0, angles + 2*np.pi, angles)  # Map to [0, 2π)
        return angles
    
    return {
        'centroid': centroid,
        'zero_direction': a0,
        'farthest_idx': farthest_idx,
        'angle_fn': angle_fn
    }

def _plot_sc_angle_analysis(sc_coords, cell_types, angular_frame, st_coords_bg, run_name, run_num):
    """Plot SC cells colored by angle from ST-derived frame"""
    import matplotlib.pyplot as plt
    import numpy as np
    
    # Compute angles for SC cells using ST-derived frame
    sc_angles = angular_frame['angle_fn'](sc_coords)
    sc_angles_degrees = np.degrees(sc_angles)
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    # Plot 1: SC cells colored by angle (with ST outline in background)
    ax1.scatter(st_coords_bg[:, 0], st_coords_bg[:, 1], 
               c='lightgray', s=10, alpha=0.3, label='ST outline')
    
    scatter = ax1.scatter(sc_coords[:, 0], sc_coords[:, 1], 
                         c=sc_angles_degrees, cmap='hsv', s=30, alpha=0.8)
    
    # Mark centroid and 0° direction
    centroid = angular_frame['centroid']
    zero_dir = angular_frame['zero_direction']
    ax1.scatter(centroid[0], centroid[1], c='black', s=100, marker='x', linewidth=3)
    ax1.arrow(centroid[0], centroid[1], zero_dir[0]*0.3, zero_dir[1]*0.3, 
              head_width=0.05, head_length=0.05, fc='red', ec='red', linewidth=2)
    
    ax1.set_title(f'{run_name}: SC Cells Colored by Angle θ')
    ax1.set_xlabel('X coordinate')
    ax1.set_ylabel('Y coordinate')
    ax1.set_aspect('equal')
    ax1.legend()
    
    # Add colorbar
    cbar = plt.colorbar(scatter, ax=ax1)
    cbar.set_label('Angle (degrees)')
    
    # Plot 2: Per-cell-type angle distribution
    unique_types = np.unique(cell_types)
    colors = plt.cm.tab10(np.linspace(0, 1, len(unique_types)))
    
    for i, cell_type in enumerate(unique_types):
        mask = cell_types == cell_type
        if np.sum(mask) > 0:
            angles_subset = sc_angles_degrees[mask]
            ax2.hist(angles_subset, bins=36, alpha=0.6, label=cell_type, 
                    color=colors[i], density=True)
    
    ax2.set_title(f'{run_name}: Angle Distribution by Cell Type')
    ax2.set_xlabel('Angle (degrees)')
    ax2.set_ylabel('Density')
    ax2.set_xlim(0, 360)
    ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    
    plt.tight_layout()
    plt.savefig(f'sc_angle_analysis_{run_name}.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    # Print circular statistics per cell type
    print(f"\n{run_name} - Circular statistics per cell type:")
    for cell_type in unique_types:
        mask = cell_types == cell_type
        if np.sum(mask) > 5:  # Only if enough cells
            angles_rad = sc_angles[mask]
            # Circular mean
            mean_cos = np.mean(np.cos(angles_rad))
            mean_sin = np.mean(np.sin(angles_rad))
            circular_mean = np.arctan2(mean_sin, mean_cos)
            if circular_mean < 0:
                circular_mean += 2*np.pi
            
            print(f"  {cell_type}: mean={np.degrees(circular_mean):.1f}°, n={np.sum(mask)}")

def _plot_comparative_sc_angle_analysis(sc_coords_list, cell_types, angular_frame, st_coords_bg):
    """Plot comparative SC angle analysis across all runs"""
    import matplotlib.pyplot as plt
    import numpy as np
    
    n_runs = len(sc_coords_list)
    unique_types = np.unique(cell_types)
    
    fig, axes = plt.subplots(2, n_runs, figsize=(5*n_runs, 10))
    if n_runs == 1:
        axes = axes.reshape(-1, 1)
    
    # Top row: SC scatter plots per run
    for i, sc_coords in enumerate(sc_coords_list):
        ax = axes[0, i]
        
        # ST background
        ax.scatter(st_coords_bg[:, 0], st_coords_bg[:, 1], 
                  c='lightgray', s=5, alpha=0.3)
        
        # SC cells colored by angle
        sc_angles = angular_frame['angle_fn'](sc_coords)
        sc_angles_degrees = np.degrees(sc_angles)
        
        scatter = ax.scatter(sc_coords[:, 0], sc_coords[:, 1], 
                           c=sc_angles_degrees, cmap='hsv', s=20, alpha=0.8)
        
        ax.set_title(f'Run {i+1}: SC Cells by Angle')
        ax.set_aspect('equal')
        
        if i == n_runs-1:  # Add colorbar to last plot
            cbar = plt.colorbar(scatter, ax=ax)
            cbar.set_label('Angle (degrees)')
    
    # Bottom row: Cell type angle distributions per run  
    colors = plt.cm.tab10(np.linspace(0, 1, len(unique_types)))
    
    for i, sc_coords in enumerate(sc_coords_list):
        ax = axes[1, i]
        
        sc_angles = angular_frame['angle_fn'](sc_coords)
        sc_angles_degrees = np.degrees(sc_angles)
        
        for j, cell_type in enumerate(unique_types):
            mask = cell_types == cell_type
            if np.sum(mask) > 5:
                angles_subset = sc_angles_degrees[mask]
                ax.hist(angles_subset, bins=36, alpha=0.6, 
                       label=cell_type if i == 0 else "", 
                       color=colors[j], density=True)
        
        ax.set_title(f'Run {i+1}: Cell Type Angles')
        ax.set_xlabel('Angle (degrees)')
        ax.set_ylabel('Density')
        ax.set_xlim(0, 360)
        
        if i == 0:
            ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    
    plt.tight_layout()
    plt.savefig('comparative_sc_angle_analysis.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    # Check for sector sliding
    print(f"\n" + "="*60)
    print("SECTOR SLIDING ANALYSIS")
    print("="*60)
    
    for cell_type in unique_types:
        mask = cell_types == cell_type
        if np.sum(mask) > 10:  # Only analyze cell types with enough cells
            circular_means = []
            
            for i, sc_coords in enumerate(sc_coords_list):
                sc_angles = angular_frame['angle_fn'](sc_coords)
                angles_subset = sc_angles[mask]
                
                # Circular mean
                mean_cos = np.mean(np.cos(angles_subset))
                mean_sin = np.mean(np.sin(angles_subset))
                circular_mean = np.arctan2(mean_sin, mean_cos)
                if circular_mean < 0:
                    circular_mean += 2*np.pi
                
                circular_means.append(np.degrees(circular_mean))
            
            # Check for large differences between runs
            max_diff = max(circular_means) - min(circular_means)
            if max_diff > 180:  # Handle wraparound
                max_diff = 360 - max_diff
            
            print(f"{cell_type}:")
            print(f"  Run means: {[f'{m:.1f}°' for m in circular_means]}")
            print(f"  Max difference: {max_diff:.1f}°")
            
            if max_diff > 30:  # Significant sliding
                print(f"  ⚠️  SECTOR SLIDING DETECTED!")
            else:
                print(f"  ✅ Consistent placement")

# Load and process data
scadata, stadata1, stadata2, stadata3 = load_and_process_cscc_data()

# ADD THESE LINES:
for i, stdata in enumerate([stadata1, stadata2, stadata3], 1):
    coords = stdata.obsm['spatial']
    print(f"ST{i}: X[{coords[:, 0].min():.2f}, {coords[:, 0].max():.2f}], Y[{coords[:, 1].min():.2f}, {coords[:, 1].max():.2f}]")


# Train individual AdvancedHierarchicalDiffusion models and get averaged results
scadata, advanced_models = train_individual_advanced_diffusion_models(
    scadata, stadata1, stadata2, stadata3
)

print("Advanced diffusion training complete! Results saved in scadata.obsm['advanced_diffusion_coords_avg']")

# Visualize results
import matplotlib.pyplot as plt
import seaborn as sns

my_tab20 = sns.color_palette("tab20", n_colors=20).as_hex()

# Plot 1: Averaged coordinates
plt.figure(figsize=(8, 6))
sc.pl.embedding(scadata, basis='advanced_diffusion_coords_avg', color='rough_celltype',
               size=85, title='SC Advanced Diffusion Coords (Averaged)',
               palette=my_tab20, legend_loc='right margin', legend_fontsize=10)
plt.show()

# Plot 2: Individual model results
for i in range(3):
    plt.figure(figsize=(6, 5))
    sc.pl.embedding(scadata, basis=f'advanced_diffusion_coords_rep{i+1}', color='rough_celltype',
                   size=85, title=f'SC Coordinates (Advanced Model {i+1})',
                   palette=my_tab20, legend_loc='right margin', legend_fontsize=10)
    plt.show()

Loading cSCC data...
ST1: X[7.00, 50.00], Y[12.00, 56.00]
ST2: X[8.00, 50.00], Y[2.00, 44.00]
ST3: X[6.00, 48.00], Y[9.00, 52.00]

Training AdvancedHierarchicalDiffusion model 1/3 for run1
Common genes for run1: 2000
SC data shape: (2688, 2000)
ST data shape: (666, 2000)
ST coords shape: (666, 2)
D_st not provided, calculating from spatial coordinates...


NameError: name 'pd' is not defined