# Granule subtyping for paired WT and AD benchmark samples

This notebook completes granule subtyping and annotation for paired WT and AD samples across benchmark settings (benchmark_rho, benchmark_filtering, benchmark_sphere, etc.). For each setting with both `MERSCOPE_WT_1_representative_data` and `MERSCOPE_AD_1_representative_data`, we:
1. Load or build granule expression profiles and concatenate WT + AD into one anndata.
2. Subtype granules (manual: k-means + heatmap → blank annotation; automatic: GranuleSubtyper with tuning steps).
3. Export subtype labels and subtype density per brain region per sample to `output/benchmark/benchmark_{xxx}/WT_AD_comparison/`.

In [2]:
import anndata
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import scanpy as sc
from sklearn.cluster import MiniBatchKMeans

from mcDETECT.utils import *
from mcDETECT.model import mcDETECT
from mcDETECT.downstream import spot_granule, classify_granules

import warnings
warnings.filterwarnings("ignore")
sc.settings.verbosity = 0

## Paths and parameters

In [None]:
# Single seed for all randomness (k-means, etc.)
SEED = 0

# Benchmark root: output/benchmark/
BENCHMARK_ROOT = "../../output/benchmark/"

# Data paths for WT and AD (for transcripts, genes, spots when building profiles or computing density)
DATA_PATH_WT = "../../data/MERSCOPE_WT_1/"
DATA_PATH_AD = "../../data/MERSCOPE_AD_1/"

# Folder names for representative data under each benchmark_xxx
WT_REPR_DIR = "MERSCOPE_WT_1_representative_data"
AD_REPR_DIR = "MERSCOPE_AD_1_representative_data"

# Output subfolder under each benchmark_xxx for labels and density
COMPARISON_DIR = "WT_AD_comparison"

In [None]:
# 34 granule markers (same as current notebook)
marker_genes = {"pre-syn": ["Bsn", "Gap43", "Nrxn1", "Slc17a6", "Slc17a7", "Slc32a1", "Snap25", "Stx1a", "Syn1", "Syp", "Syt1", "Vamp2", "Cplx2"],
                "post-syn": ["Camk2a", "Dlg3", "Dlg4", "Gphn", "Gria1", "Gria2", "Homer1", "Homer2", "Nlgn1", "Nlgn2", "Nlgn3", "Shank1", "Shank3"],
                "axons": ["Ank3", "Nav1", "Sptnb4", "Nfasc", "Mapt", "Tubb3"],
                "dendrites": ["Actb", "Cyfip2", "Ddn", "Dlg4", "Map1a", "Map2"]}
ref_genes = list(set(marker_genes["pre-syn"] + marker_genes["post-syn"] + marker_genes["axons"] + marker_genes["dendrites"]))
print(f"Number of unique marker genes: {len(ref_genes)}")

## Explicit list of benchmark settings (WT and AD share the same filename per setting)

Under each benchmark (benchmark_rho, benchmark_filtering, benchmark_sphere), WT and AD outputs are stored in
`MERSCOPE_WT_1_representative_data/` and `MERSCOPE_AD_1_representative_data/` with **the same filename** for a given parameter setting.
Format is inferred from extension: `.parquet` → granules, build profile via mc.profile(); `.h5ad` → expression profile, load directly.

