# TRAINING ML CLASSIFICATION AND REGRESSION MODELS ON STUDY A AND VLIDATION ON INDEPENDENT STUDY B COHORT

### Required libraries

In [None]:
import pandas as pd
import numpy as np
import random
import os
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.metrics import roc_auc_score, average_precision_score, f1_score, roc_curve
from scipy import stats

In [None]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)

### Read data

In [None]:
DATA_DIR = "."
WCCT_PATH = os.path.join(DATA_DIR, "WCCT_data.csv")
VXA_PATH  = os.path.join(DATA_DIR, "VXA_data.csv")
MCL_PATH  = os.path.join(DATA_DIR, "matched_populations.xlsx")

In [None]:
for p in [WCCT_PATH, VXA_PATH, MCL_PATH]:
    if not os.path.exists(p):
        raise FileNotFoundError(p)

In [None]:
wcct = pd.read_csv(WCCT_PATH)
vxa  = pd.read_csv(VXA_PATH)
mcl  = pd.read_excel(MCL_PATH)
wcct["DAY_REAL"] = pd.to_numeric(wcct["DPC"], errors="coerce")
vxa["DAY_REAL"]  = pd.to_numeric(vxa["DPC"], errors="coerce")

In [None]:
def make_dpc_num(df):
    s = df["DPC"].astype(str).str.strip()
    s = s.replace({"Baseline 1": "-1", "Baseline 2": "0", "BLS2": "0"})
    df["DPC_NUM"] = pd.to_numeric(s, errors="coerce")
    df = df[df["DPC_NUM"].notna()].copy()
    df["DPC_NUM"] = df["DPC_NUM"].astype(int)
    return df

In [None]:
wcct = make_dpc_num(wcct)
vxa  = make_dpc_num(vxa)

### Truncate columns

In [None]:
def truncate_columns(df, exclude_columns=None, decimals=5):
    if exclude_columns is None:
        exclude_columns = []
    cols_to_truncate = df.select_dtypes(include="number").columns.difference(exclude_columns)
    factor = 10 ** decimals
    tmp = df[cols_to_truncate].replace([np.inf, -np.inf], np.nan)
    df[cols_to_truncate] = np.trunc(tmp * factor) / factor
    return df

In [None]:
wcct = truncate_columns(wcct, exclude_columns=["VOLUNTEER", "DPC", "SHEDDER"], decimals=5)
vxa  = truncate_columns(vxa,  exclude_columns=["VOLUNTEER", "DPC", "SHEDDER"], decimals=5)

### Preprocessing

In [None]:
def reverse_days(x):
    if x in ["Baseline 2", "BLS2"]:
        return 0
    elif x in ["Baseline 1"]:
        return -1
    else:
        return int(x)

In [None]:
wcct["DPC"] = wcct["DPC"].apply(lambda x: reverse_days(x))
vxa["DPC"]  = vxa["DPC"].apply(lambda x: reverse_days(x))

In [None]:
wcct["DPC_NUM"] = wcct["DPC"]
vxa["DPC_NUM"]  = vxa["DPC"]

In [None]:
assert pd.api.types.is_integer_dtype(wcct["DPC"]) and pd.api.types.is_integer_dtype(vxa["DPC"])

### Baseline-normalization

In [None]:
for name, df in [("WCCT", wcct), ("VXA", vxa)]:
    df_list = []
    for vid, sub in df.groupby("VOLUNTEER"):
        sub = sub.sort_values("DPC")
        base = sub[sub["DPC"] == 0]  # BL2 baseline
        if base.empty:
            df_list.append(sub)
            continue

        cols = sub.columns.difference(["VOLUNTEER", "DPC", "DPC_NUM", "DAY_REAL", "SHEDDER"])
        sub.loc[:, cols] = sub.loc[:, cols] - base.iloc[0][cols]
        df_list.append(sub)

    df_norm = pd.concat(df_list, ignore_index=True)
    df_norm = df_norm[df_norm["DPC"] != 0]  # remove BL2

    if name == "WCCT":
        wcct = df_norm
    else:
        vxa = df_norm

### Match cell populations

In [None]:
mcl["WCCT"] = mcl["WCCT"].astype(str).str.strip()
mcl["VXA"]  = mcl["VXA"].astype(str).str.strip()

In [None]:
wcct_cols = mcl["WCCT"].dropna().unique().tolist()
vxa_cols  = mcl["VXA"].dropna().unique().tolist()

In [None]:
for col in ["VOLUNTEER", "DPC", "DPC_NUM", "DAY_REAL", "SHEDDER"]:
    if col not in wcct_cols:
        wcct_cols.append(col)
    if col not in vxa_cols:
        vxa_cols.append(col)

In [None]:
wcct = wcct[[c for c in wcct.columns if c in wcct_cols]]
vxa  = vxa[[c for c in vxa.columns  if c in vxa_cols]]

In [None]:
# Align names
vxa.columns = wcct.columns

