# NB06: Environment-Enriched Modules & Lineages

**Project**: Prophage Ecology Across Bacterial Phylogeny and Environmental Gradients

**Goal**: Identify specific prophage modules and TerL-defined lineages whose environmental enrichment exceeds phylogenetic expectation. Use null models with constrained permutations preserving host family composition and genome size distribution.

**Dependencies**: NB04 (`data/species_prophage_environment.tsv`), NB05 (`data/nmdc_prophage_prevalence.tsv`, `data/nmdc_module_by_environment.tsv`), NB02 (`data/terL_lineages.tsv`)

**Environment**: Local (all data cached from previous notebooks)

**Outputs**:
- `data/enriched_modules.tsv` — modules with environment enrichment exceeding phylogenetic expectation
- `data/enriched_lineages.tsv` — TerL lineages with environment enrichment
- `figures/module_environment_enrichment.png`
- `figures/lineage_environment_heatmap.png`

In [None]:
import sys
import os
import pandas as pd
import numpy as np
from scipy import stats
from statsmodels.stats.multitest import multipletests

sys.path.insert(0, '../src')
from prophage_utils import MODULES

os.makedirs('../data', exist_ok=True)
os.makedirs('../figures', exist_ok=True)

# Load NB04 merged species data
species_all = pd.read_csv('../data/species_prophage_environment.tsv', sep='\t')
print(f'Species with prophage + environment: {len(species_all):,}')

# Load NB02 lineage data
try:
    lineages = pd.read_csv('../data/terL_lineages.tsv', sep='\t')
    print(f'TerL lineage assignments: {len(lineages):,}')
    print(f'Unique lineages: {lineages["lineage_id"].nunique():,}')
except FileNotFoundError:
    lineages = None
    print('NB02 outputs not found — lineage enrichment will be skipped')

# Load NB05 NMDC results for cross-validation
try:
    nmdc_module_corr = pd.read_csv('../data/nmdc_module_by_environment.tsv', sep='\t')
    print(f'NMDC module-abiotic correlations: {len(nmdc_module_corr):,}')
except FileNotFoundError:
    nmdc_module_corr = None
    print('NB05 outputs not found — NMDC cross-validation will be skipped')

module_ids = sorted(MODULES.keys())

## 1. Module × Environment Enrichment with Null Models

For each module × environment combination:
1. Compute observed odds ratio (OR) of module presence in that environment vs others
2. Generate null distribution by permuting environment labels 1000× while preserving:
   - Host family-level composition (permute only within families)
   - Genome size distribution (permute within genome size quartiles)
3. Z-score = (observed OR - mean null OR) / std null OR
4. FDR correction across all tests

In [None]:
# Prepare data: filter to species with known environment, family, and genome size
analysis_df = species_all[
    species_all['primary_env'].notna() &
    (species_all['primary_env'] != 'other_unknown') &
    species_all['family'].notna() &
    species_all['genome_size_Mbp'].notna()
].copy()

# Create genome size quartiles for stratification
analysis_df['size_quartile'] = pd.qcut(
    analysis_df['genome_size_Mbp'], 4,
    labels=['Q1', 'Q2', 'Q3', 'Q4']
)

# Create stratification variable: family × size_quartile
analysis_df['stratum'] = analysis_df['family'] + '_' + analysis_df['size_quartile'].astype(str)

# Filter environments with enough species
env_counts = analysis_df['primary_env'].value_counts()
valid_envs = env_counts[env_counts >= 30].index.tolist()
analysis_df = analysis_df[analysis_df['primary_env'].isin(valid_envs)]

print(f'Species for enrichment analysis: {len(analysis_df):,}')
print(f'Families: {analysis_df["family"].nunique()}')
print(f'Environments: {analysis_df["primary_env"].nunique()}')
print(f'Strata (family × size_quartile): {analysis_df["stratum"].nunique()}')
print(f'\nEnvironment distribution:')
print(analysis_df['primary_env'].value_counts())

In [None]:
def compute_odds_ratio(df, module_col, env_col, target_env):
    """Compute odds ratio: module presence in target_env vs all others."""
    in_env = df[env_col] == target_env
    has_module = df[module_col] == True
    
    a = (in_env & has_module).sum()       # in env + has module
    b = (in_env & ~has_module).sum()      # in env + no module
    c = (~in_env & has_module).sum()      # not in env + has module
    d = (~in_env & ~has_module).sum()     # not in env + no module
    
    # Add 0.5 Haldane correction to avoid division by zero
    odds_ratio = ((a + 0.5) * (d + 0.5)) / ((b + 0.5) * (c + 0.5))
    return np.log2(odds_ratio), a, b, c, d


