In [None]:
# Clinical / Cognitive / Demographic Baseline Pipeline (CogNID-style)
# Run cells top-to-bottom. Outputs go to data/processed/ and data/processed/plots/
# - Baseline visit priority: bl/init > sc > m03 > m06 > m12 > m24 > later mXX
# - One row per PTID
# - Light tidy + mappings
# - EDA plots
# - Optional: Class-aware KNN imputation (CogNID-style)


In [None]:
# --- imports
import re
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# --- visit priority (CogNID-style)
VISIT_PRIORITY = {
    "bl": 1, "init": 1,
    "sc": 2, "screening": 2,
    "m03": 3, "month3": 3, "3m": 3,
    "m06": 4, "month6": 4, "6m": 4,
    "m12": 5, "month12": 5, "12m": 5,
    "m24": 6, "month24": 6, "24m": 6,
}

# --- mappings from your spec
GENDER_MAP = {1: "male", 2: "female", "1": "male", "2": "female"}
DIAG_MAP   = {1: "CN",   2: "MCI",    3: "DEMENTIA", "1": "CN", "2": "MCI", "3": "DEMENTIA"}

# --- column finder tokens
PTID_TOKENS   = ["ptid", "subjectid", "subject_id", "participantid", "participant_id"]
VISIT_TOKENS  = ["visit", "visist", "viscode", "viscode2"]
GENDER_TOKENS = ["gender"]
DIAG_TOKENS   = ["diagnosis", "diagnoses", "diag"]
AGE_TOKENS    = ["entry_age", "age", "ptage", "baselineage"]

MMSE_TOKENS   = ["mmscore", "mmse"]
CDRSB_TOKENS  = ["cdr sum of boxes", "cdrsb"]
FAQ_TOKENS    = ["faq total", "faq total score", "faq"]
ADAS_TOKENS   = ["adas13", "adas 13"]
COMORB_TOKENS = ["hypertension", "stroke", "smok", "diabet", "cardio", "t2dm"]

# --- small helpers
def normalize_colnames(cols):
    def norm(c):
        c2 = str(c).strip()
        c2 = re.sub(r"\s+", " ", c2)
        return c2
    return [norm(c) for c in cols]

def find_exact_col(df, candidate_keys):
    for c in df.columns:
        lc = c.lower().replace(" ", "")
        for cand in candidate_keys:
            if lc == cand:
                return c
    return None

def find_contains_col(df, token_list):
    for col in df.columns:
        lc = col.lower()
        for tok in token_list:
            if tok in lc:
                return col
    return None

def parse_visit_priority(raw):
    if pd.isna(raw):
        return 10_000
    s = str(raw).strip().lower().replace(" ", "")
    if s in VISIT_PRIORITY:
        return VISIT_PRIORITY[s]
    m = re.match(r"m(\d+)", s)
    if m:
        try:
            months = int(m.group(1))
            base = 7
            return base + months
        except Exception:
            return 10_000
    v = re.match(r"v(\d+)", s)
    if v:
        return 9_000 + int(v.group(1))
    return 10_000

def drop_empty_columns(df: pd.DataFrame, keep_cols=None, min_non_null=1):
    keep_cols = [k for k in (keep_cols or []) if k is not None]
    drop = [c for c in df.columns if c not in keep_cols and df[c].notna().sum() < min_non_null]
    return df.drop(columns=drop), drop

def standardize_yes_no(df: pd.DataFrame, yes_tokens=("y","yes","1"), no_tokens=("n","no","0")):
    out = df.copy()
    for c in out.columns:
        if out[c].dtype == object:
            lc = out[c].astype(str).str.strip().str.lower()
            mask_yes = lc.isin(yes_tokens)
            mask_no  = lc.isin(no_tokens)
            out.loc[mask_yes, c] = "Yes"
            out.loc[mask_no,  c] = "No"
    return out

def plot_bar_counts(series, title, out_png):
    plt.figure()
    series.value_counts(dropna=False).plot(kind="bar", title=title)
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

def boxplot_by_diag(df, feat_col, diag_col, out_png):
    plt.figure()
    ok = False
    if feat_col and diag_col and df[feat_col].notna().sum() > 0:
        groups, labels = [], []
        for dlab in df[diag_col].dropna().unique():
            vals = pd.to_numeric(df.loc[df[diag_col] == dlab, feat_col], errors="coerce").dropna().values
            if len(vals) > 0:
                groups.append(vals); labels.append(str(dlab))
        if groups:
            plt.boxplot(groups, labels=labels)
            plt.title(f"{feat_col} by {diag_col}")
            plt.xlabel(diag_col); plt.ylabel(feat_col)
            ok = True
    if not ok:
        plt.title(f"{feat_col} by {diag_col} (no data)")
    plt.tight_layout(); plt.savefig(out_png, dpi=150); plt.close()

def missingness_heatmap(df, out_png):
    plt.figure(figsize=(8, 6))
    msk = df.isna()
    plt.imshow(msk.values, aspect="auto", interpolation="nearest")
    plt.title("Missingness heatmap (white = missing)")
    plt.xlabel("Columns"); plt.ylabel("Rows")
    plt.colorbar(); plt.tight_layout(); plt.savefig(out_png, dpi=150); plt.close()


In [None]:
# --- set paths (edit as needed)
input_path = Path("data/raw/dem_cli_cog ADNI.xlsx")
outdir     = Path("data/processed")
plots_dir  = outdir / "plots"

outdir.mkdir(parents=True, exist_ok=True)
plots_dir.mkdir(parents=True, exist_ok=True)
print("Input:", input_path)
print("Outdir:", outdir)


