In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import anndata
from pathlib import Path
from abc_atlas_access.abc_atlas_cache.abc_project_cache import AbcProjectCache

In [None]:
import os
# Set the current working directory
os.chdir('/beegfs/scratch/ric.broccoli/kubacki.michal/SRF_SRRM3')

# Print the current working directory to confirm the change
print(f"Current working directory: {os.getcwd()}")

In [25]:
class BrainCellAtlasAnalyzer:
    """
    Analyzer for Allen Brain Cell Atlas scRNA-seq data with hierarchical annotations.
    """
    def __init__(self, download_base: str | Path = './DATA/abc_atlas', create_if_missing: bool = True):
        """
        Initialize the analyzer with custom data cache location
        
        Parameters:
        -----------
        download_base : str | Path
            Base directory for downloading data
        create_if_missing : bool
            If True, create the directory if it doesn't exist
        """
        # Convert string path to Path object if necessary
        self.download_base = Path(download_base)
        
        # Create directory if it doesn't exist
        if create_if_missing:
            self.download_base.mkdir(parents=True, exist_ok=True)
        
        # Initialize cache
        self.abc_cache = AbcProjectCache.from_cache_dir(self.download_base)
        self.abc_cache.load_latest_manifest()
        
        print(f"Using data directory: {self.download_base.absolute()}")
        print(f"Current manifest: {self.abc_cache.current_manifest}")
        
        # Load taxonomy data
        self._load_taxonomy_data()
        
    def _load_taxonomy_data(self):
        """Load hierarchical taxonomy data"""
        # Load cluster and annotation data
        self.clusters = self.abc_cache.get_metadata_dataframe(
            directory='WMB-taxonomy', 
            file_name='cluster'
        )
        
        self.term_sets = self.abc_cache.get_metadata_dataframe(
            directory='WMB-taxonomy', 
            file_name='cluster_annotation_term_set'
        )
        
        self.terms = self.abc_cache.get_metadata_dataframe(
            directory='WMB-taxonomy', 
            file_name='cluster_annotation_term',
            keep_default_na=False
        )
        
        self.membership = self.abc_cache.get_metadata_dataframe(
            directory='WMB-taxonomy', 
            file_name='cluster_to_cluster_annotation_membership'
        )

    def get_annotated_data(self, 
                          dataset: str = 'WMB-10Xv2',
                          region: str = 'CTXsp',
                          matrix_type: str = 'log2') -> anndata.AnnData:
        """
        Get expression data with full hierarchical annotations.
        
        Parameters:
        -----------
        dataset : str
            Dataset name (e.g., 'WMB-10Xv2', 'WMB-10Xv3')
        region : str
            Brain region (e.g., 'CTXsp', 'HPF')
        matrix_type : str
            Expression matrix type ('log2' or 'raw')
            
        Returns:
        --------
        anndata.AnnData
            Annotated expression data
        """
        # Get expression data
        expr_file = f"{dataset}-{region}/{matrix_type}"
        adata_path = self.abc_cache.get_data_path(
            directory=dataset,
            file_name=expr_file
        )
        adata = anndata.read_h5ad(adata_path)
            
        # Get cluster information
        clusters = self.abc_cache.get_metadata_dataframe(
            directory='WMB-taxonomy',
            file_name='cluster'
        )
        
        # Get membership with colors
        membership_color = self.abc_cache.get_metadata_dataframe(
            directory='WMB-taxonomy',
            file_name='cluster_to_cluster_annotation_membership_color'
        )
        
        # Add cluster info to adata.obs - using cluster_alias as the index
        adata.obs = adata.obs.join(clusters.set_index('cluster_alias'), on='cluster_alias')
        
        # Add colored membership annotations
        adata.obs = adata.obs.join(membership_color.set_index('cluster_alias'))
        
        # Add dataset info
        adata.uns['dataset'] = dataset
        adata.uns['region'] = region
        adata.uns['matrix_type'] = matrix_type
        
        return adata

    def get_taxonomy_summary(self) -> pd.DataFrame:
        """
        Get summary of the hierarchical taxonomy structure.
        
        Returns:
        --------
        pd.DataFrame
            Summary of term counts at each level
        """
        summary = self.terms.groupby('cluster_annotation_term_set_name')['label'].count()
        summary.name = 'number_of_terms'
        return summary
    
    def get_cell_type_counts(self, level: str = 'cluster') -> pd.DataFrame:
        """
        Get cell type counts at specified taxonomy level.
        
        Parameters:
        -----------
        level : str
            Taxonomy level ('neurotransmitter', 'class', 'subclass', 'supertype', 'cluster')
            
        Returns:
        --------
        pd.DataFrame
            Cell counts per cell type at specified level
        """
        counts = self.membership[
            self.membership['cluster_annotation_term_set_name'] == level
        ].groupby('cluster_annotation_term_name').agg({
            'number_of_cells': 'sum',
            'cluster_alias': 'count',
            'color_hex_triplet': 'first'
        })
        counts.columns = ['total_cells', 'number_of_clusters', 'color']
        return counts.sort_values('total_cells', ascending=False)

    def plot_hierarchy_distribution(self, level_a: str, level_b: str, 
                                  figsize: tuple = (10, 6)):
        """
        Plot distribution of cell types between two taxonomy levels.
        
        Parameters:
        -----------
        level_a : str
            Parent level for rows
        level_b : str
            Child level for columns
        figsize : tuple
            Figure size (width, height)
        """
        # Create pivot table
        pivot = self.membership.pivot_table(
            values='cluster_alias',
            index=[f'{level_a}_name'],
            columns=[f'{level_b}_name'],
            aggfunc='count',
            fill_value=0
        )
        
        # Get colors for level_b
        colors = self.terms[
            self.terms['cluster_annotation_term_set_name'] == level_b
        ].set_index('name')['color_hex_triplet']
        
        # Plot
        fig, ax = plt.subplots(figsize=figsize)
        bottom = np.zeros(len(pivot))
        
        for col in pivot.columns:
            ax.barh(pivot.index, pivot[col], left=bottom, 
                   label=col, color=colors.get(col, '#CCCCCC'))
            bottom += pivot[col].values
            
        ax.set_title(f'Distribution of {level_b} in each {level_a}')
        ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        plt.tight_layout()
        return fig, ax


