# Granule subtyping for paired WT and AD benchmark samples

This notebook completes granule subtyping and annotation for the **main result** (optimal mcDETECT) and for **benchmark settings** (benchmark_rho, benchmark_filtering, benchmark_sphere, etc.). The main result uses pre-built data at `output/MERSCOPE_WT_AD_comparison/granule_adata_tsne.h5ad` (already concatenated and normalized); for other settings we load from `MERSCOPE_WT_1_representative_data` and `MERSCOPE_AD_1_representative_data`. For each setting we:
1. Load data (main result: load h5ad; benchmark: concatenate WT+AD and normalize).
2. Subtype granules (manual: k-means + heatmap → annotation; optional automatic tuning).
3. Export labels and density (main result and benchmark: per-setting files). For main result only: save tSNE and heatmap as jpeg to `output/MERSCOPE_WT_AD_comparison/`.

In [None]:
import anndata
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import scanpy as sc
from scipy.sparse import csr_matrix
from scipy.stats import ttest_ind
from sklearn.cluster import MiniBatchKMeans

from mcDETECT.utils import *
from mcDETECT.model import mcDETECT
from mcDETECT.downstream import classify_granules

import warnings
warnings.filterwarnings("ignore")
sc.settings.verbosity = 0

## 1. Paths and parameters

In [None]:
# Benchmark root: output/benchmark/
BENCHMARK_ROOT = "../../output/benchmark/"

# Data paths for WT and AD (for transcripts, genes, spots when building profiles or computing density)
DATA_PATH_WT = "../../data/MERSCOPE_WT_1/"
DATA_PATH_AD = "../../data/MERSCOPE_AD_1/"

# Folder names for representative data under each benchmark_xxx
WT_REPR_DIR = "MERSCOPE_WT_1_representative_data"
AD_REPR_DIR = "MERSCOPE_AD_1_representative_data"

# Output subfolder under each benchmark_xxx for labels and density
COMPARISON_DIR = "WT_AD_comparison"

# Main result: pre-built concatenated + normalized adata with tSNE
MAIN_RESULT_DIR = "../../output/MERSCOPE_WT_AD_comparison"
MAIN_RESULT_PATH = os.path.join(MAIN_RESULT_DIR, "granule_adata_tsne.h5ad")

In [None]:
# Granule markers
marker_genes = {"pre-syn": ["Bsn", "Gap43", "Nrxn1", "Slc17a6", "Slc17a7", "Slc32a1", "Snap25", "Stx1a", "Syn1", "Syp", "Syt1", "Vamp2", "Cplx2"],
                "post-syn": ["Camk2a", "Dlg3", "Dlg4", "Gphn", "Gria1", "Gria2", "Homer1", "Homer2", "Nlgn1", "Nlgn2", "Nlgn3", "Shank1", "Shank3"],
                "axons": ["Ank3", "Nav1", "Sptnb4", "Nfasc", "Mapt", "Tubb3"],
                "dendrites": ["Actb", "Cyfip2", "Ddn", "Dlg4", "Map1a", "Map2"]}
ref_genes = ["Bsn", "Gap43", "Nrxn1", "Slc17a6", "Slc17a7", "Slc32a1", "Stx1a", "Syn1", "Syp", "Syt1", "Vamp2", "Cplx2", "Camk2a", "Dlg3", "Dlg4", "Gphn", "Gria1", "Gria2", "Homer1", "Homer2", "Nlgn1", "Nlgn2", "Nlgn3", "Shank1", "Shank3", "Cyfip2", "Ddn", "Map1a", "Map2", "Ank3", "Nav1", "Nfasc", "Mapt", "Tubb3"]
gnl_genes = ["Camk2a", "Cplx2", "Slc17a7", "Ddn", "Syp", "Map1a", "Shank1", "Syn1", "Gria1", "Gria2", "Cyfip2", "Vamp2", "Bsn", "Slc32a1", "Nfasc", "Syt1", "Tubb3", "Nav1", "Shank3", "Mapt"]

## 2. Explicit list of benchmark settings

Under each benchmark (benchmark_rho, benchmark_filtering, benchmark_sphere), WT and AD outputs are stored in
`MERSCOPE_WT_1_representative_data/` and `MERSCOPE_AD_1_representative_data/` with **the same filename** for a given parameter setting.
Format is inferred from extension: `.parquet` → granules, build profile via mc.profile(); `.h5ad` → expression profile, load directly.

In [None]:
# Main result first, then benchmark settings.
# (benchmark_name, filename) — for benchmark: same filename in WT and AD folders. For main_result: filename used as setting_key.
BENCHMARK_SETTINGS = [
    
    # main_result: pre-built concatenated + normalized; not from benchmark dirs
    ("main_result", "granule_adata_tsne.h5ad"), 
    
    # benchmark_filtering: inSomaThr and ncThr (parquet)
    ("benchmark_filtering", "granules_inSomaThr1.0_ncThr1.0.parquet"),
    ("benchmark_filtering", "granules_inSomaThr0.1_ncThr1.0.parquet"),
    ("benchmark_filtering", "granules_inSomaThr1.0_ncThr0.1.parquet"),
    ("benchmark_filtering", "granules_inSomaThr0.1_ncThr0.1.parquet"),
    
    # benchmark_sphere: radius setting (h5ad = expression profile)
    ("benchmark_sphere", "granules_expression_default.h5ad"),
    ("benchmark_sphere", "granules_expression_fixed.h5ad"),
    ("benchmark_sphere", "granules_expression_expand.h5ad"),
    ("benchmark_sphere", "granules_expression_shrink.h5ad"),
    
    # benchmark_rho: representative rho values (parquet)
    ("benchmark_rho", "granules_rho_0.0.parquet"),
    ("benchmark_rho", "granules_rho_0.2.parquet"),
    ("benchmark_rho", "granules_rho_0.4.parquet"),
    ("benchmark_rho", "granules_rho_0.6.parquet"),
    
]
print(f"Total settings: {len(BENCHMARK_SETTINGS)}")

for bench, fname in BENCHMARK_SETTINGS:
    fmt = "parquet (build profile)" if fname.endswith(".parquet") else "h5ad (load)"
    print(f"  {bench} / {fname}  [{fmt}]")

## 3. Load one file per sample and concatenate

