In [1]:
import pandas as pd
import numpy as np
from numpy.lib.stride_tricks import sliding_window_view
import os, pickle
import matplotlib.pyplot as plt

from sklearn.cluster import KMeans
import tensorflow as tf
from sklearn.preprocessing import OneHotEncoder

from itertools import product

import sys, os

import trajectory as T                      # trajectory generation
import optimizer as O                       # stochastic gradient descent optimizer
import solver as S                          # MDP solver (value-iteration)
import plot as P


num_data = 355504


np.random.seed(66)

def to_interval(istr):
    c_left = istr[0]=='['
    c_right = istr[-1]==']'
    closed = {(True, False): 'left',
              (False, True): 'right',
              (True, True): 'both',
              (False, False): 'neither'
              }[c_left, c_right]
    left, right = map(pd.to_datetime, istr[1:-1].split(','))
    return pd.Interval(left, right, closed)

re_split = False
frac = [0.4,0.2,0.4]
assert np.sum(frac) == 1
frac = np.cumsum(frac)
print (frac)
data_save_path= 'data/'

def sliding(gs, window_size = 6):
    npr_l = []
    for g in gs:
        npr = np.concatenate([np.zeros([window_size-1, g.shape[1]]),g])
        npr_l.append(sliding_window_view(npr, (window_size, g.shape[1])).squeeze(1))
    return np.vstack(npr_l)

[0.4 0.6 1. ]


# LOADING THE DATA

In [2]:
# if re_split:


aggr_df = pd.read_csv('mimic_iv_hypotensive_cut2.csv',sep = ',', header = 0,converters={1:to_interval}).set_index(['stay_id','time']).sort_index()
# create action bins (four actions in total)
aggr_df['action'] = aggr_df['bolus(binary)']*2 + aggr_df['vaso(binary)']
all_idx = np.random.permutation(aggr_df.index.get_level_values(0).unique())
train_df = aggr_df.loc[all_idx[:int(len(all_idx)*frac[0])]].sort_index()
test_df = aggr_df.loc[all_idx[int(len(all_idx)*frac[0]):int(len(all_idx)*frac[1])]].sort_index()
valid_df = aggr_df.loc[all_idx[int(len(all_idx)*frac[1]):]].sort_index()
# print (np.unique(train_df['action'],return_counts=True)[1]*1./len(train_df))
# pickle.dump([train_df, test_df, valid_df], open(data_save_path+'processed_mimic_hyp_2.pkl','wb'))
drop_columns = ['vaso(amount)','bolus(amount)',\
            'any_treatment(binary)','vaso(binary)','bolus(binary)']




# LOOKING AT THE DATA

In [3]:
print(len(aggr_df))
aggr_df.head()

355504


Unnamed: 0_level_0,Unnamed: 1_level_0,creatinine,fraction_inspired_oxygen,lactate,urine_output,alanine_aminotransferase,asparate_aminotransferase,mean_blood_pressure,diastolic_blood_pressure,systolic_blood_pressure,gcs,partial_pressure_of_oxygen,heart_rate,temperature,respiratory_rate,vaso(binary),vaso(amount),bolus(binary),bolus(amount),any_treatment(binary),action
stay_id,time,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
30001446,"[2186-04-12 03:49:00, 2186-04-12 04:49:00)",,,,,,,61.0,56.0,75.0,,,82.0,36.722222,22.0,0.0,0.0,1.0,1000.0,1.0,2.0
30001446,"[2186-04-12 04:49:00, 2186-04-12 05:49:00)",2.7,,,40.0,38.0,114.0,63.0,46.0,96.0,15.0,,80.0,,22.0,1.0,0.077612,0.0,0.0,1.0,1.0
30001446,"[2186-04-12 05:49:00, 2186-04-12 06:49:00)",,,,100.0,,,99.0,75.0,152.0,,,79.0,,19.0,1.0,0.238806,0.0,0.0,1.0,1.0
30001446,"[2186-04-12 06:49:00, 2186-04-12 07:49:00)",,,1.7,,,,72.0,55.5,107.0,,98.0,83.0,,18.0,1.0,0.166559,0.0,0.0,1.0,1.0
30001446,"[2186-04-12 07:49:00, 2186-04-12 08:49:00)",,,,40.0,,,70.0,53.0,105.0,,,77.0,35.944444,21.0,1.0,0.172458,0.0,0.0,1.0,1.0