In [None]:
custom_path = Path('./DATA/abc_atlas')
analyzer = BrainCellAtlasAnalyzer(download_base=custom_path)

In [None]:
# Check available metadata files in WMB-taxonomy
print("Available metadata files in WMB-taxonomy:")
print(analyzer.abc_cache.list_metadata_files('WMB-taxonomy'))

In [None]:
# Get annotated data
adata = analyzer.get_annotated_data(
    dataset='WMB-10Xv2',
    region='CTXsp',
    matrix_type='log2'
)

In [None]:
# Save the annotated data to file
adata.write_h5ad('./DATA/annotated_data.h5ad')

In [None]:
# Print taxonomy summary
print("\nTaxonomy levels and term counts:")
print(analyzer.get_taxonomy_summary())

In [None]:
# Get cell type counts at different levels
print("\nTop 5 cell types by number of cells:")
for level in ['class', 'subclass', 'supertype']:
    counts = analyzer.get_cell_type_counts(level)
    print(f"\n{level.capitalize()} level:")
    print(counts.head())

In [None]:

# Plot distribution between levels
analyzer.plot_hierarchy_distribution('class', 'neurotransmitter')
plt.show()

In [41]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import anndata
from pathlib import Path
from typing import List, Union, Tuple
from abc_atlas_access.abc_atlas_cache.abc_project_cache import AbcProjectCache