def constrained_permutation(df, env_col, strata_col):
    """Permute environment labels within strata (family × genome_size_quartile)."""
    permuted = df[env_col].copy()
    for stratum, idx in df.groupby(strata_col).groups.items():
        if len(idx) > 1:
            permuted.iloc[idx] = np.random.permutation(permuted.iloc[idx].values)
    return permuted


print('Functions defined for null model testing.')

In [None]:
# Run null model enrichment tests for each module × environment combination
N_PERMUTATIONS = 1000
np.random.seed(42)

enrichment_results = []

for module_id in module_ids:
    has_col = f'has_{module_id}'
    if has_col not in analysis_df.columns:
        continue
    
    for target_env in valid_envs:
        # Observed odds ratio (log2)
        obs_log_or, a, b, c, d = compute_odds_ratio(
            analysis_df, has_col, 'primary_env', target_env
        )
        
        # Null distribution: permute environment labels within strata
        null_log_ors = []
        for _ in range(N_PERMUTATIONS):
            perm_env = constrained_permutation(analysis_df, 'primary_env', 'stratum')
            perm_df = analysis_df.copy()
            perm_df['perm_env'] = perm_env
            null_or, _, _, _, _ = compute_odds_ratio(
                perm_df, has_col, 'perm_env', target_env
            )
            null_log_ors.append(null_or)
        
        null_mean = np.mean(null_log_ors)
        null_std = np.std(null_log_ors)
        z_score = (obs_log_or - null_mean) / null_std if null_std > 0 else 0
        
        # Empirical p-value (two-sided)
        n_extreme = sum(1 for x in null_log_ors if abs(x) >= abs(obs_log_or))
        p_empirical = (n_extreme + 1) / (N_PERMUTATIONS + 1)
        
        enrichment_results.append({
            'module': module_id,
            'module_name': MODULES[module_id]['full_name'],
            'environment': target_env,
            'observed_log2_OR': obs_log_or,
            'null_mean_log2_OR': null_mean,
            'null_std_log2_OR': null_std,
            'z_score': z_score,
            'p_empirical': p_empirical,
            'n_env_module': a,
            'n_env_no_module': b,
            'n_other_module': c,
            'n_other_no_module': d,
        })
    
    print(f'  Completed {module_id}')

enrichment_df = pd.DataFrame(enrichment_results)
print(f'\nTotal module × environment tests: {len(enrichment_df)}')

In [None]:
# FDR correction
if len(enrichment_df) > 0:
    reject, pvals_corrected, _, _ = multipletests(
        enrichment_df['p_empirical'], method='fdr_bh'
    )
    enrichment_df['p_fdr'] = pvals_corrected
    enrichment_df['significant_fdr'] = reject

# Show significant enrichments
sig_enrichments = enrichment_df[
    enrichment_df['significant_fdr'] == True
].sort_values('z_score', ascending=False)

print(f'Significant module × environment enrichments (FDR < 0.05): {len(sig_enrichments)}')
if len(sig_enrichments) > 0:
    print('\nTop enrichments (highest z-score):')
    display_cols = ['module_name', 'environment', 'observed_log2_OR', 'z_score', 'p_fdr']
    print(sig_enrichments[display_cols].head(20).to_string(index=False))

# Show significant depletions
sig_depletions = enrichment_df[
    (enrichment_df['significant_fdr'] == True) &
    (enrichment_df['z_score'] < 0)
].sort_values('z_score')

if len(sig_depletions) > 0:
    print('\nTop depletions (most negative z-score):')
    print(sig_depletions[display_cols].head(10).to_string(index=False))

## 2. Lineage × Environment Enrichment

Test which TerL-defined lineages are enriched in specific environments beyond phylogenetic expectation.

In [None]:
lineage_enrichment_results = []

