In [None]:
import copy
from collections.abc import Iterable
import functools
import itertools
import operator
from matplotlib import pyplot as plt

import pandas as pd
from pandas.api.types import is_numeric_dtype
import numpy as np
import numpy_ext as npe
import math
import random
from pprint import pprint
from scipy.optimize import curve_fit
from scipy.stats import poisson
from scipy.sparse import hstack, vstack, csr_matrix
import scipy

from sklearn.cluster import KMeans
from sklearn.linear_model import LinearRegression
from sklearn.decomposition import PCA
from sklearn.impute import KNNImputer
from sklearn.preprocessing import Normalizer, StandardScaler, OneHotEncoder
from sklearn.model_selection import train_test_split
from sklearn import metrics
import joblib

import seaborn as sns

import utils
import safety
import ope

import sys

from config import demographics, vital_sign_vars, lab_vars, treatment_vars, vent_vars, guideline_vars, ffill_windows_clinical, SAMPLE_TIME_H
from config import fio2_bins, peep_bins, tv_bins

In [11]:
test_set_file = 'data/test_unshaped_traj_{}.csv'
train_set_file = 'data/train_unshaped_traj_{}.csv'

greedy_policy_file = 'models/mcp_greedy_policy_{}{}.bin'
sm_policy_file = 'models/mcp_softmax_policy_{}{}.bin'
behavior_policy_train_file = 'models/clinicians_policy_train_{}{}.bin'
behavior_policy_test_file = 'models/clinicians_policy_test_{}{}.bin'
behavior_policy_file = 'models/clinicians_policy_train_test_{}{}.bin'

In [12]:
def add_traj_return(dataset):
    return_set = dataset.copy()
    return_set['traj_reward'] = np.nan
    return_set.loc[return_set.mort90day == 't', 'traj_reward'] = -100
    return_set.loc[return_set.mort90day == 'f', 'traj_reward'] = 100
    return_set['traj_return'] = (.99 ** return_set['traj_len']) * return_set['traj_reward']
    return return_set

def add_scaled_traj_return(dataset):
    return_set = dataset.copy()
    return_set['traj_reward'] = np.nan
    return_set.loc[return_set.mort90day == 't', 'traj_reward'] = 0
    return_set.loc[return_set.mort90day == 'f', 'traj_reward'] = 1
    return_set['traj_return'] = (.99 ** return_set['traj_len']) * return_set['traj_reward']
    return return_set

In [13]:
test_set_file.format('avgpotential2-0.1', 0)

'data/test_avgpotential2-0.1_traj_0.csv'

In [19]:
# check dataset sizes (icustay ids)
seeds = range(10)
# seeds = (0,2,4,6,8)
# shaping = 'unshaped'
# shaping_fname = {
#     'unshaped': '',
#     'avgpotential-0.1': '_avgpotential-0.1',
#     'avgpotential-0.5': '_avgpotential-0.5'
#     'avgpotential-1.0': '_avgpotential-1.0',
#     'avgpotential-2.0': '_avgpotential-2.0',
#     'avgpotential-3.0': '_avgpotential-3.0',
#     'avgpotential-5.0': '_avgpotential-5.0',
#     'avgpotential-6.0': '_avgpotential-6.0',
#     'avgpotential-10.0': '_avgpotential-10.0',
# }
shaping_fname = {
#     'unshaped': '',
    'avgpotential2-0.1': '_avgpotential2-0.1',
    'avgpotential2-0.5': '_avgpotential2-0.5',
    'avgpotential2-1.0': '_avgpotential2-1.0',
    'avgpotential2-2.0': '_avgpotential2-2.0',
    'avgpotential2-3.0': '_avgpotential2-3.0',
    'avgpotential2-5.0': '_avgpotential2-5.0',
    'avgpotential2-6.0': '_avgpotential2-6.0',
    'avgpotential2-10.0': '_avgpotential2-10.0',
}


