In [1]:
import pandas as pd
import numpy as np
from numpy.lib.stride_tricks import sliding_window_view
import os, pickle
import matplotlib.pyplot as plt
import tensorflow as tf

from sklearn.cluster import KMeans

from sklearn.preprocessing import OneHotEncoder

from itertools import product

import sys, os

import trajectory as T                      # trajectory generation
import optimizer as O                       # stochastic gradient descent optimizer
import solver as S                          # MDP solver (value-iteration)
import plot as P


num_data = 355504


np.random.seed(66)

def to_interval(istr):
    c_left = istr[0]=='['
    c_right = istr[-1]==']'
    closed = {(True, False): 'left',
              (False, True): 'right',
              (True, True): 'both',
              (False, False): 'neither'
              }[c_left, c_right]
    left, right = map(pd.to_datetime, istr[1:-1].split(','))
    return pd.Interval(left, right, closed)

re_split = False
frac = [0.4,0.2,0.4]
assert np.sum(frac) == 1
frac = np.cumsum(frac)
print (frac)
data_save_path= 'data/'

def sliding(gs, window_size = 6):
    npr_l = []
    for g in gs:
        npr = np.concatenate([np.zeros([window_size-1, g.shape[1]]),g])
        npr_l.append(sliding_window_view(npr, (window_size, g.shape[1])).squeeze(1))
    return np.vstack(npr_l)

2023-11-28 19:54:18.636804: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


[0.4 0.6 1. ]


# LOADING THE DATA

In [2]:
# if re_split:


aggr_df = pd.read_csv('mimic_iv_hypotensive_cut2.csv',sep = ',', header = 0,converters={1:to_interval}).set_index(['stay_id','time']).sort_index()
# create action bins (four actions in total)
aggr_df['action'] = aggr_df['bolus(binary)']*2 + aggr_df['vaso(binary)']
all_idx = np.random.permutation(aggr_df.index.get_level_values(0).unique())
train_df = aggr_df.loc[all_idx[:int(len(all_idx)*frac[0])]].sort_index()
test_df = aggr_df.loc[all_idx[int(len(all_idx)*frac[0]):int(len(all_idx)*frac[1])]].sort_index()
valid_df = aggr_df.loc[all_idx[int(len(all_idx)*frac[1]):]].sort_index()
# print (np.unique(train_df['action'],return_counts=True)[1]*1./len(train_df))
# pickle.dump([train_df, test_df, valid_df], open(data_save_path+'processed_mimic_hyp_2.pkl','wb'))
drop_columns = ['vaso(amount)','bolus(amount)',\
            'any_treatment(binary)','vaso(binary)','bolus(binary)']




# LOOKING AT THE DATA

In [3]:
## print(len(aggr_df))
aggr_df.head()

Unnamed: 0_level_0,Unnamed: 1_level_0,creatinine,fraction_inspired_oxygen,lactate,urine_output,alanine_aminotransferase,asparate_aminotransferase,mean_blood_pressure,diastolic_blood_pressure,systolic_blood_pressure,gcs,partial_pressure_of_oxygen,heart_rate,temperature,respiratory_rate,vaso(binary),vaso(amount),bolus(binary),bolus(amount),any_treatment(binary),action
stay_id,time,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
30001446,"[2186-04-12 03:49:00, 2186-04-12 04:49:00)",,,,,,,61.0,56.0,75.0,,,82.0,36.722222,22.0,0.0,0.0,1.0,1000.0,1.0,2.0
30001446,"[2186-04-12 04:49:00, 2186-04-12 05:49:00)",2.7,,,40.0,38.0,114.0,63.0,46.0,96.0,15.0,,80.0,,22.0,1.0,0.077612,0.0,0.0,1.0,1.0
30001446,"[2186-04-12 05:49:00, 2186-04-12 06:49:00)",,,,100.0,,,99.0,75.0,152.0,,,79.0,,19.0,1.0,0.238806,0.0,0.0,1.0,1.0
30001446,"[2186-04-12 06:49:00, 2186-04-12 07:49:00)",,,1.7,,,,72.0,55.5,107.0,,98.0,83.0,,18.0,1.0,0.166559,0.0,0.0,1.0,1.0
30001446,"[2186-04-12 07:49:00, 2186-04-12 08:49:00)",,,,40.0,,,70.0,53.0,105.0,,,77.0,35.944444,21.0,1.0,0.172458,0.0,0.0,1.0,1.0