In [None]:
# (benchmark_name, filename) — same filename in WT and AD folders for each setting.
# Parquet = granules → profile(); h5ad = expression profile → load as-is.
BENCHMARK_SETTINGS = [
    # benchmark_rho: representative rho values (parquet)
    ("benchmark_rho", "granules_rho_0.0.parquet"),
    ("benchmark_rho", "granules_rho_0.2.parquet"),
    ("benchmark_rho", "granules_rho_0.5.parquet"),
    ("benchmark_rho", "granules_rho_0.8.parquet"),
    # benchmark_filtering: inSomaThr and ncThr (parquet)
    ("benchmark_filtering", "granules_inSomaThr1.0_ncThr1.0.parquet"),
    ("benchmark_filtering", "granules_inSomaThr0.1_ncThr1.0.parquet"),
    ("benchmark_filtering", "granules_inSomaThr1.0_ncThr0.1.parquet"),
    ("benchmark_filtering", "granules_inSomaThr0.1_ncThr0.1.parquet"),
    ("benchmark_filtering", "granules_inSomaThr0.05_ncThr0.1.parquet"),
    # benchmark_sphere: radius setting (h5ad = expression profile)
    ("benchmark_sphere", "granules_expression_default.h5ad"),
    ("benchmark_sphere", "granules_expression_fixed.h5ad"),
    ("benchmark_sphere", "granules_expression_expand.h5ad"),
    ("benchmark_sphere", "granules_expression_shrink.h5ad"),
]
print(f"Total settings: {len(BENCHMARK_SETTINGS)}")
for bench, fname in BENCHMARK_SETTINGS:
    fmt = "parquet (build profile)" if fname.endswith(".parquet") else "h5ad (load)"
    print(f"  {bench} / {fname}  [{fmt}]")

## Load one file per sample (same filename in WT and AD) and concatenate

In [None]:
def load_genes_and_transcripts(data_path):
    genes_df = pd.read_csv(os.path.join(data_path, "processed_data/genes.csv"))
    genes = list(genes_df.iloc[:, 0])
    transcripts = pd.read_parquet(os.path.join(data_path, "processed_data/transcripts.parquet"))
    if "target" not in transcripts.columns and "gene" in transcripts.columns:
        transcripts = transcripts.rename(columns={"gene": "target"})
    return genes, transcripts

def build_mc(transcripts, nc_genes, **kwargs):
    """Build mcDETECT instance for profile()."""
    return mcDETECT(
        type="discrete",
        transcripts=transcripts,
        gnl_genes=ref_genes[:20],
        nc_genes=nc_genes or [],
        eps=1.5, minspl=3, grid_len=1, cutoff_prob=0.95, alpha=10, low_bound=3, size_thr=4.0,
        in_soma_thr=0.1, l=1, rho=0.2, s=1, nc_top=20, nc_thr=0.1,
        **kwargs
    )

In [None]:
# Load genes and transcripts once (for profile() when data are parquet)
genes_wt, transcripts_wt = load_genes_and_transcripts(DATA_PATH_WT)
genes_ad, transcripts_ad = load_genes_and_transcripts(DATA_PATH_AD)
nc_wt = list(pd.read_csv(os.path.join(DATA_PATH_WT, "processed_data/negative_controls.csv"))["Gene"])
nc_ad = list(pd.read_csv(os.path.join(DATA_PATH_AD, "processed_data/negative_controls.csv"))["Gene"])
ref_genes_common = [g for g in ref_genes if g in genes_wt and g in genes_ad]
print(f"Marker genes present in both WT and AD: {len(ref_genes_common)}")

In [None]:
def load_one_sample(path, sample_label, genes, transcripts, mc, is_parquet):
    """Load a single WT or AD sample from one file. Returns adata with obs['sample'] = sample_label."""
    if is_parquet:
        granules = pd.read_parquet(path)
        adata = mc.profile(granules, genes=genes)
    else:
        adata = sc.read_h5ad(path)
    adata.obs["sample"] = sample_label
    return adata

