# 02. Train/Test Split & Imputation

Split data 70/30 (stratified on label_30d), then apply four imputation methods:
Simple, MissForest, MICE, and Hybrid.

In [None]:
import sys
sys.path.insert(0, "..")

import os
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import SimpleImputer, IterativeImputer
from missforest import MissForest
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor

from src.config import PROJECT_ROOT, SPLIT_SEED, MODEL_SEED, LABELS
from src.variables import CATEGORY_COLS, LLM_COLS, CODE_COLS

## 1. Load raw data and outcome labels

In [None]:
# ── Paths (config-based) ──
save_base_path = PROJECT_ROOT / "data/processed_imp/260106_split_corr_LLM_ADER/outcome_dataset"
save_base_path.mkdir(parents=True, exist_ok=True)

base_df = pd.read_csv(
    PROJECT_ROOT / "data/raw/ADER_windowday_dataset_number_with_llm_v2.csv"
)

# Outcome labels (from EMR preprocessed data)
outcome_root = PROJECT_ROOT.parent / "Psychiatry/EMR Data/Raw_data/2010-2025/preprocessed_data/outcome"
outcome_dict = {
    f"label_{d}d": pd.read_csv(outcome_root / f"outcome_{d}d.csv")
    for d in [30, 60, 90, 180, 365]
}

In [None]:
valid_ids = set(base_df["환자번호"])
filtered_outcomes = {}

for target, outcome_df in outcome_dict.items():
    outcome_filtered = outcome_df[outcome_df["환자번호"].isin(valid_ids)].copy()
    filtered_outcomes[target] = outcome_filtered

    target_ids = outcome_filtered["환자번호"].unique()
    filtered_df = base_df[base_df["환자번호"].isin(target_ids)].copy()

    target_save_path = save_base_path / f"{target}_dataset.csv"
    filtered_df.to_csv(target_save_path, index=False, encoding="utf-8-sig")

In [None]:
for label_name, df in filtered_outcomes.items():
    print(f"\n{label_name}: {df.shape[0]} samples")
    print(df[label_name].value_counts())

## 2. Variable removal (>60% missing, zero variance, correlation >0.7)

In [None]:
missing_cols = ["BL3125", "BL3137", "NR4303"]
zero_var_cols = ["BL201801", "BL201802", "BL201812", "BL201813", "BL201818"]
corr_cols = [
    "ASI-3-Score", "Agoraphobia-Score", "BDI-II-Score", "BL2011", "BL2012",
    "BL2013", "BL201401", "BL201803", "BL201804", "BL201806", "BL3111",
    "BL311201", "BL312002", "BL314201", "Cognitive-Score", "D", "F", "F(B)",
    "FullScaleIQ-Compositescore", "HAM-A-Score", "Hs", "Interoceptive-Score",
    "K", "LSAS-SR-Score", "PHQ-15-Score", "PSWQ-Score", "Pa",
    "Physical-Score", "Pt", "Social-Score", "Socialphobia-Score",
]

cols_to_drop = sorted(set(missing_cols + zero_var_cols + corr_cols))
cols_to_drop_existing = [col for col in cols_to_drop if col in base_df.columns]
df_filtered = base_df.drop(columns=cols_to_drop_existing)
print(f"Dropped {len(cols_to_drop_existing)} columns. Shape: {df_filtered.shape}")

## 3. Train/test split + imputation (4 methods)

In [None]:
# Category cols for imputation (includes sleep/appetite/weight + LLM)
category_cols_for_imp = [
    "sex", "edu", "job", "marry", "drink", "smoke", "substance_abuse", "psy_family",
    "sleep", "appetite", "weight", "AD_more_three", "ER_more_two",
    "Suicidalidea", "Suicidalplan", "Suicidalattempt",
    "benzodiazepine", "quetiapine", "aripiprazole", "lithium", "divalproex", "olanzapine",
    "bipolar", "depression", "schizophrenia", "anxiety",
    "trauma_stressor_related", "somatic_symptom_disorder", "psychotic_other",
] + LLM_COLS

TARGET = LABELS
base_dir = str(PROJECT_ROOT / "data/processed_imp/260106_split_corr_LLM_ADER/imputation")

before_splits, simpleimp_splits, missforest_splits, mice_splits, hybrid_splits = {}, {}, {}, {}, {}

