# NB04: Phylogenetic Distribution & Variance Partitioning

**Project**: Prophage Ecology Across Bacterial Phylogeny and Environmental Gradients

**Goal**: Map prophage module prevalence across GTDB taxonomy; partition variance into phylogeny vs environment components; control for genome size and assembly completeness.

**Dependencies**: NB01 outputs (`data/prophage_gene_clusters.tsv`, `data/species_module_summary.tsv`), NB02 (`data/terL_lineages.tsv`), NB03 (`data/module_cooccurrence_stats.tsv`)

**Environment**: Requires BERDL JupyterHub (Spark SQL) for environment metadata extraction; statistical analysis runs locally.

**Outputs**:
- `data/species_prophage_environment.tsv` — merged species × module × environment × genome size
- `data/variance_partitioning_results.tsv` — PERMANOVA results
- `data/species_genome_size.tsv` — genome size data
- `figures/prophage_prevalence_by_phylum.png`
- `figures/variance_partitioning.png`
- `figures/genome_size_confound.png`

In [None]:
import sys
import os
import pandas as pd
import numpy as np
from scipy import stats
from scipy.spatial.distance import pdist, squareform

spark = get_spark_session()

sys.path.insert(0, '../src')
from prophage_utils import MODULES

os.makedirs('../data', exist_ok=True)
os.makedirs('../figures', exist_ok=True)

# Load NB01 outputs
species_summary = pd.read_csv('../data/species_module_summary.tsv', sep='\t')
prophage_clusters = pd.read_csv('../data/prophage_gene_clusters.tsv', sep='\t')

# Load NB02 lineage data if available
try:
    lineages = pd.read_csv('../data/terL_lineages.tsv', sep='\t')
    lineage_summary = pd.read_csv('../data/lineage_summary.tsv', sep='\t')
    print(f'TerL lineages: {lineages["lineage_id"].nunique():,}')
except FileNotFoundError:
    lineages = None
    lineage_summary = None
    print('NB02 outputs not found — lineage analysis will be skipped')

print(f'Species with prophage: {len(species_summary):,}')
print(f'Prophage clusters: {len(prophage_clusters):,}')

## 1. GTDB Taxonomy & Prophage Prevalence by Phylum

Map prophage module presence across GTDB taxonomy tree.

In [None]:
# Get full GTDB taxonomy for all species
taxonomy = spark.sql("""
    SELECT DISTINCT g.gtdb_species_clade_id,
           t.phylum, t.class, t.order, t.family, t.genus
    FROM kbase_ke_pangenome.genome g
    JOIN kbase_ke_pangenome.gtdb_taxonomy_r214v1 t ON g.gtdb_taxonomy_id = t.gtdb_taxonomy_id
""").toPandas()

# Get total species count per phylum (denominator for prevalence)
total_species = spark.sql("""
    SELECT COUNT(DISTINCT g.gtdb_species_clade_id) as n
    FROM kbase_ke_pangenome.genome g
""").toPandas()['n'].iloc[0]

print(f'Taxonomy entries: {len(taxonomy):,}')
print(f'Total species in pangenome: {total_species:,}')
print(f'Phyla: {taxonomy["phylum"].nunique()}')

# Merge with species module summary
species_tax = species_summary.merge(taxonomy, on='gtdb_species_clade_id', how='left')
print(f'\nSpecies with taxonomy + prophage: {len(species_tax):,}')

In [None]:
# Prophage prevalence by phylum
phylum_stats = []
all_phylum_species = taxonomy.groupby('phylum')['gtdb_species_clade_id'].nunique()

for phylum, grp in species_tax.groupby('phylum'):
    total_in_phylum = all_phylum_species.get(phylum, 0)
    row = {
        'phylum': phylum,
        'n_species_total': total_in_phylum,
        'n_species_prophage': len(grp),
        'pct_with_prophage': len(grp) / total_in_phylum * 100 if total_in_phylum > 0 else 0,
        'mean_modules': grp['n_modules_present'].mean(),
        'mean_prophage_clusters': grp['n_prophage_clusters'].mean(),
    }
    for module_id in MODULES.keys():
        row[f'pct_{module_id}'] = grp[f'has_{module_id}'].mean() * 100
    phylum_stats.append(row)

phylum_df = pd.DataFrame(phylum_stats).sort_values('n_species_prophage', ascending=False)

print('Top 20 phyla by prophage-carrying species:')
display_cols = ['phylum', 'n_species_total', 'n_species_prophage', 'pct_with_prophage', 'mean_modules']
print(phylum_df[display_cols].head(20).to_string(index=False))

