# Prepare data

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import ast
import matplotlib.pyplot as plt
import dill
import torch
import nbimporter
import shap

import os
import sys

os.chdir('/data/repos/actin-personalization/prediction')
sys.path.insert(0, os.path.abspath("src/main/python"))

from models import *
from data.data_processing import DataSplitter, DataPreprocessor
from data.lookups import lookup_manager
from utils.settings import settings
from src.main.python.analysis.predictive_algorithms_training import get_data, plot_different_models_survival_curves

preprocessor = DataPreprocessor(settings.db_config_path, settings.db_name)

In [None]:
df, X_train, X_test, y_train, y_test, encoded_columns = get_data()

In [None]:
def get_preprocessed_data_with_sourceId(preprocessor):
    df_raw = preprocessor.load_data()
    df_all, updated_features, _ = preprocessor.preprocess_data(
        lookup_manager.features, df=df_raw
    )
    df_all["sourceId"] = df_raw.loc[df_all.index, "sourceId"]
    #df_all["reasonRefrainmentFromTreatment"] = df_raw.loc[df_all.index, "reasonRefrainmentFromTreatment"]
    return df_raw, df_all, updated_features

df_raw, df_all, updated_features = get_preprocessed_data_with_sourceId(preprocessor)

In [None]:
def load_trained_model(model_name, model_class, model_kwargs={}):
    model_file_prefix = os.path.join(settings.save_path, f"{settings.outcome}_{model_name}")
    nn_file = model_file_prefix + ".pt"
    sk_file = model_file_prefix + ".pkl"
        
    if model_name in ['CoxPH', 'RandomSurvivalForest', 'GradientBoosting', 'AalenAdditive']:
        with open(sk_file, "rb") as f:
            model = dill.load(f)
        print(f"Model {model_name} loaded from {sk_file}")
        return model
    else:
        model = model_class(**model_kwargs)
    
        state = torch.load(nn_file, map_location=torch.device('cpu'))
        
        model.model.net.load_state_dict(state['net_state'])
    
        if 'labtrans' in state:
            model.labtrans             = state['labtrans']
            model.model.duration_index = model.labtrans.cuts
        
        if 'baseline_hazards' in state:
            model.model.baseline_hazards_ = state['baseline_hazards']
            model.model.baseline_cumulative_hazards_ = state['baseline_cumulative_hazards']
            
            print(f"Baseline hazards loaded for {model_name}.")
            
        model.model.net.eval()     
        print(f"Model {model_name} loaded from {nn_file}")
        
        return model
    
def load_all_trained_models(X_train):
    loaded_models = {}
    config_mgr = ExperimentConfig(settings.json_config_file)
    loaded_configs = config_mgr.load_model_configs()

    for model_name, (model_class, model_kwargs) in loaded_configs.items():
        print(model_name, model_class)
        try:
            loaded_model = load_trained_model(
                model_name=model_name, 
                model_class=model_class, 
                model_kwargs=model_kwargs
            )
            loaded_models[model_name] = loaded_model

            ModelTrainer._set_attention_indices(loaded_models[model_name], list(X_train.columns))
        except:
            print(f'Could not load: {model_name}')
            continue
    return loaded_models

In [None]:
trained_models = load_all_trained_models(X_train)

# Risk-stratified heterogeneity of treatment effect analysis

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.linear_model import LogisticRegression
from lifelines import CoxPHFitter


In [None]:
model = trained_models["DeepSurv_attention"]
X_input = X_test.copy()

if hasattr(model, "model") and hasattr(model.model, "predict"):
    predicted_risks = model.model.predict(X_input.values.astype("float32"))
else:
    predicted_risks = model.predict(X_input)

if isinstance(predicted_risks, (np.ndarray, torch.Tensor)) and predicted_risks.ndim > 1:
    predicted_risks = predicted_risks.ravel()

risk_df = pd.DataFrame({
    "sourceId": df_all.loc[X_input.index, "sourceId"].values,
    "predicted_risk": predicted_risks,
    "duration": y_test["survivalDaysSinceMetastaticDiagnosis"],
    "event": y_test["hadSurvivalEvent"]
})


# Step 1: General definition of the research aim

The typical research aim is: “to compare the effect of treatment to a comparator treatment in patients with a disease with respect to outcomes 𝑂1,…,𝑂𝑛”.

We use a comparative cohort design. This means that at least three cohorts of patients need to be defined at this stage of the framework:

A single treatment cohort (𝑇), which includes patients with disease receiving the target treatment of interest.

A single comparator cohort (𝐶), which includes patients with disease receiving the comparator treatment.

One or more outcome cohorts (𝑂1,…,𝑂𝑛) that contain patients developing the outcomes of interest

# Step 2: Identification of the databases

Including in our analyses multiple databases representing the population of interest potentially increases the generalizability of results. Furthermore, the cohorts should preferably have adequate sample size with adequate follow-up time to ensure precise effect estimation, even within smaller risk strata. Other relevant issues such as the depth of data capture (the precision at which measurements, lab tests, conditions are recorded) and the reliability of data entry should also be considere

In [None]:
exclude_cols = [
    'sourceId', 'hasTreatment', 'hadSurvivalEvent',
    'survivalDaysSinceMetastaticDiagnosis', 'predicted_risk', 'risk_score', 'risk_group'
]
covariates = [
    col for col in df_all.columns
    if col not in exclude_cols and not col.startswith("systemicTreatmentPlan")
]


# Step 3: Prediction

In [None]:
df_all = df_all.merge(risk_df[['sourceId', 'predicted_risk']], on='sourceId', how='inner')

df_all['risk_score'] = df_all['predicted_risk']

