In [17]:
import anndata
import matplotlib.colors as clr
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import random
import scanpy as sc
import sys
from collections import defaultdict
from sklearn.cluster import KMeans, MiniBatchKMeans
from typing import Dict, List, Tuple, Optional

import warnings
warnings.filterwarnings("ignore")
sc.settings.verbosity = 0

from mcDETECT.utils import *
from mcDETECT.model import mcDETECT

In [18]:
class AutomatedGranuleSubtyper:
    """
    Automated subtyping of granules based on marker gene expression patterns.
    
    This classifier uses a rule-based approach with thresholds to assign
    granules to subtypes based on their enrichment of different marker
    gene categories (pre-synaptic, post-synaptic, dendritic, axonal).
    """
    
    def __init__(
        self,
        genes_syn_pre: Optional[List[str]] = None,
        genes_syn_post: Optional[List[str]] = None,
        genes_dendrite: Optional[List[str]] = None,
        genes_axon: Optional[List[str]] = None,
        expression_threshold: float = 0.1,
        enrichment_threshold: float = 0.35,
        enrichment_threshold_pre: Optional[float] = None,
        enrichment_threshold_post: Optional[float] = None,
        enrichment_threshold_den: Optional[float] = None,
        enrichment_threshold_axon: Optional[float] = None,
        min_total_expression: int = 1
    ):
        """
        Initialize the automated granule subtyper.
        
        Parameters
        ----------
        genes_syn_pre : List[str], optional
            Pre-synaptic marker genes. Defaults to standard set if None.
        genes_syn_post : List[str], optional
            Post-synaptic marker genes. Defaults to standard set if None.
        genes_dendrite : List[str], optional
            Dendritic marker genes. Defaults to standard set if None.
        genes_axon : List[str], optional
            Axonal marker genes. Defaults to standard set if None.
        expression_threshold : float, default=0.1
            Minimum normalized expression to consider a marker "present"
        enrichment_threshold : float, default=0.35
            Global minimum proportion of total marker expression to consider
            a category "enriched" (e.g., 0.35 means 35% of expression).
            This is used as the default for all categories.
        enrichment_threshold_pre : float, optional
            Category-specific threshold for pre-synaptic markers.
            If None, uses enrichment_threshold.
        enrichment_threshold_post : float, optional
            Category-specific threshold for post-synaptic markers.
            If None, uses enrichment_threshold.
        enrichment_threshold_den : float, optional
            Category-specific threshold for dendritic markers.
            If None, uses enrichment_threshold.
        enrichment_threshold_axon : float, optional
            Category-specific threshold for axonal markers.
            If None, uses enrichment_threshold.
        min_total_expression : int, default=1
            Minimum total marker expression to attempt classification
        """
        # Default marker gene sets (from your notebook)
        self.genes_syn_pre = genes_syn_pre or [
            "Bsn", "Gap43", "Nrxn1", "Slc17a6", "Slc17a7", "Slc32a1",
            "Snap25", "Stx1a", "Syn1", "Syp", "Syt1", "Vamp2", "Cplx2"
        ]
        
        self.genes_syn_post = genes_syn_post or [
            "Camk2a", "Dlg3", "Dlg4", "Gphn", "Gria1", "Gria2",
            "Homer1", "Homer2", "Nlgn1", "Nlgn2", "Nlgn3", "Shank1", "Shank3"
        ]
        
        self.genes_dendrite = genes_dendrite or [
            "Actb", "Cyfip2", "Ddn", "Dlg4", "Map1a", "Map2"
        ]
        
        self.genes_axon = genes_axon or [
            "Ank3", "Nav1", "Sptnb4", "Nfasc", "Mapt", "Tubb3"
        ]
        
        # Thresholds
        self.expression_threshold = expression_threshold
        self.enrichment_threshold = enrichment_threshold
        self.min_total_expression = min_total_expression
        
        # Category-specific thresholds (use global if not specified)
        self.enrichment_threshold_pre = enrichment_threshold_pre if enrichment_threshold_pre is not None else enrichment_threshold
        self.enrichment_threshold_post = enrichment_threshold_post if enrichment_threshold_post is not None else enrichment_threshold
        self.enrichment_threshold_den = enrichment_threshold_den if enrichment_threshold_den is not None else enrichment_threshold
        self.enrichment_threshold_axon = enrichment_threshold_axon if enrichment_threshold_axon is not None else enrichment_threshold
        
        # Create marker dictionary
        self.marker_genes = {
            "pre-syn": self.genes_syn_pre,
            "post-syn": self.genes_syn_post,
            "dendrites": self.genes_dendrite,
            "axons": self.genes_axon
        }
        
    def _compute_category_scores(
        self,
        expression_matrix: np.ndarray,
        gene_names: List[str]
    ) -> pd.DataFrame:
        """
        Compute normalized expression scores for each marker category.
        
        Parameters
        ----------
        expression_matrix : np.ndarray
            Expression matrix (granules × genes)
        gene_names : List[str]
            Gene names corresponding to columns
            
        Returns
        -------
        pd.DataFrame
            DataFrame with columns: pre_score, post_score, den_score, axon_score
        """
        n_granules = expression_matrix.shape[0]
        scores = {
            "pre_score": np.zeros(n_granules),
            "post_score": np.zeros(n_granules),
            "den_score": np.zeros(n_granules),
            "axon_score": np.zeros(n_granules),
            "total_score": np.zeros(n_granules)
        }
        
        # Create gene index mapping
        gene_to_idx = {gene: idx for idx, gene in enumerate(gene_names)}
        
        # Compute average expression for each category
        for category, genes in [
            ("pre", self.genes_syn_pre),
            ("post", self.genes_syn_post),
            ("den", self.genes_dendrite),
            ("axon", self.genes_axon)
        ]:
            # Find indices of genes present in the data
            gene_indices = [gene_to_idx[g] for g in genes if g in gene_to_idx]
            
            if len(gene_indices) > 0:
                # Sum expression across category genes
                category_expr = expression_matrix[:, gene_indices].sum(axis=1)
                # Normalize by number of genes in category
                scores[f"{category}_score"] = category_expr / len(genes)
        
        # Compute total marker expression
        scores["total_score"] = (
            scores["pre_score"] + scores["post_score"] + 
            scores["den_score"] + scores["axon_score"]
        )
        
        return pd.DataFrame(scores)
    
    def _classify_granule(
        self,
        pre_score: float,
        post_score: float,
        den_score: float,
        axon_score: float,
        total_score: float
    ) -> str:
        """
        Classify a single granule based on its category scores.
        
        The classification logic:
        1. If total expression is too low → "others"
        2. Compute proportion of expression in each category
        3. Identify enriched categories (using category-specific thresholds)
        4. Assign combined subtype based on enriched categories
        
        Parameters
        ----------
        pre_score, post_score, den_score, axon_score : float
            Normalized expression scores for each category
        total_score : float
            Total marker expression
            
        Returns
        -------
        str
            Assigned subtype
        """
        # Check minimum expression threshold
        if total_score < self.min_total_expression:
            return "others"
        
        # Compute proportions
        pre_prop = pre_score / total_score
        post_prop = post_score / total_score
        den_prop = den_score / total_score
        axon_prop = axon_score / total_score
        
        # Identify enriched categories (using category-specific thresholds)
        is_pre = pre_prop >= self.enrichment_threshold_pre
        is_post = post_prop >= self.enrichment_threshold_post
        is_den = den_prop >= self.enrichment_threshold_den
        is_axon = axon_prop >= self.enrichment_threshold_axon
        
        # Classification logic based on enriched categories
        # Pure types (single enriched category)
        if is_pre and not is_post and not is_den and not is_axon:
            return "pre-syn"
        elif is_post and not is_pre and not is_den and not is_axon:
            return "post-syn"
        elif is_den and not is_pre and not is_post and not is_axon:
            return "dendrites"
        elif is_axon and not is_pre and not is_post and not is_den:
            return "axons"
        
        # Two-category combinations
        elif is_pre and is_post and not is_den and not is_axon:
            return "pre & post"
        elif is_pre and is_den and not is_post and not is_axon:
            return "pre & den"
        elif is_post and is_den and not is_pre and not is_axon:
            return "post & den"
        elif is_pre and is_axon and not is_post and not is_den:
            return "pre & axon"
        elif is_post and is_axon and not is_pre and not is_den:
            return "post & axon"
        elif is_den and is_axon and not is_pre and not is_post:
            return "den & axon"
        
        # Three-category combinations (most common mixed types)
        elif is_pre and is_post and is_den and not is_axon:
            return "pre & post & den"
        elif is_pre and is_post and is_axon and not is_den:
            return "pre & post & axon"
        elif is_pre and is_den and is_axon and not is_post:
            return "pre & den & axon"
        elif is_post and is_den and is_axon and not is_pre:
            return "post & den & axon"
        
        # Four-category combination
        elif is_pre and is_post and is_den and is_axon:
            return "pre & post & den & axon"
        
        # No enriched categories or patterns not covered above
        else:
            return "others"
    
    def predict(
        self,
        granule_adata: anndata.AnnData,
        add_scores: bool = True
    ) -> pd.Series:
        """
        Predict granule subtypes for all granules in an AnnData object.
        
        Parameters
        ----------
        granule_adata : anndata.AnnData
            AnnData object containing granule expression data
        add_scores : bool, default=True
            Whether to add category scores to granule_adata.obs
            
        Returns
        -------
        pd.Series
            Series of predicted subtypes with same index as granule_adata.obs
        """
        # Extract expression matrix
        expr_matrix = granule_adata.X
        if hasattr(expr_matrix, 'toarray'):
            expr_matrix = expr_matrix.toarray()
        
        # Get gene names
        gene_names = list(granule_adata.var_names)
        
        # Compute category scores
        scores_df = self._compute_category_scores(expr_matrix, gene_names)
        
        # Classify each granule
        subtypes = []
        for idx in range(len(scores_df)):
            subtype = self._classify_granule(
                scores_df.loc[idx, "pre_score"],
                scores_df.loc[idx, "post_score"],
                scores_df.loc[idx, "den_score"],
                scores_df.loc[idx, "axon_score"],
                scores_df.loc[idx, "total_score"]
            )
            subtypes.append(subtype)
        
        # Add scores to AnnData if requested
        if add_scores:
            for col in ["pre_score", "post_score", "den_score", "axon_score", "total_score"]:
                granule_adata.obs[col] = scores_df[col].values
        
        # Return subtypes as categorical series
        subtypes_series = pd.Series(
            subtypes,
            index=granule_adata.obs.index,
            name="granule_subtype_auto"
        )
        
        return subtypes_series
    
    def predict_and_compare(
        self,
        granule_adata: anndata.AnnData,
        manual_subtype_column: str = "granule_subtype"
    ) -> Tuple[pd.Series, pd.DataFrame]:
        """
        Predict subtypes and compare with manual annotations if available.
        
        Parameters
        ----------
        granule_adata : anndata.AnnData
            AnnData object containing granule expression data
        manual_subtype_column : str, default="granule_subtype"
            Column name in obs containing manual annotations
            
        Returns
        -------
        subtypes : pd.Series
            Predicted subtypes
        comparison : pd.DataFrame
            Confusion matrix comparing automated vs manual subtypes
        """
        # Predict subtypes
        subtypes = self.predict(granule_adata, add_scores=True)
        
        # Create comparison if manual annotations exist
        if manual_subtype_column in granule_adata.obs.columns:
            manual_subtypes = granule_adata.obs[manual_subtype_column]
            
            # Create confusion matrix
            comparison = pd.crosstab(
                manual_subtypes,
                subtypes,
                rownames=["Manual"],
                colnames=["Automated"]
            )
            
            return subtypes, comparison
        else:
            return subtypes, None


