# NB 07: Benchmarking Module-Based Predictions

Evaluate module-based function predictions against baselines:
1. **Cofitness voting**: top-N cofit partners → majority vote
2. **Ortholog transfer**: BBH annotation transfer
3. **Domain-only**: TIGRFam/PFam classification

Evaluation: hold out 20% of annotated genes, predict, measure precision/recall/F1
at KEGG KO level.

Additional validation:
- Within-module cofitness density
- Genomic adjacency (operon proximity)
- Concordance with specific phenotype hits

**Run locally** — no Spark needed (uses extracted data).

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score, recall_score, f1_score
from collections import Counter

DATA_DIR = Path('../data')
MODULE_DIR = DATA_DIR / 'modules'
ANNOT_DIR = DATA_DIR / 'annotations'
ORTHO_DIR = DATA_DIR / 'orthologs'
PRED_DIR = DATA_DIR / 'predictions'
FIG_DIR = Path('../figures')
FIG_DIR.mkdir(exist_ok=True)

pilots = pd.read_csv(DATA_DIR / 'pilot_organisms.csv')
pilot_ids = pilots['orgId'].tolist()
print(f"Pilot organisms: {pilot_ids}")

## 1. Build Gold Standard from KEGG Annotations

In [None]:
# Build gold standard: genes with KEGG annotations
gold_standard = {}

for org_id in pilot_ids:
    kegg_file = ANNOT_DIR / f'{org_id}_kegg.csv'
    if not kegg_file.exists():
        continue
    kegg = pd.read_csv(kegg_file)
    kegg['locusId'] = kegg['locusId'].astype(str)
    
    # Gene → set of KEGG groups (a gene can have multiple KOs)
    gene_kegg = kegg.groupby('locusId')['kgroup'].apply(set).to_dict()
    gold_standard[org_id] = gene_kegg
    print(f"{org_id}: {len(gene_kegg)} genes with KEGG annotations")

print(f"\nTotal annotated genes: {sum(len(v) for v in gold_standard.values())}")

## 2. Hold-Out Split

In [None]:
np.random.seed(42)
holdout_frac = 0.2

train_genes = {}  # genes used for enrichment
test_genes = {}   # genes used for evaluation

for org_id, gene_kegg in gold_standard.items():
    all_loci = list(gene_kegg.keys())
    train_loci, test_loci = train_test_split(all_loci, test_size=holdout_frac,
                                              random_state=42)
    train_genes[org_id] = {l: gene_kegg[l] for l in train_loci}
    test_genes[org_id] = {l: gene_kegg[l] for l in test_loci}
    print(f"{org_id}: train={len(train_loci)}, test={len(test_loci)}")

## 3. Module-Based Predictions

In [None]:
def module_predict(org_id, test_loci, train_kegg, membership, weights, kegg_df):
    """Predict KEGG group for test genes using module enrichment."""
    # Build module → enriched KEGG terms using TRAINING genes only
    module_kegg = {}
    for mod in membership.columns:
        mod_genes = set(membership.index[membership[mod] == 1].astype(str))
        # Only use training genes for enrichment
        mod_train = mod_genes & set(train_kegg.keys())
        if not mod_train:
            continue
        # Most common KEGG group in module
        kegg_counts = Counter()
        for g in mod_train:
            for kg in train_kegg[g]:
                kegg_counts[kg] += 1
        if kegg_counts:
            module_kegg[mod] = kegg_counts.most_common(1)[0][0]
    
    predictions = {}
    for locus in test_loci:
        if locus not in membership.index.astype(str).tolist():
            continue
        gene_mods = membership.columns[membership.loc[membership.index.astype(str) == locus].iloc[0] == 1]
        if len(gene_mods) == 0:
            continue
        # Pick module with highest |weight|
        best_mod = None
        best_weight = 0
        for mod in gene_mods:
            w = abs(weights.loc[weights.index.astype(str) == locus, mod].values[0])
            if w > best_weight and mod in module_kegg:
                best_weight = w
                best_mod = mod
        if best_mod:
            predictions[locus] = module_kegg[best_mod]
    
    return predictions

