In [101]:
import scanpy as sc
import gseapy as gp
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Dict, Optional
from scipy import stats, sparse
from tqdm import tqdm
import h5py
import warnings
from statsmodels.stats.multitest import fdrcorrection
# Suppress FutureWarnings
warnings.filterwarnings("ignore", category=FutureWarning)

In [86]:
class CellTypeGSEA:
    """
    A class to perform GSEA analysis between different cell types in single-cell data.
    """
    def __init__(
        self, 
        adata: sc.AnnData,
        cell_type_key: str = 'cell_type',
        gmt_file: str = None 
    ):
        """
        Initialize the GSEA analysis object.
        
        Parameters:
        -----------
        adata : AnnData
            Annotated data matrix with cell type annotations
        cell_type_key : str
            Key in adata.obs containing cell type labels
        gmt_file : str
            path to the gene set files (gmt format)
        """
        self.adata = adata
        self.cell_type_key = cell_type_key
        if gmt_file is not None:
            self.gene_sets = self._load_gmt_file(gmt_file)
        self.results = {}

    def _load_gmt_file(self, gmt_file):
        """
        Load and parse GMT file into dictionary format
        """
        gene_sets = {}
        with open(gmt_file, 'r') as f:
            for line in f:
                parts = line.strip().split('\t')
                if len(parts) < 3:  # Skip malformed lines
                    continue
                pathway_name = parts[0]
                pathway_description = parts[1]
                genes = [gene for gene in parts[2:] if gene]  # Remove empty strings
                gene_sets[pathway_name] = genes
        
        print(f"Loaded {len(gene_sets)} gene sets")
        # Print first gene set as example
        first_set = next(iter(gene_sets.items()))
        print(f"\nExample gene set:")
        print(f"Name: {first_set[0]}")
        print(f"Number of genes: {len(first_set[1])}")
        print(f"First few genes: {first_set[1][:5]}")
        
        return gene_sets

    def compute_cell_type_rankings(
        self,
        cell_type: str
    ) -> pd.Series:
        """
        Compute differential expression rankings for one cell type vs all others.
        
        Parameters:
        -----------
        cell_type : str
            Cell type to analyze
            
        Returns:
        --------
        pd.Series
            Ranked gene list with ranking scores
        """
        # Create binary mask for cell type
        cell_mask = self.adata.obs[self.cell_type_key] == cell_type
        
        # Initialize results storage
        n_genes = self.adata.n_vars
        scores = np.zeros(n_genes)
        pvals = np.zeros(n_genes)
        
        # Get expression matrix
        if sparse.issparse(self.adata.X):
            X = self.adata.X.toarray()
        else:
            X = self.adata.X
        
        # Compute rankings for each gene
        for i in range(n_genes):
            gene_expr = X[:, i]
            
            # Perform Mann-Whitney U test
            stat, pval = stats.mannwhitneyu(
                gene_expr[cell_mask],
                gene_expr[~cell_mask],
                alternative='two-sided'
            )
            
            # Compute effect size (log2 fold change)
            mean_1 = np.mean(gene_expr[cell_mask])
            mean_2 = np.mean(gene_expr[~cell_mask])
            log2fc = np.log2((mean_1 + 1e-10) / (mean_2 + 1e-10))
            
            scores[i] = log2fc
            pvals[i] = pval
        
        # Create ranking metric
        # Add a small epsilon to p-values to avoid log10(0)
        min_pval = np.finfo(float).tiny  # Smallest positive float
        pvals = np.maximum(pvals, min_pval)
        ranking_metric = -np.log10(pvals) * np.sign(scores)
        
        # Create ranked gene list
        gene_names = [f'gene_{i}' for i in range(n_genes)] if self.adata.var_names.empty else self.adata.var_names
        rankings = pd.Series(
            ranking_metric,
            index=gene_names,
            name='ranking'
        ).sort_values(ascending=False)
        
        return rankings

    def run_gsea(
        self,
        min_size: int = 15,
        max_size: int = 500,
        permutations: int = 1000,
        threads: int = 4
    ) -> Dict:
        """
        Run GSEA analysis for all cell types.
        
        Parameters:
        -----------
        min_size : int
            Minimum gene set size
        max_size : int
            Maximum gene set size
        permutations : int
            Number of permutations
        threads : int
            Number of parallel threads
            
        Returns:
        --------
        Dict
            Dictionary containing GSEA results for each cell type
        """
        # Get unique cell types
        cell_types = self.adata.obs[self.cell_type_key].unique()
        
        print("Running GSEA analysis for each cell type...")
        for cell_type in tqdm(cell_types):
            # Get rankings for this cell type
            rankings = self.compute_cell_type_rankings(cell_type)
            
            # Run GSEA
            pre_res = gp.prerank(
                rnk=rankings,
                gene_sets=self.gene_sets,
                min_size=min_size,
                max_size=max_size,
                permutation_num=permutations,
                threads=threads,
                seed=42,
                no_plot=True
            )
            
            # Store results
            self.results[cell_type] = pre_res.res2d
            
        return self.results

    def plot_top_pathways(
        self,
        n_pathways: int = 10,
        fdr_cutoff: float = 0.05,
        figsize: tuple = (15, 10)
    ) -> None:
        """
            Plot top enriched pathways for each cell type.
            
            Parameters:
            -----------
            n_pathways : int
                Number of top pathways to show
            fdr_cutoff : float
                FDR cutoff for significance
            figsize : tuple
                Figure size
        """
        # Combine all results
        all_results = []
        for cell_type, res in self.results.items():
            df = res.copy()
            df['cell_type'] = cell_type
            all_results.append(df)
        
        combined_results = pd.concat(all_results)

        # Convert NES and FDR q-val to numeric
        combined_results['NES'] = pd.to_numeric(combined_results['NES'], errors='coerce')
        combined_results['FDR q-val'] = pd.to_numeric(combined_results['FDR q-val'], errors='coerce')
        
        # Filter significant pathways using standard boolean indexing
        sig_pathways = combined_results[combined_results['FDR q-val'] < fdr_cutoff]
        
        # Get top pathways for each cell type
        top_pathways_list = []
        for name, group in sig_pathways.groupby('cell_type'):
            top_n = group.nlargest(n_pathways, 'NES')
            top_pathways_list.append(top_n)
    
        top_pathways = pd.concat(top_pathways_list, ignore_index=True)
        
        # Create plot
        fig = plt.figure(figsize=figsize)
        # create FacetGrid
        g = sns.FacetGrid(
            data=top_pathways,
            col='cell_type',
            col_wrap=3,
            height=6,
            aspect=1.5
        )
        
        g.map_dataframe(
            sns.barplot,
            x='NES',
            y='Term',
            hue='FDR q-val',
            palette='RdBu_r'
        )

        # Adjust y-axis label spacing for each subplot
        for ax in g.axes.flat:
            ax.tick_params(axis='y', pad=15)  # Increase padding
            plt.setp(ax.get_yticklabels(), ha='right')  # Align labels

        # Add a colorbar legend
        norm = plt.Normalize(top_pathways['FDR q-val'].min(), fdr_cutoff)
        sm = plt.cm.ScalarMappable(cmap='RdBu_r', norm=norm)
        sm.set_array([])
    
        # Add colorbar to the right of the subplots
        cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])  # [left, bottom, width, height]
        cbar = fig.colorbar(sm, cax=cbar_ax)
        cbar.set_label('FDR q-value')
    
        
        g.set_axis_labels('Normalized Enrichment Score', 'Pathway')
        g.figure.suptitle('Top Enriched Pathways by Cell Type', y=1.02)
        g.figure.tight_layout()

        # Adjust layout
        plt.subplots_adjust(
            right=0.9,
            wspace=0.4,
            hspace=0.4
        )
    
        # Return both the figure and the grid
        return g

    