class BrainCellAtlasAnalyzer:
    def __init__(self, download_base: str | Path = './DATA/abc_atlas', create_if_missing: bool = True):
        """Initialize analyzer with custom data cache location"""
        self.download_base = Path(download_base)
        if create_if_missing:
            self.download_base.mkdir(parents=True, exist_ok=True)
        
        # Initialize cache
        self.abc_cache = AbcProjectCache.from_cache_dir(self.download_base)
        self.abc_cache.load_latest_manifest()
        
        # Load taxonomy and metadata
        self._load_cell_metadata()
        self._load_taxonomy_data()
        
    def _load_cell_metadata(self):
        """Load cell metadata"""
        # Load basic cell metadata
        self.cell_metadata = self.abc_cache.get_metadata_dataframe(
            directory='WMB-10X',
            file_name='cell_metadata',
            dtype={'cell_label': str}
        )
        self.cell_metadata.set_index('cell_label', inplace=True)
        
        # Get available matrices
        self.matrices = self.cell_metadata.groupby(
            ['dataset_label', 'feature_matrix_label']
        )[['library_label']].count()
        self.matrices.columns = ['cell_count']

    def _load_taxonomy_data(self):
        """Load hierarchical taxonomy data"""
        # Load cluster details
        self.cluster_details = self.abc_cache.get_metadata_dataframe(
            directory='WMB-taxonomy',
            file_name='cluster_to_cluster_annotation_membership_pivoted',
            keep_default_na=False
        )
        self.cluster_details.set_index('cluster_alias', inplace=True)
        
        # Load cluster colors
        self.cluster_colors = self.abc_cache.get_metadata_dataframe(
            directory='WMB-taxonomy',
            file_name='cluster_to_cluster_annotation_membership_color'
        )
        self.cluster_colors.set_index('cluster_alias', inplace=True)
        
        # Extend cell metadata with annotations
        self._extend_cell_metadata()
        
    def _extend_cell_metadata(self):
        """Add cluster annotations to cell metadata"""
        self.cell_metadata = self.cell_metadata.join(
            self.cluster_details, 
            on='cluster_alias'
        )
        self.cell_metadata = self.cell_metadata.join(
            self.cluster_colors, 
            on='cluster_alias'
        )

    def get_expression_data(self,
                          dataset: str,
                          region: str,
                          matrix_type: str = 'log2',
                          genes: List[str] = None) -> anndata.AnnData:
        """
        Get expression data for specific genes from a brain region
        
        Parameters:
        -----------
        dataset : str
            Dataset name (e.g., 'WMB-10Xv2', 'WMB-10Xv3')
        region : str
            Brain region (e.g., 'CTXsp', 'TH')
        matrix_type : str
            'log2' or 'raw'
        genes : List[str]
            List of gene symbols to extract. If None, gets all genes.
            
        Returns:
        --------
        anndata.AnnData
            Expression data with annotations
        """
        # Construct file path
        expr_file = f"{dataset}-{region}/{matrix_type}"
        file_path = self.abc_cache.get_data_path(
            directory=dataset,
            file_name=expr_file
        )
        
        # Load data
        adata = anndata.read_h5ad(file_path, backed='r')
        
        # Filter for specific genes if requested
        if genes:
            gene_mask = [x in genes for x in adata.var.gene_symbol]
            adata = adata[:, gene_mask]
        
        return adata

    def analyze_gene_expression(self,
                              adata: anndata.AnnData,
                              genes: List[str],
                              group_by: str = 'neurotransmitter',
                              min_cells: int = 10) -> pd.DataFrame:
        """
        Analyze gene expression across cell types
        
        Parameters:
        -----------
        adata : anndata.AnnData
            Expression data
        genes : List[str]
            Genes to analyze
        group_by : str
            How to group cells ('neurotransmitter', 'class', 'subclass', etc)
        min_cells : int
            Minimum number of cells required in a group
            
        Returns:
        --------
        pd.DataFrame
            Mean expression per group
        """
        # Create expression DataFrame
        expr_df = pd.DataFrame(
            adata[:, genes].X.toarray(),
            index=adata.obs.index,
            columns=genes
        )
        
        # Get cell annotations
        cell_groups = self.cell_metadata.loc[expr_df.index, group_by]
        
        # Calculate mean expression per group
        grouped = expr_df.groupby(cell_groups).agg(['mean', 'count'])
        
        # Filter by minimum cell count
        mask = grouped.xs('count', axis=1, level=1).min(axis=1) >= min_cells
        grouped = grouped.xs('mean', axis=1, level=1)[mask]
        
        return grouped.sort_values(genes[0], ascending=False)

    def plot_expression_heatmap(self,
                              expression_data: pd.DataFrame,
                              figsize: Tuple[int, int] = (10, 6),
                              cmap: str = 'magma',
                              vmax: float = None) -> Tuple[plt.Figure, plt.Axes]:
        """
        Create heatmap of gene expression
        
        Parameters:
        -----------
        expression_data : pd.DataFrame
            Expression data from analyze_gene_expression
        figsize : tuple
            Figure size
        cmap : str
            Colormap name
        vmax : float
            Maximum value for color scaling
            
        Returns:
        --------
        (Figure, Axes)
            Matplotlib figure and axes objects
        """
        fig, ax = plt.subplots(figsize=figsize)
        
        sns.heatmap(
            expression_data,
            cmap=cmap,
            robust=True,
            vmax=vmax,
            ax=ax
        )
        
        # Rotate x-axis labels for better readability
        plt.xticks(rotation=45, ha='right')
        
        return fig, ax