In [None]:
# Module prevalence by phylum (top 15 phyla with most prophage-carrying species)
top_phyla = phylum_df.head(15)['phylum'].tolist()
module_cols = [f'pct_{m}' for m in sorted(MODULES.keys())]
module_names = [MODULES[m]['full_name'] for m in sorted(MODULES.keys())]

print('Module prevalence (% of species) by phylum:')
phylum_module_df = phylum_df[phylum_df['phylum'].isin(top_phyla)][['phylum'] + module_cols].copy()
phylum_module_df.columns = ['phylum'] + module_names
print(phylum_module_df.round(1).to_string(index=False))

## 2. Environment Metadata

Extract environment classifications from `ncbi_env` (EAV format) following the established pattern from PHB NB03.

In [None]:
# Extract environment metadata for all genomes (pivot from EAV)
env_data = spark.sql("""
    SELECT g.genome_id, g.gtdb_species_clade_id,
           MAX(CASE WHEN ne.harmonized_name = 'isolation_source' THEN ne.content END) as isolation_source,
           MAX(CASE WHEN ne.harmonized_name = 'env_broad_scale' THEN ne.content END) as env_broad_scale,
           MAX(CASE WHEN ne.harmonized_name = 'env_local_scale' THEN ne.content END) as env_local_scale,
           MAX(CASE WHEN ne.harmonized_name = 'env_medium' THEN ne.content END) as env_medium,
           MAX(CASE WHEN ne.harmonized_name = 'host' THEN ne.content END) as host
    FROM kbase_ke_pangenome.genome g
    JOIN kbase_ke_pangenome.ncbi_env ne ON g.ncbi_biosample_id = ne.accession
    GROUP BY g.genome_id, g.gtdb_species_clade_id
""").toPandas()

print(f'Genomes with environment metadata: {len(env_data):,}')
print(f'Species represented: {env_data["gtdb_species_clade_id"].nunique():,}')
print(f'\nField coverage:')
for col in ['isolation_source', 'env_broad_scale', 'env_local_scale', 'host']:
    n = env_data[col].notna().sum()
    print(f'  {col}: {n:,} ({n/len(env_data)*100:.1f}%)')

In [None]:
# Classify environments into broad categories
def classify_environment(row):
    """Classify genome environment from NCBI metadata."""
    source = str(row.get('isolation_source', '')).lower()
    host = str(row.get('host', '')).lower()
    env_broad = str(row.get('env_broad_scale', '')).lower()
    
    if any(kw in source for kw in ['blood', 'sputum', 'urine', 'wound', 'clinical',
                                     'patient', 'hospital', 'human']):
        return 'human_clinical'
    if 'homo sapiens' in host or 'human' in host:
        return 'human_associated'
    if any(kw in source for kw in ['animal', 'bovine', 'chicken', 'pig', 'cattle',
                                     'poultry', 'feces', 'gut', 'intestin']):
        return 'animal_associated'
    if any(kw in source for kw in ['soil', 'rhizosphere', 'root', 'compost', 'peat']):
        return 'soil'
    if any(kw in source for kw in ['plant', 'leaf', 'stem', 'flower', 'seed', 'phyllosphere']):
        return 'plant_associated'
    if any(kw in source for kw in ['freshwater', 'lake', 'river', 'pond', 'stream',
                                     'groundwater', 'spring']):
        return 'freshwater'
    if any(kw in source for kw in ['marine', 'ocean', 'sea', 'seawater', 'coastal',
                                     'estuarine', 'estuary', 'tidal', 'coral']):
        return 'marine'
    if any(kw in source for kw in ['wastewater', 'sewage', 'activated sludge', 'bioreactor',
                                     'ferment']):
        return 'wastewater_engineered'
    if any(kw in source for kw in ['sediment', 'mud', 'silt']):
        return 'sediment'
    if any(kw in source for kw in ['food', 'milk', 'cheese', 'meat', 'fish']):
        return 'food'
    return 'other_unknown'

env_data['env_category'] = env_data.apply(classify_environment, axis=1)
print('Environment classification:')
print(env_data['env_category'].value_counts())

In [None]:
# Aggregate environment to species level (most common env category)
species_env = env_data.groupby('gtdb_species_clade_id').agg(
    n_genomes_with_env=('genome_id', 'count'),
    primary_env=('env_category', lambda x: x.value_counts().index[0]),
    n_env_categories=('env_category', 'nunique'),
).reset_index()