# for (seed, shaping) in itertools.product(seeds, shaping_fname.keys()):
#     test_set = add_traj_return(pd.read_csv(test_set_file.format(shaping,seed)))
#     train_set = add_traj_return(pd.read_csv(train_set_file.format(shaping,seed)))
#     print(shaping, seed, train_set.icustay_id.nunique(), test_set.icustay_id.nunique())

In [20]:
%%time
# seeds = range(10)
for seed in seeds:
    for shaping in shaping_fname.keys():
        if shaping == 'unshaped':
            shaped = False
            shaping_scalar = 0.0
        else:
            shaped = True
            shaping_scalar = shaping.split('-')[-1]
        np.random.seed(seed)
        test_set = add_traj_return(pd.read_csv(test_set_file.format('unshaped',seed)))
        train_set = add_traj_return(pd.read_csv(train_set_file.format('unshaped',seed)))
        
        behavior_policy = utils.repair_policy_uniform(joblib.load(behavior_policy_file.format(seed,'')))
        behavior_train_policy = joblib.load(behavior_policy_train_file.format(seed,''))
        behavior_test_policy = joblib.load(behavior_policy_test_file.format(seed,''))
        behavior_safe_train = safety.repaired_safe(behavior_train_policy, behavior_train_policy)

        greedy_unsafe = utils.repair_unsupported_greedy_policy(
            joblib.load(greedy_policy_file.format(seed, shaping_fname[shaping])),
            train_set
        )
        greedy_safe = safety.repaired_safe(greedy_unsafe, behavior_train_policy, greedy=True)
        sm_unsafe = joblib.load(sm_policy_file.format(seed, shaping_fname[shaping]))
        sm_safe = safety.repaired_safe(sm_unsafe, behavior_train_policy)
        
        evaluations = [
#             (train_set, greedy_unsafe, behavior_policy, 'train', 'greedy', shaped, shaping_scalar, 'unsafe', seed),
#             (test_set, greedy_unsafe, behavior_policy, 'test', 'greedy', shaped, shaping_scalar, 'unsafe', seed),
#             (train_set, greedy_safe, behavior_policy, 'train', 'greedy', shaped, shaping_scalar, 'safe', seed),
#             (test_set, greedy_safe, behavior_policy, 'test', 'greedy', shaped, shaping_scalar, 'safe', seed),
            (train_set, sm_unsafe, behavior_policy, 'train', 'softmax', shaped, shaping_scalar, 'unsafe', seed),
            (test_set, sm_unsafe, behavior_policy, 'test', 'softmax', shaped, shaping_scalar, 'unsafe', seed),
            (train_set, sm_safe, behavior_policy, 'train', 'softmax', shaped, shaping_scalar, 'safe', seed),
            (test_set, sm_safe, behavior_policy, 'test', 'softmax', shaped, shaping_scalar, 'safe', seed),
        ]
        
        if shaping == 'unshaped':
            evaluations += [
                (train_set, behavior_train_policy, behavior_train_policy, 'train', 'observed', shaped, shaping_scalar, 'unsafe', seed),
                (test_set, behavior_test_policy, behavior_test_policy, 'test', 'observed', shaped, shaping_scalar, 'unsafe', seed),
                (train_set, behavior_train_policy, behavior_policy, 'train', 'behavior', shaped, shaping_scalar, 'unsafe', seed),
                (test_set, behavior_train_policy, behavior_policy, 'test', 'behavior', shaped, shaping_scalar, 'unsafe', seed),
                (train_set, behavior_safe_train, behavior_policy, 'train', 'behavior', shaped, shaping_scalar, 'safe', seed),
                (test_set, behavior_safe_train, behavior_policy, 'test', 'behavior', shaped, shaping_scalar, 'safe', seed),
            ]
        
        for ds, evaluation_policy, behavior_policy, *config in evaluations:
            mean, var, _ = ope.wis_policy(ds, evaluation_policy, behavior_policy)
            am = ope.am(ds, evaluation_policy, behavior_policy, delta=0.05)
            hcope5 = ope.hcope(ds, evaluation_policy, behavior_policy, delta=0.05, c=5)
            print(','.join(map(str, (*config, mean, var, am, hcope5, len(train_set), len(test_set), train_set.icustay_id.nunique(), test_set.icustay_id.nunique()))))
            # TODO: write result to file with config