In [None]:
df_all["risk_group"] = pd.qcut(df_all["predicted_risk"], q=3, labels=["Low", "Medium", "High"])


# Step 4: Estimation

In [None]:
from sklearn.linear_model import LogisticRegression
X = df_all[covariates]
y = df_all["hasTreatment"]

ps_model = LogisticRegression(solver="liblinear", max_iter=1000)
ps_model.fit(X, y)
df_all["propensity_score"] = ps_model.predict_proba(X)[:, 1]

treatment_rate = y.mean()
ps = df_all["propensity_score"]
df_all["preference_score"] = ps / (ps + (1 - ps) * (1 - treatment_rate) / treatment_rate)


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

g = sns.FacetGrid(
    df_all,
    col="risk_group",
    hue="hasTreatment",
    palette={0: "orange", 1: "blue"},
    height=4, aspect=1.2,
    sharex=True, sharey=True
)
g.map_dataframe(sns.kdeplot, x="preference_score", fill=True, alpha=0.4)
g.add_legend(title="Treatment", labels=["Control", "Treated"])
g.set_axis_labels("Preference Score", "Density")
g.set_titles("Risk Group: {col_name}")
g.fig.subplots_adjust(top=0.85)
g.fig.suptitle("Preference Score Overlap by Predicted Risk Group")
plt.tight_layout()
plt.show()


In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

def compute_smd(x_treated, x_control):
    """Standardized Mean Difference"""
    mean_t = np.nanmean(x_treated)
    mean_c = np.nanmean(x_control)
    sd_pooled = np.nanstd(np.concatenate([x_treated, x_control]))
    if sd_pooled == 0:
        return 0.0
    return np.abs(mean_t - mean_c) / sd_pooled

smd_data = []

for group in df_all["risk_group"].unique():
    group_data = df_all[df_all["risk_group"] == group].copy()
    group_data["ps_bin"] = pd.qcut(group_data["preference_score"], q=5, labels=False, duplicates="drop")

    for cov in covariates:
        treated = group_data[group_data["hasTreatment"] == 1][cov]
        control = group_data[group_data["hasTreatment"] == 0][cov]
        smd_before = compute_smd(treated.values, control.values)
        
        smd_bins = []
        for b in group_data["ps_bin"].dropna().unique():
            bin_data = group_data[group_data["ps_bin"] == b]
            t_bin = bin_data[bin_data["hasTreatment"] == 1][cov]
            c_bin = bin_data[bin_data["hasTreatment"] == 0][cov]
            if len(t_bin) > 0 and len(c_bin) > 0:
                smd_bins.append(compute_smd(t_bin.values, c_bin.values))
        smd_after = np.mean(smd_bins) if smd_bins else np.nan

        smd_data.append({
            "Risk Group": group,
            "Covariate": cov,
            "SMD Before": smd_before,
            "SMD After": smd_after
        })

smd_df = pd.DataFrame(smd_data)

g = sns.FacetGrid(smd_df, col="Risk Group", height=4, aspect=1)
g.map_dataframe(sns.scatterplot, x="SMD Before", y="SMD After", alpha=0.7)
for ax in g.axes.flat:
    ax.plot([0, 1], [0, 1], ls="--", color="gray")  # Diagonal reference line

g.set_axis_labels("Before stratification on PS", "After stratification on PS")
g.set_titles("Risk Group: {col_name}")
g.fig.suptitle("Covariate Balance Before vs. After PS Stratification", y=1.05)
plt.tight_layout()
plt.show()


In [None]:
import statsmodels.api as sm
from lifelines import CoxPHFitter

df_all["ps_bin"] = pd.qcut(df_all["preference_score"], q=5, labels=False, duplicates="drop")

results = []

for group in df_all["risk_group"].unique():
    df_rg = df_all[df_all["risk_group"] == group].copy()

    df_rg["ps_bin"] = df_rg["ps_bin"].astype("category")
    cph = CoxPHFitter()

    cols = ["hasTreatment", "ps_bin"]
    cph.fit(
        df_rg[["survivalDaysSinceMetastaticDiagnosis", "hadSurvivalEvent"] + cols],
        duration_col="survivalDaysSinceMetastaticDiagnosis",
        event_col="hadSurvivalEvent"
    )

    hr = cph.hazard_ratios_["hasTreatment"]
    ci_log = cph.confidence_intervals_.loc["hasTreatment"]
    ci_lower = np.exp(ci_log["95% lower-bound"])
    ci_upper = np.exp(ci_log["95% upper-bound"])


    results.append({
        "Risk Group": group,
        "HR": hr,
        "CI Lower": ci_lower,
        "CI Upper": ci_upper
    })

hr_df = pd.DataFrame(results)



In [None]:
import matplotlib.pyplot as plt

order = ["Low", "Medium", "High"]
hr_df["Risk Group"] = pd.Categorical(hr_df["Risk Group"], categories=order, ordered=True)
hr_df = hr_df.sort_values("Risk Group")

plt.figure(figsize=(8, 5))
plt.errorbar(
    x=hr_df["Risk Group"],
    y=hr_df["HR"],
    yerr=[hr_df["HR"] - hr_df["CI Lower"], hr_df["CI Upper"] - hr_df["HR"]],
    fmt='o',
    capsize=5,
    label="Hazard Ratio"
)
plt.axhline(1.0, color='gray', linestyle='--')
plt.title("Heterogeneity of Treatment Effect by Risk Group")
plt.ylabel("Hazard Ratio (Treated vs. Untreated)")
plt.xlabel("Baseline Risk Group")
plt.ylim(0, max(hr_df["CI Upper"].max(), 2))
plt.tight_layout()
plt.show()
