# Temporal Core Genome Dynamics - Sliding Window Analysis

## Goal
Calculate how the core genome changes as we progressively add genomes sorted by collection date.

## Approaches
1. **Cumulative Expansion**: Start with earliest 30 genomes, add one at a time, track core
2. **Fixed Time Windows**: Group into 2-year bins, calculate core within each bin

## Core Thresholds
- 90% presence
- 95% presence  
- 99% presence

In [None]:
# Cell 1: Setup and load data

import numpy as np
import pandas as pd
import os
from collections import defaultdict
import matplotlib.pyplot as plt
from tqdm import tqdm

# Data paths
DATA_PATH = "/home/psdehal/pangenome_science/data/temporal_core"
OUTPUT_PATH = DATA_PATH  # Same directory

SPECIES = ['p_aeruginosa', 'a_baumannii']
THRESHOLDS = [0.90, 0.95, 0.99]
MIN_WINDOW_SIZE = 30  # Minimum genomes for core calculation

print(f"Loading data from: {DATA_PATH}")
print(f"Thresholds: {THRESHOLDS}")
print(f"Minimum window size: {MIN_WINDOW_SIZE}")

In [None]:
# Cell 2: Load genome and cluster data

data = {}

for species in SPECIES:
    print(f"\n=== Loading {species} ===")
    
    # Load genomes with dates
    genomes_df = pd.read_parquet(f"{DATA_PATH}/{species}_genomes.parquet")
    print(f"  Genomes: {len(genomes_df)}")
    
    # Load gene clusters
    clusters_path = f"{DATA_PATH}/{species}_gene_clusters"
    if os.path.exists(f"{clusters_path}/all_clusters.parquet"):
        clusters_df = pd.read_parquet(f"{clusters_path}/all_clusters.parquet")
    else:
        chunks = []
        for f in sorted(os.listdir(clusters_path)):
            if f.endswith('.parquet'):
                chunks.append(pd.read_parquet(f"{clusters_path}/{f}"))
        clusters_df = pd.concat(chunks, ignore_index=True)
    print(f"  Cluster memberships: {len(clusters_df):,}")
    
    # Create genome-to-clusters mapping (set of clusters per genome)
    genome_clusters = clusters_df.groupby('genome_id')['gene_cluster_id'].apply(set).to_dict()
    print(f"  Unique clusters: {clusters_df['gene_cluster_id'].nunique():,}")
    
    # Get ordered list of genomes by collection date
    genomes_df = genomes_df.sort_values('collection_date').reset_index(drop=True)
    ordered_genomes = genomes_df['genome_id'].tolist()
    
    data[species] = {
        'genomes_df': genomes_df,
        'genome_clusters': genome_clusters,
        'ordered_genomes': ordered_genomes,
        'all_clusters': set(clusters_df['gene_cluster_id'].unique())
    }
    
    print(f"  Date range: {genomes_df['collection_date'].min()} to {genomes_df['collection_date'].max()}")

---
## Core Calculation Functions

In [None]:
# Cell 3: Define core calculation functions

def calculate_core(genome_ids, genome_clusters, threshold):
    """
    Calculate core genome for a set of genomes at a given threshold.
    
    Args:
        genome_ids: List of genome IDs to include
        genome_clusters: Dict mapping genome_id -> set of cluster_ids
        threshold: Fraction of genomes that must have a cluster (e.g., 0.95)
    
    Returns:
        set of cluster_ids that are 'core' (present in >= threshold fraction)
    """
    n_genomes = len(genome_ids)
    min_presence = int(np.ceil(threshold * n_genomes))
    
    # Count presence of each cluster
    cluster_counts = defaultdict(int)
    for gid in genome_ids:
        if gid in genome_clusters:
            for cid in genome_clusters[gid]:
                cluster_counts[cid] += 1
    
    # Filter to core
    core = {cid for cid, count in cluster_counts.items() if count >= min_presence}
    
    return core


def jaccard_similarity(set1, set2):
    """Calculate Jaccard similarity between two sets."""
    if len(set1) == 0 and len(set2) == 0:
        return 1.0
    intersection = len(set1 & set2)
    union = len(set1 | set2)
    return intersection / union if union > 0 else 0.0


# Test on a small subset
test_species = SPECIES[0]
test_genomes = data[test_species]['ordered_genomes'][:50]
test_core = calculate_core(test_genomes, data[test_species]['genome_clusters'], 0.95)
print(f"Test: Core size for first 50 {test_species} genomes at 95%: {len(test_core)}")

---
## Approach 1: Cumulative Expansion

Sort genomes by collection date, then progressively add genomes and track core size.