def concatenate_wt_ad_for_setting(benchmark_name, filename, genes_wt, genes_ad, transcripts_wt, transcripts_ad, nc_wt, nc_ad, ref_genes_common):
    """
    Load the **same filename** from WT and AD representative dirs for this benchmark, then concatenate.
    - .parquet → build profile with mc.profile(granules, genes) for WT and AD separately.
    - .h5ad → load directly.
    Returns combined adata with obs['sample'] in ('WT', 'AD'), var restricted to ref_genes_common.
    """
    wt_dir = os.path.join(BENCHMARK_ROOT, benchmark_name, WT_REPR_DIR)
    ad_dir = os.path.join(BENCHMARK_ROOT, benchmark_name, AD_REPR_DIR)
    path_wt = os.path.join(wt_dir, filename)
    path_ad = os.path.join(ad_dir, filename)
    if not os.path.isfile(path_wt):
        raise FileNotFoundError(f"WT file not found: {path_wt}")
    if not os.path.isfile(path_ad):
        raise FileNotFoundError(f"AD file not found: {path_ad}")

    is_parquet = filename.endswith(".parquet")
    if is_parquet:
        mc_wt = build_mc(transcripts_wt, nc_wt)
        mc_ad = build_mc(transcripts_ad, nc_ad)
        adata_wt = load_one_sample(path_wt, "WT", genes_wt, transcripts_wt, mc_wt, True)
        adata_ad = load_one_sample(path_ad, "AD", genes_ad, transcripts_ad, mc_ad, True)
    else:
        adata_wt = load_one_sample(path_wt, "WT", None, None, None, False)
        adata_ad = load_one_sample(path_ad, "AD", None, None, None, False)

    common = [g for g in ref_genes_common if g in adata_wt.var_names and g in adata_ad.var_names]
    adata_wt = adata_wt[:, common].copy()
    adata_ad = adata_ad[:, common].copy()
    adata = anndata.concat([adata_wt, adata_ad], join="outer", label="sample", keys=["WT", "AD"])
    adata.obs["sample"] = adata.obs["sample"].astype(str)
    return adata

## Per-setting pipeline: subtyping and export

In [None]:
# K-means and manual subtyping parameters (reproducibility: all driven by SEED)
N_CLUSTERS = 15
KMEANS_BATCH_SIZE = 5000
KMEANS_N_INIT = 50

# Brain areas for density (match 3_post_detection)
AREA_LIST = ["Isocortex", "OLF", "HPF-CA", "HPF-DG", "HPF-SR", "CTXsp", "TH", "MB", "FT"]

# Synaptic subtypes for density summary and correlation with ground truth (finer annotation column used if available)
SYNAPTIC_SUBTYPES = ["pre-syn", "post-syn", "pre & post"]
GROUND_TRUTH_DENSITIES = [1.71603565, 1.964351308, 2.052720791, 1.139278326, 99, 1.678527951, 1.082904337, 0.444031185, 0.0199885]  # 99 = placeholder for HPF-SR

### Manual subtyping: k-means (fixed seed) + heatmap, annotation left blank

In [None]:
def _build_category_features(granule_adata, marker_genes, agg="mean"):
    """Build (n_obs x 4) matrix: one value per category (pre-syn, post-syn, dendrites, axons).
    agg: 'mean' or 'sum' over genes in that category present in adata.var_names."""
    var_names = set(granule_adata.var_names)
    X = granule_adata.X
    if hasattr(X, "toarray"):
        X = X.toarray()
    else:
        X = np.asarray(X)
    cat_order = ["pre-syn", "post-syn", "dendrites", "axons"]
    cols = []
    for cat in cat_order:
        genes = [g for g in marker_genes.get(cat, []) if g in var_names]
        if len(genes) == 0:
            cols.append(np.zeros(granule_adata.n_obs, dtype=np.float32))
            continue
        idx = [list(granule_adata.var_names).index(g) for g in genes]
        block = X[:, np.array(idx)]
        if agg == "mean":
            cols.append(block.mean(axis=1).astype(np.float32))
        else:
            cols.append(block.sum(axis=1).astype(np.float32))
    return np.column_stack(cols)