print(f'Species with environment data: {len(species_env):,}')
print(f'\nPrimary environment distribution:')
print(species_env['primary_env'].value_counts())

## 3. Genome Size & Assembly Completeness

Extract genome size and CheckM completeness to use as covariates.

In [None]:
# Get species-level genome size and completeness from gtdb_metadata
genome_meta = spark.sql("""
    SELECT g.gtdb_species_clade_id,
           COUNT(*) as n_genomes,
           AVG(CAST(m.genome_size AS DOUBLE)) as mean_genome_size_bp,
           PERCENTILE_APPROX(CAST(m.genome_size AS DOUBLE), 0.5) as median_genome_size_bp,
           AVG(CAST(m.protein_count AS DOUBLE)) as mean_protein_count,
           AVG(CAST(m.checkm_completeness AS DOUBLE)) as mean_completeness,
           AVG(CAST(m.checkm_contamination AS DOUBLE)) as mean_contamination
    FROM kbase_ke_pangenome.genome g
    JOIN kbase_ke_pangenome.gtdb_metadata m ON g.genome_id = m.accession
    WHERE m.genome_size IS NOT NULL
    GROUP BY g.gtdb_species_clade_id
""").toPandas()

genome_meta['genome_size_Mbp'] = genome_meta['mean_genome_size_bp'] / 1e6

print(f'Species with genome metadata: {len(genome_meta):,}')
print(f'\nGenome size (Mbp):')
print(genome_meta['genome_size_Mbp'].describe())
print(f'\nCompleteness:')
print(genome_meta['mean_completeness'].describe())

In [None]:
# Genome size vs prophage module count
species_size = genome_meta.merge(
    species_summary[['gtdb_species_clade_id', 'n_prophage_clusters', 'n_modules_present']],
    on='gtdb_species_clade_id', how='left'
)
species_size['n_prophage_clusters'] = species_size['n_prophage_clusters'].fillna(0)
species_size['n_modules_present'] = species_size['n_modules_present'].fillna(0)
species_size['has_prophage'] = species_size['n_prophage_clusters'] > 0

# Correlation: genome size vs prophage count
rho_size_count, p_size_count = stats.spearmanr(
    species_size['genome_size_Mbp'], species_size['n_prophage_clusters'])

# Genome size: prophage+ vs prophage-
prophage_pos_size = species_size[species_size['has_prophage']]['genome_size_Mbp']
prophage_neg_size = species_size[~species_size['has_prophage']]['genome_size_Mbp']
u_stat, p_size_mw = stats.mannwhitneyu(prophage_pos_size, prophage_neg_size, alternative='two-sided')

print(f'Genome size vs prophage cluster count: rho={rho_size_count:.3f}, p={p_size_count:.2e}')
print(f'\nGenome size (Mbp):')
print(f'  Prophage+: median={prophage_pos_size.median():.2f}, n={len(prophage_pos_size):,}')
print(f'  Prophage-: median={prophage_neg_size.median():.2f}, n={len(prophage_neg_size):,}')
print(f'  Mann-Whitney p={p_size_mw:.2e}')
print(f'  Difference: {prophage_pos_size.median() - prophage_neg_size.median():.2f} Mbp')

## 4. Merge All Data & Prophage by Environment

In [None]:
# Merge species summary + taxonomy + environment + genome size
species_all = species_summary.merge(taxonomy, on='gtdb_species_clade_id', how='left')
species_all = species_all.merge(species_env, on='gtdb_species_clade_id', how='left')
species_all = species_all.merge(
    genome_meta[['gtdb_species_clade_id', 'genome_size_Mbp', 'mean_protein_count',
                  'mean_completeness', 'mean_contamination']],
    on='gtdb_species_clade_id', how='left'
)

# Add lineage count per species if available
if lineages is not None:
    lineage_per_species = lineages.groupby('gtdb_species_clade_id')['lineage_id'].nunique().reset_index()
    lineage_per_species.columns = ['gtdb_species_clade_id', 'n_lineages']
    species_all = species_all.merge(lineage_per_species, on='gtdb_species_clade_id', how='left')
    species_all['n_lineages'] = species_all['n_lineages'].fillna(0).astype(int)

print(f'Species with all data merged: {len(species_all):,}')
print(f'With environment: {species_all["primary_env"].notna().sum():,}')
print(f'With genome size: {species_all["genome_size_Mbp"].notna().sum():,}')

# Prophage module prevalence by environment
env_mask = species_all['primary_env'].notna() & (species_all['primary_env'] != 'other_unknown')
species_known_env = species_all[env_mask]

