In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
ADNI clinical/cognitive/demographic baseline extractor + EDA (CogNID-style).
- One row per PTID using baseline visit priority:
  bl/init > sc > m03 > m06 > m12 > m24 > later (m36, m48, ...)
- Optional code mapping:
  Gender: 1->male, 2->female
  Diagnosis: 1->CN, 2->MCI, 3->DEMENTIA
- Writes baseline Excel + preview CSV + EDA plots (PNG).

Usage:
  python clinical_baseline_pipeline.py --input "/path/dem_cli_cog ADNI.xlsx" --outdir "./out"

Requirements:
  pandas, numpy, matplotlib, openpyxl (for .xlsx I/O)
"""

import argparse
from pathlib import Path
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


# ---------------------- Config (edit if you like) ----------------------

# Visit priority (lower number = higher priority)
VISIT_PRIORITY = {
    "bl": 1, "init": 1,
    "sc": 2, "screening": 2,
    "m03": 3, "month3": 3, "3m": 3,
    "m06": 4, "month6": 4, "6m": 4,
    "m12": 5, "month12": 5, "12m": 5,
    "m24": 6, "month24": 6, "24m": 6,
    # later mXX will be handled generically (m36, m48, ...)
}

# Simple mappings (only applied if the columns exist)
GENDER_MAP = {1: "male", 2: "female", "1": "male", "2": "female"}
DIAG_MAP   = {1: "CN",   2: "MCI",    3: "DEMENTIA", "1": "CN", "2": "MCI", "3": "DEMENTIA"}

# Heuristic tokens to detect common columns (donâ€™t change unless needed)
PTID_TOKENS   = ["ptid", "subjectid", "subject_id", "participantid", "participant_id"]
VISIT_TOKENS  = ["visit", "visist", "viscode", "viscode2"]
GENDER_TOKENS = ["gender"]
DIAG_TOKENS   = ["diagnosis", "diagnoses", "diag"]
AGE_TOKENS    = ["entry_age", "age", "ptage", "baselineage"]

# Cognitive/functional feature tokens for plots (auto-detected if present)
MMSE_TOKENS   = ["mmscore", "mmse"]
CDRSB_TOKENS  = ["cdr sum of boxes", "cdrsb"]
FAQ_TOKENS    = ["faq total", "faq total score", "faq"]
ADAS_TOKENS   = ["adas13", "adas 13"]

# Comorbidity tokens to scan for simple prevalence bars (optional)
COMORB_TOKENS = ["hypertension", "stroke", "smok", "diabet", "cardio", "t2dm"]

# ----------------------------------------------------------------------


def normalize_colnames(cols):
    def norm(c):
        c2 = str(c).strip()
        c2 = re.sub(r"\s+", " ", c2)
        return c2
    return [norm(c) for c in cols]


def find_exact_col(df, candidate_keys):
    """Match normalized lower/space-removed name exactly to any candidate key."""
    for c in df.columns:
        lc = c.lower().replace(" ", "")
        for cand in candidate_keys:
            if lc == cand:
                return c
    return None


def find_contains_col(df, token_list):
    """Return the first column containing any token (case-insensitive)."""
    for col in df.columns:
        lc = col.lower()
        for tok in token_list:
            if tok in lc:
                return col
    return None


def parse_visit_priority(raw):
    """Map visit string to an integer priority; lower = better (earlier)."""
    if pd.isna(raw):
        return 10_000
    s = str(raw).strip().lower().replace(" ", "")
    if s in VISIT_PRIORITY:
        return VISIT_PRIORITY[s]
    # Generic mNN pattern (e.g., m36, m48)
    m = re.match(r"m(\d+)", s)
    if m:
        try:
            months = int(m.group(1))
            base = 7
            return base + months  # ensures m36 > m24 > m12
        except Exception:
            return 10_000
    # vNN fallback demoted
    v = re.match(r"v(\d+)", s)
    if v:
        return 9_000 + int(v.group(1))
    return 10_000


def plot_hist(series, title, out_png):
    plt.figure()
    try:
        series = pd.to_numeric(series, errors="coerce")
        series.dropna().plot(kind="hist", bins=30, title=title)
        plt.xlabel(title)
    except Exception:
        plt.title(f"{title} (not plottable)")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()


def plot_bar_counts(series, title, out_png):
    plt.figure()
    series.value_counts(dropna=False).plot(kind="bar", title=title)
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()


def boxplot_by_diag(df, feat_col, diag_col, out_png):
    plt.figure()
    ok = False
    if feat_col and diag_col and df[feat_col].notna().sum() > 0:
        groups, labels = [], []
        for dlab in df[diag_col].dropna().unique():
            vals = pd.to_numeric(df.loc[df[diag_col] == dlab, feat_col], errors="coerce").dropna().values
            if len(vals) > 0:
                groups.append(vals)
                labels.append(str(dlab))
        if groups:
            plt.boxplot(groups, labels=labels)
            plt.title(f"{feat_col} by {diag_col}")
            plt.xlabel(diag_col)
            plt.ylabel(feat_col)
            ok = True
    if not ok:
        plt.title(f"{feat_col} by {diag_col} (no data)")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()


def missingness_heatmap(df, out_png):
    plt.figure(figsize=(8, 6))
    msk = df.isna()
    plt.imshow(msk.values, aspect="auto", interpolation="nearest")
    plt.title("Missingness heatmap (white = missing)")
    plt.xlabel("Columns")
    plt.ylabel("Rows")
    plt.colorbar()
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--input", required=True, help="Path to the clinical/cognitive/demographic Excel file")
    parser.add_argument("--sheet", default=None, help="Sheet name (optional). If not set, uses the first sheet.")
    parser.add_argument("--outdir", default="./out", help="Output directory")
    args = parser.parse_args()

    in_path = Path(args.input)
    outdir = Path(args.outdir)
    outdir.mkdir(parents=True, exist_ok=True)
    plots_dir = outdir / "plots"
    plots_dir.mkdir(parents=True, exist_ok=True)

    # Load
    xl = pd.ExcelFile(in_path)
    if args.sheet:
        df = xl.parse(args.sheet)
        sheet_used = args.sheet
    else:
        df = xl.parse(xl.sheet_names[0])
        sheet_used = xl.sheet_names[0]

    # Normalize column names
    df.columns = normalize_colnames(df.columns)

    # Find key columns
    ptid_col  = find_exact_col(df, PTID_TOKENS)
    visit_col = find_exact_col(df, VISIT_TOKENS)
    if ptid_col is None or visit_col is None:
        # Show available columns to help the user rename
        print("\n[ERROR] Required columns not found.")
        print("  Need a PTID-like column and a VISIT-like column.")
        print("  Detected columns:")
        for c in df.columns:
            print("  -", c)
        missing = []
        if ptid_col is None:  missing.append("PTID")
        if visit_col is None: missing.append("VISIT")
        raise SystemExit(f"Missing required: {', '.join(missing)}")

    # Optional columns
    gender_col = find_contains_col(df, GENDER_TOKENS)
    diag_col   = find_contains_col(df, DIAG_TOKENS)
    age_col    = find_contains_col(df, AGE_TOKENS)

    # Visit priority -> one row per PTID
    work = df.copy()
    work["_visit_priority"] = work[visit_col].apply(parse_visit_priority)
    work_sorted = work.sort_values(by=["_visit_priority"]).copy()
    baseline = work_sorted.drop_duplicates(subset=[ptid_col], keep="first").copy()
    baseline = baseline.drop(columns=["_visit_priority"])

    # Mappings
    if gender_col:
        baseline[gender_col] = baseline[gender_col].map(lambda x: GENDER_MAP.get(x, x))
    if diag_col:
        baseline[diag_col] = baseline[diag_col].map(lambda x: DIAG_MAP.get(x, x))

    # Outputs
    baseline_xlsx  = outdir / "clinical_cognitive_demographic_baseline.xlsx"
    preview_csv    = outdir / "clinical_baseline_preview.csv"
    baseline.to_excel(baseline_xlsx, index=False)
    baseline.head(50).to_csv(preview_csv, index=False)

    # --- Plots (saved as PNGs) ---
    # 1) Age histogram
    plot_hist(baseline[age_col] if age_col else pd.Series([], dtype=float),
              f"Histogram of {age_col or 'age (not found)'}",
              plots_dir / "age_hist.png")

    # 2) Gender distribution
    if gender_col:
        plot_bar_counts(baseline[gender_col], f"Gender distribution ({gender_col})", plots_dir / "gender_bar.png")
    else:
        plot_bar_counts(pd.Series([], dtype=float), "Gender distribution (not found)", plots_dir / "gender_bar.png")

    # 3) Genotype distribution
    geno_col = find_contains_col(baseline, ["genotype", "apoe"])
    if geno_col:
        plot_bar_counts(baseline[geno_col], f"Genotype distribution ({geno_col})", plots_dir / "genotype_bar.png")
    else:
        plot_bar_counts(pd.Series([], dtype=float), "Genotype distribution (not found)", plots_dir / "genotype_bar.png")

    # 4) Cognitive/functional by diagnosis
    for tokens, fname in [
        (MMSE_TOKENS, "mmse_by_diag.png"),
        (CDRSB_TOKENS, "cdrsb_by_diag.png"),
        (FAQ_TOKENS, "faq_by_diag.png"),
        (ADAS_TOKENS, "adas13_by_diag.png"),
    ]:
        feat_col = find_contains_col(baseline, tokens)
        boxplot_by_diag(baseline, feat_col, diag_col, plots_dir / fname)

    # 5) Comorbidity prevalence (simple heuristic)
    comorb_cols = [c for c in baseline.columns if any(tok in c.lower() for tok in COMORB_TOKENS)]
    if comorb_cols:
        counts = {}
        for c in comorb_cols:
            vc = baseline[c].value_counts(dropna=False)
            # heuristics: count positives as 1 or 'Yes'
            pos = 0
            if 1 in vc.index: pos = max(pos, int(vc.get(1, 0)))
            if "yes" in [str(x).lower() for x in vc.index]:
                pos = max(pos, int(vc.get("Yes", 0)))
            counts[c] = pos
        ser = pd.Series(counts).sort_values(ascending=False) if counts else pd.Series(dtype=int)
        plt.figure()
        if not ser.empty:
            ser.plot(kind="bar", title="Comorbidity prevalence (heuristic positives)")
        else:
            plt.title("Comorbidity prevalence (no positives found)")
        plt.tight_layout()
        plt.savefig(plots_dir / "comorbidities_bar.png", dpi=150)
        plt.close()
    else:
        # still create a placeholder chart
        plt.figure()
        plt.title("Comorbidity prevalence (columns not found)")
        plt.tight_layout()
        plt.savefig(plots_dir / "comorbidities_bar.png", dpi=150)
        plt.close()

    # 6) Missingness heatmap
    missingness_heatmap(baseline, plots_dir / "missingness_heatmap.png")

    # Summary print
    print("\n=== Baseline extraction complete ===")
    print(f"Input file:     {in_path}")
    print(f"Sheet used:     {sheet_used}")
    print(f"PTID column:    {ptid_col}")
    print(f"Visit column:   {visit_col}")
    print(f"Gender column:  {gender_col or '(not found)'}")
    print(f"Diagnosis col:  {diag_col or '(not found)'}")
    print(f"Rows in input:  {len(df)}")
    print(f"Unique PTIDs:   {df[ptid_col].nunique()}")
    print(f"Baseline rows:  {len(baseline)}")
    print(f"\nOutputs:")
    print(f"- Baseline Excel: {baseline_xlsx}")
    print(f"- Preview CSV:    {preview_csv}")
    print(f"- Plots folder:   {plots_dir.resolve()}")


if __name__ == "__main__":
    main()