# Data for patient with hypotension, two treatment vasopressors and a bolus dose of epinephrine depending on certain features of the patient. In order to do IRL we need to discretize the action space:
* Action 0 = No treatment
* Action 1 = Vaso
* Action 2 = Bolus
* Action 3 = Vaso + Bolus



# Data cleaning for taking care of missing values

In [4]:
# for now drop indicators about bolus and vaso
train_df = train_df.drop(columns=drop_columns)
test_df = test_df.drop(columns=drop_columns)
valid_df = valid_df.drop(columns=drop_columns)

#### imputation
impute_table = pd.read_csv('mimic_iv_hypotensive_cut2_impute_table.csv',sep=',',header=0).set_index(['feature'])
train_df = train_df.fillna(method='ffill')
test_df = test_df.fillna(method='ffill')
valid_df = valid_df.fillna(method='ffill')




for f in impute_table.index:
    train_df[f] = train_df[f].fillna(value = impute_table.loc[f].values[0])
    test_df[f] = test_df[f].fillna(value = impute_table.loc[f].values[0])
    valid_df[f] = valid_df[f].fillna(value = impute_table.loc[f].values[0])

# Add observational and action ambiguity

In [5]:
train_df.head()

Unnamed: 0_level_0,Unnamed: 1_level_0,creatinine,fraction_inspired_oxygen,lactate,urine_output,alanine_aminotransferase,asparate_aminotransferase,mean_blood_pressure,diastolic_blood_pressure,systolic_blood_pressure,gcs,partial_pressure_of_oxygen,heart_rate,temperature,respiratory_rate,action
stay_id,time,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1
30004811,"[2139-10-06 10:40:29, 2139-10-06 11:40:29)",1.0,0.21,1.8,80.0,34.0,40.0,77.0,59.0,118.0,11.0,112.0,86.0,37.0,19.0,0.0
30004811,"[2139-10-06 11:40:29, 2139-10-06 12:40:29)",1.0,0.21,1.8,80.0,34.0,40.0,77.0,59.0,118.0,11.0,112.0,86.0,37.0,19.0,0.0
30004811,"[2139-10-06 12:40:29, 2139-10-06 13:40:29)",1.0,0.21,3.0,80.0,34.0,40.0,77.0,59.0,118.0,11.0,272.0,86.0,37.0,19.0,0.0
30004811,"[2139-10-06 13:40:29, 2139-10-06 14:40:29)",1.0,0.21,3.0,80.0,34.0,40.0,77.0,59.0,118.0,11.0,272.0,86.0,37.0,19.0,0.0
30004811,"[2139-10-06 14:40:29, 2139-10-06 15:40:29)",1.0,0.21,3.0,80.0,34.0,40.0,77.0,59.0,118.0,11.0,272.0,86.0,37.0,19.0,0.0


In [6]:
test_df.head()

Unnamed: 0_level_0,Unnamed: 1_level_0,creatinine,fraction_inspired_oxygen,lactate,urine_output,alanine_aminotransferase,asparate_aminotransferase,mean_blood_pressure,diastolic_blood_pressure,systolic_blood_pressure,gcs,partial_pressure_of_oxygen,heart_rate,temperature,respiratory_rate,action
stay_id,time,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1
30009339,"[2145-05-09 23:35:00, 2145-05-10 00:35:00)",1.0,0.21,3.4,80.0,34.0,40.0,77.0,59.0,118.0,11.0,340.0,86.0,37.0,19.0,0.0
30009339,"[2145-05-10 00:35:00, 2145-05-10 01:35:00)",1.0,0.21,3.4,80.0,34.0,40.0,77.0,59.0,118.0,11.0,372.0,86.0,37.0,19.0,0.0
30009339,"[2145-05-10 01:35:00, 2145-05-10 02:35:00)",1.0,0.21,3.4,80.0,34.0,40.0,77.0,59.0,118.0,11.0,360.0,86.0,37.0,19.0,0.0
30009339,"[2145-05-10 02:35:00, 2145-05-10 03:35:00)",1.0,0.21,3.4,80.0,34.0,40.0,77.0,59.0,118.0,11.0,449.0,86.0,37.0,19.0,0.0
30009339,"[2145-05-10 03:35:00, 2145-05-10 04:35:00)",1.0,0.21,4.6,80.0,34.0,40.0,77.0,59.0,118.0,11.0,345.0,86.0,37.0,19.0,0.0