def run_manual_subtyping(granule_adata, n_clusters, seed, batch_size=5000, n_init=50, use_category_markers=False, marker_genes=None, category_agg="mean"):
    """
    K-means; all randomness controlled by seed.
    - use_category_markers=False: K-means on full marker matrix → obs['granule_subtype_kmeans'].
    - use_category_markers=True: K-means on 4 category-level features (pre, post, den, axon) → obs['granule_subtype_kmeans_category'].
      Requires marker_genes dict with keys 'pre-syn', 'post-syn', 'dendrites', 'axons'.
    """
    if use_category_markers and marker_genes is None:
        raise ValueError("marker_genes required when use_category_markers=True")
    obs_key = "granule_subtype_kmeans_category" if use_category_markers else "granule_subtype_kmeans"
    if use_category_markers:
        data = _build_category_features(granule_adata, marker_genes, agg=category_agg)
    else:
        data = granule_adata.X.copy()
        if hasattr(data, "toarray"):
            data = data.toarray()
    np.random.seed(seed)
    kmeans = MiniBatchKMeans(n_clusters=n_clusters, batch_size=batch_size, random_state=seed, n_init=n_init)
    kmeans.fit(data)
    granule_adata.obs[obs_key] = kmeans.labels_.astype(str)
    desired_order = [str(i) for i in range(n_clusters)]
    granule_adata.obs[obs_key] = pd.Categorical(granule_adata.obs[obs_key], categories=desired_order, ordered=True)
    return granule_adata

In [None]:
def adata_for_category_heatmap(granule_adata, marker_genes, category_agg="mean"):
    """Build AnnData with 4 vars (pre-syn, post-syn, dendrites, axons) and same obs, for heatmap when using category-level clustering."""
    X_cat = _build_category_features(granule_adata, marker_genes, agg=category_agg)
    cat_names = ["pre-syn", "post-syn", "dendrites", "axons"]
    adata_cat = anndata.AnnData(X=X_cat.astype(np.float32), obs=granule_adata.obs.copy())
    adata_cat.var_names = cat_names
    adata_cat.var["genes"] = cat_names
    return adata_cat

In [None]:
# Manual annotation: map k-means cluster id -> subtype. Fill in based on your heatmap; leave blank for framework.
def apply_manual_annotation(granule_adata, mapping, cluster_column="granule_subtype_kmeans"):
    """Map cluster labels to subtype from mapping dict; add obs['granule_subtype'] and 'granule_subtype_manual' (mixed vs pure).
    cluster_column: e.g. 'granule_subtype_kmeans' or 'granule_subtype_kmeans_category'."""
    k2sub = {}
    for subtype, clusters in mapping.items():
        for c in clusters:
            k2sub[c] = subtype
    granule_adata.obs["granule_subtype"] = granule_adata.obs[cluster_column].astype(str).map(k2sub)
    granule_adata.obs["granule_subtype_manual"] = granule_adata.obs["granule_subtype"].apply(
        lambda s: "mixed" if pd.notna(s) and " & " in str(s) else str(s)
    )
    return granule_adata

### Optional: annotate clusters with Scanpy score_genes

After clustering (either gene-level or category-level), you can assign each cluster a subtype by the category whose mean score_genes is highest. No manual mapping needed.

### Automatic subtyping: GranuleSubtyper and hyperparameter tuning steps

In [None]:
# Steps for tuning automatic subtyping (after you have manual annotation):
# 1. Define manual distribution (proportions) from your manual annotation, e.g.:
#    manual_dist = {'pre-syn': 0.16, 'post-syn': 0.31, 'dendrites': 0.17, 'axons': 0.0, 'mixed': 0.34, 'others': 0.02}
# 2. Grid search: for (enrichment_threshold, min_zscore_threshold) run classify_granules with cluster_column='granule_subtype_kmeans'
#    (or cluster_column=None for granule-level). Get proportion per subtype for each (enrich_thr, zscore_thr).
# 3. Compute least_square_error = sum over subtypes of (auto_proportion - manual_proportion)^2.
# 4. Pick (enrich_thr, zscore_thr) with smallest LSE; use that for final automatic labels.
# See cells below for the actual grid and LSE computation.