In [87]:
# Usage:

### Load your data
"run_20250121_150654_dataset_cell_type_generated_data.h5"
"run_20250121_150654_dataset_cell_type_generated_labels.csv"
"run_20250121_150654_dataset_cell_type_generated_data.csv"
# Set directory to save output files
output_dir = "/Users/guyshani/Documents/PHD/Aim_2/10x_data_mouse/20_1_2025__normalized/saved_models/GSEA_plots/"
# Data path
file_path = "/Users/guyshani/Documents/PHD/Aim_2/10x_data_mouse/20_1_2025__normalized/saved_models/"

# Load the gene symbols of my data
gene_symbols = pd.read_csv("/Users/guyshani/Documents/PHD/Aim_2/10x_data_mouse/20_1_2025__normalized/combined_normalized_data.csv", nrows=1, sep=";", header=None, index_col=0).iloc[0]

### Generated data
# Load expression matrix generated data
with h5py.File(file_path+"run_20250121_150654_dataset_cell_type_generated_data.h5", 'r') as f:
    matrix = f['matrix'][:]
# convert negative values to 0
matrix[matrix < 0 ] =0
# Load labels generatedd data
labels = pd.read_csv(file_path+"run_20250121_150654_dataset_cell_type_generated_labels.csv")
# Create AnnData object
adataGen = sc.AnnData(matrix)
adataGen.var_names = gene_symbols
# Add cell type labels to adata
adataGen.obs['cell_type'] = labels.iloc[:, 2].values  # Assuming cell types are in the third column

### Real data
# Load expression matrix real data
with h5py.File("/Users/guyshani/Documents/PHD/Aim_2/10x_data_mouse/20_1_2025__normalized/combined_normalized_data.h5", 'r') as f:
    matrix = f['matrix'][:]
# Load labels real data
labels = pd.read_csv("/Users/guyshani/Documents/PHD/Aim_2/10x_data_mouse/20_1_2025__normalized/combined_metadata.csv", sep=";")
labels.drop('cluster', axis=1, inplace=True)

# Create AnnData object
adataReal = sc.AnnData(matrix)
adataReal.var_names = gene_symbols
# Add cell type labels to adata
adataReal.obs['cell_type'] = labels.iloc[:, 2].values  # Assuming cell types are in the third column


