In [7]:
import os, math, heapq

import sys
from pathlib import Path
import numpy as np
import pandas as pd
import cvxpy as cp

# OWN MODULES
from organsync.experiments.data.data_module import UKRegDataModule

import warnings

warnings.filterwarnings("ignore")

In [8]:
# SETUP DATA

batch_size = 256

root_data_dir = Path("../datasets").absolute()

data_dir = root_data_dir / "processed_UKReg"  # your path to UKReg

dm = UKRegDataModule(data_dir, batch_size=batch_size, control=False)
dm.prepare_data()
dm.setup(stage="fit")

In [11]:
with open("models/organite_inference.p", "rb") as f:
    organite_inf = cloudpickle.load(f)

In [12]:
from adjutorium.utils.tester import evaluate_survival_estimator
from adjutorium.plugins.prediction import Predictions

predictions = Predictions(category="risk_estimation")


def generate_surv_data(working_df, use_organ):
    cols = list(dm.x_cols)
    if use_organ:
        cols += list(dm.o_cols)
    X = working_df[cols]
    X = X.drop(columns=["CENS"])

    T = working_df["Y"]
    T = dm.scaler.scale_[-1] * T + dm.scaler.mean_[-1]

    Y = working_df["PCENS"]

    X = X[T > 0]
    Y = Y[T > 0]
    T = T[T > 0]

    return X, T, Y


def generate_interm_surv_data(working_df, use_organ):
    cols = list(dm.x_cols)
    if use_organ:
        cols += list(dm.o_cols)
    X = working_df[cols]
    interm_X = organite_inf.model.representation(torch.from_numpy(np.asarray(X)))
    # X = X.drop(columns = ["CENS"])

    T = working_df["Y"]
    T = dm.scaler.scale_[-1] * T + dm.scaler.mean_[-1]

    Y = working_df["PCENS"]

    X = X[T > 0]
    Y = Y[T > 0]
    T = T[T > 0]

    return X, T, Y


full_df = dm._train_processed

transplant_data = generate_surv_data(full_df[full_df["CENS"] == 0], use_organ=True)
no_transplant_data = generate_surv_data(full_df[full_df["CENS"] == 1], use_organ=False)

interm_transplant_data = generate_interm_surv_data(
    full_df[full_df["CENS"] == 0], use_organ=True
)

full_data = generate_surv_data(full_df, use_organ=True)

In [14]:
from organsync.survival_analysis.xgboost import XGBoostRiskEstimation
from tabulate import tabulate

headers = ["Horizon", "C-INDEX", "Brier score", "AUCROC"]


def eval_xgboost(X, T, Y):
    results = []
    surv_analysis = XGBoostRiskEstimation(objective="cox", strategy="debiased_bce")

    for horizon in [*[i * 30 for i in [6, 9, 12]], *[i * 365 for i in [1, 3, 5]]]:
        result = evaluate_survival_estimator(
            surv_analysis,
            X,
            T,
            Y,
            [horizon],
            n_folds=3,
            metrics=["c_index", "brier_score", "aucroc"],
        )
        score = result["str"]
        results.append(
            [horizon, score["c_index"], score["brier_score"], score["aucroc"]]
        )
        # print(f"   horizon {horizon} -> score {result['str']}")

    print(tabulate(results, headers=headers, tablefmt="fancy_grid"))


print("XGBoost eval for transplant data")
eval_xgboost(*transplant_data)


print("XGBoost eval for no transplant data")
eval_xgboost(*no_transplant_data)

XGBoost eval for transplant data






╒═══════════╤═════════════════╤═════════════════╤═════════════════╕
│   Horizon │ C-INDEX         │ Brier score     │ AUCROC          │
╞═══════════╪═════════════════╪═════════════════╪═════════════════╡
│       180 │ 0.741 +/- 0.027 │ 0.051 +/- 0.003 │ 0.78 +/- 0.012  │
├───────────┼─────────────────┼─────────────────┼─────────────────┤
│       270 │ 0.718 +/- 0.028 │ 0.068 +/- 0.004 │ 0.756 +/- 0.02  │
├───────────┼─────────────────┼─────────────────┼─────────────────┤
│       360 │ 0.707 +/- 0.023 │ 0.081 +/- 0.005 │ 0.751 +/- 0.033 │
├───────────┼─────────────────┼─────────────────┼─────────────────┤
│       365 │ 0.706 +/- 0.023 │ 0.082 +/- 0.005 │ 0.745 +/- 0.025 │
├───────────┼─────────────────┼─────────────────┼─────────────────┤
│      1095 │ 0.718 +/- 0.012 │ 0.11 +/- 0.001  │ 0.769 +/- 0.007 │
├───────────┼─────────────────┼─────────────────┼─────────────────┤
│      1825 │ 0.714 +/- 0.002 │ 0.142 +/- 0.002 │ 0.803 +/- 0.008 │
╘═══════════╧═════════════════╧═════════════════