In [None]:
# Cell 4: Cumulative expansion analysis

cumulative_results = []

for species in SPECIES:
    print(f"\n=== Cumulative analysis: {species} ===")
    
    ordered_genomes = data[species]['ordered_genomes']
    genome_clusters = data[species]['genome_clusters']
    genomes_df = data[species]['genomes_df']
    
    n_total = len(ordered_genomes)
    
    # Sample points (every 10 genomes to speed up, plus key points)
    sample_points = list(range(MIN_WINDOW_SIZE, n_total, 10))
    if n_total not in sample_points:
        sample_points.append(n_total)
    
    print(f"  Calculating core at {len(sample_points)} sample points...")
    
    prev_cores = {t: None for t in THRESHOLDS}
    
    for n in tqdm(sample_points):
        current_genomes = ordered_genomes[:n]
        
        # Get date of last genome added
        last_date = genomes_df.iloc[n-1]['collection_date']
        
        for threshold in THRESHOLDS:
            core = calculate_core(current_genomes, genome_clusters, threshold)
            
            # Calculate change from previous
            if prev_cores[threshold] is not None:
                jaccard = jaccard_similarity(core, prev_cores[threshold])
                gained = len(core - prev_cores[threshold])
                lost = len(prev_cores[threshold] - core)
            else:
                jaccard = None
                gained = None
                lost = None
            
            cumulative_results.append({
                'species': species,
                'approach': 'cumulative',
                'n_genomes': n,
                'last_date': last_date,
                'threshold': threshold,
                'core_size': len(core),
                'jaccard_vs_prev': jaccard,
                'gained': gained,
                'lost': lost
            })
            
            prev_cores[threshold] = core
    
    # Summary
    for threshold in THRESHOLDS:
        species_results = [r for r in cumulative_results 
                          if r['species'] == species and r['threshold'] == threshold]
        initial = species_results[0]['core_size']
        final = species_results[-1]['core_size']
        print(f"  {int(threshold*100)}% threshold: {initial} -> {final} ({100*(final-initial)/initial:+.1f}%)")

cumulative_df = pd.DataFrame(cumulative_results)
print(f"\nTotal results: {len(cumulative_df)}")

In [None]:
# Cell 5: Visualize cumulative core decay

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

colors = {0.90: 'blue', 0.95: 'green', 0.99: 'red'}

for idx, species in enumerate(SPECIES):
    ax = axes[idx]
    
    species_data = cumulative_df[cumulative_df['species'] == species]
    
    for threshold in THRESHOLDS:
        thresh_data = species_data[species_data['threshold'] == threshold]
        ax.plot(thresh_data['n_genomes'], thresh_data['core_size'], 
                color=colors[threshold], label=f"{int(threshold*100)}% threshold",
                linewidth=2)
    
    ax.set_xlabel('Number of Genomes (chronological)')
    ax.set_ylabel('Core Size (gene clusters)')
    ax.set_title(f'{species}\nCore Decay by Collection Date')
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f"{OUTPUT_PATH}/cumulative_core_decay.png", dpi=150)
plt.show()

print("Saved: cumulative_core_decay.png")

---
## Approach 2: Fixed Time Windows

Group genomes into 2-year bins, calculate core within each bin independently.

In [None]:
# Cell 6: Fixed time window analysis

WINDOW_YEARS = 2  # 2-year windows

window_results = []
window_cores = {}  # Store actual core sets for comparison

for species in SPECIES:
    print(f"\n=== Time window analysis: {species} ===")
    
    genomes_df = data[species]['genomes_df'].copy()
    genome_clusters = data[species]['genome_clusters']
    
    # Create year bins
    genomes_df['year'] = genomes_df['collection_date'].dt.year
    min_year = genomes_df['year'].min()
    max_year = genomes_df['year'].max()
    
    # Create 2-year windows
    bins = list(range(min_year, max_year + WINDOW_YEARS, WINDOW_YEARS))
    genomes_df['window'] = pd.cut(genomes_df['year'], bins=bins, right=False,
                                   labels=[f"{bins[i]}-{bins[i+1]-1}" for i in range(len(bins)-1)])
    
    print(f"  Year range: {min_year} - {max_year}")
    print(f"  Windows: {len(bins)-1}")
    
    window_cores[species] = {}
    
    for window_label in genomes_df['window'].cat.categories:
        window_genomes = genomes_df[genomes_df['window'] == window_label]['genome_id'].tolist()
        n_genomes = len(window_genomes)
        
        if n_genomes < MIN_WINDOW_SIZE:
            print(f"    {window_label}: {n_genomes} genomes (< {MIN_WINDOW_SIZE}, skipping)")
            continue
        
        window_cores[species][window_label] = {}
        
        for threshold in THRESHOLDS:
            core = calculate_core(window_genomes, genome_clusters, threshold)
            
            window_results.append({
                'species': species,
                'approach': 'fixed_window',
                'window': window_label,
                'n_genomes': n_genomes,
                'threshold': threshold,
                'core_size': len(core)
            })
            
            window_cores[species][window_label][threshold] = core
        
        print(f"    {window_label}: {n_genomes} genomes, 95% core = {len(window_cores[species][window_label][0.95])}")

