# NB01: Data Extraction — Ecotype Species Screening

**Run on**: BERDL JupyterHub (Spark available via `get_spark_session()`)

## Goal

Extract and summarize three data dimensions for all 338 species that have phylogenetic tree data:

1. **Phylogenetic substructure** — per-species branch distance statistics from `phylogenetic_tree_distance_pairs`
2. **Environmental diversity** — per-species environmental category counts and entropy from `nmdc_ncbi_biosamples.env_triads_flattened`
3. **Pangenome openness** — core/accessory/singleton counts from `pangenome`

All heavy aggregations run on Spark. Only small summary tables (~338 rows each) are brought to the driver via `.toPandas()`.

## Outputs

| File | Description | Rows |
|------|-------------|------|
| `species_tree_list.csv` | All tree species with phylogenetic_tree_id | ~338 |
| `species_phylo_stats.csv` | Branch distance statistics per species | ~338 |
| `species_pangenome_stats.csv` | Pangenome openness metrics per species | ~338 |
| `species_env_stats.csv` | Environmental diversity metrics per species | ~338 |
| `genome_biosample_map.csv` | genome_id → biosample_accession for tree species | ~90K |

## Key Pitfalls

- **Genome ID format in `phylogenetic_tree_distance_pairs`**: IDs may be bare accessions (no `RS_`/`GB_` prefix). Always verify before joining to `genome` table.
- **Cross-database joins**: Use `nmdc_ncbi_biosamples.env_triads_flattened` with full `database.table` notation.
- **`--` in species IDs**: Avoid placing these in SQL `IN()` clauses. Use Spark DataFrame joins instead.
- **Never `.toPandas()` on large tables**: Aggregate on Spark, collect only summary rows.

In [None]:
# Cell 1: Setup

import os
import numpy as np
import pandas as pd
from pyspark.sql import functions as F

# On JupyterHub: spark is available via get_spark_session() with no import
spark = get_spark_session()

OUTPUT_PATH = "../data"
os.makedirs(OUTPUT_PATH, exist_ok=True)

print(f"Spark version: {spark.version}")
print(f"Output path: {os.path.abspath(OUTPUT_PATH)}")

---
## STEP 1: Load Tree Species Universe

In [None]:
# Cell 2: Load all species with phylogenetic trees + pangenome stats

tree_species_df = spark.sql("""
    SELECT
        pt.gtdb_species_clade_id,
        pt.phylogenetic_tree_id,
        sc.GTDB_species,
        p.no_genomes,
        p.no_core,
        p.no_aux_genome,
        p.no_singleton_gene_clusters,
        p.no_gene_clusters,
        sc.mean_intra_species_ANI,
        sc.min_intra_species_ANI
    FROM kbase_ke_pangenome.phylogenetic_tree pt
    JOIN kbase_ke_pangenome.gtdb_species_clade sc
        ON pt.gtdb_species_clade_id = sc.gtdb_species_clade_id
    JOIN kbase_ke_pangenome.pangenome p
        ON pt.gtdb_species_clade_id = p.gtdb_species_clade_id
""")

tree_species_pd = tree_species_df.toPandas()

# Derived pangenome metrics
tree_species_pd['singleton_fraction'] = (
    tree_species_pd['no_singleton_gene_clusters'] / tree_species_pd['no_gene_clusters']
)
tree_species_pd['core_fraction'] = (
    tree_species_pd['no_core'] / tree_species_pd['no_gene_clusters']
)
tree_species_pd['aux_fraction'] = (
    tree_species_pd['no_aux_genome'] / tree_species_pd['no_gene_clusters']
)

print(f"Tree species loaded: {len(tree_species_pd)}")
print(f"Genome count range: {tree_species_pd['no_genomes'].min()} – {tree_species_pd['no_genomes'].max()}")
print(f"Genome count median: {tree_species_pd['no_genomes'].median():.0f}")
print(f"\nTop 10 by genome count:")
print(tree_species_pd.nlargest(10, 'no_genomes')[['GTDB_species', 'no_genomes', 'singleton_fraction']].to_string())

In [None]:
# Cell 3: Save species tree list and pangenome stats

tree_species_pd.to_csv(f"{OUTPUT_PATH}/species_tree_list.csv", index=False)
print(f"Saved species_tree_list.csv ({len(tree_species_pd)} species)")

