In [1]:
%reload_ext autoreload
%autoreload 2

import os, math, heapq
import joblib
from joblib import Parallel, delayed, load

import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
    
import numpy as np
import pandas as pd
import cvxpy as cp
import torch, wandb, lifelines

import pytorch_lightning as pl
pl.utilities.seed.seed_everything(seed=42)

from sklearn.cluster import KMeans
from sklearn.preprocessing import OneHotEncoder
from sklearn.metrics import mean_squared_error

import matplotlib.pyplot as plt

# OWN MODULES
from src.data.utils import get_data_tuples
from src.data.data_module import UNOSDataModule, UKRegDataModule, UNOS2UKRegDataModule
from src.models.organsync import OrganSync_Network
from src.models.organite import OrganITE_Network_VAE, OrganITE_Network
from src.models.transplantbenefit import UKELDModel
from src.models.confidentmatch import ConfidentMatch as ConfidentMatch_Network

from src.eval_policies.policy import MELD, MELD_na, FIFO, OrganSync, OrganSyncMax, OrganITE, TransplantBenefit, ConfidentMatch
from src.eval_policies.simulation import Inference_ConfidentMatch, Inference_OrganITE, Inference_OrganITE_VAE, Inference_OrganSync, Inference_ConfidentMatch, Inference_TransplantBenefit, Sim

In [2]:
# SETUP DATA

data = 'U2U'
batch_size = 256

if data == 'UNOS':
    project = 'organsync-net'
    data_dir = '../data/processed'
    dm = UNOSDataModule(data_dir, batch_size=batch_size)
elif data == 'U2U':
    project = 'organsync-net-u2u'
    project_vae = 'organsync-organite-pnet-u2u'
    project_oite = 'organsync-organite-net-u2u'
    project_cm = 'organsync-cm-u2u'
    project_tb = 'organsync-tb-u2u'
    data_dir = '../data/processed_UNOS2UKReg_no_split'
    model_id_0 = '17t1b4qq'
    model_id_1 = 'iyzdtu8l'
    
    model_tb_id = 'gowo2lt0'
    model_tb_id_0 = '3pbz6quu'
    model_vae_id = 'fq0niu48'
    model_oite_id = 'ase8ebrm'
    model_cm = './u2u_cm'
    dm = UNOS2UKRegDataModule(data_dir, batch_size=batch_size, control=False)
    dm_control = UNOS2UKRegDataModule(data_dir, batch_size=batch_size, control=True)
else:
    project = 'organsync-net-ukreg'
    project_vae = 'organsync-organite-pnet-ukreg'
    project_oite = 'organsync-organite-net-ukreg'
    project_cm = 'organsync-cm-ukreg'
    project_tb = 'organsync-tb-ukreg'
    data_dir = '../data/processed_UKReg/clinical_ukeld_2_ukeld'
    model_id_0 = '2gsswo91'
    model_id_1 = '8298slm5'
    model_tb_id = 'kexhhfry'
    model_tb_id_0 = 'w3cgeh30'
    model_vae_id = 'jx1xmfgr'
    model_oite_id = 'or6o700x'
    model_cm = './ukreg_cm'
    dm = UKRegDataModule(data_dir, batch_size=batch_size, control=False)
    dm_control = UKRegDataModule(data_dir, batch_size=batch_size, control=True)
    
dm_control.prepare_data()
dm_control.setup(stage='fit')

dm.prepare_data()
dm.setup(stage='fit')

In [3]:
# INFERENCE LOADING
# OrganSync


params_0 = wandb.restore(f'organsync_net.ckpt-v0.ckpt', run_path=f'jeroenbe/{project}/{model_id_0}', replace=True)
model_0 = OrganSync_Network.load_from_checkpoint(params_0.name).double()

params_1 = wandb.restore(f'organsync_net.ckpt-v0.ckpt', run_path=f'jeroenbe/{project}/{model_id_1}', replace=True)
model_1 = OrganSync_Network.load_from_checkpoint(params_1.name).double()

trainer_0 = pl.Trainer()
trainer_1 = pl.Trainer()

trainer_0.datamodule = dm_control
trainer_1.datamodule = dm
model_0.trainer = trainer_0
model_1.trainer = trainer_1

