# Prepare data

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import ast
import matplotlib.pyplot as plt
import dill
import torch
import nbimporter
import shap

import os
import sys

os.chdir('/data/repos/actin-personalization/prediction')
sys.path.insert(0, os.path.abspath("src/main/python"))

from models import *
from data.data_processing import DataSplitter, DataPreprocessor
from data.lookups import lookup_manager
from utils.settings import settings
from src.main.python.analysis.predictive_algorithms_training import get_data, plot_different_models_survival_curves

preprocessor = DataPreprocessor(settings.db_config_path, settings.db_name)

In [None]:
df, X_train, X_test, y_train, y_test, encoded_columns = get_data()

In [None]:

def get_preprocessed_data_with_sourceId(preprocessor):
    df_raw = preprocessor.load_data()
    df_all, updated_features, _ = preprocessor.preprocess_data(
        lookup_manager.features, df=df_raw
    )
    df_all["sourceId"] = df_raw.loc[df_all.index, "sourceId"]
    #df_all["reasonRefrainmentFromTreatment"] = df_raw.loc[df_all.index, "reasonRefrainmentFromTreatment"]
    return df_raw, df_all, updated_features

df_raw, df_all, updated_features = get_preprocessed_data_with_sourceId(preprocessor)

In [None]:
def load_trained_model(model_name, model_class, model_kwargs={}):
    model_file_prefix = os.path.join(settings.save_path, f"{settings.outcome}_{model_name}")
    nn_file = model_file_prefix + ".pt"
    sk_file = model_file_prefix + ".pkl"
        
    if model_name in ['CoxPH', 'RandomSurvivalForest', 'GradientBoosting', 'AalenAdditive']:
        with open(sk_file, "rb") as f:
            model = dill.load(f)
        print(f"Model {model_name} loaded from {sk_file}")
        return model
    else:
        model = model_class(**model_kwargs)
    
        state = torch.load(nn_file, map_location=torch.device('cpu'))
        
        model.model.net.load_state_dict(state['net_state'])
    
        if 'labtrans' in state:
            model.labtrans             = state['labtrans']
            model.model.duration_index = model.labtrans.cuts
        
        if 'baseline_hazards' in state:
            model.model.baseline_hazards_ = state['baseline_hazards']
            model.model.baseline_cumulative_hazards_ = state['baseline_cumulative_hazards']
            
            print(f"Baseline hazards loaded for {model_name}.")
            
        model.model.net.eval()     
        print(f"Model {model_name} loaded from {nn_file}")
        
        return model
    
def load_all_trained_models(X_train):
    loaded_models = {}
    config_mgr = ExperimentConfig(settings.json_config_file)
    loaded_configs = config_mgr.load_model_configs()

    for model_name, (model_class, model_kwargs) in loaded_configs.items():
        print(model_name, model_class)
        try:
            loaded_model = load_trained_model(
                model_name=model_name, 
                model_class=model_class, 
                model_kwargs=model_kwargs
            )
            loaded_models[model_name] = loaded_model

            ModelTrainer._set_attention_indices(loaded_models[model_name], list(X_train.columns))
        except:
            print(f'Could not load: {model_name}')
            continue
    return loaded_models

In [None]:
trained_models = load_all_trained_models(X_train)

In [None]:
import json
with open('src/main/python/data/treatment_combinations.json', 'r') as f:
    valid_treatment_combinations = json.load(f)

# Risk distribution analysis

Ranking by risk

