# NB 06: Pangenome-Scale Metal Tolerance Prediction

Use conserved metal fitness gene families to predict metal tolerance
capabilities across all 27,690 BERDL pangenome species.

**Requires BERDL JupyterHub** â€” uses `get_spark_session()` for pangenome queries.

**Strategy**: Map conserved metal OG families to their functional annotations
(COG, KEGG), then scan the pangenome for species enriched in those terms.

**Inputs**:
- `data/conserved_metal_families.csv` (from NB04)
- `data/metal_important_genes.csv` (from NB02)
- `conservation_vs_fitness/data/fb_pangenome_link.tsv`

**Outputs**:
- `data/metal_functional_signature.csv`
- `data/species_metal_scores.csv`
- `figures/species_metal_score_distribution.png`
- `figures/bioleaching_species_scores.png`

In [1]:
import pandas as pd
import numpy as np
from pathlib import Path
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns

from berdl_notebook_utils.setup_spark_session import get_spark_session
spark = get_spark_session()
print(f'Spark version: {spark.version}')

PROJECT_DIR = Path('..').resolve()
DATA_DIR = PROJECT_DIR / 'data'
FIGURES_DIR = PROJECT_DIR / 'figures'
DATA_DIR.mkdir(exist_ok=True)
FIGURES_DIR.mkdir(exist_ok=True)

Spark version: 4.0.1


## 1. Build Metal Functional Signature

Get the functional annotations (COG, KEGG) for gene clusters linked to
conserved metal fitness families. These form the "metal signature" that
we'll scan for across all pangenome species.

In [2]:
# Load conserved metal families and FB-pangenome link
conserved = pd.read_csv(DATA_DIR / 'conserved_metal_families.csv')
print(f'Conserved metal families: {len(conserved)}')

metal_important = pd.read_csv(DATA_DIR / 'metal_important_genes.csv')
metal_important['locusId'] = metal_important['locusId'].astype(str)

fb_link = pd.read_csv(
    PROJECT_DIR.parent / 'conservation_vs_fitness' / 'data' / 'fb_pangenome_link.tsv',
    sep='\t'
)
fb_link['locusId'] = fb_link['locusId'].astype(str)
print(f'FB-pangenome links: {len(fb_link):,}')

# Load ortholog groups
og_path = PROJECT_DIR.parent / 'essential_genome' / 'data' / 'all_ortholog_groups.csv'
ortholog_groups = pd.read_csv(og_path)
ortholog_groups['locusId'] = ortholog_groups['locusId'].astype(str)

# Get gene_cluster_ids for metal-important genes in conserved families
conserved_ogs = set(conserved['OG_id'].unique())
metal_og = metal_important.merge(
    ortholog_groups[['orgId', 'locusId', 'OG_id']],
    on=['orgId', 'locusId'], how='inner'
)
metal_og_conserved = metal_og[metal_og['OG_id'].isin(conserved_ogs)]

# Map to gene_cluster_ids
metal_clusters = metal_og_conserved.merge(
    fb_link[['orgId', 'locusId', 'gene_cluster_id']].drop_duplicates(),
    on=['orgId', 'locusId'], how='inner'
)
metal_cluster_ids = metal_clusters['gene_cluster_id'].unique().tolist()
print(f'Metal-important gene clusters (conserved families): {len(metal_cluster_ids):,}')

Conserved metal families: 1182


FB-pangenome links: 177,863
Metal-important gene clusters (conserved families): 3,806


In [3]:
# Query eggnog annotations for metal gene clusters
# Register cluster IDs as temp view for efficient Spark join
cluster_df = spark.createDataFrame(
    [(cid,) for cid in metal_cluster_ids], ['query_name']
)
cluster_df.createOrReplaceTempView('metal_clusters')

metal_annot = spark.sql("""
    SELECT /*+ BROADCAST(mc) */
        e.query_name, e.COG_category, e.KEGG_ko, e.PFAMs, e.Description
    FROM kbase_ke_pangenome.eggnog_mapper_annotations e
    JOIN metal_clusters mc ON e.query_name = mc.query_name
""").toPandas()