inference_0 = Inference_OrganSync(model=model_0, mean=dm_control.mean, std=dm_control.std)
inference_1 = Inference_OrganSync(model=model_1, mean=dm.mean, std=dm.std)

lambd = .1

inference_0.model.lambd = lambd
inference_1.model.lambd = lambd

# ConfidentMatch
cm_kwargs = {'k': 1, 'x_col': dm.x_cols, 'y_col': 'Y', 'H': {}}
cm = ConfidentMatch_Network(data=dm._train_processed, o_col=dm.o_cols, **cm_kwargs)
cm.load(model_cm)
inference_cm = Inference_ConfidentMatch(model=cm, mean=dm.mean, std=dm.std)

# TransplantBenefit

params_tb = wandb.restore(f'{data}_cph', run_path=f'jeroenbe/{project_tb}/{model_tb_id}', replace=True)
params_tb_0 = wandb.restore(f'{data}_cph', run_path=f'jeroenbe/{project_tb}/{model_tb_id_0}', replace=True)
cph = load(params_tb.name)# TransplantBenefit
cph_0 = load(params_tb_0.name)# TransplantBenefit
cols = np.union1d(dm.x_cols, dm.o_cols)
cols = cols[cols != 'CENS']
cols_0 = dm.x_cols
cols_0 = cols_0[cols_0!='CENS']
ukeld = UKELDModel(data=dm._train_processed, cols=cols, censor_col='CENS', duration_col='Y')
ukeld_0 = UKELDModel(data=dm_control._train_processed, cols=cols_0, censor_col='CENS', duration_col='Y')
ukeld.cph = cph
ukeld_0.cph = cph_0
inference_tb = Inference_TransplantBenefit(model=[ukeld_0, ukeld], mean=dm.mean, std=dm.std)



# OrganITE

params_vae = wandb.restore(f'organite_vae_net.ckpt', run_path=f'jeroenbe/{project_vae}/{model_vae_id}', replace=True)
O_VAE = OrganITE_Network_VAE.load_from_checkpoint(params_vae.name).double()
inference_oite_vae = Inference_OrganITE_VAE(model=O_VAE, mean=dm.mean, std=dm.std)
inference_oite_vae.model.trainer = trainer_1


params_oite = wandb.restore(f'organite_net.ckpt-v0.ckpt', run_path=f'jeroenbe/{project_oite}/{model_oite_id}', replace=True) 
organite_net = OrganITE_Network.load_from_checkpoint(params_oite.name).double()
inference_oite = Inference_OrganITE(model=organite_net, mean=dm.mean, std=dm.std)
inference_oite.model.trainer = trainer_1


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
GPU available: False, used: False
TPU available: False, using: 0 TPU cores


In [35]:
# Please adjust to the setting you wish to test
# values presented are those used on our ICML21 
# paper

sim_params = {
    'dm': dm, 'inference_0': inference_0, 'inference_1': inference_1, 
    'initial_waitlist_size': 170, 
    'organ_deficit': .8,
    'patient_count': 700
}

In [6]:
sim_tb = Sim(**sim_params)
tb = TransplantBenefit(inference=inference_tb, name='TransplantBenefit', initial_waitlist=[p.id for p in sim_tb.waitlist], dm=dm)

stats_tb = sim_tb.simulate(tb)

100%|██████████| 365/365 [01:06<00:00,  5.52it/s]


In [8]:
sim_na = Sim(**sim_params)
meld_na = MELD_na(name='MELD-na', initial_waitlist=[p.id for p in sim_na.waitlist], dm=dm)

stats_na = sim_na.simulate(meld_na)

  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
100%|██████████| 365/365 [00:03<00:00, 97.34it/s] 


In [9]:
sim_m = Sim(**sim_params)
meld = MELD(name='MELD', initial_waitlist=[p.id for p in sim_m.waitlist], dm=dm)

stats_m = sim_m.simulate(meld)

100%|██████████| 365/365 [00:02<00:00, 155.89it/s]


In [10]:
sim_f = Sim(**sim_params)
fifo = FIFO(name='FIFO', initial_waitlist=[p.id for p in sim_f.waitlist], dm=dm)

