# Multi-Species COG Functional Category Analysis

This notebook analyzes COG functional category distributions across core, auxiliary, and singleton genes for multiple species to identify conserved patterns.

## Goals
1. Run COG analysis on 32 taxonomically diverse species
2. Identify universal patterns vs species-specific differences
3. Test if novel genes are consistently enriched in mobile elements (L), surface variation (M)
4. Test if core genes are consistently enriched in housekeeping functions (J, C, H)
5. Examine phylum-level patterns

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from get_spark_session import get_spark_session
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Configure plotting
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (14, 8)

In [None]:
# Initialize Spark session
spark = get_spark_session()
print(f"Spark version: {spark.version}")

In [None]:
# COG category descriptions
COG_DESCRIPTIONS = {
    'J': 'Translation, ribosomal structure',
    'A': 'RNA processing and modification',
    'K': 'Transcription',
    'L': 'Replication, recombination, repair',
    'B': 'Chromatin structure',
    'D': 'Cell cycle control, division',
    'Y': 'Nuclear structure',
    'V': 'Defense mechanisms',
    'T': 'Signal transduction',
    'M': 'Cell wall/membrane biogenesis',
    'N': 'Cell motility',
    'Z': 'Cytoskeleton',
    'W': 'Extracellular structures',
    'U': 'Intracellular trafficking',
    'O': 'Posttranslational modification, chaperones',
    'C': 'Energy production and conversion',
    'G': 'Carbohydrate transport and metabolism',
    'E': 'Amino acid transport and metabolism',
    'F': 'Nucleotide transport and metabolism',
    'H': 'Coenzyme transport and metabolism',
    'I': 'Lipid transport and metabolism',
    'P': 'Inorganic ion transport',
    'Q': 'Secondary metabolites biosynthesis',
    'R': 'General function prediction only',
    'S': 'Function unknown',
    'NU': 'Motility and trafficking',  # Composite category
}

# Expected enrichments based on N. gonorrhoeae results
EXPECTED_NOVEL_ENRICHED = ['L', 'M', 'NU', 'U', 'E']  # Mobile elements, surface, metabolism
EXPECTED_NOVEL_DEPLETED = ['J', 'C', 'H', 'S', 'D']  # Translation, energy, coenzyme

## Step 1: Load sampled species

In [None]:
# Load the stratified sample of species
sampled_species = pd.read_csv('../data/sampled_species_for_cog_analysis.csv')

print(f"Loaded {len(sampled_species)} species for analysis")
print(f"\nPhylum distribution:")
print(sampled_species['phylum'].value_counts())
print(f"\nGenome count range: {sampled_species['no_genomes'].min()}-{sampled_species['no_genomes'].max()}")

sampled_species.head(10)

## Step 2: Define query function for COG distributions

In [None]:
def get_cog_distribution(species_id, gene_class='core'):
    """
    Query COG category distribution for a specific gene class in a species.
    
    Parameters:
    - species_id: GTDB species clade ID
    - gene_class: 'core', 'auxiliary', or 'singleton'
    
    Returns:
    - DataFrame with COG_category and gene_count
    """
    # Set filter conditions based on gene class
    if gene_class == 'core':
        class_filter = "gc.is_core = 1"
    elif gene_class == 'singleton':
        class_filter = "gc.is_singleton = 1"
    elif gene_class == 'auxiliary':
        class_filter = "gc.is_auxiliary = 1 AND gc.is_singleton = 0"
    else:
        raise ValueError(f"Unknown gene_class: {gene_class}")
    
    query = f"""
    SELECT 
        ann.COG_category,
        COUNT(*) as gene_count
    FROM kbase_ke_pangenome.gene_cluster gc
    JOIN kbase_ke_pangenome.gene_genecluster_junction j 
        ON gc.gene_cluster_id = j.gene_cluster_id
    JOIN kbase_ke_pangenome.eggnog_mapper_annotations ann 
        ON j.gene_id = ann.query_name
    WHERE gc.gtdb_species_clade_id = '{species_id}'
        AND {class_filter}
        AND ann.COG_category IS NOT NULL
        AND ann.COG_category != '-'
    GROUP BY ann.COG_category
    ORDER BY gene_count DESC
    """
    
    try:
        df = spark.sql(query).toPandas()
        df['gene_count'] = pd.to_numeric(df['gene_count'], errors='coerce')
        return df
    except Exception as e:
        print(f"Error querying {gene_class} for {species_id}: {e}")
        return pd.DataFrame(columns=['COG_category', 'gene_count'])