# Also save pangenome stats separately for downstream notebooks
pangenome_cols = [
    'gtdb_species_clade_id', 'GTDB_species', 'no_genomes', 'no_core',
    'no_aux_genome', 'no_singleton_gene_clusters', 'no_gene_clusters',
    'singleton_fraction', 'core_fraction', 'aux_fraction',
    'mean_intra_species_ANI', 'min_intra_species_ANI'
]
tree_species_pd[pangenome_cols].to_csv(f"{OUTPUT_PATH}/species_pangenome_stats.csv", index=False)
print(f"Saved species_pangenome_stats.csv")

---
## STEP 2: Phylogenetic Substructure Statistics

Compute branch distance statistics entirely on Spark.
The full table has 22.6M rows — we never `.toPandas()` it directly.
We aggregate to ~338 rows before collecting.

In [None]:
# Cell 4: Check genome ID format in phylogenetic_tree_distance_pairs
# IMPORTANT: IDs here may be bare accessions (no RS_/GB_ prefix)

sample_pairs = spark.sql("""
    SELECT phylogenetic_tree_id, genome1_id, genome2_id, branch_distance
    FROM kbase_ke_pangenome.phylogenetic_tree_distance_pairs
    LIMIT 5
""").toPandas()

print("Sample rows from phylogenetic_tree_distance_pairs:")
print(sample_pairs.to_string())

# Check prefix: do IDs start with RS_ or GB_?
sample_genome_id = sample_pairs['genome1_id'].iloc[0]
print(f"\nSample genome1_id: {sample_genome_id}")
print(f"Starts with RS_ or GB_: {sample_genome_id.startswith('RS_') or sample_genome_id.startswith('GB_')}")

In [None]:
# Cell 5: Determine correct join key for genome IDs in distance table
# If IDs lack prefix, we need to strip RS_/GB_ from genome.genome_id before joining

# Check alignment between phylogenetic_tree_distance_pairs genome IDs
# and kbase_ke_pangenome.genome genome_ids

# Get one species to test
test_tree_id = tree_species_pd['phylogenetic_tree_id'].iloc[0]
test_species_id = tree_species_pd['gtdb_species_clade_id'].iloc[0]

# Get genome IDs from distance table for this species
dist_genome_ids = spark.sql(f"""
    SELECT DISTINCT genome1_id
    FROM kbase_ke_pangenome.phylogenetic_tree_distance_pairs
    WHERE phylogenetic_tree_id = '{test_tree_id}'
    LIMIT 5
""").toPandas()

# Get genome IDs from genome table for same species
genome_table_ids = spark.sql(f"""
    SELECT genome_id
    FROM kbase_ke_pangenome.genome
    WHERE gtdb_species_clade_id = '{test_species_id}'
    LIMIT 5
""").toPandas()

print(f"Test species: {test_species_id[:60]}")
print(f"\ngenome1_id values from distance table: {dist_genome_ids['genome1_id'].tolist()}")
print(f"genome_id values from genome table:    {genome_table_ids['genome_id'].tolist()}")

In [None]:
# Cell 6: Compute per-species branch distance statistics on Spark
#
# Metrics computed per species:
#   - n_genomes_in_tree: count of distinct genomes in tree
#   - n_pairs: total pairwise comparisons
#   - mean, std, min, max, median (p50), IQR (p75-p25) of branch_distance
#   - cv (coefficient of variation = std / mean): high CV -> substructure
#   - max_median_ratio: another substructure indicator

phylo_stats_df = spark.sql("""
    SELECT
        phylogenetic_tree_id,
        COUNT(DISTINCT genome1_id)                                        AS n_genomes_in_tree,
        COUNT(*)                                                           AS n_pairs,
        AVG(branch_distance)                                              AS mean_branch_dist,
        STDDEV(branch_distance)                                           AS std_branch_dist,
        MIN(branch_distance)                                              AS min_branch_dist,
        MAX(branch_distance)                                              AS max_branch_dist,
        PERCENTILE(branch_distance, 0.25)                                 AS q25_branch_dist,
        PERCENTILE(branch_distance, 0.50)                                 AS median_branch_dist,
        PERCENTILE(branch_distance, 0.75)                                 AS q75_branch_dist,
        PERCENTILE(branch_distance, 0.90)                                 AS q90_branch_dist,
        PERCENTILE(branch_distance, 0.75) - PERCENTILE(branch_distance, 0.25) AS iqr_branch_dist,
        STDDEV(branch_distance) / NULLIF(AVG(branch_distance), 0)        AS cv_branch_dist,
        MAX(branch_distance) / NULLIF(PERCENTILE(branch_distance, 0.50), 0) AS max_median_ratio
    FROM kbase_ke_pangenome.phylogenetic_tree_distance_pairs
    GROUP BY phylogenetic_tree_id
""")