In [None]:
def plot_observed_vs_predicted_rank_by_risk(
    model,
    df_all,
    *,
    max_observed_days: int = 365,
    num_buckets: int = 10,
    figsize=(10, 6),
    show_fit: bool = False,
    print_errors: bool = True,
    show_decile_lines: bool = True,
    show_decile_stats: bool = True
):
    import numpy as np
    import matplotlib.pyplot as plt
    from matplotlib.ticker import MaxNLocator
    from sklearn.metrics import mean_absolute_error, mean_squared_error
    import pandas as pd

    X = df_all.drop(columns=['sourceId', settings.event_col, settings.duration_col])
    times_obs = df_all[settings.duration_col].values
    event_flags = df_all[settings.event_col].values
    observed_event_mask = event_flags == 1 

    if hasattr(model, "model") and hasattr(model.model, "predict"):
        predicted_risk = model.model.predict(X.values.astype("float32"))
    else:
        predicted_risk = model.predict(X)

    if isinstance(predicted_risk, (np.ndarray, torch.Tensor)) and predicted_risk.ndim > 1:
        predicted_risk = predicted_risk.ravel()

    tpred = -predicted_risk 

    mask = (times_obs <= max_observed_days) & observed_event_mask
    tobs = times_obs[mask]
    tpred_masked = tpred[mask]

    if print_errors:
        print(f"Dropped {len(times_obs) - mask.sum()} censored patients (event_flag = 0)")

    ranks = np.argsort(np.argsort(tpred_masked))

    if "hasTreatment" in df_all.columns:
        has_treatment = df_all["hasTreatment"].values[mask]
        treated_mask = has_treatment == 1
        untreated_mask = has_treatment == 0
        print(f"Treated patients:   {treated_mask.sum()}")
        print(f"Untreated patients: {untreated_mask.sum()}")
    else:
        has_treatment = np.ones_like(tpred_masked)
        treated_mask = np.ones_like(tpred_masked, dtype=bool)
        untreated_mask = np.zeros_like(tpred_masked, dtype=bool)

    fig, ax = plt.subplots(figsize=figsize, dpi=110)

    ax.scatter(
        ranks[untreated_mask], tobs[untreated_mask],
        color="red", s=10, alpha=0.6, label="No treatment"
    )

    ax.scatter(
        ranks[treated_mask], tobs[treated_mask],
        color="C0", s=10, alpha=0.6, label="Treated"
    )

    if show_fit:
        coeffs = np.polyfit(ranks, tobs, 1)
        xfit = np.linspace(ranks.min(), ranks.max(), 100)
        ax.plot(xfit, np.polyval(coeffs, xfit), ls="--", color="tab:orange", label="Linear fit")

    n = len(ranks)
    bucket_edges = [int(i * n / num_buckets) for i in range(num_buckets + 1)]

    if show_decile_lines:
        for i in range(1, num_buckets):
            edge = bucket_edges[i]
            ax.axvline(edge, color="gray", ls="--", lw=0.8, alpha=0.7)
            ax.text(edge + 2, ax.get_ylim()[1] * 0.02, f"{int(edge / n * 100)}%",
                    rotation=90, verticalalignment='bottom', fontsize=7, color='gray')

        for i in range(num_buckets):
            start, end = bucket_edges[i], bucket_edges[i + 1]
            idx = (ranks >= start) & (ranks < end)
            total_count = idx.sum()
            treated_count = has_treatment[idx].sum()
            center = (start + end) // 2
            ax.text(center, ax.get_ylim()[1] * 0.93,
                    f"n={total_count}\nT={treated_count}",
                    ha="center", fontsize=8, color="black")

    if show_decile_stats:
        medians, iqr_low, iqr_high = [], [], []
        p025_vals, p975_vals = [], []
        bucket_centers = []

        for i in range(num_buckets):
            start = bucket_edges[i]
            end = bucket_edges[i + 1]
            y_vals = tobs[(ranks >= start) & (ranks < end)]
            if len(y_vals) == 0:
                continue
            q1 = np.percentile(y_vals, 25)
            q3 = np.percentile(y_vals, 75)
            med = np.median(y_vals)
            p025 = np.percentile(y_vals, 10)
            p975 = np.percentile(y_vals, 90)

            iqr_low.append(med - q1)
            iqr_high.append(q3 - med)
            p025_vals.append(p025)
            p975_vals.append(p975)
            bucket_centers.append((start + end) // 2)
            medians.append(med)

        for x, y1, y2 in zip(bucket_centers, p025_vals, p975_vals):
            ax.plot([x, x], [y1, y2], color="black", lw=1.5,
                    label="90% percent range" if x == bucket_centers[0] else None)

        ax.errorbar(
            bucket_centers, medians,
            yerr=[iqr_low, iqr_high],
            fmt='o', color='black', elinewidth=2.5, capsize=5,
            label="Median ± IQR"
        )

    ax.set_xlabel("Predicted risk rank (lower = higher risk)")
    ax.set_ylabel("Observed survival time (days)")
    ax.set_title(f"Observed Survival vs. Predicted Risk Rank (buckets={num_buckets})")
    ax.grid(True, ls=":", lw=0.5)
    ax.xaxis.set_major_locator(MaxNLocator(integer=True))
    ax.yaxis.set_major_locator(MaxNLocator(integer=True))
    ax.legend()
    plt.tight_layout()
    plt.show()


In [None]:
plot_observed_vs_predicted_rank_by_risk(
    model              = trained_models["DeepSurv_attention"],
    df_all             = df_all,
    max_observed_days  = 10000,
    show_fit           = False,
    show_decile_lines  = True,
    show_decile_stats  = True,
    num_buckets        = 10 
)


In [None]:
def get_patient_treatment_percentile_info(
    model,
    df_all,
    treatment_map,
    *,
    patient_id=None,
    patient_row=None,
    num_buckets: int = 10,
    max_observed_days: int = 365,
    treatment_prefix="systemicTreatmentPlan"
):
    import numpy as np
    import pandas as pd

    def apply_treatment(df, mapping, treatment_cols, msi_flag):
        df_copy = df.copy()
        df_copy[treatment_cols] = 0
        for col, val in mapping.items():
            if col in df_copy:
                df_copy[col] = val
        if "hasMsi" in df_copy:
            df_copy["hasMsi"] = msi_flag
        df_copy["hasTreatment"] = (df_copy[treatment_cols].sum(axis=1) > 0).astype(int)
        return df_copy

    X_all = df_all.drop(columns=["sourceId", settings.event_col, settings.duration_col])
    times_obs = df_all[settings.duration_col].values
    event_flags = df_all[settings.event_col].values
    observed_mask = (event_flags == 1) & (times_obs <= max_observed_days)

    if hasattr(model, "model") and hasattr(model.model, "predict"):
        predicted_risk = model.model.predict(X_all.values.astype("float32"))
    else:
        predicted_risk = model.predict(X_all)
    if predicted_risk.ndim > 1:
        predicted_risk = predicted_risk.ravel()

    risk_scores = -predicted_risk
    tpred = risk_scores[observed_mask]
    tobs = times_obs[observed_mask]

    ranks = np.argsort(np.argsort(tpred))
    n = len(ranks)
    bucket_edges = [int(i * n / num_buckets) for i in range(num_buckets + 1)]

    bucket_stats = []
    for i in range(num_buckets):
        start, end = bucket_edges[i], bucket_edges[i + 1]
        idx = (ranks >= start) & (ranks < end)
        obs_vals = tobs[idx]
        if len(obs_vals) == 0:
            med, q1, q3, p025, p975 = np.nan, np.nan, np.nan, np.nan, np.nan
        else:
            q1 = np.percentile(obs_vals, 25)
            q3 = np.percentile(obs_vals, 75)
            med = np.median(obs_vals)
            p025 = np.percentile(obs_vals, 2.5)
            p975 = np.percentile(obs_vals, 97.5)
        bucket_stats.append({
            "percentile": f"{int(100 * i / num_buckets)}%–{int(100 * (i+1) / num_buckets)}%",
            "median_days": round(med, 1) if not np.isnan(med) else None,
            "iqr_days": (round(q1, 1), round(q3, 1)) if not np.isnan(q1) else (None, None),
            "ci_days": (round(p025, 1), round(p975, 1)) if not np.isnan(p025) else (None, None)
        })

    if patient_row is not None:
        df_all = pd.concat([df_all, pd.DataFrame([patient_row])], ignore_index=True)
        source_id = patient_row["sourceId"]
    else:
        source_id = patient_id

    patient_df = df_all[df_all["sourceId"] == source_id]
    if patient_df.empty:
        raise ValueError(f"Patient {source_id} not found")

    X_base = patient_df.drop(columns=["sourceId", settings.event_col, settings.duration_col])
    msi_flag = int(X_base.get("hasMsi", pd.Series([0])).iloc[0])
    treatment_cols = [col for col in X_base.columns if col.startswith(treatment_prefix)]

    results = {}
    for label, mapping in treatment_map.items():
        X_mod = apply_treatment(X_base, mapping, treatment_cols, msi_flag)
        if hasattr(model, "model") and hasattr(model.model, "predict"):
            risk_val = -model.model.predict(X_mod.values.astype("float32"))[0]
        else:
            risk_val = -model.predict(X_mod)[0]

        combined = np.append(tpred, risk_val)
        rank = np.argsort(np.argsort(combined))[-1]

        for i in range(num_buckets):
            if bucket_edges[i] <= rank < bucket_edges[i + 1]:
                results[label] = {
                    "predicted_risk": float(risk_val),
                    "percentile_range": bucket_stats[i]["percentile"],
                    "bucket_median_days": bucket_stats[i]["median_days"],
                    "bucket_iqr_days": bucket_stats[i]["iqr_days"],
                    "bucket_ci_days": bucket_stats[i]["ci_days"]
                }
                break
        else:
            i = num_buckets - 1
            results[label] = {
                "predicted_risk": float(risk_val),
                "percentile_range": bucket_stats[i]["percentile"],
                "bucket_median_days": bucket_stats[i]["median_days"],
                "bucket_iqr_days": bucket_stats[i]["iqr_days"],
                "bucket_ci_days": bucket_stats[i]["ci_days"]
            }

    return results


In [None]:
import pandas as pd
import numpy as np

def display_patient_treatment_percentile_info(
    model,
    df_all,
    treatment_map,
    *,
    patient_id=None,
    patient_row=None,
    max_observed_days=10000,
    num_buckets=10,
    treatment_prefix="systemicTreatmentPlan_"
):
    info = get_patient_treatment_percentile_info(
        model=model,
        df_all=df_all,
        treatment_map=treatment_map,
        patient_id=patient_id,
        patient_row=patient_row,
        max_observed_days=max_observed_days,
        num_buckets=num_buckets,
        treatment_prefix=treatment_prefix
    )

    if patient_row is not None:
        source_id = patient_row["sourceId"]
        df_subset = pd.DataFrame([patient_row])
    else:
        source_id = patient_id
        df_subset = df_all[df_all["sourceId"] == source_id]

    if df_subset.empty:
        raise ValueError(f"Patient {source_id} not found")

    duration_days = df_subset.iloc[0][settings.duration_col]
    observed_months = round(duration_days / 30.44, 1)

    treatment_cols = [col for col in df_subset.columns if col.startswith(treatment_prefix)]
    active_treatments = [
        col.replace(treatment_prefix, "")
        for col in treatment_cols
        if df_subset.iloc[0][col] == 1
    ]
    treatment_str = ", ".join(active_treatments) if active_treatments else "No Treatment"

    print(f"Patient {source_id}")
    print(f"Actual survival: {observed_months} months")
    print(f"Actual treatment(s): {treatment_str}\n")

    for label, entry in info.items():
        clean_label = label.replace(treatment_prefix, "")
        p_range = entry["percentile_range"]
        risk_val = entry["predicted_risk"]

        med = entry.get("bucket_median_days")
        med_mo = med / 30.44 if med is not None else None

        iqr_low, iqr_high = entry.get("bucket_iqr_days", (None, None))
        iqr_low_mo = iqr_low / 30.44 if iqr_low is not None else None
        iqr_high_mo = iqr_high / 30.44 if iqr_high is not None else None

        ci_low, ci_high = entry.get("bucket_ci_days", (None, None))
        ci_low_mo = ci_low / 30.44 if ci_low is not None else None
        ci_high_mo = ci_high / 30.44 if ci_high is not None else None

        med_str = f"{med_mo:.1f} mo" if med_mo is not None else "NA"
        iqr_str = (
            f"{iqr_low_mo:.1f}–{iqr_high_mo:.1f} mo"
            if iqr_low_mo is not None and iqr_high_mo is not None else "NA"
        )
        ci_str = (
            f"{ci_low_mo:.1f}–{ci_high_mo:.1f} mo"
            if ci_low_mo is not None and ci_high_mo is not None else "NA"
        )

        print(f"{clean_label}: {p_range} risk bucket (risk: {risk_val:.4f} | med: {med_str} | IQR: {iqr_str} | 95% CI: {ci_str})")


In [None]:
display_patient_treatment_percentile_info(
    model=trained_models["DeepSurv_attention"],
    df_all=df_all,
    treatment_map=valid_treatment_combinations,
    patient_id= 
)


In [None]:
def plot_observed_vs_predicted_risk(
    model,
    df_all,
    *,
    max_observed_days: int = 365,
    num_buckets: int = 10,
    figsize=(10, 6),
    show_fit: bool = False,
    print_errors: bool = True,
    show_bucket_lines: bool = True,
    show_bucket_stats: bool = True
):
    import numpy as np
    import matplotlib.pyplot as plt
    from matplotlib.ticker import MaxNLocator
    import pandas as pd
    import torch

    X = df_all.drop(columns=["sourceId", settings.event_col, settings.duration_col])
    times_obs = df_all[settings.duration_col].values
    event_flags = df_all[settings.event_col].values
    observed_mask = (event_flags == 1) & (times_obs <= max_observed_days)

    if hasattr(model, "model") and hasattr(model.model, "predict"):
        predicted_risk = model.model.predict(X.values.astype("float32"))
    else:
        predicted_risk = model.predict(X)

    if isinstance(predicted_risk, (np.ndarray, torch.Tensor)) and predicted_risk.ndim > 1:
        predicted_risk = predicted_risk.ravel()

    tpred_masked = predicted_risk[observed_mask]
    tobs = times_obs[observed_mask]

    if print_errors:
        print(f"Dropped {len(times_obs) - observed_mask.sum()} censored patients (event_flag = 0)")

    if "hasTreatment" in df_all.columns:
        has_treatment = df_all["hasTreatment"].values[observed_mask]
        treated_mask = has_treatment == 1
        untreated_mask = has_treatment == 0
        print(f"Treated patients:   {treated_mask.sum()}")
        print(f"Untreated patients: {untreated_mask.sum()}")
    else:
        has_treatment = np.ones_like(tpred_masked)
        treated_mask = np.ones_like(tpred_masked, dtype=bool)
        untreated_mask = np.zeros_like(tpred_masked, dtype=bool)

    fig, ax = plt.subplots(figsize=figsize, dpi=110)

    ax.scatter(
        tpred_masked[untreated_mask], tobs[untreated_mask],
        color="red", s=10, alpha=0.6, label="No treatment"
    )

    ax.scatter(
        tpred_masked[treated_mask], tobs[treated_mask],
        color="C0", s=10, alpha=0.6, label="Treated"
    )

    if show_fit:
        coeffs = np.polyfit(tpred_masked, tobs, 1)
        xfit = np.linspace(min(tpred_masked), max(tpred_masked), 100)
        ax.plot(xfit, np.polyval(coeffs, xfit), ls="--", color="tab:orange", label="Linear fit")

    if show_bucket_lines or show_bucket_stats:
        sorted_idx = np.argsort(tpred_masked)
        tpred_sorted = tpred_masked[sorted_idx]
        tobs_sorted = tobs[sorted_idx]
        has_treatment_sorted = has_treatment[sorted_idx]

        n = len(tpred_sorted)
        bucket_edges = [int(i * n / num_buckets) for i in range(num_buckets + 1)]
        bucket_centers = []
        medians, iqr_low, iqr_high = [], [], []
        p025_vals, p975_vals = [], []

        for i in range(num_buckets):
            start = bucket_edges[i]
            end = bucket_edges[i + 1]
            bucket_risks = tpred_sorted[start:end]
            bucket_obs = tobs_sorted[start:end]
            bucket_treats = has_treatment_sorted[start:end]

            if len(bucket_obs) == 0:
                continue 

            center_x = np.median(bucket_risks)
            bucket_centers.append(center_x)

            q1 = np.percentile(bucket_obs, 25)
            q3 = np.percentile(bucket_obs, 75)
            med = np.median(bucket_obs)
            p025 = np.percentile(bucket_obs, 10)
            p975 = np.percentile(bucket_obs, 90)

            medians.append(med)
            iqr_low.append(med - q1)
            iqr_high.append(q3 - med)
            p025_vals.append(p025)
            p975_vals.append(p975)

            if show_bucket_lines and i > 0:
                ax.axvline(x=tpred_sorted[start], color="gray", ls="--", lw=0.8, alpha=0.6)
                ax.text(tpred_sorted[start], ax.get_ylim()[1] * 0.02,
                        f"{int(i / num_buckets * 100)}%",
                        rotation=90, verticalalignment='bottom', fontsize=7, color='gray')

            total_count = len(bucket_obs)
            treated_count = bucket_treats.sum()
            ax.text(center_x, ax.get_ylim()[1] * 0.93,
                    f"n={total_count}\nT={treated_count}",
                    ha="center", fontsize=8, color="black")

        if show_bucket_stats:
            for x, y1, y2 in zip(bucket_centers, p025_vals, p975_vals):
                ax.plot([x, x], [y1, y2], color="black", lw=1.5,
                        label="90% percent range" if x == bucket_centers[0] else None)

            ax.errorbar(
                bucket_centers, medians,
                yerr=[iqr_low, iqr_high],
                fmt='o', color='black', elinewidth=2.5, capsize=5,
                label="Median ± IQR"
            )

    ax.set_xlabel("Predicted risk (higher = more likely event)")
    ax.set_ylabel("Observed survival time (days)")
    ax.set_title(f"Observed Survival vs. Predicted Risk (buckets={num_buckets})")
    ax.grid(True, ls=":", lw=0.5)
    ax.xaxis.set_major_locator(MaxNLocator(integer=True))
    ax.yaxis.set_major_locator(MaxNLocator(integer=True))
    ax.legend()
    plt.tight_layout()
    plt.show()


In [None]:
plot_observed_vs_predicted_risk(
    model=trained_models["DeepSurv_attention"],
    df_all=df_all,
    max_observed_days=10000,
    num_buckets=10,
    show_fit=False,
    show_bucket_lines=True,
    show_bucket_stats=True
)


In [None]:
def get_patient_treatment_risks(
    model,
    df_all,
    treatment_map,
    *,
    patient_id=None,
    patient_row=None,
    treatment_prefix="systemicTreatmentPlan"
):
    import numpy as np
    import pandas as pd

    def apply_treatment(df, mapping, treatment_cols, msi_flag):
        df_copy = df.copy()
        df_copy[treatment_cols] = 0
        for col, val in mapping.items():
            if col in df_copy:
                df_copy[col] = val
        if "hasMsi" in df_copy:
            df_copy["hasMsi"] = msi_flag
        df_copy["hasTreatment"] = (df_copy[treatment_cols].sum(axis=1) > 0).astype(int)
        return df_copy

    if patient_row is not None:
        df_all = pd.concat([df_all, pd.DataFrame([patient_row])], ignore_index=True)
        source_id = patient_row["sourceId"]
    else:
        source_id = patient_id

    patient_df = df_all[df_all["sourceId"] == source_id]
    if patient_df.empty:
        raise ValueError(f"Patient {source_id} not found")

    X_base = patient_df.drop(columns=["sourceId", settings.event_col, settings.duration_col])
    msi_flag = int(X_base.get("hasMsi", pd.Series([0])).iloc[0])
    treatment_cols = [col for col in X_base.columns if col.startswith(treatment_prefix)]

    results = {}
    for label, mapping in treatment_map.items():
        X_mod = apply_treatment(X_base, mapping, treatment_cols, msi_flag)
        if hasattr(model, "model") and hasattr(model.model, "predict"):
            risk_val = -model.model.predict(X_mod.values.astype("float32"))[0]
        else:
            risk_val = -model.predict(X_mod)[0]

        results[label] = {
            "predicted_risk": float(risk_val)
        }

    return results


In [None]:
import pandas as pd
import numpy as np

def display_patient_treatment_risks(
    model,
    df_all,
    treatment_map,
    *,
    patient_id=None,
    patient_row=None,
    treatment_prefix="systemicTreatmentPlan_"
):
    info = get_patient_treatment_risks(
        model=model,
        df_all=df_all,
        treatment_map=treatment_map,
        patient_id=patient_id,
        patient_row=patient_row,
        treatment_prefix=treatment_prefix
    )

    if patient_row is not None:
        source_id = patient_row["sourceId"]
        df_subset = pd.DataFrame([patient_row])
    else:
        source_id = patient_id
        df_subset = df_all[df_all["sourceId"] == source_id]

    if df_subset.empty:
        raise ValueError(f"Patient {source_id} not found")

    duration_days = df_subset.iloc[0][settings.duration_col]
    observed_months = round(duration_days / 30.44, 1)

    treatment_cols = [col for col in df_subset.columns if col.startswith(treatment_prefix)]
    active_treatments = [
        col.replace(treatment_prefix, "")
        for col in treatment_cols
        if df_subset.iloc[0][col] == 1
    ]
    treatment_str = ", ".join(active_treatments) if active_treatments else "No Treatment"

    print(f"Patient {source_id}")
    print(f"Actual survival: {observed_months} months")
    print(f"Actual treatment(s): {treatment_str}\n")

    for label, entry in info.items():
        clean_label = label.replace(treatment_prefix, "")
        risk_val = entry["predicted_risk"]
        print(f"{clean_label}: predicted risk = {risk_val:.4f}")


In [None]:
display_patient_treatment_risks(
    model=trained_models["DeepSurv_attention"],
    df_all=df_all,
    treatment_map=valid_treatment_combinations,
    patient_id=
)


In [None]:
def get_all_patient_treatment_risks(
    model,
    df_all,
    treatment_map,
    treatment_prefix="systemicTreatmentPlan"
):
    import numpy as np
    import pandas as pd
    from tqdm import tqdm

    def apply_treatment(df, mapping, treatment_cols, msi_flag):
        df_copy = df.copy()
        df_copy[treatment_cols] = 0
        for col, val in mapping.items():
            if col in df_copy:
                df_copy[col] = val
        if "hasMsi" in df_copy:
            df_copy["hasMsi"] = msi_flag
        df_copy["hasTreatment"] = (df_copy[treatment_cols].sum(axis=1) > 0).astype(int)
        return df_copy

    results = []
    treatment_cols = [col for col in df_all.columns if col.startswith(treatment_prefix)]

    for idx, row in tqdm(df_all.iterrows(), total=len(df_all), desc="Processing patients"):
        source_id = row["sourceId"]
        duration_days = row[settings.duration_col]
        event = row[settings.event_col]
        msi_flag = int(row.get("hasMsi", 0))

        X_base = row.drop(labels=["sourceId", settings.event_col, settings.duration_col]).to_frame().T

        actual_treatments = [col for col in treatment_cols if row[col] == 1]
        actual_treatment_str = ", ".join(
            [col.replace(treatment_prefix, "") for col in actual_treatments]
        ) if actual_treatments else "No Treatment"

        for treatment_label, mapping in treatment_map.items():
            X_mod = apply_treatment(X_base.copy(), mapping, treatment_cols, msi_flag)

            if hasattr(model, "model") and hasattr(model.model, "predict"):
                risk_val = model.model.predict(X_mod.values.astype("float32"))[0]
            else:
                risk_val = model.predict(X_mod)[0]

            results.append({
                "sourceId": source_id,
                "treatment": treatment_label,
                "predicted_risk": float(risk_val),
                "observed_duration": duration_days,
                "event": event,
                "actual_treatment": actual_treatment_str
            })

    return pd.DataFrame(results)


In [None]:
risk_df = get_all_patient_treatment_risks(
    model=trained_models["DeepSurv_attention"],
    df_all=df_all,
    treatment_map=valid_treatment_combinations,
    treatment_prefix="systemicTreatmentPlan_"
)


print(risk_df.head())


# Hazard ratio vs. Baseline risk

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import linregress

df_no_treatment = risk_df[risk_df["treatment"] == "No Treatment"].copy()
df_5fu = risk_df[risk_df["treatment"] == "5-FU"].copy()

df_merged = pd.merge(
    df_no_treatment[["sourceId", "predicted_risk", "actual_treatment"]],
    df_5fu[["sourceId", "predicted_risk"]],
    on="sourceId",
    suffixes=("_no_treatment", "_5fu")
)

df_merged["log_HR"] = df_merged["predicted_risk_5fu"] - df_merged["predicted_risk_no_treatment"]
df_merged["HR_5fu_vs_no"] = np.exp(df_merged["log_HR"])

plt.figure(figsize=(10, 7))
scatter = plt.scatter(
    df_merged["predicted_risk_no_treatment"],
    df_merged["HR_5fu_vs_no"],
    c=pd.Categorical(df_merged["actual_treatment"]).codes,
    cmap="tab10",
    alpha=0.2,
    marker='.'
)

slope, intercept, r_value, p_value, std_err = linregress(
    df_merged["predicted_risk_no_treatment"],
    df_merged["HR_5fu_vs_no"]
)
x_vals = np.linspace(df_merged["predicted_risk_no_treatment"].min(),
                     df_merged["predicted_risk_no_treatment"].max(), 500)
y_vals = slope * x_vals + intercept
plt.plot(x_vals, y_vals, color="black", linestyle="--", label="Trend line")

plt.axhline(y=1.0, color='gray', linestyle='--', linewidth=1.2, label="HR = 1")

handles, labels = scatter.legend_elements(prop="colors")
unique_treatments = pd.Categorical(df_merged["actual_treatment"]).categories
plt.legend(handles, unique_treatments, title="Actual Treatment", bbox_to_anchor=(1.05, 1), loc="upper left")

plt.xlabel("Predicted Risk (No Treatment)")
plt.ylabel("Hazard Ratio (5-FU vs. No Treatment)")
plt.title("Hazard Ratio vs. Baseline Risk, Colored by Actual Treatment")
plt.grid(True, linestyle=':', linewidth=0.5)
plt.tight_layout()
plt.show()


In [None]:
import pandas as pd
import matplotlib.pyplot as plt

valid_actuals = ["5-FU", "No Treatment"]
filtered_df = risk_df[risk_df["actual_treatment"].isin(valid_actuals)]

no_tx = filtered_df[filtered_df["treatment"] == "No Treatment"]
five_fu = filtered_df[filtered_df["treatment"] == "5-FU"]

merged = pd.merge(
    no_tx[["sourceId", "predicted_risk", "actual_treatment"]],
    five_fu[["sourceId", "predicted_risk"]],
    on="sourceId",
    suffixes=("_no_treatment", "_5fu")
)

plt.figure(figsize=(8, 6))

mask_5fu = merged["actual_treatment"] == "5-FU"
mask_no = merged["actual_treatment"] == "No Treatment"

plt.scatter(
    merged.loc[mask_no, "predicted_risk_no_treatment"],
    merged.loc[mask_no, "predicted_risk_5fu"],
    alpha=0.2, color="darkorange", label="Actual: No Treatment", s=20, marker='.'
)

plt.scatter(
    merged.loc[mask_5fu, "predicted_risk_no_treatment"],
    merged.loc[mask_5fu, "predicted_risk_5fu"],
    alpha=0.2, color="C0", label="Actual: 5-FU", s=20, marker='.'
)

plt.plot([-5, 5], [-5, 5], linestyle="--", color="gray", label="y = x")

plt.xlabel("Predicted Risk (No Treatment)")
plt.ylabel("Predicted Risk (5-FU)")
plt.title("Predicted Risk: 5-FU vs. No Treatment\n(Colored by Actual Treatment)")
plt.grid(True, linestyle=":", linewidth=0.5)
plt.legend()
plt.tight_layout()
plt.show()


In [None]:
import pandas as pd
import matplotlib.pyplot as plt

treatments_of_interest = ["No Treatment", "5-FU, oxaliplatin, bevacizumab"]

filtered = risk_df[risk_df["actual_treatment"].isin(treatments_of_interest)].copy()

pivot_df = filtered.pivot(index="sourceId", columns="treatment", values="predicted_risk").reset_index()

actual_treatments = (
    filtered.drop_duplicates(subset=["sourceId"])[["sourceId", "actual_treatment"]]
)
pivot_df = pd.merge(pivot_df, actual_treatments, on="sourceId", how="left")

col_no_tx = "No Treatment"
col_triplet = "5-FU + oxaliplatin + bevacizumab"

if col_no_tx not in pivot_df.columns or col_triplet not in pivot_df.columns:
    raise ValueError("Missing prediction for one or both treatments.")

pivot_df = pivot_df.dropna(subset=[col_no_tx, col_triplet])

plt.figure(figsize=(8, 6))

mask_no_tx = pivot_df["actual_treatment"] == "No Treatment"
plt.scatter(
    pivot_df.loc[mask_no_tx, col_no_tx],
    pivot_df.loc[mask_no_tx, col_triplet],
    color="darkorange", alpha=0.3, label="Actual: No Treatment", s=20, marker='.'
)

mask_triplet = pivot_df["actual_treatment"] == "5-FU, oxaliplatin, bevacizumab"
plt.scatter(
    pivot_df.loc[mask_triplet, col_no_tx],
    pivot_df.loc[mask_triplet, col_triplet],
    color="C0", alpha=0.7, label="Actual: 5-FU + oxaliplatin + bevacizumab", s=20, marker='.'
)

plt.plot([-5, 5], [-5, 5], linestyle="--", color="gray", label="y = x")

plt.xlabel("Predicted Risk (No Treatment)")
plt.ylabel("Predicted Risk (5-FU + oxaliplatin + bevacizumab)")
plt.title("Predicted Risk Comparison: Triplet Therapy vs. No Treatment")
plt.grid(True, linestyle=":", linewidth=0.5)
plt.legend()
plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import numpy as np
from scipy.stats import gaussian_kde

baseline_df = risk_df[risk_df["treatment"] == "No Treatment"].copy()
baseline_df = baseline_df.dropna(subset=["predicted_risk", "actual_treatment"])
baseline_df["actual_treatment"] = baseline_df["actual_treatment"].astype(str)

min_group_size = 20
group_counts = baseline_df["actual_treatment"].value_counts()
valid_groups = group_counts[group_counts >= min_group_size].index.tolist()
filtered_df = baseline_df[baseline_df["actual_treatment"].isin(valid_groups)]

plt.figure(figsize=(10, 6))
colors = sns.color_palette("tab10", len(valid_groups))

for i, group in enumerate(valid_groups):
    group_data = filtered_df[filtered_df["actual_treatment"] == group]["predicted_risk"].values
    if len(group_data) > 1 and np.std(group_data) > 0:
        kde = gaussian_kde(group_data, bw_method=0.3)
        x_vals = np.linspace(group_data.min() - 0.1, group_data.max() + 0.1, 300)
        y_vals = kde(x_vals) * len(group_data)
        plt.plot(x_vals, y_vals, label=group, color=colors[i], linewidth=2)

plt.title("Smoothed Histogram (Counts) of Baseline Risk by Actual Treatment")
plt.xlabel("Predicted Risk (No Treatment)")
plt.ylabel("Number of Patients (Smoothed)")
plt.grid(True, linestyle=":", linewidth=0.5)
plt.legend(title="Actual Treatment", bbox_to_anchor=(1.05, 1), loc="upper left")
plt.tight_layout()
plt.show()


In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import linregress

df_no_treatment = risk_df[risk_df["treatment"] == "No Treatment"].copy()
df_triplet = risk_df[risk_df["treatment"] == "5-FU + oxaliplatin + bevacizumab"].copy()

df_merged = pd.merge(
    df_no_treatment[["sourceId", "predicted_risk", "actual_treatment"]],
    df_triplet[["sourceId", "predicted_risk"]],
    on="sourceId",
    suffixes=("_no_treatment", "_triplet")
)

df_merged["log_HR"] = df_merged["predicted_risk_triplet"] - df_merged["predicted_risk_no_treatment"]
df_merged["HR_triplet_vs_no"] = np.exp(df_merged["log_HR"])

plt.figure(figsize=(10, 7))
scatter = plt.scatter(
    df_merged["predicted_risk_no_treatment"],
    df_merged["HR_triplet_vs_no"],
    c=pd.Categorical(df_merged["actual_treatment"]).codes,
    cmap="tab10",
    alpha=0.2,
    marker='.',
    s=20
)

slope, intercept, r_value, p_value, std_err = linregress(
    df_merged["predicted_risk_no_treatment"],
    df_merged["HR_triplet_vs_no"]
)
x_vals = np.linspace(df_merged["predicted_risk_no_treatment"].min(),
                     df_merged["predicted_risk_no_treatment"].max(), 500)
y_vals = slope * x_vals + intercept
plt.plot(x_vals, y_vals, color="black", linestyle="--", label="Trend line")

plt.axhline(y=1.0, color='gray', linestyle='--', linewidth=1.2, label="HR = 1")

handles, labels = scatter.legend_elements(prop="colors")
unique_treatments = pd.Categorical(df_merged["actual_treatment"]).categories
plt.legend(handles, unique_treatments, title="Actual Treatment", bbox_to_anchor=(1.05, 1), loc="upper left")

plt.xlabel("Predicted Risk (No Treatment)")
plt.ylabel("Hazard Ratio (Triplet vs. No Treatment)")
plt.title("Hazard Ratio vs. Baseline Risk (Triplet Therapy), Colored by Actual Treatment")
plt.grid(True, linestyle=':', linewidth=0.5)
plt.tight_layout()
plt.show()


In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import linregress

df_no_treatment = risk_df[risk_df["treatment"] == "No Treatment"].copy()
df_triplet = risk_df[risk_df["treatment"] == "5-FU + oxaliplatin + bevacizumab"].copy()

df_merged = pd.merge(
    df_no_treatment[["sourceId", "predicted_risk", "actual_treatment"]],
    df_triplet[["sourceId", "predicted_risk"]],
    on="sourceId",
    suffixes=("_no_treatment", "_triplet")
)

valid_actuals = ["No Treatment", "5-FU, oxaliplatin, bevacizumab"]
df_merged = df_merged[df_merged["actual_treatment"].isin(valid_actuals)].copy()

df_merged["log_HR"] = df_merged["predicted_risk_triplet"] - df_merged["predicted_risk_no_treatment"]
df_merged["HR_triplet_vs_no"] = np.exp(df_merged["log_HR"])

plt.figure(figsize=(10, 7))
scatter = plt.scatter(
    df_merged["predicted_risk_no_treatment"],
    df_merged["HR_triplet_vs_no"],
    c=pd.Categorical(df_merged["actual_treatment"]).codes,
    cmap="Set1",
    alpha=0.2,
    marker='.',
    s=20
)

slope, intercept, r_value, p_value, std_err = linregress(
    df_merged["predicted_risk_no_treatment"],
    df_merged["HR_triplet_vs_no"]
)
x_vals = np.linspace(df_merged["predicted_risk_no_treatment"].min(),
                     df_merged["predicted_risk_no_treatment"].max(), 500)
y_vals = slope * x_vals + intercept
plt.plot(x_vals, y_vals, color="black", linestyle="--", label="Trend line")

plt.axhline(y=1.0, color='gray', linestyle='--', linewidth=1.2, label="HR = 1")

handles, labels = scatter.legend_elements(prop="colors")
unique_treatments = pd.Categorical(df_merged["actual_treatment"]).categories
plt.legend(handles, unique_treatments, title="Actual Treatment", bbox_to_anchor=(1.05, 1), loc="upper left")

plt.xlabel("Predicted Risk (No Treatment)")
plt.ylabel("Hazard Ratio (Triplet vs. No Treatment)")
plt.title("Hazard Ratio vs. Baseline Risk\n(Only Actual No Tx or Triplet)")
plt.grid(True, linestyle=':', linewidth=0.5)
plt.tight_layout()
plt.show()


In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import linregress

subset = risk_df[risk_df["treatment"].isin(["No Treatment", "5-FU + oxaliplatin + bevacizumab"])].copy()

subset = subset.merge(
    df_all[["sourceId", "ageAtMetastaticDiagnosis"]],
    on="sourceId",
    how="left"
)

pivot = subset.pivot(index="sourceId", columns="treatment", values="predicted_risk").reset_index()

pivot = pivot.merge(
    risk_df[["sourceId", "actual_treatment"]],
    on="sourceId",
    how="left"
).drop_duplicates("sourceId")

pivot = pivot.merge(
    df_all[["sourceId", "ageAtMetastaticDiagnosis"]],
    on="sourceId",
    how="left"
)

pivot = pivot.dropna(subset=["No Treatment", "5-FU + oxaliplatin + bevacizumab", "ageAtMetastaticDiagnosis"])

pivot["log_HR"] = pivot["5-FU + oxaliplatin + bevacizumab"] - pivot["No Treatment"]
pivot["HR_triplet_vs_no"] = np.exp(pivot["log_HR"])

pivot["age_group"] = pd.cut(
    pivot["ageAtMetastaticDiagnosis"],
    bins=[0, 60, 70, 120],
    labels=["<60", "60–70", ">70"]
)

print(pivot["age_group"].value_counts(dropna=False))

plt.figure(figsize=(10, 7))
colors = {"<60": "C0", "60–70": "C1", ">70": "C2"}

for group, df_group in pivot.groupby("age_group"):
    plt.scatter(
        df_group["No Treatment"],
        df_group["HR_triplet_vs_no"],
        label=f"Age {group}",
        alpha=0.5,
        marker=".",
        s=20,
        color=colors.get(group, "gray")
    )

plt.axhline(y=1.0, color="gray", linestyle="--", linewidth=1.2, label="HR = 1")
plt.xlabel("Predicted Risk (No Treatment)")
plt.ylabel("Hazard Ratio (Triplet vs. No Treatment)")
plt.title("Hazard Ratio vs. Baseline Risk\nStratified by Age at Metastatic Diagnosis")
plt.legend(title="Age Group")
plt.grid(True, linestyle=":", linewidth=0.5)
plt.tight_layout()
plt.show()


In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import linregress

subset = risk_df[risk_df["treatment"].isin(["No Treatment", "5-FU + oxaliplatin + bevacizumab"])].copy()

subset = subset.merge(
    df_all[["sourceId", "ageAtMetastaticDiagnosis"]],
    on="sourceId",
    how="left"
)

pivot = subset.pivot(index="sourceId", columns="treatment", values="predicted_risk").reset_index()

pivot = pivot.merge(
    risk_df[["sourceId", "actual_treatment"]],
    on="sourceId",
    how="left"
).drop_duplicates("sourceId")

pivot = pivot.merge(
    df_all[["sourceId", "ageAtMetastaticDiagnosis"]],
    on="sourceId",
    how="left"
)

pivot = pivot.dropna(subset=["No Treatment", "5-FU + oxaliplatin + bevacizumab", "ageAtMetastaticDiagnosis"])

pivot["log_HR"] = pivot["5-FU + oxaliplatin + bevacizumab"] - pivot["No Treatment"]
pivot["HR_triplet_vs_no"] = np.exp(pivot["log_HR"])

pivot["age_quartile"] = pd.qcut(
    pivot["ageAtMetastaticDiagnosis"],
    q=4,
    labels=["Q1 (youngest)", "Q2", "Q3", "Q4 (oldest)"]
)

print(pivot["age_quartile"].value_counts(dropna=False))

plt.figure(figsize=(10, 7))
quartile_colors = {
    "Q1 (youngest)": "C0",
    "Q2": "C1",
    "Q3": "C2",
    "Q4 (oldest)": "C3"
}

for group, df_group in pivot.groupby("age_quartile"):
    plt.scatter(
        df_group["No Treatment"],
        df_group["HR_triplet_vs_no"],
        label=group,
        alpha=0.1,
        marker=".",
        s=20,
        color=quartile_colors.get(group, "gray")
    )

plt.axhline(y=1.0, color="gray", linestyle="--", linewidth=1.2, label="HR = 1")
plt.xlabel("Predicted Risk (No Treatment)")
plt.ylabel("Hazard Ratio (Triplet vs. No Treatment)")
plt.title("Hazard Ratio vs. Baseline Risk\nStratified by Age Quartiles")
plt.legend(title="Age Group")
plt.grid(True, linestyle=":", linewidth=0.5)
plt.tight_layout()
plt.show()