In [None]:
def load_genes_and_transcripts(data_path):
    genes_df = pd.read_csv(os.path.join(data_path, "processed_data/genes.csv"))
    genes = list(genes_df.iloc[:, 0])
    transcripts = pd.read_parquet(os.path.join(data_path, "processed_data/transcripts.parquet"))
    if "target" not in transcripts.columns and "gene" in transcripts.columns:
        transcripts = transcripts.rename(columns={"gene": "target"})
    return genes, transcripts

def build_mc(transcripts, gnl_genes, nc_genes, **kwargs):
    """Build mcDETECT instance for profile()."""
    return mcDETECT(type="discrete", transcripts=transcripts, gnl_genes=gnl_genes, nc_genes=nc_genes or [],
        eps=1.5, minspl=3, grid_len=1, cutoff_prob=0.95, alpha=10, low_bound=3, size_thr=4.0,
        in_soma_thr=0.1, l=1, rho=0.2, s=1, nc_top=20, nc_thr=0.1, **kwargs)

def load_one_sample(path, sample_label, genes, mc, is_parquet):
    """Load a single WT or AD sample from one file. Returns adata with obs['sample'] = sample_label."""
    if is_parquet:
        granules = pd.read_parquet(path)
        adata = mc.profile(granules, genes=genes)
    else:
        adata = sc.read_h5ad(path)
    adata.obs["sample"] = sample_label
    return adata

def concatenate_wt_ad_for_setting(benchmark_name, filename, genes_wt, genes_ad, transcripts_wt, transcripts_ad, nc_wt, nc_ad, ref_genes_common):
    """
    Load the **same filename** from WT and AD representative dirs for this benchmark, then concatenate.
    - .parquet → build profile with mc.profile(granules, genes) for WT and AD separately.
    - .h5ad → load directly.
    Returns combined adata with obs['sample'] in ('WT', 'AD'), var restricted to ref_genes_common.
    """
    wt_dir = os.path.join(BENCHMARK_ROOT, benchmark_name, WT_REPR_DIR)
    ad_dir = os.path.join(BENCHMARK_ROOT, benchmark_name, AD_REPR_DIR)
    path_wt = os.path.join(wt_dir, filename)
    path_ad = os.path.join(ad_dir, filename)
    if not os.path.isfile(path_wt):
        raise FileNotFoundError(f"WT file not found: {path_wt}")
    if not os.path.isfile(path_ad):
        raise FileNotFoundError(f"AD file not found: {path_ad}")

    is_parquet = filename.endswith(".parquet")
    if is_parquet:
        mc_wt = build_mc(transcripts_wt, gnl_genes=gnl_genes, nc_genes=nc_wt)
        mc_ad = build_mc(transcripts_ad, gnl_genes=gnl_genes, nc_genes=nc_ad)
        adata_wt = load_one_sample(path_wt, "WT", genes_wt, mc_wt, True)
        adata_ad = load_one_sample(path_ad, "AD", genes_ad, mc_ad, True)
    else:
        adata_wt = load_one_sample(path_wt, "WT", None, None, False)
        adata_ad = load_one_sample(path_ad, "AD", None, None, False)

    common = [g for g in ref_genes_common if g in adata_wt.var_names and g in adata_ad.var_names]
    adata_wt = adata_wt[:, common].copy()
    adata_ad = adata_ad[:, common].copy()
    adata = anndata.concat([adata_wt, adata_ad], join="outer", label="sample", keys=["WT", "AD"])
    adata.obs["sample"] = adata.obs["sample"].astype(str)
    return adata

In [None]:
# Load genes and transcripts once (for profile() when data are parquet)
genes_wt, transcripts_wt = load_genes_and_transcripts(DATA_PATH_WT)
genes_ad, transcripts_ad = load_genes_and_transcripts(DATA_PATH_AD)
nc_wt = list(pd.read_csv(os.path.join(DATA_PATH_WT, "processed_data/negative_controls.csv"))["Gene"])
nc_ad = list(pd.read_csv(os.path.join(DATA_PATH_AD, "processed_data/negative_controls.csv"))["Gene"])
ref_genes_common = [g for g in ref_genes if g in genes_wt and g in genes_ad]
print(f"Marker genes present in both WT and AD: {len(ref_genes_common)}")

## 4. Per-setting pipeline: subtyping and export

In [None]:
# K-means and manual subtyping parameters (reproducibility: all driven by SEED)
N_CLUSTERS = 15
KMEANS_BATCH_SIZE = 5000
KMEANS_N_INIT = 20

# Main result only: additional k values for cluster-count sensitivity (only subtype labels parquet saved; no tSNE/heatmap/density)
MAIN_RESULT_EXTRA_K = [20, 25, 30]

# Brain areas for density
AREA_LIST = ["Isocortex", "OLF", "HPF-CA", "HPF-DG", "HPF-SR", "CTXsp", "TH", "MB", "FT"]

# Convention: granule_subtype_manual = finer labels (e.g. 'pre & post'); use only for defining synaptic granules. Density and automatic-annotation guidance use granule_subtype_manual_simple ('mixed' for mixed types).
# Synaptic subtypes (from granule_subtype_manual) for correlation with ground truth
SYNAPTIC_SUBTYPES = ["pre-syn", "post-syn", "pre & post"]
GROUND_TRUTH_DENSITIES = [1.71603565, 1.964351308, 2.052720791, 1.139278326, 99, 1.678527951, 1.082904337, 0.444031185, 0.0199885]  # 99 = placeholder for HPF-SR

### 4.1 Manual subtyping: k-means (fixed seed) + heatmap, annotation left blank

In [None]:
def run_manual_subtyping(granule_adata, n_clusters, seed, batch_size=5000, n_init=20, obs_key="granule_subtype_kmeans"):
    """K-means on full marker matrix → obs[obs_key]; all randomness controlled by seed."""
    data = granule_adata.X.copy()
    if hasattr(data, "toarray"):
        data = data.toarray()
    np.random.seed(seed)
    kmeans = MiniBatchKMeans(n_clusters=n_clusters, batch_size=batch_size, random_state=seed, n_init=n_init)
    kmeans.fit(data)
    granule_adata.obs[obs_key] = kmeans.labels_.astype(str)
    desired_order = [str(i) for i in range(n_clusters)]
    granule_adata.obs[obs_key] = pd.Categorical(granule_adata.obs[obs_key], categories=desired_order, ordered=True)
    return granule_adata