phylo_stats_pd = phylo_stats_df.toPandas()

print(f"Phylo stats computed for {len(phylo_stats_pd)} species")
print(f"\nCV summary (coefficient of variation):")
print(phylo_stats_pd['cv_branch_dist'].describe())

print(f"\nSpecies with highest CV (most substructure):")
print(phylo_stats_pd.nlargest(10, 'cv_branch_dist')[
    ['phylogenetic_tree_id', 'n_genomes_in_tree', 'cv_branch_dist', 'max_median_ratio']
].to_string())

In [None]:
# Cell 7: Join phylo stats back to species IDs and save

phylo_stats_pd = phylo_stats_pd.merge(
    tree_species_pd[['gtdb_species_clade_id', 'GTDB_species', 'phylogenetic_tree_id']],
    on='phylogenetic_tree_id',
    how='left'
)

# How many tree IDs matched to a species?
n_matched = phylo_stats_pd['gtdb_species_clade_id'].notna().sum()
print(f"Phylo stats with matched species ID: {n_matched}/{len(phylo_stats_pd)}")
if n_matched < len(phylo_stats_pd):
    print("Unmatched tree IDs (investigate):")
    print(phylo_stats_pd[phylo_stats_pd['gtdb_species_clade_id'].isna()]['phylogenetic_tree_id'].tolist())

phylo_stats_pd.to_csv(f"{OUTPUT_PATH}/species_phylo_stats.csv", index=False)
print(f"\nSaved species_phylo_stats.csv")

---
## STEP 3: Environmental Diversity via NMDC BioSamples

Link: `kbase_ke_pangenome.sample` → `nmdc_ncbi_biosamples.env_triads_flattened`

We compute, per species:
- Number of genomes with any environmental annotation
- Number of distinct `env_broad_scale` categories
- Shannon entropy of `env_broad_scale` distribution

In [None]:
# Cell 8: Explore nmdc_ncbi_biosamples schema

# Check what tables are available
spark.sql("SHOW TABLES IN nmdc_ncbi_biosamples").show(20, truncate=False)

# Check env_triads_flattened schema
print("\nenv_triads_flattened schema:")
spark.sql("DESCRIBE nmdc_ncbi_biosamples.env_triads_flattened").show(30, truncate=False)

print("\nSample rows from env_triads_flattened:")
spark.sql("SELECT * FROM nmdc_ncbi_biosamples.env_triads_flattened LIMIT 3").show(truncate=False)

In [None]:
# Cell 9: Check the biosamples_flattened table for fallback env fields
# (if env_triads_flattened has sparse coverage)

print("biosamples_flattened schema (first 20 cols):")
spark.sql("DESCRIBE nmdc_ncbi_biosamples.biosamples_flattened").show(20, truncate=False)

# Check a sample row for env-related columns
print("\nSample row (env-related columns only):")
spark.sql("""
    SELECT accession, isolation_source, env_broad_scale, env_local_scale, env_medium,
           geo_loc_name, lat_lon
    FROM nmdc_ncbi_biosamples.biosamples_flattened
    WHERE env_broad_scale IS NOT NULL
    LIMIT 5
""").show(truncate=False)

In [None]:
# Cell 10: Get genome → biosample accession map for all tree species
#
# Join genome table to sample table, filtered to tree species only.
# Use a Spark join (not isin() on a Python list) to avoid serialization issues
# with large numbers of species IDs.

# Create a Spark DataFrame of tree species IDs to use as a filter
tree_species_spark = spark.createDataFrame(
    tree_species_pd[['gtdb_species_clade_id']]
)

