In [10]:
import anndata
import matplotlib.colors as clr
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import random
import scanpy as sc
import sys
from collections import defaultdict
from sklearn.cluster import KMeans, MiniBatchKMeans
from typing import Dict, List, Tuple, Optional

import warnings
warnings.filterwarnings("ignore")
sc.settings.verbosity = 0

from mcDETECT.utils import *
from mcDETECT.model import mcDETECT

In [11]:
"""
Automated Granule Subtyping Function for mcDETECT
==================================================

This module provides an automated method to classify granules into subtypes
(pre-synaptic, post-synaptic, dendritic, axonal, and mixed types) based on
their expression patterns of marker genes, without requiring clustering or
manual annotation.

Author: Claude
Date: February 2026
"""

import numpy as np
import pandas as pd
from typing import Dict, List, Tuple, Optional
import anndata


class AutomatedGranuleSubtyper:
    """
    Automated subtyping of granules based on marker gene expression patterns.
    
    This classifier uses a rule-based approach with thresholds to assign
    granules to subtypes based on their enrichment of different marker
    gene categories (pre-synaptic, post-synaptic, dendritic, axonal).
    """
    
    def __init__(
        self,
        genes_syn_pre: Optional[List[str]] = None,
        genes_syn_post: Optional[List[str]] = None,
        genes_dendrite: Optional[List[str]] = None,
        genes_axon: Optional[List[str]] = None,
        expression_threshold: float = 0.1,
        enrichment_threshold: float = 0.35,
        enrichment_threshold_pre: Optional[float] = None,
        enrichment_threshold_post: Optional[float] = None,
        enrichment_threshold_den: Optional[float] = None,
        enrichment_threshold_axon: Optional[float] = None,
        min_total_expression: int = 1
    ):
        """
        Initialize the automated granule subtyper.
        
        Parameters
        ----------
        genes_syn_pre : List[str], optional
            Pre-synaptic marker genes. Defaults to standard set if None.
        genes_syn_post : List[str], optional
            Post-synaptic marker genes. Defaults to standard set if None.
        genes_dendrite : List[str], optional
            Dendritic marker genes. Defaults to standard set if None.
        genes_axon : List[str], optional
            Axonal marker genes. Defaults to standard set if None.
        expression_threshold : float, default=0.1
            Minimum normalized expression to consider a marker "present"
        enrichment_threshold : float, default=0.35
            Global minimum proportion of total marker expression to consider
            a category "enriched" (e.g., 0.35 means 35% of expression).
            This is used as the default for all categories.
        enrichment_threshold_pre : float, optional
            Category-specific threshold for pre-synaptic markers.
            If None, uses enrichment_threshold.
        enrichment_threshold_post : float, optional
            Category-specific threshold for post-synaptic markers.
            If None, uses enrichment_threshold.
        enrichment_threshold_den : float, optional
            Category-specific threshold for dendritic markers.
            If None, uses enrichment_threshold.
        enrichment_threshold_axon : float, optional
            Category-specific threshold for axonal markers.
            If None, uses enrichment_threshold.
        min_total_expression : int, default=1
            Minimum total marker expression to attempt classification
        """
        # Default marker gene sets (from your notebook)
        self.genes_syn_pre = genes_syn_pre or [
            "Bsn", "Gap43", "Nrxn1", "Slc17a6", "Slc17a7", "Slc32a1",
            "Snap25", "Stx1a", "Syn1", "Syp", "Syt1", "Vamp2", "Cplx2"
        ]
        
        self.genes_syn_post = genes_syn_post or [
            "Camk2a", "Dlg3", "Dlg4", "Gphn", "Gria1", "Gria2",
            "Homer1", "Homer2", "Nlgn1", "Nlgn2", "Nlgn3", "Shank1", "Shank3"
        ]
        
        self.genes_dendrite = genes_dendrite or [
            "Actb", "Cyfip2", "Ddn", "Dlg4", "Map1a", "Map2"
        ]
        
        self.genes_axon = genes_axon or [
            "Ank3", "Nav1", "Sptnb4", "Nfasc", "Mapt", "Tubb3"
        ]
        
        # Thresholds
        self.expression_threshold = expression_threshold
        self.enrichment_threshold = enrichment_threshold
        self.min_total_expression = min_total_expression
        
        # Category-specific thresholds (use global if not specified)
        self.enrichment_threshold_pre = enrichment_threshold_pre if enrichment_threshold_pre is not None else enrichment_threshold
        self.enrichment_threshold_post = enrichment_threshold_post if enrichment_threshold_post is not None else enrichment_threshold
        self.enrichment_threshold_den = enrichment_threshold_den if enrichment_threshold_den is not None else enrichment_threshold
        self.enrichment_threshold_axon = enrichment_threshold_axon if enrichment_threshold_axon is not None else enrichment_threshold
        
        # Create marker dictionary
        self.marker_genes = {
            "pre-syn": self.genes_syn_pre,
            "post-syn": self.genes_syn_post,
            "dendrites": self.genes_dendrite,
            "axons": self.genes_axon
        }
        
    def _compute_category_scores(
        self,
        expression_matrix: np.ndarray,
        gene_names: List[str]
    ) -> pd.DataFrame:
        """
        Compute normalized expression scores for each marker category.
        
        Parameters
        ----------
        expression_matrix : np.ndarray
            Expression matrix (granules × genes)
        gene_names : List[str]
            Gene names corresponding to columns
            
        Returns
        -------
        pd.DataFrame
            DataFrame with columns: pre_score, post_score, den_score, axon_score
        """
        n_granules = expression_matrix.shape[0]
        scores = {
            "pre_score": np.zeros(n_granules),
            "post_score": np.zeros(n_granules),
            "den_score": np.zeros(n_granules),
            "axon_score": np.zeros(n_granules),
            "total_score": np.zeros(n_granules)
        }
        
        # Create gene index mapping
        gene_to_idx = {gene: idx for idx, gene in enumerate(gene_names)}
        
        # Compute average expression for each category
        for category, genes in [
            ("pre", self.genes_syn_pre),
            ("post", self.genes_syn_post),
            ("den", self.genes_dendrite),
            ("axon", self.genes_axon)
        ]:
            # Find indices of genes present in the data
            gene_indices = [gene_to_idx[g] for g in genes if g in gene_to_idx]
            
            if len(gene_indices) > 0:
                # Sum expression across category genes
                category_expr = expression_matrix[:, gene_indices].sum(axis=1)
                # Normalize by number of genes in category
                scores[f"{category}_score"] = category_expr / len(genes)
        
        # Compute total marker expression
        scores["total_score"] = (
            scores["pre_score"] + scores["post_score"] + 
            scores["den_score"] + scores["axon_score"]
        )
        
        return pd.DataFrame(scores)
    
    def _classify_granule(
        self,
        pre_score: float,
        post_score: float,
        den_score: float,
        axon_score: float,
        total_score: float
    ) -> str:
        """
        Classify a single granule based on its category scores.
        
        The classification logic:
        1. If total expression is too low → "others"
        2. Compute proportion of expression in each category
        3. Identify enriched categories (using category-specific thresholds)
        4. Assign combined subtype based on enriched categories
        
        Parameters
        ----------
        pre_score, post_score, den_score, axon_score : float
            Normalized expression scores for each category
        total_score : float
            Total marker expression
            
        Returns
        -------
        str
            Assigned subtype
        """
        # Check minimum expression threshold
        if total_score < self.min_total_expression:
            return "others"
        
        # Compute proportions
        pre_prop = pre_score / total_score
        post_prop = post_score / total_score
        den_prop = den_score / total_score
        axon_prop = axon_score / total_score
        
        # Identify enriched categories (using category-specific thresholds)
        is_pre = pre_prop >= self.enrichment_threshold_pre
        is_post = post_prop >= self.enrichment_threshold_post
        is_den = den_prop >= self.enrichment_threshold_den
        is_axon = axon_prop >= self.enrichment_threshold_axon
        
        # Classification logic based on enriched categories
        # Pure types (single enriched category)
        if is_pre and not is_post and not is_den and not is_axon:
            return "pre-syn"
        elif is_post and not is_pre and not is_den and not is_axon:
            return "post-syn"
        elif is_den and not is_pre and not is_post and not is_axon:
            return "dendrites"
        elif is_axon and not is_pre and not is_post and not is_den:
            return "axons"
        
        # Two-category combinations
        elif is_pre and is_post and not is_den and not is_axon:
            return "pre & post"
        elif is_pre and is_den and not is_post and not is_axon:
            return "pre & den"
        elif is_post and is_den and not is_pre and not is_axon:
            return "post & den"
        elif is_pre and is_axon and not is_post and not is_den:
            return "pre & axon"
        elif is_post and is_axon and not is_pre and not is_den:
            return "post & axon"
        elif is_den and is_axon and not is_pre and not is_post:
            return "den & axon"
        
        # Three-category combinations (most common mixed types)
        elif is_pre and is_post and is_den and not is_axon:
            return "pre & post & den"
        elif is_pre and is_post and is_axon and not is_den:
            return "pre & post & axon"
        elif is_pre and is_den and is_axon and not is_post:
            return "pre & den & axon"
        elif is_post and is_den and is_axon and not is_pre:
            return "post & den & axon"
        
        # Four-category combination
        elif is_pre and is_post and is_den and is_axon:
            return "pre & post & den & axon"
        
        # No enriched categories or patterns not covered above
        else:
            return "others"
    
    def _classify_granule_dominant(
        self,
        pre_score: float,
        post_score: float,
        den_score: float,
        axon_score: float,
        total_score: float,
        min_dominance_ratio: float = 1.5
    ) -> str:
        """
        Classify using a dominance-based approach to avoid over-calling mixed types.
        
        Strategy:
        1. Identify enriched categories (above threshold)
        2. If multiple enriched, check if one is dominant (>= min_dominance_ratio × others)
        3. If dominant, call as pure type; otherwise mixed
        
        Parameters
        ----------
        pre_score, post_score, den_score, axon_score : float
            Normalized expression scores
        total_score : float
            Total marker expression
        min_dominance_ratio : float, default=1.5
            Minimum ratio for a category to be considered "dominant" over others
            
        Returns
        -------
        str
            Assigned subtype
        """
        # Check minimum expression threshold
        if total_score < self.min_total_expression:
            return "others"
        
        # Compute proportions
        pre_prop = pre_score / total_score
        post_prop = post_score / total_score
        den_prop = den_score / total_score
        axon_prop = axon_score / total_score
        
        # Identify enriched categories
        enriched = []
        scores_dict = {}
        
        if pre_prop >= self.enrichment_threshold_pre:
            enriched.append('pre')
            scores_dict['pre'] = pre_score
        if post_prop >= self.enrichment_threshold_post:
            enriched.append('post')
            scores_dict['post'] = post_score
        if den_prop >= self.enrichment_threshold_den:
            enriched.append('den')
            scores_dict['den'] = den_score
        if axon_prop >= self.enrichment_threshold_axon:
            enriched.append('axon')
            scores_dict['axon'] = axon_score
        
        # No enriched categories
        if len(enriched) == 0:
            return "others"
        
        # Single enriched category
        if len(enriched) == 1:
            cat = enriched[0]
            if cat == 'pre':
                return "pre-syn"
            elif cat == 'post':
                return "post-syn"
            elif cat == 'den':
                return "dendrites"
            else:  # axon
                return "axons"
        
        # Multiple enriched categories - check for dominance
        # Find the dominant category (highest score)
        max_cat = max(scores_dict.items(), key=lambda x: x[1])
        max_score = max_cat[1]
        
        # Check if dominant category is sufficiently higher than others
        is_dominant = True
        for cat, score in scores_dict.items():
            if cat != max_cat[0] and score > 0:
                if max_score / score < min_dominance_ratio:
                    is_dominant = False
                    break
        
        # If one category is dominant, call it as pure type
        if is_dominant:
            cat = max_cat[0]
            if cat == 'pre':
                return "pre-syn"
            elif cat == 'post':
                return "post-syn"
            elif cat == 'den':
                return "dendrites"
            else:  # axon
                return "axons"
        
        # Otherwise, it's mixed - use standard classification
        return self._classify_granule_mixed(enriched)
    
    def _classify_granule_mixed(self, enriched: List[str]) -> str:
        """Helper to classify mixed types based on enriched categories."""
        enriched_set = set(enriched)
        
        # Two-category combinations
        if enriched_set == {'pre', 'post'}:
            return "pre & post"
        elif enriched_set == {'pre', 'den'}:
            return "pre & den"
        elif enriched_set == {'post', 'den'}:
            return "post & den"
        elif enriched_set == {'pre', 'axon'}:
            return "pre & axon"
        elif enriched_set == {'post', 'axon'}:
            return "post & axon"
        elif enriched_set == {'den', 'axon'}:
            return "den & axon"
        
        # Three-category combinations
        elif enriched_set == {'pre', 'post', 'den'}:
            return "pre & post & den"
        elif enriched_set == {'pre', 'post', 'axon'}:
            return "pre & post & axon"
        elif enriched_set == {'pre', 'den', 'axon'}:
            return "pre & den & axon"
        elif enriched_set == {'post', 'den', 'axon'}:
            return "post & den & axon"
        
        # Four-category combination
        elif enriched_set == {'pre', 'post', 'den', 'axon'}:
            return "pre & post & den & axon"
        
        # Fallback
        else:
            return "mixed"
    
    def predict(
        self,
        granule_adata: anndata.AnnData,
        add_scores: bool = True,
        use_dominance: bool = False,
        min_dominance_ratio: float = 1.5
    ) -> pd.Series:
        """
        Predict granule subtypes for all granules in an AnnData object.
        
        Parameters
        ----------
        granule_adata : anndata.AnnData
            AnnData object containing granule expression data
        add_scores : bool, default=True
            Whether to add category scores to granule_adata.obs
        use_dominance : bool, default=False
            Whether to use dominance-based classification (recommended to reduce
            mixed types). When True, a category must be significantly higher than
            others to avoid being called "mixed".
        min_dominance_ratio : float, default=1.5
            Ratio required for dominance (only used if use_dominance=True).
            E.g., 1.5 means the dominant category must be 1.5x higher than others.
            
        Returns
        -------
        pd.Series
            Series of predicted subtypes with same index as granule_adata.obs
        """
        # Extract expression matrix
        expr_matrix = granule_adata.X
        if hasattr(expr_matrix, 'toarray'):
            expr_matrix = expr_matrix.toarray()
        
        # Get gene names
        gene_names = list(granule_adata.var_names)
        
        # Compute category scores
        scores_df = self._compute_category_scores(expr_matrix, gene_names)
        
        # Classify each granule
        subtypes = []
        for idx in range(len(scores_df)):
            if use_dominance:
                subtype = self._classify_granule_dominant(
                    scores_df.loc[idx, "pre_score"],
                    scores_df.loc[idx, "post_score"],
                    scores_df.loc[idx, "den_score"],
                    scores_df.loc[idx, "axon_score"],
                    scores_df.loc[idx, "total_score"],
                    min_dominance_ratio
                )
            else:
                subtype = self._classify_granule(
                    scores_df.loc[idx, "pre_score"],
                    scores_df.loc[idx, "post_score"],
                    scores_df.loc[idx, "den_score"],
                    scores_df.loc[idx, "axon_score"],
                    scores_df.loc[idx, "total_score"]
                )
            subtypes.append(subtype)
        
        # Add scores to AnnData if requested
        if add_scores:
            for col in ["pre_score", "post_score", "den_score", "axon_score", "total_score"]:
                granule_adata.obs[col] = scores_df[col].values
        
        # Return subtypes as categorical series
        subtypes_series = pd.Series(
            subtypes,
            index=granule_adata.obs.index,
            name="granule_subtype_auto"
        )
        
        return subtypes_series
    
    def predict_and_compare(
        self,
        granule_adata: anndata.AnnData,
        manual_subtype_column: str = "granule_subtype"
    ) -> Tuple[pd.Series, pd.DataFrame]:
        """
        Predict subtypes and compare with manual annotations if available.
        
        Parameters
        ----------
        granule_adata : anndata.AnnData
            AnnData object containing granule expression data
        manual_subtype_column : str, default="granule_subtype"
            Column name in obs containing manual annotations
            
        Returns
        -------
        subtypes : pd.Series
            Predicted subtypes
        comparison : pd.DataFrame
            Confusion matrix comparing automated vs manual subtypes
        """
        # Predict subtypes
        subtypes = self.predict(granule_adata, add_scores=True)
        
        # Create comparison if manual annotations exist
        if manual_subtype_column in granule_adata.obs.columns:
            manual_subtypes = granule_adata.obs[manual_subtype_column]
            
            # Create confusion matrix
            comparison = pd.crosstab(
                manual_subtypes,
                subtypes,
                rownames=["Manual"],
                colnames=["Automated"]
            )
            
            return subtypes, comparison
        else:
            return subtypes, None


