# Training and Evaluating Our Rare Disease Prediction Models


In [14]:
import os
import json
import pandas as pd
import numpy as np
import torch
import logging
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import precision_score, recall_score, f1_score, classification_report
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import MultiLabelBinarizer
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim

# Config
config = {
    "train_path": "../data/processed/train.jsonl",
    "val_path": "../data/processed/val.jsonl",
    "test_path": "../data/processed/test.jsonl",
    "random_state": 42,
    "n_splits_cv": 5,
    "rf_params": {
        "n_estimators": 100,
        "max_depth": 10,
        "random_state": 42
    }
}

np.random.seed(config['random_state'])
torch.manual_seed(config['random_state'])

# Logging
if not os.path.exists('data/logs/'):
    os.makedirs('data/logs/')
logging.basicConfig(filename='data/logs/training_notebook.log', level=logging.INFO, format='%(asctime)s %(message)s')
logger = logging.getLogger()
logger.info("Starting training and evaluation notebook...")

assert os.path.exists(config["train_path"]), f"Train file not found: {config['train_path']}"
assert os.path.exists(config["val_path"]), f"Validation file not found: {config['val_path']}"
assert os.path.exists(config["test_path"]), f"Test file not found: {config['test_path']}"



2024-12-08 18:26:53,798 Starting training and evaluation notebook...



## Load Data

In [15]:
train_df = pd.read_json(config["train_path"], lines=True)
val_df = pd.read_json(config["val_path"], lines=True)
test_df = pd.read_json(config["test_path"], lines=True)

required_cols = ["positive_phenotypes", "all_candidate_genes", "true_diseases"]
for col in required_cols:
    assert col in train_df.columns, f"Missing required column: {col}"
    assert col in val_df.columns, f"Missing required column: {col}"
    assert col in test_df.columns, f"Missing required column: {col}"

combined_df = pd.concat([train_df, val_df], axis=0)
logger.info(f"Loaded data: Train={len(train_df)}, Val={len(val_df)}, Test={len(test_df)}")


2024-12-08 18:27:19,732 Loaded data: Train=64, Val=16, Test=20


## Feature Extraction for Random Forest

In [16]:
all_phenos = set()
all_genes = set()

for ph_list in combined_df['positive_phenotypes']:
    all_phenos.update(ph_list)
for gene_list in combined_df['all_candidate_genes']:
    all_genes.update(gene_list)

pheno_classes = sorted(list(all_phenos))
gene_classes = sorted(list(all_genes))

pheno_mlb = MultiLabelBinarizer(classes=pheno_classes)
gene_mlb = MultiLabelBinarizer(classes=gene_classes)

pheno_matrix = pheno_mlb.fit_transform(combined_df['positive_phenotypes'])
gene_matrix = gene_mlb.fit_transform(combined_df['all_candidate_genes'])

X_combined = np.hstack([pheno_matrix, gene_matrix])
y_combined = np.array([1 if d != 0 else 0 for d in combined_df['true_diseases']])

pheno_matrix_test = pheno_mlb.transform(test_df['positive_phenotypes'])
gene_matrix_test = gene_mlb.transform(test_df['all_candidate_genes'])
X_test = np.hstack([pheno_matrix_test, gene_matrix_test])
y_test = np.array([1 if d != 0 else 0 for d in test_df['true_diseases']])

## Cross-Validation with Random Forest

In [17]:
clf = RandomForestClassifier(**config["rf_params"])
skf = StratifiedKFold(n_splits=config["n_splits_cv"], shuffle=True, random_state=config["random_state"])
f1_scores = []
logger.info("Starting cross-validation...")

for fold_idx, (train_idx, val_idx) in enumerate(skf.split(X_combined, y_combined), start=1):
    X_tr, X_val = X_combined[train_idx], X_combined[val_idx]
    y_tr, y_val = y_combined[train_idx], y_combined[val_idx]
    
    clf.fit(X_tr, y_tr)
    y_val_pred = clf.predict(X_val)
    fold_f1 = f1_score(y_val, y_val_pred, average='macro', zero_division=0)
    f1_scores.append(fold_f1)
    logger.info(f"Fold {fold_idx}: F1-macro = {fold_f1:.4f}")

mean_f1 = np.mean(f1_scores)
std_f1 = np.std(f1_scores)
logger.info(f"Cross-validated F1-macro: {mean_f1:.4f} ± {std_f1:.4f}")
print(f"Cross-validated F1-macro: {mean_f1:.4f} ± {std_f1:.4f}")

2024-12-08 18:28:56,102 Starting cross-validation...
2024-12-08 18:28:56,175 Fold 1: F1-macro = 0.7922
2024-12-08 18:28:56,239 Fold 2: F1-macro = 0.5429
2024-12-08 18:28:56,302 Fold 3: F1-macro = 0.4667
2024-12-08 18:28:56,369 Fold 4: F1-macro = 0.7257
2024-12-08 18:28:56,432 Fold 5: F1-macro = 0.5636
2024-12-08 18:28:56,434 Cross-validated F1-macro: 0.6182 ± 0.1212


Cross-validated F1-macro: 0.6182 ± 0.1212


## Final Training on Combined Set and Testing




In [18]:
logger.info("Retraining final model on combined train+val set.")
clf.fit(X_combined, y_combined)

y_test_pred = clf.predict(X_test)
test_precision = precision_score(y_test, y_test_pred, average='macro', zero_division=0)
test_recall = recall_score(y_test, y_test_pred, average='macro', zero_division=0)
test_f1 = f1_score(y_test, y_test_pred, average='macro', zero_division=0)

