# Prepare data

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import ast
import matplotlib.pyplot as plt
import dill
import torch
import nbimporter
import shap

import os
import sys

os.chdir('/data/repos/actin-personalization/prediction')
sys.path.insert(0, os.path.abspath("src/main/python"))

from models import *
from data.data_processing import DataSplitter, DataPreprocessor
from data.lookups import lookup_manager
from utils.settings import settings
from src.main.python.analysis.predictive_algorithms_training import get_data, plot_different_models_survival_curves

preprocessor = DataPreprocessor(settings.db_config_path, settings.db_name)

In [None]:
df, X_train, X_test, y_train, y_test, encoded_columns = get_data()

In [None]:

def get_preprocessed_data_with_sourceId(preprocessor):
    df_raw = preprocessor.load_data()
    df_all, updated_features, _ = preprocessor.preprocess_data(
        lookup_manager.features, df=df_raw
    )
    df_all["sourceId"] = df_raw.loc[df_all.index, "sourceId"]
    #df_all["reasonRefrainmentFromTreatment"] = df_raw.loc[df_all.index, "reasonRefrainmentFromTreatment"]
    return df_raw, df_all, updated_features

df_raw, df_all, updated_features = get_preprocessed_data_with_sourceId(preprocessor)

In [None]:
def load_trained_model(model_name, model_class, model_kwargs={}):
    model_file_prefix = os.path.join(settings.save_path, f"{settings.outcome}_{model_name}")
    nn_file = model_file_prefix + ".pt"
    sk_file = model_file_prefix + ".pkl"
        
    if model_name in ['CoxPH', 'RandomSurvivalForest', 'GradientBoosting', 'AalenAdditive']:
        with open(sk_file, "rb") as f:
            model = dill.load(f)
        print(f"Model {model_name} loaded from {sk_file}")
        return model
    else:
        model = model_class(**model_kwargs)
    
        state = torch.load(nn_file, map_location=torch.device('cpu'))
        
        model.model.net.load_state_dict(state['net_state'])
    
        if 'labtrans' in state:
            model.labtrans             = state['labtrans']
            model.model.duration_index = model.labtrans.cuts
        
        if 'baseline_hazards' in state:
            model.model.baseline_hazards_ = state['baseline_hazards']
            model.model.baseline_cumulative_hazards_ = state['baseline_cumulative_hazards']
            
            print(f"Baseline hazards loaded for {model_name}.")
            
        model.model.net.eval()     
        print(f"Model {model_name} loaded from {nn_file}")
        
        return model
    
def load_all_trained_models(X_train):
    loaded_models = {}
    config_mgr = ExperimentConfig(settings.json_config_file)
    loaded_configs = config_mgr.load_model_configs()

    for model_name, (model_class, model_kwargs) in loaded_configs.items():
        print(model_name, model_class)
        try:
            loaded_model = load_trained_model(
                model_name=model_name, 
                model_class=model_class, 
                model_kwargs=model_kwargs
            )
            loaded_models[model_name] = loaded_model

            ModelTrainer._set_attention_indices(loaded_models[model_name], list(X_train.columns))
        except:
            print(f'Could not load: {model_name}')
            continue
    return loaded_models

In [None]:
trained_models = load_all_trained_models(X_train)

In [None]:
import json
with open('src/main/python/data/treatment_combinations.json', 'r') as f:
    valid_treatment_combinations = json.load(f)

# C-for-benefit estimation