print(f'Metal cluster annotations retrieved: {len(metal_annot):,}')
print(f'With COG: {metal_annot["COG_category"].notna().sum():,}')
print(f'With KEGG: {metal_annot["KEGG_ko"].notna().sum():,}')
print(f'With PFAM: {metal_annot["PFAMs"].notna().sum():,}')

Metal cluster annotations retrieved: 3,795
With COG: 3,795
With KEGG: 3,795
With PFAM: 3,795


In [4]:
# Extract the KEGG KO terms associated with conserved metal families
# These are the most specific functional markers
metal_kegg = set()
for kos in metal_annot['KEGG_ko'].dropna():
    for ko in str(kos).split(','):
        ko = ko.strip()
        if ko.startswith('ko:'):
            metal_kegg.add(ko)

# Extract PFAM terms
metal_pfam = set()
for pfams in metal_annot['PFAMs'].dropna():
    for pf in str(pfams).split(','):
        pf = pf.strip()
        if pf.startswith('PF'):
            metal_pfam.add(pf)

print(f'Metal functional signature:')
print(f'  KEGG KOs: {len(metal_kegg)}')
print(f'  PFAMs: {len(metal_pfam)}')

# Save the signature
sig_df = pd.DataFrame({
    'term': list(metal_kegg) + list(metal_pfam),
    'type': ['KEGG'] * len(metal_kegg) + ['PFAM'] * len(metal_pfam),
})
sig_df.to_csv(DATA_DIR / 'metal_functional_signature.csv', index=False)
print(f'Saved: data/metal_functional_signature.csv ({len(sig_df)} terms)')

Metal functional signature:
  KEGG KOs: 1286
  PFAMs: 1
Saved: data/metal_functional_signature.csv (1287 terms)


## 2. Score All Pangenome Species

For each of the 27,690 species, count how many metal signature KEGG/PFAM
terms are present in their gene clusters. This is the pangenome-scale
metal tolerance prediction.

In [5]:
from pyspark.sql import functions as F

kegg_list = list(metal_kegg)
kegg_spark = spark.createDataFrame([(k,) for k in kegg_list], ['ko_term'])

print(f'Querying pangenome for all {len(kegg_list)} metal KEGG terms via PySpark JOIN...')

# Step 1: Explode KEGG_ko and join with metal terms (PySpark DataFrame API)
eggnog_exploded = spark.table('kbase_ke_pangenome.eggnog_mapper_annotations') \
    .filter(F.col('KEGG_ko').isNotNull()) \
    .select('query_name', F.explode(F.split('KEGG_ko', ',')).alias('ko_raw')) \
    .withColumn('ko_term', F.trim(F.col('ko_raw')))

metal_hits = eggnog_exploded.join(F.broadcast(kegg_spark), 'ko_term') \
    .select(F.col('query_name').alias('gene_cluster_id'), 'ko_term') \
    .distinct()
metal_hits.createOrReplaceTempView('metal_kegg_hits')
n_hits = metal_hits.count()
print(f'Gene clusters matching metal KEGG terms: {n_hits:,}')

# Step 2: Count per species
species_metal_counts = spark.sql("""
    SELECT gc.gtdb_species_clade_id,
           COUNT(DISTINCT mkh.gene_cluster_id) as n_metal_clusters,
           COUNT(DISTINCT mkh.ko_term) as n_distinct_kegg_hits
    FROM metal_kegg_hits mkh
    JOIN kbase_ke_pangenome.gene_cluster gc ON mkh.gene_cluster_id = gc.gene_cluster_id
    GROUP BY gc.gtdb_species_clade_id
""")

# Step 3: Total annotated clusters per species (for normalization)
species_annotated = spark.sql("""
    SELECT gc.gtdb_species_clade_id,
           COUNT(DISTINCT gc.gene_cluster_id) as n_annotated_clusters
    FROM kbase_ke_pangenome.gene_cluster gc
    JOIN kbase_ke_pangenome.eggnog_mapper_annotations e ON gc.gene_cluster_id = e.query_name
    WHERE e.KEGG_ko IS NOT NULL
    GROUP BY gc.gtdb_species_clade_id
""")

