In [1]:
import anndata
import matplotlib.colors as clr
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import random
import scanpy as sc
import sys
from collections import defaultdict
from sklearn.cluster import KMeans, MiniBatchKMeans
from sklearn.preprocessing import StandardScaler
from typing import Dict, List, Tuple, Optional

import warnings
warnings.filterwarnings("ignore")
sc.settings.verbosity = 0

from mcDETECT.utils import *
from mcDETECT.model import mcDETECT

  from pkg_resources import get_distribution, DistributionNotFound


In [2]:
class HeatmapBasedSubtyper:
    """
    Heatmap-based granule subtyping that mimics manual annotation from z-scored heatmaps.
    
    This classifier:
    1. Z-score normalizes gene expression (per gene across clusters/granules)
    2. Computes category scores based on z-scored values
    3. Assigns subtypes based on which category has highest z-score
    
    This reproduces the visual pattern you see in a heatmap with standard_scale="var"
    and automates the manual annotation process.
    """
    
    def __init__(
        self,
        genes_syn_pre: Optional[List[str]] = None,
        genes_syn_post: Optional[List[str]] = None,
        genes_dendrite: Optional[List[str]] = None,
        genes_axon: Optional[List[str]] = None,
        enrichment_threshold: float = 0.35,
        min_zscore_threshold: float = 0.0
    ):
        """
        Initialize the heatmap-based subtyper.
        
        Parameters
        ----------
        genes_syn_pre : List[str], optional
            Pre-synaptic marker genes
        genes_syn_post : List[str], optional
            Post-synaptic marker genes
        genes_dendrite : List[str], optional
            Dendritic marker genes
        genes_axon : List[str], optional
            Axonal marker genes
        enrichment_threshold : float, default=0.35
            Minimum proportion of z-score sum to consider category enriched
        min_zscore_threshold : float, default=0.0
            Minimum mean z-score for a category to be considered present
            (0.0 = at least average expression)
        """
        # Default marker gene sets
        self.genes_syn_pre = genes_syn_pre or [
            "Bsn", "Gap43", "Nrxn1", "Slc17a6", "Slc17a7", "Slc32a1",
            "Snap25", "Stx1a", "Syn1", "Syp", "Syt1", "Vamp2", "Cplx2"
        ]
        
        self.genes_syn_post = genes_syn_post or [
            "Camk2a", "Dlg3", "Dlg4", "Gphn", "Gria1", "Gria2",
            "Homer1", "Homer2", "Nlgn1", "Nlgn2", "Nlgn3", "Shank1", "Shank3"
        ]
        
        self.genes_dendrite = genes_dendrite or [
            "Actb", "Cyfip2", "Ddn", "Dlg4", "Map1a", "Map2"
        ]
        
        self.genes_axon = genes_axon or [
            "Ank3", "Nav1", "Sptnb4", "Nfasc", "Mapt", "Tubb3"
        ]
        
        self.enrichment_threshold = enrichment_threshold
        self.min_zscore_threshold = min_zscore_threshold
        
        # Create marker dictionary
        self.marker_genes = {
            "pre-syn": self.genes_syn_pre,
            "post-syn": self.genes_syn_post,
            "dendrites": self.genes_dendrite,
            "axons": self.genes_axon
        }
    
    def _compute_zscore_matrix(
        self,
        expression_matrix: np.ndarray,
        cluster_labels: Optional[np.ndarray] = None
    ) -> np.ndarray:
        """
        Compute z-score normalized expression matrix.
        
        If cluster_labels provided: z-score per gene across clusters (cluster-level means)
        If no cluster_labels: z-score per gene across all granules
        
        Parameters
        ----------
        expression_matrix : np.ndarray
            Raw expression matrix (granules × genes)
        cluster_labels : np.ndarray, optional
            Cluster assignments for each granule
            
        Returns
        -------
        np.ndarray
            Z-scored expression matrix
        """
        if cluster_labels is not None:
            # Cluster-level z-scoring (what heatmap shows)
            unique_clusters = np.unique(cluster_labels)
            cluster_means = np.zeros((len(unique_clusters), expression_matrix.shape[1]))
            
            for i, cluster in enumerate(unique_clusters):
                mask = cluster_labels == cluster
                cluster_means[i, :] = expression_matrix[mask, :].mean(axis=0)
            
            # Z-score across clusters (per gene)
            scaler = StandardScaler()
            zscore_means = scaler.fit_transform(cluster_means.T).T
            
            # Map back to individual granules
            zscore_matrix = np.zeros_like(expression_matrix)
            for i, cluster in enumerate(unique_clusters):
                mask = cluster_labels == cluster
                zscore_matrix[mask, :] = zscore_means[i, :]
            
            return zscore_matrix
        else:
            # Granule-level z-scoring (if no clusters)
            scaler = StandardScaler()
            return scaler.fit_transform(expression_matrix)
    
    def _compute_category_scores_zscore(
        self,
        zscore_matrix: np.ndarray,
        gene_names: List[str],
        cluster_labels: Optional[np.ndarray] = None
    ) -> pd.DataFrame:
        """
        Compute category scores based on z-scored expression.
        
        For each category, sum the z-scores of genes in that category.
        This mimics looking at the heatmap and seeing which markers are "red" (high z-score).
        
        Parameters
        ----------
        zscore_matrix : np.ndarray
            Z-scored expression matrix
        gene_names : List[str]
            Gene names
        cluster_labels : np.ndarray, optional
            Cluster assignments
            
        Returns
        -------
        pd.DataFrame
            DataFrame with category z-score sums
        """
        gene_to_idx = {gene: idx for idx, gene in enumerate(gene_names)}
        
        if cluster_labels is not None:
            # Compute cluster-level scores
            unique_clusters = np.unique(cluster_labels)
            n_clusters = len(unique_clusters)
            
            scores = {
                "cluster": unique_clusters,
                "pre_zscore": np.zeros(n_clusters),
                "post_zscore": np.zeros(n_clusters),
                "den_zscore": np.zeros(n_clusters),
                "axon_zscore": np.zeros(n_clusters)
            }
            
            for i, cluster in enumerate(unique_clusters):
                mask = cluster_labels == cluster
                cluster_zscore = zscore_matrix[mask, :].mean(axis=0)
                
                # Sum z-scores for each category
                for category, genes in [
                    ("pre", self.genes_syn_pre),
                    ("post", self.genes_syn_post),
                    ("den", self.genes_dendrite),
                    ("axon", self.genes_axon)
                ]:
                    gene_indices = [gene_to_idx[g] for g in genes if g in gene_to_idx]
                    if len(gene_indices) > 0:
                        # Sum of z-scores (like visual intensity in heatmap)
                        scores[f"{category}_zscore"][i] = cluster_zscore[gene_indices].sum()
        else:
            # Compute granule-level scores
            n_granules = zscore_matrix.shape[0]
            
            scores = {
                "pre_zscore": np.zeros(n_granules),
                "post_zscore": np.zeros(n_granules),
                "den_zscore": np.zeros(n_granules),
                "axon_zscore": np.zeros(n_granules)
            }
            
            for category, genes in [
                ("pre", self.genes_syn_pre),
                ("post", self.genes_syn_post),
                ("den", self.genes_dendrite),
                ("axon", self.genes_axon)
            ]:
                gene_indices = [gene_to_idx[g] for g in genes if g in gene_to_idx]
                if len(gene_indices) > 0:
                    # Sum z-scores across category genes for each granule
                    scores[f"{category}_zscore"] = zscore_matrix[:, gene_indices].sum(axis=1)
        
        return pd.DataFrame(scores)
    
    def _classify_from_zscores(
        self,
        pre_zscore: float,
        post_zscore: float,
        den_zscore: float,
        axon_zscore: float
    ) -> str:
        """
        Classify based on z-score sums (mimics manual heatmap interpretation).
        
        Logic:
        1. Compute total positive z-score
        2. Compute proportion of each category
        3. Identify enriched categories (above threshold)
        4. Assign subtype based on enriched categories
        
        Parameters
        ----------
        pre_zscore, post_zscore, den_zscore, axon_zscore : float
            Summed z-scores for each category
            
        Returns
        -------
        str
            Assigned subtype
        """
        # Only consider positive z-scores (above average)
        pre_pos = max(0, pre_zscore)
        post_pos = max(0, post_zscore)
        den_pos = max(0, den_zscore)
        axon_pos = max(0, axon_zscore)
        
        total_pos = pre_pos + post_pos + den_pos + axon_pos
        
        if total_pos == 0:
            return "others"
        
        # Compute proportions
        pre_prop = pre_pos / total_pos
        post_prop = post_pos / total_pos
        den_prop = den_pos / total_pos
        axon_prop = axon_pos / total_pos
        
        # Identify enriched categories
        is_pre = pre_prop >= self.enrichment_threshold and pre_zscore >= self.min_zscore_threshold
        is_post = post_prop >= self.enrichment_threshold and post_zscore >= self.min_zscore_threshold
        is_den = den_prop >= self.enrichment_threshold and den_zscore >= self.min_zscore_threshold
        is_axon = axon_prop >= self.enrichment_threshold and axon_zscore >= self.min_zscore_threshold
        
        # Classification logic (same as proportion-based classifier)
        # Pure types
        if is_pre and not is_post and not is_den and not is_axon:
            return "pre-syn"
        elif is_post and not is_pre and not is_den and not is_axon:
            return "post-syn"
        elif is_den and not is_pre and not is_post and not is_axon:
            return "dendrites"
        elif is_axon and not is_pre and not is_post and not is_den:
            return "axons"
        
        # Two-category combinations
        elif is_pre and is_post and not is_den and not is_axon:
            return "pre & post"
        elif is_pre and is_den and not is_post and not is_axon:
            return "pre & den"
        elif is_post and is_den and not is_pre and not is_axon:
            return "post & den"
        elif is_pre and is_axon and not is_post and not is_den:
            return "pre & axon"
        elif is_post and is_axon and not is_pre and not is_den:
            return "post & axon"
        elif is_den and is_axon and not is_pre and not is_post:
            return "den & axon"
        
        # Three-category combinations
        elif is_pre and is_post and is_den and not is_axon:
            return "pre & post & den"
        elif is_pre and is_post and is_axon and not is_den:
            return "pre & post & axon"
        elif is_pre and is_den and is_axon and not is_post:
            return "pre & den & axon"
        elif is_post and is_den and is_axon and not is_pre:
            return "post & den & axon"
        
        # Four-category combination
        elif is_pre and is_post and is_den and is_axon:
            return "pre & post & den & axon"
        
        # No enriched categories or other patterns
        else:
            return "others"
    
    def predict(
        self,
        granule_adata: anndata.AnnData,
        cluster_column: Optional[str] = None,
        add_scores: bool = True
    ) -> pd.Series:
        """
        Predict granule subtypes using heatmap-based (z-score) approach.
        
        Parameters
        ----------
        granule_adata : anndata.AnnData
            AnnData object containing granule expression data
        cluster_column : str, optional
            Column in obs containing cluster labels. If provided, classification
            is done at the cluster level (mimicking manual annotation of clusters).
            If None, classification is done at the granule level.
        add_scores : bool, default=True
            Whether to add z-score sums to granule_adata.obs
            
        Returns
        -------
        pd.Series
            Series of predicted subtypes
        """
        # Extract expression matrix
        expr_matrix = granule_adata.X
        if hasattr(expr_matrix, 'toarray'):
            expr_matrix = expr_matrix.toarray()
        
        gene_names = list(granule_adata.var_names)
        
        # Get cluster labels if provided
        cluster_labels = None
        if cluster_column is not None and cluster_column in granule_adata.obs.columns:
            cluster_labels = granule_adata.obs[cluster_column].values
        
        # Compute z-score matrix
        zscore_matrix = self._compute_zscore_matrix(expr_matrix, cluster_labels)
        
        # Compute category scores
        scores_df = self._compute_category_scores_zscore(
            zscore_matrix,
            gene_names,
            cluster_labels
        )
        
        # Classify
        if cluster_labels is not None:
            # Cluster-level classification
            cluster_subtypes = {}
            for idx, row in scores_df.iterrows():
                cluster = row['cluster']
                subtype = self._classify_from_zscores(
                    row['pre_zscore'],
                    row['post_zscore'],
                    row['den_zscore'],
                    row['axon_zscore']
                )
                cluster_subtypes[cluster] = subtype
            
            # Map cluster subtypes to granules
            subtypes = [cluster_subtypes[c] for c in cluster_labels]
            
            # Add cluster-level scores to obs
            if add_scores:
                for idx, row in scores_df.iterrows():
                    cluster = row['cluster']
                    mask = cluster_labels == cluster
                    for col in ['pre_zscore', 'post_zscore', 'den_zscore', 'axon_zscore']:
                        if col not in granule_adata.obs.columns:
                            granule_adata.obs[col] = 0.0
                        granule_adata.obs.loc[mask, col] = row[col]
        else:
            # Granule-level classification
            subtypes = []
            for idx in range(len(scores_df)):
                subtype = self._classify_from_zscores(
                    scores_df.loc[idx, 'pre_zscore'],
                    scores_df.loc[idx, 'post_zscore'],
                    scores_df.loc[idx, 'den_zscore'],
                    scores_df.loc[idx, 'axon_zscore']
                )
                subtypes.append(subtype)
            
            # Add scores to obs
            if add_scores:
                for col in ['pre_zscore', 'post_zscore', 'den_zscore', 'axon_zscore']:
                    granule_adata.obs[col] = scores_df[col].values
        
        return pd.Series(
            subtypes,
            index=granule_adata.obs.index,
            name="granule_subtype_heatmap"
        )
    
    def predict_and_compare(
        self,
        granule_adata: anndata.AnnData,
        cluster_column: Optional[str] = None,
        manual_subtype_column: str = "granule_subtype"
    ) -> Tuple[pd.Series, Optional[pd.DataFrame]]:
        """
        Predict subtypes and compare with manual annotations.
        
        Parameters
        ----------
        granule_adata : anndata.AnnData
            AnnData object
        cluster_column : str, optional
            Cluster column name
        manual_subtype_column : str, default="granule_subtype"
            Column with manual annotations
            
        Returns
        -------
        subtypes : pd.Series
            Predicted subtypes
        comparison : pd.DataFrame
            Confusion matrix if manual annotations exist
        """
        subtypes = self.predict(granule_adata, cluster_column, add_scores=True)
        
        if manual_subtype_column in granule_adata.obs.columns:
            manual_subtypes = granule_adata.obs[manual_subtype_column]
            comparison = pd.crosstab(
                manual_subtypes,
                subtypes,
                rownames=["Manual"],
                colnames=["Heatmap-Based"]
            )
            return subtypes, comparison
        else:
            return subtypes, None