In [None]:
def get_all_patient_treatment_risks(
    model,
    df_all,
    treatment_map,
    treatment_prefix="systemicTreatmentPlan_",
    horizon_days=365
):
    import numpy as np
    import pandas as pd
    from tqdm import tqdm

    def apply_treatment(df, mapping, treatment_cols, msi_flag):
        df_copy = df.copy()
        df_copy[treatment_cols] = 0
        for col, val in mapping.items():
            if col in df_copy.columns:
                df_copy[col] = val
        if "hasMsi" in df_copy.columns:
            df_copy["hasMsi"] = msi_flag
        df_copy["hasTreatment"] = (df_copy[treatment_cols].sum(axis=1) > 0).astype(int)
        return df_copy

    def compute_survival_stats(time_grid: np.ndarray, surv_probs: np.ndarray):
        """
        Match the logic from your plotting function:
        - median_days = first time S(t) ≤ 0.5 (no interpolation)
        - auc_days = area under survival curve
        """
        below = np.where(surv_probs <= 0.5)[0]
        if below.size:
            median_days = time_grid[below[0]]
        else:
            median_days = time_grid[-1]

        auc_days = np.trapz(surv_probs, time_grid)
        return median_days, auc_days

    results = []
    treatment_cols = [col for col in df_all.columns if col.startswith(treatment_prefix)]

    for idx, row in tqdm(df_all.iterrows(), total=len(df_all), desc="Processing patients"):
        source_id = row["sourceId"]
        survival_days = row[settings.duration_col]
        event = row[settings.event_col]
        msi_flag = int(row.get("hasMsi", 0))

        X_base = row.drop(labels=["sourceId", settings.event_col, settings.duration_col]).to_frame().T

        survival_fs = model.predict_survival_function(X_base)
        time_start = max(sf.x[0] for sf in survival_fs)
        time_end = min(sf.x[-1] for sf in survival_fs)
        time_grid = np.linspace(time_start, time_end, 500)

        actual_treatments = [col for col in treatment_cols if row[col] == 1]
        actual_treatment_str = ", ".join(
            [col.replace(treatment_prefix, "") for col in actual_treatments]
        ) if actual_treatments else "No Treatment"

        for treatment_label, mapping in treatment_map.items():
            X_mod = apply_treatment(X_base.copy(), mapping, treatment_cols, msi_flag)

            if hasattr(model, "model") and hasattr(model.model, "predict"):
                risk_val = float(model.model.predict(X_mod.values.astype("float32"))[0])
            else:
                risk_val = float(model.predict(X_mod)[0])

            surv_fn, = model.predict_survival_function(X_mod)
            surv_probs = surv_fn(time_grid)

            prob_1yr = float(np.interp(horizon_days, time_grid, surv_probs))

            median_survival, auc = compute_survival_stats(time_grid, surv_probs)

            results.append({
                "sourceId": source_id,
                "treatment": treatment_label,
                "predicted_risk": risk_val,
                "predicted_prob_1yr": prob_1yr,
                "predicted_median_survival": median_survival,
                "predicted_auc": auc,
                "observed_survival": survival_days,
                "event": event,
                "actual_treatment": actual_treatment_str
            })

    return pd.DataFrame(results)


risk_df = get_all_patient_treatment_risks(
    model=trained_models["DeepSurv_attention"],
    df_all=df_all,
    treatment_map=valid_treatment_combinations,
    treatment_prefix="systemicTreatmentPlan_"
)


print(risk_df.head())


Predict median survival for all patients

In [None]:
import os
import pandas as pd

risk_path = "med_surv_df_deepsurv_attention.csv"
recalculate = False   
max_patients = None   

df_input = df_all.copy()
if max_patients is not None:
    df_input = df_input.sample(n=max_patients, random_state=42).reset_index(drop=True)
    print(f"Limiting to {max_patients} patients for processing.")

if not recalculate and os.path.exists(risk_path):
    print(f"Loading risk_df from: {risk_path}")
    risk_df = pd.read_csv(risk_path)
else:
    print("Recalculating risk_df...")
    risk_df = get_all_patient_treatment_risks(
        model=trained_models["DeepSurv_attention"],
        df_all=df_input,
        treatment_map=valid_treatment_combinations,
        treatment_prefix="systemicTreatmentPlan_"
    )
    risk_df.to_csv(risk_path, index=False)
    print(f"Saved risk_df to: {risk_path}")

print(risk_df.head())



