# Prepare data

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import ast
import matplotlib.pyplot as plt
import dill
import torch
import nbimporter
import shap

import os
import sys

os.chdir('/data/repos/actin-personalization/prediction')
sys.path.insert(0, os.path.abspath("src/main/python"))

from models import *
from data.data_processing import DataSplitter, DataPreprocessor
from data.lookups import lookup_manager
from utils.settings import settings
from src.main.python.analysis.predictive_algorithms_training import get_data, plot_different_models_survival_curves

preprocessor = DataPreprocessor(settings.db_config_path, settings.db_name)

In [None]:
df, X_train, X_test, y_train, y_test, encoded_columns = get_data()

In [None]:

def get_preprocessed_data_with_sourceId(preprocessor):
    df_raw = preprocessor.load_data()
    df_all, updated_features, _ = preprocessor.preprocess_data(
        lookup_manager.features, df=df_raw
    )
    df_all["sourceId"] = df_raw.loc[df_all.index, "sourceId"]
    #df_all["reasonRefrainmentFromTreatment"] = df_raw.loc[df_all.index, "reasonRefrainmentFromTreatment"]
    return df_raw, df_all, updated_features

df_raw, df_all, updated_features = get_preprocessed_data_with_sourceId(preprocessor)

In [None]:
def load_trained_model(model_name, model_class, model_kwargs={}):
    model_file_prefix = os.path.join(settings.save_path, f"{settings.outcome}_{model_name}")
    nn_file = model_file_prefix + ".pt"
    sk_file = model_file_prefix + ".pkl"
        
    if model_name in ['CoxPH', 'RandomSurvivalForest', 'GradientBoosting', 'AalenAdditive']:
        with open(sk_file, "rb") as f:
            model = dill.load(f)
        print(f"Model {model_name} loaded from {sk_file}")
        return model
    else:
        model = model_class(**model_kwargs)
    
        state = torch.load(nn_file, map_location=torch.device('cpu'))
        
        model.model.net.load_state_dict(state['net_state'])
    
        if 'labtrans' in state:
            model.labtrans             = state['labtrans']
            model.model.duration_index = model.labtrans.cuts
        
        if 'baseline_hazards' in state:
            model.model.baseline_hazards_ = state['baseline_hazards']
            model.model.baseline_cumulative_hazards_ = state['baseline_cumulative_hazards']
            
            print(f"Baseline hazards loaded for {model_name}.")
            
        model.model.net.eval()     
        print(f"Model {model_name} loaded from {nn_file}")
        
        return model
    
def load_all_trained_models(X_train):
    loaded_models = {}
    config_mgr = ExperimentConfig(settings.json_config_file)
    loaded_configs = config_mgr.load_model_configs()

    for model_name, (model_class, model_kwargs) in loaded_configs.items():
        print(model_name, model_class)
        try:
            loaded_model = load_trained_model(
                model_name=model_name, 
                model_class=model_class, 
                model_kwargs=model_kwargs
            )
            loaded_models[model_name] = loaded_model

            ModelTrainer._set_attention_indices(loaded_models[model_name], list(X_train.columns))
        except:
            print(f'Could not load: {model_name}')
            continue
    return loaded_models

In [None]:
trained_models = load_all_trained_models(X_train)

In [None]:
import json
with open('src/main/python/data/treatment_combinations.json', 'r') as f:
    valid_treatment_combinations = json.load(f)

# C-for-benefit estimation based on the method described by Maas et al.

Definition of treatment effect

