In [None]:
%reload_ext autoreload
%autoreload 2

import os, math, heapq
import joblib
from joblib import Parallel, delayed, load

import sys
from pathlib import Path

dir_path = Path(os.getcwd()).absolute()
module_path = str(dir_path.parent.parent.parent)

if module_path not in sys.path:
    sys.path.append(module_path)
    
import numpy as np
import pandas as pd
import cvxpy as cp
import torch, wandb, lifelines

import pytorch_lightning as pl
pl.utilities.seed.seed_everything(seed=42)

from sklearn.cluster import KMeans
from sklearn.preprocessing import OneHotEncoder
from sklearn.metrics import mean_squared_error

import matplotlib.pyplot as plt

# OWN MODULES
from experiments.data.utils import get_data_tuples
from experiments.data.data_module import UNOSDataModule, UKRegDataModule, UNOS2UKRegDataModule
from organsync.models.organsync_network import OrganSync_Network
from organsync.models.organite_network import OrganITE_Network_VAE, OrganITE_Network
from organsync.models.transplantbenefit import UKELDModel
from organsync.models.confidentmatch import ConfidentMatch as ConfidentMatch_Network

from organsync.policies import MELD, MELD_na, FIFO, OrganSync, OrganSyncMax, OrganITE, TransplantBenefit, ConfidentMatch
from organsync.models.inference import Inference_ConfidentMatch, Inference_OrganITE, Inference_OrganITE_VAE, Inference_OrganSync, Inference_ConfidentMatch, Inference_TransplantBenefit
from organsync.simulation import Sim

In [None]:
# SETUP DATA

data = 'U2U'
batch_size = 256

if data == 'UNOS':
    project = 'organsync-net'
    data_dir = '../data/processed'
    dm = UNOSDataModule(data_dir, batch_size=batch_size)
elif data == 'U2U':
    project = 'organsync-net-u2u'
    project_vae = 'organsync-organite-pnet-u2u'
    project_oite = 'organsync-organite-net-u2u'
    project_cm = 'organsync-cm-u2u'
    project_tb = 'organsync-tb-u2u'
    data_dir = '../data/processed_UNOS2UKReg_no_split'
    model_id_0 = '17t1b4qq'
    model_id_1 = 'iyzdtu8l'
    
    model_tb_id = 'gowo2lt0'
    model_tb_id_0 = '3pbz6quu'
    model_vae_id = 'fq0niu48'
    model_oite_id = 'ase8ebrm'
    model_cm = './u2u_cm'
    dm = UNOS2UKRegDataModule(data_dir, batch_size=batch_size, control=False)
    dm_control = UNOS2UKRegDataModule(data_dir, batch_size=batch_size, control=True)
else:
    project = 'organsync-net-ukreg'
    project_vae = 'organsync-organite-pnet-ukreg'
    project_oite = 'organsync-organite-net-ukreg'
    project_cm = 'organsync-cm-ukreg'
    project_tb = 'organsync-tb-ukreg'
    data_dir = '../data/processed_UKReg/clinical_ukeld_2_ukeld'
    model_id_0 = '2gsswo91'
    model_id_1 = '8298slm5'
    model_tb_id = 'kexhhfry'
    model_tb_id_0 = 'w3cgeh30'
    model_vae_id = 'jx1xmfgr'
    model_oite_id = 'or6o700x'
    model_cm = './ukreg_cm'
    dm = UKRegDataModule(data_dir, batch_size=batch_size, control=False)
    dm_control = UKRegDataModule(data_dir, batch_size=batch_size, control=True)
    
dm_control.prepare_data()
dm_control.setup(stage='fit')

dm.prepare_data()
dm.setup(stage='fit')

In [None]:
# INFERENCE LOADING
# OrganSync


params_0 = wandb.restore(f'organsync_net.ckpt-v0.ckpt', run_path=f'jeroenbe/{project}/{model_id_0}', replace=True)
model_0 = OrganSync_Network.load_from_checkpoint(params_0.name).double()

params_1 = wandb.restore(f'organsync_net.ckpt-v0.ckpt', run_path=f'jeroenbe/{project}/{model_id_1}', replace=True)
model_1 = OrganSync_Network.load_from_checkpoint(params_1.name).double()

trainer_0 = pl.Trainer()
trainer_1 = pl.Trainer()

trainer_0.datamodule = dm_control
trainer_1.datamodule = dm
model_0.trainer = trainer_0
model_1.trainer = trainer_1

inference_0 = Inference_OrganSync(model=model_0, mean=dm_control.mean, std=dm_control.std)
inference_1 = Inference_OrganSync(model=model_1, mean=dm.mean, std=dm.std)

lambd = .1

inference_0.model.lambd = lambd
inference_1.model.lambd = lambd

# ConfidentMatch
cm_kwargs = {'k': 1, 'x_col': dm.x_cols, 'y_col': 'Y', 'H': {}}
cm = ConfidentMatch_Network(data=dm._train_processed, o_col=dm.o_cols, **cm_kwargs)
cm.load(model_cm)
inference_cm = Inference_ConfidentMatch(model=cm, mean=dm.mean, std=dm.std)

