In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

# Cell 2: force single‐threaded BLAS
os.environ["OMP_NUM_THREADS"]       = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"

In [None]:
# Cell 3: actually cap BLAS to 1 thread
from threadpoolctl import threadpool_limits

# 'blas' covers OpenBLAS, MKL, etc.
threadpool_limits(limits=1, user_api='blas')

# now import as usual, no more warning
import numpy as np
import scipy
# … any other packages that use OpenBLAS …
from model.core_models_v2 import AdvancedHierarchicalDiffusion


In [None]:
# import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import torch
import scanpy as sc
import pandas as pd
import numpy as np
from tqdm import tqdm
from sklearn.neighbors import kneighbors_graph
from scipy.sparse.csgraph import dijkstra
from scipy.sparse import csr_matrix, issparse
from sklearn.preprocessing import normalize
import ot 
from sklearn.neighbors import NearestNeighbors
import matplotlib.pyplot as plt
import pandas as pd

# patient 2 data load

In [None]:
def load_and_process_cscc_data():
    """
    Load and process the cSCC dataset with multiple ST replicates.
    """
    print("Loading cSCC data...")
    
    # Load SC data
    scadata = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/scP2.h5ad')
    
    # Load all 3 ST datasets
    stadata1 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP2.h5ad')
    stadata2 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP2rep2.h5ad')
    stadata3 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP2rep3.h5ad')
    
    # Normalize and log transform
    for adata in [scadata, stadata1, stadata2, stadata3]:
        sc.pp.normalize_total(adata)
        sc.pp.log1p(adata)
    
    # Create rough cell types for SC data
    scadata.obs['rough_celltype'] = scadata.obs['level1_celltype'].astype(str)
    scadata.obs.loc[scadata.obs['level1_celltype']=='CLEC9A','rough_celltype'] = 'DC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='CD1C','rough_celltype'] = 'DC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='ASDC','rough_celltype'] = 'DC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='PDC','rough_celltype'] = 'PDC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='MDSC','rough_celltype'] = 'DC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='LC','rough_celltype'] = 'DC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='Mac','rough_celltype'] = 'Myeloid cell'
    scadata.obs.loc[scadata.obs['level1_celltype']=='Tcell','rough_celltype'] = 'T cell'
    scadata.obs.loc[scadata.obs['level2_celltype']=='TSK','rough_celltype'] = 'TSK'
    scadata.obs.loc[scadata.obs['level2_celltype'].isin(['Tumor_KC_Basal', 'Tumor_KC_Diff','Tumor_KC_Cyc']),'rough_celltype'] = 'NonTSK'
    
    return scadata, stadata1, stadata2, stadata3

def prepare_combined_st_for_diffusion(stadata1, stadata2, stadata3, scadata):
    """
    Combine all ST datasets for diffusion training while maintaining gene alignment.
    Key innovation: Use ALL ST data points for better training.
    """
    print("Preparing combined ST data for diffusion training...")
    
    # Get common genes between SC and all ST datasets
    sc_genes = set(scadata.var_names)
    st1_genes = set(stadata1.var_names)
    st2_genes = set(stadata2.var_names)
    st3_genes = set(stadata3.var_names)
    
    common_genes = sorted(list(sc_genes & st1_genes & st2_genes & st3_genes))
    print(f"Common genes across all datasets: {len(common_genes)}")
    
    # Extract aligned expression data
    sc_expr = scadata[:, common_genes].X
    st1_expr = stadata1[:, common_genes].X
    st2_expr = stadata2[:, common_genes].X
    st3_expr = stadata3[:, common_genes].X
    
    # Convert to dense if sparse
    if hasattr(sc_expr, 'toarray'):
        sc_expr = sc_expr.toarray()
    if hasattr(st1_expr, 'toarray'):
        st1_expr = st1_expr.toarray()
    if hasattr(st2_expr, 'toarray'):
        st2_expr = st2_expr.toarray()
    if hasattr(st3_expr, 'toarray'):
        st3_expr = st3_expr.toarray()
    
    # Get spatial coordinates
    st1_coords = stadata1.obsm['spatial']
    st2_coords = stadata2.obsm['spatial']
    st3_coords = stadata3.obsm['spatial']

    # Store separate coordinate lists for block-diagonal graph
    st_coords_list = [st1_coords, st2_coords, st3_coords]
    
    # Combine all ST data
    st_expr_combined = np.vstack([st1_expr, st2_expr, st3_expr])

    from sklearn.preprocessing import StandardScaler
    scaler = StandardScaler()
    st_expr_combined = scaler.fit_transform(st_expr_combined)

    st_coords_combined = np.vstack([st1_coords, st2_coords, st3_coords])

    sc_expr = scaler.fit_transform(sc_expr)

    
    # Create dataset labels for tracking
    dataset_labels = (['dataset1'] * len(st1_expr) + 
                     ['dataset2'] * len(st2_expr) + 
                     ['dataset3'] * len(st3_expr))
    
    print(f"Combined ST data shape: {st_expr_combined.shape}")
    print(f"Combined ST coords shape: {st_coords_combined.shape}")
    print(f"SC data shape: {sc_expr.shape}")
    
    # Convert to tensors
    X_sc = torch.tensor(sc_expr, dtype=torch.float32)
    X_st_combined = torch.tensor(st_expr_combined, dtype=torch.float32)
    Y_st_combined = st_coords_combined.astype(np.float32)
    
    return X_sc, X_st_combined, Y_st_combined, dataset_labels, common_genes, st_coords_list

# Load and process data
scadata, stadata1, stadata2, stadata3 = load_and_process_cscc_data()

# Prepare combined data for diffusion
X_sc, X_st_combined, Y_st_combined, dataset_labels, common_genes, st_coords_list = prepare_combined_st_for_diffusion(
    stadata1, stadata2, stadata3, scadata
)

print(f"Data preparation complete!")
print(f"SC cells: {X_sc.shape[0]}")
print(f"Combined ST spots: {X_st_combined.shape[0]}")
print(f"Common genes: {len(common_genes)}")



In [None]:
def load_and_process_cscc_data_individual_norm():
    """
    Load and process cSCC data with individual normalization per ST dataset.
    """
    print("Loading cSCC data with individual normalization...")
    
    # Load SC data
    scadata = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/scP2.h5ad')
    
    # Load all 3 ST datasets
    stadata1 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP2.h5ad')
    stadata2 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP2rep2.h5ad')
    stadata3 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP2rep3.h5ad')
    
    # Normalize expression data (same for all)
    for adata in [scadata, stadata1, stadata2, stadata3]:
        sc.pp.normalize_total(adata)
        sc.pp.log1p(adata)
    
    # Create rough cell types for SC data
    scadata.obs['rough_celltype'] = scadata.obs['level1_celltype'].astype(str)
    scadata.obs.loc[scadata.obs['level1_celltype']=='CLEC9A','rough_celltype'] = 'DC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='CD1C','rough_celltype'] = 'DC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='ASDC','rough_celltype'] = 'DC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='PDC','rough_celltype'] = 'PDC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='MDSC','rough_celltype'] = 'DC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='LC','rough_celltype'] = 'DC'
    scadata.obs.loc[scadata.obs['level1_celltype']=='Mac','rough_celltype'] = 'Myeloid cell'
    scadata.obs.loc[scadata.obs['level1_celltype']=='Tcell','rough_celltype'] = 'T cell'
    scadata.obs.loc[scadata.obs['level2_celltype']=='TSK','rough_celltype'] = 'TSK'
    scadata.obs.loc[scadata.obs['level2_celltype'].isin(['Tumor_KC_Basal', 'Tumor_KC_Diff','Tumor_KC_Cyc']),'rough_celltype'] = 'NonTSK'
    
    return scadata, stadata1, stadata2, stadata3

def normalize_coordinates_individually(coords):
    """
    Normalize coordinates to [-1, 1] range individually.
    """
    coords_min = coords.min(axis=0)
    coords_max = coords.max(axis=0)
    coords_range = coords_max - coords_min
    
    # Avoid division by zero
    coords_range[coords_range == 0] = 1.0
    
    # Normalize to [-1, 1]
    coords_normalized = 2 * (coords - coords_min) / coords_range - 1
    
    return coords_normalized, coords_min, coords_max, coords_range