stats_f = sim_f.simulate(fifo)

100%|██████████| 365/365 [00:01<00:00, 264.11it/s]


In [31]:
sim_cm = Sim(**sim_params)
cm_policy = ConfidentMatch(inference=inference_cm, name='ConfidentMatch', initial_waitlist=[p.id for p in sim_cm.waitlist], dm=dm)

stats_cm = sim_cm.simulate(cm_policy)

100%|██████████| 365/365 [03:08<00:00,  1.94it/s]


In [36]:
sim_organite = Sim(**sim_params)
organite = OrganITE(
    name='O-ITE', 
    initial_waitlist=[p.id for p in sim_organite.waitlist], 
    dm=dm, inference_ITE=inference_oite, inference_VAE=inference_oite_vae)

stats_oite = sim_organite.simulate(organite)

  organ = torch.tensor(organ)
100%|██████████| 365/365 [02:53<00:00,  2.10it/s]


In [42]:
sim_organsync = Sim(**sim_params)
organsync = OrganSync(
    name='O-Sync', 
    initial_waitlist=[p.id for p in sim_organsync.waitlist], 
    dm=dm, K=10, inference_0=inference_0, inference_1=inference_1)

stats_os = sim_organsync.simulate(organsync)

100%|██████████| 365/365 [06:30<00:00,  1.07s/it]


In [43]:
print(f'OrganSync\n - - - \n{stats_os}\n_________')
print(f'OrganITE\n - - - \n{stats_oite}\n_________')
print(f'CM\n - - - \n{stats_cm}\n_________')
print(f'TransplantBenefit\n - - - \n{stats_tb}\n_________')
print(f'MELD\n - - - \n{stats_m}\n_________')
print(f'MELD-na\n - - - \n{stats_na}\n_________')
print(f'FIFO\n - - - \n{stats_f}\n_________')

OrganSync
 - - - 
Deaths: 294
Population life-years: 847189.1196494042
Transplant count: 524
First empty day: -1
_________
OrganITE
 - - - 
Deaths: 316
Population life-years: 1209874.4949307248
Transplant count: 544
First empty day: -1
_________
CM
 - - - 
Deaths: 301
Population life-years: 1219249.056150522
Transplant count: 567
First empty day: -1
_________
TransplantBenefit
 - - - 
Deaths: 239
Population life-years: 949579.3633840501
Transplant count: 582
First empty day: -1
_________
MELD
 - - - 
Deaths: 208
Population life-years: 918717.3843586533
Transplant count: 591
First empty day: -1
_________
MELD-na
 - - - 
Deaths: 275
Population life-years: 828706.2019244415
Transplant count: 559
First empty day: -1
_________
FIFO
 - - - 
Deaths: 155
Population life-years: 903628.6660625652
Transplant count: 570
First empty day: -1
_________


In [41]:
{i: len(organsync.queues[i]) for i in organsync.queues.keys()}

{0: 1,
 1: 0,
 2: 0,
 3: 0,
 4: 0,
 5: 2,
 6: 13,
 7: 0,
 8: 1,
 9: 1,
 10: 0,
 11: 1,
 12: 0,
 13: 2,
 14: 0,
 15: 0,
 16: 0,
 17: 0,
 18: 0,
 19: 0,
 20: 0,
 21: 15,
 22: 1,
 23: 0,
 24: 4,
 25: 0,
 26: 0,
 27: 0,
 28: 0,
 29: 0,
 30: 0,
 31: 2,
 32: 0,
 33: 0,
 34: 0,
 35: 1,
 36: 0,
 37: 0,
 38: 0,
 39: 0,
 40: 0,
 41: 0,
 42: 0,
 43: 1,
 44: 0,
 45: 0,
 46: 0,
 47: 1,
 48: 1,
 49: 2}

In [None]:
#print(f'CM\n - - - \n{stats_cm}\n_________')
print(f'OrganITE\n - - - \n{stats_oite}\n_________')

In [None]:
sim_na.waitlist

In [None]:
# .2 | 450 | 2000 | -10

len(dm.o_cols)

In [None]:
organite.dm

In [None]:
np.empty(1, 0)