print(f'\nProphage module prevalence by environment:')
for env_cat, egrp in species_known_env.groupby('primary_env'):
    n = len(egrp)
    if n < 10:
        continue
    modules_str = '  '.join(
        f'{m[:3]}={egrp[f"has_{m}"].mean()*100:.0f}%' for m in sorted(MODULES.keys())
    )
    print(f'  {env_cat:25s} (n={n:5d}): {modules_str}')

## 5. Variance Partitioning: PERMANOVA

Use PERMANOVA (Anderson 2001) to partition variance in prophage module composition into:
- Phylogeny (GTDB family)
- Environment (primary_env category)
- Genome size (quartiles)

We use family-level taxonomy as the phylogenetic grouping variable since it provides a balance between resolution and group sizes.

In [None]:
from skbio.stats.distance import permanova
from skbio import DistanceMatrix

# Prepare data for PERMANOVA
# Filter to species with known environment, taxonomy, and genome size
permanova_df = species_all[
    species_all['primary_env'].notna() &
    (species_all['primary_env'] != 'other_unknown') &
    species_all['family'].notna() &
    species_all['genome_size_Mbp'].notna()
].copy()

# Need sufficient observations per group; filter families with >= 3 species
family_counts = permanova_df['family'].value_counts()
valid_families = family_counts[family_counts >= 3].index
permanova_df = permanova_df[permanova_df['family'].isin(valid_families)]

# Create genome size quartiles
permanova_df['size_quartile'] = pd.qcut(
    permanova_df['genome_size_Mbp'], 4,
    labels=['Q1_small', 'Q2', 'Q3', 'Q4_large']
)

print(f'Species for PERMANOVA: {len(permanova_df):,}')
print(f'Families: {permanova_df["family"].nunique()}')
print(f'Environments: {permanova_df["primary_env"].nunique()}')
print(f'\nEnvironment distribution:')
print(permanova_df['primary_env'].value_counts())

In [None]:
# Build prophage module composition matrix (species × 7 modules)
module_ids = sorted(MODULES.keys())
module_count_cols = [f'n_{m}' for m in module_ids]

# Use count-based composition (how many clusters per module per species)
X = permanova_df[module_count_cols].fillna(0).values

# Compute Bray-Curtis distance matrix
bc_dists = pdist(X, metric='braycurtis')
# Handle NaN distances (species with all-zero module counts)
bc_dists = np.nan_to_num(bc_dists, nan=1.0)

dm = DistanceMatrix(bc_dists, ids=permanova_df['gtdb_species_clade_id'].tolist())

print(f'Distance matrix: {dm.shape[0]} × {dm.shape[1]}')
print(f'Mean Bray-Curtis distance: {np.mean(bc_dists):.3f}')

In [None]:
# PERMANOVA tests
# If dataset is too large for full PERMANOVA (>10K species), subsample
MAX_PERMANOVA = 10000
if len(permanova_df) > MAX_PERMANOVA:
    print(f'Subsampling from {len(permanova_df):,} to {MAX_PERMANOVA:,} for PERMANOVA...')
    subsample_idx = np.random.choice(len(permanova_df), MAX_PERMANOVA, replace=False)
    perm_sub = permanova_df.iloc[subsample_idx].copy().reset_index(drop=True)
    X_sub = X[subsample_idx]
    bc_sub = pdist(X_sub, metric='braycurtis')
    bc_sub = np.nan_to_num(bc_sub, nan=1.0)
    dm_sub = DistanceMatrix(bc_sub, ids=perm_sub['gtdb_species_clade_id'].tolist())
else:
    perm_sub = permanova_df.copy()
    dm_sub = dm

# Recompute valid families for the subsample
fam_counts_sub = perm_sub['family'].value_counts()
valid_fam_sub = fam_counts_sub[fam_counts_sub >= 2].index
fam_mask = perm_sub['family'].isin(valid_fam_sub)

if fam_mask.sum() < len(perm_sub):
    perm_sub_filt = perm_sub[fam_mask].reset_index(drop=True)
    ids_filt = perm_sub_filt['gtdb_species_clade_id'].tolist()
    dm_filt = dm_sub.filter(ids_filt)
else:
    perm_sub_filt = perm_sub
    dm_filt = dm_sub

results = []