In [7]:
valid_df.head()

Unnamed: 0_level_0,Unnamed: 1_level_0,creatinine,fraction_inspired_oxygen,lactate,urine_output,alanine_aminotransferase,asparate_aminotransferase,mean_blood_pressure,diastolic_blood_pressure,systolic_blood_pressure,gcs,partial_pressure_of_oxygen,heart_rate,temperature,respiratory_rate,action
stay_id,time,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1
30001446,"[2186-04-12 03:49:00, 2186-04-12 04:49:00)",1.0,0.21,1.8,80.0,34.0,40.0,61.0,56.0,75.0,11.0,112.0,82.0,36.722222,22.0,2.0
30001446,"[2186-04-12 04:49:00, 2186-04-12 05:49:00)",2.7,0.21,1.8,40.0,38.0,114.0,63.0,46.0,96.0,15.0,112.0,80.0,36.722222,22.0,1.0
30001446,"[2186-04-12 05:49:00, 2186-04-12 06:49:00)",2.7,0.21,1.8,100.0,38.0,114.0,99.0,75.0,152.0,15.0,112.0,79.0,36.722222,19.0,1.0
30001446,"[2186-04-12 06:49:00, 2186-04-12 07:49:00)",2.7,0.21,1.7,100.0,38.0,114.0,72.0,55.5,107.0,15.0,98.0,83.0,36.722222,18.0,1.0
30001446,"[2186-04-12 07:49:00, 2186-04-12 08:49:00)",2.7,0.21,1.7,40.0,38.0,114.0,70.0,53.0,105.0,15.0,98.0,77.0,35.944444,21.0,1.0


In [8]:
import suboptimality as SO

## UNCOMMENT TO ADD OBSERVATIONAL AMBIGUITY

In [9]:
# suboptimal_features = ['creatinine', 'fraction_inspired_oxygen', 'lactate', 'urine_output',
#                   'alanine_aminotransferase', 'asparate_aminotransferase',
#                   'mean_blood_pressure', 'diastolic_blood_pressure',
#                   'systolic_blood_pressure', 'gcs', 'partial_pressure_of_oxygen', 
#                   'heart_rate', 'temperature', 'respiratory_rate']

# train_df = SO.observation_ambiguity(train_df, suboptimal_features, 0.1, 0.3)
# test_df = SO.observation_ambiguity(test_df, suboptimal_features, 0.1, 0.3)
# valid_df = SO.observation_ambiguity(valid_df, suboptimal_features, 0.1, 0.3)

## UNCOMMENT TO ADD ACTION AMBIGUITY

In [None]:
train_df = SO.action_ambiguity(train_df, 0.3)
test_df = SO.action_ambiguity(test_df, 0.3)
valid_df = SO.action_ambiguity(valid_df, 0.3)

In [None]:
train_df.head()

In [None]:
test_df.head()

In [None]:
valid_df.head()

# Normalize data

In [None]:
data_non_normalized_df = pd.concat([train_df, valid_df, test_df], axis=0, ignore_index=False).head(num_data).copy()


#### standard normalization ####
normalize_features = ['creatinine', 'fraction_inspired_oxygen', 'lactate', 'urine_output',
                  'alanine_aminotransferase', 'asparate_aminotransferase',
                  'mean_blood_pressure', 'diastolic_blood_pressure',
                  'systolic_blood_pressure', 'gcs', 'partial_pressure_of_oxygen']
