# Sparse-MENDER MERFISH Experiments

### Import Dependencies

In [None]:
import warnings
warnings.filterwarnings("ignore")

import os
import json
import numpy as np
import scanpy as sc
import matplotlib.pyplot as plt
import sys
sys.path.append("..")
from smender.data_loader import DataLoader
from smender.SMENDER import SMENDER
from smender.utils import compute_NMI, compute_ARI, compute_PAS, compute_CHAOS
from ann.AnnoyANN import AnnoyANN
from ann.FaissANN import FaissANN
from ann.HNSWANN import HNSWANN

### Define Configuration, Constants and Directories

In [None]:
# Configuration
ANN_TYPE = 'annoy'  # Options: 'annoy', 'faiss', 'hnsw', 'none'
DIM_REDUCTION = 'pca'  # Options: 'pca', 'nmf', 'ica', 'fa'

# Validate ANN_TYPE
ANN_MAP = {
    'annoy': AnnoyANN,
    'faiss': FaissANN,
    'hnsw': HNSWANN,
    'none': None
}
if ANN_TYPE not in ANN_MAP:
    raise ValueError(f"Invalid ANN_TYPE: {ANN_TYPE}. Choose from {list(ANN_MAP.keys())}")
SELECTED_ANN = ANN_MAP[ANN_TYPE]

# Validate DIM_REDUCTION
VALID_DIM_REDUCTIONS = ['pca', 'nmf', 'ica', 'fa']
if DIM_REDUCTION not in VALID_DIM_REDUCTIONS:
    raise ValueError(f"Invalid DIM_REDUCTION: {DIM_REDUCTION}. Choose from {VALID_DIM_REDUCTIONS}")

# Define dataset
dataset_name = "Allen2022Molecular_aging"

# Multiprocessing for final SMENDER run
MP_PROCESSES = 4  # Adjust based on CPU cores

# Set random seeds for reproducibility
seed = 100
np.random.seed(seed)
sc.settings.verbosity = 1

# Define directories
result_dir = os.path.join(os.path.pardir, "results", "merfish", ANN_TYPE)
plots_result_dir = os.path.join(os.path.pardir, "plots", "merfish", ANN_TYPE)
os.makedirs(result_dir, exist_ok=True)
os.makedirs(plots_result_dir, exist_ok=True)

# Define ground truth key
ground_truth_key = 'gt'

### Load the MERFISH Dataset

In [None]:
# Load MERSCOPE Dataset from SODB
loader = DataLoader("Allen2022Molecular_aging")
adata_dict = loader.load()

### Prepare Dictionaries to Store Final Results 

In [None]:
results_dict = {}
scores_dict = {}
adata_list = []

### Run SMENDER

In [None]:
# Prepare input data
print("\nPreparing input data...")
adata_list = []
for si in adata_dict.keys():
    adata = adata_dict[si]
    adata.obs['slice_id'] = si
    adata_list.append(adata)
adata_raw = adata_list[0].concatenate(adata_list[1:])
adata_raw.obs['slice_id'] = adata_raw.obs['slice_id'].astype('category')
adata_raw.obs[ground_truth_key] = adata_raw.obs['tissue'].astype('category')
adata_raw.obs['ct'] = adata_raw.obs['clust_annot'].astype('category')
batch_obs = 'slice_id'

adata = adata_raw.copy()

# Run SMENDER
print("\nRunning SMENDER...")
smender = SMENDER(
    adata,
    batch_obs='slice_id',
    ct_obs='ct',
    random_seed=seed,
    ann=SELECTED_ANN,
    dim_reduction=DIM_REDUCTION
)

# Start tracking time and memory
smender.start_smender_timing()
smender.start_smender_memory()
smender.start_dim_reduction_timing()
smender.start_dim_reduction_memory()
smender.start_nn_timing()
smender.start_nn_memory()

print("\Setting SMENDER parameters...")
smender.prepare()
smender.set_MENDER_para(
    nn_mode='radius',
    nn_para=15,
    n_scales=6
)

print("Extracting multi-scale context representation...")
smender.run_representation_mp(mp=MP_PROCESSES)

print("Running clustering...")
smender.run_clustering_normal(-0.)

# Stop tracking time and memory
performance_metrics = {
    'smender_time': smender.stop_smender_timing(),
    'smender_memory': smender.stop_smender_memory(),
    'dim_reduction_time': smender.stop_dim_reduction_timing(),
    'dim_reduction_memory': smender.stop_dim_reduction_memory(),
    'nn_time': smender.stop_nn_timing(),
    'nn_memory': smender.stop_nn_memory()
}