## 4. Baseline: Cofitness Voting

In [None]:
def cofit_predict(org_id, test_loci, train_kegg, top_n=6):
    """Predict KEGG group using top-N cofitness partners (majority vote)."""
    # Load precomputed gene weights to get cofitness from fitness matrix
    fit_file = DATA_DIR / 'matrices' / f'{org_id}_fitness_matrix.csv'
    if not fit_file.exists():
        return {}
    
    fit_matrix = pd.read_csv(fit_file, index_col=0)
    fit_matrix.index = fit_matrix.index.astype(str)
    
    # Compute correlations on the fly for test genes
    predictions = {}
    for locus in test_loci:
        if locus not in fit_matrix.index:
            continue
        gene_profile = fit_matrix.loc[locus].values
        
        # Correlate with all training genes
        train_loci_in_matrix = [l for l in train_kegg if l in fit_matrix.index]
        if not train_loci_in_matrix:
            continue
        
        corrs = []
        for tl in train_loci_in_matrix:
            r = np.corrcoef(gene_profile, fit_matrix.loc[tl].values)[0, 1]
            corrs.append((tl, r))
        corrs.sort(key=lambda x: abs(x[1]), reverse=True)
        top_partners = corrs[:top_n]
        
        # Majority vote
        kegg_votes = Counter()
        for partner, _ in top_partners:
            for kg in train_kegg[partner]:
                kegg_votes[kg] += 1
        if kegg_votes:
            predictions[locus] = kegg_votes.most_common(1)[0][0]
    
    return predictions

## 5. Baseline: Ortholog Transfer

In [None]:
def ortholog_predict(org_id, test_loci, all_train_kegg, bbh_pairs):
    """Predict KEGG group by transferring from BBH ortholog in other organism."""
    predictions = {}
    org_bbh = bbh_pairs[bbh_pairs['orgId1'] == org_id]
    
    for locus in test_loci:
        hits = org_bbh[org_bbh['locusId1'].astype(str) == locus]
        if len(hits) == 0:
            continue
        
        # Check if any ortholog has a KEGG annotation (from training set)
        for _, hit in hits.iterrows():
            other_org = hit['orgId2']
            other_locus = str(hit['locusId2'])
            if other_org in all_train_kegg and other_locus in all_train_kegg[other_org]:
                # Transfer first KEGG annotation
                predictions[locus] = list(all_train_kegg[other_org][other_locus])[0]
                break
    
    return predictions

## 6. Baseline: Domain-Only

In [None]:
def domain_predict(org_id, test_loci, train_kegg, domains_df):
    """Predict KEGG group from domain annotations.
    Build domain→KEGG mapping from training genes, apply to test."""
    if domains_df is None or len(domains_df) == 0:
        return {}
    
    domains_df['locusId'] = domains_df['locusId'].astype(str)
    
    # Build domain → KEGG mapping from training genes
    domain_kegg = {}
    for locus, kegg_groups in train_kegg.items():
        gene_domains = domains_df[domains_df['locusId'] == locus]['domainId'].tolist()
        for dom in gene_domains:
            if dom not in domain_kegg:
                domain_kegg[dom] = Counter()
            for kg in kegg_groups:
                domain_kegg[dom][kg] += 1
    
    # Predict test genes
    predictions = {}
    for locus in test_loci:
        gene_domains = domains_df[domains_df['locusId'] == locus]['domainId'].tolist()
        kegg_votes = Counter()
        for dom in gene_domains:
            if dom in domain_kegg:
                kegg_votes.update(domain_kegg[dom])
        if kegg_votes:
            predictions[locus] = kegg_votes.most_common(1)[0][0]
    
    return predictions

## 7. Run All Methods & Compare

In [None]:
# Load shared data
bbh_file = ORTHO_DIR / 'pilot_bbh_pairs.csv'
bbh_pairs = pd.read_csv(bbh_file) if bbh_file.exists() else pd.DataFrame()

results = []

