# Temporal Core Genome Dynamics - Data Extraction

## Goal
Extract genome data with collection dates and gene cluster memberships for:
- **Pseudomonas aeruginosa** (~6,760 genomes)
- **Acinetobacter baumannii** (~6,647 genomes)

## Data Needed
1. Genomes with valid collection dates (from ncbi_env table)
2. Gene cluster memberships per genome (which clusters present in which genomes)

In [None]:
# Cell 1: Setup

import numpy as np
import pandas as pd
import os
import re
from datetime import datetime
from pyspark.sql.functions import col, count, countDistinct

# Output directory
OUTPUT_PATH = "/home/psdehal/pangenome_science/data/temporal_core"
os.makedirs(OUTPUT_PATH, exist_ok=True)

# Target species (using LIKE pattern to catch the species prefix)
SPECIES = {
    'p_aeruginosa': 's__Pseudomonas_aeruginosa',
    'a_baumannii': 's__Acinetobacter_baumannii'
}

print(f"Output path: {OUTPUT_PATH}")
print(f"Target species: {list(SPECIES.keys())}")

---
## STEP 1: Extract Genomes with Collection Dates

In [None]:
# Cell 2: Query genomes with collection dates for both species

# Get genomes with collection date metadata
# Filter out invalid date values

INVALID_DATES = [
    'missing', 'not applicable', 'not collected', 
    'Unknown', 'unknown', 'not provided', 'NA', 'N/A'
]

genomes_with_dates = {}

for species_key, species_prefix in SPECIES.items():
    print(f"\n=== Extracting {species_key} ===")
    
    # Query genomes with collection dates
    query = f"""
        SELECT
            g.genome_id,
            g.gtdb_species_clade_id,
            g.ncbi_biosample_id,
            ne.content as collection_date_raw
        FROM kbase_ke_pangenome.genome g
        JOIN kbase_ke_pangenome.ncbi_env ne
            ON g.ncbi_biosample_id = ne.accession
        WHERE g.gtdb_species_clade_id LIKE '{species_prefix}%'
            AND ne.attribute_name = 'collection_date'
            AND ne.content IS NOT NULL
    """
    
    df = spark.sql(query).toPandas()
    print(f"  Total genomes with collection_date field: {len(df)}")
    
    # Filter out invalid dates
    df_valid = df[~df['collection_date_raw'].isin(INVALID_DATES)].copy()
    print(f"  After removing invalid values: {len(df_valid)}")
    
    genomes_with_dates[species_key] = df_valid
    
    # Show sample dates
    print(f"  Sample dates: {df_valid['collection_date_raw'].head(10).tolist()}")

In [None]:
# Cell 3: Parse collection dates into standardized format

def parse_collection_date(date_str):
    """
    Parse variable date formats from ncbi_env collection_date field.
    
    Common formats:
    - YYYY-MM-DD (2015-03-12)
    - YYYY/MM/DD (2015/03/12)
    - YYYY-MM (2015-03)
    - YYYY (2015)
    - Date ranges (2013/2014, 2015-04/2016-09)
    
    Returns: datetime object or None
    For partial dates (YYYY or YYYY-MM), uses mid-point (July 1 or 15th)
    For ranges, uses start date
    """
    if pd.isna(date_str) or not isinstance(date_str, str):
        return None
    
    date_str = date_str.strip()
    
    # Handle ranges - take first date
    if '/' in date_str and len(date_str) > 10:
        date_str = date_str.split('/')[0]
    
    # Try various formats
    formats_to_try = [
        ('%Y-%m-%d', None),       # 2015-03-12 -> exact
        ('%Y/%m/%d', None),       # 2015/03/12 -> exact
        ('%Y-%m', 'month'),       # 2015-03 -> mid-month (15th)
        ('%Y/%m', 'month'),       # 2015/03 -> mid-month
        ('%Y', 'year'),           # 2015 -> mid-year (July 1)
    ]
    
    for fmt, granularity in formats_to_try:
        try:
            dt = datetime.strptime(date_str, fmt)
            if granularity == 'month':
                dt = dt.replace(day=15)
            elif granularity == 'year':
                dt = dt.replace(month=7, day=1)
            return dt
        except ValueError:
            continue
    
    # Try to extract just the year
    year_match = re.search(r'(19|20)\d{2}', date_str)
    if year_match:
        try:
            year = int(year_match.group())
            return datetime(year, 7, 1)
        except:
            pass
    
    return None