train,softmax,True,0.1,unsafe,0,83.80257125640732,0.17476231173206214,227.89137133773562,-497.44346001008415,82798,28261,5239,1778
test,softmax,True,0.1,unsafe,0,61.062608077499114,305.41675624238826,-93.20653478367618,-113.65242122488755,82798,28261,5239,1778
train,softmax,True,0.1,safe,0,83.49142652684543,0.0026085798818080304,-93.1987125870041,-821.7039873500379,82798,28261,5239,1778
test,softmax,True,0.1,safe,0,83.37745994922007,0.0114200075382849,-93.206534790699,-580.7891341360822,82798,28261,5239,1778
train,softmax,True,0.5,unsafe,0,83.80257125640732,0.17476231173206214,227.89137133773562,-497.44346001008415,82798,28261,5239,1778
test,softmax,True,0.5,unsafe,0,61.062608077499114,305.41675624238826,-93.20653478367618,-113.65242122488755,82798,28261,5239,1778
train,softmax,True,0.5,safe,0,83.49142652684543,0.0026085798818080304,-93.1987125870041,-821.7039873500379,82798,28261,5239,1778
test,softmax,True,0.5,safe,0,83.37745994922007,0.0114200075382849,-93.206534790699,-580.78913413

test,softmax,True,0.1,unsafe,2,83.15120596315647,0.15322980314976148,-93.20653477798336,-198.27907243270545,83691,27525,5284,1737
train,softmax,True,0.1,safe,2,84.69377401874522,1.5495686876603851,-93.206534790699,-524.2080364831268,83691,27525,5284,1737
test,softmax,True,0.1,safe,2,80.36802503683238,18.30963952655378,-93.206534790699,-108.95237882027884,83691,27525,5284,1737
train,softmax,True,0.5,unsafe,2,85.61053981288705,1.2484630525198255,800.5883259842884,-1138900.3380706948,83691,27525,5284,1737
test,softmax,True,0.5,unsafe,2,83.15120596315647,0.15322980314976148,-93.20653477798336,-198.27907243270545,83691,27525,5284,1737
train,softmax,True,0.5,safe,2,84.69377401874522,1.5495686876603851,-93.206534790699,-524.2080364831268,83691,27525,5284,1737
test,softmax,True,0.5,safe,2,80.36802503683238,18.30963952655378,-93.206534790699,-108.95237882027884,83691,27525,5284,1737
train,softmax,True,1.0,unsafe,2,84.54277915537219,0.7661032222771464,766.1052076594701,-7730188.950452557,83691,2

train,softmax,True,0.1,safe,4,83.82691904050412,0.13490701330463167,-93.20493479448868,-109.69011438177577,84557,27294,5335,1735
test,softmax,True,0.1,safe,4,83.33864060325473,0.025250273190896372,-93.206534790699,-98.77194559370899,84557,27294,5335,1735
train,softmax,True,0.5,unsafe,4,85.70189997871282,1.8557293461734696,757.0278958262543,-86.92797957500278,84557,27294,5335,1735
test,softmax,True,0.5,unsafe,4,-61.19849213780116,476.92794619918914,-93.20653478647895,-702.1111696290634,84557,27294,5335,1735
train,softmax,True,0.5,safe,4,83.82691904050412,0.13490701330463167,-93.20493479448868,-109.69011438177577,84557,27294,5335,1735
test,softmax,True,0.5,safe,4,83.33864060325473,0.025250273190896372,-93.206534790699,-98.77194559370899,84557,27294,5335,1735
train,softmax,True,1.0,unsafe,4,85.81259769203058,1.6099123561763227,631.1181496060398,-68.59946601319439,84557,27294,5335,1735
test,softmax,True,1.0,unsafe,4,-44.18835152166712,1051.6151400085673,-93.2065347879032,-294.1257223805272