Visualise survival distribution per treatment

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import numpy as np
from scipy.stats import gaussian_kde

baseline_df = risk_df[risk_df["treatment"] == "No Treatment"].copy()
baseline_df = baseline_df.dropna(subset=["predicted_prob_1yr", "actual_treatment"])
baseline_df["actual_treatment"] = baseline_df["actual_treatment"].astype(str)

min_group_size = 20
group_counts = baseline_df["actual_treatment"].value_counts()
valid_groups = group_counts[group_counts >= min_group_size].index.tolist()
filtered_df = baseline_df[baseline_df["actual_treatment"].isin(valid_groups)]

plt.figure(figsize=(10, 6))
colors = sns.color_palette("tab10", len(valid_groups))

for i, group in enumerate(valid_groups):
    group_data = filtered_df[filtered_df["actual_treatment"] == group]["predicted_prob_1yr"].values
    if len(group_data) > 1 and np.std(group_data) > 0:
        kde = gaussian_kde(group_data, bw_method=0.3)
        x_vals = np.linspace(group_data.min() - 0.05, group_data.max() + 0.05, 300)
        y_vals = kde(x_vals) * len(group_data)
        plt.plot(x_vals, y_vals, label=group, color=colors[i], linewidth=2)

plt.title("Smoothed Histogram (Counts) of 1-Year Survival Probability (No Treatment)")
plt.xlabel("Predicted 1-Year Survival Probability (No Treatment)")
plt.ylabel("Number of Patients (Smoothed)")
plt.grid(True, linestyle=":", linewidth=0.5)
plt.legend(title="Actual Treatment", bbox_to_anchor=(1.05, 1), loc="upper left")
plt.tight_layout()
plt.show()


Visualise survival benefit vs. baseline survival

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import linregress

df_no_treatment = risk_df[risk_df["treatment"] == "No Treatment"].copy()
df_5FU_oxa_beva = risk_df[risk_df["treatment"] == "5-FU + oxaliplatin + bevacizumab"].copy()

df_merged = pd.merge(
    df_no_treatment[["sourceId", "predicted_prob_1yr", "actual_treatment"]],
    df_5FU_oxa_beva[["sourceId", "predicted_prob_1yr"]],
    on="sourceId",
    suffixes=("_no_treatment", "_5FU_oxa_beva")
)

valid_actuals = ["No Treatment", "5-FU, oxaliplatin, bevacizumab"]
df_merged = df_merged[df_merged["actual_treatment"].isin(valid_actuals)].copy()

df_merged["delta_surv_prob"] = df_merged["predicted_prob_1yr_5FU_oxa_beva"] - df_merged["predicted_prob_1yr_no_treatment"]

plt.figure(figsize=(10, 7))
scatter = plt.scatter(
    df_merged["predicted_prob_1yr_no_treatment"],
    df_merged["delta_surv_prob"],
    c=pd.Categorical(df_merged["actual_treatment"]).codes,
    cmap="Set1",
    alpha=0.2,
    marker='.',
    s=20
)

slope, intercept, r_value, p_value, std_err = linregress(
    df_merged["predicted_prob_1yr_no_treatment"],
    df_merged["delta_surv_prob"]
)
x_vals = np.linspace(df_merged["predicted_prob_1yr_no_treatment"].min(),
                     df_merged["predicted_prob_1yr_no_treatment"].max(), 500)
y_vals = slope * x_vals + intercept
plt.plot(x_vals, y_vals, color="black", linestyle="--", label="Trend line")

plt.axhline(y=0.0, color='gray', linestyle='--', linewidth=1.2, label="No Change")

handles, labels = scatter.legend_elements(prop="colors")
unique_treatments = pd.Categorical(df_merged["actual_treatment"]).categories
plt.legend(handles, unique_treatments, title="Actual Treatment", bbox_to_anchor=(1.05, 1), loc="upper left")