def analyze_species_cog(species_row):
    """
    Run complete COG analysis for a single species.
    
    Returns:
    - Dictionary with core, auxiliary, singleton DataFrames and metadata
    """
    species_id = species_row['gtdb_species_clade_id']
    species_name = species_row['GTDB_species']
    
    # Query all three gene classes
    core_df = get_cog_distribution(species_id, 'core')
    aux_df = get_cog_distribution(species_id, 'auxiliary')
    singleton_df = get_cog_distribution(species_id, 'singleton')
    
    # Add gene class labels
    core_df['gene_class'] = 'Core'
    aux_df['gene_class'] = 'Auxiliary'
    singleton_df['gene_class'] = 'Singleton/Novel'
    
    # Combine
    combined = pd.concat([core_df, aux_df, singleton_df], ignore_index=True)
    
    if len(combined) == 0:
        return None
    
    # Calculate proportions
    class_totals = combined.groupby('gene_class')['gene_count'].sum()
    combined['proportion'] = combined.apply(
        lambda row: row['gene_count'] / class_totals.get(row['gene_class'], 1) if row['gene_class'] in class_totals.index else 0,
        axis=1
    )
    
    # Add metadata
    combined['species_id'] = species_id
    combined['species_name'] = species_name
    combined['phylum'] = species_row['phylum']
    combined['no_genomes'] = species_row['no_genomes']
    
    return combined


# Test on one species
print("Testing query function on first species...")
test_result = analyze_species_cog(sampled_species.iloc[0])
if test_result is not None:
    print(f"Success! Retrieved {len(test_result)} COG category records")
    print(test_result.head())
else:
    print("Warning: Test query returned no results")

## Step 3: Run analysis on all species

**Note**: This may take 10-30 minutes depending on database load. Progress bar shows current status.

In [None]:
# Run analysis on all species
all_results = []
failed_species = []

print(f"Analyzing {len(sampled_species)} species...\n")

for idx, row in sampled_species.iterrows():
    species_name = row['GTDB_species']
    print(f"[{idx+1}/{len(sampled_species)}] Processing {species_name}...", end=' ')
    
    result = analyze_species_cog(row)
    
    if result is not None and len(result) > 0:
        all_results.append(result)
        print(f"✓ ({len(result)} records)")
    else:
        failed_species.append(species_name)
        print("✗ (no data)")

print(f"\n{'='*80}")
print(f"Analysis complete!")
print(f"  Successful: {len(all_results)} species")
print(f"  Failed: {len(failed_species)} species")
if failed_species:
    print(f"  Failed species: {', '.join(failed_species[:5])}{'...' if len(failed_species) > 5 else ''}")

In [None]:
# Combine all results
if all_results:
    all_cog_data = pd.concat(all_results, ignore_index=True)
    
    # Save raw results
    all_cog_data.to_csv('../data/multi_species_cog_results.csv', index=False)
    print(f"Saved {len(all_cog_data)} records to ../data/multi_species_cog_results.csv")
    
    print(f"\nData summary:")
    print(f"  Total records: {len(all_cog_data)}")
    print(f"  Species: {all_cog_data['species_name'].nunique()}")
    print(f"  COG categories: {all_cog_data['COG_category'].nunique()}")
    print(f"  Total genes analyzed: {all_cog_data['gene_count'].sum():,.0f}")
    
    all_cog_data.head(20)
else:
    print("ERROR: No results collected!")

## Step 4: Calculate enrichment scores across all species

In [None]:
# Calculate enrichment (Novel - Core proportion) for each species
enrichment_data = []

