# Scale Up Ecotype Analysis to More Species

## Goal
Expand the environment-ecotype analysis to many more species by:
1. Finding species with **good environmental embedding coverage** (≥20 embeddings, ≥30% coverage)
2. **Downsampling large species** to maximize phylogenetic diversity
3. Extracting data for the expanded species set

## Selection Criteria

| Parameter | Value | Rationale |
|-----------|-------|----------|
| **Minimum genomes with embeddings** | 20 | Basic statistical power |
| **Embedding coverage** | ≥30% | Representative sample |
| **Maximum genomes per species** | 250 | Tractable pairwise computations |
| **Downsampling method** | Maximize diversity using ANI distances | Preserve full genetic spread |

---
## STEP 1: Find Species with Good Embedding Coverage

In [None]:
# Cell 1: Compute Embedding Coverage per Species

import numpy as np
import pandas as pd
import os
from pyspark.sql.functions import col, count, countDistinct

# Initialize Spark
spark = get_spark_session()

OUTPUT_PATH = "../data"
os.makedirs(OUTPUT_PATH, exist_ok=True)

# Join genome table with embeddings to compute coverage per species
coverage_df = spark.sql("""
    SELECT
        g.gtdb_species_clade_id,
        COUNT(DISTINCT g.genome_id) as n_total,
        COUNT(DISTINCT e.genome_id) as n_with_embeddings
    FROM kbase_ke_pangenome.genome g
    LEFT JOIN kbase_ke_pangenome.alphaearth_embeddings_all_years e
        ON g.genome_id = e.genome_id
    GROUP BY g.gtdb_species_clade_id
""")

coverage_df.cache()
coverage_pd = coverage_df.toPandas()
coverage_pd['coverage'] = coverage_pd['n_with_embeddings'] / coverage_pd['n_total']

# Save coverage data
coverage_pd.to_csv(f"{OUTPUT_PATH}/species_embedding_coverage.csv", index=False)
print(f"Saved coverage data for {len(coverage_pd)} species")

# Show species with best coverage (>=20 embeddings, >=30% coverage)
good_coverage = coverage_pd[
    (coverage_pd['n_with_embeddings'] >= 20) &
    (coverage_pd['coverage'] >= 0.30)
].sort_values('n_with_embeddings', ascending=False)

print(f"\nSpecies with >=20 embeddings AND >=30% coverage: {len(good_coverage)}")
print(good_coverage.head(30))

In [None]:
# Cell 2: Select Target Species for Expanded Analysis

# Use ALL species meeting criteria:
# - ≥20 genomes with embeddings (for statistical power)
# - ≥30% coverage (representative sample)

TARGET_SPECIES = good_coverage['gtdb_species_clade_id'].tolist()
print(f"Selected {len(TARGET_SPECIES)} species for expanded analysis")

# Save target species list
with open(f"{OUTPUT_PATH}/target_species_expanded.txt", 'w') as f:
    for sp in TARGET_SPECIES:
        f.write(sp + '\n')

---
## STEP 2: Downsample Large Species (Maximize Phylogenetic Diversity)

For species with >250 genomes with embeddings, select representatives that **maximize total phylogenetic diversity** using ANI-based distances (distance = 100 - ANI).

*Note: Using ANI as a proxy for phylogenetic distance until per-species tree data is available.*

In [None]:
# Cell 3: Identify Species Needing Downsampling

MAX_GENOMES = 250  # Maximum genomes per species after downsampling

# All target species from good_coverage (already filtered to ≥20 embeddings, ≥30% coverage)
print(f"Total target species: {len(TARGET_SPECIES)}")

# Show which need downsampling (more than MAX_GENOMES genomes with embeddings)
needs_downsampling = good_coverage[good_coverage['n_with_embeddings'] > MAX_GENOMES]
print(f"\nSpecies needing downsampling (>{MAX_GENOMES} with embeddings): {len(needs_downsampling)}")
if len(needs_downsampling) > 0:
    print(needs_downsampling[['gtdb_species_clade_id', 'n_total', 'n_with_embeddings']].head(20))

In [None]:
# Cell 4: Diversity-Maximizing Downsampling Function

import numpy as np