# Test 1: Phylogeny (family)
print('Running PERMANOVA: prophage composition ~ phylogeny (family)...')
try:
    res_phylo = permanova(dm_filt, perm_sub_filt['family'], permutations=999)
    results.append({
        'predictor': 'phylogeny_family',
        'test_statistic': res_phylo['test statistic'],
        'p_value': res_phylo['p-value'],
        'n_groups': perm_sub_filt['family'].nunique(),
        'n_samples': len(perm_sub_filt),
    })
    print(f'  F={res_phylo["test statistic"]:.2f}, p={res_phylo["p-value"]}')
except Exception as e:
    print(f'  Error: {e}')

# Test 2: Environment
print('Running PERMANOVA: prophage composition ~ environment...')
try:
    res_env = permanova(dm_filt, perm_sub_filt['primary_env'], permutations=999)
    results.append({
        'predictor': 'environment',
        'test_statistic': res_env['test statistic'],
        'p_value': res_env['p-value'],
        'n_groups': perm_sub_filt['primary_env'].nunique(),
        'n_samples': len(perm_sub_filt),
    })
    print(f'  F={res_env["test statistic"]:.2f}, p={res_env["p-value"]}')
except Exception as e:
    print(f'  Error: {e}')

# Test 3: Genome size quartile
print('Running PERMANOVA: prophage composition ~ genome size quartile...')
try:
    res_size = permanova(dm_filt, perm_sub_filt['size_quartile'].astype(str), permutations=999)
    results.append({
        'predictor': 'genome_size_quartile',
        'test_statistic': res_size['test statistic'],
        'p_value': res_size['p-value'],
        'n_groups': 4,
        'n_samples': len(perm_sub_filt),
    })
    print(f'  F={res_size["test statistic"]:.2f}, p={res_size["p-value"]}')
except Exception as e:
    print(f'  Error: {e}')

results_df = pd.DataFrame(results)
print('\n=== PERMANOVA Summary ===')
print(results_df.to_string(index=False))

## 6. Per-Module Environment Enrichment

Test whether each module is enriched or depleted in specific environments, using chi-squared tests per module × environment.

In [None]:
# Per-module chi-squared: is module prevalence different across environments?
module_env_results = []

for module_id in sorted(MODULES.keys()):
    has_col = f'has_{module_id}'
    ct = pd.crosstab(species_known_env['primary_env'], species_known_env[has_col])
    if ct.shape[1] < 2:
        continue
    chi2, p, dof, expected = stats.chi2_contingency(ct)
    
    # Compute per-environment prevalence
    env_prev = species_known_env.groupby('primary_env')[has_col].mean()
    overall_prev = species_known_env[has_col].mean()
    
    module_env_results.append({
        'module': module_id,
        'module_name': MODULES[module_id]['full_name'],
        'chi2': chi2,
        'p_value': p,
        'dof': dof,
        'overall_prevalence': overall_prev,
        'max_env': env_prev.idxmax(),
        'max_env_prev': env_prev.max(),
        'min_env': env_prev.idxmin(),
        'min_env_prev': env_prev.min(),
    })
    print(f'{module_id}: chi2={chi2:.1f}, p={p:.2e}')
    print(f'  Overall: {overall_prev*100:.1f}%')
    print(f'  Highest: {env_prev.idxmax()} ({env_prev.max()*100:.1f}%)')
    print(f'  Lowest:  {env_prev.idxmin()} ({env_prev.min()*100:.1f}%)')

module_env_df = pd.DataFrame(module_env_results)

## 7. Genome Size Confound Analysis

Test whether prophage-environment associations hold after controlling for genome size (same stratification approach as PHB NB03).

In [None]:
# Genome size stratified analysis
# Does prophage module count vary by environment within each genome size quartile?
species_with_size = species_all[
    species_all['genome_size_Mbp'].notna() &
    species_all['primary_env'].notna() &
    (species_all['primary_env'] != 'other_unknown')
].copy()

species_with_size['size_quartile'] = pd.qcut(
    species_with_size['genome_size_Mbp'], 4,
    labels=['Q1 (small)', 'Q2', 'Q3', 'Q4 (large)']
)

print('Mean prophage module count by genome size quartile and environment:\n')

# Select environments with enough species
env_counts = species_with_size['primary_env'].value_counts()
major_envs = env_counts[env_counts >= 50].index.tolist()

for q in ['Q1 (small)', 'Q2', 'Q3', 'Q4 (large)']:
    q_sub = species_with_size[species_with_size['size_quartile'] == q]
    size_range = f'{q_sub["genome_size_Mbp"].min():.1f}-{q_sub["genome_size_Mbp"].max():.1f} Mbp'
    print(f'\n{q} ({size_range}, n={len(q_sub):,}):')
    
    for env in sorted(major_envs):
        e_sub = q_sub[q_sub['primary_env'] == env]
        if len(e_sub) >= 5:
            mean_modules = e_sub['n_modules_present'].mean()
            pct_prophage = (e_sub['n_prophage_clusters'] > 0).mean() * 100
            print(f'  {env:25s}: {pct_prophage:.0f}% with prophage, '
                  f'mean {mean_modules:.1f} modules (n={len(e_sub)})')