mu, std = (train_df[normalize_features]).mean().values,(train_df[normalize_features]).std().values
train_df[normalize_features] = (train_df[normalize_features] - mu)/std
test_df[normalize_features] = (test_df[normalize_features] - mu)/std
valid_df[normalize_features] = (valid_df[normalize_features] - mu)/std




### create data matrix ####
X_train = train_df.loc[:,train_df.columns!='action']
y_train = train_df['action']

X_test = test_df.loc[:,test_df.columns!='action']
y_test = test_df['action']

X_valid = valid_df.loc[:, valid_df.columns!='action']
y_valid = valid_df['action']

In [None]:
X_df = pd.concat([X_train, X_valid, X_test], axis=0, ignore_index=True).copy()
y_df = pd.concat([y_train, y_valid, y_test], axis=0, ignore_index=True).copy()


In [None]:
data_df = pd.concat([train_df, valid_df, test_df], axis=0, ignore_index=False).copy()
# data_df = data_df.head(num_data).copy()
# X_df = X_df.head(num_data).copy()
# y_df = y_df.head(num_data).copy()


In [None]:
print(len(data_df))
print(len(X_df))
print(len(y_df))
print(len(data_non_normalized_df))

# Normalized version of the data

In [None]:
data_df.head()

# Unormalized version of the data

In [None]:
data_non_normalized_df.head()

# Matrix form of the data (Normalized and features only)

In [None]:
X_df.head()

# Corresponding output data for training BC (corresponding treatments for each data point in X_df)

In [None]:
y_df.head()

# Clustering the feature space to extract a discrete state space form the clusters

In [None]:
num_clusters = 100
kmeans = KMeans(n_clusters= num_clusters , random_state=0)
kmeans.fit(X_df)

In [None]:
# Looking at the values counts for each cluster

np.unique(kmeans.labels_, return_counts = True)[1]

In [None]:
# Assigning each data point to a cluster

X_df['cluster'] = kmeans.labels_.copy()
data_df['cluster'] = kmeans.labels_.copy()
data_non_normalized_df['cluster'] = kmeans.labels_.copy()

In [None]:
print(X_df['cluster'].unique())

In [None]:
data_df.head()

In [None]:
X_df.head()

In [None]:
data_non_normalized_df.head()

## Add Static and Dynamic Occlusion

## UNCOMMENT TO ADD STATIC OCCLUSION

In [None]:
data_df = SO.static_occlusion(data_df, [0, 10, 20, 30, 40, 50, 60, 70, 80, 90])

## UNCOMMENT TO ADD DYNAMIC OCCLUSION

In [None]:
data_df = SO.dynamic_occlusion(data_df, 0.3)

# Converting the data into trajectories to input to an IRL algorithm Note this is the same format of trajectories we used for HW1 and HW2.

In [None]:
unique_stay_ids = data_df.index.get_level_values('stay_id').unique()

trajectories = []


for stay_id in unique_stay_ids:


  states, actions = data_df.loc[stay_id]['cluster'], data_df.loc[stay_id]['action']

  trajectory = []
  for i in range(len(states) - 1):
    trajectory.append((states[i], int(actions[i]), states[i+1] ))

  trajectories.append(T.Trajectory(trajectory))

We need to store all possible terminal states from the trajectories list. (Needed to calculate the normalizing constant in MaxEnt)

In [None]:
terminal_states = []

for traj in trajectories:
  terminal_states.append(traj._t[-1][-1])

terminal_states = list(set(terminal_states))

# Distribution of the treatments given in our data. (Most of the time no treatment is given, might vary on depending on how you cluster the data)

In [None]:
y_df.value_counts()

# BC Policy example

In [None]:
# Convert states and actions to one-hot encoding
state_encoder = OneHotEncoder(sparse=False, categories= [np.arange(num_clusters)])
action_encoder = OneHotEncoder(sparse=False, categories= [np.arange(4)])