def downsample_maximize_diversity(species_id, max_genomes=250):
    """
    Downsample a species by selecting genomes that MAXIMIZE phylogenetic diversity.
    
    Only considers genomes WITH embeddings (required for the analysis).
    Uses ANI-based distances (100 - ANI) as proxy for phylogenetic distance.

    Algorithm:
    1. Get genomes with embeddings for this species
    2. Build distance matrix from genome_ani table (distance = 100 - ANI)
    3. Use maximin selection to maximize phylogenetic spread
    """
    # Get genomes WITH embeddings for this species
    embed_genomes = spark.sql(f"""
        SELECT g.genome_id
        FROM kbase_ke_pangenome.genome g
        JOIN kbase_ke_pangenome.alphaearth_embeddings_all_years e
            ON g.genome_id = e.genome_id
        WHERE g.gtdb_species_clade_id = '{species_id}'
    """).collect()
    genome_ids = [r.genome_id for r in embed_genomes]
    
    n_genomes = len(genome_ids)
    short_name = species_id.split('__')[1].split('--')[0]
    print(f"{short_name}: {n_genomes} genomes with embeddings")

    # If small enough, return all genomes
    if n_genomes <= max_genomes:
        return genome_ids

    # Need to downsample - get ANI matrix
    print(f"  Building ANI distance matrix for {n_genomes} genomes...")
    ani_df = spark.sql(f"""
        SELECT genome1_id, genome2_id, ANI
        FROM kbase_ke_pangenome.genome_ani
        WHERE genome1_id IN ({','.join([f"'{g}'" for g in genome_ids])})
          AND genome2_id IN ({','.join([f"'{g}'" for g in genome_ids])})
    """).toPandas()

    # Build distance matrix (distance = 100 - ANI)
    genome_to_idx = {g: i for i, g in enumerate(genome_ids)}
    dist_matrix = np.zeros((n_genomes, n_genomes))

    for _, row in ani_df.iterrows():
        i = genome_to_idx.get(row['genome1_id'])
        j = genome_to_idx.get(row['genome2_id'])
        if i is not None and j is not None:
            dist = 100 - row['ANI']
            dist_matrix[i, j] = dist
            dist_matrix[j, i] = dist

    # Greedy maximin selection
    print(f"  Selecting {max_genomes} representatives to maximize diversity...")
    selected_idx = []
    remaining_idx = set(range(n_genomes))

    # Start with genome that has max sum of distances (most divergent)
    sum_dists = dist_matrix.sum(axis=1)
    first = int(np.argmax(sum_dists))
    selected_idx.append(first)
    remaining_idx.remove(first)

    # Iteratively add genome with max minimum distance to selected set
    while len(selected_idx) < max_genomes and remaining_idx:
        best_idx = None
        best_min_dist = -1

        for idx in remaining_idx:
            min_dist = min(dist_matrix[idx, s] for s in selected_idx)
            if min_dist > best_min_dist:
                best_min_dist = min_dist
                best_idx = idx

        if best_idx is None:
            break

        selected_idx.append(best_idx)
        remaining_idx.remove(best_idx)

    # Summary
    total_diversity = sum(dist_matrix[i, j] for i in selected_idx for j in selected_idx if i < j)
    print(f"  Selected {len(selected_idx)} genomes")
    print(f"  Total pairwise ANI-distance: {total_diversity:.2f}")

    return [genome_ids[i] for i in selected_idx]

# Test on one species (uncomment to test)
# if len(needs_downsampling) > 0:
#     test_species = needs_downsampling.iloc[0]['gtdb_species_clade_id']
#     genomes = downsample_maximize_diversity(test_species)

In [None]:
# Cell 5: Build Final Genome List (All Target Species)

# For each target species, get genomes (downsampled if needed)
# All selected genomes will have embeddings
all_target_genomes = []
species_stats = []

for species_id in TARGET_SPECIES:
    species_info = good_coverage[good_coverage['gtdb_species_clade_id'] == species_id].iloc[0]
    n_with_embed = species_info['n_with_embeddings']
    n_total = species_info['n_total']

    if n_with_embed > MAX_GENOMES:
        # Downsample large species using diversity-maximizing selection
        genomes = downsample_maximize_diversity(species_id, max_genomes=MAX_GENOMES)
    else:
        # Use all genomes with embeddings
        genomes_result = spark.sql(f"""
            SELECT g.genome_id
            FROM kbase_ke_pangenome.genome g
            JOIN kbase_ke_pangenome.alphaearth_embeddings_all_years e
                ON g.genome_id = e.genome_id
            WHERE g.gtdb_species_clade_id = '{species_id}'
        """).collect()
        genomes = [r.genome_id for r in genomes_result]

    for g in genomes:
        all_target_genomes.append({
            'genome_id': g, 
            'gtdb_species_clade_id': species_id
        })

    species_stats.append({
        'species': species_id,
        'n_total': n_total,
        'n_with_embeddings': n_with_embed,
        'n_selected': len(genomes)
    })

