# Calculate Effective Horizon and Effective Planning Window for a given MDP

EH calculation works as follows:
- loads np arrays for the transitions and rewards:
    - transitions_array of dimension (_, num_states, num_actions)
        - holds ints
        - Each cell [i, j] in the array represents the index of the next state that the MDP transitions to when action j is taken in state i. If the next state is a terminal state, the value stored is -1.
    - rewards_array of dimension (_, num_states, num_actions) 
        - holds reward type (set as float32, left to be changed at top of project file)
        - Each cell [i, j] stores the reward associated with transitioning from state i to the next state when action j is taken.

In [10]:
# imports
import gym
import numpy as np
from tqdm import tqdm
import threading
import math

# constants and structs
class ViResults:
    def __init__(self, exploration_qs, exploration_values, optimal_qs, optimal_values, worst_qs, worst_values, visitable_states):
        self.exploration_qs = exploration_qs
        self.exploration_values = exploration_values
        self.optimal_qs = optimal_qs
        self.optimal_values = optimal_values
        self.worst_qs = worst_qs
        self.worst_values = worst_values
        self.visitable_states = visitable_states

class EHResults:
    def __init__(self, ks, ms, vars, gaps, effective_horizon):
        self.ks = ks
        self.ms = ms
        self.vars = vars
        self.gaps = gaps
        self.effective_horizon = effective_horizon

REWARD_PRECISION = 1e-4

# setup lock and findmax
lock = threading.Lock()

def findmax(arr):
    max_val = np.max(arr)
    max_index = np.argmax(arr)
    return max_val, max_index

In [11]:
### Pipeline: functions to get bounds and their helpers ###

## OUTSTANDING TODO: ##
    # substitute env.P with transition table generating code loaded from consolidated.npz file we can load in per mdp from dataset
    # ensure all variables are being updated in the proper scope (ie. threading modifications are passed through to the variables at least when the threading is over)
    # clean up documentation

# multithreading value_iteration helper: process a timestep of vi visitable states to get output vals
def process_timestep_thread(ts, transitions, rewards, state, num_actions, horizon, exploration_values, exploration_qs, optimal_qs, optimal_values, worst_qs, worst_values,exploration_policy=None):
    for action in num_actions:
        obs = transitions[state][action]
        reward = rewards[state][action]
        # _, obs, reward, _ = env.P[state][action]
        if ts < horizon:
            exploration_qs[ts, state, action] = reward + exploration_values[ts+1, obs]
            optimal_qs[ts, state, action] = reward + optimal_values[ts+1, obs]
            worst_qs[ts, state, action] = reward + worst_values[ts+1, obs]
        else:
            exploration_qs[ts, state, action] = reward
            optimal_qs[ts, state, action] = reward
            worst_qs[ts, state, action] = reward
    
    optimal_value = max(optimal_qs[ts, state, :])
    worst_value = min(worst_qs[ts, state, :])
    if exploration_policy == None:
        exploration_value =  sum(exploration_qs[ts, state, :]) / num_actions
    else:
        exploration_value=0
        for action in range(num_actions):
            exploration_value += exploration_qs[ts, state, action] * exploration_policy[ts, state, action]
    
    optimal_values[ts, state] = optimal_value
    worst_values[ts, state] = worst_value

    # verification: no nans and no float errs
    # ensure exploration val is bw worst and optimal vals (avoid floating pt errors)
    exploration_values[ts, state] = min(optimal_value, max(worst_value, exploration_value))
    assert(not np.isnan(exploration_values[ts, state]))
    assert(not np.isnan(optimal_values[ts, state]))
    assert(not np.isnan(worst_values[ts, state]))


