In [3]:
# imports
import gym
import numpy as np
from tqdm import tqdm
import threading
import math
from time import sleep
import queue

# constants and structs
class ViResults:
    def __init__(self, exploration_qs, exploration_values, optimal_qs, optimal_values, worst_qs, worst_values, visitable_states):
        self.exploration_qs = exploration_qs
        self.exploration_values = exploration_values
        self.optimal_qs = optimal_qs
        self.optimal_values = optimal_values
        self.worst_qs = worst_qs
        self.worst_values = worst_values
        self.visitable_states = visitable_states

class EHResults:
    def __init__(self, ks, ms, vars, gaps, effective_horizon):
        self.ks = ks
        self.ms = ms
        self.vars = vars
        self.gaps = gaps
        self.effective_horizon = effective_horizon

REWARD_PRECISION = 1e-4

def findmax(arr):
    max_val = np.max(arr)
    max_index = np.argmax(arr)
    return max_val, max_index

In [85]:
### Pipeline: functions to get bounds and their helpers ###


# value iteration: calculate exploration_qs, exploration_values, optimal_qs, optimal_values, worst_qs, worst_values, visitable_states, returns in single results structure in this order
def value_iteration(transitions, rewards, horizon, exploration_policy=None):
    # get constants from environment
    num_states = len(transitions) #env.observation_space.n
    num_actions = len(transitions[0]) #env.action_space.n

    # get list of visitable states where each item is a set of visitable states at a given timestep (index)
    visitable_states = []
    current_visitable_states = set()
    current_visitable_states.add(0)
    for _ in tqdm(range(horizon)):
        next_visitable_states = set()
        for state in current_visitable_states:
            for action in range(num_actions):
                next_state = transitions[state][action]
                next_visitable_states.add(next_state)
        visitable_states.append(next_visitable_states.union(current_visitable_states))
        current_visitable_states = next_visitable_states
    print(len(visitable_states))
    print(horizon)
    
    # initialize outputs
    exploration_qs = np.full((horizon, num_states, num_actions), 0.)
    exploration_values = np.full((horizon, num_states), 0.)
    optimal_qs = np.full((horizon, num_states, num_actions), 0.)
    optimal_values = np.full((horizon, num_states), 0.)
    worst_qs = np.full((horizon, num_states, num_actions), 0.)
    worst_values = np.full((horizon, num_states), 0.)

    # get output vals for vi
    for ts in tqdm(reversed(range(horizon))):
        for state in visitable_states[ts]:
            for action in range(num_actions):
                obs = transitions[state][action]
                reward = rewards[state][action]
                if ts < horizon:
                    # a = reward + exploration_values[ts+1, obs]
                    # if a != 0:
                    #     print(a)
                    exploration_qs[ts, state, action] = reward + exploration_values[ts+1, obs]
                    # if a != 0:
                    #     print(exploration_qs[ts, state, action])
                    optimal_qs[ts, state, action] = reward + optimal_values[ts+1, obs]
                    worst_qs[ts, state, action] = reward + worst_values[ts+1, obs]
                else:
                    exploration_qs[ts, state, action] = reward
                    optimal_qs[ts, state, action] = reward
                    worst_qs[ts, state, action] = reward

            optimal_value = max(optimal_qs[ts, state, :])
            worst_value = min(worst_qs[ts, state, :])
            if exploration_policy == None:
                exploration_value = sum(exploration_qs[ts, state, :]) / num_actions
            else:
                exploration_value = 0
                for action in range(num_actions):
                    exploration_value += exploration_qs[ts, state, action] * exploration_policy[ts, state, action]
            
            optimal_values[ts, state] = optimal_value
            worst_values[ts, state] = worst_value

            # verification: no nans and no float errs
            # ensure exploration val is bw worst and optimal vals (avoid floating pt errors)
            exploration_values[ts, state] = min(optimal_value, max(worst_value, exploration_value))
            assert(not np.isnan(exploration_values[ts, state]))
            assert(not np.isnan(optimal_values[ts, state]))
            assert(not np.isnan(worst_values[ts, state]))

    # define results struct and return results
    # assert(not np.isclose(exploration_values, np.full((horizon, num_states), 0.)).all())
    results = ViResults(exploration_qs, exploration_values, optimal_qs, optimal_values, worst_qs, worst_values, visitable_states)
    return results