for org_id in pilot_ids:
    if org_id not in test_genes:
        continue
    
    test_loci = list(test_genes[org_id].keys())
    true_labels = test_genes[org_id]  # dict: locus → set of KEGG groups
    
    print(f"\n{'='*60}")
    print(f"{org_id}: {len(test_loci)} test genes")
    
    # Load data
    membership = pd.read_csv(MODULE_DIR / f'{org_id}_gene_membership.csv', index_col=0)
    weights = pd.read_csv(MODULE_DIR / f'{org_id}_gene_weights.csv', index_col=0)
    kegg = pd.read_csv(ANNOT_DIR / f'{org_id}_kegg.csv')
    domain_file = ANNOT_DIR / f'{org_id}_domains.csv'
    domains = pd.read_csv(domain_file) if domain_file.exists() else None
    
    methods = {
        'Module-ICA': module_predict(org_id, test_loci, train_genes[org_id],
                                     membership, weights, kegg),
        'Cofitness': cofit_predict(org_id, test_loci, train_genes[org_id]),
        'Ortholog': ortholog_predict(org_id, test_loci, train_genes, bbh_pairs),
        'Domain': domain_predict(org_id, test_loci, train_genes[org_id], domains),
    }
    
    for method_name, preds in methods.items():
        if not preds:
            print(f"  {method_name}: no predictions")
            continue
        
        # Evaluate: is predicted KEGG group in gene's true set?
        n_correct = 0
        n_predicted = 0
        for locus, pred_kg in preds.items():
            if locus in true_labels:
                n_predicted += 1
                if pred_kg in true_labels[locus]:
                    n_correct += 1
        
        precision = n_correct / n_predicted if n_predicted > 0 else 0
        coverage = n_predicted / len(test_loci)
        
        results.append({
            'orgId': org_id,
            'method': method_name,
            'n_test': len(test_loci),
            'n_predicted': n_predicted,
            'n_correct': n_correct,
            'precision': precision,
            'coverage': coverage
        })
        print(f"  {method_name}: precision={precision:.3f}, coverage={coverage:.3f} "
              f"({n_correct}/{n_predicted} correct, {n_predicted}/{len(test_loci)} covered)")

results_df = pd.DataFrame(results)
results_df.to_csv(PRED_DIR / 'benchmark_results.csv', index=False)

## 8. Visualize Results

In [None]:
# Aggregate across organisms
agg = results_df.groupby('method').agg(
    mean_precision=('precision', 'mean'),
    std_precision=('precision', 'std'),
    mean_coverage=('coverage', 'mean'),
    total_correct=('n_correct', 'sum'),
    total_predicted=('n_predicted', 'sum')
).reset_index()
agg['overall_precision'] = agg['total_correct'] / agg['total_predicted']

print("\nAggregate Results:")
print(agg.to_string(index=False))

# Bar chart
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

methods_order = ['Module-ICA', 'Cofitness', 'Ortholog', 'Domain']
colors = ['#2196F3', '#FF9800', '#4CAF50', '#9C27B0']

# Precision
for i, method in enumerate(methods_order):
    row = agg[agg['method'] == method]
    if len(row) > 0:
        ax1.bar(i, row['mean_precision'].values[0], color=colors[i],
                yerr=row['std_precision'].values[0], capsize=5)
ax1.set_xticks(range(len(methods_order)))
ax1.set_xticklabels(methods_order, rotation=30, ha='right')
ax1.set_ylabel('Precision')
ax1.set_title('KEGG Group Prediction Precision')
ax1.set_ylim(0, 1)

# Coverage
for i, method in enumerate(methods_order):
    row = agg[agg['method'] == method]
    if len(row) > 0:
        ax2.bar(i, row['mean_coverage'].values[0], color=colors[i])
ax2.set_xticks(range(len(methods_order)))
ax2.set_xticklabels(methods_order, rotation=30, ha='right')
ax2.set_ylabel('Coverage')
ax2.set_title('Prediction Coverage (fraction of test genes)')
ax2.set_ylim(0, 1)

