In [None]:
%reload_ext autoreload
%autoreload 2

import os, math, heapq

import sys
from pathlib import Path

dir_path = Path(os.getcwd()).absolute()
module_path = str(dir_path.parent.parent.parent)
if module_path not in sys.path:
    sys.path.append(module_path)
    
import numpy as np
import pandas as pd
import cvxpy as cp

# OWN MODULES
from experiments.data.data_module import UKRegDataModule
from organsync.models.organsync_network import OrganSync_Network
from organsync.models.organite_network import  OrganITE_Network
from organsync.models.inference import Inference_ConfidentMatch, Inference_OrganITE, Inference_OrganITE_VAE, Inference_OrganSync, Inference_ConfidentMatch, Inference_TransplantBenefit

import warnings
warnings.filterwarnings('ignore')

In [None]:
# SETUP DATA

batch_size = 256

root_data_dir = Path("../datasets").absolute()

data_dir = root_data_dir / "processed_UKReg" # your path to UKReg

dm = UKRegDataModule(data_dir, batch_size=batch_size, control=False)
dm.prepare_data()
dm.setup(stage='fit')

In [None]:
import cloudpickle
import torch

with open("models/organite_inference.p", "rb") as f:
    organite_inf = cloudpickle.load(f)

In [None]:
from adjutorium.utils.tester import evaluate_survival_estimator
from adjutorium.plugins.prediction import Predictions

predictions = Predictions(category="risk_estimation")

def generate_surv_data(working_df, use_organ):
    cols = list(dm.x_cols)
    if use_organ:
        cols += list(dm.o_cols)
    X = working_df[cols]
    X = X.drop(columns = ["CENS"])
    
    T = working_df["Y"]
    T = dm.scaler.scale_[-1] * T + dm.scaler.mean_[-1]

    Y = working_df["PCENS"]

    X = X[T > 0]
    Y = Y[T > 0]
    T = T[T > 0]
    
    return X, T, Y


def generate_interm_surv_data(working_df, use_organ):
    cols = list(dm.x_cols)
    if use_organ:
        cols += list(dm.o_cols)
    X = working_df[cols]
    interm_X = organite_inf.model.representation(torch.from_numpy(np.asarray(X)))
    #X = X.drop(columns = ["CENS"])
    
    T = working_df["Y"]
    T = dm.scaler.scale_[-1] * T + dm.scaler.mean_[-1]

    Y = working_df["PCENS"]

    X = X[T > 0]
    Y = Y[T > 0]
    T = T[T > 0]
    
    return X, T, Y

full_df = dm._train_processed

transplant_data = generate_surv_data(full_df[full_df["CENS"] == 0], use_organ = True)
no_transplant_data = generate_surv_data(full_df[full_df["CENS"] == 1], use_organ = False)

interm_transplant_data = generate_interm_surv_data(full_df[full_df["CENS"] == 0], use_organ = True)

full_data = generate_surv_data(full_df, use_organ = True)

In [None]:
from organsync.survival_analysis.xgboost import XGBoostRiskEstimation

def eval_xgboost(X, T, Y):
    print("    data cnt ", len(X))
    surv_analysis = XGBoostRiskEstimation()

    for horizon in [i * 365 for i in range(1, 5)]:
        result = evaluate_survival_estimator(
                surv_analysis,
                X, T, Y, 
                [horizon], 
                n_folds=3, 
                metrics = ["c_index", 'brier_score']
             )
        print(f"   horizon {horizon} -> score {result['str']}")
  
print("XGBoost eval on interm data")
eval_xgboost(*interm_surv_data)

print("XGBoost eval for transplant data")
eval_xgboost(*transplant_data)


print("XGBoost eval for no transplant data")
eval_xgboost(*no_transplant_data)


In [None]:
from organsync.survival_analysis.cox_ph import CoxPH

def eval_cox_ph(X, T, Y):
    surv_analysis = CoxPH()

    for horizon in [i * 365 for i in range(1, 5)]:
        result = evaluate_survival_estimator(
                surv_analysis,
                X, T, Y, 
                [horizon], 
                n_folds=3, 
                metrics = ["c_index", 'brier_score']
             )
        print(f"   horizon {horizon} -> score {result['str']}")
        

print("CoxPH eval for transplant data")
eval_cox_ph(*transplant_data)

print("CoxPH eval for no transplant data")
eval_cox_ph(*no_transplant_data)


In [None]:
from organsync.survival_analysis.xgboost import XGBoostRiskEstimation
import cloudpickle


transplant_surv_analysis = XGBoostRiskEstimation()
transplant_surv_analysis.fit(*transplant_data)

no_transplant_surv_analysis = XGBoostRiskEstimation()
no_transplant_surv_analysis.fit(*no_transplant_data)


with open("models/organite_survival.p", "wb") as f:
    cloudpickle.dump((transplant_surv_analysis, no_transplant_surv_analysis), f)



In [35]:
from organsync.survival_analysis.cox_ph import CoxPH
import cloudpickle


cox_transplant_surv_analysis = CoxPH()
cox_transplant_surv_analysis.fit(*transplant_data)

cox_no_transplant_surv_analysis = CoxPH()
cox_no_transplant_surv_analysis.fit(*no_transplant_data)


with open("models/cox_ph_survival.p", "wb") as f:
    cloudpickle.dump((cox_transplant_surv_analysis, cox_no_transplant_surv_analysis), f)