def apply_manual_annotation(granule_adata, mapping, cluster_column="granule_subtype_kmeans"):
    """Map cluster labels to subtype from mapping dict; add obs['granule_subtype_manual'] and obs['granule_subtype_manual_simple']."""
    k2sub = {}
    for subtype, clusters in mapping.items():
        for c in clusters:
            k2sub[c] = subtype
    granule_adata.obs["granule_subtype_manual"] = granule_adata.obs[cluster_column].astype(str).map(k2sub)
    granule_adata.obs["granule_subtype_manual_simple"] = granule_adata.obs["granule_subtype_manual"].apply(
        lambda s: "mixed" if pd.notna(s) and " & " in str(s) else str(s)
    )
    return granule_adata

def compute_manual_dist(granule_adata, subtype_column="granule_subtype_manual_simple"):
    """Build manual_dist from granule_subtype_manual_simple (for run_automatic_tuning_grid)"""
    manual_props = granule_adata.obs[subtype_column].value_counts(normalize=True)
    manual_dist = {"pre-syn": 0.0, "post-syn": 0.0, "dendrites": 0.0, "axons": 0.0, "mixed": 0.0, "others": 0.0}
    for k, v in manual_props.items():
        if k in manual_dist:
            manual_dist[k] = float(v)
        else:
            manual_dist["others"] = manual_dist.get("others", 0) + float(v)
    return manual_dist

### 4.2 Automatic subtyping: GranuleSubtyper and hyperparameter tuning steps

Steps for tuning automatic subtyping (after you have manual annotation):
1. Define manual distribution (proportions) from your manual annotation, e.g.:
   manual_dist = {'pre-syn': 0.16, 'post-syn': 0.31, 'dendrites': 0.17, 'axons': 0.0, 'mixed': 0.34, 'others': 0.02}
2. Grid search: for (enrichment_threshold, min_zscore_threshold) run classify_granules with cluster_column='granule_subtype_kmeans'
   (or cluster_column=None for granule-level). Get proportion per subtype for each (enrich_thr, zscore_thr).
3. Compute least_square_error = sum over subtypes of (auto_proportion - manual_proportion)^2.
4. Pick (enrich_thr, zscore_thr) with smallest LSE; use that for final automatic labels.

See cells below for the actual grid and LSE computation.

In [None]:
def run_automatic_tuning_grid(granule_adata, marker_genes, manual_dist, cluster_column="granule_subtype_kmeans"):
    """
    Grid over enrichment_threshold and min_zscore_threshold; compute LSE vs manual_dist.
    manual_dist: dict e.g. {'pre-syn': 0.16, 'post-syn': 0.31, 'dendrites': 0.17, 'axons': 0.0, 'mixed': 0.34, 'others': 0.02}
    Returns (results_df, best_enrich_thr, best_zscore_thr).
    """
    enrich_thresholds = [0.2, 0.25, 0.30, 0.35, 0.40, 0.45, 0.50, 0.55, 0.6]
    zscore_thresholds = [-1, -0.5, 0.0, 0.5, 1.0, 1.5, 2]
    dist_cols = ["pre-syn", "post-syn", "dendrites", "axons", "mixed", "others"]
    manual_series = pd.Series(manual_dist).reindex(dist_cols).fillna(0)

    results = []
    for enrich_thr in enrich_thresholds:
        for zscore_thr in zscore_thresholds:
            subtypes, subtypes_simple = classify_granules(
                granule_adata,
                cluster_column=cluster_column,
                enrichment_threshold=enrich_thr,
                min_zscore_threshold=zscore_thr,
                custom_markers=marker_genes,
            )
            counts = subtypes_simple.value_counts(normalize=True)
            row = {
                "enrich_thr": enrich_thr,
                "zscore_thr": zscore_thr,
                **{c: counts.get(c, 0) for c in dist_cols},
            }
            results.append(row)
    results_df = pd.DataFrame(results)
    results_df["least_square_error"] = (
        (results_df[dist_cols].sub(manual_series, axis=1)) ** 2
    ).sum(axis=1)
    best = results_df.sort_values("least_square_error", ascending=True).iloc[0]
    return results_df, best["enrich_thr"], best["zscore_thr"]

def run_automatic_subtyping(granule_adata, marker_genes, enrichment_threshold=0.35, min_zscore_threshold=0.0, cluster_column="granule_subtype_kmeans"):
    """Assign obs['granule_subtype_automated'] and 'granule_subtype_automated_simple' (mixed vs pure)."""
    subtypes, subtypes_simple = classify_granules(
        granule_adata,
        cluster_column=cluster_column,
        enrichment_threshold=enrichment_threshold,
        min_zscore_threshold=min_zscore_threshold,
        custom_markers=marker_genes,
    )
    granule_adata.obs["granule_subtype_automated"] = subtypes.values
    granule_adata.obs["granule_subtype_automated_simple"] = subtypes_simple.values
    return granule_adata

### 4.3 Export labels (parquet) and subtype density (CSV)

In [None]:
def compute_subtype_density_per_region(granule_obs, spots, area_col="brain_area", subtype_col="granule_subtype_manual_simple",
                                        sample_col="sample", coord_keys=("global_x", "global_y"), grid_len=50):
    """
    Spot-level density per subtype: for each (sample, brain_area, subtype), density = sum(granule count in that area) / n_spots in that area.
    granule_obs must have coords (global_x, global_y or sphere_x, sphere_y), sample_col, and subtype_col.
    spots is AnnData with spots.obs containing area_col and coord_keys. Call with one sample's granule_obs and that sample's spots
    (e.g. WT granules + WT spots, AD granules + AD spots) so coordinates are not mixed across samples.
    """
    if area_col not in spots.obs.columns or coord_keys[0] not in spots.obs.columns or coord_keys[1] not in spots.obs.columns:
        return pd.DataFrame()
    half = grid_len / 2
    rows = []
    for sample in granule_obs[sample_col].dropna().unique():
        go = granule_obs[granule_obs[sample_col] == sample].copy()
        xcol = "global_x" if "global_x" in go.columns else "sphere_x"
        ycol = "global_y" if "global_y" in go.columns else "sphere_y"
        if xcol not in go.columns or ycol not in go.columns:
            continue
        for area in spots.obs[area_col].dropna().unique():
            sp = spots.obs.loc[spots.obs[area_col] == area]
            n_spots = len(sp)
            if n_spots == 0:
                continue
            for subtype in go[subtype_col].dropna().unique():
                gs = go.loc[go[subtype_col] == subtype]
                total = 0
                for _, srow in sp.iterrows():
                    x, y = srow[coord_keys[0]], srow[coord_keys[1]]
                    in_spot = (gs[xcol].values >= x - half) & (gs[xcol].values < x + half) & (gs[ycol].values >= y - half) & (gs[ycol].values < y + half)
                    total += in_spot.sum()
                density = total / n_spots
                rows.append({"sample": sample, "brain_area": area, "subtype": subtype, "density": density, "n_spots": n_spots})
            # Overall (all granules regardless of subtype)
            total = 0
            for _, srow in sp.iterrows():
                x, y = srow[coord_keys[0]], srow[coord_keys[1]]
                in_spot = (go[xcol].values >= x - half) & (go[xcol].values < x + half) & (go[ycol].values >= y - half) & (go[ycol].values < y + half)
                total += in_spot.sum()
            density = total / n_spots
            rows.append({"sample": sample, "brain_area": area, "subtype": "overall", "density": density, "n_spots": n_spots})
    return pd.DataFrame(rows)

