# Explore Gene Data in BERDL

## Goal
Understand the structure of gene-related tables to properly compute pangenome openness metrics.

## Questions
1. How many genes per genome are in the `gene` table?
2. What fraction of genes are assigned to clusters in `gene_genecluster_junction`?
3. What do the gene clusters represent?
4. What data do we need to compute core/accessory genome fractions?

In [None]:
import pandas as pd
import numpy as np
from pyspark.sql.functions import col, count, countDistinct

# Load our target species
species_stats = pd.read_csv('../data/species_selection_stats.csv')
target_genomes = pd.read_csv('../data/target_genomes_expanded.csv')

print(f"Target species: {len(species_stats)}")
print(f"Target genomes: {len(target_genomes):,}")

## 1. Explore Gene Table Schema

In [None]:
# Check the gene table structure
gene_sample = spark.sql("""
    SELECT * 
    FROM kbase_ke_pangenome.gene 
    LIMIT 10
""")

print("=== Gene Table Schema ===")
gene_sample.printSchema()

print("\n=== Sample Rows ===")
gene_sample.show(truncate=False)

In [None]:
# Check gene_genecluster_junction table
junction_sample = spark.sql("""
    SELECT * 
    FROM kbase_ke_pangenome.gene_genecluster_junction 
    LIMIT 10
""")

print("=== Gene-Cluster Junction Schema ===")
junction_sample.printSchema()

print("\n=== Sample Rows ===")
junction_sample.show(truncate=False)

In [None]:
# Check gene_cluster table if it exists
try:
    cluster_sample = spark.sql("""
        SELECT * 
        FROM kbase_ke_pangenome.gene_cluster 
        LIMIT 10
    """)
    print("=== Gene Cluster Table Schema ===")
    cluster_sample.printSchema()
    cluster_sample.show(truncate=False)
except:
    print("gene_cluster table not found")

## 2. Count Genes per Genome

In [None]:
# Count total genes per genome for a sample species
test_species = 's__Klebsiella_pneumoniae--RS_GCF_000742135.1'
test_genomes = target_genomes[target_genomes['gtdb_species_clade_id'] == test_species]['genome_id'].tolist()

print(f"Testing with {len(test_genomes)} genomes from Klebsiella pneumoniae")

# Total genes in gene table
gene_counts = spark.sql(f"""
    SELECT genome_id, COUNT(*) as n_genes
    FROM kbase_ke_pangenome.gene
    WHERE genome_id IN ({','.join([f"'{g}'" for g in test_genomes[:20]])})
    GROUP BY genome_id
""").toPandas()

print(f"\n=== Genes per Genome (from gene table) ===")
print(gene_counts.describe())
print(f"\nSample:\n{gene_counts.head(10)}")

In [None]:
# Genes with cluster assignments
clustered_counts = spark.sql(f"""
    SELECT g.genome_id, COUNT(DISTINCT j.gene_cluster_id) as n_clustered
    FROM kbase_ke_pangenome.gene g
    JOIN kbase_ke_pangenome.gene_genecluster_junction j ON g.gene_id = j.gene_id
    WHERE g.genome_id IN ({','.join([f"'{g}'" for g in test_genomes[:20]])})
    GROUP BY g.genome_id
""").toPandas()

print(f"\n=== Genes with Cluster Assignments ===")
print(clustered_counts.describe())

# Compare
comparison = gene_counts.merge(clustered_counts, on='genome_id', how='left')
comparison['clustered_fraction'] = comparison['n_clustered'] / comparison['n_genes']
print(f"\n=== Comparison ===")
print(comparison)

## 3. Understand Gene Cluster Structure

In [None]:
# What do the gene clusters look like?
# Are they species-specific or global?

# Check cluster IDs for our test species
cluster_info = spark.sql(f"""
    SELECT j.gene_cluster_id, COUNT(*) as n_genes
    FROM kbase_ke_pangenome.gene g
    JOIN kbase_ke_pangenome.gene_genecluster_junction j ON g.gene_id = j.gene_id
    WHERE g.genome_id IN ({','.join([f"'{g}'" for g in test_genomes[:20]])})
    GROUP BY j.gene_cluster_id
    ORDER BY n_genes DESC
    LIMIT 20
""").toPandas()

print("=== Top Gene Clusters (by gene count) ===")
print(cluster_info)

# How many genomes is each cluster in?
cluster_prevalence = spark.sql(f"""
    SELECT j.gene_cluster_id, COUNT(DISTINCT g.genome_id) as n_genomes
    FROM kbase_ke_pangenome.gene g
    JOIN kbase_ke_pangenome.gene_genecluster_junction j ON g.gene_id = j.gene_id
    WHERE g.genome_id IN ({','.join([f"'{g}'" for g in test_genomes[:20]])})
    GROUP BY j.gene_cluster_id
    ORDER BY n_genomes DESC
    LIMIT 20
""").toPandas()

print("\n=== Cluster Prevalence (how many genomes) ===")
print(cluster_prevalence)

In [None]:
# Check if clusters are shared across species
# Pick a common cluster and see which species have it

common_cluster = cluster_prevalence.iloc[0]['gene_cluster_id']
print(f"Checking cluster: {common_cluster}")

species_with_cluster = spark.sql(f"""
    SELECT gm.gtdb_species_clade_id, COUNT(DISTINCT g.genome_id) as n_genomes
    FROM kbase_ke_pangenome.gene g
    JOIN kbase_ke_pangenome.gene_genecluster_junction j ON g.gene_id = j.gene_id
    JOIN kbase_ke_pangenome.genome gm ON g.genome_id = gm.genome_id
    WHERE j.gene_cluster_id = '{common_cluster}'
    GROUP BY gm.gtdb_species_clade_id
    ORDER BY n_genomes DESC
    LIMIT 10
""").toPandas()

print(f"\n=== Species with cluster {common_cluster} ===")
print(species_with_cluster)

## 4. Summary: What Data Do We Need?

Based on the exploration above, determine:
1. Whether to use `gene` table (all genes) or `gene_genecluster_junction` (clustered genes only)
2. Whether clusters are species-specific or can be compared across species
3. What query to use for computing core/accessory genome fractions

In [None]:
# Summary statistics for our target species
print("=== Summary for Data Extraction ===")
print(f"\nTarget species: {len(species_stats)}")
print(f"Target genomes: {len(target_genomes):,}")

# Estimate data sizes
# Genes per genome * genomes = total rows needed