# 03 — Cleaning + Feature Engineering (TrialPulse)

## Objective
Build the **analysis-ready dataset** used for all figures and the dashboard:
- Correct extraction for fields that vary in the API (notably: enrollment + design flags)
- Standardize phases and statuses
- Parse dates and compute durations
- Engineer sponsor type, condition area, geography, and termination reason themes

## Outputs
- `data/processed/trialpulse_analysis.parquet`
- `data/processed/trialpulse_analysis.csv`
- `data/processed/data_quality_notes.csv`

In [1]:
import glob
import json
from pathlib import Path
from typing import Any, Dict, List, Tuple

import pandas as pd
import numpy as np

In [2]:
REPO_ROOT = Path("..").resolve()
DATA_RAW = REPO_ROOT / "data" / "raw"
DATA_PROCESSED = REPO_ROOT / "data" / "processed"
DATA_PROCESSED.mkdir(parents=True, exist_ok=True)

def newest_raw_ndjson() -> Path:
    files = sorted(glob.glob(str(DATA_RAW / "ctgov_studies_*.ndjson")))
    if not files:
        raise FileNotFoundError(f"No ctgov_studies_*.ndjson found in {DATA_RAW}")
    return Path(files[-1])

raw_path = newest_raw_ndjson()
raw_path

PosixPath('/Users/saturnine/Desktop/trialpulse/data/raw/ctgov_studies_20260210T040533Z_3c62edb50608b1de.ndjson')

## Field discovery (why enrollment/design were 100% missing)

ClinicalTrials.gov v2 can place fields in different submodules depending on the study.
We’ll quickly search a handful of records for where keys like `enrollment`, `allocation`, or `masking` actually live.

In [3]:
def find_key_paths(obj: Any, target_key_substr: str, max_paths: int = 20) -> List[str]:
    """
    Recursively search nested dict/list structures for keys containing target_key_substr.
    Returns dotted paths like protocolSection.designModule.enrollmentInfo.count
    """
    paths = []

    def _walk(x: Any, prefix: str = ""):
        nonlocal paths
        if len(paths) >= max_paths:
            return
        if isinstance(x, dict):
            for k, v in x.items():
                p = f"{prefix}.{k}" if prefix else k
                if target_key_substr.lower() in str(k).lower():
                    paths.append(p)
                    if len(paths) >= max_paths:
                        return
                _walk(v, p)
        elif isinstance(x, list):
            for i, v in enumerate(x[:5]):  # cap list exploration
                p = f"{prefix}[{i}]"
                _walk(v, p)

    _walk(obj)
    return paths

In [4]:
sample = []
with raw_path.open("rb") as f:
    for i, line in enumerate(f):
        if i >= 50:
            break
        sample.append(json.loads(line))

targets = ["enrollment", "allocation", "masking", "interventionModel", "primaryPurpose"]
found = {t: set() for t in targets}

for st in sample:
    for t in targets:
        for p in find_key_paths(st, t, max_paths=50):
            found[t].add(p)

{t: sorted(list(paths))[:10] for t, paths in found.items()}

{'enrollment': ['protocolSection.designModule.enrollmentInfo'],
 'allocation': ['protocolSection.designModule.designInfo.allocation'],
 'masking': ['protocolSection.designModule.designInfo.maskingInfo',
  'protocolSection.designModule.designInfo.maskingInfo.masking',
  'protocolSection.designModule.designInfo.maskingInfo.maskingDescription'],
 'interventionModel': ['protocolSection.designModule.designInfo.interventionModel',
  'protocolSection.designModule.designInfo.interventionModelDescription'],
 'primaryPurpose': ['protocolSection.designModule.designInfo.primaryPurpose']}

## Robust flattening (v2)

We rebuild a flat table directly from raw NDJSON using **fallback paths**.
This fixes:
- `enrollment_count`, `enrollment_type`
- design fields where available

In [5]:
def safe_get(d: Dict, path: List[str], default=None):
    cur: Any = d
    for p in path:
        if not isinstance(cur, dict) or p not in cur:
            return default
        cur = cur[p]
    return cur

def to_list(x):
    if x is None:
        return []
    return x if isinstance(x, list) else [x]

