In [1]:
import numpy as np
import torch
from scipy import stats

# Importance Sampling Methods
def ImportanceSampling(histories, cur_policy, gamma, ep_num, new_policy):
    traj = histories[ep_num]
    is_weight = 1
    disc_return = 0
    for j in range(traj.shape[0]):
        St, At, Rt, _ = traj[j]
        St = int(St)
        At = int(At)
        is_weight *= new_policy[St, At] / cur_policy[St, At]
        disc_return += (gamma ** j) * Rt
    return is_weight * disc_return

def PDImportanceSampling(histories, cur_policy, gamma, ep_num, new_policy):
    traj = histories[ep_num]

    result = 0
    for t in range(traj.shape[0]):
        _, _, Rt, _ = traj[t]
        is_weight = 1
        for j in range(t + 1):
            St, At, _, _ = traj[j]
            St = int(St)
            At = int(At)
            is_weight *= new_policy[St, At] / cur_policy[St, At]
        result += (gamma ** t) * is_weight * Rt
    return result

def CalcAvgIS(histories, cur_policy, gamma, new_policy, ISFunc):
    total = 0
    print("Averaging Importance Sampling")
    update_freq = int(0.25 * len(histories))
    for ep in range(len(histories)):
        if(ep % update_freq == 0):
            print(str(ep) + " / " + str(len(histories)))
        total += ISFunc(histories, cur_policy, gamma, ep, new_policy)
    return total / len(histories)

def CalcStdDev(histories, cur_policy, gamma, new_policy, ISFunc, avgIS):
    total = 0
    for ep in range(len(histories)):
        total += (ISFunc(histories, cur_policy, gamma, ep, new_policy) - avgIS)**2

    return np.sqrt((1 / (len(histories) - 1)) * total)

def Safety_Prediction(histories, cur_policy, gamma, new_policy, ISFunc, delta, num_safety, avgIS = None):
    t_value = stats.t.ppf(1-delta, num_safety - 1)
    if(avgIS == None):
        avgIS = CalcAvgIS(histories, cur_policy, gamma, new_policy, ISFunc)
    std_dev = CalcStdDev(histories, cur_policy, gamma, new_policy, ISFunc, avgIS)

    return avgIS - 2 * (std_dev / np.sqrt(num_safety)) * t_value

def Safety_Test(histories, cur_policy, gamma, new_policy, ISFunc, delta):
    num_safety = len(histories)
    t_value = stats.t.ppf(1-delta, num_safety - 1)
    avgIS = CalcAvgIS(histories, cur_policy, gamma, new_policy, ISFunc)
    std_dev = CalcStdDev(histories, cur_policy, gamma, new_policy, ISFunc, avgIS)

    return avgIS - (std_dev / np.sqrt(num_safety)) * t_value 

In [2]:
import numpy as np
import torch
from scipy import stats

# Gets histories from the CSV
def GetHistories(path):
    num_episodes = -1 #not used
    histories = []
    with open(path) as file:
        cur_episode = -1
        cur_timestep = -1
        for idx, line in enumerate(file):
            #not used
            if(idx == 0):
                num_episodes = int(line)
                continue

            data = line.split(",")
            if(len(data) == 1):
                num_time_steps = int(line)
                cur_episode += 1
                cur_timestep = 0

                traj = torch.zeros((num_time_steps, 4))
                if torch.cuda.is_available():
                    traj = traj.cuda()

                histories.append(np.zeros((num_time_steps, 4)))
                continue

            St, At, Rt, pib = data
            histories[cur_episode][cur_timestep, 0] = int(St)
            histories[cur_episode][cur_timestep, 1] = int(At)
            histories[cur_episode][cur_timestep, 2] = float(Rt)
            histories[cur_episode][cur_timestep, 3] = float(pib)         
            cur_timestep += 1
    return histories