In [None]:
# Initialize analyzer
analyzer = BrainCellAtlasAnalyzer(download_base='./DATA/abc_atlas')


In [None]:
# Define genes of interest
srrm_genes = ['Srrm3', 'Srrm4']

# Get expression data for thalamus
adata = analyzer.get_expression_data(
    dataset='WMB-10Xv2',
    region='TH',
    genes=srrm_genes
)
    

In [None]:
# Analyze expression by different groupings
for grouping in ['neurotransmitter', 'class', 'subclass']:
    print(f"\nAnalyzing by {grouping}...")
    
    expr_data = analyzer.analyze_gene_expression(
        adata=adata,
        genes=srrm_genes,
        group_by=grouping
    )
    
    # Plot heatmap
    fig, ax = analyzer.plot_expression_heatmap(
        expression_data=expr_data,
        figsize=(8, len(expr_data)/2)
    )
    ax.set_title(f'Expression by {grouping}')
    plt.tight_layout()
    plt.show()

In [None]:
# List available metadata
print("\nAvailable metadata files:")
print(analyzer.list_available_metadata('WMB-10Xv2'))


In [None]:
# Example usage
if __name__ == "__main__":
    # Initialize analyzer
    analyzer = BrainCellAtlasAnalyzer(download_base='./DATA/abc_atlas')
    

    # Get data
    adata = analyzer.get_annotated_data(
        dataset='WMB-10Xv2',
        region='CTXsp',
        matrix_type='log2'
    )
    
    # Compare Srrm3 and Srrm4 expression
    genes = ['Srrm3', 'Srrm4']
    
    # Plot expression distributions
    plots = analyzer.plot_gene_expression(
        adata=adata,
        genes=genes,
        groupby='cluster_label' if 'cluster_label' in adata.obs else None
    )
    
    plt.show()

In [None]:
# Get data
adata = analyzer.get_annotated_data(
    dataset='WMB-10Xv2',
    region='CTXsp',
    matrix_type='log2'
)

In [None]:




# Analyze Srrm3 and Srrm4
genes = ['Srrm3', 'Srrm4']

# Create visualizations
fig1, axes = analyzer.plot_gene_expression(
    adata=adata,
    genes=genes,
    groupby='class_label'
)

fig2, ax = analyzer.plot_gene_expression_heatmap(
    adata=adata,
    genes=genes,
    groupby='class_label'
)

plt.show()

# Print statistics
stats = analyzer.analyze_gene_expression(
    adata=adata,
    genes=genes,
    groupby='class_label',
    n_top_groups=5
)

print("\nTop expressing classes for each gene:")
for gene in genes:
    print(f"\n{gene}:")
    gene_stats = stats[stats['gene'] == gene]
    print(gene_stats[['group', 'mean', 'count']].head())

In [None]:


# Compare Srrm3 and Srrm4 expression
genes_of_interest = ['Srrm3', 'Srrm4']

# Create visualizations
print("\nCreating expression plots...")

# Basic expression comparison
fig1, axes = analyzer.plot_gene_expression(
    adata=adata,
    genes=genes_of_interest,
    groupby='class_label',  # Compare at class level
    n_top_groups=15
)

# Heatmap comparison
fig2, ax = analyzer.plot_gene_expression_heatmap(
    adata=adata,
    genes=genes_of_interest,
    groupby='class_label',
    n_top_groups=15
)

plt.show()

# Print summary statistics
stats = analyzer.analyze_gene_expression(
    adata=adata,
    genes=genes_of_interest,
    groupby='class_label',
    n_top_groups=5  # Show top 5 classes
)

print("\nTop expressing classes for each gene:")
for gene in genes_of_interest:
    print(f"\n{gene}:")
    gene_stats = stats[stats['gene'] == gene]
    print(gene_stats[['group', 'mean', 'count']].head())