In [None]:
%reload_ext autoreload
%autoreload 2

import os, math, heapq
import joblib
from joblib import Parallel, delayed, load

import sys
from pathlib import Path

dir_path = Path(os.getcwd()).absolute()
module_path = str(dir_path.parent.parent.parent)

if module_path not in sys.path:
    sys.path.append(module_path)
    
import numpy as np
import pandas as pd
import cvxpy as cp
import torch, wandb, lifelines

import pytorch_lightning as pl
pl.utilities.seed.seed_everything(seed=42)

from sklearn.cluster import KMeans
from sklearn.preprocessing import OneHotEncoder
from sklearn.metrics import mean_squared_error

import matplotlib.pyplot as plt

# OWN MODULES
from experiments.data.utils import get_data_tuples
from experiments.data.data_module import UNOSDataModule, UKRegDataModule, UNOS2UKRegDataModule
from organsync.models.organsync_network import OrganSync_Network
from organsync.models.organite_network import OrganITE_Network_VAE, OrganITE_Network
from organsync.models.transplantbenefit import UKELDModel
from organsync.models.confidentmatch import ConfidentMatch as ConfidentMatch_Network
from organsync.models.organsync_network import OrganSync_Network
from organsync.policies import MELD, MELD_na, FIFO, OrganSync, OrganSyncMax, OrganITE, TransplantBenefit, ConfidentMatch
from organsync.models.inference import Inference_ConfidentMatch, Inference_OrganITE, Inference_OrganITE_VAE, Inference_OrganSync, Inference_ConfidentMatch, Inference_TransplantBenefit
from organsync.simulation import Sim

import warnings
warnings.filterwarnings('ignore')

In [None]:
# SETUP DATA

data = 'UKReg'
batch_size = 256

root_data_dir = Path("../datasets").absolute()

project = 'organsync-net-ukreg'
data_dir = root_data_dir / 'processed_UKReg'

dm = UKRegDataModule(data_dir, batch_size=batch_size, control=False)
dm.prepare_data()
dm.setup(stage='fit')

dm_control = UKRegDataModule(data_dir, batch_size=batch_size, control=True)
dm_control.prepare_data()
dm_control.setup(stage='fit')

In [None]:
# INFERENCE LOADING
# OrganSync

from pytorch_lightning import Trainer

#hyperparams from the supplementary

lr = .0095
weight_decay = 0.00001
num_hidden_layers = 3
hidden_dim = 50
output_dim = 19
dropout_prob = 0.006
epochs = 50
batch_size = 128
activation_type = 'leaky_relu'
gamma = 0.87 # lr decay
lambd = .1

# CONSTRUCT MODEL(W)
input_dim = dm.size(1)
organsync_model_with_organ = OrganSync_Network(
    input_dim=input_dim,
    hidden_dim=hidden_dim,
    num_hidden_layers=num_hidden_layers,
    output_dim=output_dim,
    lr=lr, gamma=gamma, lambd=lambd, weight_decay=weight_decay,
    activation_type=activation_type,
    dropout_prob=dropout_prob).double()
 
trainer_with_organ = Trainer(callbacks=[], max_epochs=epochs)
trainer_with_organ.fit(organsync_model_with_organ, datamodule=dm)

# CONSTRUCT MODEL(WITHOUT)
input_dim = dm_control.size(1)
organsync_model_control = OrganSync_Network(
    input_dim=input_dim,
    hidden_dim=hidden_dim,
    num_hidden_layers=num_hidden_layers,
    output_dim=output_dim,
    lr=lr, gamma=gamma, lambd=lambd, weight_decay=weight_decay,
    activation_type=activation_type,
    dropout_prob=dropout_prob).double()
 
trainer_control = Trainer(callbacks=[], max_epochs=epochs)
trainer_control.fit(organsync_model_control, datamodule=dm_control)

inference_0 = Inference_OrganSync(model=organsync_model_control, mean=dm_control.mean, std=dm_control.std)
inference_1 = Inference_OrganSync(model=organsync_model_with_organ, mean=dm.mean, std=dm.std)

In [None]:
# ConfidentMatch
from sklearn import svm
from sklearn.ensemble import RandomForestRegressor
from sklearn.neural_network import MLPRegressor

n_clusters = 15
cm_kwargs = {
             "k": n_clusters,
             "x_col": dm.x_cols,
             "y_col": "Y",
             "H": {
                 "RFR": (RandomForestRegressor, {}),
                 "SVR": (svm.SVR, {}),
                 "MLPR": (
                     MLPRegressor,
                     {"hidden_layer_sizes": (30, 100, 100, 30), "max_iter": epochs},
                 ),
             },
         }

cm = ConfidentMatch_Network(data=dm._train_processed, o_col=dm.o_cols, **cm_kwargs)
cm._get_partitions()
cm._train()

inference_cm = Inference_ConfidentMatch(model=cm, mean=dm.mean, std=dm.std)

inference_cm

In [None]:
# TransplantBenefit

DATA = dm._train_processed
DATA.CENS = np.abs(DATA.CENS - 1)
 