╒═══════════╤═════════════════╤═════════════════╤═════════════════╕
│   Horizon │ C-INDEX         │ Brier score     │ AUCROC          │
╞═══════════╪═════════════════╪═════════════════╪═════════════════╡
│       180 │ 0.948 +/- 0.023 │ 0.008 +/- 0.002 │ 0.891 +/- 0.024 │
├───────────┼─────────────────┼─────────────────┼─────────────────┤
│       270 │ 0.91 +/- 0.009  │ 0.018 +/- 0.004 │ 0.869 +/- 0.037 │
├───────────┼─────────────────┼─────────────────┼─────────────────┤
│       360 │ 0.938 +/- 0.013 │ 0.033 +/- 0.007 │ 0.904 +/- 0.013 │
├───────────┼─────────────────┼─────────────────┼─────────────────┤
│       365 │ 0.938 +/- 0.013 │ 0.033 +/- 0.007 │ 0.895 +/- 0.02  │
├───────────┼─────────────────┼─────────────────┼─────────────────┤
│      1095 │ 0.858 +/- 0.028 │ 0.11 +/- 0.008  │ 0.76 +/- 0.056  │
├───────────┼─────────────────┼─────────────────┼─────────────────┤
│      1825 │ 0.795 +/- 0.029 │ 0.174 +/- 0.01  │ 0.736 +/- 0.032 │
╘═══════════╧═════════════════╧═════════════════





In [10]:
from organsync.survival_analysis.cox_ph import CoxPH


def eval_cox_ph(X, T, Y):
    surv_analysis = CoxPH()

    for horizon in [i * 365 for i in range(1, 5)]:
        result = evaluate_survival_estimator(
            surv_analysis,
            X,
            T,
            Y,
            [horizon],
            n_folds=3,
            metrics=["c_index", "brier_score"],
        )
        print(f"   horizon {horizon} -> score {result['str']}")


print("CoxPH eval for transplant data")
eval_cox_ph(*transplant_data)

print("CoxPH eval for no transplant data")
eval_cox_ph(*no_transplant_data)

CoxPH eval for transplant data
   horizon 365 -> score {'c_index': '0.683 +/- 0.021', 'brier_score': '0.073 +/- 0.004'}
   horizon 730 -> score {'c_index': '0.661 +/- 0.017', 'brier_score': '0.101 +/- 0.003'}
   horizon 1095 -> score {'c_index': '0.646 +/- 0.011', 'brier_score': '0.122 +/- 0.001'}
   horizon 1460 -> score {'c_index': '0.639 +/- 0.008', 'brier_score': '0.144 +/- 0.001'}
CoxPH eval for no transplant data
   horizon 365 -> score {'c_index': '0.933 +/- 0.022', 'brier_score': '0.028 +/- 0.007'}
   horizon 730 -> score {'c_index': '0.871 +/- 0.041', 'brier_score': '0.072 +/- 0.007'}
   horizon 1095 -> score {'c_index': '0.811 +/- 0.028', 'brier_score': '0.111 +/- 0.0'}
   horizon 1460 -> score {'c_index': '0.779 +/- 0.013', 'brier_score': '0.142 +/- 0.011'}


In [13]:
from organsync.survival_analysis.xgboost import XGBoostRiskEstimation
import cloudpickle


transplant_surv_analysis = XGBoostRiskEstimation()
transplant_surv_analysis.fit(*transplant_data)

no_transplant_surv_analysis = XGBoostRiskEstimation()
no_transplant_surv_analysis.fit(*no_transplant_data)


with open("models/organite_survival.p", "wb") as f:
    cloudpickle.dump((transplant_surv_analysis, no_transplant_surv_analysis), f)

In [12]:
from organsync.survival_analysis.cox_ph import CoxPH
import cloudpickle


cox_transplant_surv_analysis = CoxPH()
cox_transplant_surv_analysis.fit(*transplant_data)

cox_no_transplant_surv_analysis = CoxPH()
cox_no_transplant_surv_analysis.fit(*no_transplant_data)


with open("models/cox_ph_survival.p", "wb") as f:
    cloudpickle.dump((cox_transplant_surv_analysis, cox_no_transplant_surv_analysis), f)