# Data for patient with hypotension, two treatment vasopressors and a bolus dose of epinephrine depending on certain features of the patient. In order to do IRL we need to discretize the action space:
* Action 0 = No treatment
* Action 1 = Vaso
* Action 2 = Bolus
* Action 3 = Vaso + Bolus



# We also need to do a bit of data cleaning such as taking care of missing values before running our algorithms on it

In [4]:
# for now drop indicators about bolus and vaso
train_df = train_df.drop(columns=drop_columns)
test_df = test_df.drop(columns=drop_columns)
valid_df = valid_df.drop(columns=drop_columns)

#### imputation
impute_table = pd.read_csv('mimic_iv_hypotensive_cut2_impute_table.csv',sep=',',header=0).set_index(['feature'])
train_df = train_df.fillna(method='ffill')
test_df = test_df.fillna(method='ffill')
valid_df = valid_df.fillna(method='ffill')




for f in impute_table.index:
    train_df[f] = train_df[f].fillna(value = impute_table.loc[f].values[0])
    test_df[f] = test_df[f].fillna(value = impute_table.loc[f].values[0])
    valid_df[f] = valid_df[f].fillna(value = impute_table.loc[f].values[0])


data_non_normalized_df = pd.concat([train_df, valid_df, test_df], axis=0, ignore_index=False).head(num_data).copy()


#### standard normalization ####
normalize_features = ['creatinine', 'fraction_inspired_oxygen', 'lactate', 'urine_output',
                  'alanine_aminotransferase', 'asparate_aminotransferase',
                  'mean_blood_pressure', 'diastolic_blood_pressure',
                  'systolic_blood_pressure', 'gcs', 'partial_pressure_of_oxygen']
mu, std = (train_df[normalize_features]).mean().values,(train_df[normalize_features]).std().values
train_df[normalize_features] = (train_df[normalize_features] - mu)/std
test_df[normalize_features] = (test_df[normalize_features] - mu)/std
valid_df[normalize_features] = (valid_df[normalize_features] - mu)/std




### create data matrix ####
X_train = train_df.loc[:,train_df.columns!='action']
y_train = train_df['action']

X_test = test_df.loc[:,test_df.columns!='action']
y_test = test_df['action']

X_valid = valid_df.loc[:, valid_df.columns!='action']
y_valid = valid_df['action']

In [5]:
X_df = pd.concat([X_train, X_valid, X_test], axis=0, ignore_index=True).copy()
y_df = pd.concat([y_train, y_valid, y_test], axis=0, ignore_index=True).copy()


In [6]:
data_df = pd.concat([train_df, valid_df, test_df], axis=0, ignore_index=False).copy()
# data_df = data_df.head(num_data).copy()
# X_df = X_df.head(num_data).copy()
# y_df = y_df.head(num_data).copy()


In [7]:
print(len(data_df))
print(len(X_df))
print(len(y_df))
print(len(data_non_normalized_df))

355504
355504
355504
355504


# Normalized version of the data

In [8]:
data_df.head()