target_genomes_df = pd.DataFrame(all_target_genomes)
target_genomes_df.to_csv(f"{OUTPUT_PATH}/target_genomes_expanded.csv", index=False)

stats_df = pd.DataFrame(species_stats)
stats_df.to_csv(f"{OUTPUT_PATH}/species_selection_stats.csv", index=False)

print(f"\n=== Summary ===")
print(f"Total species: {len(TARGET_SPECIES)}")
print(f"Total genomes selected: {len(target_genomes_df)}")
print(f"Genomes per species: {stats_df['n_selected'].mean():.0f} mean, {stats_df['n_selected'].min()}-{stats_df['n_selected'].max()} range")

In [None]:
# Cell 5b: Compute Embedding Diversity per Species

from scipy.spatial.distance import pdist, squareform
import matplotlib.pyplot as plt

# Embedding columns (A00-A63, 64 dimensions)
EMBEDDING_COLS = [f"A{i:02d}" for i in range(64)]

# Get embeddings for all target genomes
target_genome_ids = target_genomes_df['genome_id'].tolist()

# Query embeddings with all dimension columns
embeddings_query = f"""
    SELECT genome_id, {', '.join(EMBEDDING_COLS)}
    FROM kbase_ke_pangenome.alphaearth_embeddings_all_years
"""
embeddings_df = spark.sql(embeddings_query).filter(
    col("genome_id").isin(target_genome_ids)
).toPandas()

print(f"Retrieved {len(embeddings_df)} embeddings")

# Convert embedding columns to numpy array
embeddings_df['embedding_vec'] = embeddings_df[EMBEDDING_COLS].values.tolist()
embeddings_df['embedding_vec'] = embeddings_df['embedding_vec'].apply(np.array)

# Merge with species info
embeddings_df = embeddings_df.merge(
    target_genomes_df[['genome_id', 'gtdb_species_clade_id']], 
    on='genome_id'
)

# Compute embedding diversity per species
diversity_stats = []

for species_id in TARGET_SPECIES:
    species_embeddings = embeddings_df[
        embeddings_df['gtdb_species_clade_id'] == species_id
    ]['embedding_vec'].values
    
    if len(species_embeddings) < 2:
        continue
    
    # Stack into matrix
    emb_matrix = np.vstack(species_embeddings)
    
    # Compute pairwise cosine distances
    # cosine distance = 1 - cosine_similarity
    cosine_dists = pdist(emb_matrix, metric='cosine')
    
    short_name = species_id.split('__')[1].split('--')[0]
    diversity_stats.append({
        'species': species_id,
        'short_name': short_name,
        'n_genomes': len(species_embeddings),
        'mean_cosine_dist': np.mean(cosine_dists),
        'std_cosine_dist': np.std(cosine_dists),
        'min_cosine_dist': np.min(cosine_dists),
        'max_cosine_dist': np.max(cosine_dists),
        'median_cosine_dist': np.median(cosine_dists)
    })

diversity_df = pd.DataFrame(diversity_stats)
diversity_df.to_csv(f"{OUTPUT_PATH}/species_embedding_diversity.csv", index=False)

print(f"Computed embedding diversity for {len(diversity_df)} species")
print(f"\nSummary statistics:")
print(diversity_df[['mean_cosine_dist', 'std_cosine_dist']].describe())

# Plot distribution
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].hist(diversity_df['mean_cosine_dist'], bins=30, edgecolor='black', alpha=0.7)
axes[0].set_xlabel('Mean Pairwise Cosine Distance')
axes[0].set_ylabel('Number of Species')
axes[0].set_title('Distribution of Embedding Diversity Across Species')

axes[1].scatter(diversity_df['n_genomes'], diversity_df['mean_cosine_dist'], alpha=0.5)
axes[1].set_xlabel('Number of Genomes')
axes[1].set_ylabel('Mean Pairwise Cosine Distance')
axes[1].set_title('Embedding Diversity vs Sample Size')

plt.tight_layout()
plt.savefig(f"{OUTPUT_PATH}/embedding_diversity_distribution.png", dpi=150)
plt.show()

# Show species with lowest diversity (potential concern)
print("\nSpecies with LOWEST embedding diversity (potential concern):")
print(diversity_df.nsmallest(10, 'mean_cosine_dist')[['short_name', 'n_genomes', 'mean_cosine_dist']])

---
## STEP 3: Extract Data for Expanded Species Set

In [None]:
# Cell 6: Extract Embeddings for Target Genomes

target_genome_ids = target_genomes_df['genome_id'].tolist()

embeddings_df = spark.sql("""
    SELECT *
    FROM kbase_ke_pangenome.alphaearth_embeddings_all_years
""").filter(col("genome_id").isin(target_genome_ids))