train,softmax,True,0.1,safe,6,83.4513774365393,2.615802602452123e-12,-93.206534790699,-105.99207939792663,83070,28092,5257,1768
test,softmax,True,0.1,safe,6,87.43897379399597,2.1609954840101624,-93.206534790699,-97.7541249846436,83070,28092,5257,1768
train,softmax,True,0.5,unsafe,6,83.45145020308783,1.0479818394761786e-08,262.14337909884097,-36.705154613534404,83070,28092,5257,1768
test,softmax,True,0.5,unsafe,6,-12.65573230381005,2620.7329420452916,-93.20653467919254,-299.54707403021166,83070,28092,5257,1768
train,softmax,True,0.5,safe,6,83.4513774365393,2.615802602452123e-12,-93.206534790699,-105.99207939792663,83070,28092,5257,1768
test,softmax,True,0.5,safe,6,87.43897379399597,2.1609954840101624,-93.206534790699,-97.7541249846436,83070,28092,5257,1768
train,softmax,True,1.0,unsafe,6,83.4514953960318,2.551791717324839e-08,233.01633697674748,-40.84087430575396,83070,28092,5257,1768
test,softmax,True,1.0,unsafe,6,55.10417047931482,1138.0415623687563,-93.20653468105021,-149.77581062889

train,softmax,True,0.1,safe,8,83.45147543603821,1.5400178538557937e-08,-93.206534790699,-122.2265153417232,83553,27912,5277,1768
test,softmax,True,0.1,safe,8,83.44684783952377,4.256023646536525e-05,-93.206534790699,-134.5587320078644,83553,27912,5277,1768
train,softmax,True,0.5,unsafe,8,83.4647007398053,0.00011226653916252348,539.8287658300176,-29.43752767800801,83553,27912,5277,1768
test,softmax,True,0.5,unsafe,8,88.65101051245074,5.413967824078202,-93.20653333560206,-98.75857226223117,83553,27912,5277,1768
train,softmax,True,0.5,safe,8,83.45147543603821,1.5400178538557937e-08,-93.206534790699,-122.2265153417232,83553,27912,5277,1768
test,softmax,True,0.5,safe,8,83.44684783952377,4.256023646536525e-05,-93.206534790699,-134.5587320078644,83553,27912,5277,1768
train,softmax,True,1.0,unsafe,8,83.4662920999599,0.0002216360366286866,540.7048526438462,-32.327763556805344,83553,27912,5277,1768
test,softmax,True,1.0,unsafe,8,89.44522901751493,3.5333474911210603,-93.2065340154166,-100.96203789

