In [1]:
import copy
from collections.abc import Iterable
import functools
import itertools
import operator
from matplotlib import pyplot as plt

import pandas as pd
from pandas.api.types import is_numeric_dtype
import numpy as np
import numpy_ext as npe
import math
import random
from pprint import pprint
from scipy.optimize import curve_fit
from scipy.stats import poisson
from scipy.sparse import hstack, vstack, csr_matrix
import scipy

from sklearn.cluster import KMeans
from sklearn.linear_model import LinearRegression
from sklearn.decomposition import PCA
from sklearn.impute import KNNImputer
from sklearn.preprocessing import Normalizer, StandardScaler, OneHotEncoder
from sklearn.model_selection import train_test_split
from sklearn import metrics
import joblib

import seaborn as sns

import utils
import safety
import ope

import sys

from config import demographics, vital_sign_vars, lab_vars, treatment_vars, vent_vars, guideline_vars, ffill_windows_clinical, SAMPLE_TIME_H
from config import fio2_bins, peep_bins, tv_bins

seed = 3
unsafety_prob = 0.0
shaping = 'unshaped'

test_set_file = 'data/test_unshaped_traj_{}.csv'
train_set_file = 'data/train_unshaped_traj_{}.csv'
q_fname = 'models/peine_mc_{}_0.0_q_table_{}.bin'

greedy_policy_file = 'models/mcp_greedy_policy_{}{}.bin'
sm_policy_file = 'models/mcp_softmax_policy_{}_{}_0.0.bin'
behavior_policy_train_file = 'models/clinicians_policy_train_{}{}.bin'
behavior_policy_test_file = 'models/clinicians_policy_test_{}{}.bin'
behavior_policy_file = 'models/clinicians_policy_train_test_{}{}.bin'

all_var_types = [
    vital_sign_vars,
    lab_vars,
    treatment_vars,
    vent_vars,
    guideline_vars,
]
all_vars = functools.reduce(operator.add, all_var_types)

def add_traj_return(dataset):
    return_set = dataset.copy()
    return_set['traj_reward'] = np.nan
    return_set.loc[return_set.mort90day == 't', 'traj_reward'] = -100
    return_set.loc[return_set.mort90day == 'f', 'traj_reward'] = 100
    return_set['traj_return'] = (.99 ** return_set['traj_len']) * return_set['traj_reward']
    return return_set

def add_scaled_traj_return(dataset):
    return_set = dataset.copy()
    return_set['traj_reward'] = np.nan
    return_set.loc[return_set.mort90day == 't', 'traj_reward'] = 0
    return_set.loc[return_set.mort90day == 'f', 'traj_reward'] = 1
    return_set['traj_return'] = (.99 ** return_set['traj_len']) * return_set['traj_reward']
    return return_set

def add_traj_len(dataset):
    assert dataset.traj_count.isna().sum() == 0
    return_set = dataset.copy()
    return_set['traj_len'] = return_set.groupby('icustay_id')['traj_count'].transform('max')
    return_set['traj_len'] = return_set['traj_len'] + 1
    return return_set

def postprocess(dataset):
    return add_traj_return(add_traj_len(dataset))

np.random.seed(seed)
test_set = postprocess(pd.read_csv(test_set_file.format(seed)))
train_set = postprocess(pd.read_csv(train_set_file.format(seed)))

q_file = joblib.load(q_fname.format(shaping, seed))
q_table_nan = q_file['model']
q_table = np.nan_to_num(q_table_nan, 0.0)
q_table_neg = q_table_nan.copy()
q_table_neg[q_table_neg == 0.0] = float('-inf')
t = 1.0
q_table_nan[q_table_nan == 0.0] = np.nan

assert test_set.traj_reward.isna().sum() == 0
assert train_set.traj_reward.isna().sum() == 0
behavior_policy = utils.repair_policy_uniform(joblib.load(behavior_policy_file.format(seed,'')))
behavior_train_policy = joblib.load(behavior_policy_train_file.format(seed,''))
behavior_test_policy = joblib.load(behavior_policy_test_file.format(seed,''))
behavior_safe_train = safety.repaired_safe(behavior_train_policy, behavior_train_policy)
sm_unsafe = joblib.load(sm_policy_file.format(seed, shaping, unsafety_prob))
if shaping == 'unshaped':
    shaping_scalar = 0.0