# calculate EH bound using GORP bounds
def get_EH_bound(transitions, rewards, horizon, exploration_policy=None):
    num_states = len(transitions)#env.observation_space.n
    num_actions = len(transitions[0])#env.action_space.n
    # Perform value iteration
    vi = value_iteration(transitions, rewards, horizon, exploration_policy)
    print("done with value iteration")
    
    # initialize and compute variance bounds
    var_bounds = np.full((horizon, num_states, num_actions), 0.)
    for timestep in tqdm(range(horizon-1)):
        # across different states in a given timestep, so no overwriting vals in data structures
        for state in vi.visitable_states[timestep]:
            # finds a bound on the variance of the qs based on the best and worst qs for each state at a timestep
            for action in range(num_actions):
                q = vi.exploration_qs[timestep, state, action]
                worst_q = vi.worst_qs[timestep,state,action]
                optimal_q = vi.optimal_qs[timestep, state, action]
                var_bound = (q - worst_q) * (optimal_q - worst_q)
                var_bounds[timestep, state, action] = var_bound
    print("done with variance bounds")

    # initialize q values, starting k, and EH results structure
    current_qs = vi.exploration_qs
    results = EHResults([],[],[],[],horizon)
    k = 1

    # find best working k
    while k < results.effective_horizon:
        # initialize results objects and iteration
        k_works = True
        state_ms = np.full((num_states, horizon), 0.)
        state_vars = np.full((num_states, horizon), 0.)
        state_gaps = np.full((num_states, horizon), 0.)
        states_can_be_visited = np.full((num_states, horizon), False)
        states_can_be_visited[0,0] = True
        
        for timestep in tqdm(reversed(range(horizon-1))):
            for state in vi.visitable_states[timestep]:
                # checks k validity and adds to seeable next states in following timestep
                if states_can_be_visited[timestep, state]:
                    # init vals
                    max_q = -np.inf
                    max_suboptimal_q = -np.inf
                    max_var = 0
                    # iterate over actions to get actual vals
                    for action in range(num_actions):
                        q = current_qs[timestep, state, action]
                        max_q = max(max_q, q)
                        if vi.optimal_qs[timestep,state,action] < vi.optimal_values[timestep, state] - REWARD_PRECISION:
                            max_suboptimal_q = max(max_suboptimal_q, q)
                        max_var = max(max_var, var_bounds[timestep, state, action])
                        # check for k fail condition
                        if max_q == max_suboptimal_q:
                               k_works = False
                        # otherwise get the state m value
                        else:
                            gap = max_q - max_suboptimal_q
                            state_gaps[timestep, state] = gap
                            state_vars[timestep, state] = max_var
                            m = math.ceil(16 * max_var / (gap**2) * math.log(2 * horizon * (num_actions**k)))
                            state_ms[timestep, state] = max(1,m)
                        # iterate through actions to find next visitable states
                            for action in range(num_actions):
                                if current_qs[timestep, state, action] > max_suboptimal_q:
                                    next_state = transitions[state][action]
                                    # _, next_state, _, _ = env.P[state][action]
                                    if timestep < horizon:
                                        states_can_be_visited[timestep+1, next_state] = True
        if not k_works:
            break
        print("done with k check")
    
        # if k works flag not triggered, then k is horizon and we update result vals accordingly
        if k_works:
            results.ks.append(k)
            highest_m = np.max(state_ms)
            timestep, state = np.unravel_index(np.argmax(state_ms), state_ms.shape)
            results.ms.append(highest_m)
            # log of highest_m with base num_actions
            H_k = k + math.log(highest_m, num_actions)
            print(f"H_{k} = {H_k}")
            results.gaps.append(state_gaps[timestep, state])
            results.vars.append(state_vars[timestep, state])
            results.effective_horizon = min(results.effective_horizon, H_k)
        
        # run a bellman backup
        for timestep in tqdm(reversed(range(horizon-1))):
            # thread across different states in a given timestep, so no overwriting vals in data structures
            for state in vi.visitable_states[timestep]:
                 # update the current qs and variance bounds for the run
                 for action in range(num_actions):
                    next_state = transitions[state][action]
                    reward = rewards[state][action]
                    #_, next_state, reward, _ = env.P[state][action]
                    max_next_q = -np.inf
                    max_next_var_bound = 0
                    for action in range(num_actions):
                        next_q = current_qs[timestep+1, next_state, action]
                        max_next_q = max(max_next_q, next_q)
                        next_var_bound = var_bounds[timestep+1, next_state, action]
                        max_next_var_bound = max(max_next_var_bound, next_var_bound)
                    current_qs[timestep, state, action] = reward + max_next_q
                    var_bounds[timestep, state, action] = max_next_var_bound
        print("done w bellman backup")        
        k += 1
        print("k increased")
    print("EH results done")
    return results


