# NB 04: Module Functional Annotation

Label each ICA module with biological function using enrichment analysis.

**Part 1 (JupyterHub)**: Extract KEGG, SEED, domain, and specific phenotype
annotations from Spark.

**Part 2 (local)**: Fisher exact test enrichment for each module.

Run Part 1 on JupyterHub first, then Part 2 can run locally.

In [None]:
import pandas as pd
import numpy as np
from pathlib import Path
from scipy import stats as scipy_stats
from statsmodels.stats.multitest import multipletests

DATA_DIR = Path('../data')
ANNOT_DIR = DATA_DIR / 'annotations'
MODULE_DIR = DATA_DIR / 'modules'
ANNOT_DIR.mkdir(parents=True, exist_ok=True)

pilots = pd.read_csv(DATA_DIR / 'pilot_organisms.csv')
pilot_ids = pilots['orgId'].tolist()
print(f"Pilot organisms: {pilot_ids}")

## Part 1: Extract Annotations from Spark

**Run this section on JupyterHub.**

In [None]:
# Initialize Spark (comment out if running Part 2 locally)
try:
    spark = get_spark_session()
    HAS_SPARK = True
    print(f"Spark version: {spark.version}")
except Exception:
    HAS_SPARK = False
    print("No Spark available — running Part 2 only (local mode)")

In [None]:
if HAS_SPARK:
    for org_id in pilot_ids:
        # KEGG annotations
        kegg_file = ANNOT_DIR / f'{org_id}_kegg.csv'
        if not (kegg_file.exists() and kegg_file.stat().st_size > 0):
            kegg = spark.sql(f"""
                SELECT km.locusId, km.kgroup, kd.desc as kgroup_desc,
                       ke.ec
                FROM kescience_fitnessbrowser.keggmember km
                LEFT JOIN kescience_fitnessbrowser.kgroupdesc kd
                    ON km.kgroup = kd.kgroup
                LEFT JOIN kescience_fitnessbrowser.kgroupec ke
                    ON km.kgroup = ke.kgroup
                WHERE km.orgId = '{org_id}'
            """).toPandas()
            kegg.to_csv(kegg_file, index=False)
            print(f"KEGG: {org_id} — {len(kegg)} annotations")
        else:
            print(f"CACHED: {org_id} KEGG")

        # SEED annotations
        seed_file = ANNOT_DIR / f'{org_id}_seed.csv'
        if not (seed_file.exists() and seed_file.stat().st_size > 0):
            seed = spark.sql(f"""
                SELECT sa.locusId, sa.seed_desc,
                       sc.subsystem, sc.category1, sc.category2, sc.category3
                FROM kescience_fitnessbrowser.seedannotation sa
                LEFT JOIN kescience_fitnessbrowser.seedclass sc
                    ON sa.seed_desc = sc.seed_desc
                WHERE sa.orgId = '{org_id}'
            """).toPandas()
            seed.to_csv(seed_file, index=False)
            print(f"SEED: {org_id} — {len(seed)} annotations")
        else:
            print(f"CACHED: {org_id} SEED")

        # Domain annotations
        domain_file = ANNOT_DIR / f'{org_id}_domains.csv'
        if not (domain_file.exists() and domain_file.stat().st_size > 0):
            domains = spark.sql(f"""
                SELECT locusId, domainDb, domainId, domainName,
                       definition, geneSymbol, ec
                FROM kescience_fitnessbrowser.genedomain
                WHERE orgId = '{org_id}'
            """).toPandas()
            domains.to_csv(domain_file, index=False)
            print(f"Domains: {org_id} — {len(domains)} annotations")
        else:
            print(f"CACHED: {org_id} domains")

        # Specific phenotypes
        pheno_file = ANNOT_DIR / f'{org_id}_specific_phenotypes.csv'
        if not (pheno_file.exists() and pheno_file.stat().st_size > 0):
            pheno = spark.sql(f"""
                SELECT sp.locusId, sp.expName,
                       e.expDesc, e.expGroup, e.condition_1
                FROM kescience_fitnessbrowser.specificphenotype sp
                JOIN kescience_fitnessbrowser.experiment e
                    ON sp.orgId = e.orgId AND sp.expName = e.expName
                WHERE sp.orgId = '{org_id}'
            """).toPandas()
            pheno.to_csv(pheno_file, index=False)
            print(f"Phenotypes: {org_id} — {len(pheno)} entries")
        else:
            print(f"CACHED: {org_id} phenotypes")

## Part 2: Enrichment Analysis

Fisher exact test for each annotation term vs module membership.

In [None]:
def enrichment_analysis(module_genes, all_genes, annotation_map, min_annotated=3):
    """Fisher exact test enrichment for a single module.
    
    Parameters
    ----------
    module_genes : set
        Genes in the module.
    all_genes : set
        All genes in the organism.
    annotation_map : dict
        {term: set_of_genes} mapping.
    min_annotated : int
        Minimum annotated genes in module for testing.
    
    Returns
    -------
    results : list of dict
        Enrichment results per term.
    """
    results = []
    n_total = len(all_genes)
    n_module = len(module_genes)
    
    for term, term_genes in annotation_map.items():
        term_genes = term_genes & all_genes  # intersect with valid genes
        overlap = module_genes & term_genes
        
        if len(overlap) < min_annotated:
            continue
        
        # 2x2 contingency table
        a = len(overlap)                          # in module AND annotated
        b = len(module_genes - term_genes)         # in module NOT annotated
        c = len(term_genes - module_genes)          # NOT in module but annotated
        d = n_total - len(module_genes | term_genes)  # neither
        
        odds_ratio, p_value = scipy_stats.fisher_exact([[a, b], [c, d]],
                                                        alternative='greater')
        results.append({
            'term': term,
            'n_overlap': a,
            'n_module': n_module,
            'n_term': len(term_genes),
            'odds_ratio': odds_ratio,
            'p_value': p_value
        })
    
    return results