# Kruskal-Wallis per quartile: does module count differ by environment?
print('\n--- Kruskal-Wallis: module count ~ environment within each size quartile ---')
for q in ['Q1 (small)', 'Q2', 'Q3', 'Q4 (large)']:
    q_sub = species_with_size[species_with_size['size_quartile'] == q]
    groups = [g['n_modules_present'].values for _, g in q_sub.groupby('primary_env') if len(g) >= 5]
    if len(groups) >= 2:
        h_stat, p_kw = stats.kruskal(*groups)
        print(f'  {q}: H={h_stat:.1f}, p={p_kw:.2e}')

In [None]:
# Partial Spearman: prophage burden ~ environment controlling for genome size
# Encode environment as numeric (mean prophage rank to test gradient)
from scipy.stats import rankdata

def partial_spearman(x, y, z):
    """Partial Spearman correlation between x and y, controlling for z."""
    rx = rankdata(x)
    ry = rankdata(y)
    rz = rankdata(z)
    cx = np.polyfit(rz, rx, 1)
    cy = np.polyfit(rz, ry, 1)
    res_x = rx - np.polyval(cx, rz)
    res_y = ry - np.polyval(cy, rz)
    return stats.spearmanr(res_x, res_y)

# Test: prophage cluster count ~ genome size
valid = species_with_size['n_prophage_clusters'].notna() & species_with_size['genome_size_Mbp'].notna()
rho_raw, p_raw = stats.spearmanr(
    species_with_size.loc[valid, 'n_prophage_clusters'],
    species_with_size.loc[valid, 'genome_size_Mbp']
)
print(f'Prophage count ~ genome size: rho={rho_raw:.3f}, p={p_raw:.2e}')

# For each module, test: module presence ~ environment (encoded) | genome size
# Use soil as proxy for 'enriched' environment based on lysogeny literature
species_with_size['is_soil'] = (species_with_size['primary_env'] == 'soil').astype(int)

print('\nPartial Spearman (module ~ soil | genome_size):')
for module_id in sorted(MODULES.keys()):
    has_col = f'has_{module_id}'
    valid = species_with_size[has_col].notna() & species_with_size['genome_size_Mbp'].notna()
    if valid.sum() < 100:
        continue
    
    rho_partial, p_partial = partial_spearman(
        species_with_size.loc[valid, has_col].astype(float).values,
        species_with_size.loc[valid, 'is_soil'].values,
        species_with_size.loc[valid, 'genome_size_Mbp'].values
    )
    rho_unadj, p_unadj = stats.spearmanr(
        species_with_size.loc[valid, has_col].astype(float),
        species_with_size.loc[valid, 'is_soil']
    )
    print(f'  {module_id}: raw rho={rho_unadj:.3f}, partial rho={rho_partial:.3f} (p={p_partial:.2e})')

## 8. AlphaEarth Embedding Analysis

Use 64-dim environmental embeddings (28% genome coverage) for continuous environmental gradient analysis.

In [None]:
# Extract AlphaEarth embeddings
emb_cols = [f'A{i:02d}' for i in range(64)]
emb_select = ', '.join([f'ae.{c}' for c in emb_cols])

embeddings = spark.sql(f"""
    SELECT g.genome_id, g.gtdb_species_clade_id,
           {emb_select}
    FROM kbase_ke_pangenome.genome g
    JOIN kbase_ke_pangenome.alphaearth_embeddings_all_years ae
        ON g.genome_id = ae.genome_id
""").toPandas()

print(f'Genomes with embeddings: {len(embeddings):,}')
print(f'Species with embeddings: {embeddings["gtdb_species_clade_id"].nunique():,}')

# Filter NaN embeddings
valid_mask = ~embeddings[emb_cols].isna().any(axis=1)
embeddings = embeddings[valid_mask]
print(f'After NaN filter: {len(embeddings):,} genomes')

In [None]:
# Per-species embedding statistics
species_emb = embeddings.groupby('gtdb_species_clade_id').agg(
    n_genomes_emb=('genome_id', 'count'),
    **{f'mean_{c}': (c, 'mean') for c in emb_cols},
    **{f'var_{c}': (c, 'var') for c in emb_cols},
).reset_index()