plt.xlabel("Predicted 1-Year Survival (No Treatment)")
plt.ylabel("Δ 1-Year Survival\n(5FU_oxa_beva − No Treatment)")
plt.title("Survival Benefit vs. Baseline 1-Year Survival\n(Only Actual No Tx or 5FU_oxa_beva)")
plt.grid(True, linestyle=':', linewidth=0.5)
plt.tight_layout()
plt.show()


Pairing untreated patients with nearest treated patient (based on predicted survival benefit)

In [None]:
from sklearn.neighbors import NearestNeighbors
import pandas as pd
import numpy as np

df_no_treatment = risk_df[risk_df["treatment"] == "No Treatment"].copy()
df_5FU_oxa_beva = risk_df[risk_df["treatment"] == "5-FU + oxaliplatin + bevacizumab"].copy()

df_merged = pd.merge(
    df_no_treatment[["sourceId", "predicted_prob_1yr", "actual_treatment"]],
    df_5FU_oxa_beva[["sourceId", "predicted_prob_1yr"]],
    on="sourceId",
    suffixes=("_no_treatment", "_5FU_oxa_beva")
)

valid_actuals = ["No Treatment", "5-FU, oxaliplatin, bevacizumab"]
df_merged = df_merged[df_merged["actual_treatment"].isin(valid_actuals)].copy()

df_merged["delta_surv_prob"] = df_merged["predicted_prob_1yr_5FU_oxa_beva"] - df_merged["predicted_prob_1yr_no_treatment"]

treated = df_merged[df_merged["actual_treatment"] == "5-FU, oxaliplatin, bevacizumab"].copy()
untreated = df_merged[df_merged["actual_treatment"] == "No Treatment"].copy()

nn = NearestNeighbors(n_neighbors=1, metric='euclidean')
nn.fit(treated[["delta_surv_prob"]])
distances, indices = nn.kneighbors(untreated[["delta_surv_prob"]])

matched_pairs = pd.DataFrame({
    "untreated_sourceId": untreated["sourceId"].values,
    "untreated_delta_surv_prob": untreated["delta_surv_prob"].values,
    "treated_sourceId": treated.iloc[indices.flatten()]["sourceId"].values,
    "treated_delta_surv_prob": treated.iloc[indices.flatten()]["delta_surv_prob"].values,
})
matched_pairs["delta_difference"] = np.abs(matched_pairs["untreated_delta_surv_prob"] - matched_pairs["treated_delta_surv_prob"])

matched_pairs = matched_pairs.sort_values("delta_difference").reset_index(drop=True)


In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 6))
plt.scatter(
    matched_pairs["untreated_delta_surv_prob"],
    matched_pairs["delta_difference"],
    alpha=0.6,
    s=30,
    edgecolor='k'
)

plt.axhline(y=0, color='gray', linestyle='--', linewidth=1)

plt.xlabel("Predicted 1-Year Survival Benefit\n(Untreated Patient: 5FU_oxa_beva − No Treatment)")
plt.ylabel("Absolute Difference with Matched Treated Patient")
plt.title("Matching Quality:\nUntreated vs. Treated Δ 1-Year Survival Probability")
plt.grid(True, linestyle=':', linewidth=0.5)
plt.tight_layout()
plt.show()


Removing dissimilar pairs

In [None]:
from sklearn.neighbors import NearestNeighbors
import pandas as pd
import numpy as np

df_no_treatment = risk_df[risk_df["treatment"] == "No Treatment"].copy()
df_5FU_oxa_beva = risk_df[risk_df["treatment"] == "5-FU + oxaliplatin + bevacizumab"].copy()

df_merged = pd.merge(
    df_no_treatment[["sourceId", "predicted_prob_1yr", "actual_treatment"]],
    df_5FU_oxa_beva[["sourceId", "predicted_prob_1yr"]],
    on="sourceId",
    suffixes=("_no_treatment", "_5FU_oxa_beva")
)

valid_actuals = ["No Treatment", "5-FU, oxaliplatin, bevacizumab"]
df_merged = df_merged[df_merged["actual_treatment"].isin(valid_actuals)].copy()