# TransplantBenefit

params_tb = wandb.restore(f'{data}_cph', run_path=f'jeroenbe/{project_tb}/{model_tb_id}', replace=True)
params_tb_0 = wandb.restore(f'{data}_cph', run_path=f'jeroenbe/{project_tb}/{model_tb_id_0}', replace=True)
cph = load(params_tb.name)# TransplantBenefit
cph_0 = load(params_tb_0.name)# TransplantBenefit
cols = np.union1d(dm.x_cols, dm.o_cols)
cols = cols[cols != 'CENS']
cols_0 = dm.x_cols
cols_0 = cols_0[cols_0!='CENS']
ukeld = UKELDModel(data=dm._train_processed, cols=cols, censor_col='CENS', duration_col='Y')
ukeld_0 = UKELDModel(data=dm_control._train_processed, cols=cols_0, censor_col='CENS', duration_col='Y')
ukeld.cph = cph
ukeld_0.cph = cph_0
inference_tb = Inference_TransplantBenefit(model=[ukeld_0, ukeld], mean=dm.mean, std=dm.std)



# OrganITE

params_vae = wandb.restore(f'organite_vae_net.ckpt', run_path=f'jeroenbe/{project_vae}/{model_vae_id}', replace=True)
O_VAE = OrganITE_Network_VAE.load_from_checkpoint(params_vae.name).double()
inference_oite_vae = Inference_OrganITE_VAE(model=O_VAE, mean=dm.mean, std=dm.std)
inference_oite_vae.model.trainer = trainer_1


params_oite = wandb.restore(f'organite_net.ckpt-v0.ckpt', run_path=f'jeroenbe/{project_oite}/{model_oite_id}', replace=True) 
organite_net = OrganITE_Network.load_from_checkpoint(params_oite.name).double()
inference_oite = Inference_OrganITE(model=organite_net, mean=dm.mean, std=dm.std)
inference_oite.model.trainer = trainer_1


In [None]:
# Please adjust to the setting you wish to test
# values presented are those used on our ICML21 
# paper

sim_params = {
    'dm': dm, 'inference_0': inference_0, 'inference_1': inference_1, 
    'initial_waitlist_size': 170, 
    'organ_deficit': .8,
    'patient_count': 700
}

In [None]:
sim_tb = Sim(**sim_params)
tb = TransplantBenefit(inference=inference_tb, name='TransplantBenefit', initial_waitlist=[p.id for p in sim_tb.waitlist], dm=dm)

stats_tb = sim_tb.simulate(tb)

In [None]:
sim_na = Sim(**sim_params)
meld_na = MELD_na(name='MELD-na', initial_waitlist=[p.id for p in sim_na.waitlist], dm=dm)

stats_na = sim_na.simulate(meld_na)

In [None]:
sim_m = Sim(**sim_params)
meld = MELD(name='MELD', initial_waitlist=[p.id for p in sim_m.waitlist], dm=dm)

stats_m = sim_m.simulate(meld)

In [None]:
sim_f = Sim(**sim_params)
fifo = FIFO(name='FIFO', initial_waitlist=[p.id for p in sim_f.waitlist], dm=dm)

stats_f = sim_f.simulate(fifo)

In [None]:
sim_cm = Sim(**sim_params)
cm_policy = ConfidentMatch(inference=inference_cm, name='ConfidentMatch', initial_waitlist=[p.id for p in sim_cm.waitlist], dm=dm)

stats_cm = sim_cm.simulate(cm_policy)

In [None]:
sim_organite = Sim(**sim_params)
organite = OrganITE(
    name='O-ITE', 
    initial_waitlist=[p.id for p in sim_organite.waitlist], 
    dm=dm, inference_ITE=inference_oite, inference_VAE=inference_oite_vae)

stats_oite = sim_organite.simulate(organite)

In [None]:
sim_organsync = Sim(**sim_params)
organsync = OrganSync(
    name='O-Sync', 
    initial_waitlist=[p.id for p in sim_organsync.waitlist], 
    dm=dm, K=10, inference_0=inference_0, inference_1=inference_1)

stats_os = sim_organsync.simulate(organsync)

In [None]:
print(f'OrganSync\n - - - \n{stats_os}\n_________')
print(f'OrganITE\n - - - \n{stats_oite}\n_________')
print(f'CM\n - - - \n{stats_cm}\n_________')
print(f'TransplantBenefit\n - - - \n{stats_tb}\n_________')
print(f'MELD\n - - - \n{stats_m}\n_________')
print(f'MELD-na\n - - - \n{stats_na}\n_________')
print(f'FIFO\n - - - \n{stats_f}\n_________')

In [None]:
{i: len(organsync.queues[i]) for i in organsync.queues.keys()}

In [None]:
#print(f'CM\n - - - \n{stats_cm}\n_________')
print(f'OrganITE\n - - - \n{stats_oite}\n_________')

In [None]:
sim_na.waitlist

In [None]:
# .2 | 450 | 2000 | -10

len(dm.o_cols)

In [None]:
organite.dm

In [None]:
np.empty(1, 0)