# NB03: Module Co-occurrence Validation

**Project**: Prophage Ecology Across Bacterial Phylogeny and Environmental Gradients

**Goal**: Test H1c — do genes within each predefined prophage module co-occur across genomes significantly more than random gene groupings of equal size? Also test contig co-localization.

**Dependencies**: NB01 outputs (`data/prophage_gene_clusters.tsv`, `data/species_module_summary.tsv`)

**Environment**: Requires BERDL JupyterHub (Spark SQL) for genome-level extraction

**Outputs**:
- `data/module_cooccurrence_stats.tsv` — co-occurrence statistics per module per species
- `data/contig_colocation.tsv` — contig co-localization statistics
- `figures/module_cooccurrence_heatmap.png`

In [None]:
import sys
import os
import pandas as pd
import numpy as np
from scipy import stats
from itertools import combinations

spark = get_spark_session()

sys.path.insert(0, '../src')
from prophage_utils import MODULES

os.makedirs('../data', exist_ok=True)
os.makedirs('../figures', exist_ok=True)

# Load NB01 outputs
prophage_clusters = pd.read_csv('../data/prophage_gene_clusters.tsv', sep='\t')
species_summary = pd.read_csv('../data/species_module_summary.tsv', sep='\t')

print(f'Prophage clusters: {len(prophage_clusters):,}')
print(f'Species with prophage: {len(species_summary):,}')

## 1. Select Representative Species

Choose ~50 species spanning phylogenetic diversity with >=20 genomes each and multiple prophage modules present.

In [None]:
# Get taxonomy for species selection
taxonomy = spark.sql("""
    SELECT DISTINCT g.gtdb_species_clade_id, t.phylum, t.class, t.order, t.family
    FROM kbase_ke_pangenome.genome g
    JOIN kbase_ke_pangenome.gtdb_taxonomy_r214v1 t ON g.gtdb_taxonomy_id = t.gtdb_taxonomy_id
""").toPandas()

# Merge with species summary
candidates = species_summary.merge(taxonomy, on='gtdb_species_clade_id', how='left')

# Filter: >=20 genomes, >=3 modules present, >=10 prophage clusters
candidates = candidates[
    (candidates['no_genomes'] >= 20) &
    (candidates['n_modules_present'] >= 3) &
    (candidates['n_prophage_clusters'] >= 10)
]

print(f'Candidate species (>=20 genomes, >=3 modules, >=10 clusters): {len(candidates):,}')

# Stratified sampling: pick ~50 species across phyla
target_n = 50
selected = candidates.groupby('phylum').apply(
    lambda x: x.nlargest(max(1, int(target_n * len(x) / len(candidates))), 'n_prophage_clusters'),
    include_groups=False
).reset_index(drop=True)

# If we have too few, add more from the largest phyla
if len(selected) < target_n:
    remaining = candidates[~candidates['gtdb_species_clade_id'].isin(selected['gtdb_species_clade_id'])]
    extra = remaining.nlargest(target_n - len(selected), 'n_prophage_clusters')
    selected = pd.concat([selected, extra])

# Cap at 50
selected = selected.head(target_n)

print(f'\nSelected species: {len(selected)}')
print(f'Phyla represented: {selected["phylum"].nunique()}')
print(f'Genome range: {selected["no_genomes"].min()}-{selected["no_genomes"].max()}')
print(f'Module range: {selected["n_modules_present"].min()}-{selected["n_modules_present"].max()}')

## 2. Extract Genome × Cluster Presence Matrices

For each selected species, extract which prophage gene clusters are present in which genomes. This requires joining the billion-row `gene` and `gene_genecluster_junction` tables, filtered per species.

In [None]:
# For each species, extract genome-level prophage cluster presence
# Strategy: get genome IDs for species, get prophage cluster IDs for species,
# then query the junction table

all_cooccurrence_results = []
all_colocation_results = []