pangenome_stats = spark.sql("""
    SELECT gtdb_species_clade_id,
           CAST(no_genomes AS INT) as no_genomes,
           CAST(no_gene_clusters AS INT) as no_gene_clusters,
           CAST(no_core AS INT) as no_core
    FROM kbase_ke_pangenome.pangenome
""").toPandas()

species_scores = species_metal_counts.toPandas()
sp_annot = species_annotated.toPandas()
print(f'Species with metal KEGG hits: {len(species_scores):,}')

species_scores = species_scores.merge(pangenome_stats, on='gtdb_species_clade_id', how='left')
species_scores = species_scores.merge(sp_annot, on='gtdb_species_clade_id', how='left')

# Raw score and genome-size normalized score
species_scores['metal_score_raw'] = species_scores['n_metal_clusters'] / len(kegg_list)
species_scores['metal_score_norm'] = species_scores['n_metal_clusters'] / species_scores['n_annotated_clusters']

species_scores = species_scores.sort_values('metal_score_norm', ascending=False)

print(f'\nTop 20 species by NORMALIZED metal score (genome-size corrected):')
print('=' * 110)
for _, row in species_scores.head(20).iterrows():
    species = row['gtdb_species_clade_id'].split('--')[0]
    print(f'  {species:50s}  norm={row.metal_score_norm:.4f}  '
          f'metal={int(row.n_metal_clusters)}/{int(row.n_annotated_clusters)} annotated  '
          f'genomes={int(row.no_genomes)}')

Querying pangenome for all 1286 metal KEGG terms via PySpark JOIN...


Gene clusters matching metal KEGG terms: 22,267,167


Species with metal KEGG hits: 27,702

Top 20 species by NORMALIZED metal score (genome-size corrected):
  s__Pantoea_A_carbekii                               norm=0.4970  metal=411/827 annotated  genomes=2
  s__Kinetoplastibacterium_blastocrithidii            norm=0.4620  metal=334/723 annotated  genomes=2
  s__Kinetoplastibacterium_crithidii                  norm=0.4609  metal=342/742 annotated  genomes=2
  s__Buchnera_aphidicola                              norm=0.4539  metal=261/575 annotated  genomes=2
  s__Buchnera_aphidicola_AO                           norm=0.4536  metal=269/593 annotated  genomes=2
  s__Buchnera_aphidicola_R                            norm=0.4528  metal=264/583 annotated  genomes=2
  s__Buchnera_aphidicola_C                            norm=0.4514  metal=265/587 annotated  genomes=5
  s__SoEE_sp002933335                                 norm=0.4513  metal=306/678 annotated  genomes=2
  s__Buchnera_aphidicola_U                            norm=0.4512  metal=259/574

## 3. Validate Against Known Metal-Tolerant Species

Check where known bioleaching organisms and ENIGMA isolates rank
in the metal tolerance score distribution.

In [6]:
# Known metal-relevant genera/species to look for
bioleaching_genera = [
    'Acidithiobacillus', 'Leptospirillum', 'Sulfobacillus',
    'Cupriavidus', 'Ralstonia',
    'Shewanella', 'Desulfovibrio', 'Geobacter',
    'Pseudomonas', 'Rhodanobacter',
    'Herbaspirillum', 'Caulobacter', 'Sphingomonas',
    'Gluconobacter',  # REE bioleaching
    'Marinobacter',
]

def match_genus(clade_id, genera):
    for g in genera:
        if g.lower() in clade_id.lower():
            return g
    return None

species_scores['matched_genus'] = species_scores['gtdb_species_clade_id'].apply(
    lambda x: match_genus(x, bioleaching_genera)
)

bioleaching = species_scores[species_scores['matched_genus'].notna()]
non_bioleaching = species_scores[species_scores['matched_genus'].isna()]