In [None]:
to_drop = [
    "CD4-CD8- T Cells",
    "CD4-CD8- T Cells CD38+",
    "CD4-CD8- T Cells CD38+Ki67+",
    "CD8 T Cells",
    "CD8 T Cells CD38+",
    "CD8 T Cells CD38+Ki67+",
]
drop_set = {x.upper() for x in to_drop}
cols_to_drop = [c for c in wcct.columns if c.upper() in drop_set]
wcct = wcct.drop(columns=cols_to_drop, errors="ignore")
vxa  = vxa.drop(columns=cols_to_drop, errors="ignore")

In [None]:
# Rescale CD66+ by dividing by 100
for df in [wcct, vxa]:
    cd66_cols = [col for col in df.columns if "CD66+" in col.upper()]
    if cd66_cols:
        df.loc[:, cd66_cols] = df.loc[:, cd66_cols] / 100

## Random Forest Classifier

In [None]:
def evaluate_preds(true, pred_proba):
    auc = roc_auc_score(true, pred_proba)
    pr  = average_precision_score(true, pred_proba)
    bin_pred = (pred_proba > 0.5).astype(int)
    f1 = f1_score(true, bin_pred)
    return auc, pr, f1
tr_days = [2,3,4,5,6,7,8] 
clf_wcct = wcct[wcct["DPC"].isin(tr_days)].copy
tr_x = clf_wcct.drop(["VOLUNTEER","DPC","DPC_NUM","DAY_REAL","SHEDDER"], axis=1, errors="ignore")
tr_y = clf_wcct["SHEDDER"]

### ROC Curve per Day

In [None]:
plt.figure(figsize=(10, 8))
mean_fpr_grid = np.linspace(0, 1, 100)
plt.plot([0, 1], [0, 1], "--", color="yellow", label="Chance")
for day in [1, 2, 3, 5]:
    test = vxa[vxa["DAY_REAL"] == day].copy()

    if test.empty:
        print(f"Skip Day {day}: empty")
        continue
    if test["SHEDDER"].nunique() < 2:
        print(f"Skip Day {day}: one class only")
        continue

    testx = test.drop(["VOLUNTEER","DPC","DPC_NUM","DAY_REAL","SHEDDER"], axis=1, errors="ignore")
    testy = test["SHEDDER"]

    interp_tprs, aucs = [], []

    for i in range(100):
        rf = RandomForestClassifier(random_state=SEED + i)
        rf.fit(tr_x, tr_y)
        probs = rf.predict_proba(testx)[:, 1]

        fpr, tpr, _ = roc_curve(testy, probs)
        interp_tpr = np.interp(mean_fpr_grid, fpr, tpr)
        interp_tpr[0] = 0.0

        interp_tprs.append(interp_tpr)
        aucs.append(roc_auc_score(testy, probs))

    mean_tpr = np.mean(interp_tprs, axis=0)
    mean_tpr[-1] = 1.0
    mean_auc = np.mean(aucs)

    plt.plot(mean_fpr_grid, mean_tpr, label=f"Day-{day} (AUC = {mean_auc:.2f})")
plt.title("ROC Curves per Day — Study B (VXA)")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.legend(loc="lower right")
plt.tight_layout()
plt.show()

### Filter shedders

In [None]:
wcct_s = wcct[wcct["SHEDDER"] == 1].copy()
wcct_s = wcct_s[~wcct_s["VOLUNTEER"].isin([101,109,304])].copy()
vxa_s  = vxa[vxa["SHEDDER"] == 1].copy()

### Filter unwanted days out

In [None]:
tr_days_reg = [1,2,3,4,5,6,7,8]
rg_wcct = wcct_s[wcct_s["DPC"].isin(tr_days_reg)].copy()
rg_vxa  = vxa_s[vxa_s["DPC"].isin(tr_days_reg)].copy()

## Random Forest regressor

In [None]:
regr = RandomForestRegressor(random_state=SEED)
trus, preds = [], []
to_write = pd.DataFrame()
X_tr = rg_wcct.drop(["VOLUNTEER","DPC","SHEDDER","DAY_REAL","DPC_NUM"], axis=1, errors="ignore")
Y_tr = rg_wcct["DPC"]
clf = regr.fit(X_tr, Y_tr)
vxa_ids  = sorted(rg_vxa["VOLUNTEER"].unique())
vxa_days = sorted(rg_vxa["DPC"].unique())
for id_ in vxa_ids:
    for d in vxa_days:
        test = rg_vxa[(rg_vxa["VOLUNTEER"] == id_) & (rg_vxa["DPC"] == d)]
        if test.empty:
            continue

        X_te = test.drop(["VOLUNTEER","DPC","SHEDDER","DAY_REAL","DPC_NUM"], axis=1, errors="ignore")
        Y_te = test["DPC"]

        pr = regr.predict(X_te)

        trus.extend(Y_te.values)
        preds.extend(pr)

        dic_ = {
            "iteration": d,
            "ID": id_,
            "true": float(Y_te.values[0]),
            "predicted": float(pr[0]),
        }
        to_write = pd.concat([to_write, pd.DataFrame([dic_])], ignore_index=True)
        correlation, p_value = stats.pearsonr(trus, preds)