# Apply parsing to both species
for species_key, df in genomes_with_dates.items():
    print(f"\n=== Parsing dates for {species_key} ===")
    
    df['collection_date'] = df['collection_date_raw'].apply(parse_collection_date)
    
    n_parsed = df['collection_date'].notna().sum()
    n_total = len(df)
    print(f"  Successfully parsed: {n_parsed}/{n_total} ({100*n_parsed/n_total:.1f}%)")
    
    # Filter to only parsed dates
    df_parsed = df[df['collection_date'].notna()].copy()
    genomes_with_dates[species_key] = df_parsed
    
    # Date range
    print(f"  Date range: {df_parsed['collection_date'].min()} to {df_parsed['collection_date'].max()}")
    
    # Show failed parses
    failed = df[df['collection_date'].isna()]['collection_date_raw'].unique()[:10]
    print(f"  Sample unparsed values: {failed}")

In [None]:
# Cell 4: Visualize date distributions

import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

for idx, (species_key, df) in enumerate(genomes_with_dates.items()):
    ax = axes[idx]
    
    # Extract years
    years = df['collection_date'].dt.year
    
    ax.hist(years, bins=range(years.min(), years.max()+2), 
            edgecolor='black', alpha=0.7)
    ax.set_xlabel('Collection Year')
    ax.set_ylabel('Number of Genomes')
    ax.set_title(f"{species_key}\n(n={len(df)} genomes with dates)")
    ax.axvline(years.median(), color='red', linestyle='--', 
               label=f'Median: {int(years.median())}')
    ax.legend()

plt.tight_layout()
plt.savefig(f"{OUTPUT_PATH}/collection_date_distributions.png", dpi=150)
plt.show()

print("Saved: collection_date_distributions.png")

In [None]:
# Cell 5: Save genome lists with dates

for species_key, df in genomes_with_dates.items():
    output_file = f"{OUTPUT_PATH}/{species_key}_genomes.parquet"
    
    # Prepare for export
    export_df = df[['genome_id', 'gtdb_species_clade_id', 
                    'collection_date_raw', 'collection_date']].copy()
    export_df['collection_year'] = export_df['collection_date'].dt.year
    
    # Sort by collection date
    export_df = export_df.sort_values('collection_date').reset_index(drop=True)
    
    # Save
    export_df.to_parquet(output_file)
    print(f"Saved {len(export_df)} genomes to {output_file}")
    
    # Summary stats
    print(f"  Year range: {export_df['collection_year'].min()} - {export_df['collection_year'].max()}")
    print(f"  Genomes per year: {export_df.groupby('collection_year').size().describe()}\n")

---
## STEP 2: Extract Gene Cluster Memberships

For each genome, we need to know which gene clusters are present.
This creates a genome-to-cluster presence matrix.

In [None]:
# Cell 6: Extract gene cluster memberships per genome

# For each species, get which gene clusters are present in which genomes
# This joins: gene -> gene_genecluster_junction -> gene_cluster

from pyspark.sql.functions import monotonically_increasing_id