def compute_subtype_per_spot_counts(granule_obs, spots, area_col="brain_area", subtype_col="granule_subtype_manual_simple",
                                     sample_col="sample", coord_keys=("global_x", "global_y"), grid_len=50):
    """Return DataFrame with one row per (sample, brain_area, subtype, spot): columns sample, brain_area, subtype, count (granules in that spot)."""
    if area_col not in spots.obs.columns or coord_keys[0] not in spots.obs.columns or coord_keys[1] not in spots.obs.columns:
        return pd.DataFrame()
    half = grid_len / 2
    rows = []
    for sample in granule_obs[sample_col].dropna().unique():
        go = granule_obs[granule_obs[sample_col] == sample].copy()
        xcol = "global_x" if "global_x" in go.columns else "sphere_x"
        ycol = "global_y" if "global_y" in go.columns else "sphere_y"
        if xcol not in go.columns or ycol not in go.columns:
            continue
        for area in spots.obs[area_col].dropna().unique():
            sp = spots.obs.loc[spots.obs[area_col] == area]
            if len(sp) == 0:
                continue
            for subtype in go[subtype_col].dropna().unique():
                gs = go.loc[go[subtype_col] == subtype]
                for _, srow in sp.iterrows():
                    x, y = srow[coord_keys[0]], srow[coord_keys[1]]
                    in_spot = (gs[xcol].values >= x - half) & (gs[xcol].values < x + half) & (gs[ycol].values >= y - half) & (gs[ycol].values < y + half)
                    count = in_spot.sum()
                    rows.append({"sample": sample, "brain_area": area, "subtype": subtype, "count": count})
            # Overall (all granules per spot)
            for _, srow in sp.iterrows():
                x, y = srow[coord_keys[0]], srow[coord_keys[1]]
                in_spot = (go[xcol].values >= x - half) & (go[xcol].values < x + half) & (go[ycol].values >= y - half) & (go[ycol].values < y + half)
                count = in_spot.sum()
                rows.append({"sample": sample, "brain_area": area, "subtype": "overall", "count": count})
    return pd.DataFrame(rows)

def export_setting_results(benchmark_name, setting_key, benchmark_root, comparison_dir, granule_adata, subtype_col="granule_subtype_manual_simple", density_df=None, out_dir=None):
    """
    Export to benchmark_xxx/WT_AD_comparison/ (or out_dir if provided) with setting-specific filenames:
    - granule_subtype_labels_{setting_key}.parquet
    - subtype_density_per_region_{setting_key}.csv
    """
    out_dir = out_dir if out_dir is not None else os.path.join(benchmark_root, benchmark_name, comparison_dir)
    os.makedirs(out_dir, exist_ok=True)
    keep_cols = [c for c in ["granule_id", "sample", subtype_col, "global_x", "global_y", "granule_subtype_kmeans", "granule_subtype_manual"] if c in granule_adata.obs.columns]
    labels_df = granule_adata.obs[keep_cols].copy()
    labels_df.to_parquet(os.path.join(out_dir, f"granule_subtype_labels_{setting_key}.parquet"), index=False)
    if density_df is not None and len(density_df) > 0:
        density_df.to_csv(os.path.join(out_dir, f"subtype_density_per_region_{setting_key}.csv"), index=False)
    print(f"Exported {out_dir}")

def export_labels_only(granule_adata, out_dir, setting_key, label_col, extra_cols=None):
    """Export only subtype labels parquet (e.g. for main result extra k); label_col = obs column with cluster ids."""
    os.makedirs(out_dir, exist_ok=True)
    base = ["granule_id", "sample", label_col, "global_x", "global_y"]
    keep = [c for c in (base + (extra_cols or [])) if c in granule_adata.obs.columns]
    granule_adata.obs[keep].copy().to_parquet(os.path.join(out_dir, f"granule_subtype_labels_{setting_key}.parquet"), index=False)
    print(f"Exported labels {setting_key} -> {out_dir}")

## 5. Run full pipeline for each setting

In [None]:
# Run for a selected setting
benchmark_name, filename = BENCHMARK_SETTINGS[0]    # change this index
setting_key = os.path.splitext(filename)[0]
is_main_result = (benchmark_name == "main_result")
print(f"Processing: {benchmark_name} / {filename}  (setting_key={setting_key}, is_main_result={is_main_result})")

In [None]:
# Load data: main result = pre-built h5ad (already normalized); else concatenate WT+AD and normalize
if is_main_result:
    adata_combined = sc.read_h5ad(MAIN_RESULT_PATH)
    adata_combined.obs.rename(columns={"batch": "sample"}, inplace=True)
    adata_combined.obs["sample_simple"] = adata_combined.obs["sample"].replace({"MERSCOPE_WT_1": "WT", "MERSCOPE_AD_1": "AD"})
    print(adata_combined.shape)
    print(adata_combined.obs["sample"].value_counts())
else:
    adata_combined = concatenate_wt_ad_for_setting(benchmark_name, filename, genes_wt, genes_ad, transcripts_wt, transcripts_ad, nc_wt, nc_ad, ref_genes_common)
    adata_combined.layers["counts"] = csr_matrix(adata_combined.X.copy())
    sc.pp.normalize_total(adata_combined, target_sum=1e4)
    sc.pp.log1p(adata_combined)
    print(adata_combined.shape)
    print(adata_combined.obs["sample"].value_counts())

In [None]:
# Manual subtyping: k-means with fixed seed
SEED = 42
np.random.seed(SEED)
run_manual_subtyping(adata_combined, n_clusters=N_CLUSTERS, seed=SEED, batch_size=KMEANS_BATCH_SIZE, n_init=KMEANS_N_INIT)