def join_from_dict_list(items, key, sep="|"):
    out = []
    for it in to_list(items):
        if isinstance(it, dict):
            v = it.get(key)
            if isinstance(v, str) and v.strip():
                out.append(v.strip())
    return sep.join(out)

def join_clean(items, sep="|"):
    out = []
    for x in to_list(items):
        if isinstance(x, str) and x.strip():
            out.append(x.strip())
    return sep.join(out)

def first_date(module: Dict, field: str):
    if not isinstance(module, dict):
        return None
    ds = module.get(field)
    if isinstance(ds, dict):
        return ds.get("date")
    return None

def pick_first_nonnull(*vals):
    for v in vals:
        if v is None:
            continue
        if isinstance(v, float) and np.isnan(v):
            continue
        return v
    return None

In [6]:
def flatten_study_v2(study: Dict) -> Dict:
    ps = safe_get(study, ["protocolSection"], {}) or {}

    id_mod = ps.get("identificationModule", {}) or {}
    status_mod = ps.get("statusModule", {}) or {}
    sponsor_mod = ps.get("sponsorCollaboratorsModule", {}) or {}
    design_mod = ps.get("designModule", {}) or {}
    cond_mod = ps.get("conditionsModule", {}) or {}
    arms_mod = ps.get("armsInterventionsModule", {}) or {}
    loc_mod = ps.get("contactsLocationsModule", {}) or {}

    nct_id = id_mod.get("nctId")
    brief_title = id_mod.get("briefTitle")
    official_title = id_mod.get("officialTitle")

    overall_status = status_mod.get("overallStatus")
    why_stopped = status_mod.get("whyStopped")

    start_date = first_date(status_mod, "startDateStruct")
    primary_completion_date = first_date(status_mod, "primaryCompletionDateStruct")
    completion_date = first_date(status_mod, "completionDateStruct")

    # Enrollment is inconsistently located across records; try multiple common locations
    enroll_a = safe_get(status_mod, ["enrollmentInfo", "count"])
    enroll_b = safe_get(design_mod, ["enrollmentInfo", "count"])
    enroll_c = safe_get(design_mod, ["enrollmentInfo", "enrollmentCount"])  # sometimes used
    enrollment_count = pick_first_nonnull(enroll_a, enroll_b, enroll_c)

    etype_a = safe_get(status_mod, ["enrollmentInfo", "type"])
    etype_b = safe_get(design_mod, ["enrollmentInfo", "type"])
    enrollment_type = pick_first_nonnull(etype_a, etype_b)

    phases = design_mod.get("phases")
    phases_str = ",".join(phases) if isinstance(phases, list) else phases

    # Design signals (may still be sparse depending on returned schema)
    study_type = pick_first_nonnull(design_mod.get("studyType"), ps.get("designModule", {}).get("studyType"))
    allocation = design_mod.get("allocation")
    intervention_model = design_mod.get("interventionModel")
    masking = design_mod.get("masking")
    primary_purpose = design_mod.get("primaryPurpose")

    lead = sponsor_mod.get("leadSponsor", {}) or {}
    lead_sponsor_name = lead.get("name")
    lead_sponsor_class = lead.get("class")

    conditions = cond_mod.get("conditions", [])
    interventions = arms_mod.get("interventions", [])
    intervention_types = join_from_dict_list(interventions, "type")
    intervention_names = join_from_dict_list(interventions, "name")

    locations = loc_mod.get("locations", [])
    countries = []
    states = []
    for loc in to_list(locations):
        if not isinstance(loc, dict):
            continue
        c = loc.get("country")
        s = loc.get("state")
        if isinstance(c, str) and c.strip():
            countries.append(c.strip())
        if isinstance(s, str) and s.strip():
            states.append(s.strip())

    return {
        "nct_id": nct_id,
        "brief_title": brief_title,
        "official_title": official_title,

        "overall_status": overall_status,
        "why_stopped": why_stopped,

        "start_date": start_date,
        "primary_completion_date": primary_completion_date,
        "completion_date": completion_date,

        "enrollment_count": enrollment_count,
        "enrollment_type": enrollment_type,

        "phases_raw": phases_str,

        "lead_sponsor_name": lead_sponsor_name,
        "lead_sponsor_class": lead_sponsor_class,

        "conditions": join_clean(conditions),
        "intervention_types": intervention_types,
        "intervention_names": intervention_names,

        "countries": "|".join(sorted(set(countries))),
        "states": "|".join(sorted(set(states))),

        "study_type": study_type,
        "allocation": allocation,
        "intervention_model": intervention_model,
        "masking": masking,
        "primary_purpose": primary_purpose,
    }