if lineages is not None:
    # Merge lineage data with species environment info
    lineage_species = lineages.merge(
        analysis_df[['gtdb_species_clade_id', 'primary_env', 'family',
                      'genome_size_Mbp', 'stratum']],
        on='gtdb_species_clade_id', how='inner'
    )
    
    # Get lineages with sufficient representation (>= 10 species)
    lineage_counts = lineage_species['lineage_id'].value_counts()
    major_lineages = lineage_counts[lineage_counts >= 10].index.tolist()
    print(f'Lineages with >= 10 species in analysis set: {len(major_lineages)}')
    
    # Create binary lineage presence per species
    species_lineage = lineage_species.groupby(
        ['gtdb_species_clade_id', 'lineage_id']
    ).size().reset_index(name='count')
    
    for lineage_id in major_lineages[:50]:  # cap at 50 lineages to keep runtime manageable
        # Which species have this lineage?
        species_with_lineage = set(
            species_lineage[species_lineage['lineage_id'] == lineage_id]['gtdb_species_clade_id']
        )
        analysis_df_copy = analysis_df.copy()
        analysis_df_copy['has_lineage'] = analysis_df_copy['gtdb_species_clade_id'].isin(species_with_lineage)
        
        for target_env in valid_envs:
            obs_log_or, a, b, c, d = compute_odds_ratio(
                analysis_df_copy, 'has_lineage', 'primary_env', target_env
            )
            
            # Constrained permutation null
            null_log_ors = []
            for _ in range(500):  # fewer permutations for lineages (more tests)
                perm_env = constrained_permutation(analysis_df_copy, 'primary_env', 'stratum')
                perm_df = analysis_df_copy.copy()
                perm_df['perm_env'] = perm_env
                null_or, _, _, _, _ = compute_odds_ratio(
                    perm_df, 'has_lineage', 'perm_env', target_env
                )
                null_log_ors.append(null_or)
            
            null_mean = np.mean(null_log_ors)
            null_std = np.std(null_log_ors)
            z_score = (obs_log_or - null_mean) / null_std if null_std > 0 else 0
            n_extreme = sum(1 for x in null_log_ors if abs(x) >= abs(obs_log_or))
            p_empirical = (n_extreme + 1) / (501)
            
            lineage_enrichment_results.append({
                'lineage_id': lineage_id,
                'environment': target_env,
                'n_species_with_lineage': len(species_with_lineage),
                'observed_log2_OR': obs_log_or,
                'z_score': z_score,
                'p_empirical': p_empirical,
            })
        
        if major_lineages.index(lineage_id) % 10 == 0:
            print(f'  Completed {major_lineages.index(lineage_id)+1}/{min(50, len(major_lineages))} lineages')

lineage_enrich_df = pd.DataFrame(lineage_enrichment_results)

if len(lineage_enrich_df) > 0:
    reject, pvals_corrected, _, _ = multipletests(
        lineage_enrich_df['p_empirical'], method='fdr_bh'
    )
    lineage_enrich_df['p_fdr'] = pvals_corrected
    lineage_enrich_df['significant_fdr'] = reject
    
    sig_lineages = lineage_enrich_df[lineage_enrich_df['significant_fdr'] == True]
    print(f'\nSignificant lineage × environment enrichments (FDR < 0.05): {len(sig_lineages)}')
    if len(sig_lineages) > 0:
        print(sig_lineages.sort_values('z_score', ascending=False).head(15).to_string(index=False))
else:
    print('No lineage enrichment tests were performed')

## 3. Specialist vs Generalist Lineages

Classify lineages as environment-specialists (concentrated in 1-2 environments) vs generalists (broadly distributed).