print("Pearson r:", correlation, "p:", p_value)

### Scatter Plot of True vs Predicted values
### Median Line, regression line, 95% confidence Interval, Pearson correlation coefficient and p-value

In [None]:
def plot_scatter_with_ci(y_true, y_pred, n_bootstrap=1000):
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)
    r, p = stats.pearsonr(y_true, y_pred)

    plt.figure(figsize=(6, 6))
    plt.scatter(y_true, y_pred, alpha=0.7, label="Data points")

    lims = [1, 7]
    plt.plot(lims, lims, "k--", alpha=0.6, label="y = x")

    median_pred = np.median(y_pred)
    plt.axhline(median_pred, linestyle="--", label=f"Median pred = {median_pred:.2f}")

    slope, intercept, _, _, _ = stats.linregress(y_true, y_pred)
    line = slope * np.array(lims) + intercept
    plt.plot(lims, line, label="Fit line")

    boot_lines = []
    for _ in range(n_bootstrap):
        idx = np.random.choice(len(y_true), len(y_true), replace=True)
        s, i, _, _, _ = stats.linregress(y_true[idx], y_pred[idx])
        boot_lines.append(s * np.array(lims) + i)

    boot_lines = np.array(boot_lines)
    low_line  = np.percentile(boot_lines, 2.5, axis=0)
    high_line = np.percentile(boot_lines, 97.5, axis=0)

    plt.fill_between(lims, low_line, high_line, alpha=0.2, label="95% CI")

    plt.xlabel("True values")
    plt.ylabel("Predicted values")
    plt.title(f"Study B — Pearson r = {r:.2f}, p = {p:.2e}")
    plt.legend()
    plt.grid(alpha=0.3)
    plt.xlim(0.7, 7.3)
    plt.ylim(0.7, 7.3)
    plt.xticks(range(1, 8))
    plt.yticks(range(1, 8))
    plt.tight_layout()
    plt.show()
plot_scatter_with_ci(trus, preds)

### rMSE 
### preds:list of predicted days
### trus:list of true days

In [None]:
def plot_daywise_rmse(preds, trus):
    preds = np.array(preds)
    trus  = np.array(trus)

    expected_days = sorted(list(set(np.round(trus).astype(int))))
    rmses = []

    for d in expected_days:
        mask = np.isclose(trus, d, atol=0.1)
        if mask.sum() == 0:
            rmses.append(np.nan)
            continue
        rmse = np.sqrt(np.mean((preds[mask] - trus[mask]) ** 2))
        rmses.append(rmse)

    labels = [str(int(d)) for d in expected_days]
    plt.figure(figsize=(8, 5))
    bars = plt.bar(labels, rmses, edgecolor="black")
    plt.ylabel("rMSE")
    plt.xlabel("d.p.c")
    plt.title("RMSE Study B")
    plt.ylim(0, np.nanmax(rmses) * 1.3)

    for bar, val in zip(bars, rmses):
        if np.isnan(val):
            continue
        plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.05,
                 f"{val:.2f}", ha="center", va="bottom", fontsize=9)

    plt.tight_layout()
    plt.show()
plot_daywise_rmse(preds, trus)

### Heatmap Study A
### Predicted vs actual DPC by Subject

In [None]:
X_val = rg_vxa.drop(["VOLUNTEER","DPC","SHEDDER","DAY_REAL","DPC_NUM"], axis=1, errors="ignore")
y_true = rg_vxa["DPC"].values
ids = rg_vxa["VOLUNTEER"].astype(str).values
y_pred = clf.predict(X_val)
mask = (y_true >= 1) & (y_true <= 7)
df_h = pd.DataFrame({
    "SubjectID": ids[mask],
    "Actual_DPC": y_true[mask],
    "Predicted_DPC": y_pred[mask]
})
heatmap_df = pd.pivot_table(
    df_h,
    index="Actual_DPC",
    columns="SubjectID",
    values="Predicted_DPC",
    aggfunc="median"
)
heatmap_df = heatmap_df.interpolate(axis=1, limit_direction="both")
heatmap_df = heatmap_df.interpolate(axis=0, limit_direction="both")
heatmap_df = heatmap_df.fillna(heatmap_df.mean(numeric_only=True))
heatmap_df = heatmap_df.sort_index()
try:
    heatmap_df = heatmap_df.reindex(sorted(heatmap_df.columns, key=lambda x: int(x)), axis=1)
except Exception:
    heatmap_df = heatmap_df.reindex(sorted(heatmap_df.columns), axis=1)
plt.figure(figsize=(14, 6))
ax = sns.heatmap(
    heatmap_df,
    vmin=min(df_h["Actual_DPC"]), vmax=max(df_h["Actual_DPC"]),
    cmap="YlGn",
    annot=True, fmt=".3f",
    cbar_kws={"label": "Predicted DPC"}
)
ax.set_title("Validation Heatmap of Predicted vs Actual DPC by Subject", fontsize=14)
ax.set_xlabel("Subject ID")
ax.set_ylabel("Actual DPC")
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
plt.show()