# NB01: Prophage Gene Discovery

**Project**: Prophage Ecology Across Bacterial Phylogeny and Environmental Gradients

**Goal**: Identify all prophage-associated gene clusters in the BERDL pangenome using eggNOG annotations, classify into 7 operationally defined modules (A-G), and extract terminase large subunit (TerL) sequences for lineage clustering.

**Environment**: Requires BERDL JupyterHub (Spark SQL)

**Outputs**:
- `data/prophage_gene_clusters.tsv` — all prophage-associated gene clusters with module assignments
- `data/terL_sequences.fasta` — TerL protein sequences for lineage clustering
- `data/species_module_summary.tsv` — per-species prophage module presence/absence

In [None]:
import sys
import os
import pandas as pd
import numpy as np

# Spark session (on BERDL JupyterHub — no import needed)
spark = get_spark_session()

# Add project src to path for prophage_utils
sys.path.insert(0, '../src')
from prophage_utils import MODULES, classify_gene_to_module, is_terL, build_spark_where_clause

# Output directories
os.makedirs('../data', exist_ok=True)
os.makedirs('../figures', exist_ok=True)

print(f'Spark session active. Modules defined: {list(MODULES.keys())}')

## 1. Explore Prophage-Related Annotations

Before running the full extraction, let's check what prophage-related annotations exist in the eggNOG table and estimate hit counts.

In [None]:
# Check counts for key prophage-related description keywords
keywords = [
    'terminase', 'capsid', 'portal protein', 'holin', 'endolysin',
    'integrase', 'tail sheath', 'tail tube', 'tape measure', 'baseplate',
    'excisionase', 'lysozyme', 'repressor', 'anti-crispr', 'anti-restriction',
    'phage', 'prophage', 'tail fiber', 'tail spike'
]

keyword_counts = []
for kw in keywords:
    count_df = spark.sql(f"""
        SELECT COUNT(*) as n
        FROM kbase_ke_pangenome.eggnog_mapper_annotations
        WHERE LOWER(Description) LIKE '%{kw}%'
    """).toPandas()
    keyword_counts.append({'keyword': kw, 'count': int(count_df['n'].iloc[0])})
    print(f'  {kw}: {count_df["n"].iloc[0]:,}')

kw_df = pd.DataFrame(keyword_counts).sort_values('count', ascending=False)
print(f'\nTotal unique keywords checked: {len(keywords)}')
print(kw_df.to_string(index=False))

In [None]:
# Check PFam domain counts for phage-related domains
pfam_keywords = [
    'Terminase', 'Phage_portal', 'HK97', 'Phage_cap',
    'Phage_holin', 'Phage_integrase', 'Phage_tail',
    'Phage_sheath', 'Phage_fiber', 'Phage_lysozyme',
    'XRE_N', 'ArdA', 'Phage_int_SAM'
]

pfam_counts = []
for pf in pfam_keywords:
    count_df = spark.sql(f"""
        SELECT COUNT(*) as n
        FROM kbase_ke_pangenome.eggnog_mapper_annotations
        WHERE PFAMs LIKE '%{pf}%'
    """).toPandas()
    pfam_counts.append({'pfam_keyword': pf, 'count': int(count_df['n'].iloc[0])})
    print(f'  {pf}: {count_df["n"].iloc[0]:,}')

pf_df = pd.DataFrame(pfam_counts).sort_values('count', ascending=False)
print(f'\nPFam keyword counts:')
print(pf_df.to_string(index=False))

In [None]:
# Sample some annotations for each key prophage marker to understand annotation patterns
for marker in ['terminase large subunit', 'major capsid protein', 'holin', 'phage integrase', 'portal protein']:
    print(f'\n--- {marker.upper()} ---')
    sample = spark.sql(f"""
        SELECT PFAMs, Description, KEGG_ko, COG_category
        FROM kbase_ke_pangenome.eggnog_mapper_annotations
        WHERE LOWER(Description) LIKE '%{marker}%'
        LIMIT 10
    """).toPandas()
    for _, row in sample.iterrows():
        print(f'  PFAMs={row["PFAMs"]}  Desc={row["Description"][:80]}  KO={row["KEGG_ko"]}  COG={row["COG_category"]}')

## 2. Extract All Prophage Gene Clusters

Use the comprehensive WHERE clause from `prophage_utils` to find all gene clusters matching any prophage module marker. Join with `gene_cluster` to get species, core/accessory status, and representative sequences.