df_merged["delta_surv_prob"] = (
    df_merged["predicted_prob_1yr_5FU_oxa_beva"] - df_merged["predicted_prob_1yr_no_treatment"]
)

treated = df_merged[df_merged["actual_treatment"] == "5-FU, oxaliplatin, bevacizumab"].copy()
untreated = df_merged[df_merged["actual_treatment"] == "No Treatment"].copy()

nn = NearestNeighbors(n_neighbors=1, metric='euclidean')
nn.fit(treated[["delta_surv_prob"]])
distances, indices = nn.kneighbors(untreated[["delta_surv_prob"]])

matched_pairs = pd.DataFrame({
    "untreated_sourceId": untreated["sourceId"].values,
    "untreated_predicted_benefit": untreated["delta_surv_prob"].values,
    "treated_sourceId": treated.iloc[indices.flatten()]["sourceId"].values,
    "treated_predicted_benefit": treated.iloc[indices.flatten()]["delta_surv_prob"].values,
})
matched_pairs["delta_difference"] = np.abs(
    matched_pairs["untreated_predicted_benefit"] - matched_pairs["treated_predicted_benefit"]
)

matched_pairs = matched_pairs[matched_pairs["delta_difference"] <= 0.005].reset_index(drop=True)

ONE_YEAR_DAYS = 365

risk_subset = risk_df[[
    "sourceId", "treatment", "predicted_prob_1yr", "observed_survival", "event"
]].copy()

matched_pairs["untreated_sourceId"] = matched_pairs["untreated_sourceId"].astype(str)
matched_pairs["treated_sourceId"] = matched_pairs["treated_sourceId"].astype(str)
risk_subset["sourceId"] = risk_subset["sourceId"].astype(str)

untreated_info = risk_subset[risk_subset["treatment"] == "No Treatment"].copy()
untreated_info["untreated_survived_1yr"] = (
    (untreated_info["observed_survival"] >= ONE_YEAR_DAYS) | (untreated_info["event"] == 0)
).astype(int)
untreated_info = untreated_info.rename(columns={
    "predicted_prob_1yr": "untreated_pred_prob_1yr"
})[["sourceId", "untreated_pred_prob_1yr", "untreated_survived_1yr"]]

treated_info = risk_subset[risk_subset["treatment"] == "5-FU + oxaliplatin + bevacizumab"].copy()
treated_info["treated_survived_1yr"] = (
    (treated_info["observed_survival"] >= ONE_YEAR_DAYS) | (treated_info["event"] == 0)
).astype(int)
treated_info = treated_info.rename(columns={
    "predicted_prob_1yr": "treated_pred_prob_1yr"
})[["sourceId", "treated_pred_prob_1yr", "treated_survived_1yr"]]

matched_pairs = matched_pairs.merge(
    untreated_info, how="left", left_on="untreated_sourceId", right_on="sourceId"
).drop(columns=["sourceId"])

matched_pairs = matched_pairs.merge(
    treated_info, how="left", left_on="treated_sourceId", right_on="sourceId"
).drop(columns=["sourceId"])

matched_pairs["predicted_1yr_benefit"] = (
    matched_pairs["treated_predicted_benefit"] - matched_pairs["untreated_predicted_benefit"]
)

print(matched_pairs[[
    "untreated_sourceId", "treated_sourceId",
    "untreated_pred_prob_1yr", "treated_pred_prob_1yr",
    "untreated_predicted_benefit", "treated_predicted_benefit",
    "predicted_1yr_benefit",
    "untreated_survived_1yr", "treated_survived_1yr",
    "delta_difference"
]].head())


In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 6))
plt.scatter(
    matched_pairs["untreated_predicted_benefit"],
    matched_pairs["delta_difference"],
    alpha=0.6,
    s=30,
    edgecolor='k'
)

plt.axhline(y=0, color='gray', linestyle='--', linewidth=1)