# value iteration: calculate exploration_qs, exploration_values, optimal_qs, optimal_values, worst_qs, worst_values, visitable_states, returns in single results structure in this order
def value_iteration(transitions, rewards, horizon, exploration_policy=None):
    # get constants from environment
    num_states = len(transitions) #env.observation_space.n
    num_actions = len(transitions[0]) #env.action_space.n

    # get list of visitable states where each item is a set of visitable states at a given timestep (index)
    visitable_states = []
    current_visitable_states = set()
    current_visitable_states.add(0)
    for _ in tqdm(range(horizon)):
        next_visitable_states = set()
        for state in current_visitable_states:
            for action in range(num_actions):
                next_state = transitions[state][action]
                # _, next_state, _, _ = env.P[state][action]
                next_visitable_states.add(next_state)
        visitable_states.append(next_visitable_states.union(current_visitable_states))
        current_visitable_states = next_visitable_states
    
    # initialize outputs
    exploration_qs = np.full((horizon, num_states, num_actions), np.nan)
    exploration_values = np.full((horizon, num_states), np.nan)
    optimal_qs = np.full((horizon, num_states, num_actions), np.nan)
    optimal_values = np.full((horizon, num_states), np.nan)
    worst_qs = np.full((horizon, num_states, num_actions), np.nan)
    worst_values = np.full((horizon, num_states), np.nan)

    # get output vals for vi
    for timestep in tqdm(range(horizon)):
        # thread across different states in a given timestep, so no overwriting vals in data structures
        threads = []
        for state in visitable_states[timestep]:
            thread = threading.Thread(target=process_timestep_thread(timestep, transitions, rewards, state, num_actions, horizon, exploration_values, exploration_qs, optimal_qs, optimal_values, worst_qs, worst_values,exploration_policy))
            threads.append(thread)
        # run the threads
        for thread in threads:
            thread.start()
        # merge the results of the threads
        for thread in threads:
            thread.join()

    # define results struct and return results
    results = ViResults(exploration_qs, exploration_values, optimal_qs, optimal_values, worst_qs, worst_values, visitable_states)
    return results


# compute variance bounds: finds a bound on the variance of the qs based on the best and worst qs for each state at a timestep
def compute_variance_thread(var_bounds, ts, state, num_actions, vi):
    for action in range(num_actions):
        q = vi.exploration_qs[ts, state, action]
        worst_q = vi.worst_qs[ts,state,action]
        optimal_q = vi.optimal_qs[ts, state, action]
        var_bound = (q - worst_q) * (optimal_q - worst_q)
        var_bounds[ts, state, action] = var_bound


# k iteration thread: checks k validity and adds to seeable next states in following timestep
def k_working_thread(k_works, ts, state, states_can_be_visited, num_actions, current_qs, vi, var_bounds, state_gaps, state_vars, horizon,k, state_ms, transitions, rewards):
    if states_can_be_visited[ts, state]:
        # init vals
        max_q = -np.inf
        max_suboptimal_q = -np.inf
        max_var = 0
        # iterate over actions to get actual vals
        for action in range(num_actions):
            q = current_qs[ts, state, action]
            max_q = max(max_q, q)
            if vi.optimal_qs[ts,state,action] < vi.optimal_values[ts, state] - REWARD_PRECISION:
                max_suboptimal_q = max(max_suboptimal_q, q)
            max_var = max(max_var, var_bounds[ts, state, action])
            # check for k fail condition
            if max_q == max_suboptimal_q:
                with lock:
                    k_works = False
            # otherwise get the state m value
            else:
                gap = max_q - max_suboptimal_q
                state_gaps[ts, state] = gap
                state_vars[ts, state] = max_var
                m = math.ceil(16 * max_var / (gap**2) * math.log(2 * horizon * (num_actions**k)))
                state_ms[ts, state] = max(1,m)
            # iterate through actions to find next visitable states
                for action in range(num_actions):
                    if current_qs[ts, state, action] > max_suboptimal_q:
                        next_state = transitions[state][action]
                        # _, next_state, _, _ = env.P[state][action]
                        if ts < horizon:
                            states_can_be_visited[ts+1, next_state] = True


# bellman backup thread: update the current qs and variance bounds for the run
def bellman_backup_thread(ts, num_actions, transitions, rewards, state, current_qs, var_bounds):
    for action in range(num_actions):
        next_state = transitions[state][action]
        reward = rewards[state][action]
        #_, next_state, reward, _ = env.P[state][action]
        max_next_q = -np.inf
        max_next_var_bound = 0
        for action in range(num_actions):
            next_q = current_qs[ts+1, next_state, action]
            max_next_q = max(max_next_q, next_q)
            next_var_bound = var_bounds[ts+1, next_state, action]
            max_next_var_bound = max(max_next_var_bound, next_var_bound)
        current_qs[ts, state, action] = reward + max_next_q
        var_bounds[ts, state, action] = max_next_var_bound