evaluations = [
    (train_set, sm_unsafe, behavior_policy, 'train', 'softmax', shaping, shaping_scalar, 'unsafe', seed),
    (test_set, sm_unsafe, behavior_policy, 'test', 'softmax', shaping, shaping_scalar, 'unsafe', seed),
]
if shaping == 'unshaped':
    evaluations += [
        (train_set, behavior_train_policy, behavior_train_policy, 'train', 'observed', shaping, shaping_scalar, 'unsafe', seed),
        (test_set, behavior_test_policy, behavior_test_policy, 'test', 'observed', shaping, shaping_scalar, 'unsafe', seed),
        (train_set, behavior_train_policy, behavior_policy, 'train', 'behavior', shaping, shaping_scalar, 'unsafe', seed),
        (test_set, behavior_train_policy, behavior_policy, 'test', 'behavior', shaping, shaping_scalar, 'unsafe', seed),
        (train_set, behavior_safe_train, behavior_policy, 'train', 'behavior', shaping, shaping_scalar, 'safe', seed),
        (test_set, behavior_safe_train, behavior_policy, 'test', 'behavior', shaping, shaping_scalar, 'safe', seed),
    ]
ois_weights = []
means = []
phwis_means = []
policies = []
for ds, evaluation_policy, behavior_policy, *config in evaluations:
    mean, var, traj_weights = ope.wis_policy(ds, evaluation_policy, behavior_policy)
    phwismean, var, traj_weights = ope.phwis_policy(ds, evaluation_policy, behavior_policy)
    ois_weights.append(traj_weights)
    means.append(mean)
    phwis_means.append(phwis_means)
    policies.append((evaluation_policy, behavior_policy))
    #am = ope.am(ds, evaluation_policy, behavior_policy, delta=0.05)
    #hcope5 = ope.hcope(ds, evaluation_policy, behavior_policy, delta=0.05, c=5)
    am, hcope5 = np.nan, np.nan
    ess = (traj_weights > 0.0).sum()
    print(','.join(map(str, (*config, mean, phwismean, var, ess))))
    # TODO: write result to file with config

AssertionError: Evaluation policy should have some support in behavior policy

In [2]:
return_set = train_set.copy()
return_set.traj_reward = np.nan
return_set.loc[return_set.mort90day == 't', 'traj_reward'] = 0
return_set.loc[return_set.mort90day == 'f', 'traj_reward'] = 1
return_set['traj_return'] = (.99 ** return_set['traj_len']) * return_set['traj_reward']

In [None]:
return_set.traj_reward.isna().sum()

In [None]:
train_set.mort90day

In [None]:
sns.histplot(behavior_policy[behavior_policy != 0.0])

In [None]:
train_set.state_action_id.value_counts().hist()

In [None]:
train_set.state.value_counts().hist(bins=400)
plt.title('Train set state visitations')
plt.show()

train_set.state_action_id.value_counts().hist(bins=140)
plt.title('Train set state-action visitations')
plt.show()

In [None]:
sm_result = 1 # index of solution
e_policy, b_policy = policies[sm_result]
train_test = train_set if sm_result % 2 == 0 else test_set
train_test_label = 'train' if sm_result % 2 == 0 else 'test'
traj_weights = ois_weights[sm_result]
sa_weights = ope.ois_sa_weights(train_test, e_policy, b_policy)
weight_cutoff = .0001 / len(traj_weights)
# weight_cutoff = 0.0 #.01 / len(traj_weights)
train_test['ois_weights_plot'] = traj_weights
train_test['sa_weights'] = sa_weights
traj_returns = train_test.groupby('icustay_id')['traj_return'].first()
weights_returns = pd.DataFrame({'traj_return': traj_returns, 'ois_weight': traj_weights})
weights_returns['weighted_return'] = weights_returns.traj_return * weights_returns.ois_weight
incl_trajs = traj_weights[traj_weights > weight_cutoff]
g = sns.histplot(ois_weights[sm_result], bins=500)
g.set(xscale='log', yscale='log')
plt.title('{}, weights>{:.2E}={}={:.2f}%,seed={},ope={:.1f}'.format(train_test_label, weight_cutoff,(traj_weights > weight_cutoff).sum(), (traj_weights > 0.0).mean() * 100, seed, means[sm_result]))
plt.xlabel('OIS weight')
plt.show()