# Convenience function for quick usage
def classify_granules(
    granule_adata: anndata.AnnData,
    expression_threshold: float = 0.1,
    enrichment_threshold: float = 0.35,
    enrichment_threshold_pre: Optional[float] = None,
    enrichment_threshold_post: Optional[float] = None,
    enrichment_threshold_den: Optional[float] = None,
    enrichment_threshold_axon: Optional[float] = None,
    min_total_expression: int = 1,
    custom_markers: Optional[Dict[str, List[str]]] = None,
    use_dominance: bool = False,
    min_dominance_ratio: float = 1.5
) -> pd.Series:
    """
    Classify granules into subtypes using automated rule-based classification.
    
    This is a convenience function that creates a subtyper and runs prediction.
    
    Parameters
    ----------
    granule_adata : anndata.AnnData
        AnnData object containing granule expression profiles
    expression_threshold : float, default=0.1
        Minimum normalized expression to consider a marker "present"
    enrichment_threshold : float, default=0.35
        Global minimum proportion of expression to consider a category enriched.
        This is used as the default for all categories if specific thresholds are not set.
    enrichment_threshold_pre : float, optional
        Category-specific threshold for pre-synaptic markers (overrides enrichment_threshold)
    enrichment_threshold_post : float, optional
        Category-specific threshold for post-synaptic markers (overrides enrichment_threshold)
    enrichment_threshold_den : float, optional
        Category-specific threshold for dendritic markers (overrides enrichment_threshold)
    enrichment_threshold_axon : float, optional
        Category-specific threshold for axonal markers (overrides enrichment_threshold)
    min_total_expression : int, default=1
        Minimum total marker expression for classification
    custom_markers : Dict[str, List[str]], optional
        Custom marker gene sets. Should have keys: 'pre-syn', 'post-syn',
        'dendrites', 'axons'. If None, uses default markers.
    use_dominance : bool, default=False
        Use dominance-based classification to reduce mixed types. When True,
        even if multiple categories are enriched, if one is sufficiently dominant
        (>= min_dominance_ratio × others), it's called as a pure type.
    min_dominance_ratio : float, default=1.5
        Ratio required for a category to be considered dominant (only used if
        use_dominance=True). E.g., 1.5 means dominant category must be 1.5x
        higher than any other enriched category.
        
    Returns
    -------
    pd.Series
        Predicted granule subtypes
        
    Examples
    --------
    >>> # Basic usage with default parameters
    >>> subtypes = classify_granules(granule_adata)
    >>> granule_adata.obs['granule_subtype_auto'] = subtypes
    
    >>> # Adjust thresholds for stricter classification
    >>> subtypes = classify_granules(
    ...     granule_adata,
    ...     enrichment_threshold=0.4,  # Require 40% enrichment
    ...     min_total_expression=2      # Require more total expression
    ... )
    
    >>> # Use category-specific thresholds + dominance (RECOMMENDED for avoiding mixed types)
    >>> subtypes = classify_granules(
    ...     granule_adata,
    ...     enrichment_threshold_pre=0.15,   # Lower threshold for pre-syn
    ...     enrichment_threshold_post=0.15,  # Lower threshold for post-syn
    ...     enrichment_threshold_den=0.45,   # Higher threshold for dendrites
    ...     enrichment_threshold_axon=0.45,  # Higher threshold for axons
    ...     use_dominance=True,              # Use dominance mode
    ...     min_dominance_ratio=1.5          # Require 1.5x dominance
    ... )
    
    >>> # Use custom marker genes
    >>> custom_markers = {
    ...     'pre-syn': ['Syp', 'Syn1', 'Bsn'],
    ...     'post-syn': ['Dlg4', 'Shank3', 'Homer1'],
    ...     'dendrites': ['Map2', 'Cyfip2'],
    ...     'axons': ['Mapt', 'Nav1']
    ... }
    >>> subtypes = classify_granules(granule_adata, custom_markers=custom_markers)
    """
    # Create subtyper
    if custom_markers is not None:
        subtyper = AutomatedGranuleSubtyper(
            genes_syn_pre=custom_markers.get('pre-syn'),
            genes_syn_post=custom_markers.get('post-syn'),
            genes_dendrite=custom_markers.get('dendrites'),
            genes_axon=custom_markers.get('axons'),
            expression_threshold=expression_threshold,
            enrichment_threshold=enrichment_threshold,
            enrichment_threshold_pre=enrichment_threshold_pre,
            enrichment_threshold_post=enrichment_threshold_post,
            enrichment_threshold_den=enrichment_threshold_den,
            enrichment_threshold_axon=enrichment_threshold_axon,
            min_total_expression=min_total_expression
        )
    else:
        subtyper = AutomatedGranuleSubtyper(
            expression_threshold=expression_threshold,
            enrichment_threshold=enrichment_threshold,
            enrichment_threshold_pre=enrichment_threshold_pre,
            enrichment_threshold_post=enrichment_threshold_post,
            enrichment_threshold_den=enrichment_threshold_den,
            enrichment_threshold_axon=enrichment_threshold_axon,
            min_total_expression=min_total_expression
        )
    
    # Predict and return
    return subtyper.predict(
        granule_adata, 
        add_scores=True,
        use_dominance=use_dominance,
        min_dominance_ratio=min_dominance_ratio
    )