In [None]:
def run_automatic_tuning_grid(granule_adata, marker_genes, manual_dist, cluster_column="granule_subtype_kmeans"):
    """
    Grid over enrichment_threshold and min_zscore_threshold; compute LSE vs manual_dist.
    manual_dist: dict e.g. {'pre-syn': 0.16, 'post-syn': 0.31, 'dendrites': 0.17, 'axons': 0.0, 'mixed': 0.34, 'others': 0.02}
    Returns (results_df, best_enrich_thr, best_zscore_thr).
    """
    enrich_thresholds = [0.2, 0.25, 0.30, 0.35, 0.40, 0.45, 0.50]
    zscore_thresholds = [-0.5, -0.25, 0.0, 0.25, 0.5, 0.75, 1.0, 1.25, 1.5]
    dist_cols = ["pre-syn", "post-syn", "dendrites", "axons", "mixed", "others"]
    manual_series = pd.Series(manual_dist).reindex(dist_cols).fillna(0)

    results = []
    for enrich_thr in enrich_thresholds:
        for zscore_thr in zscore_thresholds:
            subtypes, subtypes_simple = classify_granules(
                granule_adata,
                cluster_column=cluster_column,
                enrichment_threshold=enrich_thr,
                min_zscore_threshold=zscore_thr,
                custom_markers=marker_genes,
            )
            counts = subtypes_simple.value_counts(normalize=True)
            row = {
                "enrich_thr": enrich_thr,
                "zscore_thr": zscore_thr,
                **{c: counts.get(c, 0) for c in dist_cols},
            }
            results.append(row)
    results_df = pd.DataFrame(results)
    results_df["least_square_error"] = (
        (results_df[dist_cols].sub(manual_series, axis=1)) ** 2
    ).sum(axis=1)
    best = results_df.sort_values("least_square_error", ascending=True).iloc[0]
    return results_df, best["enrich_thr"], best["zscore_thr"]

In [None]:
def run_automatic_subtyping(granule_adata, marker_genes, enrichment_threshold=0.35, min_zscore_threshold=0.0, cluster_column="granule_subtype_kmeans"):
    """Assign obs['granule_subtype_automated'] and 'granule_subtype_automated_simple' (mixed vs pure)."""
    subtypes, subtypes_simple = classify_granules(
        granule_adata,
        cluster_column=cluster_column,
        enrichment_threshold=enrichment_threshold,
        min_zscore_threshold=min_zscore_threshold,
        custom_markers=marker_genes,
    )
    granule_adata.obs["granule_subtype_automated"] = subtypes.values
    granule_adata.obs["granule_subtype_automated_simple"] = subtypes_simple.values
    return granule_adata

### Export labels (parquet) and subtype density (CSV)

In [None]:
def compute_subtype_density_per_region(granule_obs, spots, area_col="brain_area", subtype_col="granule_subtype_manual",
                                        sample_col="sample", coord_keys=("global_x", "global_y"), grid_len=50):
    """
    Spot-level density per subtype: for each (sample, brain_area, subtype), density = sum(granule count in that area) / n_spots in that area.
    granule_obs must have coords (global_x, global_y or sphere_x, sphere_y), sample_col, and subtype_col.
    spots is AnnData with spots.obs containing area_col and coord_keys.
    """
    if area_col not in spots.obs.columns or coord_keys[0] not in spots.obs.columns or coord_keys[1] not in spots.obs.columns:
        return pd.DataFrame()
    half = grid_len / 2
    rows = []
    for sample in granule_obs[sample_col].dropna().unique():
        go = granule_obs[granule_obs[sample_col] == sample].copy()
        xcol = "global_x" if "global_x" in go.columns else "sphere_x"
        ycol = "global_y" if "global_y" in go.columns else "sphere_y"
        if xcol not in go.columns or ycol not in go.columns:
            continue
        for area in spots.obs[area_col].dropna().unique():
            sp = spots.obs.loc[spots.obs[area_col] == area]
            n_spots = len(sp)
            if n_spots == 0:
                continue
            for subtype in go[subtype_col].dropna().unique():
                gs = go.loc[go[subtype_col] == subtype]
                total = 0
                for _, srow in sp.iterrows():
                    x, y = srow[coord_keys[0]], srow[coord_keys[1]]
                    in_spot = (gs[xcol].values >= x - half) & (gs[xcol].values < x + half) & (gs[ycol].values >= y - half) & (gs[ycol].values < y + half)
                    total += in_spot.sum()
                density = total / n_spots
                rows.append({"sample": sample, "brain_area": area, "subtype": subtype, "density": density, "n_spots": n_spots})
    return pd.DataFrame(rows)

