# NB 01: Explore Fitness Browser & Select Pilot Organisms

Survey all 48 organisms in the Fitness Browser to select ~5 pilot organisms
for ICA module decomposition. Selection based on:
- Number of experiments (primary)
- Number of genes
- Ortholog connectivity
- Data quality (cor12, mad12)

**Run on BERDL JupyterHub** for Spark access.

In [None]:
import pandas as pd
import numpy as np
from pathlib import Path

spark = get_spark_session()
print(f"Spark version: {spark.version}")

DATA_DIR = Path('../data')
DATA_DIR.mkdir(exist_ok=True)

## 1. Organism Overview

In [None]:
# List all organisms
organisms = spark.sql("""
    SELECT orgId, division, genus, species, strain, taxonomyId
    FROM kescience_fitnessbrowser.organism
    ORDER BY genus, species
""").toPandas()
print(f"Total organisms: {len(organisms)}")
organisms.head(10)

## 2. Experiment Counts per Organism

In [None]:
# Count experiments per organism
exp_counts = spark.sql("""
    SELECT orgId,
           COUNT(*) as n_experiments,
           COUNT(DISTINCT expGroup) as n_exp_groups,
           COUNT(DISTINCT condition_1) as n_conditions
    FROM kescience_fitnessbrowser.experiment
    GROUP BY orgId
    ORDER BY n_experiments DESC
""").toPandas()
print(f"Organisms with experiments: {len(exp_counts)}")
exp_counts.head(10)

## 3. Gene Counts per Organism

In [None]:
# Count genes per organism
gene_counts = spark.sql("""
    SELECT orgId,
           COUNT(*) as n_genes,
           SUM(CASE WHEN desc LIKE '%hypothetical%' OR desc LIKE '%uncharacterized%'
               THEN 1 ELSE 0 END) as n_hypothetical
    FROM kescience_fitnessbrowser.gene
    GROUP BY orgId
    ORDER BY n_genes DESC
""").toPandas()
gene_counts['pct_hypothetical'] = (gene_counts['n_hypothetical'] / gene_counts['n_genes'] * 100).round(1)
print(f"Organisms with genes: {len(gene_counts)}")
gene_counts.head(10)

## 4. Experiment Quality Metrics

In [None]:
# QC metrics per organism
qc_metrics = spark.sql("""
    SELECT orgId,
           AVG(CAST(cor12 AS FLOAT)) as mean_cor12,
           AVG(CAST(mad12 AS FLOAT)) as mean_mad12,
           MIN(CAST(cor12 AS FLOAT)) as min_cor12,
           SUM(CASE WHEN CAST(cor12 AS FLOAT) >= 0.1 THEN 1 ELSE 0 END) as n_good_experiments
    FROM kescience_fitnessbrowser.experiment
    WHERE cor12 IS NOT NULL
    GROUP BY orgId
    ORDER BY mean_cor12 DESC
""").toPandas()
print(f"Organisms with QC data: {len(qc_metrics)}")
qc_metrics.head(10)

## 5. Ortholog Connectivity

In [None]:
# Ortholog connectivity: how many other organisms does each organism connect to?
ortholog_conn = spark.sql("""
    SELECT orgId1,
           COUNT(DISTINCT orgId2) as n_ortholog_partners,
           COUNT(*) as n_ortholog_pairs
    FROM kescience_fitnessbrowser.ortholog
    GROUP BY orgId1
    ORDER BY n_ortholog_partners DESC
""").toPandas()
ortholog_conn.columns = ['orgId', 'n_ortholog_partners', 'n_ortholog_pairs']
print(f"Organisms with orthologs: {len(ortholog_conn)}")
ortholog_conn.head(10)

## 6. Probe fitbyexp Table Schema

Check the structure of pre-pivoted fitness tables.