window_df = pd.DataFrame(window_results)
print(f"\nTotal window results: {len(window_df)}")

In [None]:
# Cell 7: Calculate inter-window Jaccard similarities

# Compare core composition between adjacent time windows
jaccard_results = []

for species in SPECIES:
    print(f"\n=== Inter-window similarity: {species} ===")
    
    windows = sorted(window_cores[species].keys())
    
    for threshold in THRESHOLDS:
        print(f"  {int(threshold*100)}% threshold:")
        
        for i in range(len(windows) - 1):
            w1, w2 = windows[i], windows[i+1]
            core1 = window_cores[species][w1].get(threshold, set())
            core2 = window_cores[species][w2].get(threshold, set())
            
            if len(core1) == 0 or len(core2) == 0:
                continue
            
            jaccard = jaccard_similarity(core1, core2)
            shared = len(core1 & core2)
            only_w1 = len(core1 - core2)
            only_w2 = len(core2 - core1)
            
            jaccard_results.append({
                'species': species,
                'threshold': threshold,
                'window_1': w1,
                'window_2': w2,
                'core_1_size': len(core1),
                'core_2_size': len(core2),
                'shared': shared,
                'only_in_w1': only_w1,
                'only_in_w2': only_w2,
                'jaccard': jaccard
            })
            
            print(f"    {w1} -> {w2}: Jaccard={jaccard:.3f}, shared={shared}, lost={only_w1}, gained={only_w2}")

jaccard_df = pd.DataFrame(jaccard_results)

In [None]:
# Cell 8: Visualize fixed window cores

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

for idx, species in enumerate(SPECIES):
    ax = axes[idx]
    
    species_data = window_df[window_df['species'] == species]
    windows = species_data['window'].unique()
    
    x = np.arange(len(windows))
    width = 0.25
    
    for i, threshold in enumerate(THRESHOLDS):
        thresh_data = species_data[species_data['threshold'] == threshold]
        thresh_data = thresh_data.set_index('window').loc[windows]
        ax.bar(x + i*width, thresh_data['core_size'], width,
               label=f"{int(threshold*100)}%", color=colors[threshold], alpha=0.8)
    
    ax.set_xlabel('Time Window')
    ax.set_ylabel('Core Size')
    ax.set_title(f'{species}\nCore Size by Time Window')
    ax.set_xticks(x + width)
    ax.set_xticklabels(windows, rotation=45, ha='right')
    ax.legend()
    ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig(f"{OUTPUT_PATH}/fixed_window_cores.png", dpi=150)
plt.show()

print("Saved: fixed_window_cores.png")

---
## Track Core Turnover: Which Genes Leave/Enter Core?

In [None]:
# Cell 9: Identify stable core vs transient core genes

# For each species, identify:
# - Genes that are core in ALL windows (stable core)
# - Genes that are core in SOME windows (transient)
# - Genes that enter/leave core over time

core_stability = []

for species in SPECIES:
    print(f"\n=== Core stability analysis: {species} ===")
    
    windows = sorted(window_cores[species].keys())
    
    for threshold in THRESHOLDS:
        # Get all cores
        all_cores = [window_cores[species][w].get(threshold, set()) for w in windows]
        all_cores = [c for c in all_cores if len(c) > 0]  # Filter empty
        
        if len(all_cores) < 2:
            continue
        
        # Stable core = intersection of all window cores
        stable_core = set.intersection(*all_cores)
        
        # Ever-core = union of all window cores
        ever_core = set.union(*all_cores)
        
        # Transient = ever-core minus stable
        transient_core = ever_core - stable_core
        
        print(f"  {int(threshold*100)}% threshold:")
        print(f"    Stable core (in ALL windows): {len(stable_core)}")
        print(f"    Transient (in SOME windows): {len(transient_core)}")
        print(f"    Ever-core (union): {len(ever_core)}")
        
        core_stability.append({
            'species': species,
            'threshold': threshold,
            'n_windows': len(all_cores),
            'stable_core_size': len(stable_core),
            'transient_core_size': len(transient_core),
            'ever_core_size': len(ever_core),
            'stability_ratio': len(stable_core) / len(ever_core) if len(ever_core) > 0 else 0
        })

