# NB01: Data Extraction for Co-fitness Co-inheritance Analysis

Extract genome × gene cluster presence matrices, co-fitness pairs, gene coordinates,
and phylogenetic distances for 11 target organisms.

**Requires BERDL JupyterHub** — `get_spark_session()` must be available.

This notebook is the interactive equivalent of `src/extract_data.py`.
For batch extraction, run the script directly.

In [1]:
import pandas as pd
import numpy as np
from pathlib import Path

# Import Spark session — works on JupyterHub and locally
try:
    get_spark_session
except NameError:
    from berdl_notebook_utils.setup_spark_session import get_spark_session

spark = get_spark_session()
print(f"Spark version: {spark.version}")

DATA_DIR = Path('../data')
CONS_DIR = Path('../../conservation_vs_fitness/data')

for subdir in ['genome_cluster_matrices', 'cofit', 'gene_coords', 'phylo_distances']:
    (DATA_DIR / subdir).mkdir(parents=True, exist_ok=True)

# Target organisms and their species clades
TARGET_ORGANISMS = {
    'Koxy': 's__Klebsiella_michiganensis--RS_GCF_002925905.1',
    'Btheta': 's__Bacteroides_thetaiotaomicron--RS_GCF_000011065.1',
    'Smeli': 's__Sinorhizobium_meliloti--RS_GCF_017876815.1',
    'RalstoniaUW163': 's__Ralstonia_solanacearum--RS_GCF_002251695.1',
    'Putida': 's__Pseudomonas_E_alloputida--RS_GCF_021282585.1',
    'SyringaeB728a': 's__Pseudomonas_E_syringae_M--RS_GCF_009176725.1',
    'Korea': 's__Sphingomonas_koreensis--RS_GCF_002797435.1',
    'RalstoniaGMI1000': 's__Ralstonia_pseudosolanacearum--RS_GCF_024925465.1',
    'Phaeo': 's__Phaeobacter_inhibens--RS_GCF_000473105.1',
    'Ddia6719': 's__Dickeya_dianthicola--RS_GCF_000365305.1',
    'pseudo3_N2E3': 's__Pseudomonas_E_fluorescens_E--RS_GCF_001307155.1',
}

# Load shared data
link = pd.read_csv(CONS_DIR / 'fb_pangenome_link.tsv', sep='\t')
link = link[link['orgId'] != 'Dyella79']
org_mapping = pd.read_csv(CONS_DIR / 'organism_mapping.tsv', sep='\t')

print(f"Link table: {len(link):,} rows, {link['orgId'].nunique()} organisms")
print(f"Target organisms: {len(TARGET_ORGANISMS)}")

Spark version: 4.0.1


Link table: 173,582 rows, 43 organisms
Target organisms: 11


## Step 0: Verify Target Species Pangenome Stats

In [2]:
clade_ids = list(TARGET_ORGANISMS.values())
clade_str = "','".join(clade_ids)

stats = spark.sql(f"""
    SELECT p.gtdb_species_clade_id,
           s.GTDB_species,
           p.no_genomes,
           p.no_core,
           p.no_aux_genome,
           p.no_singleton_gene_clusters,
           p.no_gene_clusters,
           s.mean_intra_species_ANI
    FROM kbase_ke_pangenome.pangenome p
    JOIN kbase_ke_pangenome.gtdb_species_clade s
        ON p.gtdb_species_clade_id = s.gtdb_species_clade_id
    WHERE p.gtdb_species_clade_id IN ('{clade_str}')
    ORDER BY p.no_genomes DESC
""").toPandas()

# Add orgId column
clade_to_org = {v: k for k, v in TARGET_ORGANISMS.items()}
stats['orgId'] = stats['gtdb_species_clade_id'].map(clade_to_org)

# Count FB genes per organism
fb_counts = link.groupby('orgId').agg(
    n_fb_genes=('locusId', 'nunique'),
    n_aux_fb=('is_auxiliary', lambda x: (x == True).sum())
).reset_index()
stats = stats.merge(fb_counts, on='orgId', how='left')