for idx, species_row in selected.iterrows():
    species_id = species_row['gtdb_species_clade_id']
    species_prefix = species_id.split('--')[0]
    
    print(f'\n[{idx+1}/{len(selected)}] {species_id} ({species_row["no_genomes"]} genomes, '
          f'{species_row["n_prophage_clusters"]} prophage clusters)')
    
    # Get prophage cluster IDs for this species
    sp_clusters = prophage_clusters[
        prophage_clusters['gtdb_species_clade_id'] == species_id
    ]
    if len(sp_clusters) < 5:
        print(f'  Skipping: too few prophage clusters ({len(sp_clusters)})')
        continue
    
    prophage_cluster_ids = sp_clusters['gene_cluster_id'].tolist()
    
    # Register as temp view for BROADCAST hint
    spark.createDataFrame(
        [(c,) for c in prophage_cluster_ids], ['gene_cluster_id']
    ).createOrReplaceTempView('target_clusters')
    
    # Query genome × cluster presence via gene + junction tables
    try:
        presence_df = spark.sql(f"""
            SELECT /*+ BROADCAST(tc) */
                g.genome_id, j.gene_cluster_id, g.gene_id
            FROM kbase_ke_pangenome.gene g
            JOIN kbase_ke_pangenome.gene_genecluster_junction j ON g.gene_id = j.gene_id
            JOIN target_clusters tc ON j.gene_cluster_id = tc.gene_cluster_id
            WHERE g.genome_id IN (
                SELECT genome_id FROM kbase_ke_pangenome.genome
                WHERE gtdb_species_clade_id = '{species_id}'
            )
        """).toPandas()
    except Exception as e:
        print(f'  ERROR: {e}')
        continue
    
    if len(presence_df) == 0:
        print(f'  No presence data found')
        continue
    
    n_genomes = presence_df['genome_id'].nunique()
    n_clusters_found = presence_df['gene_cluster_id'].nunique()
    print(f'  Found {len(presence_df):,} gene-cluster pairs across {n_genomes} genomes, {n_clusters_found} clusters')
    
    # Build binary presence/absence matrix: genomes × prophage clusters
    presence_binary = presence_df.groupby(['genome_id', 'gene_cluster_id']).size().unstack(fill_value=0)
    presence_binary = (presence_binary > 0).astype(int)
    
    # === MODULE CO-OCCURRENCE TEST ===
    # For each module, compute mean pairwise Jaccard between member clusters
    # Compare to null: random gene sets of same size
    
    for module_id in MODULES.keys():
        # Get clusters in this module for this species
        module_mask = sp_clusters[f'has_{module_id}'] == True
        module_clusters = sp_clusters.loc[module_mask, 'gene_cluster_id'].tolist()
        # Only those present in the matrix
        module_clusters = [c for c in module_clusters if c in presence_binary.columns]
        
        if len(module_clusters) < 2:
            continue
        
        # Observed mean Jaccard
        module_matrix = presence_binary[module_clusters]
        observed_jaccard = _mean_pairwise_jaccard(module_matrix)
        
        # Null: 500 random gene sets of same size from all prophage clusters in species
        all_available = [c for c in prophage_cluster_ids if c in presence_binary.columns]
        null_jaccards = []
        for _ in range(500):
            random_set = np.random.choice(all_available, size=len(module_clusters), replace=False)
            null_matrix = presence_binary[random_set]
            null_jaccards.append(_mean_pairwise_jaccard(null_matrix))
        
        null_mean = np.mean(null_jaccards)
        null_std = np.std(null_jaccards)
        z_score = (observed_jaccard - null_mean) / null_std if null_std > 0 else 0
        p_value = 1 - stats.norm.cdf(z_score)  # one-sided: is observed > null?
        
        all_cooccurrence_results.append({
            'gtdb_species_clade_id': species_id,
            'module': module_id,
            'n_clusters': len(module_clusters),
            'n_genomes': n_genomes,
            'observed_jaccard': observed_jaccard,
            'null_mean': null_mean,
            'null_std': null_std,
            'z_score': z_score,
            'p_value': p_value,
        })
    
    # === CONTIG CO-LOCALIZATION TEST ===
    # Parse gene_id to extract contig info (format: CONTIG_GENENUM)
    presence_df['contig'] = presence_df['gene_id'].apply(
        lambda x: '_'.join(x.rsplit('_', 1)[:-1]) if '_' in x else x
    )
    
    for module_id in MODULES.keys():
        module_mask = sp_clusters[f'has_{module_id}'] == True
        module_clusters = set(sp_clusters.loc[module_mask, 'gene_cluster_id'].tolist())
        
        if len(module_clusters) < 2:
            continue
        
        # For each genome, check if module genes are on the same contig
        module_genes = presence_df[presence_df['gene_cluster_id'].isin(module_clusters)]
        
        n_genomes_with_module = 0
        n_colocalized = 0
        for genome_id, ggrp in module_genes.groupby('genome_id'):
            if ggrp['gene_cluster_id'].nunique() >= 2:
                n_genomes_with_module += 1
                # Check if majority of module genes are on the same contig
                contig_counts = ggrp.groupby('contig')['gene_cluster_id'].nunique()
                max_same_contig = contig_counts.max()
                if max_same_contig >= 2:
                    n_colocalized += 1
        
        coloc_frac = n_colocalized / n_genomes_with_module if n_genomes_with_module > 0 else 0
        all_colocation_results.append({
            'gtdb_species_clade_id': species_id,
            'module': module_id,
            'n_genomes_with_module': n_genomes_with_module,
            'n_colocalized': n_colocalized,
            'colocation_fraction': coloc_frac,
        })