states_onehot = state_encoder.fit_transform(X_df['cluster'].to_numpy().reshape(-1, 1))
actions_onehot = action_encoder.fit_transform(y_df.to_numpy().reshape(-1, 1))


# # Define neural network architecture
model = tf.keras.Sequential([
    tf.keras.layers.Dense(32, activation='relu', input_shape=(states_onehot.shape[1],)),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(32, activation='relu'),
    tf.keras.layers.Dense(actions_onehot.shape[1], activation='softmax')  # Output layer with softmax for discrete actions
])

# # Compile the model
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics= ['accuracy'])

# # Train the model
model.fit(states_onehot, actions_onehot,  epochs=5, batch_size=128)

# # Evaluate the model
test_loss = model.evaluate(states_onehot, actions_onehot)
print("Test Loss:", test_loss)


In [None]:
bc_policy = np.argmax(model.predict(state_encoder.transform(np.arange(num_clusters).reshape(-1, 1))), axis =1)
bc_policy

In [None]:
# Default plot code

# fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(32, 16))

# for i, ax in enumerate(axes.flatten()):
#     data_df[data_df['cluster'] == i]['action'].value_counts().plot(kind = 'bar', ax = ax, title = 'Cluster: ' + str(i+ 1))
#     ax.set_ylabel('Counts')
    
fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(32, 16))

labels = ['No action', 'Vaso', 'Bolus', 'Vaso + Bolus']

for i, ax in enumerate(axes.flatten()):
    print(data_df[data_df['cluster'] == i]['action'].value_counts().tolist())

for i, ax in enumerate(axes.flatten()):
    policy = data_df[data_df['cluster'] == i]['action'].value_counts().tolist()
    while len(policy) < 4:
        policy.append(0)
    ax.pie(policy, labels=labels, autopct='%1.1f%%', textprops={'fontsize': 17})
    ax.set_title('Cluster ' + str(i+ 1), fontsize=25)

fig.suptitle("BC Policy", fontsize=50)

# Estimating the Transition Dynamics using the MLE (feel free to play around with the smoothing_value)

In [None]:
smoothing_value = 1

p_transition = np.zeros((num_clusters, num_clusters, 4)) + smoothing_value


for traj in trajectories:

  for tran in traj._t:

    p_transition[tran[0], tran[2], tran[1]] +=1

p_transition = p_transition/ p_transition.sum(axis = 1)[:, np.newaxis, :]

In [None]:
p_transition

# Max Causal Entropy

# Feel free to play with the discount factor

In [None]:
discount = 0.9

In [None]:
from maxent import irl, irl_causal

In [None]:
# set up features: we use one feature vector per state (1 hot encoding for each cluster/state)
features = state_encoder.transform(np.arange(num_clusters).reshape(-1, 1))

# choose our parameter initialization strategy:
#   initialize parameters with constant
init = O.Constant(1.0)

# choose our optimization strategy:
#   we select exponentiated stochastic gradient descent with linear learning-rate decay
optim = O.ExpSga(lr=O.linear_decay(lr0=0.2))

# actually do some inverse reinforcement learning
# reward_maxent = maxent_irl(p_transition, features, terminal_states, trajectories, optim, init, eps= 1e-3)

reward_maxent_causal = irl_causal(p_transition, features, terminal_states, trajectories, optim, init, discount,
               eps=1e-3, eps_svf=1e-4, eps_lap=1e-4)

In [None]:
reward_maxent_causal

In [None]:
v = reward_maxent_causal
normalized_reward_maxent_causal = (v - v.min()) / (v.max() - v.min())
print(normalized_reward_maxent_causal)

In [None]:
reshaped_rewards = normalized_reward_maxent_causal.reshape((10, 10))

fig, ax = plt.subplots(figsize=(8,8))
im = ax.imshow(reshaped_rewards, cmap='inferno')
ax.set_title("MaxEnt Causal Reward Heatmap", fontsize = 20)
colorbar = fig.colorbar(im, fraction=0.046, pad=0.04)
colorbar.set_label(label = "Normalized Values", fontsize = 15)