# Transfer clusters
pred_key = "smender_clusters"
adata_raw.obs[pred_key] = smender.adata_MENDER.obs['MENDER'].astype('category')

### Compute Metrics

In [None]:
# Function to compute metrics
def compute_metrics(adata, ground_truth_key, cluster_key):
    gt = adata.obs[ground_truth_key].astype(str)
    pred = adata.obs[cluster_key].astype(str)
    return {
        'NMI': compute_NMI(gt, pred),
        'ARI': compute_ARI(gt, pred),
        'PAS': compute_PAS(adata, pred),
        'CHAOS': compute_CHAOS(adata, pred)
    }

# Compute overall metrics
print("Computing metrics...")
final_scores = compute_metrics(adata_raw, ground_truth_key, pred_key)

### Visualize

In [None]:
# Function to plot UMAP
def plot_umap(adata, title, color_key, prefix, ann_type, save_path=None):
    _, ax = plt.subplots()
    sc.pl.umap(adata, color=color_key, title=f"{prefix} - {title} - {ann_type}", ax=ax, show=False)
    if save_path:
        plt.savefig(save_path, dpi=200, bbox_inches='tight')
        plt.close()

In [None]:
# Plot UMAP for ground truth
final_umap_gt_path = os.path.join(plots_result_dir, f"final_ground_truth_umap_{ANN_TYPE}.png")
plot_umap(adata_raw, "All Samples", ground_truth_key, "Ground Truth", ANN_TYPE, save_path=final_umap_gt_path)

In [None]:
# Plot UMAP for SMENDER clusters
final_umap_smender_path = os.path.join(plots_result_dir, f"final_smender_clusters_umap_{ANN_TYPE}.png")
plot_umap(adata_raw, "All Samples", pred_key, "SMENDER", ANN_TYPE, save_path=final_umap_smender_path)

In [None]:
# Spatial plots
print("Generating spatial plots...")
def output_cluster_all_modified(smender, obs='MENDER', obs_gt=ground_truth_key, dirname=plots_result_dir):
    smender.adata_MENDER.obs[smender.batch_obs] = smender.adata_MENDER.obs[smender.batch_obs].astype('category')
    for si in smender.adata_MENDER.obs[smender.batch_obs].cat.categories:
        cur_a = smender.adata_MENDER[smender.adata_MENDER.obs[smender.batch_obs] == si]
        if obs_gt and obs_gt in cur_a.obs:
            nmi = np.round(compute_NMI(cur_a, obs_gt, obs), 3)
            ari = np.round(compute_ARI(cur_a, obs_gt, obs), 3)
            title = f'{si}\nNMI: {nmi} ARI: {ari}'
            pas = np.round(compute_PAS(cur_a, obs), 3)
            chaos = np.round(compute_CHAOS(cur_a, obs), 3)
            title += f' PAS: {pas} CHAOS: {chaos}'
        else:
            title = si
            pas = np.round(compute_PAS(cur_a, obs), 3)
            chaos = np.round(compute_CHAOS(cur_a, obs), 3)
            title += f' PAS: {pas} CHAOS: {chaos}'
        fig, ax = plt.subplots()
        sc.pl.embedding(cur_a, basis='spatial', color=obs, ax=ax, show=False)
        ax.axis('equal')
        ax.set_title(title)
        save_path = os.path.join(dirname, f"spatial_{si}_{obs}_{ANN_TYPE}.png")
        plt.savefig(save_path, dpi=200, bbox_inches='tight')
        plt.close()

output_cluster_all_modified(smender, obs='MENDER', obs_gt=ground_truth_key)
output_cluster_all_modified(smender, obs=ground_truth_key, obs_gt=None)

### Save Results

In [None]:
# Save results
output_file = os.path.join(result_dir, f"smender_{ANN_TYPE}_results.json")
with open(output_file, 'w') as f:
    json.dump({
        'results': {
            'n_cells': adata_raw.n_obs,
            'n_genes': adata_raw.n_vars,
            'cluster_counts': adata_raw.obs['smender_clusters'].value_counts().to_dict()
        },
        'scores': final_scores,
        'performance': {
            'smender_time_seconds': performance_metrics['smender_time'],
            'smender_memory_mb': performance_metrics['smender_memory'],
            'dim_reduction_time_seconds': performance_metrics['dim_reduction_time'],
            'dim_reduction_memory_mb': performance_metrics['dim_reduction_memory'],
            'nn_time_seconds': performance_metrics['nn_time'],
            'nn_memory_mb': performance_metrics['nn_memory']
        }
    }, f, indent=4)
print(f"Results, scores, and performance metrics saved to {output_file}")