# calculate EH bound using GORP bounds
def get_EH_bound(transitions, rewards, horizon, exploration_policy=None):
    num_states = len(transitions)#env.observation_space.n
    num_actions = len(transitions[0])#env.action_space.n
    # Perform value iteration
    vi = value_iteration(transitions, rewards, horizon, exploration_policy)
    
    # initialize and compute variance bounds
    var_bounds = np.full((horizon, num_states, num_actions), np.nan)
    for timestep in tqdm(range(horizon)):
        # thread across different states in a given timestep, so no overwriting vals in data structures
        threads = []
        for state in vi.visitable_states[timestep]:
            thread = threading.Thread(target=compute_variance_thread(var_bounds, timestep, state, num_actions, vi))
            threads.append(thread)
        # run the threads
        for thread in threads:
            thread.start()
        # merge the results of the threads
        for thread in threads:
            thread.join()

    # initialize q values, starting k, and EH results structure
    current_qs = vi.exploration_qs
    results = EHResults([],[],[],[],horizon)
    k = 1

    # find best working k
    while k < results.effective_horizon:
        # initialize results objects and iteration
        k_works = True
        state_ms = np.zeros(horizon, num_states)
        state_vars = np.zeros(horizon, num_states)
        state_gaps = np.zeros(horizon, num_states)
        states_can_be_visited = np.full((horizon, num_states), False)
        states_can_be_visited[0,0] = True
        
        for timestep in tqdm(range(horizon)):
            # thread across different states in a given timestep, so no overwriting vals in data structures
            threads = []
            for state in vi.visitable_states[timestep]:
                thread = threading.Thread(target=k_working_thread(k_works, timestep, state, states_can_be_visited, num_actions, current_qs, vi, var_bounds, state_gaps, state_vars, horizon, k, state_ms, transitions, rewards))
                threads.append(thread)
            # run the threads
            for thread in threads:
                thread.start()
            # merge the results of the threads
            for thread in threads:
                thread.join()
        if not k_works:
            break
    
    # if k works flag never triggered, then k is horizon and we update result vals accordingly
    if k_works:
        results.ks.append(k)
        highest_m, timestep_state = findmax(state_ms)
        timestep, state = timestep_state
        results.ms.append(highest_m)
        # log of highest_m with base num_actions
        H_k = k + math.log(highest_m, num_actions)
        print(f"H_{k} = {H_k}")
        results.gaps.append(state_gaps[timestep, state])
        results.vars.append(state_vars[timestep, state])
        results.effective_horizon = min(results.effective_horizon, H_k)
    
    # run a bellman backup
    for timestep in tqdm(range(horizon-1)):
        # thread across different states in a given timestep, so no overwriting vals in data structures
            threads = []
            for state in vi.visitable_states[timestep]:
                thread = threading.Thread(target=bellman_backup_thread(timestep, num_actions, transitions, rewards, state, current_qs, var_bounds))
                threads.append(thread)
            # run the threads
            for thread in threads:
                thread.start()
            # merge the results of the threads
            for thread in threads:
                thread.join()
                
    k += 1
    return results
    

# k_working_epw thread: checks if given state works for k
def k_working_epw_thread(k_works, states_can_be_visited, ts, state, num_actions, current_qs, vi, transitions, rewards, horizon):
    if states_can_be_visited[ts, state]:
        # init and find max q
        max_q = -np.inf
        for action in range(num_actions):
            max_q = max(max_q, current_qs[ts, state, action])
        for action in range(num_actions):
            # check if possible to take this action
            if current_qs[ts, state, action] >= max_q:
                if vi.optimal_qs[ts, state, action] < vi.optimal_values[ts,state] - REWARD_PRECISION:
                    k_works = False
                next_state = transitions[state][action]
                #_, next_state, _, _ = env.P[state][action]
                if ts < horizon:
                    states_can_be_visited[ts+1, next_state] = True