Unnamed: 0_level_0,Unnamed: 1_level_0,creatinine,fraction_inspired_oxygen,lactate,urine_output,alanine_aminotransferase,asparate_aminotransferase,mean_blood_pressure,diastolic_blood_pressure,systolic_blood_pressure,gcs,partial_pressure_of_oxygen,heart_rate,temperature,respiratory_rate,action
stay_id,time,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1
30004811,"[2139-10-06 10:40:29, 2139-10-06 11:40:29)",-0.422008,-1.760743,-0.182521,-0.225783,-0.288689,-0.265706,0.404836,0.391566,0.36885,-2.479374,-0.186486,86.0,37.0,19.0,0.0
30004811,"[2139-10-06 11:40:29, 2139-10-06 12:40:29)",-0.422008,-1.760743,-0.182521,-0.225783,-0.288689,-0.265706,0.404836,0.391566,0.36885,-2.479374,-0.186486,86.0,37.0,19.0,0.0
30004811,"[2139-10-06 12:40:29, 2139-10-06 13:40:29)",-0.422008,-1.760743,0.360532,-0.225783,-0.288689,-0.265706,0.404836,0.391566,0.36885,-2.479374,2.356989,86.0,37.0,19.0,0.0
30004811,"[2139-10-06 13:40:29, 2139-10-06 14:40:29)",-0.422008,-1.760743,0.360532,-0.225783,-0.288689,-0.265706,0.404836,0.391566,0.36885,-2.479374,2.356989,86.0,37.0,19.0,0.0
30004811,"[2139-10-06 14:40:29, 2139-10-06 15:40:29)",-0.422008,-1.760743,0.360532,-0.225783,-0.288689,-0.265706,0.404836,0.391566,0.36885,-2.479374,2.356989,86.0,37.0,19.0,0.0


# Unormalized version of the data

In [9]:
data_non_normalized_df.head()

Unnamed: 0_level_0,Unnamed: 1_level_0,creatinine,fraction_inspired_oxygen,lactate,urine_output,alanine_aminotransferase,asparate_aminotransferase,mean_blood_pressure,diastolic_blood_pressure,systolic_blood_pressure,gcs,partial_pressure_of_oxygen,heart_rate,temperature,respiratory_rate,action
stay_id,time,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1
30004811,"[2139-10-06 10:40:29, 2139-10-06 11:40:29)",1.0,0.21,1.8,80.0,34.0,40.0,77.0,59.0,118.0,11.0,112.0,86.0,37.0,19.0,0.0
30004811,"[2139-10-06 11:40:29, 2139-10-06 12:40:29)",1.0,0.21,1.8,80.0,34.0,40.0,77.0,59.0,118.0,11.0,112.0,86.0,37.0,19.0,0.0
30004811,"[2139-10-06 12:40:29, 2139-10-06 13:40:29)",1.0,0.21,3.0,80.0,34.0,40.0,77.0,59.0,118.0,11.0,272.0,86.0,37.0,19.0,0.0
30004811,"[2139-10-06 13:40:29, 2139-10-06 14:40:29)",1.0,0.21,3.0,80.0,34.0,40.0,77.0,59.0,118.0,11.0,272.0,86.0,37.0,19.0,0.0
30004811,"[2139-10-06 14:40:29, 2139-10-06 15:40:29)",1.0,0.21,3.0,80.0,34.0,40.0,77.0,59.0,118.0,11.0,272.0,86.0,37.0,19.0,0.0


# Matrix form of the data (Normalized and features only)

In [10]:
X_df.head()

Unnamed: 0,creatinine,fraction_inspired_oxygen,lactate,urine_output,alanine_aminotransferase,asparate_aminotransferase,mean_blood_pressure,diastolic_blood_pressure,systolic_blood_pressure,gcs,partial_pressure_of_oxygen,heart_rate,temperature,respiratory_rate
0,-0.422008,-1.760743,-0.182521,-0.225783,-0.288689,-0.265706,0.404836,0.391566,0.36885,-2.479374,-0.186486,86.0,37.0,19.0
1,-0.422008,-1.760743,-0.182521,-0.225783,-0.288689,-0.265706,0.404836,0.391566,0.36885,-2.479374,-0.186486,86.0,37.0,19.0
2,-0.422008,-1.760743,0.360532,-0.225783,-0.288689,-0.265706,0.404836,0.391566,0.36885,-2.479374,2.356989,86.0,37.0,19.0
3,-0.422008,-1.760743,0.360532,-0.225783,-0.288689,-0.265706,0.404836,0.391566,0.36885,-2.479374,2.356989,86.0,37.0,19.0
4,-0.422008,-1.760743,0.360532,-0.225783,-0.288689,-0.265706,0.404836,0.391566,0.36885,-2.479374,2.356989,86.0,37.0,19.0