for species_name in all_cog_data['species_name'].unique():
    species_data = all_cog_data[all_cog_data['species_name'] == species_name]
    
    # Pivot to get proportions by gene class
    pivot = species_data.pivot_table(
        index='COG_category',
        columns='gene_class',
        values='proportion',
        fill_value=0
    )
    
    if 'Core' in pivot.columns and 'Singleton/Novel' in pivot.columns:
        enrichment = pivot['Singleton/Novel'] - pivot['Core']
        
        for cog_cat in enrichment.index:
            enrichment_data.append({
                'species_name': species_name,
                'phylum': species_data.iloc[0]['phylum'],
                'COG_category': cog_cat,
                'enrichment': enrichment[cog_cat],
                'core_prop': pivot.loc[cog_cat, 'Core'],
                'novel_prop': pivot.loc[cog_cat, 'Singleton/Novel']
            })

enrichment_df = pd.DataFrame(enrichment_data)

print(f"Calculated enrichment scores for {len(enrichment_df)} species × COG combinations")
enrichment_df.head(20)

## Step 5: Identify conserved patterns

Which COG categories are consistently enriched/depleted in novel genes across species?

In [None]:
# Aggregate enrichment across species
cog_summary = enrichment_df.groupby('COG_category').agg({
    'enrichment': ['mean', 'std', 'median'],
    'species_name': 'count'
}).round(4)

cog_summary.columns = ['mean_enrichment', 'std_enrichment', 'median_enrichment', 'n_species']
cog_summary = cog_summary.reset_index()
cog_summary = cog_summary.sort_values('mean_enrichment', ascending=False)

# Add descriptions
cog_summary['description'] = cog_summary['COG_category'].map(COG_DESCRIPTIONS)

# Calculate consistency (% of species where enrichment has same sign as mean)
def calculate_consistency(cog_cat):
    cat_data = enrichment_df[enrichment_df['COG_category'] == cog_cat]
    mean_enrich = cat_data['enrichment'].mean()
    if mean_enrich > 0:
        consistent = (cat_data['enrichment'] > 0).sum()
    else:
        consistent = (cat_data['enrichment'] < 0).sum()
    return consistent / len(cat_data) * 100

cog_summary['consistency_pct'] = cog_summary['COG_category'].apply(calculate_consistency)

print("\n" + "="*80)
print("COG ENRICHMENT SUMMARY (Novel vs Core genes)")
print("="*80)
print("\nTop 10 ENRICHED in novel genes:")
print(cog_summary.head(10)[['COG_category', 'description', 'mean_enrichment', 'consistency_pct', 'n_species']].to_string(index=False))

print("\nTop 10 DEPLETED in novel genes:")
print(cog_summary.tail(10)[['COG_category', 'description', 'mean_enrichment', 'consistency_pct', 'n_species']].to_string(index=False))

# Save summary
cog_summary.to_csv('../data/cog_enrichment_summary.csv', index=False)
print("\nSaved summary to ../data/cog_enrichment_summary.csv")

## Step 6: Visualizations

In [None]:
# Plot 1: Heatmap of enrichment across species
pivot_enrichment = enrichment_df.pivot_table(
    index='COG_category',
    columns='species_name',
    values='enrichment',
    fill_value=0
)

# Sort by mean enrichment
row_order = cog_summary.sort_values('mean_enrichment', ascending=False)['COG_category']
pivot_enrichment = pivot_enrichment.loc[[cat for cat in row_order if cat in pivot_enrichment.index]]

fig, ax = plt.subplots(figsize=(18, 12))
sns.heatmap(
    pivot_enrichment,
    cmap='RdBu_r',
    center=0,
    cbar_kws={'label': 'Enrichment (Novel - Core)'},
    ax=ax,
    xticklabels=False,
    yticklabels=True
)
ax.set_xlabel('Species', fontsize=12)
ax.set_ylabel('COG Category', fontsize=12)
ax.set_title('COG Category Enrichment in Novel Genes Across Species\n(Red = enriched in novel, Blue = enriched in core)', fontsize=14)
plt.tight_layout()
plt.savefig('../data/multi_species_enrichment_heatmap.png', dpi=300, bbox_inches='tight')
plt.show()

print("Heatmap saved to ../data/multi_species_enrichment_heatmap.png")