# Heatmap for manual annotation (always save to corresponding directory)
groupby = "granule_subtype_kmeans"
var_names = [g for g in ref_genes if g in adata_combined.var_names]
adata_combined.obs[groupby] = pd.Categorical(adata_combined.obs[groupby], categories=[str(i) for i in range(N_CLUSTERS)], ordered=True)
ax = sc.pl.heatmap(adata_combined, var_names=var_names, groupby=groupby, cmap="Reds", standard_scale="var", dendrogram=False, swap_axes=True, show=False, figsize=(10, 6))
heatmap_out_dir = MAIN_RESULT_DIR if is_main_result else os.path.join(BENCHMARK_ROOT, benchmark_name, COMPARISON_DIR)
os.makedirs(heatmap_out_dir, exist_ok=True)
heatmap_fname = "heatmap_subtype.jpeg" if is_main_result else f"heatmap_subtype_{setting_key}.jpeg"
plt.gcf().savefig(os.path.join(heatmap_out_dir, heatmap_fname), dpi=500, bbox_inches="tight")
plt.close()
print(f"Saved {heatmap_fname} to {heatmap_out_dir}")

In [None]:
adata_combined.obs["granule_subtype_kmeans"].isin(["3", "7", "11", "12"]).sum() / len(adata_combined)

In [None]:
# -------------------- Fill MANUAL_SUBTYPE_MAPPING from the heatmap -------------------- #

# MANUAL_SUBTYPE_MAPPING = {
#     "pre-syn": [],
#     "post-syn": [],
#     "dendrites": [],
#     "axons": [],
#     "pre & post": [],
#     "pre & den": [],
#     "post & den": [],
#     "pre & post & den": [],
#     "others": [],
# }

# [main result] K = 25, seed = 42
# 18 can be pre-den
MANUAL_SUBTYPE_MAPPING = {
    "pre-syn": ["3", "15", "16", "21"],
    "post-syn": ["5", "6", "7", "8", "11", "22", "23"],
    "dendrites": ["9", "14", "17"],
    "axons": [],
    "pre & post": [],
    "pre & den": [],
    "post & den": ["1", "10", "18", "19", "20"],
    "pre & post & den": ["0", "2", "4", "12", "13", "24"],
    "others": [],
}

# # [main result] K = 20, seed = 42
# # 18 can be pre-den
# MANUAL_SUBTYPE_MAPPING = {
#     "pre-syn": ["0", "7", "12", "15", "18"],
#     "post-syn": ["1", "3", "6", "19"],
#     "dendrites": ["2", "4", "11"],
#     "axons": [],
#     "pre & post": ["9", "10"],
#     "pre & den": ["14"],
#     "post & den": ["5", "8", "17"],
#     "pre & post & den": ["13", "16"],
#     "others": [],
# }

# # [granules_rho_0.6] seed = 42
# MANUAL_SUBTYPE_MAPPING = {
#     "pre-syn": ["3", "7", "11", "12"],
#     "post-syn": ["4", "5", "10"],
#     "dendrites": ["6", "13"],
#     "axons": [],
#     "pre & post": ["8"],
#     "pre & den": ["2"],
#     "post & den": ["0", "1"],
#     "pre & post & den": ["9", "14"],
#     "others": [],
# }

# # [granules_rho_0.4] seed = 42
# MANUAL_SUBTYPE_MAPPING = {
#     "pre-syn": ["2", "4", "5", "14"],
#     "post-syn": ["1", "7", "12"],
#     "dendrites": ["0", "6"],
#     "axons": [],
#     "pre & post": ["13"],
#     "pre & den": [],
#     "post & den": ["3", "8", "9"],
#     "pre & post & den": ["10", "11"],
#     "others": [],
# }

# # [granules_rho_0.2] seed = 1
# MANUAL_SUBTYPE_MAPPING = {
#     "pre-syn": ["2", "11", "13"],
#     "post-syn": ["1", "8"],
#     "dendrites": ["3", "9"],
#     "axons": [],
#     "pre & post": ["0", "4"],
#     "pre & den": [],
#     "post & den": ["6", "7", "10"],
#     "pre & post & den": ["5", "12", "14"],
#     "others": [],
# }

# # [granules_rho_0.0] seed = 42
# MANUAL_SUBTYPE_MAPPING = {
#     "pre-syn": ["7", "13", "14"],
#     "post-syn": ["5", "8"],
#     "dendrites": ["4", "6"],
#     "axons": [],
#     "pre & post": ["0", "12"],
#     "pre & den": [],
#     "post & den": ["1", "3"],
#     "pre & post & den": ["2", "9", "10", "11"],
#     "others": [],
# }

# # [granules_expression_shrink] seed = 1
# MANUAL_SUBTYPE_MAPPING = {
#     "pre-syn": ["4", "10", "11", "13"],
#     "post-syn": ["1", "7", "8", "9"],
#     "dendrites": ["6", "12"],
#     "axons": [],
#     "pre & post": ["5"],
#     "pre & den": ["2"],
#     "post & den": ["14"],
#     "pre & post & den": [],
#     "others": ["3"],
# }

# # [granules_expression_expand] seed = 1
# MANUAL_SUBTYPE_MAPPING = {
#     "pre-syn": ["5", "7", "12"],
#     "post-syn": ["0", "1"],
#     "dendrites": ["2", "4"],
#     "axons": [],
#     "pre & post": ["3", "10"],
#     "pre & den": [],
#     "post & den": ["8", "9", "11"],
#     "pre & post & den": ["6", "13", "14"],
#     "others": [],
# }

# # [granules_expression_fixed] seed = 42
# MANUAL_SUBTYPE_MAPPING = {
#     "pre-syn": ["6", "7", "8", "12"],
#     "post-syn": ["0", "5", "9"],
#     "dendrites": ["2", "11"],
#     "axons": [],
#     "pre & post": ["4"],
#     "pre & den": ["10"],
#     "post & den": ["1", "3"],
#     "pre & post & den": ["13", "14"],
#     "others": [],
# }

# # [granules_expression_default] seed = 0
# MANUAL_SUBTYPE_MAPPING = {
#     "pre-syn": ["5", "7", "13"],
#     "post-syn": ["1", "3", "8"],
#     "dendrites": ["6", "9"],
#     "axons": [],
#     "pre & post": ["10"],
#     "pre & den": ["0"],
#     "post & den": ["2", "4"],
#     "pre & post & den": ["11", "12", "14"],
#     "others": [],
# }