sns.boxplot(x=traj_weights)
plt.show()

sns.boxplot(x=traj_weights[traj_weights > weight_cutoff])
plt.show()

plot_returns = weights_returns[weights_returns.ois_weight > weight_cutoff]
g = sns.scatterplot(plot_returns, x='traj_return', y='ois_weight', alpha=.4)
g.set(yscale='log')
plt.xlabel('Trajectory return')
plt.ylabel('OIS weight')
plt.show()

(train_test.groupby('icustay_id')['mort90day'].first() == 'f').mean()

In [None]:
np.nanmax(q_table, axis=1)

In [None]:
il_policy = policies[4][0]
m = .5
mixed_policy = ((m* il_policy + (1-m)*e_policy))
mixed_safe = safety.repaired_safe(mixed_policy, behavior_train_policy)
ope.wis_policy(ds, mixed_policy, behavior_policy)
ope.phwis_policy(ds, mixed_policy, behavior_policy)

In [None]:
ope.wis_policy(ds, mixed_policy, behavior_policy)

In [None]:
il_policy

In [None]:
weights_returns.weighted_return.idxmax()

In [None]:
traj_weights.idxmax(), traj_weights.max(), 
# train_set

In [None]:
(train_test.icustay_id == 248077).any()
train_test[train_test.icustay_id == 248077][['state', 'action_discrete']]

In [None]:
train_test['e_policy_prob'] = train_test.apply(lambda x: e_policy[x.state, x.action_discrete], axis=1)
train_test['b_policy_prob'] = train_test.apply(lambda x: b_policy[x.state, x.action_discrete], axis=1)
train_test['il_policy_prob'] = train_test.apply(lambda x: il_policy[x.state, x.action_discrete], axis=1)
train_test['q_value'] = train_test.apply(lambda x: q_table_nan[x.state, x.action_discrete], axis=1)

train_test[train_test.icustay_id == 248077][['state', 'action_discrete', 'sa_weights', 'b_policy_prob', 'e_policy_prob', 'il_policy_prob', 'q_value']]

# train_set.state.value_counts()[479]

In [None]:
il_policy_neg = il_policy.copy()
il_policy_neg[il_policy_neg == 0.0] = float('-inf')
temperature = 1e5
il_trans = scipy.special.softmax(il_policy_neg / temperature)

In [None]:
il_trans

In [None]:
for a, (p_e, p_il, p_m, q) in enumerate(zip(e_policy[343,:], il_policy[343,:], mixed_policy[343,:], q_table_nan[343,:])):
    if p_e > 0 or p_il > 0:
        print("{:4d}: {:.4f} {:.4f} {:.4f} {:.4f} {}".format(a, p_e, p_il, p_m, q, safety.action_id_compliance[a]))

In [None]:
sns.histplot(train_set.groupby('state').icustay_id.nunique(), bins=train_set.groupby('state').icustay_id.nunique().max())
plt.xlabel('# trajectories')
plt.title('Distribution of states over train trajectories')
plt.show()

sns.histplot(train_set.groupby('state_action_id').icustay_id.nunique(), bins=train_set.groupby('state_action_id').icustay_id.nunique().max())
plt.xlabel('# trajectories')
plt.title('Distribution of state-action tuples over train trajectories')
plt.show()

In [None]:
sns.histplot(test_set.groupby('state').icustay_id.nunique(), bins=test_set.groupby('state').icustay_id.nunique().max())
plt.xlabel('# trajectories')
plt.title('Distribution of states over test trajectories')
plt.show()