# Convenience function for quick usage
def classify_granules(
    granule_adata: anndata.AnnData,
    expression_threshold: float = 0.1,
    enrichment_threshold: float = 0.35,
    enrichment_threshold_pre: Optional[float] = None,
    enrichment_threshold_post: Optional[float] = None,
    enrichment_threshold_den: Optional[float] = None,
    enrichment_threshold_axon: Optional[float] = None,
    min_total_expression: int = 1,
    custom_markers: Optional[Dict[str, List[str]]] = None
) -> pd.Series:
    """
    Classify granules into subtypes using automated rule-based classification.
    
    This is a convenience function that creates a subtyper and runs prediction.
    
    Parameters
    ----------
    granule_adata : anndata.AnnData
        AnnData object containing granule expression profiles
    expression_threshold : float, default=0.1
        Minimum normalized expression to consider a marker "present"
    enrichment_threshold : float, default=0.35
        Global minimum proportion of expression to consider a category enriched.
        This is used as the default for all categories if specific thresholds are not set.
    enrichment_threshold_pre : float, optional
        Category-specific threshold for pre-synaptic markers (overrides enrichment_threshold)
    enrichment_threshold_post : float, optional
        Category-specific threshold for post-synaptic markers (overrides enrichment_threshold)
    enrichment_threshold_den : float, optional
        Category-specific threshold for dendritic markers (overrides enrichment_threshold)
    enrichment_threshold_axon : float, optional
        Category-specific threshold for axonal markers (overrides enrichment_threshold)
    min_total_expression : int, default=1
        Minimum total marker expression for classification
    custom_markers : Dict[str, List[str]], optional
        Custom marker gene sets. Should have keys: 'pre-syn', 'post-syn',
        'dendrites', 'axons'. If None, uses default markers.
        
    Returns
    -------
    pd.Series
        Predicted granule subtypes
        
    Examples
    --------
    >>> # Basic usage with default parameters
    >>> subtypes = classify_granules(granule_adata)
    >>> granule_adata.obs['granule_subtype_auto'] = subtypes
    
    >>> # Adjust thresholds for stricter classification
    >>> subtypes = classify_granules(
    ...     granule_adata,
    ...     enrichment_threshold=0.4,  # Require 40% enrichment
    ...     min_total_expression=2      # Require more total expression
    ... )
    
    >>> # Use category-specific thresholds
    >>> # (more permissive for synaptic, stricter for dendritic/axonal)
    >>> subtypes = classify_granules(
    ...     granule_adata,
    ...     enrichment_threshold_pre=0.25,   # Lower threshold for pre-syn (25%)
    ...     enrichment_threshold_post=0.25,  # Lower threshold for post-syn (25%)
    ...     enrichment_threshold_den=0.45,   # Higher threshold for dendrites (45%)
    ...     enrichment_threshold_axon=0.45   # Higher threshold for axons (45%)
    ... )
    
    >>> # Use custom marker genes
    >>> custom_markers = {
    ...     'pre-syn': ['Syp', 'Syn1', 'Bsn'],
    ...     'post-syn': ['Dlg4', 'Shank3', 'Homer1'],
    ...     'dendrites': ['Map2', 'Cyfip2'],
    ...     'axons': ['Mapt', 'Nav1']
    ... }
    >>> subtypes = classify_granules(granule_adata, custom_markers=custom_markers)
    """
    # Create subtyper
    if custom_markers is not None:
        subtyper = AutomatedGranuleSubtyper(
            genes_syn_pre=custom_markers.get('pre-syn'),
            genes_syn_post=custom_markers.get('post-syn'),
            genes_dendrite=custom_markers.get('dendrites'),
            genes_axon=custom_markers.get('axons'),
            expression_threshold=expression_threshold,
            enrichment_threshold=enrichment_threshold,
            enrichment_threshold_pre=enrichment_threshold_pre,
            enrichment_threshold_post=enrichment_threshold_post,
            enrichment_threshold_den=enrichment_threshold_den,
            enrichment_threshold_axon=enrichment_threshold_axon,
            min_total_expression=min_total_expression
        )
    else:
        subtyper = AutomatedGranuleSubtyper(
            expression_threshold=expression_threshold,
            enrichment_threshold=enrichment_threshold,
            enrichment_threshold_pre=enrichment_threshold_pre,
            enrichment_threshold_post=enrichment_threshold_post,
            enrichment_threshold_den=enrichment_threshold_den,
            enrichment_threshold_axon=enrichment_threshold_axon,
            min_total_expression=min_total_expression
        )
    
    # Predict and return
    return subtyper.predict(granule_adata, add_scores=True)