var_cols = [f'var_{c}' for c in emb_cols]
species_emb['total_emb_variance'] = species_emb[var_cols].sum(axis=1)

# Filter to species with >= 5 genomes for stable estimates
species_emb = species_emb[species_emb['n_genomes_emb'] >= 5]
print(f'Species with >= 5 embedded genomes: {len(species_emb):,}')

# Merge with prophage data
emb_prophage = species_emb[['gtdb_species_clade_id', 'n_genomes_emb', 'total_emb_variance']].merge(
    species_all[['gtdb_species_clade_id', 'n_prophage_clusters', 'n_modules_present',
                  'genome_size_Mbp'] + [f'has_{m}' for m in MODULES.keys()]],
    on='gtdb_species_clade_id', how='inner'
)
emb_prophage['has_prophage'] = emb_prophage['n_prophage_clusters'] > 0

print(f'Species with embeddings + prophage data: {len(emb_prophage):,}')

# Test: embedding variance (environmental breadth) vs prophage burden
pos = emb_prophage[emb_prophage['has_prophage']]['total_emb_variance']
neg = emb_prophage[~emb_prophage['has_prophage']]['total_emb_variance']
u, p_emb = stats.mannwhitneyu(pos, neg, alternative='two-sided')

print(f'\nEmbedding variance (niche breadth):')
print(f'  Prophage+: median={pos.median():.4f}, n={len(pos):,}')
print(f'  Prophage-: median={neg.median():.4f}, n={len(neg):,}')
print(f'  Mann-Whitney p={p_emb:.2e}')

# Partial correlation controlling for genome size
valid = emb_prophage['genome_size_Mbp'].notna()
rho_partial_emb, p_partial_emb = partial_spearman(
    emb_prophage.loc[valid, 'total_emb_variance'].values,
    emb_prophage.loc[valid, 'n_modules_present'].values,
    emb_prophage.loc[valid, 'genome_size_Mbp'].values
)
print(f'\nPartial Spearman (emb_variance ~ modules | genome_size): rho={rho_partial_emb:.3f}, p={p_partial_emb:.2e}')

## 9. Save Outputs

In [None]:
# Save merged species data
species_all.to_csv('../data/species_prophage_environment.tsv', sep='\t', index=False)
print(f'Saved data/species_prophage_environment.tsv: {len(species_all):,} rows')

# Save PERMANOVA results
results_df.to_csv('../data/variance_partitioning_results.tsv', sep='\t', index=False)
print(f'Saved data/variance_partitioning_results.tsv: {len(results_df)} rows')

# Save genome size data
species_size.to_csv('../data/species_genome_size.tsv', sep='\t', index=False)
print(f'Saved data/species_genome_size.tsv: {len(species_size):,} rows')

## 10. Figures

In [None]:
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns

# Figure 1: Prophage module prevalence by phylum (heatmap)
fig, ax = plt.subplots(figsize=(12, 8))

# Use top 15 phyla
heatmap_data = phylum_df[phylum_df['phylum'].isin(top_phyla)].set_index('phylum')
heatmap_cols = [f'pct_{m}' for m in sorted(MODULES.keys())]
heatmap_labels = [MODULES[m]['full_name'] for m in sorted(MODULES.keys())]

sns.heatmap(
    heatmap_data[heatmap_cols].rename(columns=dict(zip(heatmap_cols, heatmap_labels))),
    cmap='YlOrRd', annot=True, fmt='.0f',
    ax=ax, cbar_kws={'label': '% species with module'}
)
ax.set_title('Prophage Module Prevalence by Phylum (top 15)')
ax.set_ylabel('')
ax.set_xlabel('')

plt.tight_layout()
plt.savefig('../figures/prophage_prevalence_by_phylum.png', dpi=150, bbox_inches='tight')
plt.show()
print('Saved figures/prophage_prevalence_by_phylum.png')

In [None]:
# Figure 2: Variance partitioning summary
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Panel A: PERMANOVA F-statistics
ax = axes[0]
if len(results_df) > 0:
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c'][:len(results_df)]
    bars = ax.barh(results_df['predictor'], results_df['test_statistic'], color=colors)
    for i, (_, row) in enumerate(results_df.iterrows()):
        pstr = f'p={row["p_value"]:.3f}' if row['p_value'] >= 0.001 else f'p<0.001'
        ax.text(row['test_statistic'] + 0.5, i, pstr, va='center', fontsize=10)
    ax.set_xlabel('PERMANOVA pseudo-F')
    ax.set_title('Variance in Prophage Composition\nExplained by Each Factor')