## Rebuild flat dataset from raw

This produces the canonical table used by TrialPulse going forward.

In [7]:
records = []
with raw_path.open("rb") as f:
    for line in f:
        line = line.strip()
        if not line:
            continue
        st = json.loads(line)
        records.append(flatten_study_v2(st))

df = pd.DataFrame(records)
df.shape

(25000, 23)

## Standardize phases and statuses + compute timeline metrics

In [8]:
# Date parsing (CT.gov returns YYYY-MM-DD or YYYY-MM)
for c in ["start_date", "primary_completion_date", "completion_date"]:
    df[c] = pd.to_datetime(df[c], errors="coerce")

# Phase normalization
def phase_bucket(phases_raw: str) -> str:
    if not isinstance(phases_raw, str) or not phases_raw:
        return "Missing"
    p = phases_raw.upper()
    if p == "PHASE2":
        return "Phase 2"
    if p == "PHASE3":
        return "Phase 3"
    if "PHASE2" in p and "PHASE3" in p:
        return "Phase 2/3"
    if "PHASE2" in p:
        return "Phase 2+"
    if "PHASE3" in p:
        return "Phase 3+"
    return "Other"

df["phase"] = df["phases_raw"].apply(phase_bucket)

# Status grouping
ACTIVE_STATUSES = {
    "RECRUITING", "NOT_YET_RECRUITING", "ENROLLING_BY_INVITATION",
    "ACTIVE_NOT_RECRUITING"
}
def status_group(s: str) -> str:
    if not isinstance(s, str) or not s:
        return "Unknown/Other"
    s = s.upper()
    if s == "COMPLETED":
        return "Completed"
    if s == "TERMINATED":
        return "Terminated"
    if s == "WITHDRAWN":
        return "Withdrawn"
    if s in ACTIVE_STATUSES:
        return "Active"
    if s == "SUSPENDED":
        return "Suspended"
    if s == "UNKNOWN":
        return "Unknown/Other"
    return "Unknown/Other"

df["status_group"] = df["overall_status"].apply(status_group)

# Timeline metrics (days)
df["duration_start_to_primary_days"] = (df["primary_completion_date"] - df["start_date"]).dt.days
df["duration_start_to_completion_days"] = (df["completion_date"] - df["start_date"]).dt.days

# Year for trends (from start_date)
df["start_year"] = df["start_date"].dt.year

## Sponsor type normalization + intervention type primary

In [9]:
def sponsor_bucket(x: str) -> str:
    if not isinstance(x, str) or not x:
        return "Other/Unknown"
    x = x.upper()
    if x == "INDUSTRY":
        return "Industry"
    if x == "NIH":
        return "NIH"
    if x in {"OTHER_GOV", "FED"}:
        return "Government"
    # OTHER, NETWORK, INDIV, UNKNOWN, etc.
    return "Other/Unknown"

df["sponsor_type"] = df["lead_sponsor_class"].apply(sponsor_bucket)

def primary_intervention_type(types_str: str) -> str:
    if not isinstance(types_str, str) or not types_str:
        return "Missing"
    # intervention_types is pipe-separated list; take first for a stable label
    return types_str.split("|")[0].strip() if types_str.split("|")[0].strip() else "Missing"

df["intervention_type_primary"] = df["intervention_types"].apply(primary_intervention_type)

## Condition area mapping (rule-based, explainable)

This is intentionally interpretable: map condition text to broad clinical areas used in ops reporting.