In [None]:
cluster_sizes = np.zeros(100)

for i in range(num_clusters):
    cluster_sizes[i] = (len(X_df.loc[X_df['cluster'] == i]))

In [None]:
fig, ax = plt.subplots(figsize=(10, 10))
ax.scatter(cluster_sizes, normalized_reward_maxent_causal)
ax.set_title("Reward Value vs. Number of Data Points per Cluster", fontsize=18)
ax.set_xlabel("Number of Data Points", fontsize = 15)
ax.set_ylabel("Normalized Reward Values", fontsize = 15)


current_values = plt.gca().get_xticks()
plt.gca().set_xticklabels(['{:,.0f}'.format(x) for x in current_values])

# GAIL CODE

-need to add mimic.py and transition_mat to gymnasium local installation to run this env

-assumes that trajectories are saved in variable 'trajectories' (I used the same format as above)

In [None]:
from imitation.policies.serialize import load_policy
from imitation.util.util import make_vec_env
from imitation.data.wrappers import RolloutInfoWrapper

#SEED = 42

#need to add mimic.py and transition_mat to gymnasium local installation to use this env!!!
env = make_vec_env(
    "mimic-v0",
    rng=np.random.default_rng(SEED),
    n_envs=8,
    post_wrappers=[
        lambda env, _: RolloutInfoWrapper(env)
    ],
)

from imitation.data.types import Trajectory
import pickle

#with open("mimic-trajectories","rb") as f:
#    trajectories = pickle.load(f)

rollouts = []
#turn trajectories into (state,action) pairs
for gail_trajectory in trajectories:
    obs = np.array(list(gail_trajectory.states()))
    acts = np.array([i for (_,i,_) in gail_trajectory.transitions()])
    rollouts.append(Trajectory(obs=obs, acts=acts, infos=None, terminal=True))

from imitation.algorithms.adversarial.gail import GAIL
from imitation.rewards.reward_nets import BasicRewardNet
from imitation.util.networks import RunningNorm
from stable_baselines3 import PPO
from stable_baselines3.ppo import MlpPolicy
from stable_baselines3.common.evaluation import evaluate_policy

# This is a learner that uses the proximal policy optimization algorithm and we will it use to compute the policy update steps of the GAIL algorithm
learner = PPO(
    env=env,
    policy=MlpPolicy,
    batch_size=64,
    ent_coef=0.0,
    learning_rate=0.0004,
    gamma=0.95,
    n_epochs=5,
    seed=SEED,
)

# This is a neural network that takes a batch of (state, action, next_state) triples and calculates the associated rewards and we will use it to compute the discriminator update steps of the GAIL algorithm
reward_net = BasicRewardNet(
    observation_space=env.observation_space,
    action_space=env.action_space,
    normalize_input_layer=RunningNorm,
)

# imitation implementation of GAIL 
gail_trainer = GAIL(
    demonstrations=rollouts,
    demo_batch_size=1024,
    gen_replay_buffer_capacity=512,
    n_disc_updates_per_round=8,
    venv=env,
    gen_algo=learner,
    reward_net=reward_net,
)

#this takes a while, can run it for fewer iterations
gail_trainer.train(800_000)

#extract policy
GAIL_policy = [learner.predict(i,deterministic=True)[0].item() for i in range(100)]

# Computing the policy induced by your learnt reward

In [None]:
V, Q = S.value_iteration(p_transition, reward_maxent_causal, discount)

In [None]:
Q = Q.reshape((4, num_clusters))

In [None]:
soft_pi_mce = (np.exp(Q)/ np.sum(np.exp(Q), axis = 0)).T

In [None]:
soft_pi_mce

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(32, 16))

labels = ['No action', 'Vaso', 'Bolus', 'Vaso + Bolus']

for i, ax in enumerate(axes.flatten()):
    ax.pie(soft_pi_mce[i], labels=labels, autopct='%1.1f%%', textprops={'fontsize': 17})
    ax.set_title('Cluster ' + str(i+ 1), fontsize=25)