# Simpler variant: proportion per (sample, subtype) or per (sample, brain_area, subtype) when area_col present.
def compute_subtype_density_simple(granule_obs, area_col="brain_area", subtype_col="granule_subtype_manual", sample_col="sample"):
    """Density = count / total granules per sample. If area_col in obs, group by (sample, brain_area, subtype); else (sample, subtype) with brain_area='all'."""
    total_per_sample = granule_obs.groupby(sample_col).size()
    if area_col in granule_obs.columns:
        g = granule_obs.groupby([sample_col, area_col, subtype_col], dropna=False).size().reset_index(name="count")
        g["density"] = g.apply(lambda r: r["count"] / total_per_sample[r[sample_col]], axis=1)
        g["n_spots"] = np.nan
    else:
        g = granule_obs.groupby([sample_col, subtype_col], dropna=False).size().reset_index(name="count")
        g["density"] = g.apply(lambda r: r["count"] / total_per_sample[r[sample_col]], axis=1)
        g["brain_area"] = "all"
        g["n_spots"] = np.nan
    return g

In [None]:
def export_setting_results(benchmark_name, setting_key, benchmark_root, comparison_dir, granule_adata, subtype_col="granule_subtype_manual", density_df=None):
    """
    Export to benchmark_xxx/WT_AD_comparison/ with setting-specific filenames:
    - granule_subtype_labels_{setting_key}.parquet
    - subtype_density_per_region_{setting_key}.csv
    """
    out_dir = os.path.join(benchmark_root, benchmark_name, comparison_dir)
    os.makedirs(out_dir, exist_ok=True)
    keep_cols = [c for c in ["granule_id", "sample", subtype_col, "global_x", "global_y", "granule_subtype_kmeans", "granule_subtype_kmeans_category", "granule_subtype_score_genes"] if c in granule_adata.obs.columns]
    labels_df = granule_adata.obs[keep_cols].copy()
    labels_df.to_parquet(os.path.join(out_dir, f"granule_subtype_labels_{setting_key}.parquet"), index=False)
    if density_df is not None and len(density_df) > 0:
        density_df.to_csv(os.path.join(out_dir, f"subtype_density_per_region_{setting_key}.csv"), index=False)
    else:
        density_simple = compute_subtype_density_simple(granule_adata.obs, subtype_col=subtype_col)
        if len(density_simple) > 0:
            density_simple.to_csv(os.path.join(out_dir, f"subtype_density_per_region_{setting_key}.csv"), index=False)
    print(f"Exported {out_dir}")

## Run full pipeline for one setting (example)

In [None]:
# Run for the first available setting (or pick another index from BENCHMARK_SETTINGS)
benchmark_name, filename = BENCHMARK_SETTINGS[0]
setting_key = os.path.splitext(filename)[0]  # e.g. granules_rho_0.0 or granules_expression_default
print(f"Processing: {benchmark_name} / {filename}  (setting_key={setting_key})")

In [None]:
# Concatenate WT + AD adata (same filename in each representative dir)
adata_combined = concatenate_wt_ad_for_setting(benchmark_name, filename, genes_wt, genes_ad, transcripts_wt, transcripts_ad, nc_wt, nc_ad, ref_genes_common)

sc.pp.normalize_total(adata_combined, target_sum=1e4)
sc.pp.log1p(adata_combined)

print(adata_combined.shape)
print(adata_combined.obs["sample"].value_counts())

In [None]:
# Manual subtyping: k-means with fixed seed (default: 34 genes → granule_subtype_kmeans)
run_manual_subtyping(adata_combined, n_clusters=N_CLUSTERS, seed=SEED, batch_size=KMEANS_BATCH_SIZE, n_init=KMEANS_N_INIT)
del adata_combined.obs["granule_subtype_kmeans_category"]