# epw_bellman_backup, modified bellman backup for epw
def epw_bellman_backup_thread(ts, num_actions, state, transitions, rewards, current_qs):
    for action in range(num_actions):
        next_state = transitions[state][action]
        reward = rewards[state][action]
        #_, next_state, reward, _ = env.P[state][action]
        max_next_q = -np.inf
        for action in range(num_actions):
            next_q = current_qs[ts+1, next_state, action]
            max_next_q = max(max_next_q, next_q)
        current_qs[ts, state, action] = reward + max_next_q


# calculate EPW
def get_EPW(transitions, rewards, horizon, exploration_policy=None, start_with_rewards=True):
    # calculate min k for the gym environment and return it
    num_states = len(transitions)#env.observation_space.n
    num_actions = len(transitions[0])#env.action_space.n
    # Perform value iteration
    vi = value_iteration(transitions, rewards, horizon, exploration_policy)

    # init current_qs
    if start_with_rewards:
        current_qs = np.full((horizon, num_states, num_actions), np.nan)
        for timestep in tqdm(range(horizon)):
            for state in vi.visitable_states[timestep]:
                for action in range(num_actions):
                    current_qs[timestep, state, action] = rewards[state][action]
                    #_, _, current_qs[timestep, state, action], _ = env.P[state][action]
    else:
        current_qs = vi.exploration_qs

    #
    k = 1
    while True:
        # check if this k value works
        k_works = True
        states_can_be_visited = np.full((horizon, num_states), False)
        states_can_be_visited[0,0] = True
        for timestep in tqdm(range(horizon)):
            # thread across different states in a given timestep, so no overwriting vals in data structures
            threads = []
            for state in vi.visitable_states[timestep]:
                thread = threading.Thread(target=k_working_epw_thread(k_works, states_can_be_visited, timestep, state, num_actions, current_qs, vi, transitions, rewards, horizon))
                threads.append(thread)
            # run the threads
            for thread in threads:
                thread.start()
            # merge the results of the threads
            for thread in threads:
                thread.join()
        if not k_works:
            break
    
    if k_works:
        return k
        
    # otherwise run bellman backup and up k
    for timestep in tqdm(range(horizon)):
        # thread across different states in a given timestep, so no overwriting vals in data structures
        threads = []
        for state in vi.visitable_states[timestep]:
            thread = threading.Thread(target=epw_bellman_backup_thread())
            threads.append(thread)
        # run the threads
        for thread in threads:
            thread.start()
        # merge the results of the threads
        for thread in threads:
            thread.join()

    k += 1

# Scratchwork

Figuring out .npz structure of transitions and rewards. From documentation:

`consolidated.npz`: the tabular representation of the MDP with consolidated states. States are consolidated if any sequence of actions from them will always lead to the same sequence of observations and rewards. The file format is a NumPy NPZ archive with two arrays, `transitions` and `rewards`, each of shape `(num_states, num_actions)`. `transitions[state, action]` gives the index of the next state reached by taking action `action` in state `state`, or -1 if the next state is terminal; `rewards[state, action]` gives the reward accompanying the aforementioned transition. The intial state has index 0.


In [12]:
import numpy as np

pong_table = np.load('/Users/laurenc/Documents/GitHub/282_expansion/data/bridge_dataset/mdps/pong_20_fs30/consolidated.npz')

In [15]:
# Print the keys (names of the arrays saved in the file)
print("Keys in the npz file:", list(pong_table.keys()))

# Print the structure and the first few rows of each array
for key in pong_table.keys():
    print("\nArray:", key)
    print("Shape:", pong_table[key].shape)
    print("Data type:", pong_table[key].dtype)
    print("First few rows:")
    print(pong_table[key][:5])  # Print the first 5 rows of the array

Keys in the npz file: ['rewards', 'transitions']

Array: rewards
Shape: (254, 6)
Data type: float32
First few rows:
[[ 0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.]
 [-1. -1. -1. -1. -1. -1.]]

Array: transitions
Shape: (254, 6)
Data type: int64
First few rows:
[[123 123  70  28  70  28]
 [ 66  66  66 242  66 242]
 [138 138 138 114 138 114]
 [ 41  41  41  20  41  20]
 [209 209 209 108 209 108]]