# Panel B: Module prevalence by environment
ax = axes[1]
env_module_heatmap = species_known_env.groupby('primary_env')[
    [f'has_{m}' for m in sorted(MODULES.keys())]
].mean() * 100

# Filter to environments with >= 50 species
env_counts = species_known_env['primary_env'].value_counts()
plot_envs = env_counts[env_counts >= 50].index
env_module_heatmap = env_module_heatmap.loc[plot_envs]
env_module_heatmap.columns = [MODULES[m]['full_name'] for m in sorted(MODULES.keys())]

sns.heatmap(env_module_heatmap, cmap='YlOrRd', annot=True, fmt='.0f', ax=ax,
            cbar_kws={'label': '% species'})
ax.set_title('Module Prevalence by Environment')
ax.set_ylabel('')

plt.tight_layout()
plt.savefig('../figures/variance_partitioning.png', dpi=150, bbox_inches='tight')
plt.show()
print('Saved figures/variance_partitioning.png')

In [None]:
# Figure 3: Genome size confound analysis
fig, axes = plt.subplots(1, 3, figsize=(16, 5))

# Panel A: Genome size distribution prophage+ vs prophage-
ax = axes[0]
ax.hist(prophage_neg_size, bins=50, alpha=0.6, color='#9E9E9E',
        label=f'No prophage (n={len(prophage_neg_size):,})', density=True)
ax.hist(prophage_pos_size, bins=50, alpha=0.6, color='#E91E63',
        label=f'Prophage+ (n={len(prophage_pos_size):,})', density=True)
ax.axvline(prophage_neg_size.median(), color='#616161', linestyle='--', linewidth=1.5)
ax.axvline(prophage_pos_size.median(), color='#880E4F', linestyle='--', linewidth=1.5)
ax.set_xlabel('Genome size (Mbp)')
ax.set_ylabel('Density')
ax.set_title(f'Genome Size: Prophage+ vs Prophage-\n(p={p_size_mw:.2e})')
ax.legend(fontsize=8)

# Panel B: Genome size vs prophage cluster count
ax = axes[1]
ax.scatter(species_size['genome_size_Mbp'], species_size['n_prophage_clusters'],
           alpha=0.1, s=5, color='#E91E63')
ax.set_xlabel('Genome size (Mbp)')
ax.set_ylabel('Prophage cluster count')
ax.set_title(f'Genome Size vs Prophage Burden\n(rho={rho_size_count:.3f})')

# Panel C: Embedding variance by prophage status
ax = axes[2]
emb_prophage['status'] = emb_prophage['has_prophage'].map({True: 'Prophage+', False: 'No prophage'})
sns.boxplot(data=emb_prophage, x='status', y='total_emb_variance', ax=ax,
            palette={'Prophage+': '#E91E63', 'No prophage': '#9E9E9E'})
ax.set_ylabel('Total embedding variance (niche breadth)')
ax.set_title(f'Environmental Breadth\n(p={p_emb:.2e})')
ax.set_xlabel('')

plt.suptitle('Genome Size Confound Analysis', fontsize=13, y=1.02)
plt.tight_layout()
plt.savefig('../figures/genome_size_confound.png', dpi=150, bbox_inches='tight')
plt.show()
print('Saved figures/genome_size_confound.png')

In [None]:
# Summary
print('='*60)
print('NB04 SUMMARY')
print('='*60)
print(f'Species analyzed: {len(species_all):,}')
print(f'Phyla represented: {species_all["phylum"].nunique()}')
print(f'Species with known environment: {species_known_env.shape[0]:,}')
print(f'Species with genome size: {species_with_size.shape[0]:,}')
print(f'\nPERMANOVA results:')
for _, row in results_df.iterrows():
    sig = '***' if row['p_value'] < 0.001 else ('**' if row['p_value'] < 0.01 else ('*' if row['p_value'] < 0.05 else 'ns'))
    print(f'  {row["predictor"]}: F={row["test_statistic"]:.2f}, p={row["p_value"]} {sig}')
print(f'\nGenome size confound: rho={rho_size_count:.3f} (prophage count ~ genome size)')
print(f'\nFiles saved:')
print(f'  data/species_prophage_environment.tsv')
print(f'  data/variance_partitioning_results.tsv')
print(f'  data/species_genome_size.tsv')
print(f'  figures/prophage_prevalence_by_phylum.png')
print(f'  figures/variance_partitioning.png')
print(f'  figures/genome_size_confound.png')