logger.info(
    f"Test results - Precision(macro): {test_precision:.4f}, Recall(macro): {test_recall:.4f}, F1(macro): {test_f1:.4f}"
)
print("Test Classification Report:")
print(classification_report(y_test, y_test_pred, zero_division=0))

2024-12-08 18:29:10,281 Retraining final model on combined train+val set.
2024-12-08 18:29:10,349 Test results - Precision(macro): 0.4333, Recall(macro): 0.4405, F1(macro): 0.4357


Test Classification Report:
              precision    recall  f1-score   support

           0       0.20      0.17      0.18         6
           1       0.67      0.71      0.69        14

    accuracy                           0.55        20
   macro avg       0.43      0.44      0.44        20
weighted avg       0.53      0.55      0.54        20


## Feature Importance




In [19]:
importances = clf.feature_importances_
feature_names = pheno_classes + gene_classes
importance_df = pd.DataFrame({'feature': feature_names, 'importance': importances})
importance_df = importance_df.sort_values('importance', ascending=False)
print("Top 10 most important features:")
print(importance_df.head(10))


Top 10 most important features:
            feature  importance
8        HP:0001257    0.074662
32  ENSG00000144285    0.060936
11       HP:0001332    0.060438
35  ENSG00000177628    0.050995
7        HP:0001251    0.048419
21       HP:0012378    0.048321
15       HP:0002367    0.043687
13       HP:0002076    0.043364
24  ENSG00000080815    0.042907
36  ENSG00000186868    0.039434


## Training a Simple MLP Model

# This neural network model uses the same features and provides a comparative baseline.


In [21]:
class PatientDataset(Dataset):
    def __init__(self, dataframe, pheno_mlb, gene_mlb):
        self.data = dataframe.copy()
        self.pheno_mlb = pheno_mlb
        self.gene_mlb = gene_mlb
        self.data['pheno_vec'] = list(self.pheno_mlb.transform(self.data['positive_phenotypes']))
        self.data['gene_vec'] = list(self.gene_mlb.transform(self.data['all_candidate_genes']))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        pheno_vec = torch.tensor(row['pheno_vec'], dtype=torch.float32)
        gene_vec = torch.tensor(row['gene_vec'], dtype=torch.float32)
        x = torch.cat([pheno_vec, gene_vec])
        y = torch.tensor([1.0 if row['true_diseases'] != 0 else 0.0], dtype=torch.float32)
        return x, y

class SimpleMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim=128):
        super(SimpleMLP, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, 1)
        self.sig = nn.Sigmoid()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.sig(self.fc2(x))
        return x

input_dim = len(pheno_classes) + len(gene_classes)
train_data = PatientDataset(train_df, pheno_mlb, gene_mlb)
val_data = PatientDataset(val_df, pheno_mlb, gene_mlb)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
val_loader = DataLoader(val_data, batch_size=32, shuffle=False)

model = SimpleMLP(input_dim=input_dim, hidden_dim=128)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

epochs = 10
for epoch in range(epochs):
    model.train()
    total_loss = 0.0
    for Xb, yb in train_loader:
        optimizer.zero_grad()
        preds = model(Xb)
        loss = criterion(preds, yb)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg_train_loss = total_loss / len(train_loader)

    model.eval()
    val_loss = 0.0
    preds_list = []
    true_list = []
    with torch.no_grad():
        for Xv, yv in val_loader:
            pv = model(Xv)
            vloss = criterion(pv, yv)
            val_loss += vloss.item()
            predicted = (pv >= 0.5).float().flatten().tolist()
            truth = yv.flatten().tolist()
            preds_list.extend(predicted)
            true_list.extend(truth)
    avg_val_loss = val_loss / len(val_loader)
    precision = precision_score(true_list, preds_list, zero_division=0)
    recall = recall_score(true_list, preds_list, zero_division=0)
    f1 = f1_score(true_list, preds_list, zero_division=0)
    print(f"Epoch {epoch+1}/{epochs}: Train Loss={avg_train_loss:.4f}, Val Loss={avg_val_loss:.4f}, Val Prec={precision:.4f}, Val Recall={recall:.4f}, Val F1={f1:.4f}")

logger.info("Finished training MLP model.")

2024-12-08 18:30:06,364 Finished training MLP model.


Epoch 1/10: Train Loss=0.7140, Val Loss=0.7040, Val Prec=0.8000, Val Recall=0.3636, Val F1=0.5000
Epoch 2/10: Train Loss=0.7025, Val Loss=0.6955, Val Prec=0.6667, Val Recall=0.3636, Val F1=0.4706
Epoch 3/10: Train Loss=0.6932, Val Loss=0.6873, Val Prec=0.7778, Val Recall=0.6364, Val F1=0.7000
Epoch 4/10: Train Loss=0.6833, Val Loss=0.6798, Val Prec=0.7273, Val Recall=0.7273, Val F1=0.7273
Epoch 5/10: Train Loss=0.6747, Val Loss=0.6726, Val Prec=0.7692, Val Recall=0.9091, Val F1=0.8333
Epoch 6/10: Train Loss=0.6656, Val Loss=0.6658, Val Prec=0.7143, Val Recall=0.9091, Val F1=0.8000
Epoch 7/10: Train Loss=0.6580, Val Loss=0.6592, Val Prec=0.7333, Val Recall=1.0000, Val F1=0.8462
Epoch 8/10: Train Loss=0.6506, Val Loss=0.6529, Val Prec=0.6875, Val Recall=1.0000, Val F1=0.8148
Epoch 9/10: Train Loss=0.6418, Val Loss=0.6472, Val Prec=0.6875, Val Recall=1.0000, Val F1=0.8148
Epoch 10/10: Train Loss=0.6352, Val Loss=0.6416, Val Prec=0.6875, Val Recall=1.0000, Val F1=0.8148