print(stats[['orgId', 'GTDB_species', 'no_genomes', 'mean_intra_species_ANI',
             'no_gene_clusters', 'n_fb_genes', 'n_aux_fb']].to_string(index=False))

           orgId                    GTDB_species  no_genomes  mean_intra_species_ANI  no_gene_clusters  n_fb_genes  n_aux_fb
            Koxy     s__Klebsiella_michiganensis         399                   98.57             61735        4965       822
          Btheta s__Bacteroides_thetaiotaomicron         287                   98.44             65634        4727      1632
           Smeli       s__Sinorhizobium_meliloti         241                   98.93             58199        6123      1365
  RalstoniaUW163       s__Ralstonia_solanacearum         141                   96.27             23007        4303       867
          Putida     s__Pseudomonas_E_alloputida         128                   97.49             42747        5470      1372
   SyringaeB728a     s__Pseudomonas_E_syringae_M         126                   98.70             29917        5031       723
           Korea       s__Sphingomonas_koreensis          72                   98.13              7633        4116       637


## Step 1: Extract Genome × Gene Cluster Presence Matrices

For each species, build a binary matrix: rows = genomes, columns = gene clusters.
Only include clusters that FB genes map to (from `fb_pangenome_link.tsv`).

**Performance note**: Each organism requires joining `gene_genecluster_junction` (~1B rows)
with `gene` (~1B rows). BROADCAST hints on the small filter tables reduce join cost.
Expect ~3-5 min per organism, ~45 min total. Already-cached matrices are skipped.

In [3]:
import time

matrix_summary = []

for orgId, clade_id in TARGET_ORGANISMS.items():
    outpath = DATA_DIR / 'genome_cluster_matrices' / f'{orgId}_presence.tsv'
    if outpath.exists() and outpath.stat().st_size > 0:
        cached = pd.read_csv(outpath, sep='\t', index_col=0)
        matrix_summary.append({'orgId': orgId, 'genomes': cached.shape[0],
                               'clusters': cached.shape[1], 'status': 'cached'})
        print(f"  [{orgId}] Cached: {cached.shape[0]} genomes x {cached.shape[1]} clusters")
        continue

    print(f"  [{orgId}] Extracting...", flush=True)
    t0 = time.time()

    # Get target cluster IDs
    org_clusters = link[link['gtdb_species_clade_id'] == clade_id]['gene_cluster_id'].unique()
    if len(org_clusters) == 0:
        org_clusters = link[link['orgId'] == orgId]['gene_cluster_id'].unique()
    print(f"    Target clusters: {len(org_clusters)}")

    if len(org_clusters) == 0:
        print(f"    WARNING: No clusters for {orgId}, skipping")
        continue

    # Register small filter tables for BROADCAST joins
    cluster_df = spark.createDataFrame([(c,) for c in org_clusters], ['gene_cluster_id'])
    cluster_df.createOrReplaceTempView('target_clusters')

    genome_ids = spark.sql(f"""
        SELECT genome_id FROM kbase_ke_pangenome.genome
        WHERE gtdb_species_clade_id = '{clade_id}'
    """).toPandas()['genome_id'].tolist()
    genome_df = spark.createDataFrame([(g,) for g in genome_ids], ['genome_id'])
    genome_df.createOrReplaceTempView('target_genomes')

    # Use BROADCAST hints on small tables to avoid shuffle joins
    presence = spark.sql("""
        SELECT /*+ BROADCAST(tc), BROADCAST(tg) */
            DISTINCT g.genome_id, j.gene_cluster_id
        FROM kbase_ke_pangenome.gene_genecluster_junction j
        JOIN target_clusters tc ON j.gene_cluster_id = tc.gene_cluster_id
        JOIN kbase_ke_pangenome.gene g ON j.gene_id = g.gene_id
        JOIN target_genomes tg ON g.genome_id = tg.genome_id
    """).toPandas()

    elapsed = time.time() - t0
    print(f"    Raw presence rows: {len(presence):,} ({elapsed:.0f}s)")

    if len(presence) == 0:
        print(f"    WARNING: No data for {orgId}")
        continue

    presence['present'] = 1
    matrix = presence.pivot_table(
        index='genome_id', columns='gene_cluster_id',
        values='present', fill_value=0, aggfunc='max'
    )

    matrix_summary.append({'orgId': orgId, 'genomes': matrix.shape[0],
                           'clusters': matrix.shape[1], 'status': 'extracted'})
    print(f"    Matrix: {matrix.shape[0]} genomes x {matrix.shape[1]} clusters")
    matrix.to_csv(outpath, sep='\t')