In [None]:
%%time
# seeds = range(10)
seeds = (0,)
for seed in seeds:
    for shaping in ('unshaped',):
        np.random.seed(seed)
        test_set = add_traj_return(pd.read_csv(test_set_file.format(shaping,seed)))
        train_set = add_traj_return(pd.read_csv(train_set_file.format(shaping,seed)))
        
        behavior_policy = utils.repair_policy_uniform(joblib.load(behavior_policy_file.format(seed,'')))
        behavior_train_policy = joblib.load(behavior_policy_train_file.format(seed,''))
        behavior_test_policy = joblib.load(behavior_policy_test_file.format(seed,''))
        behavior_safe_train = safety.repaired_safe(behavior_train_policy, behavior_train_policy)

        greedy_unsafe = utils.repair_unsupported_greedy_policy(
            joblib.load(greedy_policy_file.format(seed, '')),
            train_set
        )
        greedy_safe = safety.repaired_safe(greedy_unsafe, behavior_train_policy, greedy=True)
        sm_unsafe = joblib.load(sm_policy_file.format(seed, ''))
        sm_safe = safety.repaired_safe(sm_unsafe, behavior_train_policy)
                
        evaluations = (
            (train_set, behavior_train_policy, behavior_train_policy, 'train', 'observed', 'unshaped', 'unsafe', seed),
            (test_set, behavior_test_policy, behavior_test_policy, 'test', 'observed', 'unshaped', 'unsafe', seed),
            (train_set, behavior_train_policy, behavior_policy, 'train', 'behavior', 'unshaped', 'unsafe', seed),
            (test_set, behavior_train_policy, behavior_policy, 'test', 'behavior', 'unshaped', 'unsafe', seed),
            (train_set, behavior_safe_train, behavior_policy, 'train', 'behavior', 'unshaped', 'safe', seed),
            (test_set, behavior_safe_train, behavior_policy, 'test', 'behavior', 'unshaped', 'safe', seed),
            (train_set, greedy_unsafe, behavior_policy, 'train', 'greedy', 'unshaped', 'unsafe', seed),
            (test_set, greedy_unsafe, behavior_policy, 'test', 'greedy', 'unshaped', 'unsafe', seed),
            (train_set, greedy_safe, behavior_policy, 'train', 'greedy', 'unshaped', 'safe', seed),
            (test_set, greedy_safe, behavior_policy, 'test', 'greedy', 'unshaped', 'safe', seed),
            (train_set, sm_unsafe, behavior_policy, 'train', 'softmax', 'unshaped', 'unsafe', seed),
            (test_set, sm_unsafe, behavior_policy, 'test', 'softmax', 'unshaped', 'unsafe', seed),
            (train_set, sm_safe, behavior_policy, 'train', 'softmax', 'unshaped', 'safe', seed),
            (test_set, sm_safe, behavior_policy, 'test', 'softmax', 'unshaped', 'safe', seed)
        )

# HCOPE hyperparameter optimization

In [None]:
cs = (.01, .1, 1, 2, 5, 10, 20, 50, 1e2, 1.2e2, 1.5e2, 1.75e2, 1e3, 1.5e3, 1e4, 1e5, 1e6)
us = True
delta = .05
n_post = train_set.icustay_id.nunique()
print('sm_unsafe')
sm_unsafe_results = [ope.hcope_prediction(test_set, sm_unsafe, behavior_policy, n_post=n_post, c=c, delta=delta, unscale=us) for c in cs]
print('sm_safe')
sm_safe_results = [ope.hcope_prediction(test_set, sm_safe, behavior_policy, n_post=n_post,  c=c, delta=delta, unscale=us) for c in cs]
print('observed')
observed_results = [ope.hcope_prediction(test_set, behavior_policy, behavior_policy, n_post=n_post, c=c, delta=delta, unscale=us) for c in cs]
print('bh')
bh_unsafe_results = [ope.hcope_prediction(test_set, behavior_train_policy, behavior_policy, n_post=n_post, c=c, delta=delta, unscale=us) for c in cs]
print('bh_safe')
bh_safe_results = [ope.hcope_prediction(test_set, behavior_safe_train, behavior_policy, n_post=n_post, c=c, delta=delta, unscale=us) for c in cs]
print('sm_unsafe')
greedy_unsafe_results = [ope.hcope_prediction(test_set, greedy_unsafe, behavior_policy, n_post=n_post, c=c, delta=delta, unscale=us) for c in cs]
print('sm_safe')
greedy_safe_results = [ope.hcope_prediction(test_set, greedy_safe, behavior_policy, n_post=n_post, c=c, delta=delta, unscale=us) for c in cs]

In [None]:
algorithms = sorted((
    ('softmax-unsafe', sm_unsafe_results),
    ('softmax-safe', sm_safe_results),
    ('observed', observed_results),
    ('behavior-unsafe', bh_unsafe_results),
    ('behavior-safe', bh_safe_results),
#     ('greedy-unsafe', greedy_unsafe_results),
#     ('greedy-safe', greedy_safe_results),
))
colors = sns.color_palette(n_colors=len(algorithms))
for i, (label, results) in enumerate(algorithms):
    ax = plt.scatter(x=cs, y=results, c=colors[i], label=label)