# Convenience function
def classify_granules_heatmap(
    granule_adata: anndata.AnnData,
    cluster_column: Optional[str] = None,
    enrichment_threshold: float = 0.35,
    min_zscore_threshold: float = 0.0,
    custom_markers: Optional[Dict[str, List[str]]] = None
) -> pd.Series:
    """
    Classify granules using heatmap-based (z-score) approach.
    
    This mimics the manual annotation process from 3_subtyping.ipynb where
    you looked at a z-scored heatmap and assigned subtypes to clusters.
    
    Parameters
    ----------
    granule_adata : anndata.AnnData
        AnnData object with granule expression data
    cluster_column : str, optional
        Column in obs with cluster labels. If provided, classification is done
        at cluster level. If None, done at granule level.
    enrichment_threshold : float, default=0.35
        Minimum proportion of positive z-score to consider category enriched
    min_zscore_threshold : float, default=0.0
        Minimum z-score for category to be considered (0 = average)
    custom_markers : Dict[str, List[str]], optional
        Custom marker gene sets
        
    Returns
    -------
    pd.Series
        Predicted subtypes
        
    Examples
    --------
    >>> # Cluster-level (mimics manual annotation exactly)
    >>> subtypes = classify_granules_heatmap(
    ...     granule_adata,
    ...     cluster_column="granule_subtype_kmeans"
    ... )
    
    >>> # Granule-level (no clustering)
    >>> subtypes = classify_granules_heatmap(granule_adata)
    """
    if custom_markers is not None:
        subtyper = HeatmapBasedSubtyper(
            genes_syn_pre=custom_markers.get('pre-syn'),
            genes_syn_post=custom_markers.get('post-syn'),
            genes_dendrite=custom_markers.get('dendrites'),
            genes_axon=custom_markers.get('axons'),
            enrichment_threshold=enrichment_threshold,
            min_zscore_threshold=min_zscore_threshold
        )
    else:
        subtyper = HeatmapBasedSubtyper(
            enrichment_threshold=enrichment_threshold,
            min_zscore_threshold=min_zscore_threshold
        )
    
    return subtyper.predict(granule_adata, cluster_column, add_scores=True)

