In [1]:
!pip install gseapy

Collecting gseapy
  Downloading gseapy-1.1.10-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (11 kB)
Downloading gseapy-1.1.10-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (600 kB)
[2K   [38;2;114;156;31m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m600.9/600.9 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m0m eta [36m-:--:--[0m
[?25hInstalling collected packages: gseapy
Successfully installed gseapy-1.1.10


In [3]:
import pandas as pd
import numpy as np
import gseapy as gp
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# ============================================================================
# CONFIGURATION
# ============================================================================



# Output directory
OUTPUT_DIR = Path('GO_enrichment_results')
OUTPUT_DIR.mkdir(exist_ok=True)

# GO databases to use
GENE_SETS = [
    'GO_Biological_Process_2023',
    'KEGG_2021_Human'
]

# Organism
ORGANISM = 'Human'

# Significance threshold for  results
PADJ_THRESHOLD = 0.05

print("="*70)
print("GO ENRICHMENT ANALYSIS - scVI DEGs (Upregulated Genes)")
print("="*70)

# ============================================================================
#  LOADING CELL TYPES FROM ADATA
# ============================================================================

print("\n[1/5] Loading cell types from AnnData object...")

import scanpy as sc
adata = sc.read('annotated.h5ad')
cell_types = adata.obs['Cell_Type'].unique().tolist()

print(f"Found {len(cell_types)} cell types:")
for i, ct in enumerate(cell_types, 1):
    print(f"  {i}. {ct}")

# ============================================================================
#  LOADING DEG FILES AND EXTRACT UPREGULATED GENES
# ============================================================================

print("\n[2/5] Loading DEG files and extracting upregulated genes...")

upregulated_genes = {}
deg_stats = []

for cell_type in cell_types:
    # Create filename
    filename = f"DEG_{cell_type.replace(' ', '_')}_tumor_vs_normal.csv"
    
    try:
        # Load DEG file
        df = pd.read_csv(filename)
        
        # Extract gene names (from 'Unnamed: 0' column)
        df['gene'] = df['Unnamed: 0']
        
        # Filter for upregulated genes (positive lfc_mean)
        upregulated = df[df['lfc_mean'] > 0]['gene'].tolist()
        
        # Store results
        upregulated_genes[cell_type] = upregulated
        
        # Statistics
        total_degs = len(df)
        n_upregulated = len(upregulated)
        n_downregulated = total_degs - n_upregulated
        
        deg_stats.append({
            'Cell_Type': cell_type,
            'Total_DEGs': total_degs,
            'Upregulated': n_upregulated,
            'Downregulated': n_downregulated,
            'File': filename,
            'Status': 'Loaded'
        })
        
        print(f"  ‚úì {cell_type}: {n_upregulated} upregulated / {total_degs} total DEGs")
        
    except FileNotFoundError:
        print(f"  ‚úó {cell_type}: File not found - {filename}")
        upregulated_genes[cell_type] = []
        deg_stats.append({
            'Cell_Type': cell_type,
            'Total_DEGs': 0,
            'Upregulated': 0,
            'Downregulated': 0,
            'File': filename,
            'Status': 'Not Found'
        })
    except Exception as e:
        print(f"  ‚úó {cell_type}: Error - {str(e)}")
        upregulated_genes[cell_type] = []
        deg_stats.append({
            'Cell_Type': cell_type,
            'Total_DEGs': 0,
            'Upregulated': 0,
            'Downregulated': 0,
            'File': filename,
            'Status': f'Error: {str(e)}'
        })

# Save DEG statistics
deg_stats_df = pd.DataFrame(deg_stats)
deg_stats_df.to_csv(OUTPUT_DIR / 'DEG_statistics_summary.csv', index=False)
print(f"\n  ‚Üí Saved: {OUTPUT_DIR / 'DEG_statistics_summary.csv'}")

# ============================================================================
#  RUNNING GO ENRICHMENT FOR EACH CELL TYPE
# ============================================================================

print("\n[3/5] Running GO enrichment analysis...")

enrichment_results = {}
successful_enrichments = 0
failed_enrichments = 0

for cell_type in cell_types:
    gene_list = upregulated_genes[cell_type]
    
    # Skip if no genes
    if len(gene_list) == 0:
        print(f"  ‚äò {cell_type}: Skipped (no upregulated genes)")
        continue
    
    # Skip if too few genes (enrichment needs at least 2-3 genes)
    if len(gene_list) < 3:
        print(f"  ‚äò {cell_type}: Skipped (only {len(gene_list)} genes, minimum 3 required)")
        continue
    
    print(f"  ‚ü≥ {cell_type}: Analyzing {len(gene_list)} genes...")
    
    try:
        # Create output directory for this cell type
        cell_type_dir = OUTPUT_DIR / cell_type.replace(' ', '_')
        cell_type_dir.mkdir(exist_ok=True)
        
        # Run enrichment
        enr = gp.enrichr(
            gene_list=gene_list,
            gene_sets=GENE_SETS,
            organism=ORGANISM,
            outdir=str(cell_type_dir),
            cutoff=0.05,
            no_plot=True  # We'll create custom plots
        )
        
        # Store results
        enrichment_results[cell_type] = enr.results
        
        # Save results
        enr.results.to_csv(cell_type_dir / f'{cell_type.replace(" ", "_")}_GO_enrichment.csv', index=False)
        
        # Count significant terms
        sig_terms = len(enr.results[enr.results['Adjusted P-value'] < PADJ_THRESHOLD])
        
        print(f"    ‚úì Found {sig_terms} significant terms (padj < {PADJ_THRESHOLD})")
        successful_enrichments += 1
        
    except Exception as e:
        print(f"    ‚úó Error: {str(e)}")
        failed_enrichments += 1
        enrichment_results[cell_type] = None

print(f"\n  Summary: {successful_enrichments} successful, {failed_enrichments} failed")

# ============================================================================
#  CREATING VISUALIZATIONS
# ============================================================================

print("\n[4/5] Creating visualizations...")

# Create a figure directory
fig_dir = OUTPUT_DIR / 'figures'
fig_dir.mkdir(exist_ok=True)

# --- Individual dotplots for each cell type ---
print("  ‚Ä¢ Generating individual dotplots...")

for cell_type, results in enrichment_results.items():
    if results is None or len(results) == 0:
        continue
    
    # Filter significant results
    sig_results = results[results['Adjusted P-value'] < PADJ_THRESHOLD].copy()
    
    if len(sig_results) == 0:
        print(f"    ‚äò {cell_type}: No significant terms to plot")
        continue
    
    # Take top 15 terms (sorted by adjusted p-value)
    sig_results = sig_results.sort_values('Adjusted P-value').head(15)
    
    # Create dotplot
    fig, ax = plt.subplots(figsize=(10, 8))
    
    # Plot
    scatter = ax.scatter(
        sig_results['Combined Score'],
        range(len(sig_results)),
        s=sig_results['Overlap'].str.split('/').str[0].astype(int) * 10,
        c=-np.log10(sig_results['Adjusted P-value']),
        cmap='Reds',
        alpha=0.7,
        edgecolors='black',
        linewidth=0.5
    )
    
    # Formatting
    ax.set_yticks(range(len(sig_results)))
    ax.set_yticklabels(sig_results['Term'], fontsize=9)
    ax.set_xlabel('Combined Score', fontsize=11)
    ax.set_title(f'GO Enrichment - {cell_type}\n(Upregulated genes, n={len(upregulated_genes[cell_type])})', 
                 fontsize=12, fontweight='bold')
    ax.grid(True, alpha=0.3, axis='x')
    
    # Colorbar
    cbar = plt.colorbar(scatter, ax=ax)
    cbar.set_label('-log10(Adjusted P-value)', fontsize=10)
    
    plt.tight_layout()
    plt.savefig(fig_dir / f'{cell_type.replace(" ", "_")}_dotplot.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"    ‚úì {cell_type}")

# --- Summary barplot: Number of enriched terms per cell type ---
print("  ‚Ä¢ Generating summary plots...")

term_counts = []
for cell_type, results in enrichment_results.items():
    if results is not None:
        n_sig = len(results[results['Adjusted P-value'] < PADJ_THRESHOLD])
        term_counts.append({'Cell_Type': cell_type, 'Significant_Terms': n_sig})

if term_counts:
    term_counts_df = pd.DataFrame(term_counts).sort_values('Significant_Terms', ascending=False)
    
    fig, ax = plt.subplots(figsize=(10, 8))
    ax.barh(term_counts_df['Cell_Type'], term_counts_df['Significant_Terms'], 
            color='steelblue', alpha=0.7)
    ax.set_xlabel('Number of Significant GO Terms', fontsize=11)
    ax.set_title('Enriched GO Terms per Cell Type\n(Upregulated genes)', 
                 fontsize=12, fontweight='bold')
    ax.grid(True, alpha=0.3, axis='x')
    plt.tight_layout()
    plt.savefig(fig_dir / 'summary_enriched_terms_per_celltype.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"    ‚úì Summary barplot saved")

# --- Top enriched terms across all cell types ---
print("  ‚Ä¢ Generating top terms heatmap...")

all_terms = []
for cell_type, results in enrichment_results.items():
    if results is not None:
        sig = results[results['Adjusted P-value'] < PADJ_THRESHOLD].copy()
        sig['Cell_Type'] = cell_type
        all_terms.append(sig[['Term', 'Cell_Type', 'Adjusted P-value', 'Gene_set']])

if all_terms:
    all_terms_df = pd.concat(all_terms, ignore_index=True)
    
    # Get top 20 most frequent terms
    top_terms = all_terms_df['Term'].value_counts().head(20).index
    
    # Create matrix: rows=terms, columns=cell types, values=-log10(padj)
    heatmap_data = []
    for term in top_terms:
        row = {'Term': term}
        for cell_type in enrichment_results.keys():
            if enrichment_results[cell_type] is not None:
                match = enrichment_results[cell_type][enrichment_results[cell_type]['Term'] == term]
                if len(match) > 0:
                    row[cell_type] = -np.log10(match.iloc[0]['Adjusted P-value'])
                else:
                    row[cell_type] = 0
            else:
                row[cell_type] = 0
        heatmap_data.append(row)
    
    heatmap_df = pd.DataFrame(heatmap_data).set_index('Term')
    
    # Plot heatmap
    fig, ax = plt.subplots(figsize=(12, 10))
    sns.heatmap(heatmap_df, cmap='YlOrRd', cbar_kws={'label': '-log10(Adjusted P-value)'}, 
                linewidths=0.5, ax=ax)
    ax.set_title('Top 20 Enriched GO Terms Across Cell Types', fontsize=12, fontweight='bold')
    ax.set_xlabel('Cell Type', fontsize=11)
    ax.set_ylabel('GO Term', fontsize=11)
    plt.xticks(rotation=45, ha='right', fontsize=9)
    plt.yticks(fontsize=8)
    plt.tight_layout()
    plt.savefig(fig_dir / 'top_terms_heatmap.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"    ‚úì Heatmap saved")

print(f"\n  ‚Üí All figures saved to: {fig_dir}")

# ============================================================================
#  CREATING SUMMARY TABLES
# ============================================================================

print("\n[5/5] Creating summary tables...")

# --- Top terms per cell type ---
top_terms_summary = []
for cell_type, results in enrichment_results.items():
    if results is not None and len(results) > 0:
        sig = results[results['Adjusted P-value'] < PADJ_THRESHOLD].sort_values('Adjusted P-value').head(10)
        for idx, row in sig.iterrows():
            top_terms_summary.append({
                'Cell_Type': cell_type,
                'Term': row['Term'],
                'Gene_set': row['Gene_set'],
                'Adjusted_P_value': row['Adjusted P-value'],
                'Combined_Score': row['Combined Score'],
                'Genes': row['Genes'],
                'Overlap': row['Overlap']
            })

if top_terms_summary:
    top_terms_df = pd.DataFrame(top_terms_summary)
    top_terms_df.to_csv(OUTPUT_DIR / 'top_enriched_terms_all_celltypes.csv', index=False)
    print(f"  ‚úì Top terms summary saved")

# --- Overall summary ---
summary_data = []
for cell_type in cell_types:
    n_genes = len(upregulated_genes[cell_type])
    
    if enrichment_results.get(cell_type) is not None:
        results = enrichment_results[cell_type]
        n_total_terms = len(results)
        n_sig_terms = len(results[results['Adjusted P-value'] < PADJ_THRESHOLD])
        
        # Separate by database
        go_terms = len(results[results['Gene_set'].str.contains('GO', na=False)])
        kegg_terms = len(results[results['Gene_set'].str.contains('KEGG', na=False)])
        
        go_sig = len(results[(results['Adjusted P-value'] < PADJ_THRESHOLD) & 
                             (results['Gene_set'].str.contains('GO', na=False))])
        kegg_sig = len(results[(results['Adjusted P-value'] < PADJ_THRESHOLD) & 
                               (results['Gene_set'].str.contains('KEGG', na=False))])
    else:
        n_total_terms = n_sig_terms = go_terms = kegg_terms = go_sig = kegg_sig = 0
    
    summary_data.append({
        'Cell_Type': cell_type,
        'Upregulated_Genes': n_genes,
        'Total_Terms_Found': n_total_terms,
        'Significant_Terms': n_sig_terms,
        'GO_Terms': go_terms,
        'GO_Significant': go_sig,
        'KEGG_Terms': kegg_terms,
        'KEGG_Significant': kegg_sig
    })

summary_df = pd.DataFrame(summary_data).sort_values('Significant_Terms', ascending=False)
summary_df.to_csv(OUTPUT_DIR / 'enrichment_summary.csv', index=False)
print(f"  ‚úì Overall summary saved")

print(f"\n  ‚Üí All summary tables saved to: {OUTPUT_DIR}")

# ============================================================================
# FINAL SUMMARY
# ============================================================================

print("\n" + "="*70)
print("ANALYSIS COMPLETE!")
print("="*70)
print(f"\nüìä Results Summary:")
print(f"  ‚Ä¢ Cell types analyzed: {len(cell_types)}")
print(f"  ‚Ä¢ Successful enrichments: {successful_enrichments}")
print(f"  ‚Ä¢ Total upregulated genes analyzed: {sum(len(g) for g in upregulated_genes.values())}")
print(f"\nüìÅ Output files:")
print(f"  ‚Ä¢ Main directory: {OUTPUT_DIR}")
print(f"  ‚Ä¢ Figures: {fig_dir}")
print(f"  ‚Ä¢ Individual cell type results: {OUTPUT_DIR}/[cell_type]/")
print("\n‚úÖ All done! Check the output directory for results.")
print("="*70)

GO ENRICHMENT ANALYSIS - scVI DEGs (Upregulated Genes)

[1/5] Loading cell types from AnnData object...
Found 21 cell types:
  1. CD4+ T Cells
  2. CMS3
  3. Tip-like ECs
  4. CD8+ T cells
  5. B Cells
  6. Spp1+
  7. Mast cells
  8. Stromal 2
  9. CMS2
  10. Regulatory T Cells
  11. Pericytes
  12. Dendritic cells
  13. Gamma delta T cells
  14. Helper 17 T cells
  15. Mature Enterocytes type 2
  16. NK cells
  17. Plasma Cells
  18. Stromal 3
  19. Plasmacytoid Dendritic Cells
  20. Follicular helper T cells
  21. Enteric glia cells

[2/5] Loading DEG files and extracting upregulated genes...
  ‚úì CD4+ T Cells: 0 upregulated / 0 total DEGs
  ‚úì CMS3: 10 upregulated / 19 total DEGs
  ‚úì Tip-like ECs: 4 upregulated / 29 total DEGs
  ‚úì CD8+ T cells: 0 upregulated / 0 total DEGs
  ‚úì B Cells: 0 upregulated / 1 total DEGs
  ‚úì Spp1+: 22 upregulated / 146 total DEGs
  ‚úì Mast cells: 24 upregulated / 117 total DEGs
  ‚úì Stromal 2: 8 upregulated / 31 total DEGs
  ‚úì CMS2: 2 upregul