# Example usage script
if __name__ == "__main__":
    """
    Example usage demonstrating how to use the automated subtyper.
    """
    import scanpy as sc
    
    # Load your granule data
    # granule_adata = sc.read_h5ad("output/granule_adata_raw.h5ad")
    
    # Method 1: Quick classification with default parameters
    # subtypes = classify_granules(granule_adata)
    # granule_adata.obs['granule_subtype_auto'] = subtypes
    
    # Method 2: More control with custom parameters
    # subtyper = AutomatedGranuleSubtyper(
    #     expression_threshold=0.1,
    #     enrichment_threshold=0.35,
    #     min_total_expression=1
    # )
    # subtypes, comparison = subtyper.predict_and_compare(
    #     granule_adata,
    #     manual_subtype_column="granule_subtype"
    # )
    # 
    # # Add to AnnData
    # granule_adata.obs['granule_subtype_auto'] = subtypes
    # 
    # # Print comparison with manual annotations
    # if comparison is not None:
    #     print("\nComparison with manual annotations:")
    #     print(comparison)
    #     print(f"\nOverall agreement: {(comparison.values.diagonal().sum() / comparison.values.sum() * 100):.2f}%")
    
    # Method 3: Optimize thresholds (optional)
    # You can experiment with different thresholds to match your manual annotations
    # for enrich_thr in [0.3, 0.35, 0.4, 0.45]:
    #     subtyper = AutomatedGranuleSubtyper(enrichment_threshold=enrich_thr)
    #     subtypes = subtyper.predict(granule_adata)
    #     # Evaluate agreement with manual labels
    #     # ...
    
    print("Automated Granule Subtyper module loaded successfully!")
    print("Use classify_granules() for quick classification or")
    print("AutomatedGranuleSubtyper class for more control.")