In [None]:
for org_id in pilot_ids:
    out_file = MODULE_DIR / f'{org_id}_module_annotations.csv'
    cond_file = MODULE_DIR / f'{org_id}_module_conditions.csv'
    
    if out_file.exists() and out_file.stat().st_size > 0:
        print(f"CACHED: {org_id} annotations")
        continue
    
    print(f"\nAnnotating {org_id} modules...")
    
    # Load membership
    membership = pd.read_csv(MODULE_DIR / f'{org_id}_gene_membership.csv', index_col=0)
    all_genes = set(membership.index.astype(str))
    module_names = membership.columns.tolist()
    
    # Load annotations and build term->gene maps
    annotation_maps = {}
    
    # KEGG
    kegg_file = ANNOT_DIR / f'{org_id}_kegg.csv'
    if kegg_file.exists():
        kegg = pd.read_csv(kegg_file)
        kegg_map = kegg.groupby('kgroup')['locusId'].apply(lambda x: set(x.astype(str))).to_dict()
        annotation_maps['KEGG'] = kegg_map
    
    # SEED
    seed_file = ANNOT_DIR / f'{org_id}_seed.csv'
    if seed_file.exists():
        seed = pd.read_csv(seed_file)
        if 'subsystem' in seed.columns:
            seed_map = seed.dropna(subset=['subsystem']).groupby('subsystem')['locusId'].apply(
                lambda x: set(x.astype(str))).to_dict()
            annotation_maps['SEED'] = seed_map
    
    # Domains (TIGRFam)
    domain_file = ANNOT_DIR / f'{org_id}_domains.csv'
    if domain_file.exists():
        domains = pd.read_csv(domain_file)
        tigr = domains[domains['domainDb'] == 'TIGRFam']
        if len(tigr) > 0:
            tigr_map = tigr.groupby('domainId')['locusId'].apply(
                lambda x: set(x.astype(str))).to_dict()
            annotation_maps['TIGRFam'] = tigr_map
    
    # Run enrichment for each module
    all_results = []
    for mod in module_names:
        mod_genes = set(membership.index[membership[mod] == 1].astype(str))
        if len(mod_genes) == 0:
            continue
        
        for db_name, term_map in annotation_maps.items():
            results = enrichment_analysis(mod_genes, all_genes, term_map)
            for r in results:
                r['module'] = mod
                r['database'] = db_name
            all_results.extend(results)
    
    if all_results:
        enrich_df = pd.DataFrame(all_results)
        # FDR correction
        reject, fdr, _, _ = multipletests(enrich_df['p_value'], method='fdr_bh')
        enrich_df['fdr'] = fdr
        enrich_df['significant'] = reject
        enrich_df = enrich_df.sort_values(['module', 'fdr'])
        enrich_df.to_csv(out_file, index=False)
        
        n_sig = enrich_df['significant'].sum()
        n_modules_annotated = enrich_df[enrich_df['significant']]['module'].nunique()
        print(f"  {n_sig} significant enrichments across {n_modules_annotated} modules")
    else:
        print(f"  No enrichments found")
    
    # Map module activity to experiment conditions
    profiles = pd.read_csv(MODULE_DIR / f'{org_id}_module_profiles.csv', index_col=0)
    exp_meta = pd.read_csv(ANNOT_DIR / f'{org_id}_experiments.csv')
    
    condition_results = []
    for mod in profiles.index:
        activity = profiles.loc[mod]
        # Top 5 most activated experiments
        top_activated = activity.abs().nlargest(5)
        for exp_name, act_value in top_activated.items():
            exp_info = exp_meta[exp_meta['expName'] == exp_name]
            if len(exp_info) > 0:
                condition_results.append({
                    'module': mod,
                    'expName': exp_name,
                    'activity': float(activity[exp_name]),
                    'abs_activity': float(act_value),
                    'expDesc': exp_info.iloc[0].get('expDesc', ''),
                    'expGroup': exp_info.iloc[0].get('expGroup', ''),
                    'condition_1': exp_info.iloc[0].get('condition_1', '')
                })
    
    if condition_results:
        cond_df = pd.DataFrame(condition_results)
        cond_df.to_csv(cond_file, index=False)
        print(f"  Saved condition mappings")

In [None]:
# Summary: top enrichments per organism
for org_id in pilot_ids:
    ann_file = MODULE_DIR / f'{org_id}_module_annotations.csv'
    if not ann_file.exists():
        continue
    ann = pd.read_csv(ann_file)
    sig = ann[ann['significant']]
    print(f"\n{org_id}: {len(sig)} significant enrichments")
    # Show top enrichment per module
    top = sig.groupby('module').first().reset_index()
    if len(top) > 0:
        print(top[['module', 'database', 'term', 'n_overlap', 'odds_ratio', 'fdr']].head(10).to_string(index=False))