# Corresponding output data for training BC (corresponding treatments for each data point in X_df)

In [11]:
y_df.head()

0    0.0
1    0.0
2    0.0
3    0.0
4    0.0
Name: action, dtype: float64

# Clustering the feature space to extract a discrete state space form the clusters

In [12]:
num_clusters = 100
kmeans = KMeans(n_clusters= num_clusters , random_state=0)
kmeans.fit(X_df)



In [13]:
# Looking at the values counts for each cluster

np.unique(kmeans.labels_, return_counts = True)[1]

array([1573, 5201, 4839, 6302,  755, 2624, 4282, 1152,  206, 2074, 5666,
       2005, 6163, 1005, 3104,  859, 5044, 3901, 1961, 7998, 5495, 3806,
       4191, 3170, 2419,  573, 1358,  721, 8024, 6950, 4028,  423, 7428,
       5084, 6395, 2433,  235, 4324, 7887, 2404, 4686, 8928, 1132,  100,
       3347,  608, 4564, 7711, 1162,  120,   18, 2624, 5086, 3867, 6196,
       1095,  955,  198,  917, 3480, 7076, 2589, 4021, 8352, 1976, 7861,
       1557, 5527, 1369, 4511, 2516, 1746,  560, 6902, 2894,  676, 1189,
       6226, 5385, 1181, 8436, 7384, 5240,  702, 4064, 3613, 1590, 6754,
       2498, 3158,    2, 2687,  283, 4901, 6062, 3146,  678, 3040, 8417,
       7879])

In [14]:
# Assigning each data point to a cluster

X_df['cluster'] = kmeans.labels_.copy()
data_df['cluster'] = kmeans.labels_.copy()
data_non_normalized_df['cluster'] = kmeans.labels_.copy()

In [15]:
data_df.head()

Unnamed: 0_level_0,Unnamed: 1_level_0,creatinine,fraction_inspired_oxygen,lactate,urine_output,alanine_aminotransferase,asparate_aminotransferase,mean_blood_pressure,diastolic_blood_pressure,systolic_blood_pressure,gcs,partial_pressure_of_oxygen,heart_rate,temperature,respiratory_rate,action,cluster
stay_id,time,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1
30004811,"[2139-10-06 10:40:29, 2139-10-06 11:40:29)",-0.422008,-1.760743,-0.182521,-0.225783,-0.288689,-0.265706,0.404836,0.391566,0.36885,-2.479374,-0.186486,86.0,37.0,19.0,0.0,38
30004811,"[2139-10-06 11:40:29, 2139-10-06 12:40:29)",-0.422008,-1.760743,-0.182521,-0.225783,-0.288689,-0.265706,0.404836,0.391566,0.36885,-2.479374,-0.186486,86.0,37.0,19.0,0.0,38
30004811,"[2139-10-06 12:40:29, 2139-10-06 13:40:29)",-0.422008,-1.760743,0.360532,-0.225783,-0.288689,-0.265706,0.404836,0.391566,0.36885,-2.479374,2.356989,86.0,37.0,19.0,0.0,38
30004811,"[2139-10-06 13:40:29, 2139-10-06 14:40:29)",-0.422008,-1.760743,0.360532,-0.225783,-0.288689,-0.265706,0.404836,0.391566,0.36885,-2.479374,2.356989,86.0,37.0,19.0,0.0,38
30004811,"[2139-10-06 14:40:29, 2139-10-06 15:40:29)",-0.422008,-1.760743,0.360532,-0.225783,-0.288689,-0.265706,0.404836,0.391566,0.36885,-2.479374,2.356989,86.0,37.0,19.0,0.0,38