Automated Granule Subtyper module loaded successfully!
Use classify_granules() for quick classification or
AutomatedGranuleSubtyper class for more control.


In [12]:
granule_adata = sc.read_h5ad("../data/adata_final.h5ad")
granule_adata

AnnData object with n_obs × n_vars = 1498704 × 290
    obs: 'global_x', 'global_y', 'global_z', 'layer_z', 'sphere_r', 'size', 'comp', 'in_nucleus', 'gene', 'brain_area', 'global_y_new', 'global_x_new', 'synapse_id', 'global_x_adjusted', 'global_y_adjusted', 'batch', 'brain_area_merged', 'batch_simple', 'granule_subtype_kmeans', 'granule_subtype'
    var: 'genes'
    uns: 'batch_simple_colors', 'brain_area_colors', 'log1p', 'pca', 'rank_genes_groups', 'tsne'
    obsm: 'X_pca', 'X_tsne'
    varm: 'PCs'

In [13]:
granule_adata.obs["granule_subtype"].value_counts() / granule_adata.shape[0]

granule_subtype
mixed        0.300492
post-syn     0.237391
pre-syn      0.193643
dendrites    0.143209
others       0.125265
Name: count, dtype: float64

In [None]:
subtypes = classify_granules(
    granule_adata,
    # Low thresholds to detect synaptic markers
    enrichment_threshold_pre=0.15,
    enrichment_threshold_post=0.15,
    # High thresholds to be strict about dendritic/axonal
    enrichment_threshold_den=0.45,
    enrichment_threshold_axon=0.45,
    # ENABLE DOMINANCE MODE
    use_dominance=True,
    min_dominance_ratio=1.5  # Require 1.5x dominance
)