# Extracts as much policy info as possible from episode
def GetPolicyFromEpisode(histories, ep_num, num_states, num_actions):
    policy = np.zeros((num_states, num_actions))
    traj = histories[ep_num]

    for state in range(num_states):
        for action in range(num_actions):
            valid_idx = np.logical_and(traj[:,0] == state, traj[:,1] == action)
            if(not valid_idx.any()):
                continue

            policy[state, action] = traj[valid_idx][0,3]

    return policy

# Gets the policy
def GetPolicy(histories, num_states, num_actions, num_iterartions):
    cur_policy = np.zeros((num_states, num_actions))

    # for ep in range(len(histories)):
    for ep in range(num_iterartions):
        temp_p = GetPolicyFromEpisode(histories, ep, num_states, num_actions)
        for state in range(num_states):
            for action in range(num_actions):
                if((cur_policy[state, action] == 0) and (temp_p[state, action] != 0)):
                    cur_policy[state, action] = temp_p[state, action]
    return cur_policy

In [14]:
#---SET PARAMS---
USE_GRIDWORLD = True
USE_PDIS = True
num_train_intervals = 40
percent_increase = 0.1

num_states = 18
if(USE_GRIDWORLD):
    num_states = 23
num_actions = 4
gamma = 0.95
delta = 0.01 #1 - delta, confidence

In [15]:
path = "data\data.csv"
if(USE_GRIDWORLD):
    path = "data\gridworld_data.csv"

histories = GetHistories(path)

In [17]:
split_idx = int(len(histories) * .8)
train = histories[:split_idx]
test = histories[split_idx:]
print(len(train))
print(len(test))

80000
20000


In [18]:
exploration_policy = GetPolicy(train, num_states, num_actions, 1000)
new_policy = np.load("policies/gw/delta_0.01/safety_0.4380582971385134.npy")
print(new_policy)

[[0.20763181 0.32261548 0.2320871  0.2376656 ]
 [0.43754009 0.30769659 0.18706861 0.06769472]
 [0.1496868  0.07095748 0.60367863 0.17567709]
 [0.1293479  0.19027268 0.62913662 0.0512428 ]
 [0.18533418 0.01572689 0.30536101 0.49357793]
 [0.0960113  0.71620255 0.11046996 0.07731619]
 [0.00503686 0.97096792 0.01639663 0.00759859]
 [0.19370951 0.45149438 0.30279422 0.05200189]
 [0.28703467 0.12633633 0.2157527  0.37087631]
 [0.2284537  0.24358334 0.19415907 0.33380389]
 [0.75345447 0.08064593 0.05413773 0.11176187]
 [0.10877635 0.50308273 0.19032245 0.19781848]
 [0.26561551 0.21333928 0.2680081  0.25303711]
 [0.12665282 0.05972497 0.79498376 0.01863844]
 [0.33565089 0.33904759 0.01103397 0.31426754]
 [0.957755   0.02293884 0.00478514 0.01452102]
 [0.05718595 0.18441702 0.32750377 0.43089327]
 [0.03235585 0.30904219 0.42189042 0.23671155]
 [0.03575216 0.05057249 0.87646227 0.03721309]
 [0.5877959  0.02863936 0.09584122 0.28772351]
 [0.02369365 0.01379787 0.87515103 0.08735745]
 [0.18271901 

In [19]:
ISFunc = ImportanceSampling
if(USE_PDIS):
    ISFunc = PDImportanceSampling

In [20]:
J_predicted_lower_bound = Safety_Prediction(train, exploration_policy, gamma, new_policy, ISFunc, delta, len(test))
print(J_predicted_lower_bound)
J_safety_lower_bound = Safety_Test(test, exploration_policy, gamma, new_policy, ISFunc, delta)
print(J_safety_lower_bound)

Averaging Importance Sampling
0 / 80000
20000 / 80000
40000 / 80000
60000 / 80000
-1.182750951767137
Averaging Importance Sampling
0 / 20000
5000 / 20000
10000 / 20000
15000 / 20000
0.4322671551878296