# Option: use 4 category-level features → stores in granule_subtype_kmeans_category
# run_manual_subtyping(adata_combined, N_CLUSTERS, SEED, KMEANS_BATCH_SIZE, KMEANS_N_INIT, use_category_markers=True, marker_genes=marker_genes, category_agg="sum")

In [None]:
# Heatmap for manual annotation: use 4 category-level markers when clustering was category-level, else 34 genes.
if "granule_subtype_kmeans_category" in adata_combined.obs.columns:
    adata_plot = adata_for_category_heatmap(adata_combined, marker_genes)
    groupby = "granule_subtype_kmeans_category"
    var_names = ["pre-syn", "post-syn", "dendrites", "axons"]
else:
    adata_plot = adata_combined
    groupby = "granule_subtype_kmeans"
    ref_genes_sorted = ["Bsn", "Gap43", "Nrxn1", "Slc17a6", "Slc17a7", "Slc32a1", "Stx1a", "Syn1", "Syp", "Syt1", "Vamp2", "Cplx2",
                       "Camk2a", "Dlg3", "Dlg4", "Gphn", "Gria1", "Gria2", "Homer1", "Homer2", "Nlgn1", "Nlgn2", "Nlgn3", "Shank1", "Shank3",
                       "Cyfip2", "Ddn", "Map1a", "Map2", "Ank3", "Nav1", "Nfasc", "Mapt", "Tubb3"]
    var_names = [g for g in ref_genes_sorted if g in adata_combined.var_names]
    adata_plot.obs[groupby] = pd.Categorical(adata_plot.obs[groupby], categories=[str(i) for i in range(N_CLUSTERS)], ordered=True)
ax = sc.pl.heatmap(adata_plot, var_names=var_names, groupby=groupby, cmap="Reds", standard_scale="var", dendrogram=False, swap_axes=True, show=False, figsize=(10, 6))
plt.show()

In [None]:
MANUAL_SUBTYPE_MAPPING = {
    "pre-syn": ["2", "5", "6", "8", "13"],
    "post-syn": ["0", "1", "7"],
    "dendrites": ["3", "9"],
    "axons": [],
    "pre & post": ["4"],
    "pre & den": [],
    "post & den": ["11", "12"],
    "pre & post & den": ["10", "14"],
    "others": [],
}

# MANUAL_SUBTYPE_MAPPING = {
#     "pre-syn": [],
#     "post-syn": [],
#     "dendrites": [],
#     "axons": [],
#     "pre & post": [],
#     "pre & den": [],
#     "post & den": [],
#     "pre & post & den": [],
#     "others": [],
# }

In [None]:
# Apply manual mapping (after you fill MANUAL_SUBTYPE_MAPPING from the heatmap)
cluster_col = "granule_subtype_kmeans_category" if "granule_subtype_kmeans_category" in adata_combined.obs.columns else "granule_subtype_kmeans"
apply_manual_annotation(adata_combined, MANUAL_SUBTYPE_MAPPING, cluster_column=cluster_col)

In [None]:
adata_combined.obs["granule_subtype_manual"].value_counts() / len(adata_combined)

In [None]:
# Build manual_dist from manual annotation proportions (for run_automatic_tuning_grid)
manual_props = adata_combined.obs["granule_subtype_manual"].value_counts(normalize=True)
manual_dist = {"pre-syn": 0.0, "post-syn": 0.0, "dendrites": 0.0, "axons": 0.0, "mixed": 0.0, "others": 0.0}
for k, v in manual_props.items():
    if k in manual_dist:
        manual_dist[k] = float(v)
    else:
        manual_dist["others"] = manual_dist.get("others", 0) + float(v)
manual_dist

In [None]:
# Automatic subtyping: tuning step (run after you have manual proportions)
# manual_dist = {'pre-syn': 0.16, 'post-syn': 0.31, 'dendrites': 0.17, 'axons': 0.0, 'mixed': 0.34, 'others': 0.02}
results_df, best_enrich, best_zscore = run_automatic_tuning_grid(adata_combined, marker_genes, manual_dist)
print(results_df.sort_values('least_square_error').head(10))