sns.histplot(test_set.groupby('state_action_id').icustay_id.nunique(), bins=test_set.groupby('state_action_id').icustay_id.nunique().max())
plt.xlabel('# trajectories')
plt.title('Distribution of state-action tuples over test trajectories')
plt.show()

In [None]:
sns.histplot(train_set.groupby('state').size())

In [7]:
import importlib
importlib.reload(safety)
safety.state_compliance_clinical(test_set, safety.avg_clinical_timestep).mean()

0.9741615392822848

In [13]:
(~(test_set.ph_imp_scaled_impknn_unscaled > 7.2)).sum()

566

In [None]:
test_set

In [16]:
566 / test_set.shape[0]

0.020130170359568943

In [18]:
(~(train_set.ph_imp_scaled_impknn_unscaled > 7.2)).sum() / train_set.shape[0]

0.021718297036344904

In [None]:
((e_policy == 0) & (il_policy >= 1e-6)).sum(axis=1)

In [None]:
for a, (p_e, p_il, p_m, q) in enumerate(zip(e_policy[393,:], il_policy[393,:], mixed_policy[393,:], q_table_nan[393,:])):
    if p_e > 0 or p_il > 0:
        print("{:4d}: {:.4f} {:.4f} {:.4f} {:.4f} {}".format(a, p_e, p_il, p_m, q, safety.action_id_compliance[a]))

In [None]:
q_table

In [None]:
train_set.groupby('state')['icustay_id'].nunique()[294]

In [None]:
train_set.state_action_id.value_counts()['343-191']

In [None]:
train_set[(train_set['state'] == 112)].action_discrete.value_counts()

In [None]:
list(train_set.columns)

In [None]:
e_policy[294,191], b_policy[294,191]

In [None]:
train_set.groupby('state')['icustay_id'].nunique()[294]

In [None]:
traj_weights[traj_weights.index == 252599]

In [None]:
traj_weights.sort_values(ascending=False).index[1]

In [None]:
train_set[train_set['icustay_id'] == 252599][demographics]

In [None]:
train_set.vent_duration_h.describe()

In [None]:
train_set[train_set['icustay_id'] == 201896]['bun_imp_scaled_impknn_unscaled']

In [None]:
test_set['elixhauser_vanwalraven']

In [None]:
train_test[train_test.icustay_id == 201896][['state', 'action_discrete']]

In [None]:
behavior_policy[241,183]

In [None]:
behavior_train_policy[241,183]

In [None]:
train_set.state.value_counts()[241]

In [None]:
len(weights_returns[weights_returns.ois_weight < 10e5]), len(weights_returns)

In [None]:
non_outlier = weights_returns[weights_returns.ois_weight < 10e5]
(non_outlier.traj_return * non_outlier.ois_weight).sum() / non_outlier.ois_weight.sum()

In [None]:
sns.boxplot(non_outlier.traj_return * (non_outlier.ois_weight / non_outlier.ois_weight.sum()))

In [None]:
(non_outlier.traj_return * (non_outlier.ois_weight / non_outlier.ois_weight.sum())).median()

In [None]:
plot_returns = weights_returns[weights_returns.ois_weight > weight_cutoff]
g = plt.scatter(x=weights_returns['traj_return'], y=weights_returns['ois_weight'], alpha=.4)
# g.set(yscale='log')
plt.xlabel('Trajectory return')
plt.ylabel('OIS weight')
plt.show()

In [None]:
(train_set.groupby('icustay_id')['mort90day'].first() == 't').mean()
(test_set.groupby('icustay_id')['mort90day'].first() == 't').mean()

In [None]:
(test_set[test_set.icustay_id.isin(incl_trajs.index)].groupby('icustay_id')['mort90day'].first() == 't').mean()

In [None]:
ois_weights

In [None]:
traj

In [None]:
(ois_weights[1] > 0.0).mean()

In [None]:
(traj_weights > (1/len(traj_weights))).sum()