In [None]:
# --- load and baseline-filter (one row per PTID)\nxl = pd.ExcelFile(input_path)\ndf = xl.parse(xl.sheet_names[0])\ndf.columns = normalize_colnames(df.columns)\n\nptid_col  = find_exact_col(df, PTID_TOKENS)\nvisit_col = find_exact_col(df, VISIT_TOKENS)\nif ptid_col is None or visit_col is None:\n    raise ValueError(f\"Required columns not found. PTID={ptid_col}, VISIT={visit_col}\\nGot: {list(df.columns)}\")\n\ngender_col = find_contains_col(df, GENDER_TOKENS)\ndiag_col   = find_contains_col(df, DIAG_TOKENS)\nage_col    = find_contains_col(df, AGE_TOKENS)\n\nwork = df.copy()\nwork[\"_visit_priority\"] = work[visit_col].apply(parse_visit_priority)\nwork_sorted = work.sort_values(by=[\"_visit_priority\"]).copy()\nbaseline = work_sorted.drop_duplicates(subset=[ptid_col], keep=\"first\").copy()\nbaseline.drop(columns=[\"_visit_priority\"], inplace=True)\n\n# light tidy\nbaseline, dropped_empty = drop_empty_columns(baseline, keep_cols=[ptid_col, visit_col, gender_col, diag_col])\nbaseline = standardize_yes_no(baseline)\n\n# map codes\nif gender_col:\n    baseline[gender_col] = baseline[gender_col].map(lambda x: GENDER_MAP.get(x, x))\nif diag_col:\n    baseline[diag_col] = baseline[diag_col].map(lambda x: DIAG_MAP.get(x, x))\n\nbaseline_xlsx = outdir / \"clinical_cognitive_demographic_baseline.xlsx\"\nbaseline.to_excel(baseline_xlsx, index=False)\nprint(\"Saved baseline:\", baseline_xlsx)\n\nbaseline.head(5)\n

In [None]:
# --- EDA plots (saved)\n# Age\nif age_col:\n    series = pd.to_numeric(baseline[age_col], errors=\"coerce\").dropna()\n    plt.figure()\n    series.plot(kind=\"hist\", bins=30, title=f\"Histogram of {age_col}\")\n    plt.xlabel(age_col); plt.tight_layout()\n    plt.savefig(plots_dir / \"age_hist.png\", dpi=150); plt.close()\n\n# Gender\nif gender_col:\n    plot_bar_counts(baseline[gender_col], f\"Gender distribution ({gender_col})\", plots_dir / \"gender_bar.png\")\n\n# Genotype\ngeno_col = find_contains_col(baseline, [\"genotype\", \"apoe\"])\nif geno_col:\n    plot_bar_counts(baseline[geno_col], f\"Genotype distribution ({geno_col})\", plots_dir / \"genotype_bar.png\")\n\n# Cognitive/functional by Diagnosis\nfor tokens, fname in [\n    (MMSE_TOKENS, \"mmse_by_diag.png\"),\n    (CDRSB_TOKENS, \"cdrsb_by_diag.png\"),\n    (FAQ_TOKENS, \"faq_by_diag.png\"),\n    (ADAS_TOKENS, \"adas13_by_diag.png\"),\n]:\n    feat_col = find_contains_col(baseline, tokens)\n    boxplot_by_diag(baseline, feat_col, diag_col, plots_dir / fname)\n\n# Comorbidity prevalence (heuristic)\ncomorb_cols = [c for c in baseline.columns if any(tok in c.lower() for tok in COMORB_TOKENS)]\nif comorb_cols:\n    counts = {}\n    for c in comorb_cols:\n        vc = baseline[c].value_counts(dropna=False)\n        pos = 0\n        if 1 in vc.index: pos = max(pos, int(vc.get(1, 0)))\n        if \"Yes\" in vc.index: pos = max(pos, int(vc.get(\"Yes\", 0)))\n        counts[c] = pos\n    ser = pd.Series(counts).sort_values(ascending=False) if counts else pd.Series(dtype=int)\n    plt.figure()\n    if not ser.empty:\n        ser.plot(kind=\"bar\", title=\"Comorbidity prevalence (heuristic positives)\")\n    else:\n        plt.title(\"Comorbidity prevalence (no positives found)\")\n    plt.tight_layout(); plt.savefig(plots_dir / \"comorbidities_bar.png\", dpi=150); plt.close()\n\n# Missingness heatmap\nmissingness_heatmap(baseline, plots_dir / \"missingness_heatmap.png\")\n\nprint(\"Saved plots to:\", plots_dir)\n

In [None]:
# OPTIONAL: run this cell if you want imputed output like CogNID (clinical_imputed.xlsx)\nimpute = True  # set False to skip\n\nif impute:\n    import sys\n    sys.path.append(\"utils\")  # so we can import DataImputation.py\n    from DataImputation import ClassAwareKNNImputer\n\n    target_col = diag_col if diag_col else \"Diagnosis\"\n    imputer = ClassAwareKNNImputer(\n        target_col=target_col,\n        id_cols=[ptid_col],\n        n_neighbors=5,\n        add_noise_frac=0.01,\n        clip_strategy=\"iqr\",\n        percentile_bounds=(1, 99),\n        min_class_size=5,\n        random_state=42\n    )\n\n    df_imputed, report = imputer.fit_transform(baseline)\n    imputed_xlsx = outdir / \"clinical_imputed.xlsx\"\n    df_imputed.to_excel(imputed_xlsx, index=False)\n    (outdir / \"imputation_report.txt\").write_text(str(report))\n\n    print(\"Saved imputed file:\", imputed_xlsx)\n