In [9]:
granule_adata = sc.read_h5ad("../data/adata_final.h5ad")
granule_adata

AnnData object with n_obs × n_vars = 1498704 × 290
    obs: 'global_x', 'global_y', 'global_z', 'layer_z', 'sphere_r', 'size', 'comp', 'in_nucleus', 'gene', 'brain_area', 'global_y_new', 'global_x_new', 'synapse_id', 'global_x_adjusted', 'global_y_adjusted', 'batch', 'brain_area_merged', 'batch_simple', 'granule_subtype_kmeans', 'granule_subtype'
    var: 'genes'
    uns: 'batch_simple_colors', 'brain_area_colors', 'log1p', 'pca', 'rank_genes_groups', 'tsne'
    obsm: 'X_pca', 'X_tsne'
    varm: 'PCs'

In [10]:
granule_adata.obs["granule_subtype"].value_counts() / granule_adata.shape[0]

granule_subtype
mixed        0.300492
post-syn     0.237391
pre-syn      0.193643
dendrites    0.143209
others       0.125265
Name: count, dtype: float64

In [11]:
subtypes = classify_granules_heatmap(granule_adata, cluster_column="granule_subtype_kmeans")

In [12]:
def simplify(s):
    if ' & ' in str(s):
        return 'mixed'
    return str(s)