plt.xscale('log')
plt.yscale('symlog')
plt.xlabel('c')
plt.ylabel('Lower Bound')
plt.title('Optimization of c parameter')
plt.legend(loc=2, bbox_to_anchor=(1.0,1.0))
plt.axhline(-100, c='black', alpha=.2, linestyle='--')
plt.axhline(100, c='black', alpha=.2, linestyle='--')
plt.show() 

# ANALYSIS

In [None]:
w_sm_unsafe_train = ope.ois_traj_weights(train_set, sm_unsafe, behavior_policy)
patient_id = train_set[train_set['mort90day'] == 't'].groupby('icustay_id').first().sample(1).index[0]
patient = train_set[train_set.icustay_id == patient_id]
action_dists = []
for s in patient.state:
    action_dists.append(sm_unsafe[s, :])
action_dists = np.array(action_dists)
c_a_probs = []
alt_act = []
for c_a, pi_a in zip(patient.action_discrete, action_dists):
    c_a_probs.append(pi_a[c_a])
    alt_act.append(pi_a.argmax())
c_a_probs == np.array(c_a_probs)
i = 0
print("Importance weight:", w_sm_unsafe_train[patient.index[0]])
for c_a, e_a in zip(patient.action_discrete, alt_act):
    print(i, end='')
    print('\tclincn', utils.to_action_ranges(c_a))
    print('\tevaltn', utils.to_action_ranges(e_a))
    i += 1

In [None]:
patient_id = train_set[train_set['mort90day'] == 't'].groupby('icustay_id').first().sample(1).index[0]
patient = train_set[train_set.icustay_id == patient_id]
action_dists = []
for s in patient.state:
    action_dists.append(sm_unsafe[s, :])
action_dists = np.array(action_dists)
c_a_probs = []
alt_act = []
for c_a, pi_a in zip(patient.action_discrete, action_dists):
    c_a_probs.append(pi_a[c_a])
    alt_act.append(pi_a.argmax())
c_a_probs == np.array(c_a_probs)
i = 0
print("Importance weight:", w_sm_unsafe_train[patient.index[0]])
for c_a, e_a in zip(patient.action_discrete, alt_act):
    print(i, end='')
    print('\tclincn', utils.to_action_ranges(c_a))
    print('\tevaltn', utils.to_action_ranges(e_a))
    i += 1

In [None]:
patient_id = train_set[train_set['mort90day'] == 't'].groupby('icustay_id').first().sample(1).index[0]
patient = train_set[train_set.icustay_id == patient_id]
action_dists = []
for s in patient.state:
    action_dists.append(sm_unsafe[s, :])
action_dists = np.array(action_dists)
c_a_probs = []
alt_act = []
for c_a, pi_a in zip(patient.action_discrete, action_dists):
    c_a_probs.append(pi_a[c_a])  (-5.788365278029741, 18.05224023727751)
    alt_act.append(pi_a.argmax())
c_a_probs == np.array(c_a_probs)
i = 0
print("Weight:", w_sm_unsafe_train[patient.index[0]])
for c_a, e_a in zip(patient.action_discrete, alt_act):
    print(i, end='')
    print('\tclincn', utils.to_action_ranges(c_a), c_a_probs[i])
    print('\tevaltn', utils.to_action_ranges(e_a))
    i += 1

# Safety analysis

In [None]:
compliant_behavior_score = behavior_policy.copy()
for action_id in range(7**3):
    if not safety.action_compliance_map[action_id]:
        compliant_behavior_score[:, action_id] = float('-inf')

compliant_behavior_score[compliant_behavior_score == 0.0] = float('-inf')
compliant_behavior_policy = scipy.special.softmax(compliant_behavior_score, axis=1)

assert compliant_behavior_policy.shape == (650, 7**3)
print(compliant_behavior_policy.sum(axis=1).min(), compliant_behavior_policy.sum(axis=1).max())

In [None]:
((compliant_behavior_policy > 0.0) & (behavior_policy ==0.0))

In [None]:
((behavior_policy == 0.0) & (compliant_behavior_policy > 0.0)).any()

In [None]:
print("compliant behavior, WIS")
train_mean, train_var, train_weights = ope.wis_policy(train_set, compliant_behavior_policy, behavior_policy)