In [None]:
if lineages is not None and len(lineage_species) > 0:
    # Compute environmental breadth per lineage
    lineage_env_dist = lineage_species.groupby('lineage_id')['primary_env'].value_counts(
        normalize=True
    ).unstack(fill_value=0)
    
    # Shannon entropy as environmental breadth metric
    def shannon_entropy(row):
        probs = row[row > 0].values
        return -np.sum(probs * np.log2(probs))
    
    lineage_env_dist['shannon'] = lineage_env_dist.apply(shannon_entropy, axis=1)
    lineage_env_dist['n_species'] = lineage_species.groupby('lineage_id')['gtdb_species_clade_id'].nunique()
    lineage_env_dist['dominant_env'] = lineage_env_dist.drop(
        columns=['shannon', 'n_species']
    ).idxmax(axis=1)
    lineage_env_dist['dominant_pct'] = lineage_env_dist.drop(
        columns=['shannon', 'n_species', 'dominant_env']
    ).max(axis=1) * 100
    
    # Filter to lineages with >= 5 species for meaningful breadth
    lineage_breadth = lineage_env_dist[
        lineage_env_dist['n_species'] >= 5
    ][['shannon', 'n_species', 'dominant_env', 'dominant_pct']].reset_index()
    
    # Classify: specialist (shannon < 1.0 or dominant_pct > 80%) vs generalist
    lineage_breadth['category'] = 'generalist'
    lineage_breadth.loc[
        (lineage_breadth['shannon'] < 1.0) | (lineage_breadth['dominant_pct'] > 80),
        'category'
    ] = 'specialist'
    
    print(f'Lineages with >= 5 species: {len(lineage_breadth):,}')
    print(f'\nSpecialist vs Generalist:')
    print(lineage_breadth['category'].value_counts())
    
    print(f'\nTop specialist lineages (lowest Shannon entropy):')
    top_specialists = lineage_breadth.nsmallest(10, 'shannon')
    print(top_specialists[['lineage_id', 'n_species', 'dominant_env',
                           'dominant_pct', 'shannon']].to_string(index=False))
    
    print(f'\nTop generalist lineages (highest Shannon entropy):')
    top_generalists = lineage_breadth.nlargest(10, 'shannon')
    print(top_generalists[['lineage_id', 'n_species', 'dominant_env',
                           'dominant_pct', 'shannon']].to_string(index=False))
else:
    lineage_breadth = None
    print('No lineage data available for specialist/generalist classification')

## 4. Cross-Validation with NMDC Data

Check whether pangenome-enriched module × environment patterns match the NMDC signal.

In [None]:
if nmdc_module_corr is not None and len(enrichment_df) > 0:
    # The pangenome tells us which modules are enriched in which environments
    # NMDC tells us which modules correlate with which abiotic variables
    # Cross-validation: do soil-enriched modules (pangenome) also correlate with
    # soil-related abiotic variables (NMDC) like pH, organic carbon, etc.?
    
    print('=== Cross-validation: Pangenome enrichment vs NMDC abiotic correlations ===')
    
    # Pangenome: significant module × environment enrichments
    print('\nPangenome: significant enrichments')
    if len(sig_enrichments) > 0:
        for _, row in sig_enrichments.head(10).iterrows():
            print(f'  {row["module_name"]} in {row["environment"]}: '
                  f'log2(OR)={row["observed_log2_OR"]:.2f}, z={row["z_score"]:.2f}')
    else:
        print('  No significant enrichments found')
    
    # NMDC: significant module × abiotic correlations
    nmdc_sig = nmdc_module_corr[nmdc_module_corr.get('significant_fdr', False) == True]
    print(f'\nNMDC: significant module-abiotic correlations: {len(nmdc_sig)}')
    if len(nmdc_sig) > 0:
        for _, row in nmdc_sig.head(10).iterrows():
            clean = row['abiotic_variable'].replace('annotations_', '').replace('_has_numeric_value', '')
            print(f'  {row["module_name"]} ~ {clean}: rho={row["spearman_rho"]:.3f}')
    
    # Module-level concordance: do the same modules show up in both?
    if len(sig_enrichments) > 0 and len(nmdc_sig) > 0:
        pangenome_modules = set(sig_enrichments['module'].unique())
        nmdc_modules = set(nmdc_sig['module'].unique())
        overlap = pangenome_modules & nmdc_modules
        print(f'\nModules significant in both pangenome and NMDC: {overlap}')
        print(f'Pangenome-only: {pangenome_modules - nmdc_modules}')
        print(f'NMDC-only: {nmdc_modules - pangenome_modules}')
else:
    print('NMDC cross-validation data not available')

## 5. Save Outputs

In [None]:
# Save module enrichment results
enrichment_df.to_csv('../data/enriched_modules.tsv', sep='\t', index=False)
print(f'Saved data/enriched_modules.tsv: {len(enrichment_df):,} rows')

# Save lineage enrichment results
if len(lineage_enrich_df) > 0:
    lineage_enrich_df.to_csv('../data/enriched_lineages.tsv', sep='\t', index=False)
    print(f'Saved data/enriched_lineages.tsv: {len(lineage_enrich_df):,} rows')
else:
    print('No lineage enrichment data to save')

## 6. Figures

In [None]:
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns

