In [5]:
# src/data_prep.py
"""
Data prep for CTD curated gene-disease associations.
Outputs:
- data/processed/gd_pairs.csv  (positives + negatives)
- data/processed/train_pairs.csv, test_pairs.csv (pair-level)
- data/processed/train_entityheld.csv, test_entityheld.csv (entity-held-out)
- data/processed/genes.csv, diseases.csv
"""
import os
import pandas as pd
import numpy as np
import json
from pathlib import Path
from sklearn.model_selection import train_test_split

# ========== USER CONFIG ==========
RAW_PATH = "data/raw/CTD_curated_genes_diseases.csv"   # change if needed
OUT_DIR = "data/processed"
SEED = 42
NEG_RATIO = 1        # negatives per positive (1 = balanced)
EVIDENCE_FILTER = True   # set False to use all rows
EVIDENCE_KEYWORDS = ["marker/mechanism", "therapeutic"]  # keep rows that match any
HOLDOUT_FRACTION = 0.10  # for entity-held-out
MAX_NEG_ATTEMPTS = 100000
# =================================

rng = np.random.default_rng(SEED)
Path(OUT_DIR).mkdir(parents=True, exist_ok=True)

# --- 1. Load CSV robustly ---
print("Loading raw file:", RAW_PATH)
# CTD files have comment lines starting with #, and the header is in a comment
# The actual data starts without a header, so we need to specify column names
try:
    # First, try reading with comment skip and no header, then assign column names
    df = pd.read_csv(RAW_PATH, comment='#', header=None)
    df.columns = ['GeneSymbol','GeneID','DiseaseName','DiseaseID','DirectEvidence','OmimIDs','PubMedIDs']
except Exception as e:
    print(f"Error reading with comment skip: {e}")
    # fallback: manually find where data starts
    with open(RAW_PATH, 'r') as f:
        lines = f.readlines()
    
    # find first non-comment line
    data_start = 0
    for i, line in enumerate(lines):
        if not line.strip().startswith('#') and line.strip():
            data_start = i
            break
    
    # read from data start without header
    df = pd.read_csv(RAW_PATH, skiprows=data_start, header=None)
    df.columns = ['GeneSymbol','GeneID','DiseaseName','DiseaseID','DirectEvidence','OmimIDs','PubMedIDs']

print(f"Loaded data with {len(df)} rows and columns: {df.columns.tolist()}")

# normalize column names
cols = [c.strip() for c in df.columns]
df.columns = cols

# Basic check: required columns
required = ['GeneSymbol','DiseaseName']
for r in required:
    if r not in df.columns:
        raise ValueError(f"Required column '{r}' not found in data. Found columns: {df.columns.tolist()}")

# --- 2. Clean text fields ---
df['GeneSymbol'] = df['GeneSymbol'].astype(str).str.strip()
df['DiseaseName'] = df['DiseaseName'].astype(str).str.strip()
# ensure PubMedIDs column exists
if 'PubMedIDs' not in df.columns:
    df['PubMedIDs'] = ""

# drop missing
n_before = len(df)
df = df[ (~df['GeneSymbol'].isna()) & (~df['DiseaseName'].isna()) ]
print(f"Dropped {n_before - len(df)} rows with missing gene/disease")

# --- 3. Optional: evidence filtering for higher confidence ---
if EVIDENCE_FILTER and 'DirectEvidence' in df.columns:
    mask = df['DirectEvidence'].fillna("").str.lower().apply(
        lambda s: any(k.lower() in s for k in EVIDENCE_KEYWORDS)
    )
    n_keep = mask.sum()
    print(f"Evidence filter ON: keeping {n_keep} rows matching keywords {EVIDENCE_KEYWORDS}")
    df = df[mask].copy()
else:
    print("Evidence filter OFF or DirectEvidence missing; using full dataset")

# --- 4. Build positive sentences and dedupe ---
df['text'] = df['GeneSymbol'] + " is associated with " + df['DiseaseName']
pos_df = df[['GeneSymbol','DiseaseName','text','PubMedIDs']].drop_duplicates().reset_index(drop=True)
pos_df['label'] = 1
print("Positive pairs:", len(pos_df))

# --- 5. Negative sampling (random) ---
genes = pos_df['GeneSymbol'].unique()
diseases = pos_df['DiseaseName'].unique()
pos_text_set = set(pos_df['text'].tolist())

negatives = []
attempts = 0
target_neg = len(pos_df) * NEG_RATIO
print(f"Sampling {target_neg} negative examples (neg ratio={NEG_RATIO})")
while len(negatives) < target_neg and attempts < MAX_NEG_ATTEMPTS:
    g = rng.choice(genes)
    d = rng.choice(diseases)
    txt = f"{g} is associated with {d}"
    if txt not in pos_text_set:
        negatives.append((g,d,txt))
    attempts += 1