def prepare_individually_normalized_st_data(stadata1, stadata2, stadata3, scadata):
    """
    Normalize each ST dataset individually, then combine.
    """
    print("Preparing individually normalized ST data...")
    
    # Get common genes
    sc_genes = set(scadata.var_names)
    st1_genes = set(stadata1.var_names)
    st2_genes = set(stadata2.var_names)
    st3_genes = set(stadata3.var_names)
    
    common_genes = sorted(list(sc_genes & st1_genes & st2_genes & st3_genes))
    print(f"Common genes across all datasets: {len(common_genes)}")
    
    # Extract aligned expression data
    sc_expr = scadata[:, common_genes].X
    st1_expr = stadata1[:, common_genes].X
    st2_expr = stadata2[:, common_genes].X
    st3_expr = stadata3[:, common_genes].X
    
    # Convert to dense if sparse
    if hasattr(sc_expr, 'toarray'):
        sc_expr = sc_expr.toarray()
    if hasattr(st1_expr, 'toarray'):
        st1_expr = st1_expr.toarray()
    if hasattr(st2_expr, 'toarray'):
        st2_expr = st2_expr.toarray()
    if hasattr(st3_expr, 'toarray'):
        st3_expr = st3_expr.toarray()
    
    # Get spatial coordinates and normalize individually
    st1_coords = stadata1.obsm['spatial']
    st2_coords = stadata2.obsm['spatial']
    st3_coords = stadata3.obsm['spatial']
    
    print("Normalizing coordinates individually...")
    st1_coords_norm, st1_min, st1_max, st1_range = normalize_coordinates_individually(st1_coords)
    st2_coords_norm, st2_min, st2_max, st2_range = normalize_coordinates_individually(st2_coords)
    st3_coords_norm, st3_min, st3_max, st3_range = normalize_coordinates_individually(st3_coords)
    
    print(f"ST1 coord range: [{st1_coords_norm.min():.3f}, {st1_coords_norm.max():.3f}]")
    print(f"ST2 coord range: [{st2_coords_norm.min():.3f}, {st2_coords_norm.max():.3f}]")
    print(f"ST3 coord range: [{st3_coords_norm.min():.3f}, {st3_coords_norm.max():.3f}]")
    
    # Combine all ST data
    st_expr_combined = np.vstack([st1_expr, st2_expr, st3_expr])
    st_coords_combined = np.vstack([st1_coords_norm, st2_coords_norm, st3_coords_norm])
    
    # Create dataset metadata
    dataset_info = {
        'labels': (['dataset1'] * len(st1_expr) + 
                  ['dataset2'] * len(st2_expr) + 
                  ['dataset3'] * len(st3_expr)),
        'normalization_params': {
            'dataset1': {'min': st1_min, 'max': st1_max, 'range': st1_range},
            'dataset2': {'min': st2_min, 'max': st2_max, 'range': st2_range},
            'dataset3': {'min': st3_min, 'max': st3_max, 'range': st3_range}
        }
    }
    
    print(f"Combined ST data shape: {st_expr_combined.shape}")
    print(f"Combined ST coords shape: {st_coords_combined.shape}")
    print(f"SC data shape: {sc_expr.shape}")
    
    # Convert to tensors
    X_sc = torch.tensor(sc_expr, dtype=torch.float32)
    X_st_combined = torch.tensor(st_expr_combined, dtype=torch.float32)
    Y_st_combined = st_coords_combined.astype(np.float32)
    
    return X_sc, X_st_combined, Y_st_combined, dataset_info, common_genes

In [None]:
scadata, stadata1, stadata2, stadata3 = load_and_process_cscc_data_individual_norm()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def train_individual_advanced_diffusion_models(scadata, stadata1, stadata2, stadata3):
    """
    Train separate AdvancedHierarchicalDiffusion models for each ST dataset and average the results.
    MODIFIED: Run stadata1 three times to test for SC cluster rotation/sliding
    """
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Store results from each model
    sc_coords_results = []
    models_all = []
    
    # STEP 1: Build canonical angular frame from ST slide (ONCE)
    # st_coords_raw = stadata1.obsm['spatial']  # Use raw ST coordinates
    # angular_frame = _build_canonical_angular_frame(st_coords_raw)
    
    # List of ST datasets for iteration - Use stadata1 three times
    st_datasets = [
        (stadata1, "run1"),
        (stadata2, "run2"), 
        (stadata3, "run3")
    ]
    
    for i, (stadata, run_name) in enumerate(st_datasets):
        print(f"\n{'='*50}")
        print(f"Training AdvancedHierarchicalDiffusion model {i+1}/3 for {run_name}")
        print(f"{'='*50}")
        
        # Get common genes between SC and current ST dataset
        sc_genes = set(scadata.var_names)
        st_genes = set(stadata.var_names)
        common_genes = sorted(list(sc_genes & st_genes))
        
        print(f"Common genes for {run_name}: {len(common_genes)}")
        
        # Extract expression data
        sc_expr = scadata[:, common_genes].X
        st_expr = stadata[:, common_genes].X
        
        # Convert to dense if sparse
        if hasattr(sc_expr, 'toarray'):
            sc_expr = sc_expr.toarray()
        if hasattr(st_expr, 'toarray'):
            st_expr = st_expr.toarray()
            
        # Get spatial coordinates
        st_coords = stadata.obsm['spatial']
        
        print(f"SC data shape: {sc_expr.shape}")
        print(f"ST data shape: {st_expr.shape}")
        print(f"ST coords shape: {st_coords.shape}")
        
        # Initialize model with different random seed for each run
        torch.manual_seed(42 + i)
        np.random.seed(42 + i)

        # dp = 1 - scadata.obs['n_genes_by_counts'].median() / stadata.obs['n_genes_by_counts'].median()

        
        model = AdvancedHierarchicalDiffusion(
            st_gene_expr=st_expr,
            st_coords=st_coords,
            sc_gene_expr=sc_expr,
            cell_types_sc=scadata.obs['rough_celltype'].values,
            transport_plan=None,
            D_st=None,
            D_induced=None,
            n_genes=len(common_genes),
            n_embedding=[512, 256, 128],
            coord_space_diameter=2.00,
            sigma=0.75,
            alpha=0.8,
            mmdbatch=1000,
            batch_size=256,
            device=device,
            lr_e=0.002,
            lr_d=0.0002,
            n_timesteps=300,
            n_denoising_blocks=4,
            hidden_dim=256,
            num_heads=6,
            num_hierarchical_scales=3,
            dp=0.2,
            outf=f'advanced_diffusion_{run_name}'
        )

        # Train the model
        print(f"Training model for {run_name}...")
        model.train(
            encoder_epochs=1201,
            vae_epochs=3001,
            diffusion_epochs=5001, p_drop_max=0.2
        )

        # st_coords_raw = model.st_coords_norm.cpu().numpy()  # Use normalized coords from model
        st_coords_raw = model.st_coords_norm.cpu().numpy()  # Use normalized coords from model
        angular_frame = _build_canonical_angular_frame(st_coords_raw)
        
        # Generate SC coordinates
        print(f"Generating SC coordinates using {run_name} model...")

        sc_coords = model.sample_sc_coordinates(
            batch_size=512,
            guidance_scale=8.0,
            return_normalized= True
        )

        print(f"\n=== Generated SC Coordinates ({run_name}) ===")
        print(f"  X range: [{sc_coords[:, 0].min():.3f}, {sc_coords[:, 0].max():.3f}]")
        print(f"  Y range: [{sc_coords[:, 1].min():.3f}, {sc_coords[:, 1].max():.3f}]")
        # print(f"  Max radius from center: {np.max(np.linalg.norm(sc_coords, axis=1)):.3f}")
        # print(f"  % points outside unit circle: {(np.linalg.norm(sc_coords, axis=1) > 1).mean()*100:.1f}%")

        # Evaluate geometry preservation
        # metrics = model.evaluate_geometry_preservation(sc_coords)
        sc_coords_results.append(sc_coords)
        models_all.append(model)
        
        # STEP 2: Plot SC cells colored by angle (using ST-derived frame)
        _plot_sc_angle_analysis(sc_coords, scadata.obs['rough_celltype'].values, 
                               angular_frame, st_coords_raw, run_name, i+1)
    
    # STEP 3: Comparative analysis across runs
    _plot_comparative_sc_angle_analysis(sc_coords_results, scadata.obs['rough_celltype'].values,
                                       angular_frame, st_coords_raw)
    
    # Compute averaged SC coordinates
    sc_coords_results_np = [
        coords.cpu().numpy() if hasattr(coords, 'cpu') else coords 
        for coords in sc_coords_results
    ]

    sc_coords_avg = np.mean(sc_coords_results_np, axis=0)
    sc_coords_std = np.std(sc_coords_results_np, axis=0)

    # Store results in scadata
    scadata.obsm['advanced_diffusion_coords_avg'] = sc_coords_avg
    scadata.obsm['advanced_diffusion_coords_std'] = sc_coords_std

    # Store individual results
    for i, coords in enumerate(sc_coords_results_np):  # Use numpy arrays
        scadata.obsm[f'advanced_diffusion_coords_rep{i+1}'] = coords

    print(f"\nTraining complete. Results stored in scadata.obsm")
    return scadata, models_all

def _build_canonical_angular_frame(st_coords):
    """Build canonical angular frame from ST coordinates (dataset-specific, run-independent)"""
    import numpy as np
    
    # Compute centroid
    centroid = st_coords.mean(axis=0)
    
    # Find farthest spot from centroid (deterministic 0° direction)
    distances = np.linalg.norm(st_coords - centroid, axis=1)
    farthest_idx = np.argmax(distances)
    a0 = st_coords[farthest_idx] - centroid  # 0° direction vector
    
    def angle_fn(x):
        """Compute angle from canonical frame"""
        if x.ndim == 1:
            x = x.reshape(1, -1)
        
        v = x - centroid
        cross = a0[0] * v[:, 1] - a0[1] * v[:, 0]  # z-component of 2D cross
        dot = a0[0] * v[:, 0] + a0[1] * v[:, 1]
        angles = np.arctan2(cross, dot)
        angles = np.where(angles < 0, angles + 2*np.pi, angles)  # Map to [0, 2π)
        return angles
    
    return {
        'centroid': centroid,
        'zero_direction': a0,
        'farthest_idx': farthest_idx,
        'angle_fn': angle_fn
    }

