# Sparse-MENDER STARMAP Prelimbic Area Experiments

### Import Dependencies

In [1]:
import warnings
warnings.filterwarnings("ignore")

import os
import json
import numpy as np
import scanpy as sc
import matplotlib.pyplot as plt
import sys
sys.path.append("..")
from smender.data_loader import DataLoader
from smender.SMENDER import SMENDER
from smender.utils import compute_NMI, compute_ARI, compute_PAS, compute_CHAOS
from ann.AnnoyANN import AnnoyANN
from ann.HNSWANN import HNSWANN

### Define Configuration, Constants and Directories

In [2]:
# Configuration
ANN_TYPE = 'none'  # Options: 'annoy', 'hnsw', 'none'
DIM_REDUCTION = 'pca'  # Options: 'pca', 'nmf', 'ica', 'fa'

# Configuration for plots and results
ANN_TYPE_STR = 'none' # Options: 'Annoy', 'HNSW', 'ENN'
DIM_REDUCTION_STR = 'PCA' # Options: 'PCA', 'NMF', 'ICA', 'FA'
DATASET_STR = "STARMap Prelimbic Area"

# Validate ANN_TYPE
ANN_MAP = {
    'annoy': AnnoyANN,
    'hnsw': HNSWANN,
    'none': None
}
if ANN_TYPE not in ANN_MAP:
    raise ValueError(f"Invalid ANN_TYPE: {ANN_TYPE}. Choose from {list(ANN_MAP.keys())}")
SELECTED_ANN = ANN_MAP[ANN_TYPE]

# Validate DIM_REDUCTION
VALID_DIM_REDUCTIONS = ['pca', 'nmf', 'ica', 'fa']
if DIM_REDUCTION not in VALID_DIM_REDUCTIONS:
    raise ValueError(f"Invalid DIM_REDUCTION: {DIM_REDUCTION}. Choose from {VALID_DIM_REDUCTIONS}")

# Define dataset
dataset_name = "Dataset11_MS_raw"

# Multiprocessing for final SMENDER run
MP_PROCESSES = 4  # Adjust based on CPU cores

# Set random seeds for reproducibility
seed = 100
np.random.seed(seed)
sc.settings.verbosity = 1

# Define directories
result_dir = os.path.join(os.path.pardir, "results", "starmap", ANN_TYPE, DIM_REDUCTION)
plots_result_dir = os.path.join(os.path.pardir, "plots", "starmap", ANN_TYPE, DIM_REDUCTION)
os.makedirs(result_dir, exist_ok=True)
os.makedirs(plots_result_dir, exist_ok=True)

# Define batch and ground truth keys
ground_truth_key = 'gt'
batch_obs = 'slice_id'

### Load the STARMAP Dataset

In [3]:
# Load STARMAP Dataset from SODB
loader = DataLoader(dataset_name)
adata_dict = loader.load()
adata = list(adata_dict.values())[0]

Loading dataset: Dataset11_MS_raw
load experiment[Dataset11] in dataset[Dataset11_MS_raw]
Dataset loaded successfully.


### Prepare Dictionaries to Store Final Results 

In [4]:
results_dict = {}
scores_dict = {}
adata_list = []

### Run SMENDER

In [5]:
# Prepare input data
print("Preparing input data...")
adata_raw = adata.copy()
adata = adata_raw.copy()
# Estimate number of domains
n_cls = np.unique(adata.obs[ground_truth_key]).shape[0]

# Run SMENDER
print("Running SMENDER...")
smender = SMENDER(
    adata,
    batch_obs=batch_obs,
    ct_obs='ct',
    random_seed=seed,
    verbose=0,
    ann=SELECTED_ANN,
    dim_reduction=DIM_REDUCTION
)

# Start tracking time and memory
smender.start_smender_timing()
smender.start_smender_memory()
smender.start_dim_reduction_timing()
smender.start_dim_reduction_memory()
smender.start_nn_timing()
smender.start_nn_memory()

print("Setting SMENDER parameters...")
smender.prepare()
smender.set_MENDER_para(
    nn_mode='radius',
    nn_para=150,
    n_scales=6
)

print("Extracting multi-scale context representation...")
smender.run_representation_mp(mp=MP_PROCESSES)