In [None]:
# Build the full WHERE clause from our module definitions
where_clause = build_spark_where_clause()
print(f'WHERE clause has {where_clause.count("OR")+1} conditions')
print(f'First 500 chars: {where_clause[:500]}...')

In [None]:
# Main extraction query: join eggnog_mapper_annotations with gene_cluster
# This is the most expensive query — ~93M annotation rows joined with 132M cluster rows
# Filter significantly reduces the scan via the WHERE clause

query = f"""
SELECT gc.gene_cluster_id,
       gc.gtdb_species_clade_id,
       gc.is_core,
       gc.is_auxiliary,
       gc.is_singleton,
       gc.faa_sequence,
       ann.PFAMs,
       ann.Description,
       ann.KEGG_ko,
       ann.COG_category,
       ann.EC
FROM kbase_ke_pangenome.gene_cluster gc
JOIN kbase_ke_pangenome.eggnog_mapper_annotations ann
    ON gc.gene_cluster_id = ann.query_name
WHERE {where_clause}
"""

print('Running prophage gene extraction query...')
prophage_spark = spark.sql(query)

# Cache in Spark for reuse
prophage_spark.cache()
n_total = prophage_spark.count()
print(f'Total prophage-related gene clusters found: {n_total:,}')

In [None]:
# Convert to pandas for module classification
# This should be manageable — prophage clusters are a small fraction of the 132M total
print('Converting to pandas...')
prophage_df = prophage_spark.toPandas()
print(f'DataFrame shape: {prophage_df.shape}')
print(f'Memory: {prophage_df.memory_usage(deep=True).sum() / 1e6:.1f} MB')

## 3. Classify Gene Clusters into Modules (A-G)

Apply the `classify_gene_to_module()` function from `prophage_utils` to assign each gene cluster to one or more of the 7 prophage modules.

In [None]:
# Classify each gene cluster into module(s)
prophage_df['modules'] = prophage_df.apply(
    lambda r: classify_gene_to_module(r['Description'], r['PFAMs'], r['KEGG_ko'], r['COG_category']),
    axis=1
)

# Count how many clusters are assigned to each module
from collections import Counter
module_counter = Counter()
for modules_list in prophage_df['modules']:
    for m in modules_list:
        module_counter[m] += 1

print('Gene clusters per module:')
for module_id in sorted(MODULES.keys()):
    count = module_counter.get(module_id, 0)
    print(f'  {module_id}: {count:,} ({MODULES[module_id]["full_name"]})')

# How many clusters have no module assignment?
n_unassigned = sum(1 for m in prophage_df['modules'] if len(m) == 0)
print(f'\nUnassigned (matched WHERE but not any module): {n_unassigned:,}')

# How many are assigned to multiple modules?
n_multi = sum(1 for m in prophage_df['modules'] if len(m) > 1)
print(f'Multi-module assignments: {n_multi:,}')

In [None]:
# Create a flat representation: one row per cluster, comma-separated module list
prophage_df['module_str'] = prophage_df['modules'].apply(lambda x: ','.join(x) if x else 'unassigned')

# Also create boolean columns for each module
for module_id in MODULES.keys():
    prophage_df[f'has_{module_id}'] = prophage_df['modules'].apply(lambda x: module_id in x)

print('Module assignment summary:')
print(prophage_df['module_str'].value_counts().head(20))

In [None]:
# Identify TerL clusters specifically
prophage_df['is_terL'] = prophage_df.apply(
    lambda r: is_terL(r['Description'], r['PFAMs'], r['KEGG_ko']),
    axis=1
)

n_terL = prophage_df['is_terL'].sum()
n_terL_species = prophage_df.loc[prophage_df['is_terL'], 'gtdb_species_clade_id'].nunique()
print(f'TerL clusters: {n_terL:,}')
print(f'Species with TerL: {n_terL_species:,}')

# Sample some TerL annotations
print('\nSample TerL annotations:')
terL_sample = prophage_df[prophage_df['is_terL']].head(10)
for _, r in terL_sample.iterrows():
    print(f'  {r["gene_cluster_id"][:40]}  PFAMs={r["PFAMs"]}  Desc={r["Description"][:60]}')

## 4. Core/Accessory/Singleton Status of Prophage Genes

In [None]:
# Analyze core/accessory/singleton distribution per module
def classify_conservation(row):
    if row['is_core'] == 1 or row['is_core'] == True:
        return 'core'
    elif row['is_singleton'] == 1 or row['is_singleton'] == True:
        return 'singleton'
    else:
        return 'accessory'