# # [granules_inSomaThr0.1_ncThr0.1] seed = 42
# MANUAL_SUBTYPE_MAPPING = {
#     "pre-syn": ["2", "4", "6", "8"],
#     "post-syn": ["0", "5"],
#     "dendrites": ["7", "11"],
#     "axons": [],
#     "pre & post": ["10", "13"],
#     "pre & den": [],
#     "post & den": ["1", "3", "14"],
#     "pre & post & den": ["9", "12"],
#     "others": [],
# }

# # [granules_inSomaThr1.0_ncThr0.1] seed = 42
# MANUAL_SUBTYPE_MAPPING = {
#     "pre-syn": ["0", "2", "6", "10"],
#     "post-syn": ["1", "7", "9", "11"],
#     "dendrites": ["8", "12", "14"],
#     "axons": [],
#     "pre & post": [],
#     "pre & den": ["5"],
#     "post & den": ["3", "13"],
#     "pre & post & den": ["4"],
#     "others": [],
# }

# # [granules_inSomaThr0.1_ncThr1.0] seed = 42
# MANUAL_SUBTYPE_MAPPING = {
#     "pre-syn": ["3", "9", "14"],
#     "post-syn": ["5", "7", "8"],
#     "dendrites": ["2", "10"],
#     "axons": [],
#     "pre & post": ["1", "4", "12"],
#     "pre & den": [],
#     "post & den": ["0", "6"],
#     "pre & post & den": ["11", "13"],
#     "others": [],
# }

# # [granules_inSomaThr1.0_ncThr1.0] seed = 42
# MANUAL_SUBTYPE_MAPPING = {
#     "pre-syn": ["2", "9", "10"],
#     "post-syn": ["4", "5", "8", "12"],
#     "dendrites": ["7", "11"],
#     "axons": [],
#     "pre & post": ["3"],
#     "pre & den": ["6"],
#     "post & den": ["0", "1", "13"],
#     "pre & post & den": ["14"],
#     "others": [],
# }

# # [main result] seed = 42
# MANUAL_SUBTYPE_MAPPING = {
#     "pre-syn": ["0", "7", "14"],
#     "post-syn": ["1", "3", "6", "10"],
#     "dendrites": ["2", "11"],
#     "axons": [],
#     "pre & post": ["9"],
#     "pre & den": [],
#     "post & den": ["4", "8"],
#     "pre & post & den": ["5", "12", "13"],
#     "others": [],
# }

In [None]:
# Apply manual mapping
cluster_col = "granule_subtype_kmeans"
apply_manual_annotation(adata_combined, MANUAL_SUBTYPE_MAPPING, cluster_column=cluster_col)
adata_combined.obs["granule_subtype_manual_simple"] = pd.Categorical(adata_combined.obs["granule_subtype_manual_simple"], categories=["pre-syn", "post-syn", "dendrites", "axons", "mixed", "others"], ordered=True)
adata_combined.obs["granule_subtype_manual_simple"] = adata_combined.obs["granule_subtype_manual_simple"].cat.remove_unused_categories()

# Compute manual distribution
manual_dist = compute_manual_dist(adata_combined, subtype_column="granule_subtype_manual_simple")
print(manual_dist)

[main_results] {'pre-syn': 0.15519290910673186, 'post-syn': 0.3885835803678392, 'dendrites': 0.19101862155671548, 'axons': 0.0, 'mixed': 0.26520488896871347, 'others': 0.0}

[main_results_k=20] {'pre-syn': 0.1494297993049088, 'post-syn': 0.3530587531685531, 'dendrites': 0.17469767975810677, 'axons': 0.0, 'mixed': 0.3228137677684313, 'others': 0.0}

[main_results_k=25] {'pre-syn': 0.15192112918068484, 'post-syn': 0.3642424264867897, 'dendrites': 0.16830780283406133, 'axons': 0.0, 'mixed': 0.3155286414984641, 'others': 0.0}

[granules_inSomaThr1.0_ncThr1.0] {'pre-syn': 0.20004894375147722, 'post-syn': 0.32047398511396424, 'dendrites': 0.13986541025789678, 'axons': 0.0, 'mixed': 0.3396116608766617, 'others': 0.0}

[granules_inSomaThr0.1_ncThr1.0] {'pre-syn': 0.15000669937335093, 'post-syn': 0.3225894280452946, 'dendrites': 0.16438114968117853, 'axons': 0.0, 'mixed': 0.3630227229001759, 'others': 0.0}

[granules_inSomaThr1.0_ncThr0.1] {'pre-syn': 0.21476695152392114, 'post-syn': 0.36276955690919316, 'dendrites': 0.18148625632298146, 'axons': 0.0, 'mixed': 0.24097723524390424, 'others': 0.0}

[granules_inSomaThr0.1_ncThr0.1] {'pre-syn': 0.14834252370315307, 'post-syn': 0.31774366264535653, 'dendrites': 0.16014409358640266, 'axons': 0.0, 'mixed': 0.37376972006508774, 'others': 0.0}

[granules_expression_default] {'pre-syn': 0.13860811408828067, 'post-syn': 0.3466022185889685, 'dendrites': 0.167492172354478, 'axons': 0.0, 'mixed': 0.34729749496827284, 'others': 0.0}

[granules_expression_fixed] {'pre-syn': 0.15189150355600076, 'post-syn': 0.38605058945735116, 'dendrites': 0.17114538219833245, 'axons': 0.0, 'mixed': 0.29091252478831564, 'others': 0.0}

[granules_expression_expand] {'pre-syn': 0.1437027957331694, 'post-syn': 0.25536270096820246, 'dendrites': 0.1657785151266588, 'axons': 0.0, 'mixed': 0.4351559881719693, 'others': 0.0}

[granules_expression_shrink] {'pre-syn': 0.1381218555207944, 'post-syn': 0.35294652196343296, 'dendrites': 0.16116393652115882, 'axons': 0.0, 'mixed': 0.15625961388280174, 'others': 0.19150807211181212}

[granules_rho_0.0] {'pre-syn': 0.11979762252445636, 'post-syn': 0.3103549865084521, 'dendrites': 0.1830056733470193, 'axons': 0.0, 'mixed': 0.38684171762007225, 'others': 0.0}

[granules_rho_0.2] {'pre-syn': 0.13002397585843736, 'post-syn': 0.3150292843714473, 'dendrites': 0.17234260451108296, 'axons': 0.0, 'mixed': 0.38260413525903236, 'others': 0.0}