embeddings_pd = embeddings_df.toPandas()
embeddings_pd.to_csv(f"{OUTPUT_PATH}/embeddings_expanded.csv", index=False)
print(f"Saved {len(embeddings_pd)} embeddings")

In [None]:
# Cell 7: Extract Within-Species ANI (Chunked)

from pyspark.sql.functions import monotonically_increasing_id
import os

ANI_OUTPUT_PATH = f"{OUTPUT_PATH}/ani_expanded"
os.makedirs(ANI_OUTPUT_PATH, exist_ok=True)

# Get ANI for target genomes (within-species pairs only)
ani_df = spark.sql("""
    SELECT
        a.genome1_id,
        a.genome2_id,
        a.ANI,
        g.gtdb_species_clade_id
    FROM kbase_ke_pangenome.genome_ani a
    JOIN kbase_ke_pangenome.genome g ON a.genome1_id = g.genome_id
""").filter(
    col("genome1_id").isin(target_genome_ids) &
    col("genome2_id").isin(target_genome_ids)
)

ani_df.cache()
total_count = ani_df.count()
print(f"ANI pairs: {total_count}")

# Export in chunks
CHUNK_SIZE = 1000000
ani_with_id = ani_df.withColumn("_id", monotonically_increasing_id())

n_chunks = (total_count // CHUNK_SIZE) + 1
for i in range(n_chunks):
    start_id = i * CHUNK_SIZE
    end_id = (i + 1) * CHUNK_SIZE
    
    chunk = ani_with_id.filter(
        (col("_id") >= start_id) & (col("_id") < end_id)
    ).drop("_id")
    
    chunk_pd = chunk.toPandas()
    chunk_pd.to_csv(f"{ANI_OUTPUT_PATH}/ani_chunk_{i:03d}.csv", index=False)
    print(f"  Saved chunk {i+1}/{n_chunks}: {len(chunk_pd)} rows")

print(f"\nCompleted ANI export to {ANI_OUTPUT_PATH}")

In [None]:
# Cell 8: Extract Gene Clusters per Genome (Chunked)

CLUSTERS_OUTPUT_PATH = f"{OUTPUT_PATH}/gene_clusters_expanded"
os.makedirs(CLUSTERS_OUTPUT_PATH, exist_ok=True)

# Get gene clusters for target genomes
genome_clusters_df = spark.sql("""
    SELECT
        g.genome_id,
        gg.gene_cluster_id,
        gm.gtdb_species_clade_id
    FROM kbase_ke_pangenome.gene g
    JOIN kbase_ke_pangenome.gene_genecluster_junction gg
        ON g.gene_id = gg.gene_id
    JOIN kbase_ke_pangenome.genome gm
        ON g.genome_id = gm.genome_id
""").filter(col("genome_id").isin(target_genome_ids))

genome_clusters_df.cache()
total_count = genome_clusters_df.count()
print(f"Gene-cluster associations: {total_count}")

# Export in chunks
CHUNK_SIZE = 5000000
clusters_with_id = genome_clusters_df.withColumn("_id", monotonically_increasing_id())

n_chunks = (total_count // CHUNK_SIZE) + 1
for i in range(n_chunks):
    start_id = i * CHUNK_SIZE
    end_id = (i + 1) * CHUNK_SIZE
    
    chunk = clusters_with_id.filter(
        (col("_id") >= start_id) & (col("_id") < end_id)
    ).drop("_id")
    
    chunk_pd = chunk.toPandas()
    chunk_pd.to_csv(f"{CLUSTERS_OUTPUT_PATH}/clusters_chunk_{i:03d}.csv", index=False)
    print(f"  Saved chunk {i+1}/{n_chunks}: {len(chunk_pd)} rows")

print(f"\nCompleted gene clusters export to {CLUSTERS_OUTPUT_PATH}")

---
## Summary

After running this notebook, download the following from the cluster:

- `species_embedding_coverage.csv` - Coverage for all species
- `target_genomes_expanded.csv` - Selected genomes (all have embeddings)
- `species_selection_stats.csv` - Selection statistics per species
- `species_embedding_diversity.csv` - Embedding diversity per species
- `embedding_diversity_distribution.png` - Visualization of diversity
- `embeddings_expanded.csv` - Environmental embeddings
- `ani_expanded/` - Pairwise ANI chunks
- `gene_clusters_expanded/` - Gene cluster chunks

Then run local analysis to:
1. Review embedding diversity distribution and decide on any flagging criteria
2. Compute environment-gene content correlations for all species
3. Identify which species show strongest environment-ecotype signal