granule_adata.obs['granule_subtype_auto'] = subtypes

# Check distribution
print(subtypes.value_counts(normalize=True))

granule_subtype_auto
dendrites            0.301581
others               0.271758
pre-syn              0.154302
pre & post           0.100073
post-syn             0.084719
post & den           0.028068
axons                0.025595
pre & den            0.018433
pre & post & den     0.005224
pre & axon           0.004547
den & axon           0.004019
post & axon          0.001484
pre & post & axon    0.000199
Name: proportion, dtype: float64


In [16]:
def simplify(s):
    if ' & ' in str(s):
        return 'mixed'
    return str(s)

simple_no_dom = subtypes.apply(simplify)

In [18]:
simple_no_dom.value_counts() / len(simple_no_dom)

granule_subtype_auto
dendrites    0.301581
others       0.271758
mixed        0.162045
pre-syn      0.154302
post-syn     0.084719
axons        0.025595
Name: count, dtype: float64

In [20]:
results = []

for dom_ratio in [1.0, 1.2, 1.4, 1.6, 1.8, 2.0]:
    for pre_syn_thr in [0.1, 0.15, 0.20, 0.25, 0.30, 0.35]:
        for post_syn_thr in [0.1, 0.15, 0.20, 0.25, 0.30, 0.35]:
            for dendrite_axon_thr in [0.40, 0.45, 0.50, 0.55]:
                subtypes = classify_granules(
                    granule_adata,
                    enrichment_threshold_pre=pre_syn_thr,
                    enrichment_threshold_post=post_syn_thr,
                    enrichment_threshold_den=dendrite_axon_thr,
                    enrichment_threshold_axon=dendrite_axon_thr,
                    use_dominance=True,
                    min_dominance_ratio=1.5
                )
                
                simple = subtypes.apply(lambda s: 'mixed' if ' & ' in str(s) else str(s))
                counts = simple.value_counts(normalize=True)
                
                print(f"Completed: dom_ratio: {dom_ratio}, pre_syn_thr: {pre_syn_thr}, post_syn_thr: {post_syn_thr}, dendrite_axon_thr: {dendrite_axon_thr}")
                
                results.append({
                    'dom_ratio': dom_ratio,
                    'pre_syn_thr': pre_syn_thr,
                    'post_syn_thr': post_syn_thr,
                    'dendrite_axon_thr': dendrite_axon_thr,
                    'pre-syn': counts.get('pre-syn', 0),
                    'post-syn': counts.get('post-syn', 0),
                    'dendrites': counts.get('dendrites', 0),
                    'mixed': counts.get('mixed', 0),
                    'others': counts.get('others', 0)
                })