# Get genome → biosample accession for tree species
genome_biosample_df = spark.sql("""
    SELECT
        g.genome_id,
        g.gtdb_species_clade_id,
        s.ncbi_biosample_accession_id
    FROM kbase_ke_pangenome.genome g
    JOIN kbase_ke_pangenome.sample s
        ON g.genome_id = s.genome_id
""").join(tree_species_spark, on='gtdb_species_clade_id', how='inner')

genome_biosample_df.cache()

n_genomes = genome_biosample_df.count()
n_with_biosample = genome_biosample_df.filter(
    F.col('ncbi_biosample_accession_id').isNotNull()
).count()

print(f"Genomes in tree species: {n_genomes:,}")
print(f"Genomes with biosample accession: {n_with_biosample:,} ({100*n_with_biosample/n_genomes:.1f}%)")

# Save genome-biosample map
genome_biosample_pd = genome_biosample_df.toPandas()
genome_biosample_pd.to_csv(f"{OUTPUT_PATH}/genome_biosample_map.csv", index=False)
print(f"\nSaved genome_biosample_map.csv ({len(genome_biosample_pd):,} rows)")

In [None]:
# Cell 11: Join genomes to env_triads_flattened and compute env diversity per species
#
# Strategy: join genome → biosample → env_triads, then aggregate per species.
# env_broad_scale is the primary diversity signal (marine, soil, freshwater, host, etc.)
# Compute on Spark; collect only the ~338-row summary.

env_diversity_df = spark.sql("""
    SELECT
        g.gtdb_species_clade_id,
        COUNT(DISTINCT g.genome_id)                                           AS n_genomes_total,
        COUNT(DISTINCT CASE WHEN e.env_broad_scale IS NOT NULL
                            THEN g.genome_id END)                             AS n_genomes_with_env,
        COUNT(DISTINCT e.env_broad_scale)                                     AS n_distinct_env_broad,
        COUNT(DISTINCT e.env_local_scale)                                     AS n_distinct_env_local,
        COUNT(DISTINCT e.env_medium)                                          AS n_distinct_env_medium
    FROM kbase_ke_pangenome.genome g
    JOIN kbase_ke_pangenome.sample s
        ON g.genome_id = s.genome_id
    LEFT JOIN nmdc_ncbi_biosamples.env_triads_flattened e
        ON s.ncbi_biosample_accession_id = e.accession
    GROUP BY g.gtdb_species_clade_id
""").join(tree_species_spark, on='gtdb_species_clade_id', how='inner')

env_diversity_pd = env_diversity_df.toPandas()

env_diversity_pd['env_coverage_fraction'] = (
    env_diversity_pd['n_genomes_with_env'] / env_diversity_pd['n_genomes_total']
)

print(f"Species with env diversity computed: {len(env_diversity_pd)}")
print(f"\nEnv coverage fraction summary:")
print(env_diversity_pd['env_coverage_fraction'].describe())
print(f"\nSpecies with zero env coverage: {(env_diversity_pd['env_coverage_fraction'] == 0).sum()}")
print(f"\nn_distinct_env_broad summary:")
print(env_diversity_pd['n_distinct_env_broad'].describe())
print(f"\nTop 10 by env_broad diversity:")
top_env = env_diversity_pd.merge(
    tree_species_pd[['gtdb_species_clade_id', 'GTDB_species']],
    on='gtdb_species_clade_id', how='left'
).nlargest(10, 'n_distinct_env_broad')
print(top_env[['GTDB_species', 'n_genomes_with_env', 'n_distinct_env_broad', 'env_coverage_fraction']].to_string())

In [None]:
# Cell 12: Compute Shannon entropy of env_broad_scale distribution per species
#
# Shannon entropy H = -sum(p * log(p)) where p = fraction of genomes in each env category.
# Higher entropy = more evenly spread across environments.

# Get per-species per-env counts
env_counts_df = spark.sql("""
    SELECT
        g.gtdb_species_clade_id,
        e.env_broad_scale,
        COUNT(DISTINCT g.genome_id) AS n_genomes_in_env
    FROM kbase_ke_pangenome.genome g
    JOIN kbase_ke_pangenome.sample s
        ON g.genome_id = s.genome_id
    JOIN nmdc_ncbi_biosamples.env_triads_flattened e
        ON s.ncbi_biosample_accession_id = e.accession
    WHERE e.env_broad_scale IS NOT NULL
    GROUP BY g.gtdb_species_clade_id, e.env_broad_scale
""").join(tree_species_spark, on='gtdb_species_clade_id', how='inner')