In [10]:
AREA_KEYWORDS = {
    "Oncology": ["cancer", "tumor", "carcinoma", "neoplasm", "lymphoma", "leukemia", "melanoma"],
    "Cardiology": ["cardio", "heart", "coronary", "myocard", "hypertension", "stroke", "vascular"],
    "Neurology": ["neuro", "alzheimer", "parkinson", "migraine", "epilepsy", "multiple sclerosis", "dementia"],
    "Infectious Disease": ["infection", "infectious", "covid", "sars", "influenza", "hiv", "hepatitis", "tuberculosis"],
    "Endocrinology/Metabolic": ["diabetes", "obesity", "metabolic", "thyroid", "cholesterol", "hyperlipid"],
    "Immunology": ["autoimmune", "lupus", "rheumatoid", "immun", "crohn", "colitis", "psoriasis"],
    "Psychiatry": ["depression", "bipolar", "schizophrenia", "anxiety", "ptsd", "adhd"],
    "Pulmonology": ["asthma", "copd", "pulmonary", "respiratory", "lung"],
    "Renal": ["kidney", "renal", "nephro"],
    "Dermatology": ["derma", "skin", "eczema", "acne"],
    "Hematology": ["anemia", "hemophilia", "thromb", "platelet"],
    "Gastroenterology": ["gastro", "liver", "hepatic", "cirrhosis", "pancrea", "ibd"],
}

def map_condition_area(conditions_str: str) -> str:
    if not isinstance(conditions_str, str) or not conditions_str:
        return "Other/Unmapped"
    text = conditions_str.lower()
    hits = []
    for area, kws in AREA_KEYWORDS.items():
        if any(kw in text for kw in kws):
            hits.append(area)
    if len(hits) == 0:
        return "Other/Unmapped"
    if len(hits) == 1:
        return hits[0]
    return "Multi-area"

df["condition_area"] = df["conditions"].apply(map_condition_area)

## Geography features

We keep it simple and decision-relevant:
- number of countries
- whether US is included
- number of US states listed (when present)

In [11]:
def count_pipe(x: str) -> int:
    if not isinstance(x, str) or not x:
        return 0
    return len([p for p in x.split("|") if p.strip()])

df["n_countries"] = df["countries"].apply(count_pipe)
df["n_states"] = df["states"].apply(count_pipe)
df["has_us"] = df["countries"].fillna("").str.contains("United States", case=False, na=False)

## Termination/withdrawal reason themes (interpretable)

`why_stopped` is missing for many trials (common in public data).
We keep that explicit and theme what we have with keyword rules.

In [12]:
def clean_reason(x: str) -> str:
    if not isinstance(x, str) or not x.strip():
        return "Not reported"
    r = " ".join(x.strip().split())
    return r

df["why_stopped_clean"] = df["why_stopped"].apply(clean_reason)

THEME_RULES = {
    "Safety/Adverse events": ["adverse", "safety", "tox", "side effect", "serious", "ae"],
    "Efficacy/Futility": ["lack of efficacy", "ineffic", "futil", "not effective", "endpoint not met"],
    "Enrollment/Recruitment": ["enroll", "recruit", "accru", "slow", "insufficient subjects", "low participation"],
    "Funding/Business": ["fund", "budget", "financial", "business", "strategic", "priority", "sponsor decision"],
    "Operational/Logistics": ["operational", "logistic", "site", "staff", "supply", "vendor"],
    "COVID-19": ["covid", "pandemic", "sars-cov-2"],
    "Regulatory/Ethics": ["regulatory", "irb", "ethic", "compliance", "approval"],
}

def reason_theme(reason: str) -> str:
    r = reason.lower() if isinstance(reason, str) else ""
    if reason == "Not reported":
        return "Not reported"
    for theme, kws in THEME_RULES.items():
        if any(kw in r for kw in kws):
            return theme
    return "Other/Unclear"

df["why_theme"] = df["why_stopped_clean"].apply(reason_theme)

## Data quality snapshot (what’s missing, by design)

We export a small QA table to document missingness and field coverage.
This becomes part of the professional limitations narrative.

In [13]:
dq = pd.DataFrame({
    "column": df.columns,
    "missing_pct": (df.isna().mean() * 100).round(2).values,
    "example": [df[c].dropna().iloc[0] if df[c].notna().any() else None for c in df.columns],
}).sort_values("missing_pct", ascending=False)