print(f'Bioleaching/metal-relevant species found: {len(bioleaching)}')
print(f'Other species: {len(non_bioleaching)}')

# Test with NORMALIZED score (genome-size corrected)
print(f'\nNormalized metal score comparison:')
print(f'  Bioleaching genera median: {bioleaching["metal_score_norm"].median():.4f}')
print(f'  All species median:        {species_scores["metal_score_norm"].median():.4f}')

from scipy import stats
u, p = stats.mannwhitneyu(
    bioleaching['metal_score_norm'], non_bioleaching['metal_score_norm'],
    alternative='greater'
)
print(f'  Mann-Whitney U: {u:.0f}, p={p:.3e} (bioleaching > other, normalized)')

# Also test raw score
u_raw, p_raw = stats.mannwhitneyu(
    bioleaching['metal_score_raw'], non_bioleaching['metal_score_raw'],
    alternative='greater'
)
print(f'  Mann-Whitney U: {u_raw:.0f}, p={p_raw:.3e} (bioleaching > other, raw)')

# Per-genus summary (normalized)
print(f'\nMetal scores by genus (normalized):')
genus_scores = bioleaching.groupby('matched_genus').agg(
    n_species=('metal_score_norm', 'count'),
    mean_norm=('metal_score_norm', 'mean'),
    median_norm=('metal_score_norm', 'median'),
    median_raw=('metal_score_raw', 'median'),
).sort_values('median_norm', ascending=False)

for _, row in genus_scores.iterrows():
    percentile = (species_scores['metal_score_norm'] < row['median_norm']).mean() * 100
    print(f'  {row.name:20s}  {int(row.n_species):4d} species  '
          f'norm={row.median_norm:.4f}  raw={row.median_raw:.3f}  '
          f'({percentile:.0f}th percentile)')

Bioleaching/metal-relevant species found: 838
Other species: 26864

Normalized metal score comparison:
  Bioleaching genera median: 0.2459
  All species median:        0.2443
  Mann-Whitney U: 11469426, p=1.746e-01 (bioleaching > other, normalized)
  Mann-Whitney U: 19848543, p=0.000e+00 (bioleaching > other, raw)

Metal scores by genus (normalized):
  Leptospirillum           4 species  norm=0.3143  raw=0.641  (91th percentile)
  Acidithiobacillus       14 species  norm=0.2805  raw=0.729  (77th percentile)
  Marinobacter            48 species  norm=0.2777  raw=0.850  (75th percentile)
  Sulfobacillus           14 species  norm=0.2725  raw=0.799  (71th percentile)
  Herbaspirillum          13 species  norm=0.2725  raw=1.034  (71th percentile)
  Desulfovibrio           54 species  norm=0.2683  raw=0.626  (68th percentile)
  Gluconobacter           16 species  norm=0.2587  raw=0.661  (61th percentile)
  Rhodanobacter           17 species  norm=0.2560  raw=0.698  (59th percentile)
  Shewa

## 4. Figures

In [7]:
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Left: Normalized score distribution
ax = axes[0]
ax.hist(non_bioleaching['metal_score_norm'], bins=50, color='gray', alpha=0.6,
        label='All species', edgecolor='black', linewidth=0.3)
ax.hist(bioleaching['metal_score_norm'], bins=50, color='#e74c3c', alpha=0.8,
        label='Bioleaching genera', edgecolor='black', linewidth=0.3)
ax.set_xlabel('Normalized Metal Score (metal clusters / annotated clusters)')
ax.set_ylabel('Number of Species')
ax.set_title(f'Metal Tolerance Score Distribution ({len(species_scores):,} species)')
ax.legend()

# Right: Box plot by genus (normalized)
ax = axes[1]
top_genera = genus_scores.head(10).index.tolist()
plot_data = bioleaching[bioleaching['matched_genus'].isin(top_genera)]
if len(plot_data) > 0:
    genus_order = plot_data.groupby('matched_genus')['metal_score_norm'].median().sort_values(ascending=False).index
    sns.boxplot(data=plot_data, x='matched_genus', y='metal_score_norm',
                order=genus_order, ax=ax, color='#3498db')
    ax.axhline(species_scores['metal_score_norm'].median(), color='red',
               linestyle='--', alpha=0.7, label='Global median')
    ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
    ax.set_ylabel('Normalized Metal Score')
    ax.set_title('Metal Scores by Genus (genome-size normalized)')
    ax.legend()