def _mean_pairwise_jaccard(binary_matrix):
    """Compute mean pairwise Jaccard similarity between columns of a binary matrix."""
    cols = binary_matrix.values.T  # each row is a cluster's presence vector
    n = cols.shape[0]
    if n < 2:
        return 0.0
    jaccards = []
    for i in range(n):
        for j in range(i+1, n):
            intersection = np.sum(cols[i] & cols[j])
            union = np.sum(cols[i] | cols[j])
            jaccards.append(intersection / union if union > 0 else 0)
    return np.mean(jaccards)

## 3. Analyze Co-occurrence Results

In [None]:
cooc_df = pd.DataFrame(all_cooccurrence_results)
coloc_df = pd.DataFrame(all_colocation_results)

print(f'Co-occurrence tests: {len(cooc_df):,}')
print(f'Co-localization tests: {len(coloc_df):,}')

# Summarize by module
print('\n--- Co-occurrence Summary by Module ---')
for module_id in sorted(MODULES.keys()):
    mod_data = cooc_df[cooc_df['module'] == module_id]
    if len(mod_data) == 0:
        continue
    n_sig = (mod_data['p_value'] < 0.05).sum()
    mean_z = mod_data['z_score'].mean()
    mean_obs = mod_data['observed_jaccard'].mean()
    mean_null = mod_data['null_mean'].mean()
    print(f'  {module_id}: {n_sig}/{len(mod_data)} significant (p<0.05), '
          f'mean z={mean_z:.2f}, obs Jaccard={mean_obs:.3f} vs null={mean_null:.3f}')

print('\n--- Contig Co-localization Summary by Module ---')
for module_id in sorted(MODULES.keys()):
    mod_data = coloc_df[coloc_df['module'] == module_id]
    if len(mod_data) == 0:
        continue
    mean_frac = mod_data['colocation_fraction'].mean()
    print(f'  {module_id}: mean co-localization fraction = {mean_frac:.3f} '
          f'({len(mod_data)} species tested)')

In [None]:
# Save results
cooc_df.to_csv('../data/module_cooccurrence_stats.tsv', sep='\t', index=False)
coloc_df.to_csv('../data/contig_colocation.tsv', sep='\t', index=False)
print(f'Saved data/module_cooccurrence_stats.tsv: {len(cooc_df):,} rows')
print(f'Saved data/contig_colocation.tsv: {len(coloc_df):,} rows')

In [None]:
# Visualization
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns

fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Panel 1: Z-scores by module (box plot)
module_order = sorted(MODULES.keys())
module_names = [MODULES[m]['full_name'] for m in module_order]

plot_data = cooc_df[cooc_df['module'].isin(module_order)].copy()
plot_data['module_name'] = plot_data['module'].map({m: MODULES[m]['full_name'] for m in module_order})

sns.boxplot(data=plot_data, y='module_name', x='z_score', ax=axes[0], orient='h', color='steelblue')
axes[0].axvline(x=1.96, color='red', linestyle='--', alpha=0.5, label='z=1.96 (p<0.05)')
axes[0].set_xlabel('Z-score (observed vs null co-occurrence)')
axes[0].set_ylabel('')
axes[0].set_title('Module Co-occurrence Significance')
axes[0].legend()

# Panel 2: Co-localization fraction by module
coloc_plot = coloc_df[coloc_df['module'].isin(module_order)].copy()
coloc_plot['module_name'] = coloc_plot['module'].map({m: MODULES[m]['full_name'] for m in module_order})

sns.boxplot(data=coloc_plot, y='module_name', x='colocation_fraction', ax=axes[1], orient='h', color='darkorange')
axes[1].set_xlabel('Contig Co-localization Fraction')
axes[1].set_ylabel('')
axes[1].set_title('Module Genes on Same Contig')

plt.tight_layout()
plt.savefig('../figures/module_cooccurrence_heatmap.png', dpi=150, bbox_inches='tight')
plt.show()
print('Saved figures/module_cooccurrence_heatmap.png')

In [None]:
# Summary
print('='*60)
print('NB03 SUMMARY')
print('='*60)
print(f'Species tested: {cooc_df["gtdb_species_clade_id"].nunique()}')
print(f'Module × species tests: {len(cooc_df)}')
n_sig = (cooc_df['p_value'] < 0.05).sum()
print(f'Significant co-occurrence (p<0.05): {n_sig}/{len(cooc_df)} ({n_sig/len(cooc_df)*100:.1f}%)')
print(f'Mean co-localization fraction: {coloc_df["colocation_fraction"].mean():.3f}')