# Simulate trials

In [1]:
import os
from pathlib import Path


# Get project directory
def get_project_dir():
    root = Path(os.path.expanduser("~"))
    return f"{root}/github/llm-drug-discovery"


project_dir = get_project_dir()
project_dir

'/home/mgustineli/github/llm-drug-discovery'

In [3]:
# Constants
PROJECT_DIR = Path.home() / "github" / "llm-drug-discovery"
DATA_DIR = PROJECT_DIR / "data"
INPUT_CSV = DATA_DIR / "heart_failure_clinical_records_dataset.csv"


TRIAL_TYPES = [
    "Randomized Controlled Trial",
    "Open-Label Study",
    "Double-Blind Study",
    "Observational Study",
    "Phase II Clinical Trial",
    "Phase III Clinical Trial",
]

TRIAL_NAMES = [
    "HEART-PROTECT",
    "CARDIAC-SHIELD",
    "FAILURE-PREVENT",
    "CARDIAC-RECOVERY",
    "HEART-RESTORE",
    "CARDIAC-CARE",
    "HEART-FUNCTION",
    "CARDIAC-BOOST",
    "FAILURE-REVERSE",
    "HEART-STRENGTH",
    "CARDIAC-HEALTH",
    "FAILURE-CONTROL",
]

INTERVENTIONS = [
    "ACE inhibitor therapy",
    "Beta-blocker treatment",
    "Mineralocorticoid receptor antagonist",
    "SGLT2 inhibitor therapy",
    "Novel anti-inflammatory agent",
    "Stem cell therapy",
    "Gene therapy approach",
    "Cardiac rehabilitation program",
    "Remote monitoring system",
    "Digital health intervention",
]

DURATIONS = [90, 180, 365, 730]
AGE_RANGES = [(40, 65), (50, 75), (65, 95), (40, 95)]
EF_RANGES = [(20, 40), (30, 50), (15, 35), (25, 45)]
CREATININE_RANGES = [(0.5, 1.5), (0.7, 2.0), (0.9, 3.0)]
SODIUM_RANGES = [(130, 145), (125, 140), (135, 150)]

In [4]:
import json
import random
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from dataclasses import dataclass, asdict, field
from pathlib import Path
from typing import List, Optional  # Data classes


@dataclass
class EligibilityCriteria:
    age_min: int
    age_max: int
    ef_min: int
    ef_max: int
    creatinine_min: float
    creatinine_max: float
    sodium_min: int
    sodium_max: int
    anaemia: Optional[int] = None
    diabetes: Optional[int] = None
    high_blood_pressure: Optional[int] = None
    sex: Optional[int] = None
    smoking: Optional[int] = None

    def to_text(self) -> str:
        parts = [
            f"Age between {self.age_min} and {self.age_max} years",
            f"Ejection fraction between {self.ef_min}% and {self.ef_max}%",
            f"Serum creatinine between {self.creatinine_min} and {self.creatinine_max} mg/dL",
            f"Serum sodium between {self.sodium_min} and {self.sodium_max} mEq/L",
        ]
        if self.anaemia is not None:
            parts.append("With anaemia" if self.anaemia else "Without anaemia")
        if self.diabetes is not None:
            parts.append("With diabetes" if self.diabetes else "Without diabetes")
        if self.high_blood_pressure is not None:
            parts.append(
                "With high blood pressure"
                if self.high_blood_pressure
                else "Without high blood pressure"
            )
        if self.sex is not None:
            parts.append("Male patients only" if self.sex else "Female patients only")
        if self.smoking is not None:
            parts.append("Current smokers only" if self.smoking else "Non-smokers only")
        return ". ".join(parts)


@dataclass
class Trial:
    id: str
    name: str
    type: str
    description: str
    intervention: str
    duration: int
    eligibility_criteria: EligibilityCriteria
    eligible_patients: List[int] = field(default_factory=list)


@dataclass
class Patient:
    id: int
    demographics: str
    medical_history: str

In [6]:
# Helper: Load data
def load_data(path: Path) -> pd.DataFrame:
    return pd.read_csv(path)