In [None]:
# Plot 2: Distribution of enrichment scores for key COG categories
key_cogs = ['L', 'M', 'J', 'C', 'H', 'V', 'E', 'S']
key_data = enrichment_df[enrichment_df['COG_category'].isin(key_cogs)]

fig, ax = plt.subplots(figsize=(14, 8))
sns.boxplot(
    data=key_data,
    x='COG_category',
    y='enrichment',
    order=sorted(key_cogs),
    palette='Set2',
    ax=ax
)
ax.axhline(y=0, color='black', linestyle='--', linewidth=1)
ax.set_xlabel('COG Category', fontsize=12)
ax.set_ylabel('Enrichment (Novel - Core)', fontsize=12)
ax.set_title('Distribution of COG Enrichment Across Species\n(Positive = enriched in novel genes)', fontsize=14)

# Add category descriptions
labels = [f"{cog}\n{COG_DESCRIPTIONS.get(cog, '')[:20]}..." for cog in sorted(key_cogs)]
ax.set_xticklabels(labels, fontsize=10)

plt.tight_layout()
plt.savefig('../data/cog_enrichment_distribution.png', dpi=300, bbox_inches='tight')
plt.show()

print("Distribution plot saved to ../data/cog_enrichment_distribution.png")

In [None]:
# Plot 3: Mean enrichment with error bars
top_n = 15
top_enriched = cog_summary.nlargest(top_n, 'mean_enrichment')
top_depleted = cog_summary.nsmallest(top_n, 'mean_enrichment')
plot_data = pd.concat([top_enriched, top_depleted])

fig, ax = plt.subplots(figsize=(12, 10))
colors = ['coral' if x > 0 else 'skyblue' for x in plot_data['mean_enrichment']]
ax.barh(
    range(len(plot_data)),
    plot_data['mean_enrichment'],
    xerr=plot_data['std_enrichment'],
    color=colors,
    alpha=0.7,
    error_kw={'linewidth': 1, 'ecolor': 'gray'}
)

# Labels with descriptions
labels = [f"{row['COG_category']}: {row['description'][:40]}" for _, row in plot_data.iterrows()]
ax.set_yticks(range(len(plot_data)))
ax.set_yticklabels(labels, fontsize=9)
ax.axvline(x=0, color='black', linestyle='-', linewidth=1)
ax.set_xlabel('Mean Enrichment ± SD (Novel - Core)', fontsize=12)
ax.set_title(f'Top {top_n} Enriched and Depleted COG Categories in Novel Genes\n(Across {len(all_cog_data["species_name"].unique())} species)', fontsize=14)
ax.grid(axis='x', alpha=0.3)

plt.tight_layout()
plt.savefig('../data/cog_mean_enrichment.png', dpi=300, bbox_inches='tight')
plt.show()

print("Mean enrichment plot saved to ../data/cog_mean_enrichment.png")

In [None]:
# Plot 4: Phylum-specific patterns
phylum_enrichment = enrichment_df.groupby(['phylum', 'COG_category'])['enrichment'].mean().reset_index()
pivot_phylum = phylum_enrichment.pivot_table(
    index='COG_category',
    columns='phylum',
    values='enrichment',
    fill_value=0
)

# Sort by overall enrichment
row_order = cog_summary.sort_values('mean_enrichment', ascending=False)['COG_category'].head(20)
pivot_phylum = pivot_phylum.loc[[cat for cat in row_order if cat in pivot_phylum.index]]

fig, ax = plt.subplots(figsize=(12, 10))
sns.heatmap(
    pivot_phylum,
    cmap='RdBu_r',
    center=0,
    annot=True,
    fmt='.2f',
    cbar_kws={'label': 'Mean Enrichment'},
    ax=ax
)
ax.set_xlabel('Phylum', fontsize=12)
ax.set_ylabel('COG Category', fontsize=12)
ax.set_title('COG Enrichment Patterns by Phylum\n(Top 20 most variable categories)', fontsize=14)
plt.tight_layout()
plt.savefig('../data/cog_enrichment_by_phylum.png', dpi=300, bbox_inches='tight')
plt.show()

print("Phylum-specific plot saved to ../data/cog_enrichment_by_phylum.png")

## Step 7: Statistical analysis