env_counts_pd = env_counts_df.toPandas()

# Compute Shannon entropy per species
def shannon_entropy(counts):
    total = counts.sum()
    if total == 0:
        return 0.0
    probs = counts / total
    return float(-np.sum(probs * np.log2(probs + 1e-12)))

entropy_per_species = (
    env_counts_pd
    .groupby('gtdb_species_clade_id')['n_genomes_in_env']
    .apply(shannon_entropy)
    .reset_index()
    .rename(columns={'n_genomes_in_env': 'env_broad_entropy'})
)

print(f"Entropy computed for {len(entropy_per_species)} species with env annotations")
print(entropy_per_species['env_broad_entropy'].describe())

# Save per-environment breakdown for downstream analysis
env_counts_pd.to_csv(f"{OUTPUT_PATH}/species_env_category_counts.csv", index=False)
print(f"\nSaved species_env_category_counts.csv ({len(env_counts_pd):,} rows)")

In [None]:
# Cell 13: Merge env stats and save

species_env_stats = env_diversity_pd.merge(
    entropy_per_species, on='gtdb_species_clade_id', how='left'
).merge(
    tree_species_pd[['gtdb_species_clade_id', 'GTDB_species']],
    on='gtdb_species_clade_id', how='left'
)

# Fill entropy = 0 for species with no env annotations
species_env_stats['env_broad_entropy'] = species_env_stats['env_broad_entropy'].fillna(0.0)

species_env_stats.to_csv(f"{OUTPUT_PATH}/species_env_stats.csv", index=False)
print(f"Saved species_env_stats.csv ({len(species_env_stats)} species)")

print(f"\nSpecies with zero env coverage (will score 0 for env diversity): "
      f"{(species_env_stats['n_genomes_with_env'] == 0).sum()}")
print(f"Tip: species with 0 env coverage are not disqualified — they score lowest on env dimension only.")

---
## STEP 4: Sanity Checks and Handoff Summary

In [None]:
# Cell 14: Verify all output files and print summary

import os

output_files = [
    'species_tree_list.csv',
    'species_pangenome_stats.csv',
    'species_phylo_stats.csv',
    'species_env_stats.csv',
    'species_env_category_counts.csv',
    'genome_biosample_map.csv',
]

print("=== Output File Inventory ===")
for f in output_files:
    fpath = os.path.join(OUTPUT_PATH, f)
    if os.path.exists(fpath):
        df = pd.read_csv(fpath)
        print(f"  {f}: {len(df):,} rows, {df.shape[1]} cols")
    else:
        print(f"  MISSING: {f}")

print()
print("=== Universe Summary ===")
print(f"  Total tree species: {len(tree_species_pd)}")
print(f"  Species with >=20 genomes in tree: {(tree_species_pd['no_genomes'] >= 20).sum()}")
print(f"  Species with >=50 genomes in tree: {(tree_species_pd['no_genomes'] >= 50).sum()}")
print()
print("=== Phylo Substructure Summary ===")
phylo_check = pd.read_csv(f"{OUTPUT_PATH}/species_phylo_stats.csv")
print(f"  Median CV: {phylo_check['cv_branch_dist'].median():.3f}")
print(f"  Top quartile CV threshold (75th pct): {phylo_check['cv_branch_dist'].quantile(0.75):.3f}")
print()
print("=== Environmental Diversity Summary ===")
env_check = pd.read_csv(f"{OUTPUT_PATH}/species_env_stats.csv")
n_any_env = (env_check['n_genomes_with_env'] > 0).sum()
print(f"  Species with any env annotation: {n_any_env}/{len(env_check)} ({100*n_any_env/len(env_check):.0f}%)")
print(f"  Median n_distinct_env_broad (among annotated): "
      f"{env_check[env_check['n_genomes_with_env']>0]['n_distinct_env_broad'].median():.1f}")
print(f"  Median env entropy (among annotated): "
      f"{env_check[env_check['env_broad_entropy']>0]['env_broad_entropy'].median():.2f}")

print()
print("All done. Download data/ directory and proceed with NB02 locally.")