dq_path = DATA_PROCESSED / "data_quality_notes.csv"
dq.to_csv(dq_path, index=False)

dq.head(15), dq_path

(                               column  missing_pct  \
 22                    primary_purpose       100.00   
 21                            masking       100.00   
 20                 intervention_model       100.00   
 19                         allocation       100.00   
 4                         why_stopped        87.78   
 26  duration_start_to_completion_days        56.78   
 25     duration_start_to_primary_days        56.35   
 7                     completion_date        53.99   
 6             primary_completion_date        53.98   
 5                          start_date        49.93   
 27                         start_year        49.93   
 9                     enrollment_type         5.13   
 8                    enrollment_count         1.68   
 2                      official_title         1.52   
 28                       sponsor_type         0.00   
 
                                               example  
 22                                               None  
 21 

## Final analysis dataset export

This is the dataset used for:
- figures (Notebook 04)
- dashboard (Notebook 05)

In [14]:
out_parquet = DATA_PROCESSED / "trialpulse_analysis.parquet"
out_csv = DATA_PROCESSED / "trialpulse_analysis.csv"

df.to_parquet(out_parquet, index=False)
df.to_csv(out_csv, index=False)

(out_parquet, out_csv)

(PosixPath('/Users/saturnine/Desktop/trialpulse/data/processed/trialpulse_analysis.parquet'),
 PosixPath('/Users/saturnine/Desktop/trialpulse/data/processed/trialpulse_analysis.csv'))

## Expected output checks

If these are `True`, Notebook 03 is complete.

In [15]:
out_parquet.exists(), out_csv.exists(), dq_path.exists()

(True, True, True)

In [16]:
# Key coverage checks we care about
(df["enrollment_count"].isna().mean() * 100).round(1), (df["why_stopped"].isna().mean() * 100).round(1)

(np.float64(1.7), np.float64(87.8))

## Patch: Extract designInfo fields (allocation/masking/model/purpose)

Field discovery shows these live under:
`protocolSection.designModule.designInfo.*`

We backfill:
- allocation
- masking
- intervention_model
- primary_purpose

In [17]:
def extract_designinfo_fields(study: Dict) -> Dict:
    ps = study.get("protocolSection", {}) or {}
    dm = ps.get("designModule", {}) or {}
    di = dm.get("designInfo", {}) or {}

    allocation = di.get("allocation")
    intervention_model = di.get("interventionModel")
    primary_purpose = di.get("primaryPurpose")

    masking_info = di.get("maskingInfo", {}) or {}
    masking = masking_info.get("masking")  # sometimes nested

    return {
        "allocation": allocation,
        "intervention_model": intervention_model,
        "primary_purpose": primary_purpose,
        "masking": masking,
    }

# Build mapping nct_id -> designInfo fields
design_map = {}
with raw_path.open("rb") as f:
    for line in f:
        st = json.loads(line)
        ps = st.get("protocolSection", {}) or {}
        nct = (ps.get("identificationModule", {}) or {}).get("nctId")
        if nct:
            design_map[nct] = extract_designinfo_fields(st)

design_df = pd.DataFrame.from_dict(design_map, orient="index").reset_index().rename(columns={"index":"nct_id"})
design_df.shape

(25000, 5)

In [18]:
df = df.drop(columns=["allocation","intervention_model","primary_purpose","masking"], errors="ignore").merge(
    design_df, on="nct_id", how="left"
)

# Recompute DQ quickly for those fields (optional)
(df[["allocation","intervention_model","primary_purpose","masking"]].isna().mean() * 100).round(1)

allocation            2.2
intervention_model    3.0
primary_purpose       0.7
masking               2.2
dtype: float64

In [19]:
out_parquet = DATA_PROCESSED / "trialpulse_analysis.parquet"
out_csv = DATA_PROCESSED / "trialpulse_analysis.csv"
df.to_parquet(out_parquet, index=False)
df.to_csv(out_csv, index=False)

out_parquet, out_csv

(PosixPath('/Users/saturnine/Desktop/trialpulse/data/processed/trialpulse_analysis.parquet'),
 PosixPath('/Users/saturnine/Desktop/trialpulse/data/processed/trialpulse_analysis.csv'))