[granules_rho_0.4] {'pre-syn': 0.17187769923442658, 'post-syn': 0.33082858846804586, 'dendrites': 0.17280395712649974, 'axons': 0.0, 'mixed': 0.32448975517102785, 'others': 0.0}

[granules_rho_0.6] {'pre-syn': 0.16717631228142837, 'post-syn': 0.34092826323352887, 'dendrites': 0.18520620252608516, 'axons': 0.0, 'mixed': 0.3066892219589576, 'others': 0.0}


In [None]:
# [Main result only] Ordered heatmap and tSNE

if is_main_result:
    
    # -------------------- Ordered heatmap -------------------- #
    orig_colors = list(adata_combined.uns.get(groupby + "_colors", []))
    cluster_to_simple = {}
    for subtype, clusters in MANUAL_SUBTYPE_MAPPING.items():
        simple = "mixed" if (pd.notna(subtype) and " & " in str(subtype)) else str(subtype)
        for c in clusters:
            cluster_to_simple[str(c)] = simple
    subtype_order = ["pre-syn", "post-syn", "dendrites", "axons", "mixed", "others"]
    ordered_cluster_ids = []
    for simple in subtype_order:
        clusters_here = sorted([c for c, s in cluster_to_simple.items() if s == simple], key=lambda x: int(x))
        ordered_cluster_ids.extend(clusters_here)
    if len(ordered_cluster_ids) == 0:
        ordered_cluster_ids = [str(i) for i in range(N_CLUSTERS)]
    if len(orig_colors) >= N_CLUSTERS:
        adata_combined.uns[groupby + "_colors"] = [orig_colors[int(c)] for c in ordered_cluster_ids]
    adata_combined.obs[groupby] = pd.Categorical(adata_combined.obs[groupby].astype(str), categories=ordered_cluster_ids, ordered=True)
    ax = sc.pl.heatmap(adata_combined, var_names=var_names, groupby=groupby, cmap="Reds", standard_scale="var", dendrogram=False, swap_axes=True, show=False, figsize=(10, 6))
    plt.gcf().savefig(os.path.join(MAIN_RESULT_DIR, "heatmap_subtype_ordered.jpeg"), dpi=500, bbox_inches="tight")
    plt.close()
    print(f"Saved heatmap_subtype_ordered.jpeg to {MAIN_RESULT_DIR}")
    adata_combined.obs[groupby] = pd.Categorical(adata_combined.obs[groupby], categories=[str(i) for i in range(N_CLUSTERS)], ordered=True)
    if len(orig_colors) >= N_CLUSTERS:
        adata_combined.uns[groupby + "_colors"] = orig_colors[:N_CLUSTERS]
    
    # -------------------- tSNE -------------------- #
    
    # Plot tSNE colored by sample
    sc.set_figure_params(figsize = (8, 8))
    ax = sc.pl.tsne(adata_combined, color = "sample_simple", palette={"WT": "#a0ccec", "AD": "#f48488"}, size = 1, show = False)
    ax.grid(False)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xlabel("")
    ax.set_ylabel("")
    ax.set_title("")
    for spine in ax.spines.values():
        spine.set_visible(False)
    plt.gcf().savefig(os.path.join(MAIN_RESULT_DIR, "granules_by_batch_tsne.jpeg"), dpi=500, bbox_inches="tight")
    plt.close()
    print(f"Saved granules_by_batch_tsne.jpeg to {MAIN_RESULT_DIR}")
    
    # Plot tSNE colored by granule subtype
    for col in ["granule_subtype_kmeans", "granule_subtype_manual", "granule_subtype_manual_simple"]:
        if adata_combined.obs[col].nunique() <= 10:
            adata_combined = assign_palette_to_adata(adata_combined, obs_key = col, cmap_name = "Set2")
        sc.set_figure_params(figsize = (8, 8))
        ax = sc.pl.embedding(adata_combined, basis="tsne", color=col, size=1, show=False)
        ax.grid(False)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xlabel("")
        ax.set_ylabel("")
        ax.set_title("")
        for spine in ax.spines.values():
            spine.set_visible(False)
        plt.gcf().savefig(os.path.join(MAIN_RESULT_DIR, f"{col}_tsne.jpeg"), dpi=500, bbox_inches="tight")
        plt.close()
        print(f"Saved {col}_tsne.jpeg to {MAIN_RESULT_DIR}")

In [None]:
# # Automatic subtyping: tuning step (run after you have manual proportions)
# results_df, best_enrich, best_zscore = run_automatic_tuning_grid(adata_combined, marker_genes, manual_dist, cluster_column=None)
# print(results_df.sort_values('least_square_error').head(10))

# # Then run with best hyperparameters:
# run_automatic_subtyping(adata_combined, marker_genes, enrichment_threshold=0.3, min_zscore_threshold=0.0)

In [None]:
# Density: use granule_subtype_manual_simple (mixed types → 'mixed'). Overlay granules to their sample's spots only.
# AD sample adjusted by CAPTURE_EFFICIENCY_COEF (divide AD count/density by this to correct for capture efficiency).
CAPTURE_EFFICIENCY_COEF = 0.818691
spots_wt = sc.read_h5ad(os.path.join(DATA_PATH_WT, "processed_data/spots.h5ad"))
spots_ad = sc.read_h5ad(os.path.join(DATA_PATH_AD, "processed_data/spots.h5ad"))
sample_col = "sample_simple" if is_main_result else "sample"
if "brain_area" in spots_wt.obs.columns and "brain_area" in spots_ad.obs.columns:
    density_df_wt = compute_subtype_density_per_region(adata_combined.obs[adata_combined.obs[sample_col] == "WT"], spots_wt, subtype_col="granule_subtype_manual_simple", sample_col=sample_col)
    density_df_ad = compute_subtype_density_per_region(adata_combined.obs[adata_combined.obs[sample_col] == "AD"], spots_ad, subtype_col="granule_subtype_manual_simple", sample_col=sample_col)
    density_df_ad["density"] = density_df_ad["density"] / CAPTURE_EFFICIENCY_COEF
    density_df = pd.concat([density_df_wt, density_df_ad], ignore_index=True)
    density_df["setting"] = setting_key
    per_spot_wt = compute_subtype_per_spot_counts(adata_combined.obs[adata_combined.obs[sample_col] == "WT"], spots_wt, subtype_col="granule_subtype_manual_simple", sample_col=sample_col)
    per_spot_ad = compute_subtype_per_spot_counts(adata_combined.obs[adata_combined.obs[sample_col] == "AD"], spots_ad, subtype_col="granule_subtype_manual_simple", sample_col=sample_col)
    per_spot_ad["count"] = per_spot_ad["count"] / CAPTURE_EFFICIENCY_COEF
    p_vals = []
    for (area, subtype), g in density_df.groupby(["brain_area", "subtype"]):
        wt_counts = per_spot_wt[(per_spot_wt["brain_area"] == area) & (per_spot_wt["subtype"] == subtype)]["count"].values
        ad_counts = per_spot_ad[(per_spot_ad["brain_area"] == area) & (per_spot_ad["subtype"] == subtype)]["count"].values
        wt_log = np.log1p(wt_counts)
        ad_log = np.log1p(ad_counts)
        if len(wt_log) >= 2 and len(ad_log) >= 2:
            _, p = ttest_ind(wt_log, ad_log)
        else:
            p = np.nan
        p_vals.append({"brain_area": area, "subtype": subtype, "p_val": p})
    p_df = pd.DataFrame(p_vals)
    density_df = density_df.merge(p_df, on=["brain_area", "subtype"], how="left")
    density_df["p_val_star"] = density_df["p_val"].apply(lambda p: "" if pd.isna(p) else p_val_to_star(p))