In [16]:
X_df.head()

Unnamed: 0,creatinine,fraction_inspired_oxygen,lactate,urine_output,alanine_aminotransferase,asparate_aminotransferase,mean_blood_pressure,diastolic_blood_pressure,systolic_blood_pressure,gcs,partial_pressure_of_oxygen,heart_rate,temperature,respiratory_rate,cluster
0,-0.422008,-1.760743,-0.182521,-0.225783,-0.288689,-0.265706,0.404836,0.391566,0.36885,-2.479374,-0.186486,86.0,37.0,19.0,38
1,-0.422008,-1.760743,-0.182521,-0.225783,-0.288689,-0.265706,0.404836,0.391566,0.36885,-2.479374,-0.186486,86.0,37.0,19.0,38
2,-0.422008,-1.760743,0.360532,-0.225783,-0.288689,-0.265706,0.404836,0.391566,0.36885,-2.479374,2.356989,86.0,37.0,19.0,38
3,-0.422008,-1.760743,0.360532,-0.225783,-0.288689,-0.265706,0.404836,0.391566,0.36885,-2.479374,2.356989,86.0,37.0,19.0,38
4,-0.422008,-1.760743,0.360532,-0.225783,-0.288689,-0.265706,0.404836,0.391566,0.36885,-2.479374,2.356989,86.0,37.0,19.0,38


In [17]:
data_non_normalized_df.head()

Unnamed: 0_level_0,Unnamed: 1_level_0,creatinine,fraction_inspired_oxygen,lactate,urine_output,alanine_aminotransferase,asparate_aminotransferase,mean_blood_pressure,diastolic_blood_pressure,systolic_blood_pressure,gcs,partial_pressure_of_oxygen,heart_rate,temperature,respiratory_rate,action,cluster
stay_id,time,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1
30004811,"[2139-10-06 10:40:29, 2139-10-06 11:40:29)",1.0,0.21,1.8,80.0,34.0,40.0,77.0,59.0,118.0,11.0,112.0,86.0,37.0,19.0,0.0,38
30004811,"[2139-10-06 11:40:29, 2139-10-06 12:40:29)",1.0,0.21,1.8,80.0,34.0,40.0,77.0,59.0,118.0,11.0,112.0,86.0,37.0,19.0,0.0,38
30004811,"[2139-10-06 12:40:29, 2139-10-06 13:40:29)",1.0,0.21,3.0,80.0,34.0,40.0,77.0,59.0,118.0,11.0,272.0,86.0,37.0,19.0,0.0,38
30004811,"[2139-10-06 13:40:29, 2139-10-06 14:40:29)",1.0,0.21,3.0,80.0,34.0,40.0,77.0,59.0,118.0,11.0,272.0,86.0,37.0,19.0,0.0,38
30004811,"[2139-10-06 14:40:29, 2139-10-06 15:40:29)",1.0,0.21,3.0,80.0,34.0,40.0,77.0,59.0,118.0,11.0,272.0,86.0,37.0,19.0,0.0,38


# Converting the data into trajectories to input to an IRL algorithm Note this is the same format of trajectories we used for HW1 and HW2.

In [18]:
unique_stay_ids = data_df.index.get_level_values('stay_id').unique()

trajectories = []


for stay_id in unique_stay_ids:


  states, actions = data_df.loc[stay_id]['cluster'], data_df.loc[stay_id]['action']

  trajectory = []
  for i in range(len(states) - 1):
    trajectory.append((states[i], int(actions[i]), states[i+1] ))

  trajectories.append(T.Trajectory(trajectory))

We need to store all possible terminal states from the trajectories list. (Needed to calculate the normalizing constant in MaxEnt)