simple_no_dom = subtypes.apply(simplify)

In [13]:
simple_no_dom.value_counts() / len(simple_no_dom)

granule_subtype_heatmap
post-syn     0.276759
pre-syn      0.262064
dendrites    0.174790
mixed        0.161122
others       0.125265
Name: count, dtype: float64

In [15]:
# Define parameter ranges
enrich_thresholds = [0.25, 0.30, 0.35, 0.40, 0.45, 0.50]
zscore_thresholds = [-0.5, 0.0, 0.5, 1.0]

results = []

for enrich_thr in enrich_thresholds:
    for zscore_thr in zscore_thresholds:
        subtypes = classify_granules_heatmap(
            granule_adata,
            cluster_column="cluster_kmeans",
            enrichment_threshold=enrich_thr,
            min_zscore_threshold=zscore_thr
        )
        
        # Simplify
        simple = subtypes.apply(lambda s: 'mixed' if ' & ' in str(s) else str(s))
        counts = simple.value_counts(normalize=True)
        
        results.append({
            'enrich_thr': enrich_thr,
            'zscore_thr': zscore_thr,
            'pre-syn': counts.get('pre-syn', 0),
            'post-syn': counts.get('post-syn', 0),
            'dendrites': counts.get('dendrites', 0),
            'mixed': counts.get('mixed', 0),
            'others': counts.get('others', 0)
        })
        
        print(f"Completed: enrich_thr: {enrich_thr}, zscore_thr: {zscore_thr}")