print("\n=== MATRIX SUMMARY ===")
print(pd.DataFrame(matrix_summary).to_string(index=False))

  [Koxy] Cached: 399 genomes x 4942 clusters


  [Btheta] Cached: 287 genomes x 4649 clusters


  [Smeli] Cached: 241 genomes x 6004 clusters
  [RalstoniaUW163] Cached: 141 genomes x 4413 clusters


  [Putida] Cached: 128 genomes x 5409 clusters
  [SyringaeB728a] Extracting...


    Target clusters: 4999


    Raw presence rows: 558,868 (218s)
    Matrix: 126 genomes x 4999 clusters


  [Korea] Extracting...


    Target clusters: 4075


    Raw presence rows: 254,177 (211s)
    Matrix: 72 genomes x 4075 clusters
  [RalstoniaGMI1000] Extracting...


    Target clusters: 4723


    Raw presence rows: 285,319 (220s)
    Matrix: 70 genomes x 4723 clusters
  [Phaeo] Extracting...


    Target clusters: 3790


    Raw presence rows: 145,722 (209s)
    Matrix: 43 genomes x 3790 clusters
  [Ddia6719] Extracting...


    Target clusters: 4694


    Raw presence rows: 256,026 (209s)
    Matrix: 66 genomes x 4694 clusters
  [pseudo3_N2E3] Extracting...


    Target clusters: 5513


    Raw presence rows: 214,589 (211s)
    Matrix: 40 genomes x 5513 clusters

=== MATRIX SUMMARY ===
           orgId  genomes  clusters    status
            Koxy      399      4942    cached
          Btheta      287      4649    cached
           Smeli      241      6004    cached
  RalstoniaUW163      141      4413    cached
          Putida      128      5409    cached
   SyringaeB728a      126      4999 extracted
           Korea       72      4075 extracted
RalstoniaGMI1000       70      4723 extracted
           Phaeo       43      3790 extracted
        Ddia6719       66      4694 extracted
    pseudo3_N2E3       40      5513 extracted


## Step 2: Extract Co-fitness Pairs

In [4]:
cofit_summary = []

for orgId in TARGET_ORGANISMS:
    outpath = DATA_DIR / 'cofit' / f'{orgId}_cofit.tsv'
    if outpath.exists() and outpath.stat().st_size > 0:
        cached = pd.read_csv(outpath, sep='\t')
        cofit_summary.append({'orgId': orgId, 'pairs': len(cached), 'status': 'cached'})
        print(f"  [{orgId}] Cached: {len(cached):,} pairs")
        continue

    print(f"  [{orgId}] Extracting...", end='', flush=True)
    cofit = spark.sql(f"""
        SELECT orgId, locusId, hitId,
               CAST(rank AS INT) as rank,
               CAST(cofit AS FLOAT) as cofit
        FROM kescience_fitnessbrowser.cofit
        WHERE orgId = '{orgId}'
        ORDER BY locusId, CAST(rank AS INT)
    """).toPandas()
    print(f" {len(cofit):,} pairs")

    cofit.to_csv(outpath, sep='\t', index=False)
    cofit_summary.append({'orgId': orgId, 'pairs': len(cofit), 'status': 'extracted'})

print("\n=== COFIT SUMMARY ===")
print(pd.DataFrame(cofit_summary).to_string(index=False))

  [Koxy] Extracting...

 423,936 pairs


  [Btheta] Extracting...

 328,455 pairs


  [Smeli] Extracting...

 528,699 pairs


  [RalstoniaUW163] Extracting...

 0 pairs
  [Putida] Extracting...

 458,688 pairs


  [SyringaeB728a] Extracting...

 371,004 pairs


  [Korea] Extracting...

 230,724 pairs


  [RalstoniaGMI1000] Extracting...

 0 pairs
  [Phaeo] Extracting...

 192,138 pairs


  [Ddia6719] Extracting...

 250,488 pairs


  [pseudo3_N2E3] Extracting...

 507,828 pairs



=== COFIT SUMMARY ===
           orgId  pairs    status
            Koxy 423936 extracted
          Btheta 328455 extracted
           Smeli 528699 extracted
  RalstoniaUW163      0 extracted
          Putida 458688 extracted
   SyringaeB728a 371004 extracted
           Korea 230724 extracted
RalstoniaGMI1000      0 extracted
           Phaeo 192138 extracted
        Ddia6719 250488 extracted
    pseudo3_N2E3 507828 extracted


## Step 3: Extract Gene Coordinates