In [19]:
terminal_states = []

for traj in trajectories:
  terminal_states.append(traj._t[-1][-1])

terminal_states = list(set(terminal_states))

In [20]:
terminal_states

[0,
 1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 11,
 12,
 13,
 14,
 15,
 16,
 17,
 18,
 19,
 20,
 21,
 22,
 23,
 24,
 25,
 26,
 27,
 28,
 29,
 30,
 31,
 32,
 33,
 34,
 35,
 36,
 37,
 38,
 39,
 40,
 41,
 42,
 44,
 45,
 46,
 47,
 48,
 49,
 50,
 51,
 52,
 53,
 54,
 55,
 56,
 57,
 58,
 59,
 60,
 61,
 62,
 63,
 64,
 65,
 66,
 67,
 68,
 69,
 70,
 71,
 72,
 73,
 74,
 75,
 76,
 77,
 78,
 79,
 80,
 81,
 82,
 83,
 84,
 85,
 86,
 87,
 88,
 89,
 91,
 92,
 93,
 94,
 95,
 96,
 97,
 98,
 99]

# Distribution of the treatments given in our data. (Most of the time no treatment is given, might vary on depending on how you cluster the data)

In [21]:
y_df.value_counts()

0.0    195786
1.0    135305
3.0     15978
2.0      8435
Name: action, dtype: int64

# Estimating the Transition Dynamics using the MLE (feel free to play around with the smoothing_value)

In [22]:
smoothing_value = 1

p_transition = np.zeros((num_clusters, num_clusters, 4)) + smoothing_value


for traj in trajectories:

  for tran in traj._t:

    p_transition[tran[0], tran[2], tran[1]] +=1

p_transition = p_transition/ p_transition.sum(axis = 1)[:, np.newaxis, :]

# Adverse Reinforcement Learning

In [23]:
import gym
import numpy as np

class HealthcareIRLEnvironment(gym.Env):
    def __init__(self, num_states, num_actions, transition_matrix, expert_trajectories):
        super(HealthcareIRLEnvironment, self).__init__()
        self.num_states = num_states
        self.num_actions = num_actions
        self.transition_matrix = transition_matrix
        self.expert_trajectories = expert_trajectories
        self.current_state = None
        self.current_step = None

        # Define action and observation spaces
        self.action_space = gym.spaces.Discrete(num_actions)
        self.observation_space = gym.spaces.Discrete(num_states)

    def reset(self):
        # Reset the environment to an initial state
        self.current_state = np.random.randint(self.num_states)
        self.current_step = 0
        return self.current_state

    def step(self, action):
        # Take a step in the environment based on the given action
        # Compute the next state based on the transition matrix
        next_state_probs = self.transition_matrix[self.current_state, action, :]
        next_state = np.random.choice(self.num_states, p=next_state_probs)

        # Compute the reward based on the expert trajectories
        reward = 0.0
        for trajectory in self.expert_trajectories:
            for t in trajectory:
                if t._t[0] == self.current_state and t._t[1] == action:
                    # TODO: fix later, for now just add 0.5 everytime a state, action pair is present in the expert trajectories
                    reward += 0.5

        self.current_state = next_state
        self.current_step += 1

        # Define a termination condition (e.g., episode ends after a certain number of steps)
        done = self.current_step >= 100

        # Return the next state, reward, done flag, and additional info
        return next_state, reward, done, {}

In [24]:
healthcare_env = HealthcareIRLEnvironment(num_clusters, 4, p_transition, trajectories)

  and should_run_async(code)


In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable

# Define the discriminator (a neural network) for AIRL
class Discriminator(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(state_dim + action_dim, 128)
        self.fc2 = nn.Linear(128, 1)

    def forward(self, state, action):
        x = torch.cat([state, action], dim=0)
        x = torch.relu(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        return x

class PolicyNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(in_features=100, out_features=64)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(in_features=64, out_features=4)
        self.softmax = nn.Softmax()

    def forward(self, state):
        x = self.fc1(state)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.softmax(x)
        return x



In [None]:
def actions_to_one_hot_array(input_array):

  num_unique_values = 4
  one_hot_array = np.zeros((len(input_array), num_unique_values), dtype=int)

  for i, val in enumerate(input_array):
    one_hot_array[i, val] = 1

  return one_hot_array

def states_to_one_hot_array(input_array):

  num_unique_values = 100
  one_hot_array = np.zeros((len(input_array), num_unique_values), dtype=int)

  for i, val in enumerate(input_array):
    one_hot_array[i, val] = 1

  return one_hot_array


In [None]:
states = []
actions = []

for traject in trajectories:
  for t in traject._t:
    states.append(t[0])
    actions.append(t[1])

In [None]:
# warm start policy
LEARNING_RATE = 0.01
NUM_EPOCHS = 500

model = PolicyNetwork(num_clusters, 4)
loss = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE)

one_hot_actions = torch.tensor(actions_to_one_hot_array(actions)).to(torch.float32)
one_hot_states = torch.tensor(states_to_one_hot_array(states)).to(torch.float32)

for epoch in range(NUM_EPOCHS):
    optimizer.zero_grad()

    output = model(one_hot_states)

    l = loss(output, one_hot_actions)
    l.backward()
    optimizer.step()

print("Training complete!")


In [None]:
all_states = torch.tensor(states_to_one_hot_array(range(100))).to(torch.float32)

with torch.no_grad():
    # Make predictions
    predictions = model(all_states)

bc_policy = torch.argmax(predictions, dim=1).numpy()  #This should be an array of shape (25, ) with entries being either 0, 1, 2 or 3 (corresponding to the actions for each of the 25 states)

assert bc_policy.shape == (100, )

In [None]:
state_dim = num_clusters
action_dim = 4
discriminator = Discriminator(state_dim, action_dim)
policy_network = PolicyNetwork(num_clusters, 4)

discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=0.001)
policy_optimizer = optim.Adam(policy_network.parameters(), lr=0.001)

criterion = nn.BCELoss()


In [31]:
num_epochs = 5
for epoch in range(num_epochs):
    # Collect expert demonstrations (multiple trajectories)
    expert_trajectories = trajectories  # Replace with your data or expert policy

    # Update the discriminator
    for trajectory in expert_trajectories:
        for i in range(len(trajectory.transitions())):
            state, action, _ = trajectory.transitions()[i]
            state = torch.nn.functional.one_hot(torch.tensor(int(state)), num_classes=num_clusters).float()
            action = torch.nn.functional.one_hot(torch.tensor(int(action)), num_classes=4).float()
            predicted_expert_prob = discriminator(state, action)
            expert_label = Variable(torch.ones(1))
            loss_expert = criterion(predicted_expert_prob, expert_label)

            # Update discriminator
            discriminator_optimizer.zero_grad()
            loss_expert.backward()
            discriminator_optimizer.step()

    # Update the policy network using the discriminator
    for trajectory in expert_trajectories:
        for i in range(len(trajectory.transitions())):
            state, _, _ = trajectory.transitions()[i]
            state = Variable(torch.tensor(state).float())
            state = torch.nn.functional.one_hot(torch.tensor(int(state)), num_classes=num_clusters).float()
            policy_action = policy_network(state)
            predicted_expert_prob = discriminator(state, policy_action)
            policy_loss = -torch.log(predicted_expert_prob)

            # Update policy network
            policy_optimizer.zero_grad()
            policy_loss.backward()
            policy_optimizer.step()

  x = self.softmax(x)


In [34]:
# for i in range(num_clusters):
#   state = torch.nn.functional.one_hot(torch.tensor(i), num_classes=num_clusters).float()
#   policy_action = policy_network(state)
#   print(np.argmax([float(a) for a in policy_action]))