In [None]:
# List all fitbyexp tables
tables = spark.sql("SHOW TABLES IN kescience_fitnessbrowser").toPandas()
fitbyexp_tables = tables[tables['tableName'].str.startswith('fitbyexp_')]
print(f"fitbyexp tables: {len(fitbyexp_tables)}")
print(fitbyexp_tables['tableName'].tolist()[:10])

In [None]:
# Probe one fitbyexp table schema (using the first one found)
if len(fitbyexp_tables) > 0:
    sample_table = fitbyexp_tables['tableName'].iloc[0]
    schema = spark.sql(f"DESCRIBE kescience_fitnessbrowser.{sample_table}").toPandas()
    print(f"Schema for {sample_table}:")
    print(f"  Total columns: {len(schema)}")
    print(f"  First 5 columns: {schema['col_name'].head().tolist()}")
    print(f"  Last 5 columns: {schema['col_name'].tail().tolist()}")
    
    # Sample a few rows
    sample = spark.sql(f"SELECT * FROM kescience_fitnessbrowser.{sample_table} LIMIT 3").toPandas()
    print(f"\nSample rows shape: {sample.shape}")
    print(sample.iloc[:, :5])

## 7. Build Combined Statistics & Select Pilots

In [None]:
# Merge all statistics
stats = organisms.merge(exp_counts, on='orgId', how='left')
stats = stats.merge(gene_counts, on='orgId', how='left')
stats = stats.merge(qc_metrics, on='orgId', how='left')
stats = stats.merge(ortholog_conn, on='orgId', how='left')

# Fill NaN for organisms missing data
stats = stats.fillna(0)

# Composite score for pilot selection
# Normalize each metric to [0, 1] and weight
for col in ['n_experiments', 'n_genes', 'n_ortholog_partners', 'mean_cor12']:
    col_max = stats[col].max()
    if col_max > 0:
        stats[f'{col}_norm'] = stats[col] / col_max
    else:
        stats[f'{col}_norm'] = 0

# Weighted composite: experiments most important
stats['composite_score'] = (
    0.40 * stats['n_experiments_norm'] +
    0.20 * stats['n_genes_norm'] +
    0.20 * stats['n_ortholog_partners_norm'] +
    0.20 * stats['mean_cor12_norm']
)

stats = stats.sort_values('composite_score', ascending=False)
print("\nTop 10 organisms by composite score:")
display_cols = ['orgId', 'genus', 'species', 'strain', 'n_experiments',
                'n_genes', 'n_ortholog_partners', 'mean_cor12', 'composite_score']
stats[display_cols].head(10)

In [None]:
# Select top 5 as pilot organisms
n_pilots = 5
pilots = stats.head(n_pilots).copy()
print(f"\nSelected {n_pilots} pilot organisms:")
for _, row in pilots.iterrows():
    print(f"  {row['orgId']:20s} {row['genus']} {row['species']} {row['strain']}")
    print(f"    Experiments: {int(row['n_experiments'])}, Genes: {int(row['n_genes'])}, "
          f"Ortholog partners: {int(row['n_ortholog_partners'])}, "
          f"Mean cor12: {row['mean_cor12']:.3f}")

## 8. Save Results

In [None]:
# Save organism statistics
out_stats = DATA_DIR / 'organism_stats.csv'
stats.to_csv(out_stats, index=False)
print(f"Saved: {out_stats} ({len(stats)} organisms)")

# Save pilot organism list
out_pilots = DATA_DIR / 'pilot_organisms.csv'
pilots.to_csv(out_pilots, index=False)
print(f"Saved: {out_pilots} ({len(pilots)} pilots)")

In [None]:
print("="*60)
print("EXPLORATION SUMMARY")
print("="*60)
print(f"Total organisms: {len(stats)}")
print(f"Total experiments: {int(stats['n_experiments'].sum())}")
print(f"Total genes: {int(stats['n_genes'].sum())}")
print(f"Pilot organisms: {pilots['orgId'].tolist()}")
print(f"Pilot experiment range: {int(pilots['n_experiments'].min())}-{int(pilots['n_experiments'].max())}")
print("="*60)