# calculate EPW
def get_EPW(transitions, rewards, horizon, exploration_policy=None, start_with_rewards=True):
    # calculate min k for the gym environment and return it
    num_states = len(transitions) # env.observation_space.n
    num_actions = len(transitions[0]) # env.action_space.n

    # Perform value iteration
    vi = value_iteration(transitions, rewards, horizon, exploration_policy)

    # init current_qs
    if start_with_rewards:
        current_qs = np.full((horizon, num_states, num_actions), 0.)
        for timestep in tqdm(range(horizon)):
            for state in vi.visitable_states[timestep]:
                for action in range(num_actions):
                    current_qs[timestep, state, action] = rewards[state][action]
                    #_, _, current_qs[timestep, state, action], _ = env.P[state][action]
    else:
        current_qs = vi.exploration_qs

    #
    k = 1
    while True:
        # check if this k value works
        k_works = True
        states_can_be_visited = np.full((horizon, num_states), False)
        states_can_be_visited[0,0] = True
        for timestep in tqdm(range(horizon)):
            for state in vi.visitable_states[timestep]:
                #checks if given state works for k
                if states_can_be_visited[timestep, state]:
                    # init and find max q
                    max_q = -np.inf
                    for action in range(num_actions):
                        max_q = max(max_q, current_qs[timestep, state, action])
                    for action in range(num_actions):
                        # check if possible to take this action
                        if current_qs[timestep, state, action] >= max_q:
                            if vi.optimal_qs[timestep, state, action] < vi.optimal_values[timestep,state] - REWARD_PRECISION:
                                k_works = False
                            next_state = transitions[state][action]
                            #_, next_state, _, _ = env.P[state][action]
                            if timestep < horizon:
                                states_can_be_visited[timestep+1, next_state] = True
        if not k_works:
            break
    
    if k_works:
        return k
            
    # otherwise run bellman backup and up k
    for timestep in tqdm(range(horizon)):
        for state in vi.visitable_states[timestep]:
            for action in range(num_actions):
                next_state = transitions[state][action]
                reward = rewards[state][action]
                #_, next_state, reward, _ = env.P[state][action]
                max_next_q = -np.inf
                for action in range(num_actions):
                    next_q = current_qs[timestep+1, next_state, action]
                    max_next_q = max(max_next_q, next_q)
                current_qs[timestep, state, action] = reward + max_next_q

    k += 1

In [105]:
## run through ##

# load tables from dataset
pong_table = np.load('/Users/laurenc/Documents/GitHub/282_expansion/data/bridge_dataset/mdps/pong_20_fs30/consolidated.npz')
rewards = pong_table['rewards']
transitions = pong_table['transitions']
for line in transitions:
    for item in line:
        print(item, end=" ")
    print(";")
horizon = len(transitions)

# calculate eh
# eh_results = get_EH_bound(transitions, rewards, horizon)
# print(eh_results.effective_horizon)

# calculate EPW
# epw_result = get_EPW(transitions, rewards, horizon)
# print(epw_result)

# # print results
# print(f"effective horizon: {eh_results.effective_horizon} \n effective planning window: {epw_result}")


123 123 70 28 70 28 ;
66 66 66 242 66 242 ;
138 138 138 114 138 114 ;
41 41 41 20 41 20 ;
209 209 209 108 209 108 ;
140 140 170 239 170 239 ;
54 54 185 122 185 122 ;
143 143 24 143 24 143 ;
19 19 19 173 19 173 ;
67 67 67 202 67 202 ;
16 16 45 71 45 71 ;
37 37 65 77 65 77 ;
53 53 159 141 159 141 ;
86 86 219 217 219 217 ;
184 184 2 63 2 63 ;
150 150 150 204 150 204 ;
215 215 167 50 167 50 ;
-1 -1 -1 -1 -1 -1 ;
72 72 178 145 178 145 ;
41 41 41 20 41 20 ;
-1 -1 -1 -1 -1 -1 ;
157 157 94 157 94 157 ;
238 238 66 82 66 82 ;
75 75 160 157 160 157 ;
-1 -1 -1 -1 -1 -1 ;
57 57 19 225 19 225 ;
-1 -1 -1 -1 -1 -1 ;
112 112 198 112 198 112 ;
49 49 11 49 11 49 ;
161 161 235 7 235 7 ;
101 101 168 101 168 101 ;
35 35 35 39 35 39 ;
72 72 178 145 178 145 ;
86 86 219 217 219 217 ;
170 170 170 32 170 32 ;
-1 -1 -1 -1 -1 -1 ;
143 143 24 143 24 143 ;
16 16 45 71 45 71 ;
44 44 56 74 56 74 ;
-1 -1 -1 -1 -1 -1 ;
154 154 148 27 148 27 ;
-1 -1 -1 -1 -1 -1 ;
97 97 189 201 189 201 ;
98 98 81 153 81 153 ;
18 18 188 23