# Then run with best hyperparameters:
# run_automatic_subtyping(adata_combined, marker_genes, enrichment_threshold=0.2, min_zscore_threshold=0.0)
# Or: run_automatic_subtyping(adata_combined, marker_genes, enrichment_threshold=best_enrich, min_zscore_threshold=best_zscore)

In [None]:
adata_combined.obs["granule_subtype_automated_simple"].value_counts() / len(adata_combined)

In [None]:
# Density: if spots with brain_area are available, use compute_subtype_density_per_region; else use simple density
spots_wt = sc.read_h5ad(os.path.join(DATA_PATH_WT, "processed_data/spots.h5ad"))
spots_ad = sc.read_h5ad(os.path.join(DATA_PATH_AD, "processed_data/spots.h5ad"))
if "brain_area" in spots_wt.obs.columns and "brain_area" in spots_ad.obs.columns:
    density_df_wt = compute_subtype_density_per_region(adata_combined.obs[adata_combined.obs["sample"] == "WT"], spots_wt, subtype_col="granule_subtype_manual")
    density_df_ad = compute_subtype_density_per_region(adata_combined.obs[adata_combined.obs["sample"] == "AD"], spots_ad, subtype_col="granule_subtype_manual")
    density_df = pd.concat([density_df_wt, density_df_ad], ignore_index=True)
    density_df["setting"] = setting_key
else:
    density_df = compute_subtype_density_simple(adata_combined.obs, subtype_col="granule_subtype_manual")
    density_df["setting"] = setting_key
print(density_df.head())

In [None]:
# WT synaptic granule density vs ground truth: weighted Pearson and Spearman (HPF-SR excluded)
if "brain_area" not in spots_wt.obs.columns:
    print(f"Setting {setting_key}: skip correlation (no brain_area in spots)")
else:
    subtype_col_synaptic = "granule_subtype_manual" if "granule_subtype_manual" in adata_combined.obs.columns else ("granule_subtype_automated_simple" if "granule_subtype_automated_simple" in adata_combined.obs.columns else "granule_subtype")
    obs_wt = adata_combined.obs[adata_combined.obs["sample"] == "WT"].copy()
    obs_wt_synaptic = obs_wt[obs_wt[subtype_col_synaptic].isin(SYNAPTIC_SUBTYPES)].copy()
    obs_wt_synaptic["_synaptic_agg"] = "synaptic"
    density_synaptic_df = compute_subtype_density_per_region(obs_wt_synaptic, spots_wt, subtype_col="_synaptic_agg")
    df_wt_synaptic = density_synaptic_df[(density_synaptic_df["sample"] == "WT") & (density_synaptic_df["subtype"] == "synaptic")]
    wt_density_list = df_wt_synaptic.set_index("brain_area").reindex(AREA_LIST)["density"].values
    wt_density_list = np.nan_to_num(wt_density_list, nan=0.0)
    wt_list_no_hpf = np.delete(wt_density_list, 4)
    gt_no_hpf = np.delete(np.array(GROUND_TRUTH_DENSITIES), 4)
    weights = [np.sum(spots_wt.obs["brain_area"] == i) for i in AREA_LIST]
    weights_no_hpf = weights.copy()
    weights_no_hpf.pop(4)
    scaled_wt = scale(wt_list_no_hpf)
    scaled_gt = scale(gt_no_hpf)
    r_pearson = weighted_corr(scaled_wt, scaled_gt, weights_no_hpf)
    r_spearman = weighted_spearmanr(scaled_wt, scaled_gt, weights_no_hpf)
    print(f"Setting {setting_key}: weighted Pearson = {r_pearson:.4f}, weighted Spearman = {r_spearman:.4f}")

In [None]:
# Export labels and density to benchmark_xxx/WT_AD_comparison/ (files named with setting_key)
export_setting_results(
    benchmark_name, setting_key, BENCHMARK_ROOT, COMPARISON_DIR, adata_combined,
    subtype_col="granule_subtype_manual", density_df=density_df
)