print("Running clustering...")
smender.run_clustering_normal(n_cls)

# Stop tracking time and memory
performance_metrics = {
    'smender_time': smender.stop_smender_timing(),
    'smender_memory': smender.stop_smender_memory(),
    'dim_reduction_time': smender.stop_dim_reduction_timing(),
    'dim_reduction_memory': smender.stop_dim_reduction_memory(),
    'nn_time': smender.stop_nn_timing(),
    'nn_memory': smender.stop_nn_memory()
}
print(f"\nPerformance Metrics:\n{performance_metrics}")

# Transfer clusters
pred_key = "smender_clusters"
adata_raw.obs[pred_key] = smender.adata_MENDER.obs['MENDER'].astype('category')

Preparing input data...
Running SMENDER...
Setting SMENDER parameters...
Extracting multi-scale context representation...
default number of processes is 200
Concatenation failed: "Values ['69x4486-0', '93x1063-0', '143x3445-0', '88x4092-0', '120x4293-0', '104x2648-0', '105x2425-0', '146x5387-0', '107x3104-0', '127x3822-0', '259x4507-0', '210x2935-0', '241x3271-0', '239x2631-0', '281x3080-0', '224x3566-0', '221x4872-0', '290x5270-0', '276x5435-0', '321x3949-0', '339x2790-0', '317x5119-0', '342x5600-0', '352x2563-0', '372x1693-0', '311x2357-0', '385x5503-0', '396x4972-0', '388x5349-0', '402x1331-0', '409x4543-0', '413x5757-0', '393x4808-0', '446x5944-0', '462x5091-0', '482x5413-0', '482x5581-0', '461x3050-0', '471x5246-0', '526x5768-0', '555x1786-0', '563x5971-0', '526x6243-0', '548x6495-0', '554x6386-0', '558x5137-0', '561x6118-0', '581x4471-0', '599x5622-0', '604x5412-0', '598x6567-0', '578x286-0', '616x5239-0', '652x6228-0', '705x1715-0', '732x5330-0', '835x5264-0', '871x1075-0', '847

### Compute Metrics

In [6]:
# Function to compute metrics
def compute_metrics(adata, ground_truth_key, cluster_key):
    if ground_truth_key not in adata.obs or cluster_key not in adata.obs:
        raise KeyError(f"One or both keys ({ground_truth_key}, {cluster_key}) not found in adata.obs")
    if adata.obs[ground_truth_key].isna().any() or adata.obs[cluster_key].isna().any():
        raise ValueError(f"NaN values found in {ground_truth_key} or {cluster_key}. Handle NaN values before computing metrics.")
    return {
        'NMI': compute_NMI(adata, ground_truth_key, cluster_key),
        'ARI': compute_ARI(adata, ground_truth_key, cluster_key),
        'PAS': compute_PAS(adata, cluster_key),
        'CHAOS': compute_CHAOS(adata, cluster_key)
    }

# Compute overall metrics
print("Computing metrics...")
final_scores = compute_metrics(adata_raw, ground_truth_key, pred_key)
print(f"\nFinal Metrics:\n{final_scores}")

Computing metrics...


ValueError: NaN values found in gt or smender_clusters. Handle NaN values before computing metrics.

### Visualize

In [None]:
# Function to plot UMAP
def plot_umap(adata, title, color_key, prefix, save_path=None):
    _, ax = plt.subplots()
    main_title = f"{prefix} - {title}"
    subtitle = f'\n{DATASET_STR} - {ANN_TYPE_STR} + {DIM_REDUCTION_STR}'
    sc.pl.umap(adata, color=color_key, title=main_title + subtitle, ax=ax, show=False)
    if save_path:
        plt.savefig(save_path, dpi=200, bbox_inches='tight')
        plt.close()

: 

: 

In [None]:
# Plot UMAP for ground truth
final_umap_gt_path = os.path.join(plots_result_dir, f"final_ground_truth_umap_{DATASET_STR}_{ANN_TYPE_STR}_{DIM_REDUCTION_STR}.png")
plot_umap(adata_raw, "All Samples", ground_truth_key, "Ground Truth", save_path=final_umap_gt_path)