else:
    density_df = pd.DataFrame()
print(density_df.head())

In [None]:
# WT synaptic granule density vs ground truth (HPF-SR excluded). Synaptic granules defined by granule_subtype_manual only (SYNAPTIC_SUBTYPES).
if "brain_area" not in spots_wt.obs.columns:
    print(f"Setting {setting_key}: skip correlation (no brain_area in spots)")
else:
    subtype_col_synaptic = "granule_subtype_manual"  # use finer annotation only for defining synaptic
    obs_wt = adata_combined.obs[adata_combined.obs[sample_col] == "WT"].copy()
    obs_wt_synaptic = obs_wt[obs_wt[subtype_col_synaptic].isin(SYNAPTIC_SUBTYPES)].copy()
    obs_wt_synaptic["_synaptic_agg"] = "synaptic"
    density_synaptic_df = compute_subtype_density_per_region(obs_wt_synaptic, spots_wt, subtype_col="_synaptic_agg")
    if is_main_result:
        df_wt_synaptic = density_synaptic_df[(density_synaptic_df["sample"] == "MERSCOPE_WT_1") & (density_synaptic_df["subtype"] == "synaptic")]
    else:
        df_wt_synaptic = density_synaptic_df[(density_synaptic_df[sample_col] == "WT") & (density_synaptic_df["subtype"] == "synaptic")]
    wt_density_list = df_wt_synaptic.set_index("brain_area").reindex(AREA_LIST)["density"].values
    wt_density_list = np.nan_to_num(wt_density_list, nan=0.0)
    wt_list_no_hpfsr_ft = np.delete(wt_density_list, [4, 8])
    gt_no_hpfsr_ft = np.delete(np.array(GROUND_TRUTH_DENSITIES), [4, 8])
    weights = [np.sum(spots_wt.obs["brain_area"] == i) for i in AREA_LIST]
    weights_no_hpfsr_ft = weights.copy()
    weights_no_hpfsr_ft.pop(8)
    weights_no_hpfsr_ft.pop(4)
    scaled_wt = scale(wt_list_no_hpfsr_ft)
    scaled_gt = scale(gt_no_hpfsr_ft)
    r_pearson = weighted_corr(scaled_wt, scaled_gt, weights_no_hpfsr_ft)
    r_spearman = weighted_spearmanr(scaled_wt, scaled_gt, weights_no_hpfsr_ft)
    print(f"Setting {setting_key}: weighted Pearson = {r_pearson:.4f}, weighted Spearman = {r_spearman:.4f}")

Setting main_results: weighted Pearson = 0.9134, weighted Spearman = 0.8657

Setting main_results (k = 20): weighted Pearson = 0.9274, weighted Spearman = 0.9306

Setting main_results (k = 25): weighted Pearson = 0.9090, weighted Spearman = 0.8657

Setting granules_inSomaThr1.0_ncThr1.0: weighted Pearson = 0.8904, weighted Spearman = 0.8297

Setting granules_inSomaThr0.1_ncThr1.0: weighted Pearson = 0.9216, weighted Spearman = 0.9528

Setting granules_inSomaThr1.0_ncThr0.1: weighted Pearson = 0.9005, weighted Spearman = 0.8297

Setting granules_inSomaThr0.1_ncThr0.1: weighted Pearson = 0.9142, weighted Spearman = 0.8657

Setting granules_expression_default: weighted Pearson = 0.9070, weighted Spearman = 0.8657

Setting granules_expression_fixed: weighted Pearson = 0.9002, weighted Spearman = 0.8657

Setting granules_expression_expand: weighted Pearson = 0.9199, weighted Spearman = 0.9528

Setting granules_expression_shrink: weighted Pearson = 0.8840, weighted Spearman = 0.8657

Setting granules_rho_0.0: weighted Pearson = 0.9041, weighted Spearman = 0.8657

Setting granules_rho_0.2: weighted Pearson = 0.9237, weighted Spearman = 0.9306

Setting granules_rho_0.4: weighted Pearson = 0.9212, weighted Spearman = 0.8754

Setting granules_rho_0.6: weighted Pearson = 0.9157, weighted Spearman = 0.8657

In [None]:
# Export labels and density (main result → MAIN_RESULT_DIR, else benchmark_xxx/WT_AD_comparison/)
if is_main_result:
    export_setting_results(benchmark_name, setting_key, BENCHMARK_ROOT, COMPARISON_DIR, adata_combined, subtype_col="granule_subtype_manual_simple", density_df=density_df, out_dir=MAIN_RESULT_DIR)
else:
    export_setting_results(benchmark_name, setting_key, BENCHMARK_ROOT, COMPARISON_DIR, adata_combined, subtype_col="granule_subtype_manual_simple", density_df=density_df)

In [None]:
# # Main result only: run k-means for extra k (20, 25, 30) and save subtype labels parquet only (no tSNE, heatmap, or density)
# if is_main_result:
#     for k in MAIN_RESULT_EXTRA_K:
#         run_manual_subtyping(adata_combined, n_clusters=k, seed=SEED, batch_size=KMEANS_BATCH_SIZE, n_init=KMEANS_N_INIT, obs_key=f"granule_subtype_kmeans_k{k}")
#         export_labels_only(adata_combined, MAIN_RESULT_DIR, f"granule_adata_tsne_k{k}", label_col=f"granule_subtype_kmeans_k{k}")