prophage_df['conservation'] = prophage_df.apply(classify_conservation, axis=1)

print('Overall conservation of prophage gene clusters:')
print(prophage_df['conservation'].value_counts())

print('\nConservation by module:')
for module_id in sorted(MODULES.keys()):
    mask = prophage_df[f'has_{module_id}']
    if mask.sum() > 0:
        vc = prophage_df.loc[mask, 'conservation'].value_counts(normalize=True)
        core_pct = vc.get('core', 0) * 100
        acc_pct = vc.get('accessory', 0) * 100
        sing_pct = vc.get('singleton', 0) * 100
        print(f'  {module_id}: core={core_pct:.1f}% acc={acc_pct:.1f}% sing={sing_pct:.1f}% (n={mask.sum():,})')

## 5. Species-Level Module Summary

In [None]:
# Aggregate to species level: which modules are present, how many clusters per module
species_modules = []

for species_id, grp in prophage_df.groupby('gtdb_species_clade_id'):
    row = {'gtdb_species_clade_id': species_id, 'n_prophage_clusters': len(grp)}
    for module_id in MODULES.keys():
        n_mod = grp[f'has_{module_id}'].sum()
        row[f'n_{module_id}'] = n_mod
        row[f'has_{module_id}'] = n_mod > 0
    row['n_terL'] = grp['is_terL'].sum()
    row['n_modules_present'] = sum(1 for m in MODULES.keys() if row[f'has_{m}'])
    species_modules.append(row)

species_mod_df = pd.DataFrame(species_modules)

# Merge with pangenome stats for context
pangenome_stats = spark.sql("""
    SELECT gtdb_species_clade_id, no_genomes, no_gene_clusters, no_core, no_aux_genome
    FROM kbase_ke_pangenome.pangenome
""").toPandas()

species_mod_df = species_mod_df.merge(pangenome_stats, on='gtdb_species_clade_id', how='left')

print(f'Species with any prophage annotation: {len(species_mod_df):,}')
print(f'Total species in pangenome: {len(pangenome_stats):,}')
print(f'Fraction with prophage: {len(species_mod_df)/len(pangenome_stats)*100:.1f}%')
print(f'\nModule presence across species:')
for module_id in sorted(MODULES.keys()):
    n = species_mod_df[f'has_{module_id}'].sum()
    pct = n / len(pangenome_stats) * 100
    print(f'  {module_id}: {n:,} species ({pct:.1f}%)')

print(f'\nNumber of modules present per species:')
print(species_mod_df['n_modules_present'].value_counts().sort_index())

## 6. Save Outputs

In [None]:
# Save prophage gene clusters (drop faa_sequence to keep file small)
prophage_out = prophage_df.drop(columns=['faa_sequence']).copy()
prophage_out.to_csv('../data/prophage_gene_clusters.tsv', sep='\t', index=False)
print(f'Saved data/prophage_gene_clusters.tsv: {len(prophage_out):,} rows')

# Save species module summary
species_mod_df.to_csv('../data/species_module_summary.tsv', sep='\t', index=False)
print(f'Saved data/species_module_summary.tsv: {len(species_mod_df):,} rows')

# Save TerL sequences as FASTA for lineage clustering
terL_clusters = prophage_df[prophage_df['is_terL'] & prophage_df['faa_sequence'].notna()].copy()
n_with_seq = len(terL_clusters)
print(f'\nTerL clusters with protein sequence: {n_with_seq:,}')

with open('../data/terL_sequences.fasta', 'w') as f:
    for _, row in terL_clusters.iterrows():
        seq = row['faa_sequence']
        if seq and len(seq) > 10:
            header = f">{row['gene_cluster_id']}|{row['gtdb_species_clade_id']}"
            f.write(f"{header}\n{seq}\n")

print(f'Saved data/terL_sequences.fasta')

## 7. Validation: Check Known Prophage-Carrying Species

Verify that well-known prophage-carrying species (e.g., *E. coli* with lambda, *S. typhimurium* with P22) are detected.

In [None]:
# Check known prophage-carrying species
known_species = [
    's__Escherichia_coli',      # lambda phage
    's__Salmonella_enterica',   # P22 phage
    's__Staphylococcus_aureus', # phi11, phiSa3
    's__Pseudomonas_aeruginosa', # multiple prophages
    's__Mycobacterium_tuberculosis', # phiRv1, phiRv2
    's__Bacillus_subtilis',     # SPBeta, PBSX
    's__Vibrio_cholerae',       # CTXphi
    's__Streptococcus_pyogenes', # multiple prophages encoding virulence
]