print("pi_b on train:", (train_mean, train_var, (train_weights > 0.001).sum()))
test_mean, test_var, test_weights = ope.wis_policy(test_set, compliant_behavior_policy, behavior_policy)
print("pi_b on test :", (test_mean, test_var, (test_weights > 0.001).sum()))

In [None]:
safety.state_compliance_clinical(train_set, safety.avg_clinical_timestep).mean()

In [None]:
safety.action_compliance_clinical(train_set).mean()

In [None]:
compliance_fs = [
        safety.tv_compliance_clinical,
        safety.rr_compliance_clinical,
        safety.spo2_compliance_clinical,
        safety.pplat_compliance_clinical,
        safety.ph_compliance_clinical,
]
compliance_scores = [f(train_set) for f in compliance_fs]

In [None]:
for s in compliance_scores:
    print(s.mean())

In [None]:
train_set.tv_derived.min(), train_set.tv_derived.max(), train_set.tv_derived.mean(), train_set.tv_derived.median()

In [None]:
shaping = 'avgpotential-1.0'
seed = 0
np.random.seed(seed)
test_set = add_traj_return(pd.read_csv(test_set_file.format('unshaped',seed)))
train_set = add_traj_return(pd.read_csv(train_set_file.format('unshaped',seed)))

behavior_policy = utils.repair_policy_uniform(joblib.load(behavior_policy_file.format(seed,'')))
behavior_train_policy = joblib.load(behavior_policy_train_file.format(seed,''))
behavior_test_policy = joblib.load(behavior_policy_test_file.format(seed,''))
behavior_safe_train = safety.repaired_safe(behavior_train_policy, behavior_train_policy)

greedy_unsafe = utils.repair_unsupported_greedy_policy(
    joblib.load(greedy_policy_file.format(seed, shaping_fname[shaping])),
    train_set
)
greedy_safe = safety.repaired_safe(greedy_unsafe, behavior_train_policy, greedy=True)
sm_unsafe = joblib.load(sm_policy_file.format(seed, shaping_fname[shaping]))
sm_safe = safety.repaired_safe(sm_unsafe, behavior_train_policy)

mean, var, _ = ope.wis_policy(test_set, sm_unsafe, behavior_policy)
mean,var

In [None]:
mean, var, _ = ope.wis_policy(train_set, sm_unsafe, behavior_policy)
mean, var

In [None]:
mean, var, _ = ope.wis_policy(train_set, sm_unsafe, behavior_policy)
mean, var

In [None]:
mean, var, _ = ope.wis_policy(test_set, greedy_unsafe, behavior_policy)
mean,var

In [None]:
shaping = 'unshaped'
seed = 0
np.random.seed(seed)
test_set = add_traj_return(pd.read_csv(test_set_file.format('unshaped',seed)))
train_set = add_traj_return(pd.read_csv(train_set_file.format('unshaped',seed)))

behavior_policy = utils.repair_policy_uniform(joblib.load(behavior_policy_file.format(seed,'')))
behavior_train_policy = joblib.load(behavior_policy_train_file.format(seed,''))
behavior_test_policy = joblib.load(behavior_policy_test_file.format(seed,''))
behavior_safe_train = safety.repaired_safe(behavior_train_policy, behavior_train_policy)

greedy_unsafe = utils.repair_unsupported_greedy_policy(
    joblib.load(greedy_policy_file.format(seed, shaping_fname[shaping])),
    train_set
)
greedy_safe = safety.repaired_safe(greedy_unsafe, behavior_train_policy, greedy=True)
sm_unsafe = joblib.load(sm_policy_file.format(seed, shaping_fname[shaping]))
sm_safe = safety.repaired_safe(sm_unsafe, behavior_train_policy)

mean, var, _ = ope.wis_policy(test_set, sm_unsafe, behavior_policy)
mean,var

In [None]:
mean, var, _ = ope.wis_policy(train_set, sm_unsafe, behavior_policy)
mean, var