# Figure 1: Module × environment enrichment heatmap
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Panel A: Z-scores
ax = axes[0]
if len(enrichment_df) > 0:
    pivot_z = enrichment_df.pivot_table(
        index='module_name', columns='environment', values='z_score'
    )
    sns.heatmap(pivot_z, cmap='RdBu_r', center=0, annot=True, fmt='.1f', ax=ax,
                cbar_kws={'label': 'Z-score (vs null model)'})
    ax.set_title('Module × Environment Enrichment\n(Z-score vs constrained null)')
    ax.set_ylabel('')

# Panel B: log2 OR (observed)
ax = axes[1]
if len(enrichment_df) > 0:
    pivot_or = enrichment_df.pivot_table(
        index='module_name', columns='environment', values='observed_log2_OR'
    )
    # Mark significance with asterisks
    annot_matrix = pivot_or.copy().round(1).astype(str)
    for idx in enrichment_df.index:
        row = enrichment_df.loc[idx]
        if row.get('significant_fdr', False):
            annot_matrix.loc[row['module_name'], row['environment']] += '*'
    
    sns.heatmap(pivot_or, cmap='RdBu_r', center=0, annot=annot_matrix.values,
                fmt='', ax=ax, cbar_kws={'label': 'log2(Odds Ratio)'})
    ax.set_title('Module × Environment Observed Enrichment\n(* = FDR < 0.05)')
    ax.set_ylabel('')

plt.tight_layout()
plt.savefig('../figures/module_environment_enrichment.png', dpi=150, bbox_inches='tight')
plt.show()
print('Saved figures/module_environment_enrichment.png')

In [None]:
# Figure 2: Lineage environment breadth and specialist/generalist
if lineage_breadth is not None and len(lineage_breadth) > 0:
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Panel A: Shannon entropy distribution
    ax = axes[0]
    specialist_mask = lineage_breadth['category'] == 'specialist'
    ax.hist(lineage_breadth.loc[specialist_mask, 'shannon'], bins=30,
            alpha=0.7, color='#E91E63', label='Specialist')
    ax.hist(lineage_breadth.loc[~specialist_mask, 'shannon'], bins=30,
            alpha=0.7, color='#2196F3', label='Generalist')
    ax.set_xlabel('Shannon entropy (environmental breadth)')
    ax.set_ylabel('Number of lineages')
    ax.set_title('Lineage Environmental Breadth')
    ax.legend()
    
    # Panel B: Dominant environment distribution for specialists
    ax = axes[1]
    specialists = lineage_breadth[specialist_mask]
    if len(specialists) > 0:
        env_counts = specialists['dominant_env'].value_counts()
        ax.barh(env_counts.index, env_counts.values, color='#E91E63', alpha=0.8)
        ax.set_xlabel('Number of specialist lineages')
        ax.set_title('Dominant Environment of Specialist Lineages')
    
    plt.tight_layout()
    plt.savefig('../figures/lineage_environment_heatmap.png', dpi=150, bbox_inches='tight')
    plt.show()
    print('Saved figures/lineage_environment_heatmap.png')
else:
    print('No lineage data for figure')

In [None]:
# Summary
print('='*60)
print('NB06 SUMMARY')
print('='*60)
print(f'Species analyzed: {len(analysis_df):,}')
print(f'Module × environment tests: {len(enrichment_df)}')
n_sig_mod = len(sig_enrichments) if len(enrichment_df) > 0 else 0
print(f'Significant module enrichments (FDR<0.05): {n_sig_mod}')

if len(lineage_enrich_df) > 0:
    n_sig_lin = lineage_enrich_df['significant_fdr'].sum() if 'significant_fdr' in lineage_enrich_df else 0
    print(f'Lineage × environment tests: {len(lineage_enrich_df)}')
    print(f'Significant lineage enrichments (FDR<0.05): {n_sig_lin}')

if lineage_breadth is not None:
    n_spec = (lineage_breadth['category'] == 'specialist').sum()
    n_gen = (lineage_breadth['category'] == 'generalist').sum()
    print(f'Lineage classification: {n_spec} specialists, {n_gen} generalists')

print(f'\nFiles saved:')
print(f'  data/enriched_modules.tsv')
if len(lineage_enrich_df) > 0:
    print(f'  data/enriched_lineages.tsv')
print(f'  figures/module_environment_enrichment.png')
if lineage_breadth is not None:
    print(f'  figures/lineage_environment_heatmap.png')