cols = np.union1d(dm.x_cols, dm.o_cols)
cols = cols[cols != 'CENS']

ukeld = UKELDModel(data=DATA, cols=cols, censor_col='CENS', duration_col='Y', penalizer = 0.01)
ukeld.fit()

cols_0 = dm_control.x_cols
cols_0 = cols_0[cols_0!='CENS']

ukeld_0 = UKELDModel(data=DATA, cols=cols, censor_col='CENS', duration_col='Y', penalizer = 0.01)
ukeld_0.fit()

inference_tb = Inference_TransplantBenefit(model=[ukeld_0, ukeld], mean=dm.mean, std=dm.std)

In [None]:
# OrganITE

input_dim = dm.size(1)
hidden_dim=16
num_hidden_layers=4
output_dim=6
lr=0.007
gamma=0.79
lambd=0.15
kappa=0.15
weight_decay=0.0006
n_clusters=15
activation_type="leaky_relu"
dropout_prob=0.11
epochs = 30

organite_model = OrganITE_Network(
             input_dim=input_dim,
             hidden_dim=hidden_dim,
             num_hidden_layers=num_hidden_layers,
             output_dim=output_dim,
             lr=lr,
             gamma=gamma,
             lambd=lambd,
             kappa=kappa,
             weight_decay=weight_decay,
             n_clusters=n_clusters,
             activation_type=activation_type,
             dropout_prob=dropout_prob,
).double()

organite_trainer = Trainer(callbacks=[], max_epochs=epochs)
organite_trainer.fit(organite_model, datamodule=dm)

inference_oite = Inference_OrganITE(model=organite_model, mean=dm.mean, std=dm.std)

# Train VAE
input_dim = len(dm.o_cols)
organite_vae_model_model = OrganITE_Network_VAE(
         input_dim=input_dim,
         hidden_dim=hidden_dim,
         output_dim=output_dim,
         lr=lr,
         gamma=gamma,
         weight_decay=weight_decay,
).double()

organite_vae_trainer = Trainer(callbacks=[], max_epochs=epochs)
organite_vae_trainer.fit(organite_vae_model_model, datamodule=dm)

inference_oite_vae = Inference_OrganITE_VAE(model=organite_vae_model_model, mean=dm.mean, std=dm.std)


In [None]:
# Please adjust to the setting you wish to test
# values presented are those used on our ICML21 
# paper

sim_params = {
    'dm': dm, 'inference_0': inference_0, 'inference_1': inference_1, 
    'initial_waitlist_size': 170, 
    'organ_deficit': .7,
    'patient_count': 1000
}

In [None]:
sim_tb = Sim(**sim_params)
tb = TransplantBenefit(inference=inference_tb, name='TransplantBenefit', initial_waitlist=[p.id for p in sim_tb.waitlist], dm=dm)

stats_tb = sim_tb.simulate(tb)

In [None]:
sim_na = Sim(**sim_params)
meld_na = MELD_na(name='MELD-na', initial_waitlist=[p.id for p in sim_na.waitlist], dm=dm)

stats_na = sim_na.simulate(meld_na)

In [None]:
sim_m = Sim(**sim_params)
meld = MELD(name='MELD', initial_waitlist=[p.id for p in sim_m.waitlist], dm=dm)

stats_m = sim_m.simulate(meld)

In [None]:
sim_f = Sim(**sim_params)
fifo = FIFO(name='FIFO', initial_waitlist=[p.id for p in sim_f.waitlist], dm=dm)

stats_f = sim_f.simulate(fifo)

In [None]:
sim_cm = Sim(**sim_params)
cm_policy = ConfidentMatch(inference=inference_cm, name='ConfidentMatch', initial_waitlist=[p.id for p in sim_cm.waitlist], dm=dm)

stats_cm = sim_cm.simulate(cm_policy)

In [None]:
sim_organite = Sim(**sim_params)
organite = OrganITE(
    name='O-ITE', 
    initial_waitlist=[p.id for p in sim_organite.waitlist], 
    dm=dm, inference_ITE=inference_oite, inference_VAE=inference_oite_vae)

stats_oite = sim_organite.simulate(organite)

In [None]:
sim_organsync = Sim(**sim_params)
organsync = OrganSync(
    name='O-Sync', 
    initial_waitlist=[p.id for p in sim_organsync.waitlist], 
    dm=dm, K=10, inference_0=inference_0, inference_1=inference_1)

stats_os = sim_organsync.simulate(organsync)

In [None]:
print(f'OrganSync - - - {stats_os[0].deaths}\n_________')
print(f'OrganITE - - - {stats_oite[0].deaths}\n_________')
print(f'CM - - - {stats_cm[0].deaths}\n_________')
print(f'TransplantBenefit - - - {stats_tb[0].deaths}\n_________')
print(f'MELD - - - {stats_m[0].deaths}\n_________')
print(f'MELD-na - - - {stats_na[0].deaths}\n_________')
print(f'FIFO - - - {stats_f[0].deaths}\n_________')

In [None]:
{i: len(organsync.queues[i]) for i in organsync.queues.keys()}