# NB 02: Extract Gene-Fitness Matrices

Build gene × experiment fitness matrices for each pilot organism.
Also extracts gene metadata and experiment metadata.

**Strategy**: Try `fitbyexp_{orgId}` tables first (pre-pivoted),
fall back to pivoting `genefitness` if needed.

**Run on BERDL JupyterHub** for Spark access.

In [None]:
import pandas as pd
import numpy as np
from pathlib import Path

spark = get_spark_session()
print(f"Spark version: {spark.version}")

DATA_DIR = Path('../data')
MATRIX_DIR = DATA_DIR / 'matrices'
ANNOT_DIR = DATA_DIR / 'annotations'
MATRIX_DIR.mkdir(parents=True, exist_ok=True)
ANNOT_DIR.mkdir(parents=True, exist_ok=True)

In [None]:
# Load pilot organisms
pilots = pd.read_csv(DATA_DIR / 'pilot_organisms.csv')
pilot_ids = pilots['orgId'].tolist()
print(f"Pilot organisms: {pilot_ids}")

## 1. Extract Gene Metadata

In [None]:
for org_id in pilot_ids:
    out_file = ANNOT_DIR / f'{org_id}_genes.csv'
    if out_file.exists() and out_file.stat().st_size > 0:
        print(f"CACHED: {org_id} genes")
        continue
    
    genes = spark.sql(f"""
        SELECT locusId, sysName, gene, desc, scaffoldId,
               CAST(begin AS INT) as begin, CAST(end AS INT) as end, strand
        FROM kescience_fitnessbrowser.gene
        WHERE orgId = '{org_id}'
        ORDER BY scaffoldId, begin
    """).toPandas()
    genes.to_csv(out_file, index=False)
    print(f"Saved: {org_id} — {len(genes)} genes")

## 2. Extract Experiment Metadata

In [None]:
for org_id in pilot_ids:
    out_file = ANNOT_DIR / f'{org_id}_experiments.csv'
    if out_file.exists() and out_file.stat().st_size > 0:
        print(f"CACHED: {org_id} experiments")
        continue
    
    exps = spark.sql(f"""
        SELECT expName, expDesc, expGroup, condition_1, media,
               CAST(cor12 AS FLOAT) as cor12,
               CAST(mad12 AS FLOAT) as mad12,
               CAST(nMapped AS INT) as nMapped
        FROM kescience_fitnessbrowser.experiment
        WHERE orgId = '{org_id}'
        ORDER BY expName
    """).toPandas()
    exps.to_csv(out_file, index=False)
    print(f"Saved: {org_id} — {len(exps)} experiments")

## 3. Extract Fitness Matrices

Try `fitbyexp_{orgId}` first; fall back to `genefitness` pivot.

In [None]:
def extract_fitness_matrix_fitbyexp(spark, org_id):
    """Extract fitness matrix from pre-pivoted fitbyexp table."""
    table_name = f"kescience_fitnessbrowser.fitbyexp_{org_id.lower()}"
    try:
        df = spark.sql(f"SELECT * FROM {table_name}").toPandas()
        if len(df) == 0:
            return None
        
        # First column is locusId (or similar identifier)
        id_col = df.columns[0]
        df = df.set_index(id_col)
        
        # Convert all fitness values to numeric
        df = df.apply(pd.to_numeric, errors='coerce')
        
        return df
    except Exception as e:
        print(f"  fitbyexp_{org_id.lower()} not available: {e}")
        return None


def extract_fitness_matrix_genefitness(spark, org_id):
    """Extract and pivot fitness matrix from genefitness table."""
    gf = spark.sql(f"""
        SELECT locusId, expName,
               CAST(fit AS FLOAT) as fit
        FROM kescience_fitnessbrowser.genefitness
        WHERE orgId = '{org_id}'
    """).toPandas()
    
    if len(gf) == 0:
        return None
    
    # Pivot to gene x experiment matrix
    matrix = gf.pivot(index='locusId', columns='expName', values='fit')
    return matrix


def extract_t_matrix(spark, org_id):
    """Extract t-statistic matrix from genefitness table."""
    gf = spark.sql(f"""
        SELECT locusId, expName,
               CAST(t AS FLOAT) as t
        FROM kescience_fitnessbrowser.genefitness
        WHERE orgId = '{org_id}'
    """).toPandas()
    
    if len(gf) == 0:
        return None
    
    matrix = gf.pivot(index='locusId', columns='expName', values='t')
    return matrix

In [None]:
matrix_summary = []

for org_id in pilot_ids:
    fit_file = MATRIX_DIR / f'{org_id}_fitness_matrix.csv'
    t_file = MATRIX_DIR / f'{org_id}_t_matrix.csv'
    
    if fit_file.exists() and fit_file.stat().st_size > 0:
        print(f"CACHED: {org_id} fitness matrix")
        fit_matrix = pd.read_csv(fit_file, index_col=0)
    else:
        print(f"\nExtracting {org_id}...")
        
        # Try fitbyexp first
        fit_matrix = extract_fitness_matrix_fitbyexp(spark, org_id)
        if fit_matrix is None:
            print(f"  Falling back to genefitness pivot...")
            fit_matrix = extract_fitness_matrix_genefitness(spark, org_id)
        
        if fit_matrix is None:
            print(f"  ERROR: No fitness data for {org_id}")
            continue
        
        # Load experiment QC to filter bad experiments
        exp_meta = pd.read_csv(ANNOT_DIR / f'{org_id}_experiments.csv')
        good_exps = exp_meta[exp_meta['cor12'] >= 0.1]['expName'].tolist()
        shared_exps = [e for e in good_exps if e in fit_matrix.columns]
        fit_matrix = fit_matrix[shared_exps]
        print(f"  Kept {len(shared_exps)}/{len(fit_matrix.columns)} experiments (cor12 >= 0.1)")
        
        # Drop genes missing fitness in >50% of experiments
        missing_frac = fit_matrix.isna().mean(axis=1)
        fit_matrix = fit_matrix[missing_frac <= 0.5]
        print(f"  Kept {len(fit_matrix)} genes (<=50% missing)")
        
        # Fill remaining NaN with 0
        fit_matrix = fit_matrix.fillna(0.0)
        
        fit_matrix.to_csv(fit_file)
        print(f"  Saved: {fit_file}")
    
    # Extract t-statistic matrix
    if t_file.exists() and t_file.stat().st_size > 0:
        print(f"CACHED: {org_id} t-statistic matrix")
    else:
        print(f"  Extracting t-statistics for {org_id}...")
        t_matrix = extract_t_matrix(spark, org_id)
        if t_matrix is not None:
            # Apply same filters as fitness matrix
            shared_cols = [c for c in fit_matrix.columns if c in t_matrix.columns]
            shared_rows = [r for r in fit_matrix.index if r in t_matrix.index]
            t_matrix = t_matrix.loc[shared_rows, shared_cols].fillna(0.0)
            t_matrix.to_csv(t_file)
            print(f"  Saved: {t_file}")
    
    matrix_summary.append({
        'orgId': org_id,
        'n_genes': fit_matrix.shape[0],
        'n_experiments': fit_matrix.shape[1],
        'matrix_density': (fit_matrix != 0).mean().mean(),
        'mean_fitness': fit_matrix.values.mean(),
        'std_fitness': fit_matrix.values.std()
    })

summary_df = pd.DataFrame(matrix_summary)
summary_df.to_csv(MATRIX_DIR / 'matrix_summary.csv', index=False)
print("\n" + "="*60)
print("MATRIX EXTRACTION SUMMARY")
print("="*60)
summary_df