for t in TARGET:
    print(f"\n--- Processing target: {t} ---")
    outcome_df = outcome_dict[t]
    valid_patients = outcome_df["환자번호"].unique()
    df_full = df_filtered[df_filtered["환자번호"].isin(valid_patients)].copy()

    if t not in df_full.columns:
        df_full = pd.merge(df_full, outcome_df[["환자번호", t]], on="환자번호", how="left")

    other_labels = [col for col in TARGET if col != t and col in df_full.columns]
    df = df_full.drop(columns=["환자번호"] + other_labels)

    feature_cols = [c for c in df.columns if c != t]
    numeric_cols = [c for c in feature_cols if c not in category_cols_for_imp]

    # Train/test split
    train_raw, test_raw = train_test_split(
        df, test_size=0.3, stratify=df[t], random_state=SPLIT_SEED
    )
    before_splits[f"{t}_train"] = train_raw.reset_index(drop=True)
    before_splits[f"{t}_test"] = test_raw.reset_index(drop=True)

    # (A) Simple Imputation
    X_tr_s = train_raw[feature_cols].copy().reset_index(drop=True)
    X_te_s = test_raw[feature_cols].copy().reset_index(drop=True)
    for col in feature_cols:
        strat = "most_frequent" if col in category_cols_for_imp else "mean"
        imp = SimpleImputer(strategy=strat)
        X_tr_s[[col]] = imp.fit_transform(X_tr_s[[col]])
        X_te_s[[col]] = imp.transform(X_te_s[[col]])
    simpleimp_splits[f"{t}_train"] = pd.concat([X_tr_s, train_raw[[t]].reset_index(drop=True)], axis=1)
    simpleimp_splits[f"{t}_test"] = pd.concat([X_te_s, test_raw[[t]].reset_index(drop=True)], axis=1)

    # (B) MissForest
    mf = MissForest(
        categorical=category_cols_for_imp,
        clf=RandomForestClassifier(n_estimators=100, random_state=MODEL_SEED, n_jobs=-1),
        rgr=RandomForestRegressor(n_estimators=100, random_state=MODEL_SEED, n_jobs=-1),
        verbose=False,
    )
    X_tr_mf = pd.DataFrame(mf.fit_transform(train_raw[feature_cols]), columns=feature_cols).reset_index(drop=True)
    X_te_mf = pd.DataFrame(mf.transform(test_raw[feature_cols]), columns=feature_cols).reset_index(drop=True)
    missforest_splits[f"{t}_train"] = pd.concat([X_tr_mf, train_raw[[t]].reset_index(drop=True)], axis=1)
    missforest_splits[f"{t}_test"] = pd.concat([X_te_mf, test_raw[[t]].reset_index(drop=True)], axis=1)

    # (C) MICE
    X_tr_mi = train_raw[feature_cols].copy().reset_index(drop=True)
    X_te_mi = test_raw[feature_cols].copy().reset_index(drop=True)
    for col in category_cols_for_imp:
        if col in X_tr_mi.columns:
            imputer = SimpleImputer(strategy="most_frequent")
            X_tr_mi[[col]] = imputer.fit_transform(X_tr_mi[[col]])
            X_te_mi[[col]] = imputer.transform(X_te_mi[[col]])
    mice = IterativeImputer(random_state=MODEL_SEED)
    mice_num = [c for c in numeric_cols if c in X_tr_mi.columns]
    X_tr_mi[mice_num] = mice.fit_transform(X_tr_mi[mice_num])
    X_te_mi[mice_num] = mice.transform(X_te_mi[mice_num])
    mice_splits[f"{t}_train"] = pd.concat([X_tr_mi, train_raw[[t]].reset_index(drop=True)], axis=1)
    mice_splits[f"{t}_test"] = pd.concat([X_te_mi, test_raw[[t]].reset_index(drop=True)], axis=1)

    # (D) Hybrid (Simple for <=30% missing, MissForest for >30%)
    missing_rate = train_raw[feature_cols].isna().mean()
    simple_cols = missing_rate[missing_rate <= 0.3].index.tolist()
    forest_cols = missing_rate[missing_rate > 0.3].index.tolist()
    X_tr_h = train_raw[feature_cols].copy().reset_index(drop=True)
    X_te_h = test_raw[feature_cols].copy().reset_index(drop=True)

    for col in simple_cols:
        strat = "most_frequent" if col in category_cols_for_imp else "mean"
        imp = SimpleImputer(strategy=strat)
        X_tr_h[[col]] = imp.fit_transform(X_tr_h[[col]])
        X_te_h[[col]] = imp.transform(X_te_h[[col]])

    if forest_cols:
        mf2 = MissForest(
            categorical=[c for c in category_cols_for_imp if c in forest_cols],
            clf=RandomForestClassifier(n_estimators=100, random_state=MODEL_SEED, n_jobs=-1),
            rgr=RandomForestRegressor(n_estimators=100, random_state=MODEL_SEED, n_jobs=-1),
            verbose=False,
        )
        mf2.fit(X_tr_h[forest_cols])
        X_tr_h[forest_cols] = pd.DataFrame(mf2.transform(X_tr_h[forest_cols]), columns=forest_cols)
        X_te_h[forest_cols] = pd.DataFrame(mf2.transform(X_te_h[forest_cols]), columns=forest_cols)

    hybrid_splits[f"{t}_train"] = pd.concat([X_tr_h, train_raw[[t]].reset_index(drop=True)], axis=1)
    hybrid_splits[f"{t}_test"] = pd.concat([X_te_h, test_raw[[t]].reset_index(drop=True)], axis=1)

print("\nAll imputation methods completed.")

## 4. Save imputed datasets

In [None]:
for sub in ["before_imput", "simple_imput", "missforest_imput", "mice_imput", "hybrid_imput"]:
    os.makedirs(os.path.join(base_dir, sub), exist_ok=True)

def save_dict(splits, folder, prefix):
    for key, df_ in splits.items():
        target, part = key.rsplit("_", 1)
        filename = f"{prefix}_{target}_{part}.csv"
        path = os.path.join(base_dir, folder, filename)
        df_.to_csv(path, index=False, encoding="utf-8-sig")

save_dict(before_splits, "before_imput", "before")
save_dict(simpleimp_splits, "simple_imput", "simple")
save_dict(missforest_splits, "missforest_imput", "missforest")
save_dict(mice_splits, "mice_imput", "mice")
save_dict(hybrid_splits, "hybrid_imput", "hybrid")
print("All imputed datasets saved.")