In [21]:
# Colors
color_dct = ["#F56867","#FEB915","#C798EE","#59BE86","#7495D3","#6D1A9C","#15821E","#3A84E6","#997273","#787878","#DB4C6C","#9E7A7A","#554236","#AF5F3C","#93796C","#F9BD3F","#DAB370","#877F6C","#268785"]
color_cts = clr.LinearSegmentedColormap.from_list("magma", ["#000003", "#3B0F6F", "#8C2980", "#F66E5B", "#FD9F6C", "#FBFCBF"], N=256)

In [22]:
granule_adata = sc.read_h5ad("../data/adata_final.h5ad")
granule_adata

AnnData object with n_obs × n_vars = 1498704 × 290
    obs: 'global_x', 'global_y', 'global_z', 'layer_z', 'sphere_r', 'size', 'comp', 'in_nucleus', 'gene', 'brain_area', 'global_y_new', 'global_x_new', 'synapse_id', 'global_x_adjusted', 'global_y_adjusted', 'batch', 'brain_area_merged', 'batch_simple', 'granule_subtype_kmeans', 'granule_subtype'
    var: 'genes'
    uns: 'batch_simple_colors', 'brain_area_colors', 'log1p', 'pca', 'rank_genes_groups', 'tsne'
    obsm: 'X_pca', 'X_tsne'
    varm: 'PCs'