plt.suptitle('Pangenome-Scale Metal Tolerance Prediction', fontsize=14, y=1.02)
plt.tight_layout()
fig.savefig(FIGURES_DIR / 'species_metal_score_distribution.png', dpi=150, bbox_inches='tight')
plt.show()
print(f'Saved: figures/species_metal_score_distribution.png')

  ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')


Saved: figures/species_metal_score_distribution.png


In [8]:
# Bioleaching-focused: key genera
key_genera = ['Acidithiobacillus', 'Leptospirillum', 'Cupriavidus',
              'Shewanella', 'Desulfovibrio', 'Geobacter']
key_species = bioleaching[bioleaching['matched_genus'].isin(key_genera)].copy()
key_species['short_name'] = key_species['gtdb_species_clade_id'].apply(
    lambda x: x.split('--')[0].replace('s__', '')
)

if len(key_species) > 0:
    fig, ax = plt.subplots(figsize=(12, max(6, len(key_species) * 0.3)))
    key_species_sorted = key_species.sort_values('metal_score_norm')
    colors = [{'Acidithiobacillus': '#e74c3c', 'Leptospirillum': '#e67e22',
               'Cupriavidus': '#2ecc71', 'Shewanella': '#3498db',
               'Desulfovibrio': '#9b59b6', 'Geobacter': '#1abc9c'
              }.get(g, 'gray') for g in key_species_sorted['matched_genus']]

    ax.barh(range(len(key_species_sorted)), key_species_sorted['metal_score_norm'],
            color=colors, edgecolor='black', linewidth=0.3)
    ax.set_yticks(range(len(key_species_sorted)))
    ax.set_yticklabels(key_species_sorted['short_name'], fontsize=7)
    ax.axvline(species_scores['metal_score_norm'].median(), color='gray',
               linestyle='--', alpha=0.7, label='Global median')
    ax.set_xlabel('Normalized Metal Score')
    ax.set_title('Metal Tolerance Scores: Key Bioleaching & Metal-Reducing Species')
    ax.legend()
    plt.tight_layout()
    fig.savefig(FIGURES_DIR / 'bioleaching_species_scores.png', dpi=150, bbox_inches='tight')
    plt.show()
    print(f'Saved: figures/bioleaching_species_scores.png')
else:
    print('No key bioleaching species found in pangenome')

Saved: figures/bioleaching_species_scores.png


## 5. Save Results

In [9]:
species_scores.to_csv(DATA_DIR / 'species_metal_scores.csv', index=False)
print(f'Saved: data/species_metal_scores.csv ({len(species_scores):,} species)')

print('\n' + '=' * 80)
print('NB06 SUMMARY: Pangenome-Scale Metal Tolerance Prediction')
print('=' * 80)
print(f'Metal functional signature: {len(metal_kegg)} KEGG terms (full, no truncation)')
print(f'Gene clusters matching: {n_hits:,}')
print(f'Species scored: {len(species_scores):,}')
print(f'Normalized score range: {species_scores["metal_score_norm"].min():.4f} - '
      f'{species_scores["metal_score_norm"].max():.4f}')
print(f'Bioleaching genera: {len(bioleaching)} species')
print(f'Bioleaching vs other (normalized): Mann-Whitney p={p:.3e}')
print('=' * 80)

Saved: data/species_metal_scores.csv (27,702 species)

NB06 SUMMARY: Pangenome-Scale Metal Tolerance Prediction
Metal functional signature: 1286 KEGG terms (full, no truncation)
Gene clusters matching: 22,267,167
Species scored: 27,702
Normalized score range: 0.0657 - 0.4970
Bioleaching genera: 838 species
Bioleaching vs other (normalized): Mann-Whitney p=1.746e-01