In [None]:
def get_all_patient_treatment_risks(
    model,
    df_all,
    treatment_map,
    treatment_prefix="systemicTreatmentPlan_",
    horizon_days=365
):
    import numpy as np
    import pandas as pd
    from tqdm import tqdm

    def apply_treatment(df, mapping, treatment_cols, msi_flag):
        df_copy = df.copy()
        df_copy[treatment_cols] = 0
        for col, val in mapping.items():
            if col in df_copy.columns:
                df_copy[col] = val
        if "hasMsi" in df_copy.columns:
            df_copy["hasMsi"] = msi_flag
        df_copy["hasTreatment"] = (df_copy[treatment_cols].sum(axis=1) > 0).astype(int)
        return df_copy

    def compute_survival_stats(time_grid: np.ndarray, surv_probs: np.ndarray):
        """
        Match the logic from your plotting function:
        - median_days = first time S(t) ≤ 0.5 (no interpolation)
        - auc_days = area under survival curve
        """
        below = np.where(surv_probs <= 0.5)[0]
        if below.size:
            median_days = time_grid[below[0]]
        else:
            median_days = time_grid[-1]

        auc_days = np.trapz(surv_probs, time_grid)
        return median_days, auc_days

    results = []
    treatment_cols = [col for col in df_all.columns if col.startswith(treatment_prefix)]

    for idx, row in tqdm(df_all.iterrows(), total=len(df_all), desc="Processing patients"):
        source_id = row["sourceId"]
        survival_days = row[settings.duration_col]
        event = row[settings.event_col]
        msi_flag = int(row.get("hasMsi", 0))

        X_base = row.drop(labels=["sourceId", settings.event_col, settings.duration_col]).to_frame().T

        survival_fs = model.predict_survival_function(X_base)
        time_start = max(sf.x[0] for sf in survival_fs)
        time_end = min(sf.x[-1] for sf in survival_fs)
        time_grid = np.linspace(time_start, time_end, 500)

        actual_treatments = [col for col in treatment_cols if row[col] == 1]
        actual_treatment_str = ", ".join(
            [col.replace(treatment_prefix, "") for col in actual_treatments]
        ) if actual_treatments else "No Treatment"

        for treatment_label, mapping in treatment_map.items():
            X_mod = apply_treatment(X_base.copy(), mapping, treatment_cols, msi_flag)

            if hasattr(model, "model") and hasattr(model.model, "predict"):
                risk_val = float(model.model.predict(X_mod.values.astype("float32"))[0])
            else:
                risk_val = float(model.predict(X_mod)[0])

            surv_fn, = model.predict_survival_function(X_mod)
            surv_probs = surv_fn(time_grid)

            prob_1yr = float(np.interp(horizon_days, time_grid, surv_probs))

            median_survival, auc = compute_survival_stats(time_grid, surv_probs)

            results.append({
                "sourceId": source_id,
                "treatment": treatment_label,
                "predicted_risk": risk_val,
                "predicted_prob_1yr": prob_1yr,
                "predicted_median_survival": median_survival,
                "predicted_auc": auc,
                "observed_survival": survival_days,
                "event": event,
                "actual_treatment": actual_treatment_str
            })

    return pd.DataFrame(results)


risk_df = get_all_patient_treatment_risks(
    model=trained_models["DeepSurv_attention"],
    df_all=df_all,
    treatment_map=valid_treatment_combinations,
    treatment_prefix="systemicTreatmentPlan_"
)


print(risk_df.head())


In [None]:
import os
import pandas as pd

risk_path = "1yr_surv_df_deepsurv_attention.csv"
recalculate = False
max_patients = None

df_input = df_all.copy()
if max_patients is not None:
    df_input = df_input.sample(n=max_patients, random_state=42).reset_index(drop=True)
    print(f"Limiting to {max_patients} patients for processing.")

if not recalculate and os.path.exists(risk_path):
    print(f"Loading risk_df from: {risk_path}")
    risk_df = pd.read_csv(risk_path)
else:
    print("Recalculating risk_df...")
    risk_df = get_all_patient_treatment_risks(
        model=trained_models["DeepSurv_attention"],
        df_all=df_input,
        treatment_map=valid_treatment_combinations,
        treatment_prefix="systemicTreatmentPlan_"
    )
    risk_df.to_csv(risk_path, index=False)
    print(f"Saved risk_df to: {risk_path}")

print(risk_df.head())



Extraction of untreated and treated patient groups

In [None]:
treat_control = "No Treatment"
treat_active = "5-FU + oxaliplatin + bevacizumab"

df_control = risk_df[risk_df["treatment"] == treat_control][["sourceId", "predicted_prob_1yr"]]
df_active = risk_df[risk_df["treatment"] == treat_active][["sourceId", "predicted_prob_1yr"]]

df_control = df_control.rename(columns={"predicted_prob_1yr": "prob_untreated_pred"})
df_active = df_active.rename(columns={"predicted_prob_1yr": "prob_treated_pred"})