plt.xlabel("Predicted 1-Year Survival Benefit\n(Untreated Patient: 5FU_oxa_beva − No Treatment)")
plt.ylabel("Absolute Difference with Matched Treated Patient")
plt.title("Matching Quality:\nUntreated vs. Treated Δ 1-Year Survival Probability")
plt.grid(True, linestyle=':', linewidth=0.5)
plt.tight_layout()
plt.show()


Computing pairwise effect

In [None]:
def compute_pairwise_effect(row):
    if row["treated_survived_1yr"] == 1 and row["untreated_survived_1yr"] == 0:
        return +1
    elif row["treated_survived_1yr"] == 0 and row["untreated_survived_1yr"] == 1:
        return -1
    else:
        return 0


matched_pairs["pairwise_effect"] = matched_pairs.apply(compute_pairwise_effect, axis=1)

print(matched_pairs[[
    "untreated_survived_1yr", "treated_survived_1yr", "pairwise_effect"
]].head(30))


Estimating pairwise treatment effect

In [None]:
matched_pairs["predicted_benefit_avg"] = 0.5 * (
    matched_pairs["untreated_predicted_benefit"] + matched_pairs["treated_predicted_benefit"]
)



C-benefit calculation

In [None]:
from itertools import combinations
import numpy as np
import pandas as pd

valid_pairs = matched_pairs.copy().reset_index(drop=True)

records = []
for i, j in combinations(valid_pairs.index, 2):
    effect_i = valid_pairs.loc[i, "pairwise_effect"]
    effect_j = valid_pairs.loc[j, "pairwise_effect"]

    if effect_i == effect_j:
        continue

    winner = "i" if effect_i > effect_j else "j"

    pred_i = valid_pairs.loc[i, "predicted_benefit_avg"]
    pred_j = valid_pairs.loc[j, "predicted_benefit_avg"]
    predicted = "i" if pred_i > pred_j else "j"

    concordant = winner == predicted

    records.append(concordant)

if records:
    c_for_benefit = np.mean(records)
else:
    c_for_benefit = np.nan

print(f"c-for-benefit : {c_for_benefit:.3f}")


Full overview of pairs

In [None]:
from itertools import combinations
import pandas as pd
import numpy as np
import random

valid_pairs = matched_pairs.reset_index(drop=True)

random.seed(42)
pair_indices = list(combinations(valid_pairs.index, 2))
sample_indices = random.sample(pair_indices, 100000)

records = []
for i, j in sample_indices:
    row_i = valid_pairs.loc[i]
    row_j = valid_pairs.loc[j]

    if row_i["pairwise_effect"] == row_j["pairwise_effect"]:
        continue

    if row_i["pairwise_effect"] > row_j["pairwise_effect"]:
        winner = "i"
    else:
        winner = "j"

    if row_i["predicted_benefit_avg"] > row_j["predicted_benefit_avg"]:
        predicted = "i"
    else:
        predicted = "j"

    concordant = (winner == predicted)

    records.append({
        "pair_i_untreated": row_i["untreated_sourceId"],
        "pair_i_treated": row_i["treated_sourceId"],
        "pair_j_untreated": row_j["untreated_sourceId"],
        "pair_j_treated": row_j["treated_sourceId"],
        "effect_i": row_i["pairwise_effect"],
        "effect_j": row_j["pairwise_effect"],
        "benefit_i": row_i["predicted_benefit_avg"],
        "benefit_j": row_j["predicted_benefit_avg"],
        "predicted_winner": predicted,
        "observed_winner": winner,
        "concordant": concordant
    })

comparison_df = pd.DataFrame(records)
comparison_df["concordant"] = comparison_df["concordant"].astype(int)

if len(comparison_df) > 0:
    c_for_benefit_sample = comparison_df["concordant"].mean()
    print(f"Sampled c-for-benefit (1000 comparisons): {c_for_benefit_sample:.3f}")
else:
    print("No discordant pairs found in sampled comparisons.")

print(comparison_df.head(100))