for species_name in known_species:
    matches = species_mod_df[
        species_mod_df['gtdb_species_clade_id'].str.startswith(species_name)
    ]
    if len(matches) > 0:
        row = matches.iloc[0]
        modules_present = [m for m in MODULES.keys() if row[f'has_{m}']]
        print(f'{species_name}: {row["n_prophage_clusters"]} clusters, '
              f'{row["n_modules_present"]}/7 modules: {modules_present}')
    else:
        print(f'{species_name}: NOT FOUND in prophage results')

## 8. Summary Statistics

In [None]:
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Panel 1: Gene clusters per module
module_names = [MODULES[m]['full_name'] for m in sorted(MODULES.keys())]
module_counts = [module_counter.get(m, 0) for m in sorted(MODULES.keys())]
axes[0].barh(module_names, module_counts, color='steelblue')
axes[0].set_xlabel('Gene Clusters')
axes[0].set_title('Prophage Gene Clusters by Module')
for i, (v, n) in enumerate(zip(module_counts, module_names)):
    axes[0].text(v + max(module_counts)*0.01, i, f'{v:,}', va='center', fontsize=9)

# Panel 2: Species per module
species_counts = [species_mod_df[f'has_{m}'].sum() for m in sorted(MODULES.keys())]
axes[1].barh(module_names, species_counts, color='darkorange')
axes[1].set_xlabel('Species')
axes[1].set_title('Species with Module Present')
for i, v in enumerate(species_counts):
    axes[1].text(v + max(species_counts)*0.01, i, f'{v:,}', va='center', fontsize=9)

# Panel 3: Distribution of modules per species
mod_dist = species_mod_df['n_modules_present'].value_counts().sort_index()
axes[2].bar(mod_dist.index, mod_dist.values, color='seagreen')
axes[2].set_xlabel('Number of Modules Present')
axes[2].set_ylabel('Number of Species')
axes[2].set_title('Prophage Module Richness per Species')
axes[2].set_xticks(range(0, 8))

plt.tight_layout()
plt.savefig('../figures/prophage_module_discovery.png', dpi=150, bbox_inches='tight')
plt.show()
print('Saved figures/prophage_module_discovery.png')

In [None]:
# Conservation status by module (stacked bar)
fig, ax = plt.subplots(figsize=(10, 5))

conservation_data = []
for module_id in sorted(MODULES.keys()):
    mask = prophage_df[f'has_{module_id}']
    if mask.sum() > 0:
        vc = prophage_df.loc[mask, 'conservation'].value_counts(normalize=True)
        conservation_data.append({
            'module': MODULES[module_id]['full_name'],
            'core': vc.get('core', 0),
            'accessory': vc.get('accessory', 0),
            'singleton': vc.get('singleton', 0),
        })

cons_df = pd.DataFrame(conservation_data)
cons_df.plot.barh(
    x='module', stacked=True, ax=ax,
    color=['#2ca02c', '#ff7f0e', '#d62728'],
)
ax.set_xlabel('Fraction')
ax.set_title('Core/Accessory/Singleton Status of Prophage Gene Clusters by Module')
ax.legend(title='Conservation')

plt.tight_layout()
plt.savefig('../figures/prophage_conservation_by_module.png', dpi=150, bbox_inches='tight')
plt.show()
print('Saved figures/prophage_conservation_by_module.png')

In [None]:
# Final summary
print('='*60)
print('NB01 SUMMARY')
print('='*60)
print(f'Total prophage gene clusters: {len(prophage_df):,}')
print(f'Species with prophage annotations: {species_mod_df.shape[0]:,} / {len(pangenome_stats):,} ({species_mod_df.shape[0]/len(pangenome_stats)*100:.1f}%)')
print(f'TerL clusters (for lineage clustering): {n_terL:,}')
print(f'TerL clusters with sequences: {n_with_seq:,}')
print(f'\nModule prevalence (species with >= 1 cluster):')
for module_id in sorted(MODULES.keys()):
    n = species_mod_df[f'has_{module_id}'].sum()
    print(f'  {module_id}: {n:,} species')
print(f'\nFiles saved:')
print(f'  data/prophage_gene_clusters.tsv ({len(prophage_out):,} rows)')
print(f'  data/species_module_summary.tsv ({len(species_mod_df):,} rows)')
print(f'  data/terL_sequences.fasta ({n_with_seq:,} sequences)')
print(f'  figures/prophage_module_discovery.png')
print(f'  figures/prophage_conservation_by_module.png')