df_ite = pd.merge(df_control, df_active, on="sourceId")

df_actual = risk_df.drop_duplicates("sourceId")[[
    "sourceId", 
    "actual_treatment", 
    "event", 
    "observed_survival"
]]
df_ite = pd.merge(df_ite, df_actual, on="sourceId")

df_ite["survived_1yr"] = np.where(
    df_ite["observed_survival"] > 365,
    1,
    0
)

df_ite["predicted_ite"] = df_ite["prob_untreated_pred"] - df_ite["prob_treated_pred"]

valid_actuals = ["No Treatment", "5-FU, oxaliplatin, bevacizumab"]
df_ite = df_ite[df_ite["actual_treatment"].isin(valid_actuals)].copy()

print("Patients remaining after filtering:", df_ite.shape[0])
print(df_ite["actual_treatment"].value_counts())
print("\nPreview:")
print(df_ite.head())


Create covariate dataframe

In [None]:
exclude_cols = [
    'hadSurvivalEvent',
    'systemicTreatmentPlan_5-FU',
    'systemicTreatmentPlan_oxaliplatin',
    'systemicTreatmentPlan_irinotecan',
    'systemicTreatmentPlan_bevacizumab',
    'systemicTreatmentPlan_panitumumab',
    'systemicTreatmentPlan_pembrolizumab',
    'systemicTreatmentPlan_nivolumab',
    'hasTreatment',
    'survivalDaysSinceMetastaticDiagnosis'
]

covariate_cols = [col for col in df_all.columns if col not in exclude_cols + ['sourceId']]
df_covariates = df_all[["sourceId"] + covariate_cols].copy()

print("Created df_covariates with shape:", df_covariates.shape)
print("Total covariates:", len(covariate_cols))

print(df_covariates.head())

Create pairs based on similarity in covariates
Find for every untreated patient the closest treated match (without duplicates)

In [None]:
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
from scipy.spatial.distance import cdist
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import os

output_path = "matched_pairs_Maas.csv"
overwrite = False

if os.path.exists(output_path) and not overwrite:
    print(f"Matched pairs file already exists at '{output_path}'.")
    print("Loading from file instead of recomputing. To overwrite, set overwrite = True.\n")
    matched_pairs_df = pd.read_csv(output_path)
else:
    treated_df = df_ite[df_ite['actual_treatment'] != "No Treatment"].copy()
    untreated_df = df_ite[df_ite['actual_treatment'] == "No Treatment"].copy()

    treated_cov = treated_df.merge(df_covariates, on='sourceId')
    untreated_cov = untreated_df.merge(df_covariates, on='sourceId')

    exclude_cols = [
        'sourceId', 'actual_treatment', 'event', 'observed_survival',
        'prob_untreated_pred', 'prob_treated_pred', 'predicted_ite'
    ]
    X_treated = treated_cov.drop(columns=[col for col in exclude_cols if col in treated_cov.columns])
    X_untreated = untreated_cov.drop(columns=[col for col in exclude_cols if col in untreated_cov.columns])

    scaler = StandardScaler()
    X_all = pd.concat([X_treated, X_untreated], axis=0)
    scaler.fit(X_all)
    X_treated_scaled = scaler.transform(X_treated)
    X_untreated_scaled = scaler.transform(X_untreated)

    cov_matrix = np.cov(X_all.T, rowvar=False)
    VI = np.linalg.inv(cov_matrix)
    dist_matrix = cdist(X_untreated_scaled, X_treated_scaled, metric='mahalanobis', VI=VI)

    treated_indices_used = set()
    matched_pairs = []

    print("Matching untreated patients to nearest treated patient:")
    for i in tqdm(range(len(dist_matrix))):
        row = dist_matrix[i]
        sorted_indices = np.argsort(row)
        for j in sorted_indices:
            if j not in treated_indices_used:
                treated_indices_used.add(j)
                matched_pairs.append({
                    'untreated_id': untreated_cov.iloc[i]['sourceId'],
                    'treated_id': treated_cov.iloc[j]['sourceId'],
                    'distance': row[j]
                })
                break

    matched_pairs_df = pd.DataFrame(matched_pairs).reset_index(drop=True)

    expected_matches = len(untreated_cov)
    actual_matches = len(matched_pairs_df)


    plt.figure(figsize=(8, 5))
    plt.hist(matched_pairs_df['distance'], bins=30, edgecolor='k')
    plt.xlabel("Mahalanobis Distance")
    plt.ylabel("Number of Matched Pairs")
    plt.title("Distribution of Matching Distances")
    plt.grid(True)
    plt.tight_layout()
    plt.show()

    matched_pairs_df.to_csv(output_path, index=False)
    print(f"Saved matched pairs to: {output_path}")