In [5]:
for orgId in TARGET_ORGANISMS:
    outpath = DATA_DIR / 'gene_coords' / f'{orgId}_coords.tsv'
    if outpath.exists() and outpath.stat().st_size > 0:
        print(f"  [{orgId}] Cached")
        continue

    print(f"  [{orgId}] Extracting...", end='', flush=True)
    coords = spark.sql(f"""
        SELECT orgId, locusId, scaffoldId,
               CAST(begin AS INT) as begin,
               CAST(end AS INT) as end,
               strand
        FROM kescience_fitnessbrowser.gene
        WHERE orgId = '{orgId}'
        ORDER BY scaffoldId, CAST(begin AS INT)
    """).toPandas()
    print(f" {len(coords):,} genes")
    coords.to_csv(outpath, sep='\t', index=False)

  [Koxy] Extracting...

 5,586 genes
  [Btheta] Extracting...

 4,902 genes
  [Smeli] Extracting...

 6,281 genes
  [RalstoniaUW163] Extracting...

 5,006 genes
  [Putida] Extracting...

 5,661 genes
  [SyringaeB728a] Extracting...

 5,216 genes
  [Korea] Extracting...

 4,245 genes
  [RalstoniaGMI1000] Extracting...

 5,204 genes
  [Phaeo] Extracting...

 3,944 genes
  [Ddia6719] Extracting...

 4,338 genes
  [pseudo3_N2E3] Extracting...

 5,854 genes


## Step 4: Extract Phylogenetic Distances

In [6]:
clade_str = "','".join(TARGET_ORGANISMS.values())

tree_mapping = spark.sql(f"""
    SELECT gtdb_species_clade_id, phylogenetic_tree_id
    FROM kbase_ke_pangenome.phylogenetic_tree
    WHERE gtdb_species_clade_id IN ('{clade_str}')
""").toPandas()

clade_to_org = {v: k for k, v in TARGET_ORGANISMS.items()}
print(f"Species with phylogenetic trees: {len(tree_mapping)}/{len(TARGET_ORGANISMS)}")

for _, row in tree_mapping.iterrows():
    clade_id = row['gtdb_species_clade_id']
    tree_id = row['phylogenetic_tree_id']
    orgId = clade_to_org.get(clade_id)
    if orgId is None:
        continue

    outpath = DATA_DIR / 'phylo_distances' / f'{orgId}_phylo_distances.tsv'
    if outpath.exists() and outpath.stat().st_size > 0:
        print(f"  [{orgId}] Cached")
        continue

    print(f"  [{orgId}] Extracting...", end='', flush=True)
    distances = spark.sql(f"""
        SELECT genome1_id, genome2_id, branch_distance
        FROM kbase_ke_pangenome.phylogenetic_tree_distance_pairs
        WHERE phylogenetic_tree_id = '{tree_id}'
    """).toPandas()
    print(f" {len(distances):,} pairs")
    distances.to_csv(outpath, sep='\t', index=False)

# Save reference genome mapping
ref_genomes = org_mapping[
    org_mapping['orgId'].isin(TARGET_ORGANISMS.keys())
][['orgId', 'gtdb_species_clade_id', 'pg_genome_id']].drop_duplicates()
ref_genomes.to_csv(DATA_DIR / 'phylo_distances' / 'reference_genomes.tsv',
                   sep='\t', index=False)
print(f"\nReference genomes saved: {len(ref_genomes)} rows")

Species with phylogenetic trees: 9/11
  [Koxy] Extracting...

 79,401 pairs
  [Btheta] Extracting...

 41,041 pairs
  [Smeli] Extracting...

 28,920 pairs
  [RalstoniaGMI1000] Extracting...

 2,415 pairs
  [RalstoniaUW163] Extracting...

 9,870 pairs
  [Putida] Extracting...

 8,128 pairs
  [SyringaeB728a] Extracting...

 7,875 pairs
  [Ddia6719] Extracting...

 2,145 pairs
  [Korea] Extracting...

 2,556 pairs

Reference genomes saved: 95 rows


In [7]:
print('=' * 60)
print('NB01 SUMMARY: Data Extraction')
print('=' * 60)
print(f'Target organisms: {len(TARGET_ORGANISMS)}')
print(f'Species with phylogenetic trees: {len(tree_mapping)}')
print(f'\nOutput directories:')
for subdir in ['genome_cluster_matrices', 'cofit', 'gene_coords', 'phylo_distances']:
    files = list((DATA_DIR / subdir).glob('*.tsv'))
    print(f'  {subdir}/: {len(files)} files')
print('=' * 60)

NB01 SUMMARY: Data Extraction
Target organisms: 11
Species with phylogenetic trees: 9

Output directories:
  genome_cluster_matrices/: 11 files
  cofit/: 11 files
  gene_coords/: 11 files
  phylo_distances/: 10 files