plt.tight_layout()
plt.savefig(FIG_DIR / 'benchmark_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

## 9. Validation: Within-Module Cofitness Density

In [None]:
# Check that module genes are more cofit than random gene pairs
for org_id in pilot_ids:
    fit_file = DATA_DIR / 'matrices' / f'{org_id}_fitness_matrix.csv'
    member_file = MODULE_DIR / f'{org_id}_gene_membership.csv'
    
    if not fit_file.exists() or not member_file.exists():
        continue
    
    fit_matrix = pd.read_csv(fit_file, index_col=0)
    membership = pd.read_csv(member_file, index_col=0)
    
    # Compute correlation matrix (subsample for speed)
    n_sample = min(500, len(fit_matrix))
    sample_idx = np.random.choice(len(fit_matrix), n_sample, replace=False)
    sample_corr = np.corrcoef(fit_matrix.values[sample_idx])
    
    # Within-module vs random correlations
    within_corrs = []
    for mod in membership.columns:
        mod_genes = membership.index[membership[mod] == 1]
        mod_idx = [i for i, g in enumerate(fit_matrix.index[sample_idx])
                   if g in mod_genes.values]
        if len(mod_idx) >= 2:
            for i in range(len(mod_idx)):
                for j in range(i+1, len(mod_idx)):
                    within_corrs.append(sample_corr[mod_idx[i], mod_idx[j]])
    
    # Random pairs
    random_corrs = []
    for _ in range(len(within_corrs)):
        i, j = np.random.choice(n_sample, 2, replace=False)
        random_corrs.append(sample_corr[i, j])
    
    if within_corrs and random_corrs:
        print(f"{org_id}:")
        print(f"  Within-module cofitness: mean={np.mean(within_corrs):.3f}, "
              f"median={np.median(within_corrs):.3f} (n={len(within_corrs)})")
        print(f"  Random pairs cofitness: mean={np.mean(random_corrs):.3f}, "
              f"median={np.median(random_corrs):.3f}")

## 10. Validation: Genomic Adjacency

In [None]:
# Check if module members show elevated genomic proximity (operon-like)
for org_id in pilot_ids:
    gene_file = ANNOT_DIR / f'{org_id}_genes.csv'
    member_file = MODULE_DIR / f'{org_id}_gene_membership.csv'
    
    if not gene_file.exists() or not member_file.exists():
        continue
    
    genes = pd.read_csv(gene_file)
    genes['locusId'] = genes['locusId'].astype(str)
    membership = pd.read_csv(member_file, index_col=0)
    membership.index = membership.index.astype(str)
    
    # Sort genes by genomic position
    genes = genes.sort_values(['scaffoldId', 'begin'])
    genes['gene_order'] = range(len(genes))
    locus_to_order = dict(zip(genes['locusId'], genes['gene_order']))
    
    # Count adjacent pairs within modules vs expected by chance
    n_adjacent_within = 0
    n_pairs_within = 0
    
    for mod in membership.columns:
        mod_genes = membership.index[membership[mod] == 1].tolist()
        orders = sorted([locus_to_order[g] for g in mod_genes if g in locus_to_order])
        
        for i in range(len(orders)):
            for j in range(i+1, len(orders)):
                n_pairs_within += 1
                if abs(orders[i] - orders[j]) <= 3:  # within 3 genes
                    n_adjacent_within += 1
    
    # Expected adjacency by chance
    n_total_genes = len(genes)
    expected_adj_rate = 6 / n_total_genes  # ±3 neighbors out of N genes
    
    if n_pairs_within > 0:
        observed_rate = n_adjacent_within / n_pairs_within
        enrichment = observed_rate / expected_adj_rate if expected_adj_rate > 0 else 0
        print(f"{org_id}: adjacency enrichment = {enrichment:.1f}× "
              f"(observed={observed_rate:.4f}, expected={expected_adj_rate:.4f})")

In [None]:
print("="*60)
print("BENCHMARKING COMPLETE")
print("="*60)
print(f"Results saved: {PRED_DIR / 'benchmark_results.csv'}")
print(f"Figure saved: {FIG_DIR / 'benchmark_comparison.png'}")