: 

: 

In [None]:
# Plot UMAP for SMENDER clusters
final_umap_smender_path = os.path.join(plots_result_dir, f"final_smender_clusters_umap_{DATASET_STR}_{ANN_TYPE}_{DIM_REDUCTION}.png")
plot_umap(adata_raw, "All Samples", pred_key, "SMENDER", save_path=final_umap_smender_path)

: 

: 

In [None]:
# Spatial plots
print("Generating spatial plots...")
def output_cluster_all_modified(smender, obs='MENDER', obs_gt=ground_truth_key, dirname=plots_result_dir):
    smender.adata_MENDER.obs[smender.batch_obs] = smender.adata_MENDER.obs[smender.batch_obs].astype('category')

    metrics_dict = {}
    for si in smender.adata_MENDER.obs[smender.batch_obs].cat.categories:
        cur_a = smender.adata_MENDER[smender.adata_MENDER.obs[smender.batch_obs] == si]
        main_title = si
        subsubtitle = f'{DIM_REDUCTION_STR} - {ANN_TYPE_STR}'
        metrics = {}

        if obs_gt and obs_gt in cur_a.obs:
            nmi = np.round(compute_NMI(cur_a, obs_gt, obs), 3)
            ari = np.round(compute_ARI(cur_a, obs_gt, obs), 3)
            pas = np.round(compute_PAS(cur_a, obs), 3)
            chaos = np.round(compute_CHAOS(cur_a, obs), 3)
            subtitle = f'NMI: {nmi}  ARI: {ari}  PAS: {pas}  CHAOS: {chaos}'
            metrics.update({'NMI': float(nmi), 'ARI': float(ari)})
        else:
            pas = np.round(compute_PAS(cur_a, obs), 3)
            chaos = np.round(compute_CHAOS(cur_a, obs), 3)
            subtitle = f'PAS: {pas}  CHAOS: {chaos}'

        metrics.update({'PAS': float(pas), 'CHAOS': float(chaos)})
        metrics_dict[si] = metrics

        fig, ax = plt.subplots()
        sc.pl.embedding(cur_a, basis='spatial', color=obs, ax=ax, show=False)
        ax.axis('equal')
        fig.suptitle(main_title, fontsize=12, y=1.02)
        ax.set_title(subtitle, fontsize=10, pad=20)
        ax.text(0.5, 1.04, subsubtitle, transform=ax.transAxes, fontsize=8, ha='center', va='center')
        save_path = os.path.join(dirname, f"spatial_{si}_{obs}_{ANN_TYPE}_{DIM_REDUCTION}.png")
        plt.savefig(save_path, dpi=200, bbox_inches='tight')
        plt.close()

    # Save metrics to JSON
    metrics_path = os.path.join(result_dir, f"metrics_per_batch_{obs}_{ANN_TYPE}_{DIM_REDUCTION}.json")
    with open(metrics_path, 'w') as f:
        json.dump(metrics_dict, f, indent=4)

output_cluster_all_modified(smender, obs='MENDER', obs_gt=ground_truth_key)
output_cluster_all_modified(smender, obs=ground_truth_key, obs_gt=None)

: 

: 

### Save Results

In [None]:
# Save results
output_file = os.path.join(result_dir, f"smender_{DATASET_STR}_{ANN_TYPE}_{DIM_REDUCTION}_results.json")
with open(output_file, 'w') as f:
    json.dump({
        'results': {
            'n_cells': adata_raw.n_obs,
            'n_genes': adata_raw.n_vars,
            'cluster_counts': adata_raw.obs['smender_clusters'].value_counts().to_dict()
        },
        'scores': final_scores,
        'performance': {
            'smender_time_seconds': performance_metrics['smender_time'],
            'smender_memory_mb': performance_metrics['smender_memory'],
            'dim_reduction_time_seconds': performance_metrics['dim_reduction_time'],
            'dim_reduction_memory_mb': performance_metrics['dim_reduction_memory'],
            'nn_time_seconds': performance_metrics['nn_time'],
            'nn_memory_mb': performance_metrics['nn_memory']
        }
    }, f, indent=4)
print(f"Results, scores, and performance metrics saved to {output_file}")

: 

: 