for species_key, species_prefix in SPECIES.items():
    print(f"\n=== Extracting gene clusters for {species_key} ===")
    
    # Get target genome IDs (those with valid dates)
    target_genomes = genomes_with_dates[species_key]['genome_id'].tolist()
    print(f"  Target genomes: {len(target_genomes)}")
    
    # Query gene clusters present in each genome
    # Aggregate to get unique (genome_id, gene_cluster_id) pairs
    query = f"""
        SELECT DISTINCT
            g.genome_id,
            gg.gene_cluster_id
        FROM kbase_ke_pangenome.gene g
        JOIN kbase_ke_pangenome.gene_genecluster_junction gg
            ON g.gene_id = gg.gene_id
        WHERE g.genome_id IN ({','.join([f"'{gid}'" for gid in target_genomes])})
    """
    
    # This could be large - count first
    clusters_df = spark.sql(query)
    clusters_df.cache()
    total_pairs = clusters_df.count()
    print(f"  Total genome-cluster pairs: {total_pairs:,}")
    
    # Get cluster counts
    n_clusters = clusters_df.select('gene_cluster_id').distinct().count()
    print(f"  Unique gene clusters: {n_clusters:,}")
    
    # Export in chunks if large
    CLUSTERS_PATH = f"{OUTPUT_PATH}/{species_key}_gene_clusters"
    os.makedirs(CLUSTERS_PATH, exist_ok=True)
    
    CHUNK_SIZE = 5000000
    if total_pairs > CHUNK_SIZE:
        clusters_with_id = clusters_df.withColumn("_id", monotonically_increasing_id())
        n_chunks = (total_pairs // CHUNK_SIZE) + 1
        
        for i in range(n_chunks):
            start_id = i * CHUNK_SIZE
            end_id = (i + 1) * CHUNK_SIZE
            
            chunk = clusters_with_id.filter(
                (col("_id") >= start_id) & (col("_id") < end_id)
            ).drop("_id")
            
            chunk_pd = chunk.toPandas()
            chunk_pd.to_parquet(f"{CLUSTERS_PATH}/chunk_{i:03d}.parquet")
            print(f"    Saved chunk {i+1}/{n_chunks}: {len(chunk_pd):,} rows")
    else:
        # Small enough to save as single file
        clusters_pd = clusters_df.toPandas()
        clusters_pd.to_parquet(f"{CLUSTERS_PATH}/all_clusters.parquet")
        print(f"  Saved all {len(clusters_pd):,} pairs")
    
    clusters_df.unpersist()

---
## STEP 3: Verify Data and Create Summary

In [None]:
# Cell 7: Verification and summary

print("=== DATA EXTRACTION SUMMARY ===")
print()

summary_stats = []

for species_key in SPECIES.keys():
    # Load genome data
    genomes_df = pd.read_parquet(f"{OUTPUT_PATH}/{species_key}_genomes.parquet")
    
    # Load cluster data
    clusters_path = f"{OUTPUT_PATH}/{species_key}_gene_clusters"
    if os.path.exists(f"{clusters_path}/all_clusters.parquet"):
        clusters_df = pd.read_parquet(f"{clusters_path}/all_clusters.parquet")
    else:
        # Combine chunks
        chunks = []
        for f in sorted(os.listdir(clusters_path)):
            if f.endswith('.parquet'):
                chunks.append(pd.read_parquet(f"{clusters_path}/{f}"))
        clusters_df = pd.concat(chunks, ignore_index=True)
    
    # Statistics
    n_genomes = len(genomes_df)
    n_clusters = clusters_df['gene_cluster_id'].nunique()
    year_min = genomes_df['collection_year'].min()
    year_max = genomes_df['collection_year'].max()
    genes_per_genome = clusters_df.groupby('genome_id').size()
    
    print(f"{species_key}:")
    print(f"  Genomes with dates: {n_genomes:,}")
    print(f"  Unique gene clusters: {n_clusters:,}")
    print(f"  Collection years: {year_min} - {year_max}")
    print(f"  Clusters per genome: {genes_per_genome.mean():.0f} mean, {genes_per_genome.min()}-{genes_per_genome.max()} range")
    print()
    
    summary_stats.append({
        'species': species_key,
        'n_genomes': n_genomes,
        'n_clusters': n_clusters,
        'year_min': year_min,
        'year_max': year_max,
        'mean_clusters_per_genome': genes_per_genome.mean()
    })

# Save summary
summary_df = pd.DataFrame(summary_stats)
summary_df.to_csv(f"{OUTPUT_PATH}/extraction_summary.csv", index=False)
print(f"Summary saved to: {OUTPUT_PATH}/extraction_summary.csv")

In [None]:
# Cell 8: Compare with full pangenome statistics

# Get expected core sizes from pangenome table for validation
for species_key, species_prefix in SPECIES.items():
    query = f"""
        SELECT
            p.gtdb_species_clade_id,
            p.no_genomes,
            p.no_core,
            p.no_aux_genome,
            p.no_singleton_gene_clusters,
            s.mean_intra_species_ANI
        FROM kbase_ke_pangenome.pangenome p
        JOIN kbase_ke_pangenome.gtdb_species_clade s
            ON p.gtdb_species_clade_id = s.gtdb_species_clade_id
        WHERE p.gtdb_species_clade_id LIKE '{species_prefix}%'
    """
    
    result = spark.sql(query).toPandas()
    
    print(f"\n{species_key} - Full pangenome statistics:")
    print(result.to_string(index=False))
    print(f"\n  Our dated subset: {len(genomes_with_dates[species_key])} genomes")
    print(f"  Coverage: {100 * len(genomes_with_dates[species_key]) / result['no_genomes'].iloc[0]:.1f}%")

---
## Summary

After running this notebook, the following files are created:

- `p_aeruginosa_genomes.parquet` - Genomes with parsed collection dates
- `a_baumannii_genomes.parquet` - Genomes with parsed collection dates
- `p_aeruginosa_gene_clusters/` - Gene cluster memberships (genome_id, gene_cluster_id)
- `a_baumannii_gene_clusters/` - Gene cluster memberships
- `collection_date_distributions.png` - Visualization of date distributions
- `extraction_summary.csv` - Summary statistics

**Next Step**: Run `02_sliding_window.ipynb` to calculate core genome across time windows.