In [None]:
# Test if observed patterns match expectations from N. gonorrhoeae
print("\n" + "="*80)
print("HYPOTHESIS TESTING")
print("="*80)

print("\nBased on N. gonorrhoeae, we expected:")
print(f"  Enriched in novel: {', '.join(EXPECTED_NOVEL_ENRICHED)}")
print(f"  Depleted in novel: {', '.join(EXPECTED_NOVEL_DEPLETED)}")

# Check how many of our expectations are confirmed
top_10_enriched = set(cog_summary.head(10)['COG_category'])
top_10_depleted = set(cog_summary.tail(10)['COG_category'])

enriched_confirmed = [cog for cog in EXPECTED_NOVEL_ENRICHED if cog in top_10_enriched]
depleted_confirmed = [cog for cog in EXPECTED_NOVEL_DEPLETED if cog in top_10_depleted]

print("\nResults across all species:")
print(f"  Enriched expectations confirmed: {len(enriched_confirmed)}/{len(EXPECTED_NOVEL_ENRICHED)} ({', '.join(enriched_confirmed)})")
print(f"  Depleted expectations confirmed: {len(depleted_confirmed)}/{len(EXPECTED_NOVEL_DEPLETED)} ({', '.join(depleted_confirmed)})")

# Calculate consistency for expected categories
print("\nConsistency of expected patterns:")
for cog in EXPECTED_NOVEL_ENRICHED:
    if cog in cog_summary['COG_category'].values:
        row = cog_summary[cog_summary['COG_category'] == cog].iloc[0]
        print(f"  {cog} ({COG_DESCRIPTIONS.get(cog, 'Unknown')}):")
        print(f"    Mean enrichment: {row['mean_enrichment']:+.3f}")
        print(f"    Consistency: {row['consistency_pct']:.1f}% of species")

In [None]:
# Test for phylum-specific differences
print("\n" + "="*80)
print("PHYLUM-SPECIFIC PATTERNS")
print("="*80)

for cog in ['L', 'M', 'J', 'C']:
    cog_data = enrichment_df[enrichment_df['COG_category'] == cog]
    phylum_means = cog_data.groupby('phylum')['enrichment'].mean().sort_values(ascending=False)
    
    print(f"\n{cog} ({COG_DESCRIPTIONS.get(cog, 'Unknown')}):")
    for phylum, mean_enrich in phylum_means.items():
        print(f"  {phylum:25s}: {mean_enrich:+.3f}")

## Step 8: Summary and conclusions

In [None]:
print("\n" + "="*80)
print("ANALYSIS SUMMARY")
print("="*80)

print(f"\nDataset:")
print(f"  Species analyzed: {len(all_cog_data['species_name'].unique())}")
print(f"  Phyla represented: {len(all_cog_data['phylum'].unique())}")
print(f"  Total genes analyzed: {all_cog_data['gene_count'].sum():,.0f}")
print(f"  COG categories found: {all_cog_data['COG_category'].nunique()}")

print(f"\nKey findings:")
print(f"  1. Most consistently enriched in novel genes:")
for _, row in cog_summary.head(5).iterrows():
    print(f"     - {row['COG_category']}: {row['description'][:50]} ({row['consistency_pct']:.0f}% consistent)")

print(f"\n  2. Most consistently depleted in novel genes:")
for _, row in cog_summary.tail(5).iterrows():
    print(f"     - {row['COG_category']}: {row['description'][:50]} ({row['consistency_pct']:.0f}% consistent)")

print(f"\n  3. Patterns from N. gonorrhoeae:")
if len(enriched_confirmed) >= 3 and len(depleted_confirmed) >= 3:
    print(f"     ✓ CONFIRMED across species")
else:
    print(f"     ✗ NOT fully replicated across species")

print(f"\nGenerated files:")
print(f"  - ../data/multi_species_cog_results.csv")
print(f"  - ../data/cog_enrichment_summary.csv")
print(f"  - ../data/multi_species_enrichment_heatmap.png")
print(f"  - ../data/cog_enrichment_distribution.png")
print(f"  - ../data/cog_mean_enrichment.png")
print(f"  - ../data/cog_enrichment_by_phylum.png")