stability_df = pd.DataFrame(core_stability)
print("\nCore stability summary:")
print(stability_df.to_string(index=False))

In [None]:
# Cell 10: Track when specific genes leave the core

# For the cumulative analysis, identify at what point each gene leaves the core

gene_exit_points = []

for species in SPECIES:
    print(f"\n=== Gene exit tracking: {species} ===")
    
    ordered_genomes = data[species]['ordered_genomes']
    genome_clusters = data[species]['genome_clusters']
    genomes_df = data[species]['genomes_df']
    
    # Use 95% threshold
    threshold = 0.95
    
    # Track at key points (every 100 genomes)
    check_points = list(range(MIN_WINDOW_SIZE, len(ordered_genomes), 100))
    if len(ordered_genomes) not in check_points:
        check_points.append(len(ordered_genomes))
    
    # Calculate core at each checkpoint
    checkpoint_cores = {}
    for n in tqdm(check_points, desc="Calculating cores"):
        checkpoint_cores[n] = calculate_core(ordered_genomes[:n], genome_clusters, threshold)
    
    # For each gene in initial core, find when it exits
    initial_core = checkpoint_cores[check_points[0]]
    
    for gene_id in initial_core:
        exit_point = None
        for n in check_points:
            if gene_id not in checkpoint_cores[n]:
                exit_point = n
                break
        
        gene_exit_points.append({
            'species': species,
            'gene_cluster_id': gene_id,
            'exit_at_n_genomes': exit_point,
            'exits_early': exit_point is not None and exit_point < len(ordered_genomes) // 2,
            'never_exits': exit_point is None
        })
    
    # Summary
    species_exits = [e for e in gene_exit_points if e['species'] == species]
    never_exits = sum(1 for e in species_exits if e['never_exits'])
    early_exits = sum(1 for e in species_exits if e['exits_early'])
    print(f"  Initial core: {len(initial_core)}")
    print(f"  Never exit (stable): {never_exits}")
    print(f"  Exit early (< 50%): {early_exits}")

exit_df = pd.DataFrame(gene_exit_points)

---
## Save Results

In [None]:
# Cell 11: Save all results

# Cumulative results
cumulative_df.to_parquet(f"{OUTPUT_PATH}/cumulative_results.parquet")
print(f"Saved: cumulative_results.parquet ({len(cumulative_df)} rows)")

# Window results
window_df.to_parquet(f"{OUTPUT_PATH}/window_results.parquet")
print(f"Saved: window_results.parquet ({len(window_df)} rows)")

# Inter-window Jaccard
jaccard_df.to_csv(f"{OUTPUT_PATH}/inter_window_jaccard.csv", index=False)
print(f"Saved: inter_window_jaccard.csv ({len(jaccard_df)} rows)")

# Core stability
stability_df.to_csv(f"{OUTPUT_PATH}/core_stability.csv", index=False)
print(f"Saved: core_stability.csv ({len(stability_df)} rows)")

# Gene exit points
exit_df.to_parquet(f"{OUTPUT_PATH}/gene_exit_points.parquet")
print(f"Saved: gene_exit_points.parquet ({len(exit_df)} rows)")

# Save stable core gene lists
for species in SPECIES:
    windows = sorted(window_cores[species].keys())
    all_cores_95 = [window_cores[species][w].get(0.95, set()) for w in windows]
    all_cores_95 = [c for c in all_cores_95 if len(c) > 0]
    if len(all_cores_95) >= 2:
        stable = set.intersection(*all_cores_95)
        with open(f"{OUTPUT_PATH}/{species}_stable_core_genes.txt", 'w') as f:
            for gene_id in sorted(stable):
                f.write(f"{gene_id}\n")
        print(f"Saved: {species}_stable_core_genes.txt ({len(stable)} genes)")

print("\nAll results saved!")

---
## Summary

### Files Created
- `cumulative_results.parquet` - Core size at each cumulative step
- `window_results.parquet` - Core size for each fixed time window
- `inter_window_jaccard.csv` - Jaccard similarity between adjacent windows
- `core_stability.csv` - Stable vs transient core statistics
- `gene_exit_points.parquet` - When each gene exits the core
- `{species}_stable_core_genes.txt` - Gene IDs in stable core
- `cumulative_core_decay.png` - Visualization of core decay
- `fixed_window_cores.png` - Core size by time window

### Next Steps
Run `03_analysis.ipynb` for:
- Decay rate fitting (power law)
- Species comparison
- Detailed turnover analysis