# Generate trial protocols
def generate_trial_protocols(num_trials: int = 10, seed: int = 42) -> List[Trial]:
    np.random.seed(seed)
    random.seed(seed)
    trials: List[Trial] = []
    for i in range(num_trials):
        tid = f"HF-TRIAL-{i + 1:03d}"
        name = f"{random.choice(TRIAL_NAMES)}-{i + 1}"
        ttype = random.choice(TRIAL_TYPES)
        intervention = random.choice(INTERVENTIONS)
        duration = random.choice(DURATIONS)

        # Sample ranges
        age_min, age_max = random.choice(AGE_RANGES)
        ef_min, ef_max = random.choice(EF_RANGES)
        cr_min, cr_max = random.choice(CREATININE_RANGES)
        na_min, na_max = random.choice(SODIUM_RANGES)

        # Binary criteria
        anaemia = random.choice([0, 1, None])
        diabetes = random.choice([0, 1, None])
        high_bp = random.choice([0, 1, None])
        sex = random.choice([0, 1, None])
        smoking = random.choice([0, 1, None])

        ec = EligibilityCriteria(
            age_min,
            age_max,
            ef_min,
            ef_max,
            float(cr_min),
            float(cr_max),
            na_min,
            na_max,
            anaemia,
            diabetes,
            high_bp,
            sex,
            smoking,
        )
        desc = (
            f"A {ttype.lower()} investigating the efficacy of {intervention} "
            f"in heart failure patients over {duration} days."
        )
        trials.append(Trial(tid, name, ttype, desc, intervention, duration, ec))
    return trials


# Determine eligible patients (vectorized)
def determine_eligible_patients(trials: List[Trial], df: pd.DataFrame) -> List[Trial]:
    for trial in trials:
        c = trial.eligibility_criteria
        mask = (
            df["age"].between(c.age_min, c.age_max)
            & df["ejection_fraction"].between(c.ef_min, c.ef_max)
            & df["serum_creatinine"].between(c.creatinine_min, c.creatinine_max)
            & df["serum_sodium"].between(c.sodium_min, c.sodium_max)
        )
        for attr in ("anaemia", "diabetes", "high_blood_pressure", "sex", "smoking"):
            val = getattr(c, attr)
            if val is not None:
                mask &= df[attr] == val
        trial.eligible_patients = df[mask].index.astype(int).tolist()
    return trials


# Format patient data
def format_patient_data(df: pd.DataFrame) -> List[Patient]:
    patients: List[Patient] = []
    for idx, row in df.iterrows():
        demo = f"{'Male' if row['sex'] else 'Female'}, {int(row['age'])} years"
        conditions = [
            name.replace("_", " ").title()
            for name in ("anaemia", "diabetes", "high_blood_pressure", "smoking")
            if row[name] == 1
        ]
        measures = [
            f"Ejection Fraction: {int(row['ejection_fraction'])}%",
            f"Serum Creatinine: {float(row['serum_creatinine'])} mg/dL",
            f"Serum Sodium: {int(row['serum_sodium'])} mEq/L",
            f"Creatinine Phosphokinase: {int(row['creatinine_phosphokinase'])} mcg/L",
            f"Platelets: {int(row['platelets'])} cells/mL",
        ]
        history = ". ".join(conditions + measures)
        patients.append(Patient(int(idx), demo, history))
    return patients

In [8]:
df = load_data(INPUT_CSV)
trials = generate_trial_protocols(15)
trials = determine_eligible_patients(trials, df)
patients = format_patient_data(df)

train_patients, test_patients = train_test_split(
    patients, test_size=0.2, random_state=42
)
train_trials, test_trials = train_test_split(trials, test_size=0.2, random_state=42)

# Serialize
for name, data in (
    ("train_patients.json", train_patients),
    ("test_patients.json", test_patients),
    ("train_trials.json", train_trials),
    ("test_trials.json", test_trials),
):
    path = DATA_DIR / name
    with open(path, "w") as f:
        json.dump([asdict(item) for item in data], f, indent=2)

# Summary
print(f"Generated {len(trials)} trials; formatted {len(patients)} patients.")
print(f"Train: {len(train_patients)} pts, {len(train_trials)} trials")
print(f"Test:  {len(test_patients)} pts, {len(test_trials)} trials")

Generated 15 trials; formatted 299 patients.
Train: 239 pts, 12 trials
Test:  60 pts, 3 trials