matched_pairs_df.head()


Calculate pairwise effect

In [None]:
df_lookup = df_ite.set_index("sourceId")[[
    "prob_untreated_pred", "prob_treated_pred", "survived_1yr"
]]

matched = matched_pairs_df.copy()
matched = matched.merge(df_lookup, left_on="untreated_id", right_index=True)

matched = matched.merge(df_lookup, left_on="treated_id", right_index=True)

matched["pred_pairwise_effect"] = matched["prob_treated_pred_y"] - matched["prob_untreated_pred_x"]

matched["obs_pairwise_effect"] = np.select(
    [
        (matched["survived_1yr_x"] == 0) & (matched["survived_1yr_y"] == 1),
        (matched["survived_1yr_x"] == 1) & (matched["survived_1yr_y"] == 0),
        (matched["survived_1yr_x"] == matched["survived_1yr_y"])
    ],
    [1, -1, 0]
)

# --- Final preview ---
matched[[
    "untreated_id", "treated_id", "distance",
    "prob_untreated_pred_x", "prob_treated_pred_y",
    "survived_1yr_x", "survived_1yr_y",
    "pred_pairwise_effect", "obs_pairwise_effect"
]].head()


Calibration of benefit and average absolute vertical distance

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import statsmodels.api as sm

n_bins = 4

matched["quantile_bin"] = pd.qcut(matched["pred_pairwise_effect"], q=n_bins, duplicates='drop')

quantile_df = matched.groupby("quantile_bin").agg({
    "pred_pairwise_effect": "mean",
    "obs_pairwise_effect": "mean"
}).reset_index()

lowess = sm.nonparametric.lowess
smoothed = lowess(
    matched["obs_pairwise_effect"],
    matched["pred_pairwise_effect"],
    frac=0.3
)
smoothed_df = pd.DataFrame(smoothed, columns=["pred_pairwise_effect", "smoothed_obs"])

calib_df = matched[["pred_pairwise_effect"]].copy()
calib_df["smoothed_obs"] = np.interp(
    calib_df["pred_pairwise_effect"],
    smoothed_df["pred_pairwise_effect"],
    smoothed_df["smoothed_obs"]
)
calib_df["abs_error"] = np.abs(calib_df["smoothed_obs"] - calib_df["pred_pairwise_effect"])
Eavg = calib_df["abs_error"].mean()
E50 = calib_df["abs_error"].median()
E90 = calib_df["abs_error"].quantile(0.9)

print(f"Eavg-for-benefit: {Eavg:.4f}")
print(f"E50-for-benefit:  {E50:.4f}")
print(f"E90-for-benefit:  {E90:.4f}")

plt.figure(figsize=(7, 5))
plt.plot([-1, 1], [-1, 1], linestyle="--", color="gray", label="Perfect calibration")
plt.plot(smoothed_df["pred_pairwise_effect"], smoothed_df["smoothed_obs"], color="blue", label="LOWESS-smoothed")
plt.scatter(
    quantile_df["pred_pairwise_effect"],
    quantile_df["obs_pairwise_effect"],
    color="red", s=50, label="Quantile averages"
)
plt.xlabel("Predicted pairwise treatment effect")
plt.ylabel("Observed pairwise treatment effect")
plt.title("Calibration of benefit")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()


Estimate C-for benefit

In [None]:
from itertools import combinations

df_c = matched[["pred_pairwise_effect", "obs_pairwise_effect"]].copy()

concordant = discordant = tied = 0