results_df = pd.DataFrame(results)

# Find best match to manual (if you have manual annotations)
manual_dist = {
    'pre-syn': 0.1936,
    'post-syn': 0.2374,
    'dendrites': 0.1432,
    'mixed': 0.3005,
    'others': 0.1253
}

results_df['distance'] = results_df.apply(
    lambda row: sum(abs(row[k] - manual_dist[k]) for k in manual_dist.keys()),
    axis=1
)

print("Best parameter combinations:")
print(results_df.sort_values('distance').head(10))

Completed: enrich_thr: 0.25, zscore_thr: -0.5
Completed: enrich_thr: 0.25, zscore_thr: 0.0
Completed: enrich_thr: 0.25, zscore_thr: 0.5
Completed: enrich_thr: 0.25, zscore_thr: 1.0
Completed: enrich_thr: 0.3, zscore_thr: -0.5
Completed: enrich_thr: 0.3, zscore_thr: 0.0
Completed: enrich_thr: 0.3, zscore_thr: 0.5
Completed: enrich_thr: 0.3, zscore_thr: 1.0
Completed: enrich_thr: 0.35, zscore_thr: -0.5
Completed: enrich_thr: 0.35, zscore_thr: 0.0
Completed: enrich_thr: 0.35, zscore_thr: 0.5
Completed: enrich_thr: 0.35, zscore_thr: 1.0
Completed: enrich_thr: 0.4, zscore_thr: -0.5
Completed: enrich_thr: 0.4, zscore_thr: 0.0
Completed: enrich_thr: 0.4, zscore_thr: 0.5
Completed: enrich_thr: 0.4, zscore_thr: 1.0
Completed: enrich_thr: 0.45, zscore_thr: -0.5
Completed: enrich_thr: 0.45, zscore_thr: 0.0
Completed: enrich_thr: 0.45, zscore_thr: 0.5
Completed: enrich_thr: 0.45, zscore_thr: 1.0
Completed: enrich_thr: 0.5, zscore_thr: -0.5
Completed: enrich_thr: 0.5, zscore_thr: 0.0
Completed: enr

In [17]:
results_df

Unnamed: 0,enrich_thr,zscore_thr,pre-syn,post-syn,dendrites,mixed,others,distance
0,0.25,-0.5,0.145328,0.112597,0.226849,0.199591,0.274764,0.507097
1,0.25,0.0,0.145328,0.112597,0.226849,0.199591,0.274764,0.507097
2,0.25,0.5,0.131997,0.114205,0.172011,0.187829,0.353087,0.554068
3,0.25,1.0,0.12564,0.110892,0.073385,0.176263,0.472949,0.73617
4,0.3,-0.5,0.162989,0.130096,0.234092,0.149663,0.275162,0.529506
5,0.3,0.0,0.162989,0.130096,0.234092,0.149663,0.275162,0.529506
6,0.3,0.5,0.149365,0.129279,0.177852,0.142021,0.353485,0.573672
7,0.3,1.0,0.141844,0.123664,0.078216,0.134931,0.473347,0.744092
8,0.35,-0.5,0.180581,0.144642,0.242886,0.097217,0.278653,0.562098
9,0.35,0.0,0.180581,0.144642,0.242886,0.097217,0.278653,0.562098