In [None]:
# /Users/guyshani/Documents/PHD/gene_sets/mh.all.v2024.1.Mm.symbols.gmt.txt
# /Users/guyshani/Documents/PHD/gene_sets/m8.all.v2024.1.Mm.symbols.gmt.txt

# Initialize GSEA analyzer
gsea_analyzer_gen = CellTypeGSEA(
    adataGen,
    cell_type_key='cell_type',
    gmt_file = '/Users/guyshani/Documents/PHD/gene_sets/mh.all.v2024.1.Mm.symbols.gmt.txt'
)
gsea_analyzer_real = CellTypeGSEA(
    adataReal,
    cell_type_key='cell_type',
    gmt_file = '/Users/guyshani/Documents/PHD/gene_sets/mh.all.v2024.1.Mm.symbols.gmt.txt'
)


# Run GSEA
resultsGen = gsea_analyzer_gen.run_gsea()
resultsReal = gsea_analyzer_real.run_gsea()

# Plot results
gG = gsea_analyzer_gen.plot_top_pathways()
gG.figure.savefig(output_dir+'gsea_gen_results.pdf', bbox_inches='tight', dpi=300)
gR = gsea_analyzer_real.plot_top_pathways()
gR.figure.savefig(output_dir+'gsea_real_results.pdf', bbox_inches='tight', dpi=300)

# Save results to CSV
for cell_type, res in resultsGen.items():
    res.to_csv(f'{output_dir}gsea_gen_results_{cell_type}.csv')
for cell_type, res in resultsReal.items():
    res.to_csv(f'{output_dir}gsea_real_results_{cell_type}.csv')

In [None]:
#list(resultsGen.keys())

# Organize both results into comparable dataframes
def prepare_results_df(results, suffix=''):
    all_results = []
    for cell_type, res in results.items():
        res = res.copy()
        res['cell_type'] = cell_type
        all_results.append(res)
    df = pd.concat(all_results)
    
    # Ensure NES is numeric
    df['NES'] = pd.to_numeric(df['NES'], errors='coerce')
    
    return df

# Prepare both dataframes
df_gen = prepare_results_df(resultsGen, '_gen')
df_real = prepare_results_df(resultsReal, '_real')

# Merge the dataframes on cell type and pathway term
merged_df = pd.merge(
    df_gen[['Term', 'cell_type', 'NES']].rename(columns={'NES': 'NES_gen'}),
    df_real[['Term', 'cell_type', 'NES']].rename(columns={'NES': 'NES_real'}),
    on=['Term', 'cell_type'],
    how='inner'
)

In [None]:
merged_df

In [None]:


# Perform statistical test for each cell type
statistical_results = []

for cell_type in merged_df['cell_type'].unique():
    cell_data = merged_df[merged_df['cell_type'] == cell_type]
    
    # Perform paired t-test
    t_stat, p_val = stats.ttest_rel(
        cell_data['NES_gen'],
        cell_data['NES_real']
    )
    
    # Calculate correlation
    corr, corr_p = stats.pearsonr(
        cell_data['NES_gen'],
        cell_data['NES_real']
    )
    
    statistical_results.append({
        'cell_type': cell_type,
        'n_pathways': len(cell_data),
        't_statistic': t_stat,
        'p_value': p_val,
        'correlation': corr,
        'correlation_p': corr_p
    })

# Convert results to dataframe
results_df = pd.DataFrame(statistical_results)

# Add multiple testing correction
_,results_df['p_value_adj'] = fdrcorrection(results_df['p_value'])
_,results_df['correlation_p_adj'] = fdrcorrection(results_df['correlation_p'])

# Sort by p-value
results_df = results_df.sort_values('p_value')

print("Statistical comparison results:")
print(results_df.to_string(float_format=lambda x: '{:.2e}'.format(x) if isinstance(x, float) else str(x)))



# Create subplot grid
n_cell_types = len(merged_df['cell_type'].unique())
n_cols = 3
n_rows = (n_cell_types + n_cols - 1) // n_cols

fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5*n_rows))
axes = axes.flatten()

for idx, cell_type in enumerate(merged_df['cell_type'].unique()):
    cell_data = merged_df[merged_df['cell_type'] == cell_type]
    
    sns.scatterplot(
        data=cell_data,
        x='NES_gen',
        y='NES_real',
        ax=axes[idx]
    )
    
    # Add correlation line
    sns.regplot(
        data=cell_data,
        x='NES_gen',
        y='NES_real',
        scatter=False,
        ax=axes[idx],
        color='red'
    )
    
    axes[idx].set_title(f'{cell_type}\nr={results_df.loc[results_df["cell_type"]==cell_type, "correlation"].iloc[0]:.2f}')
    axes[idx].set_xlabel('Generated NES')
    axes[idx].set_ylabel('Real NES')
    
    # Add diagonal line
    lims = [
        min(axes[idx].get_xlim()[0], axes[idx].get_ylim()[0]),
        max(axes[idx].get_xlim()[1], axes[idx].get_ylim()[1])
    ]
    axes[idx].plot(lims, lims, '--', color='gray', alpha=0.5)

# Remove empty subplots
for idx in range(n_cell_types, len(axes)):
    fig.delaxes(axes[idx])

plt.tight_layout()
plt.show()