for (i, row_i), (j, row_j) in combinations(df_c.iterrows(), 2):
    obs_diff = row_i["obs_pairwise_effect"] - row_j["obs_pairwise_effect"]
    
    if obs_diff == 0:
        continue

    pred_diff = row_i["pred_pairwise_effect"] - row_j["pred_pairwise_effect"]

    if np.sign(obs_diff) == np.sign(pred_diff):
        concordant += 1
    elif pred_diff == 0:
        tied += 1
    else:
        discordant += 1

n_informative = concordant + discordant
if n_informative > 0:
    c_for_benefit = concordant / n_informative
else:
    c_for_benefit = np.nan

print(f"C-for-benefit: {c_for_benefit:.4f}")
print(f"Concordant: {concordant}, Discordant: {discordant}, Tied: {tied}")


Sped up version for C-for-benefit estimation

In [None]:
from itertools import combinations
import pandas as pd
import numpy as np
import random

valid_pairs = matched.reset_index(drop=True)

random.seed(42)
pair_indices = list(combinations(valid_pairs.index, 2))
sample_indices = random.sample(pair_indices, 100000)

records = []
for i, j in sample_indices:
    row_i = valid_pairs.loc[i]
    row_j = valid_pairs.loc[j]

    if row_i["obs_pairwise_effect"] == row_j["obs_pairwise_effect"]:
        continue

    if row_i["obs_pairwise_effect"] > row_j["obs_pairwise_effect"]:
        winner = "i"
    else:
        winner = "j"

    if row_i["pred_pairwise_effect"] > row_j["pred_pairwise_effect"]:
        predicted = "i"
    else:
        predicted = "j"

    concordant = (winner == predicted)

    records.append({
        "pair_i_untreated": row_i["untreated_id"],
        "pair_i_treated": row_i["treated_id"],
        "pair_j_untreated": row_j["untreated_id"],
        "pair_j_treated": row_j["treated_id"],
        "effect_i": row_i["obs_pairwise_effect"],
        "effect_j": row_j["obs_pairwise_effect"],
        "benefit_i": row_i["pred_pairwise_effect"],
        "benefit_j": row_j["pred_pairwise_effect"],
        "predicted_winner": predicted,
        "observed_winner": winner,
        "concordant": concordant
    })

comparison_df = pd.DataFrame(records)
comparison_df["concordant"] = comparison_df["concordant"].astype(int)

if len(comparison_df) > 0:
    c_for_benefit_sample = comparison_df["concordant"].mean()
    print(f"Sampled C-for-benefit (1000 comparisons): {c_for_benefit_sample:.3f}")
else:
    print("No informative pairs (with different observed effects) found.")

comparison_df.head(10)


Distribution of pairwise effect

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

benefits_i = comparison_df[["benefit_i"]].copy()
benefits_i.columns = ["benefit"]
benefits_i["source"] = "pair_i"

benefits_j = comparison_df[["benefit_j"]].copy()
benefits_j.columns = ["benefit"]
benefits_j["source"] = "pair_j"

benefits = pd.concat([benefits_i, benefits_j], ignore_index=True)

plt.figure(figsize=(8, 5))
sns.kdeplot(data=benefits, x="benefit", hue="source", common_norm=False, linewidth=2)
plt.axvline(0, color="gray", linestyle=":")
plt.title("Distribution of Predicted Pairwise Effects (from comparison_df)")
plt.xlabel("Predicted Pairwise Effect")
plt.ylabel("Density")
plt.grid(True)
plt.tight_layout()
plt.show()


perform brier and cross-entropy

In [None]:
import numpy as np
from sklearn.metrics import log_loss, brier_score_loss

discordant = matched[matched["obs_pairwise_effect"].isin([1, -1])].copy()

discordant["label"] = (discordant["obs_pairwise_effect"] == 1).astype(int)

discordant["pred_prob_benefit"] = 1 / (1 + np.exp(discordant["pred_pairwise_effect"]))

brier = brier_score_loss(discordant["label"], discordant["pred_prob_benefit"])
cross_entropy = log_loss(discordant["label"], discordant["pred_prob_benefit"])

print(f"Brier-for-benefit:         {brier:.4f}")
print(f"Cross-entropy-for-benefit: {cross_entropy:.4f}")