neg_df = pd.DataFrame(negatives, columns=['GeneSymbol','DiseaseName','text'])
neg_df['label'] = 0
neg_df['PubMedIDs'] = ""  # no pmid for synthetic negatives
print("Negative pairs sampled:", len(neg_df))

# Optional: create some hard negatives (same gene but slightly related diseases)
# (LEFT AS TODO for team if they have disease categories)

# --- 6. Combine, shuffle, save master pairs file ---
all_df = pd.concat([pos_df[['GeneSymbol','DiseaseName','text','PubMedIDs','label']], neg_df], ignore_index=True)
all_df = all_df.sample(frac=1, random_state=SEED).reset_index(drop=True)
print("Total dataset size (pos+neg):", len(all_df))
all_df.to_csv(os.path.join(OUT_DIR, "gd_pairs.csv"), index=False)
print("Saved:", os.path.join(OUT_DIR, "gd_pairs.csv"))

# --- 7. Train/test split (pair-level stratified by label) ---
train, test = train_test_split(all_df, test_size=0.2, stratify=all_df['label'], random_state=SEED)
train.to_csv(os.path.join(OUT_DIR, "train_pairs.csv"), index=False)
test.to_csv(os.path.join(OUT_DIR, "test_pairs.csv"), index=False)
print("Saved pair-level splits: train:", len(train), " test:", len(test))

# --- 8. Entity-held-out split (hold out some genes and diseases) ---
unique_genes = list(genes)
unique_diseases = list(diseases)
n_g_hold = max(1, int(len(unique_genes) * HOLDOUT_FRACTION))
n_d_hold = max(1, int(len(unique_diseases) * HOLDOUT_FRACTION))
genes_hold = list(rng.choice(unique_genes, size=n_g_hold, replace=False))
diseases_hold = list(rng.choice(unique_diseases, size=n_d_hold, replace=False))
print("Entity-held-out:", len(genes_hold), "genes held out,", len(diseases_hold), "diseases held out")

# test set = rows where gene in genes_hold or disease in diseases_hold
ent_test_mask = all_df['GeneSymbol'].isin(genes_hold) | all_df['DiseaseName'].isin(diseases_hold)
ent_test = all_df[ent_test_mask]
ent_train = all_df[~ent_test_mask]
print("Entity-held splits: train:", len(ent_train), " test:", len(ent_test))
ent_train.to_csv(os.path.join(OUT_DIR, "train_entityheld.csv"), index=False)
ent_test.to_csv(os.path.join(OUT_DIR, "test_entityheld.csv"), index=False)

# --- 9. Save entities lists and mapping for later KG building ---
genes_df = pd.DataFrame({'GeneSymbol': sorted(list(genes))})
diseases_df = pd.DataFrame({'DiseaseName': sorted(list(diseases))})
genes_df.to_csv(os.path.join(OUT_DIR, "genes.csv"), index=False)
diseases_df.to_csv(os.path.join(OUT_DIR, "diseases.csv"), index=False)

meta = {
    "n_pos": int(pos_df.shape[0]),
    "n_neg": int(neg_df.shape[0]),
    "total": int(all_df.shape[0]),
    "holdout_genes": genes_hold,
    "holdout_diseases": diseases_hold,
    "seed": int(SEED)
}
with open(os.path.join(OUT_DIR, "prep_metadata.json"), "w") as f:
    json.dump(meta, f, indent=2)

print("Done. Outputs in:", OUT_DIR)

Loading raw file: data/raw/CTD_curated_genes_diseases.csv
Loaded data with 34222 rows and columns: ['GeneSymbol', 'GeneID', 'DiseaseName', 'DiseaseID', 'DirectEvidence', 'OmimIDs', 'PubMedIDs']
Dropped 0 rows with missing gene/disease
Evidence filter ON: keeping 34222 rows matching keywords ['marker/mechanism', 'therapeutic']
Positive pairs: 34222
Sampling 34222 negative examples (neg ratio=1)
Negative pairs sampled: 34222
Total dataset size (pos+neg): 68444
Negative pairs sampled: 34222
Total dataset size (pos+neg): 68444
Saved: data/processed/gd_pairs.csv
Saved pair-level splits: train: 54755  test: 13689
Entity-held-out: 911 genes held out, 585 diseases held out
Entity-held splits: train: 56298  test: 12146
Saved: data/processed/gd_pairs.csv
Saved pair-level splits: train: 54755  test: 13689
Entity-held-out: 911 genes held out, 585 diseases held out
Entity-held splits: train: 56298  test: 12146
Done. Outputs in: data/processed
Done. Outputs in: data/processed