In [30]:
subtypes = classify_granules(
    granule_adata,
    enrichment_threshold_pre=0.15,   # Lower: easier to call pre-syn (was 0.35)
    enrichment_threshold_post=0.15,  # Lower: easier to call post-syn (was 0.35)
    enrichment_threshold_den=0.45,   # Higher: harder to call dendrites (was 0.35)
    enrichment_threshold_axon=0.45   # Higher: harder to call axons (was 0.35)
)

In [31]:
def relabel_granule_subtypes(subtype_series: pd.Series) -> pd.Series:
    """
    Relabel automated granule subtypes into high-level categories.
    
    Rules:
    - Any multi-category combination (contains ' & ') -> 'mixed'
    - Keep pure categories unchanged
    - Keep 'others'
    """
    s = subtype_series.astype(str)

    mixed_mask = s.str.contains(" & ", regex=False)

    relabeled = s.where(~mixed_mask, "mixed")

    return relabeled.astype("category")

In [32]:
subtypes_simple = relabel_granule_subtypes(subtypes)
subtypes_simple

0            mixed
1            mixed
2           others
3           others
4         post-syn
            ...   
532852       axons
532853      others
532854       mixed
532855       mixed
532856       axons
Name: granule_subtype_auto, Length: 1498704, dtype: category
Categories (6, object): ['axons', 'dendrites', 'mixed', 'others', 'post-syn', 'pre-syn']