fig.suptitle("MCE Policy", fontsize=50)

In [None]:
policy_mce = np.argmax(Q, axis = 0).reshape(-1, )

In [None]:
policy_mce

# Identify top and bottom 5 reward clusters

In [None]:
k = 5

min_sorted = np.sort(reward_maxent_causal)
min_five = min_sorted[:k]
min_five_indices = []

for j in range(len(min_five)):
    for i in range(len(reward_maxent_causal)):
        if min_five[j] == reward_maxent_causal[i]:
            min_five_indices.append(i)

print(min_five)
print(min_five_indices)

In [None]:
max_sorted = np.sort(reward_maxent_causal)
max_five = max_sorted[-k:]
max_five_indices = []

for j in range(len(max_five)):
    for i in range(len(reward_maxent_causal)):
        if max_five[j] == reward_maxent_causal[i]:
            max_five_indices.append(i)

print(max_five)
print(max_five_indices)

In [None]:
min_five_indices = np.flip(min_five_indices)
max_five_indices = np.flip(max_five_indices)
print(max_five_indices)
print(min_five_indices)

# Average values for heat map

In [None]:
viz_df = data_non_normalized_df
viz_df = viz_df.drop(['action'], axis=1)

color_matrix_y_labels = viz_df.drop(['cluster'], axis=1).columns.tolist()
color_matrix_x_labels = max_five_indices.tolist()
for i in min_five_indices :
    color_matrix_x_labels.append(i)
print(color_matrix_y_labels)
print(color_matrix_x_labels)


color_matrix = []

for i in max_five_indices:
    df = viz_df[viz_df['cluster'] == i].mean()
    lst = df.drop(['cluster']).tolist()
    color_matrix.append(lst)    
    
for i in min_five_indices:
    df = viz_df[viz_df['cluster'] == i].mean()
    lst = df.drop(['cluster']).tolist()
    color_matrix.append(lst)
    
color_matrix = np.transpose(color_matrix)

color_matrix

# Adding policy actions to heat map

In [None]:
bc_action = []
for i in color_matrix_x_labels:
    bc_action.append(bc_policy[i])

print(bc_action)

mce_action = []
for i in color_matrix_x_labels:
    mce_action.append(policy_mce[i])

print(mce_action)

In [None]:
color_matrix = color_matrix.tolist()

In [None]:
color_matrix.append(bc_action)
color_matrix.append(mce_action)

In [None]:
len(color_matrix)

In [None]:
color_matrix_y_labels.append('bc_action')
color_matrix_y_labels.append('mce_action')

In [None]:
color_matrix_y_labels

In [None]:
len(color_matrix_y_labels)
print(color_matrix_y_labels)

In [None]:
normalized_color_matrix = color_matrix

for i in range(len(normalized_color_matrix)):
    lst = normalized_color_matrix[i]
    max_val = max(lst)
    min_val = min(lst)
    if color_matrix_y_labels[i] == 'bc_action' or color_matrix_y_labels[i] == 'mce_action':
        max_val = 3
        min_val = 0
    dif = max_val - min_val
    for j in range(len(lst)):
        normalized_color_matrix[i][j] = normalized_color_matrix[i][j]-min_val
        normalized_color_matrix[i][j] = normalized_color_matrix[i][j]/dif        

In [None]:
normalized_color_matrix

In [None]:
fig, ax = plt.subplots(figsize=(32,20))
im = ax.imshow(color_matrix, cmap='inferno')
ax.set_title("Top and Bottom 5 States", fontsize = 28)
ax.set_yticks(np.arange(0, 16, 1))
ax.set_yticklabels(color_matrix_y_labels, fontsize = 15)
ax.set_xticks(np.arange(0, 10, 1))
ax.set_xticklabels(color_matrix_x_labels, fontsize = 15)
ax.set_ylabel('Features', fontsize = 20)
ax.set_xlabel('Clusters', fontsize = 20)
colorbar = fig.colorbar(im, fraction=0.046, pad=0.04)
colorbar.set_label(label = "Normalized Values", fontsize = 15)