def _plot_sc_angle_analysis(sc_coords, cell_types, angular_frame, st_coords_bg, run_name, run_num):
    """Plot SC cells colored by angle from ST-derived frame"""
    import matplotlib.pyplot as plt
    import numpy as np
    
    # Compute angles for SC cells using ST-derived frame
    sc_angles = angular_frame['angle_fn'](sc_coords)
    sc_angles_degrees = np.degrees(sc_angles)
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    # Plot 1: SC cells colored by angle (with ST outline in background)
    ax1.scatter(st_coords_bg[:, 0], st_coords_bg[:, 1], 
               c='black', s=100, alpha=0.8, label='ST outline')
    
    scatter = ax1.scatter(sc_coords[:, 0], sc_coords[:, 1], 
                         c=sc_angles_degrees, cmap='hsv', s=30, alpha=0.8)
    
    # Mark centroid and 0° direction
    centroid = angular_frame['centroid']
    zero_dir = angular_frame['zero_direction']
    ax1.scatter(centroid[0], centroid[1], c='black', s=100, marker='x', linewidth=3)
    ax1.arrow(centroid[0], centroid[1], zero_dir[0]*0.3, zero_dir[1]*0.3, 
              head_width=0.05, head_length=0.05, fc='red', ec='red', linewidth=2)
    
    ax1.set_title(f'{run_name}: SC Cells Colored by Angle θ')
    ax1.set_xlabel('X coordinate')
    ax1.set_ylabel('Y coordinate')
    ax1.set_aspect('equal')
    ax1.legend()
    
    # Add colorbar
    cbar = plt.colorbar(scatter, ax=ax1)
    cbar.set_label('Angle (degrees)')
    
    # Plot 2: Per-cell-type angle distribution
    unique_types = np.unique(cell_types)
    colors = plt.cm.tab10(np.linspace(0, 1, len(unique_types)))
    
    for i, cell_type in enumerate(unique_types):
        mask = cell_types == cell_type
        if np.sum(mask) > 0:
            angles_subset = sc_angles_degrees[mask]
            ax2.hist(angles_subset, bins=36, alpha=0.6, label=cell_type, 
                    color=colors[i], density=True)
    
    ax2.set_title(f'{run_name}: Angle Distribution by Cell Type')
    ax2.set_xlabel('Angle (degrees)')
    ax2.set_ylabel('Density')
    ax2.set_xlim(0, 360)
    ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    
    plt.tight_layout()
    plt.savefig(f'sc_angle_analysis_{run_name}.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    # Print circular statistics per cell type
    print(f"\n{run_name} - Circular statistics per cell type:")
    for cell_type in unique_types:
        mask = cell_types == cell_type
        if np.sum(mask) > 5:  # Only if enough cells
            angles_rad = sc_angles[mask]
            # Circular mean
            mean_cos = np.mean(np.cos(angles_rad))
            mean_sin = np.mean(np.sin(angles_rad))
            circular_mean = np.arctan2(mean_sin, mean_cos)
            if circular_mean < 0:
                circular_mean += 2*np.pi
            
            print(f"  {cell_type}: mean={np.degrees(circular_mean):.1f}°, n={np.sum(mask)}")

def _plot_comparative_sc_angle_analysis(sc_coords_list, cell_types, angular_frame, st_coords_bg):
    """Plot comparative SC angle analysis across all runs"""
    import matplotlib.pyplot as plt
    import numpy as np
    
    n_runs = len(sc_coords_list)
    unique_types = np.unique(cell_types)
    
    fig, axes = plt.subplots(2, n_runs, figsize=(5*n_runs, 10))
    if n_runs == 1:
        axes = axes.reshape(-1, 1)
    
    # Top row: SC scatter plots per run
    for i, sc_coords in enumerate(sc_coords_list):
        ax = axes[0, i]
        
        # ST background
        ax.scatter(st_coords_bg[:, 0], st_coords_bg[:, 1], 
                  c='black', s=20, alpha=0.8)
        
        # SC cells colored by angle
        sc_angles = angular_frame['angle_fn'](sc_coords)
        sc_angles_degrees = np.degrees(sc_angles)
        
        scatter = ax.scatter(sc_coords[:, 0], sc_coords[:, 1], 
                           c=sc_angles_degrees, cmap='hsv', s=20, alpha=0.1)
        
        ax.set_title(f'Run {i+1}: SC Cells by Angle')
        ax.set_aspect('equal')
        
        if i == n_runs-1:  # Add colorbar to last plot
            cbar = plt.colorbar(scatter, ax=ax)
            cbar.set_label('Angle (degrees)')
    
    # Bottom row: Cell type angle distributions per run  
    colors = plt.cm.tab10(np.linspace(0, 1, len(unique_types)))
    
    for i, sc_coords in enumerate(sc_coords_list):
        ax = axes[1, i]
        
        sc_angles = angular_frame['angle_fn'](sc_coords)
        sc_angles_degrees = np.degrees(sc_angles)
        
        for j, cell_type in enumerate(unique_types):
            mask = cell_types == cell_type
            if np.sum(mask) > 5:
                angles_subset = sc_angles_degrees[mask]
                ax.hist(angles_subset, bins=36, alpha=0.6, 
                       label=cell_type if i == 0 else "", 
                       color=colors[j], density=True)
        
        ax.set_title(f'Run {i+1}: Cell Type Angles')
        ax.set_xlabel('Angle (degrees)')
        ax.set_ylabel('Density')
        ax.set_xlim(0, 360)
        
        if i == 0:
            ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    
    plt.tight_layout()
    plt.savefig('comparative_sc_angle_analysis.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    # Check for sector sliding
    print(f"\n" + "="*60)
    print("SECTOR SLIDING ANALYSIS")
    print("="*60)
    
    for cell_type in unique_types:
        mask = cell_types == cell_type
        if np.sum(mask) > 10:  # Only analyze cell types with enough cells
            circular_means = []
            
            for i, sc_coords in enumerate(sc_coords_list):
                sc_angles = angular_frame['angle_fn'](sc_coords)
                angles_subset = sc_angles[mask]
                
                # Circular mean
                mean_cos = np.mean(np.cos(angles_subset))
                mean_sin = np.mean(np.sin(angles_subset))
                circular_mean = np.arctan2(mean_sin, mean_cos)
                if circular_mean < 0:
                    circular_mean += 2*np.pi
                
                circular_means.append(np.degrees(circular_mean))
            
            # Check for large differences between runs
            max_diff = max(circular_means) - min(circular_means)
            if max_diff > 180:  # Handle wraparound
                max_diff = 360 - max_diff
            
            print(f"{cell_type}:")
            print(f"  Run means: {[f'{m:.1f}°' for m in circular_means]}")
            print(f"  Max difference: {max_diff:.1f}°")
            
            if max_diff > 30:  # Significant sliding
                print(f"  ⚠️  SECTOR SLIDING DETECTED!")
            else:
                print(f"  ✅ Consistent placement")

# Load and process data
scadata, stadata1, stadata2, stadata3 = load_and_process_cscc_data()

# ADD THESE LINES:
for i, stdata in enumerate([stadata1, stadata2, stadata3], 1):
    coords = stdata.obsm['spatial']
    print(f"ST{i}: X[{coords[:, 0].min():.2f}, {coords[:, 0].max():.2f}], Y[{coords[:, 1].min():.2f}, {coords[:, 1].max():.2f}]")


# Train individual AdvancedHierarchicalDiffusion models and get averaged results
scadata, advanced_models = train_individual_advanced_diffusion_models(
    scadata, stadata1, stadata2, stadata3
)

print("Advanced diffusion training complete! Results saved in scadata.obsm['advanced_diffusion_coords_avg']")

# Visualize results
import matplotlib.pyplot as plt
import seaborn as sns

my_tab20 = sns.color_palette("tab20", n_colors=20).as_hex()

# Plot 1: Averaged coordinates
plt.figure(figsize=(8, 6))
sc.pl.embedding(scadata, basis='advanced_diffusion_coords_avg', color='rough_celltype',
               size=85, title='SC Advanced Diffusion Coords (Averaged)',
               palette=my_tab20, legend_loc='right margin', legend_fontsize=10)
plt.show()

# Plot 2: Individual model results
for i in range(3):
    plt.figure(figsize=(6, 5))
    sc.pl.embedding(scadata, basis=f'advanced_diffusion_coords_rep{i+1}', color='rough_celltype',
                   size=85, title=f'SC Coordinates (Advanced Model {i+1})',
                   palette=my_tab20, legend_loc='right margin', legend_fontsize=10)
    plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial.distance import pdist, squareform

# Compute distance matrices for the 3 representations
coords_list = [scadata.obsm[f'advanced_diffusion_coords_rep{i}'] for i in range(1, 4)]
dist_matrices = [squareform(pdist(coords, metric='euclidean')) for coords in coords_list]

# Plot the 3 distance matrices side by side
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
for i, (ax, dist_mat) in enumerate(zip(axes, dist_matrices)):
    im = ax.imshow(dist_mat, cmap='viridis', aspect='auto')
    ax.set_title(f'Distance Matrix - Rep{i+1}')
    plt.colorbar(im, ax=ax)

# Compute and print pairwise correlations between distance matrices
print("Pairwise Correlations between Distance Matrices:")
for i in range(3):
    for j in range(i+1, 3):
        corr = np.corrcoef(dist_matrices[i].flatten(), dist_matrices[j].flatten())[0, 1]
        print(f"Rep{i+1} vs Rep{j+1}: {corr:.4f}")

plt.tight_layout()
plt.show()

In [None]:
rawstdata = sc.read_csv('/home/ehtesamul/sc_st/data/cSCC/processed/GSM4284316_P2_ST_rep1_stdata.tsv.gz',delimiter='\t')

def normalize_coordinates_isotropic(coords):
    """Normalize coordinates isotropically to [-1, 1]"""
    center = coords.mean(axis=0)
    centered_coords = coords - center
    max_dist = np.max(np.linalg.norm(centered_coords, axis=1))
    normalized_coords = centered_coords / (max_dist + 1e-8)
    return normalized_coords, center, max_dist

# Load metadata FIRST to know which spots to keep
rawstmeta = pd.read_csv('/home/ehtesamul/sc_st/data/cSCC/processed/GSM4284316_spot_data-selection-P2_ST_rep1.tsv.gz',delimiter='\t')

# Normalize the filtered coordinates
stindex=[]
for i in range(len(rawstmeta.x.tolist())):
    stindex.append(str(rawstmeta.x[i])+'x'+str(rawstmeta.y[i]))
rawstmeta.index = stindex

# Filter FIRST, then extract and normalize coordinates
rawstdata = rawstdata[stindex,:]
rawstdata.obs = rawstmeta

# NOW extract coordinates from the filtered data
coord = np.array([x.split('x') for x in rawstdata.obs_names.tolist()],dtype='int')
print(f"Coordinates shape after filtering: {coord.shape}")  # Should be (666, 2)

# Normalize the filtered coordinates
coord_norm, _, _ = normalize_coordinates_isotropic(coord)
rawstdata.obsm['spatial'] = coord_norm
# rawstdata.obsm['spatial'] = coord


# Continue with preprocessing
sc.pp.normalize_total(rawstdata, target_sum=1e4)
sc.pp.log1p(rawstdata)
rawstdata.layers["log1p"] = rawstdata.X.copy()   # keep positive values
rawstdata.raw = rawstdata[:, :]                  # optional: make this the .raw snapshot
sc.pp.scale(rawstdata)                           # z-score only for downstream PCA etc.


print(rawstdata)  # Should show 666 × 17138
print(f"Spatial coordinates shape: {rawstdata.obsm['spatial'].shape}")  # Should be (666, 2)

sc.pl.spatial(rawstdata,color=['BST2','NRP1','JCHAIN'],show=True,basis='spatial',na_in_legend=False,spot_size=0.05, save='all_three_exp')

In [None]:
from matplotlib import rcParams

sccooravg = scadata.obsm['advanced_diffusion_coords_avg']
PDCcoor = sccooravg[scadata.obs.level2_celltype=='PDC',:]

from sklearn.neighbors import NearestNeighbors
import numpy as np

from sklearn.neighbors import NearestNeighbors
import numpy as np

# estimate ST spot pitch from nearest-neighbor distance in ST space
_nn_st = NearestNeighbors(n_neighbors=2).fit(rawstdata.obsm['spatial'])
_d2, _ = _nn_st.kneighbors(rawstdata.obsm['spatial'])
spot_pitch = np.median(_d2[:, 1])

radius = 1.2 * spot_pitch  # <- toggle this (1.2–2.0× are common)


nbrs_rad = NearestNeighbors(radius=radius).fit(rawstdata.obsm['spatial'])
ind = nbrs_rad.radius_neighbors(PDCcoor, return_distance=False)
nearST = sorted(set(np.concatenate(ind)))

# k = 10  # <- toggle (3–10)
# nbrs_k = NearestNeighbors(n_neighbors=k).fit(rawstdata.obsm['spatial'])
# idx = nbrs_k.kneighbors(PDCcoor, return_distance=False)
# nearST = sorted(set(idx.flatten()))


rawstdata.obs['pDCnear'] = 'Others'
spot_idx = rawstdata.obs.index[nearST]   # index names corresponding to nearST positions
rawstdata.obs.loc[spot_idx, 'pDCnear'] = 'pDC'


rcParams['axes.spines.right'] = False
rcParams['axes.spines.top'] = False
rcParams['pdf.fonttype'] = 42
rcParams['ps.fonttype'] = 42
sc.pl.spatial(rawstdata,color=['pDCnear'],show=True,basis='spatial',na_in_legend=False,spot_size=0.05,save='pDCenrich')

In [None]:
sccooravg

In [None]:
# compute spot pitch
_nn_st = NearestNeighbors(n_neighbors=2).fit(rawstdata.obsm['spatial'])
_d2, _ = _nn_st.kneighbors(rawstdata.obsm['spatial'])
spot_pitch = np.median(_d2[:,1])
radius = 1.2 * spot_pitch   # choose 1.2–1.8 as sensitivity
# radius-based pDC label
nbrs_rad = NearestNeighbors(radius=radius).fit(rawstdata.obsm['spatial'])
ind = nbrs_rad.radius_neighbors(PDCcoor, return_distance=False)
nearST = sorted(set(np.concatenate(ind)))

# k = 15  # <- toggle (3–10)
# nbrs_k = NearestNeighbors(n_neighbors=k).fit(rawstdata.obsm['spatial'])
# idx = nbrs_k.kneighbors(PDCcoor, return_distance=False)
# nearST = sorted(set(idx.flatten()))




rawstdata.obs['pDCnear_radius'] = 'Others'
rawstdata.obs.loc[rawstdata.obs.index[nearST],'pDCnear_radius'] = 'pDC'

# run Scanpy DE (genome-wide)
sc.tl.rank_genes_groups(rawstdata, groupby='pDCnear_radius', method='wilcoxon',
                        layer='log1p', use_raw=False, key_added='deg_pdc_radius')

degdf = sc.get.rank_genes_groups_df(rawstdata, group='pDC', key='deg_pdc_radius', log2fc_min=0)
# degdf includes: names, logfoldchanges (Scanpy style), scores, pvals, pvals_adj


# Print specific genes (BST2, NRP1) in a clean line format
for g in ['BST2','NRP1']:
    row = degdf.loc[degdf['names'] == g]
    if not row.empty:
        r = row.iloc[0]
        print(f"{g}: log2FC={r.logfoldchanges:.3f}, scores= {r.scores}, p={r.pvals:.2e}, padj={r.pvals_adj:.2e}")
    else:
        print(f"{g}: not found in DE results")

In [None]:
# figsize(4,4)
# mpl.rcParams['figure.figsize'] = (4, 4)
sc.pl.spatial(scadata,color="level2_celltype",groups=["Tumor_KC_Cyc"],spot_size=0.06, show=True,basis='advanced_diffusion_coords_avg',title='reconstructed',na_in_legend=False,save='P2cyc')
sc.pl.spatial(scadata,color="level2_celltype",groups=["Tumor_KC_Basal"],spot_size=0.06, show=True,basis='advanced_diffusion_coords_avg',title='reconstructed',na_in_legend=False,save='P2bas')
sc.pl.spatial(scadata,color="level2_celltype",groups=["Tumor_KC_Diff"],spot_size=0.06, show=True,basis='advanced_diffusion_coords_avg',title='reconstructed',na_in_legend=False,save='P2diff')
#save='nonTSK',

In [None]:
import squidpy as sq
sq.gr.spatial_neighbors(scadata,spatial_key='advanced_diffusion_coords_avg')
sq.gr.nhood_enrichment(scadata,cluster_key='rough_celltype')
sq.gr.interaction_matrix(scadata,cluster_key='rough_celltype')
kscadata = scadata[ scadata.obs.level2_celltype.isin(['Tumor_KC_Cyc','Tumor_KC_Basal','Tumor_KC_Diff','TSK'])].copy()
sq.gr.spatial_neighbors(kscadata,spatial_key='advanced_diffusion_coords_avg')
sq.gr.nhood_enrichment(kscadata,cluster_key='level2_celltype')
# sq.pl.nhood_enrichment(kscadata, cluster_key="level2_celltype",cmap='coolwarm',save='TSKKC_new_good.png',figsize=(3,5))
sq.pl.nhood_enrichment(kscadata, cluster_key="level2_celltype",cmap='coolwarm',figsize=(2,2), save='TSKKC_P2_avg.svg', dpi=600)


# patient 10 stuff

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.spatial.distance import cdist
from sklearn.metrics.pairwise import cosine_similarity
import pandas as pd

# Load all 3 ST datasets
stadata1_p10 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP10rep1.h5ad')
stadata2_p10 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP10rep2.h5ad')
stadata3_p10 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP10rep3.h5ad')

datasets = [stadata1_p10, stadata2_p10, stadata3_p10]
names = ['ST_P10_Rep1', 'ST_P10_Rep2', 'ST_P10_Rep3']

# Basic info
print("Dataset Basic Info:")
for i, (data, name) in enumerate(zip(datasets, names)):
    print(f"{name}: {data.shape[0]} spots, {data.shape[1]} genes")
    print(f"  Spatial coords range: X[{data.obsm['spatial'][:,0].min():.2f}, {data.obsm['spatial'][:,0].max():.2f}], Y[{data.obsm['spatial'][:,1].min():.2f}, {data.obsm['spatial'][:,1].max():.2f}]")

In [None]:
def load_and_process_cscc_data_p10():
    """
    Load and process the cSCC dataset with multiple ST replicates.
    """
    print("Loading cSCC data...")
    
    # Load SC data
    scadata_p10 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/scP10.h5ad')
    
    # Load all 3 ST datasets
    stadata1_p10 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP10rep1.h5ad')
    stadata2_p10 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP10rep2.h5ad')
    stadata3_p10 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP10rep3.h5ad')
    
    # Normalize and log transform
    for adata in [scadata_p10, stadata1_p10, stadata2_p10, stadata3_p10]:
        sc.pp.normalize_total(adata)
        sc.pp.log1p(adata)
    
    # Create rough cell types for SC data
    scadata_p10.obs['rough_celltype'] = scadata_p10.obs['level1_celltype'].astype(str)
    scadata_p10.obs.loc[scadata_p10.obs['level1_celltype']=='CLEC9A','rough_celltype'] = 'DC'
    scadata_p10.obs.loc[scadata_p10.obs['level1_celltype']=='CD1C','rough_celltype'] = 'DC'
    scadata_p10.obs.loc[scadata_p10.obs['level1_celltype']=='ASDC','rough_celltype'] = 'DC'
    scadata_p10.obs.loc[scadata_p10.obs['level1_celltype']=='PDC','rough_celltype'] = 'PDC'
    scadata_p10.obs.loc[scadata_p10.obs['level1_celltype']=='MDSC','rough_celltype'] = 'DC'
    scadata_p10.obs.loc[scadata_p10.obs['level1_celltype']=='LC','rough_celltype'] = 'DC'
    scadata_p10.obs.loc[scadata_p10.obs['level1_celltype']=='Mac','rough_celltype'] = 'Myeloid cell'
    scadata_p10.obs.loc[scadata_p10.obs['level1_celltype']=='Tcell','rough_celltype'] = 'T cell'
    scadata_p10.obs.loc[scadata_p10.obs['level2_celltype']=='TSK','rough_celltype'] = 'TSK'
    scadata_p10.obs.loc[scadata_p10.obs['level2_celltype'].isin(['Tumor_KC_Basal', 'Tumor_KC_Diff','Tumor_KC_Cyc']),'rough_celltype'] = 'NonTSK'
    
    return scadata_p10, stadata1_p10, stadata2_p10, stadata3_p10

def prepare_combined_st_for_diffusion(stadata1, stadata2, stadata3, scadata):
    """
    Combine all ST datasets for diffusion training while maintaining gene alignment.
    Key innovation: Use ALL ST data points for better training.
    """
    print("Preparing combined ST data for diffusion training...")
    
    # Get common genes between SC and all ST datasets
    sc_genes = set(scadata.var_names)
    st1_genes = set(stadata1.var_names)
    st2_genes = set(stadata2.var_names)
    st3_genes = set(stadata3.var_names)
    
    common_genes = sorted(list(sc_genes & st1_genes & st2_genes & st3_genes))
    print(f"Common genes across all datasets: {len(common_genes)}")
    
    # Extract aligned expression data
    sc_expr = scadata[:, common_genes].X
    st1_expr = stadata1[:, common_genes].X
    st2_expr = stadata2[:, common_genes].X
    st3_expr = stadata3[:, common_genes].X
    
    # Convert to dense if sparse
    if hasattr(sc_expr, 'toarray'):
        sc_expr = sc_expr.toarray()
    if hasattr(st1_expr, 'toarray'):
        st1_expr = st1_expr.toarray()
    if hasattr(st2_expr, 'toarray'):
        st2_expr = st2_expr.toarray()
    if hasattr(st3_expr, 'toarray'):
        st3_expr = st3_expr.toarray()
    
    # Get spatial coordinates
    st1_coords = stadata1.obsm['spatial']
    st2_coords = stadata2.obsm['spatial']
    st3_coords = stadata3.obsm['spatial']

    # Store separate coordinate lists for block-diagonal graph
    st_coords_list = [st1_coords, st2_coords, st3_coords]
    
    # Combine all ST data
    st_expr_combined = np.vstack([st1_expr, st2_expr, st3_expr])

    from sklearn.preprocessing import StandardScaler
    scaler = StandardScaler()
    st_expr_combined = scaler.fit_transform(st_expr_combined)

    st_coords_combined = np.vstack([st1_coords, st2_coords, st3_coords])

    sc_expr = scaler.fit_transform(sc_expr)


    
    # Create dataset labels for tracking
    dataset_labels = (['dataset1'] * len(st1_expr) + 
                     ['dataset2'] * len(st2_expr) + 
                     ['dataset3'] * len(st3_expr))
    
    print(f"Combined ST data shape: {st_expr_combined.shape}")
    print(f"Combined ST coords shape: {st_coords_combined.shape}")
    print(f"SC data shape: {sc_expr.shape}")
    
    # Convert to tensors
    X_sc = torch.tensor(sc_expr, dtype=torch.float32)
    X_st_combined = torch.tensor(st_expr_combined, dtype=torch.float32)
    Y_st_combined = st_coords_combined.astype(np.float32)
    
    return X_sc, X_st_combined, Y_st_combined, dataset_labels, common_genes, st_coords_list

# Load and process data
scadata_p10, stadata1_p10, stadata2_p10, stadata3_p10 = load_and_process_cscc_data_p10()

# Prepare combined data for diffusion
X_sc, X_st_combined, Y_st_combined, dataset_labels, common_genes, st_coords_list = prepare_combined_st_for_diffusion(
    stadata1_p10, stadata2_p10, stadata3_p10, scadata_p10
)

print(f"Data preparation complete!")
print(f"SC cells: {X_sc.shape[0]}")
print(f"Combined ST spots: {X_st_combined.shape[0]}")
print(f"Common genes: {len(common_genes)}")

In [None]:
# Minimal notebook cell: Procrustes alignment for 5 points across 3 coordinate systems
import numpy as np

def procrustes_align(X_ref, Y, allow_scaling=True, allow_reflection=True):
    X = np.asarray(X_ref, float); Y = np.asarray(Y, float)
    if X.shape != Y.shape: raise ValueError("X_ref and Y must have the same shape")
    n = X.shape[0]
    Xc = X - X.mean(axis=0, keepdims=True)
    Yc = Y - Y.mean(axis=0, keepdims=True)
    C = Yc.T @ Xc / n
    U, S, Vt = np.linalg.svd(C)
    R = U @ Vt
    if (not allow_reflection) and (np.linalg.det(R) < 0):
        Vt[-1] *= -1
        R = U @ Vt
    s = (S.sum()) / (np.linalg.norm(Yc) ** 2 / n) if allow_scaling else 1.0
    Y_aligned = s * (Yc @ R) + X.mean(axis=0, keepdims=True)
    t = (X.mean(axis=0) - s * (Y.mean(axis=0) @ R))
    rmsd = np.sqrt(((X - Y_aligned)**2).sum() / n)
    return Y_aligned, R, s, t, rmsd

def transform(P, theta_deg=0, tx=0, ty=0, scale=1.0, reflect=False, noise=0.0, seed=0):
    rng = np.random.default_rng(seed)
    th = np.deg2rad(theta_deg)
    R = np.array([[np.cos(th), -np.sin(th)],
                  [np.sin(th),  np.cos(th)]])
    if reflect: R[:,0] *= -1
    Q = (P @ R.T) * scale + np.array([tx, ty])
    if noise > 0: Q = Q + rng.normal(0, noise, size=Q.shape)
    return Q

# --- Example data (5 points) ---
base = np.array([[0.0, 0.0],
                 [1.2, 0.1],
                 [1.1, 1.1],
                 [0.2, 1.3],
                 [-0.3, 0.6]], float)
A = base
B = transform(base, theta_deg=35, tx=2.0, ty=-1.0, scale=1.0, noise=0.01, seed=1)
C = transform(base, theta_deg=-95, tx=-1.3, ty=0.9, scale=1.25, reflect=True, noise=0.01, seed=2)

# --- Align B, C → A ---
B_aln, Rb, sb, tb, rmsd_b = procrustes_align(A, B, allow_scaling=True, allow_reflection=True)
C_aln, Rc, sc, tc, rmsd_c = procrustes_align(A, C, allow_scaling=True, allow_reflection=True)

# Results you can print or inspect
print("B→A  s=", round(sb,4), "RMSD=", round(rmsd_b,6), "\nR=\n", Rb, "\nt=", tb)
print("\nC→A  s=", round(sc,4), "RMSD=", round(rmsd_c,6), "\nR=\n", Rc, "\nt=", tc)

# Optional plotting (disabled by default)
SHOW_PLOTS = True
if SHOW_PLOTS:
    import matplotlib.pyplot as plt
    plt.figure(); plt.scatter(A[:,0],A[:,1]); plt.scatter(B[:,0],B[:,1]); plt.scatter(C[:,0],C[:,1]); plt.axis('equal'); plt.title("Before"); plt.show()
    plt.figure(); plt.scatter(A[:,0],A[:,1]); plt.scatter(B_aln[:,0],B_aln[:,1]); plt.scatter(C_aln[:,0],C_aln[:,1]); plt.axis('equal'); plt.title("After"); plt.show()


In [None]:
import scanpy as sc
def train_individual_advanced_diffusion_models(scadata, stadata1, stadata2, stadata3):
    """
    Train separate AdvancedHierarchicalDiffusion models for each ST dataset and average the results.
    MODIFIED: Run stadata1 three times to test for SC cluster rotation/sliding
    """
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Store results from each model
    sc_coords_results = []
    models_all = []
    
    # STEP 1: Build canonical angular frame from ST slide (ONCE)
    # st_coords_raw = stadata1.obsm['spatial']  # Use raw ST coordinates
    # angular_frame = _build_canonical_angular_frame(st_coords_raw)
    
    # List of ST datasets for iteration - Use stadata1 three times
    st_datasets = [
        (stadata1, "run1"),
        (stadata2, "run2"), 
        (stadata3, "run3")
    ]
    
    for i, (stadata, run_name) in enumerate(st_datasets):
        print(f"\n{'='*50}")
        print(f"Training AdvancedHierarchicalDiffusion model {i+1}/3 for {run_name}")
        print(f"{'='*50}")
        
        # Get common genes between SC and current ST dataset
        sc_genes = set(scadata.var_names)
        st_genes = set(stadata.var_names)
        common_genes = sorted(list(sc_genes & st_genes))
        
        print(f"Common genes for {run_name}: {len(common_genes)}")
        
        # Extract expression data
        sc_expr = scadata[:, common_genes].X
        st_expr = stadata[:, common_genes].X
        
        # Convert to dense if sparse
        if hasattr(sc_expr, 'toarray'):
            sc_expr = sc_expr.toarray()
        if hasattr(st_expr, 'toarray'):
            st_expr = st_expr.toarray()
            
        # Get spatial coordinates
        st_coords = stadata.obsm['spatial']
        
        print(f"SC data shape: {sc_expr.shape}")
        print(f"ST data shape: {st_expr.shape}")
        print(f"ST coords shape: {st_coords.shape}")
        
        # Initialize model with different random seed for each run
        torch.manual_seed(42 + i)
        np.random.seed(42 + i)

        # dp = 1 - scadata.obs['n_genes_by_counts'].median() / stadata.obs['n_genes_by_counts'].median()

        
        model = AdvancedHierarchicalDiffusion(
            st_gene_expr=st_expr,
            st_coords=st_coords,
            sc_gene_expr=sc_expr,
            cell_types_sc=scadata.obs['rough_celltype'].values,
            transport_plan=None,
            D_st=None,
            D_induced=None,
            n_genes=len(common_genes),
            n_embedding=[512, 256, 128],
            coord_space_diameter=2.00,
            sigma=0.75,
            alpha=0.8,
            mmdbatch=1000,
            batch_size=256,
            device=device,
            lr_e=0.002,
            lr_d=0.0002,
            n_timesteps=300,
            n_denoising_blocks=4,
            hidden_dim=256,
            num_heads=6,
            num_hierarchical_scales=3,
            dp=0.2,
            outf=f'advanced_diffusion_{run_name}'
        )

        # Train the model
        print(f"Training model for {run_name}...")
        model.train(
            encoder_epochs=1201,
            vae_epochs=3001,
            diffusion_epochs=5001, p_drop_max=0.15
        )

        # st_coords_raw = model.st_coords_norm.cpu().numpy()  # Use normalized coords from model
        st_coords_raw = model.st_coords_norm.cpu().numpy()  # Use normalized coords from model
        angular_frame = _build_canonical_angular_frame(st_coords_raw)
        
        # Generate SC coordinates
        print(f"Generating SC coordinates using {run_name} model...")
        # sc_coords = model.generate_sc_coordinates()
        # sc_coords = model.sample_sc_coordinates_batched(
        #     batch_size=512,  # Even smaller batches
        #     refine_coords=False
        # )

        # model.fine_tune_decoder_boundary(
        #     epochs=10,           # 8–15 is typical
        #     batch_size=1024,
        #     lambda_hull=3.0,     # 3–8 works well; increase only if leaks remain
        #     outlier_sigma= 0.5   # amount of latent perturbation for hull shaping
        # )


        # sc_coords = model.sample_sc_coordinates_pure_diffusion(
        #     batch_size=512, return_normalized=False
        # )
        sc_coords = model.sample_sc_coordinates(
            batch_size=512,
            guidance_scale=10.0,
            return_normalized= True
        )

        print(f"\n=== Generated SC Coordinates ({run_name}) ===")
        print(f"  X range: [{sc_coords[:, 0].min():.3f}, {sc_coords[:, 0].max():.3f}]")
        print(f"  Y range: [{sc_coords[:, 1].min():.3f}, {sc_coords[:, 1].max():.3f}]")
        # print(f"  Max radius from center: {np.max(np.linalg.norm(sc_coords, axis=1)):.3f}")
        # print(f"  % points outside unit circle: {(np.linalg.norm(sc_coords, axis=1) > 1).mean()*100:.1f}%")

        # Evaluate geometry preservation
        # metrics = model.evaluate_geometry_preservation(sc_coords)
        sc_coords_results.append(sc_coords)
        models_all.append(model)
        
        # STEP 2: Plot SC cells colored by angle (using ST-derived frame)
        _plot_sc_angle_analysis(sc_coords, scadata.obs['rough_celltype'].values, 
                               angular_frame, st_coords_raw, run_name, i+1)
    
    # STEP 3: Comparative analysis across runs
    _plot_comparative_sc_angle_analysis(sc_coords_results, scadata.obs['rough_celltype'].values,
                                       angular_frame, st_coords_raw)
    
    # Compute averaged SC coordinates
    sc_coords_results_np = [
        coords.cpu().numpy() if hasattr(coords, 'cpu') else coords 
        for coords in sc_coords_results
    ]

    sc_coords_avg = np.mean(sc_coords_results_np, axis=0)
    sc_coords_std = np.std(sc_coords_results_np, axis=0)

    # Store results in scadata
    scadata.obsm['advanced_diffusion_coords_avg'] = sc_coords_avg
    scadata.obsm['advanced_diffusion_coords_std'] = sc_coords_std

    # Store individual results
    for i, coords in enumerate(sc_coords_results_np):  # Use numpy arrays
        scadata.obsm[f'advanced_diffusion_coords_rep{i+1}'] = coords

    print(f"\nTraining complete. Results stored in scadata.obsm")
    return scadata, models_all

def _build_canonical_angular_frame(st_coords):
    """Build canonical angular frame from ST coordinates (dataset-specific, run-independent)"""
    import numpy as np
    
    # Compute centroid
    centroid = st_coords.mean(axis=0)
    
    # Find farthest spot from centroid (deterministic 0° direction)
    distances = np.linalg.norm(st_coords - centroid, axis=1)
    farthest_idx = np.argmax(distances)
    a0 = st_coords[farthest_idx] - centroid  # 0° direction vector
    
    def angle_fn(x):
        """Compute angle from canonical frame"""
        if x.ndim == 1:
            x = x.reshape(1, -1)
        
        v = x - centroid
        cross = a0[0] * v[:, 1] - a0[1] * v[:, 0]  # z-component of 2D cross
        dot = a0[0] * v[:, 0] + a0[1] * v[:, 1]
        angles = np.arctan2(cross, dot)
        angles = np.where(angles < 0, angles + 2*np.pi, angles)  # Map to [0, 2π)
        return angles
    
    return {
        'centroid': centroid,
        'zero_direction': a0,
        'farthest_idx': farthest_idx,
        'angle_fn': angle_fn
    }

def _plot_sc_angle_analysis(sc_coords, cell_types, angular_frame, st_coords_bg, run_name, run_num):
    """Plot SC cells colored by angle from ST-derived frame"""
    import matplotlib.pyplot as plt
    import numpy as np
    
    # Compute angles for SC cells using ST-derived frame
    sc_angles = angular_frame['angle_fn'](sc_coords)
    sc_angles_degrees = np.degrees(sc_angles)
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    # Plot 1: SC cells colored by angle (with ST outline in background)
    ax1.scatter(st_coords_bg[:, 0], st_coords_bg[:, 1], 
               c='black', s=100, alpha=0.8, label='ST outline')
    
    scatter = ax1.scatter(sc_coords[:, 0], sc_coords[:, 1], 
                         c=sc_angles_degrees, cmap='hsv', s=30, alpha=0.8)
    
    # Mark centroid and 0° direction
    centroid = angular_frame['centroid']
    zero_dir = angular_frame['zero_direction']
    ax1.scatter(centroid[0], centroid[1], c='black', s=100, marker='x', linewidth=3)
    ax1.arrow(centroid[0], centroid[1], zero_dir[0]*0.3, zero_dir[1]*0.3, 
              head_width=0.05, head_length=0.05, fc='red', ec='red', linewidth=2)
    
    ax1.set_title(f'{run_name}: SC Cells Colored by Angle θ')
    ax1.set_xlabel('X coordinate')
    ax1.set_ylabel('Y coordinate')
    ax1.set_aspect('equal')
    ax1.legend()
    
    # Add colorbar
    cbar = plt.colorbar(scatter, ax=ax1)
    cbar.set_label('Angle (degrees)')
    
    # Plot 2: Per-cell-type angle distribution
    unique_types = np.unique(cell_types)
    colors = plt.cm.tab10(np.linspace(0, 1, len(unique_types)))
    
    for i, cell_type in enumerate(unique_types):
        mask = cell_types == cell_type
        if np.sum(mask) > 0:
            angles_subset = sc_angles_degrees[mask]
            ax2.hist(angles_subset, bins=36, alpha=0.6, label=cell_type, 
                    color=colors[i], density=True)
    
    ax2.set_title(f'{run_name}: Angle Distribution by Cell Type')
    ax2.set_xlabel('Angle (degrees)')
    ax2.set_ylabel('Density')
    ax2.set_xlim(0, 360)
    ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    
    plt.tight_layout()
    plt.savefig(f'sc_angle_analysis_{run_name}.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    # Print circular statistics per cell type
    print(f"\n{run_name} - Circular statistics per cell type:")
    for cell_type in unique_types:
        mask = cell_types == cell_type
        if np.sum(mask) > 5:  # Only if enough cells
            angles_rad = sc_angles[mask]
            # Circular mean
            mean_cos = np.mean(np.cos(angles_rad))
            mean_sin = np.mean(np.sin(angles_rad))
            circular_mean = np.arctan2(mean_sin, mean_cos)
            if circular_mean < 0:
                circular_mean += 2*np.pi
            
            print(f"  {cell_type}: mean={np.degrees(circular_mean):.1f}°, n={np.sum(mask)}")

def _plot_comparative_sc_angle_analysis(sc_coords_list, cell_types, angular_frame, st_coords_bg):
    """Plot comparative SC angle analysis across all runs"""
    import matplotlib.pyplot as plt
    import numpy as np
    
    n_runs = len(sc_coords_list)
    unique_types = np.unique(cell_types)
    
    fig, axes = plt.subplots(2, n_runs, figsize=(5*n_runs, 10))
    if n_runs == 1:
        axes = axes.reshape(-1, 1)
    
    # Top row: SC scatter plots per run
    for i, sc_coords in enumerate(sc_coords_list):
        ax = axes[0, i]
        
        # ST background
        ax.scatter(st_coords_bg[:, 0], st_coords_bg[:, 1], 
                  c='black', s=20, alpha=0.8)
        
        # SC cells colored by angle
        sc_angles = angular_frame['angle_fn'](sc_coords)
        sc_angles_degrees = np.degrees(sc_angles)
        
        scatter = ax.scatter(sc_coords[:, 0], sc_coords[:, 1], 
                           c=sc_angles_degrees, cmap='hsv', s=20, alpha=0.1)
        
        ax.set_title(f'Run {i+1}: SC Cells by Angle')
        ax.set_aspect('equal')
        
        if i == n_runs-1:  # Add colorbar to last plot
            cbar = plt.colorbar(scatter, ax=ax)
            cbar.set_label('Angle (degrees)')
    
    # Bottom row: Cell type angle distributions per run  
    colors = plt.cm.tab10(np.linspace(0, 1, len(unique_types)))
    
    for i, sc_coords in enumerate(sc_coords_list):
        ax = axes[1, i]
        
        sc_angles = angular_frame['angle_fn'](sc_coords)
        sc_angles_degrees = np.degrees(sc_angles)
        
        for j, cell_type in enumerate(unique_types):
            mask = cell_types == cell_type
            if np.sum(mask) > 5:
                angles_subset = sc_angles_degrees[mask]
                ax.hist(angles_subset, bins=36, alpha=0.6, 
                       label=cell_type if i == 0 else "", 
                       color=colors[j], density=True)
        
        ax.set_title(f'Run {i+1}: Cell Type Angles')
        ax.set_xlabel('Angle (degrees)')
        ax.set_ylabel('Density')
        ax.set_xlim(0, 360)
        
        if i == 0:
            ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    
    plt.tight_layout()
    plt.savefig('comparative_sc_angle_analysis.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    # Check for sector sliding
    print(f"\n" + "="*60)
    print("SECTOR SLIDING ANALYSIS")
    print("="*60)
    
    for cell_type in unique_types:
        mask = cell_types == cell_type
        if np.sum(mask) > 10:  # Only analyze cell types with enough cells
            circular_means = []
            
            for i, sc_coords in enumerate(sc_coords_list):
                sc_angles = angular_frame['angle_fn'](sc_coords)
                angles_subset = sc_angles[mask]
                
                # Circular mean
                mean_cos = np.mean(np.cos(angles_subset))
                mean_sin = np.mean(np.sin(angles_subset))
                circular_mean = np.arctan2(mean_sin, mean_cos)
                if circular_mean < 0:
                    circular_mean += 2*np.pi
                
                circular_means.append(np.degrees(circular_mean))
            
            # Check for large differences between runs
            max_diff = max(circular_means) - min(circular_means)
            if max_diff > 180:  # Handle wraparound
                max_diff = 360 - max_diff
            
            print(f"{cell_type}:")
            print(f"  Run means: {[f'{m:.1f}°' for m in circular_means]}")
            print(f"  Max difference: {max_diff:.1f}°")
            
            if max_diff > 30:  # Significant sliding
                print(f"  ⚠️  SECTOR SLIDING DETECTED!")
            else:
                print(f"  ✅ Consistent placement")


In [None]:
# Load and process data
scadata_p10, stadata1_p10, stadata2_p10, stadata3_p10 = load_and_process_cscc_data_p10()

# Train individual AdvancedHierarchicalDiffusion models and get averaged results
scadata_p10, advanced_models_p10 = train_individual_advanced_diffusion_models(
    scadata_p10, stadata1_p10, stadata2_p10, stadata3_p10
)

In [None]:
import numpy as np

# --- Rigid Procrustes (no scaling, no reflection) ---
def procrustes_align(X_ref, Y, allow_scaling=False, allow_reflection=False):
    X = np.asarray(X_ref, float); Y = np.asarray(Y, float)
    if X.shape != Y.shape:
        raise ValueError("X_ref and Y must have the same shape (same cells, same order).")
    n = X.shape[0]

    Xc = X - X.mean(axis=0, keepdims=True)
    Yc = Y - Y.mean(axis=0, keepdims=True)

    C = Yc.T @ Xc / n
    U, S, Vt = np.linalg.svd(C)
    R = U @ Vt
    if (not allow_reflection) and (np.linalg.det(R) < 0):
        Vt[-1] *= -1
        R = U @ Vt

    s = 1.0 if not allow_scaling else (S.sum()) / (np.linalg.norm(Yc) ** 2 / n)
    Y_aligned = s * (Yc @ R) + X.mean(axis=0, keepdims=True)
    t = (X.mean(axis=0) - s * (Y.mean(axis=0) @ R))
    rmsd = np.sqrt(((X - Y_aligned) ** 2).sum() / n)
    return Y_aligned, R, s, t, rmsd

# --- Fetch reps (as numpy) ---
rep1 = np.asarray(scadata_p10.obsm['advanced_diffusion_coords_rep1'], float)
rep2 = np.asarray(scadata_p10.obsm['advanced_diffusion_coords_rep2'], float)
rep3 = np.asarray(scadata_p10.obsm['advanced_diffusion_coords_rep3'], float)

# Sanity checks
assert rep1.shape == rep2.shape == rep3.shape, "rep shapes differ"
assert not np.isnan(rep1).any() and not np.isnan(rep2).any() and not np.isnan(rep3).any(), "NaNs present"

# --- Align rep2, rep3 to rep1 (rigid: rotation+translation only) ---
rep2_aln, R2, s2, t2, rmsd2 = procrustes_align(rep1, rep2, allow_scaling=False, allow_reflection=False)
rep3_aln, R3, s3, t3, rmsd3 = procrustes_align(rep1, rep3, allow_scaling=False, allow_reflection=False)

print(f"rep2→rep1  scale={s2:.4f} (fixed to 1.0), reflection={'yes' if np.linalg.det(R2)<0 else 'no'}, RMSD={rmsd2:.6f}")
print(f"rep3→rep1  scale={s3:.4f} (fixed to 1.0), reflection={'yes' if np.linalg.det(R3)<0 else 'no'}, RMSD={rmsd3:.6f}")

# --- Save aligned reps back and recompute average/std ---
scadata_p10.obsm['advanced_diffusion_coords_rep1_aligned'] = rep1
scadata_p10.obsm['advanced_diffusion_coords_rep2_aligned'] = rep2_aln
scadata_p10.obsm['advanced_diffusion_coords_rep3_aligned'] = rep3_aln

stack = np.stack([rep1, rep2_aln, rep3_aln], axis=0)  # (3, n_cells, 2)
scadata_p10.obsm['advanced_diffusion_coords_avg_rigid'] = stack.mean(axis=0)
scadata_p10.obsm['advanced_diffusion_coords_std_rigid'] = stack.std(axis=0)

# Optional: keep transforms if you want to reapply later
scadata_p10.uns['advanced_diffusion_rigid_transforms'] = {
    'rep2': {'R': R2, 's': float(s2), 't': t2, 'rmsd': float(rmsd2)},
    'rep3': {'R': R3, 's': float(s3), 't': t3, 'rmsd': float(rmsd3)},
}

print("Aligned coords stored in:")
print("  obsm['advanced_diffusion_coords_rep2_aligned']")
print("  obsm['advanced_diffusion_coords_rep3_aligned']")
print("Averaged coords (rigid) in:")
print("  obsm['advanced_diffusion_coords_avg_rigid'], std in 'advanced_diffusion_coords_std_rigid'")


In [None]:
# Visualize results
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib as mpl
import scanpy as sc

mpl.rcParams['figure.figsize'] = (4, 4)

my_tab20 = sns.color_palette("tab20", n_colors=20).as_hex()

# Plot 1: Averaged coordinates
plt.figure(figsize=(8, 6))
sc.pl.embedding(scadata_p10, basis='advanced_diffusion_coords_avg', color='rough_celltype',
               size=85, title='SC Advanced Diffusion Coords (Averaged)',
               palette=my_tab20, legend_loc='right margin', legend_fontsize=10)
plt.show()

# Plot 2: Individual model results
for i in range(3):
    plt.figure(figsize=(6, 5))
    sc.pl.embedding(scadata_p10, basis=f'advanced_diffusion_coords_rep{i+1}', color='rough_celltype',
                   size=85, title=f'SC Coordinates (Advanced Model {i+1})',
                   palette=my_tab20, legend_loc='right margin', legend_fontsize=10)
    plt.show()

In [None]:
import os
import warnings
import matplotlib as mpl
import matplotlib.pyplot as plt
import scanpy as sc

import seaborn as sns
n_groups = scadata_p10.obs["rough_celltype"].nunique()
my_tab20 = sns.color_palette("tab20", n_colors=n_groups).as_hex()

# my_tab20 = sns.color_palette("tab10", n_colors=20).as_hex()


# ---------- user preferences ----------
sc.settings.set_figure_params(format='svg')      # keep default SVG
mpl.rcParams['figure.figsize'] = (6, 6)          # default figsize (will be overridden for this fig)
os.makedirs("figures", exist_ok=True)            # create folder on the go
outpath = "figures/P10_rep1_v1.svg"                 # final save path
dpi_save = 600                                   # desired DPI for export
# --------------------------------------

# (optional) temporarily silence the FutureWarning about squidpy alternative
warnings.filterwarnings("ignore", category=FutureWarning)

# Plot with Scanpy but DON'T save yet (show=False), so we can customize the legend
sc.pl.spatial(
    scadata_p10,
    color="rough_celltype",
    spot_size=0.03,
    show=False,                         # prevent automatic display/save so we can customize
    basis='advanced_diffusion_coords_avg',
    title='reconstructed',
    save=None,                           # make sure Scanpy doesn't auto-save; we handle saving below
    palette = my_tab20
)

# grab current Axes & Figure
ax = plt.gca()
fig = ax.get_figure()

# --- Optional: make the figure wider so the horizontal legend fits comfortably ---
# You can change these numbers (width, height) to taste.
fig.set_size_inches(10, 6)   # increase width (was 5x5); user said wider is fine

# --- Build a clean horizontal legend BELOW the plot ---
# Remove any existing legend first (Scanpy sometimes creates one)
old_leg = ax.get_legend()
if old_leg:
    old_leg.remove()

# obtain handles & labels from the scatter artist(s)
handles, labels = ax.get_legend_handles_labels()

# If Scanpy produced extra "title-like" legend entries, you might get duplicate/empty labels;
# if so, filter them out (uncomment if needed)
# pairs = [(h, l) for h, l in zip(handles, labels) if l not in (None, '')]
# if pairs:
#     handles, labels = zip(*pairs)
# else:
#     handles, labels = [], []

# place legend as a horizontal strip below the main axes
# - loc='upper center' with bbox_to_anchor centers the legend beneath the axes
# - bbox_to_anchor: (0.5, -0.15) -> x=0.5 center, y negative moves it below the axes; tweak y to move further down
# - ncol: set how many columns you want; set to len(labels) for a single-row legend (will be wide)
# - frameon: False removes the border if you want a clean strip
ncol = len(labels)                 # puts all entries in a single row; change to e.g. 6 for multiple rows
legend = ax.legend(
    handles,
    labels,
    loc='upper center',
    bbox_to_anchor=(0.5, -0.13),
    ncol=ncol,
    frameon=False,
    handlelength=2.5,
    columnspacing=1.0,
    fontsize=14
)

import matplotlib

# after you create `legend`:
for h in legend.legendHandles:
    if isinstance(h, matplotlib.collections.PathCollection):
        # size is in points^2 (try values like 50, 100, 200, 400)
        h.set_sizes([200.0])     # <- increase this number to make markers bigger
    else:
        try:
            h.set_markersize(10)  # for Line2D handles (if any)
        except Exception:
            pass



# Reserve space at the bottom so the legend isn't clipped
# Increase the bottom margin proportionally to how far below you placed the legend
fig.subplots_adjust(bottom=0.22)   # increase if you move legend further down (or more rows)

# If you prefer the legend to be drawn across the whole figure, you can use fig.legend(...) instead.
# Example: fig.legend(handles, labels, loc='lower center', ncol=ncol, bbox_to_anchor=(0.5, 0.02))
# But ax.legend above keeps alignment with the axes more predictably.

# --- Save the figure to the folder as SVG at 600 DPI ---
fig.savefig(outpath, format="svg", dpi=dpi_save, bbox_inches="tight")

# show the final adjusted figure in the notebook
plt.show()


In [None]:
# rcParams['pdf.fonttype'] = 42
# rcParams['ps.fonttype'] = 42
# figsize(4,4)
mpl.rcParams['figure.figsize'] = (5, 5)
# sc.pl.spatial(scadata,color="level3_celltype",groups=["TSK"],spot_size=0.06, show=True,basis='advanced_diffusion_coords_rep3',title='reconstructed',na_in_legend=False, save='P2_rep3_tsk_v2')
sc.pl.spatial(
    scadata_p10,
    color="level3_celltype",
    groups=["TSK"],
    spot_size=0.06,
    show=False,  # <- Important: prevent auto-show
    basis='advanced_diffusion_coords_avg',
    title='reconstructed',
    na_in_legend=False
)

# Save manually with high resolution
plt.savefig("P10_rep2_tsk_v10.svg", dpi=600, bbox_inches="tight")
plt.show()
#save='TSK',

In [None]:
scadata_p10.obs['selection'] = (scadata_p10.obs['level2_celltype']=='TSK').astype(int)
scadata_p10.obs['selection2'] = (scadata_p10.obs['level1_celltype']=='Fibroblast').astype(int)
scadata_p10.obs['selection3'] = (scadata_p10.obs['rough_celltype']=='Epithelial').astype(int)

# figsize(6,5)
plt.figure(figsize=(6, 6))

sc.pl.spatial(scadata_p10, color=['selection','selection2','selection3','level3_celltype'], spot_size=0.025,cmap='bwr',basis='advanced_diffusion_coords_avg')

In [None]:
import squidpy as sq
sq.gr.spatial_neighbors(scadata_p10,spatial_key='advanced_diffusion_coords_avg')
sq.gr.nhood_enrichment(scadata_p10,cluster_key='rough_celltype')
sq.gr.interaction_matrix(scadata_p10,cluster_key='rough_celltype')
kscadata_p10 = scadata_p10[ scadata_p10.obs.level2_celltype.isin(['Tumor_KC_Cyc','Tumor_KC_Basal','Tumor_KC_Diff','TSK'])].copy()
sq.gr.spatial_neighbors(kscadata_p10,spatial_key='advanced_diffusion_coords_avg')
sq.gr.nhood_enrichment(kscadata_p10,cluster_key='level2_celltype')
# sq.pl.nhood_enrichment(kscadata, cluster_key="level2_celltype",cmap='coolwarm',save='TSKKC_new_best_p10.svg',figsize=(3,5), title=None)
# sq.pl.nhood_enrichment(kscadata, cluster_key="level2_celltype", cmap='coolwarm', save='TSKKC_new_best_p10.svg', figsize=(3,5), ylabel='')
# sq.pl.nhood_enrichment(kscadata, cluster_key="level2_celltype",cmap='coolwarm',figsize=(3,5))

fig, ax = plt.subplots(figsize=(3,5))
sq.pl.nhood_enrichment(kscadata_p10, cluster_key="level2_celltype", cmap='coolwarm', ax=ax)
ax.set_ylabel('')
# plt.savefig('TSKKC_new_best_p10.svg')
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.spatial.distance import cdist
from sklearn.metrics.pairwise import cosine_similarity
import pandas as pd

# Load all 3 ST datasets
stadata1_p10 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP10rep1.h5ad')
stadata2_p10 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP10rep2.h5ad')
stadata3_p10 = sc.read_h5ad('/home/ehtesamul/sc_st/data/cSCC/processed/stP10rep3.h5ad')

datasets = [stadata1_p10, stadata2_p10, stadata3_p10]
names = ['ST_P10_Rep1', 'ST_P10_Rep2', 'ST_P10_Rep3']

# Basic info
print("Dataset Basic Info:")
for i, (data, name) in enumerate(zip(datasets, names)):
    print(f"{name}: {data.shape[0]} spots, {data.shape[1]} genes")
    print(f"  Spatial coords range: X[{data.obsm['spatial'][:,0].min():.2f}, {data.obsm['spatial'][:,0].max():.2f}], Y[{data.obsm['spatial'][:,1].min():.2f}, {data.obsm['spatial'][:,1].max():.2f}]")

# Plot spatial coordinates
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Individual plots
for i, (data, name) in enumerate(zip(datasets, names)):
    coords = data.obsm['spatial']
    row = 0 if i < 2 else 1
    col = i if i < 2 else 0
    axes[row, col].scatter(coords[:, 0], coords[:, 1], alpha=0.6, s=20)
    axes[row, col].set_title(f'{name}\n{data.shape[0]} spots')
    axes[row, col].set_xlabel('X coordinate')
    axes[row, col].set_ylabel('Y coordinate')

# Overlay plot
colors = ['red', 'blue', 'green']
for i, (data, name, color) in enumerate(zip(datasets, names, colors)):
    coords = data.obsm['spatial']
    axes[1, 1].scatter(coords[:, 0], coords[:, 1], alpha=0.5, s=15, c=color, label=name)
axes[1, 1].set_title('All Datasets Overlay')
axes[1, 1].set_xlabel('X coordinate')
axes[1, 1].set_ylabel('Y coordinate')
axes[1, 1].legend()

plt.tight_layout()
plt.show()

# Find common genes
all_genes = [set(data.var_names) for data in datasets]
common_genes = sorted(list(all_genes[0] & all_genes[1] & all_genes[2]))
print(f"\nCommon genes across all datasets: {len(common_genes)}")

# Coordinate overlap analysis
print("\nCoordinate Overlap Analysis:")
tolerance = 1.0  # Distance tolerance for "overlap"

for i in range(len(datasets)):
    for j in range(i+1, len(datasets)):
        coords_i = datasets[i].obsm['spatial']
        coords_j = datasets[j].obsm['spatial']
        
        # Calculate pairwise distances
        distances = cdist(coords_i, coords_j)
        min_distances = np.min(distances, axis=1)
        
        # Count overlaps within tolerance
        overlaps = np.sum(min_distances < tolerance)
        
        print(f"{names[i]} vs {names[j]}:")
        print(f"  Spots within {tolerance} units: {overlaps}/{len(coords_i)} ({overlaps/len(coords_i)*100:.1f}%)")
        print(f"  Mean min distance: {np.mean(min_distances):.2f}")

# Gene expression similarity for closest spots
print("\nGene Expression Similarity (for closest coordinate pairs):")

for i in range(len(datasets)):
    for j in range(i+1, len(datasets)):
        # Get common genes data
        expr_i = datasets[i][:, common_genes].X
        expr_j = datasets[j][:, common_genes].X
        
        if hasattr(expr_i, 'toarray'):
            expr_i = expr_i.toarray()
        if hasattr(expr_j, 'toarray'):
            expr_j = expr_j.toarray()
        
        coords_i = datasets[i].obsm['spatial']
        coords_j = datasets[j].obsm['spatial']
        
        # Find closest pairs
        distances = cdist(coords_i, coords_j)
        closest_j_indices = np.argmin(distances, axis=1)
        
        # Calculate correlations for closest pairs
        correlations = []
        for spot_i in range(len(expr_i)):
            closest_j = closest_j_indices[spot_i]
            corr = np.corrcoef(expr_i[spot_i], expr_j[closest_j])[0, 1]
            if not np.isnan(corr):
                correlations.append(corr)
        
        print(f"{names[i]} vs {names[j]}:")
        print(f"  Mean gene expression correlation: {np.mean(correlations):.4f}")
        print(f"  Median correlation: {np.median(correlations):.4f}")
        print(f"  Correlations > 0.5: {np.sum(np.array(correlations) > 0.5)}/{len(correlations)} ({np.sum(np.array(correlations) > 0.5)/len(correlations)*100:.1f}%)")

# Distance distribution plot
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
pair_idx = 0

for i in range(len(datasets)):
    for j in range(i+1, len(datasets)):
        coords_i = datasets[i].obsm['spatial']
        coords_j = datasets[j].obsm['spatial']
        
        distances = cdist(coords_i, coords_j)
        min_distances = np.min(distances, axis=1)
        
        axes[pair_idx].hist(min_distances, bins=50, alpha=0.7)
        axes[pair_idx].set_title(f'{names[i]} vs {names[j]}\nMin Distance Distribution')
        axes[pair_idx].set_xlabel('Distance to closest spot')
        axes[pair_idx].set_ylabel('Frequency')
        axes[pair_idx].axvline(tolerance, color='red', linestyle='--', label=f'Tolerance={tolerance}')
        axes[pair_idx].legend()
        
        pair_idx += 1

plt.tight_layout()
plt.show()