In [33]:
subtypes_simple.value_counts() / len(subtypes_simple)

granule_subtype_auto
mixed        0.514663
others       0.271758
dendrites    0.084002
pre-syn      0.075360
post-syn     0.045928
axons        0.008289
Name: count, dtype: float64

In [29]:
granule_adata.obs['granule_subtype'].value_counts() / granule_adata.shape[0]

granule_subtype
mixed        0.300492
post-syn     0.237391
pre-syn      0.193643
dendrites    0.143209
others       0.125265
Name: count, dtype: float64

In [5]:
custom_markers = {
    "pre-syn": ["Bsn", "Gap43", "Nrxn1", "Slc17a6", "Slc17a7", "Slc32a1", "Snap25", "Stx1a", "Syn1", "Syp", "Syt1", "Vamp2", "Cplx2"],
    "post-syn": ["Camk2a", "Dlg3", "Dlg4", "Gphn", "Gria1", "Gria2", "Homer1", "Homer2", "Nlgn1", "Nlgn2", "Nlgn3", "Shank1", "Shank3"],
    "axons": ["Ank3", "Nav1", "Sptnb4", "Nfasc", "Mapt", "Tubb3"],
    "dendrites": ["Actb", "Cyfip2", "Ddn", "Dlg4", "Map1a", "Map2"]
}
subtypes_auto = automated_granule_subtyping(granule_adata, custom_markers = custom_markers)

In [6]:
granule_adata.obs['granule_subtype_auto'] = subtypes_auto

In [None]:


granule_adata.obs["granule_subtype_simple"] = relabel_granule_subtypes(
    granule_adata.obs["granule_subtype_auto"]
)

granule_subtype
mixed        0.300492
post-syn     0.237391
pre-syn      0.193643
dendrites    0.143209
others       0.125265
Name: count, dtype: float64

In [None]:
granule_adata.obs['granule_subtype'].value_counts() / granule_adata.shape[0]

In [27]:
# same category order
granule_adata.obs["granule_subtype_simple"] = pd.Categorical(granule_adata.obs["granule_subtype_simple"], categories=["pre-syn", "post-syn", "dendrites", "axons", "mixed", "others"], ordered=True)
granule_adata.obs["granule_subtype"] = pd.Categorical(granule_adata.obs["granule_subtype"], categories=["pre-syn", "post-syn", "dendrites", "mixed", "others"], ordered=True)

In [28]:
comparison = pd.crosstab(granule_adata.obs["granule_subtype"], granule_adata.obs["granule_subtype_simple"], rownames=["Manual"], colnames=["Automated"])
print(comparison)

# Calculate agreement
agreement = (comparison.values.diagonal().sum() / comparison.values.sum() * 100)
print(f"\nOverall agreement: {agreement:.2f}%")

Automated  pre-syn  post-syn  dendrites  axons   mixed  others
Manual                                                        
pre-syn      91909      1876      35884  13911   72231   58512
post-syn      5480     53757      23610  11205   70132  188367
dendrites      221        43      66073   1563   10432     836
mixed        75340      9392     325583   6595  132501   57947
others        3775       215      14868  17538    6317  153120

Overall agreement: 14.89%