results_df = pd.DataFrame(results)

Completed: dom_ratio: 1.0, pre_syn_thr: 0.1, post_syn_thr: 0.1, dendrite_axon_thr: 0.4
Completed: dom_ratio: 1.0, pre_syn_thr: 0.1, post_syn_thr: 0.1, dendrite_axon_thr: 0.45
Completed: dom_ratio: 1.0, pre_syn_thr: 0.1, post_syn_thr: 0.1, dendrite_axon_thr: 0.5
Completed: dom_ratio: 1.0, pre_syn_thr: 0.1, post_syn_thr: 0.1, dendrite_axon_thr: 0.55
Completed: dom_ratio: 1.0, pre_syn_thr: 0.1, post_syn_thr: 0.15, dendrite_axon_thr: 0.4
Completed: dom_ratio: 1.0, pre_syn_thr: 0.1, post_syn_thr: 0.15, dendrite_axon_thr: 0.45
Completed: dom_ratio: 1.0, pre_syn_thr: 0.1, post_syn_thr: 0.15, dendrite_axon_thr: 0.5
Completed: dom_ratio: 1.0, pre_syn_thr: 0.1, post_syn_thr: 0.15, dendrite_axon_thr: 0.55
Completed: dom_ratio: 1.0, pre_syn_thr: 0.1, post_syn_thr: 0.2, dendrite_axon_thr: 0.4
Completed: dom_ratio: 1.0, pre_syn_thr: 0.1, post_syn_thr: 0.2, dendrite_axon_thr: 0.45
Completed: dom_ratio: 1.0, pre_syn_thr: 0.1, post_syn_thr: 0.2, dendrite_axon_thr: 0.5
Completed: dom_ratio: 1.0, pre_syn

In [None]:
# Find best match to manual
manual = {'pre-syn': 0.1936, 'post-syn': 0.2374, 'dendrites': 0.1432, 
          'mixed': 0.3005, 'others': 0.1253}

results_df['distance'] = results_df.apply(
    lambda row: sum(abs(row[k] - manual[k]) for k in manual.keys()), 
    axis=1
)

print("\nBest parameter combinations:")
print(results_df.sort_values('distance').head(10))