In [None]:
import sys
sys.path.append('../')
from configs import stronger_tradeoff
from reward_model import CondensedRewardModel
from algorithms import TransitionModel, DumbTransitionModel, update_transition_matrix
import time
import itertools
from algorithms import update_reward_model
from transition_model import OracleTransitionModel
from gyms.sim_library import get_stronger_tradeoff_simulator
import torch
import numpy.random as npr
from gyms.sim_library import stronger_reward_function
import numpy as np
import matplotlib.pyplot as plt
from sklearn.utils import shuffle


cfg = stronger_tradeoff
state_p0 = cfg['state_p0']

action_names = cfg['action_names']
state_names = cfg['state_names']
delta_names = range(1, 11)

actions = list(range(len(action_names)))
states = list(range(len(state_names)))
deltas = list(range(len(delta_names)))


gamma = 0.99
action_cost = 5
horizon = 50

env = get_stronger_tradeoff_simulator(
    actions, states, state_p0, 
    gamma=gamma, 
    action_cost=action_cost, 
    horizon=horizon, 
    terminal_reward=0)

device = torch.device('cpu')
dumb = False
terminal_state = env.terminal_state
is_windygrid = False


In [2]:

def get_expanded_T(T, num_deltas):
    A, S1, S2 = T.shape
    assert S1 == S2
    expanded_T = torch.ones(A, S1, S2, num_deltas)
    for a in range(A):
        for k in range(num_deltas):
            T_delta = torch.matrix_power(T[a, :, :], k + 1)
            expanded_T[a, :, :, k] = T_delta
    return expanded_T


def sample_reward(T, s, a, delay, terminal_state, terminal_reward, cost, gamma, horizon, reward_function):
    T_a = T[a]
    reward = 0
    cur_s = s
    cur_t = 0  # steps taken
    terminated = (cur_s == terminal_state)

    if terminated:
        reward += terminal_reward
        return cur_s, reward, cur_t, terminated

    reward -= cost
    for _ in range(delay):
        discount = gamma ** cur_t
        reward += discount * reward_function(cur_s, a)
        cur_s = npr.multinomial(1, T_a[cur_s], size=None).tolist().index(1)
        cur_t += 1
        if cur_s == terminal_state:
            terminated = True
            break
            
    return cur_s, reward, cur_t, terminated

In [None]:



start = time.time()
reps = 30
Ns = [1, 2, 5, 10, 20, 50, 100]
terminal_state = env.terminal_state
terminal_reward = env.terminal_reward

def R_to_R_lookup(R, transition):
    R_lookup = {}
    for (s, a, k) in itertools.product(states, actions, deltas):
        expected_R = R.get_prediction(s, a, k, transition)                
        R_lookup[(s, a, k)] = expected_R
    return R_lookup


# get all oracle expected rewards
oracle_R = CondensedRewardModel(
    states, actions, deltas, delta_names,
    terminal_state, gamma, action_cost, device,
    is_windygrid=is_windygrid, env=env, 
    oracle_reward_lookup=True).to(device)

oracle_transition = OracleTransitionModel(env, actions, states, delta_names, device, dumb=False)
oracle_T = oracle_transition()
expanded_true_T = get_expanded_T(oracle_T, len(delta_names))
oracle_R_lookup = R_to_R_lookup(oracle_R, oracle_transition)

reward_function = lambda s, a: stronger_reward_function(s, a, terminal_reward=terminal_reward)

R_lr = 0.2
T_lr = 0.01
convergence = 1e-5
patience = 3
model_batchsize = 500
batchsize = model_batchsize
weight_by_state_action = False
max_iters = 5000
weight_by_actions = False

print(terminal_state, terminal_reward, states, actions, deltas)
all_R_oracleT_lookups = {}
all_R_smartT_lookups = {}
all_R_dumbT_lookups = {}
all_empirical = {}
for i, N in enumerate(Ns):
    print('N: ', N, 'deltas: ', deltas)
    R_oracleT_lookups = []
    R_smartT_lookups = []
    R_dumbT_lookups = []
    R_empiric = []
    
    smart_Ts = []
    dumb_Ts = []
    empiric_Ts = []
    for rep in range(reps):
        print(f'----- N: {N}, rep {rep} ({time.time() - start:.2f} sec) ------')
        dataset = []
        empirical_rs = {}
        for (s, a, k) in itertools.product(states, actions, deltas):
            for _ in range(N):
                delay = k + 1
                s2, r, cur_t, terminated = sample_reward(oracle_T, s, a, delay, 
                                                         terminal_state, terminal_reward, 
                                                         action_cost, gamma, horizon, reward_function)
                dataset.append((a, k, s, s2, r))
                empirical_rs[(s, a, k)] = empirical_rs.get((s, a, k), []) + [r]
        dataset = shuffle(dataset)
        
        oracle_transition = OracleTransitionModel(env, actions, states, delta_names, device, dumb=False)
        oracle_T = oracle_transition()
        
        ## condensed reward + oracle transition
        R_oracleT = CondensedRewardModel(
            states, actions, deltas, delta_names,
            terminal_state, gamma, action_cost, device,
            is_windygrid=is_windygrid, env=env, 
            oracle_reward_lookup=False).to(device)
        r_loss_func = torch.nn.MSELoss()
        r_optimizer = torch.optim.Adam(R_oracleT.parameters(), lr=R_lr)

        R_oracleT, avg_r_loss = update_reward_model(
            dataset, R_oracleT, r_loss_func, r_optimizer, device,
            convergence=convergence, 
            patience=patience, 
            batchsize=model_batchsize,
            max_iters=max_iters,
            transition_model=oracle_transition,
            weight_by_actions=weight_by_actions)
        R_oracleT_lookup = R_to_R_lookup(R_oracleT, oracle_transition)
        R_oracleT_lookups.append(R_oracleT_lookup)
        
        ## condensed reward + smart transition
        smart_transition = TransitionModel(actions, states, delta_names, device, terminal_state)
        smart_optimizer = torch.optim.SGD(smart_transition.parameters(), lr=T_lr)
        smart_transition, nll_avg = update_transition_matrix(
            dataset, 
            smart_transition, 
            smart_optimizer,
            torch.device('cpu'),
            convergence=convergence, 
            patience=patience, 
            batchsize=batchsize,
            weight_by_actions=weight_by_actions,
            weight_by_state_action=weight_by_state_action) 
        est_T = smart_transition()
        expanded_est_T = get_expanded_T(est_T, len(delta_names))
        smart_Ts.append(expanded_est_T)
  
        R_smartT = CondensedRewardModel(
            states, actions, deltas, delta_names,
            terminal_state, gamma, action_cost, device,
            is_windygrid=is_windygrid, env=env, 
            oracle_reward_lookup=False).to(device)
        r_smartT_optimizer = torch.optim.Adam(R_smartT.parameters(), lr=R_lr)

        R_smartT, avg_r_loss = update_reward_model(
            dataset, R_smartT, r_loss_func, r_smartT_optimizer, device,
            convergence=convergence, 
            patience=patience, 
            batchsize=model_batchsize,
            max_iters=max_iters,
            transition_model=smart_transition,
            weight_by_actions=weight_by_actions)
        R_smartT_lookup = R_to_R_lookup(R_smartT, smart_transition)
        R_smartT_lookups.append(R_smartT_lookup)
        
        ## condensed reward + dumb transition
        dumb_transition = DumbTransitionModel(actions, states, delta_names, device, terminal_state)
        dumb_optimizer = torch.optim.SGD(dumb_transition.parameters(), lr=T_lr)
        dumb_transition, nll_avg = update_transition_matrix(
            dataset, 
            dumb_transition, 
            dumb_optimizer,
            torch.device('cpu'),
            convergence=convergence, 
            patience=patience, 
            batchsize=batchsize,
            weight_by_actions=weight_by_actions,
            weight_by_state_action=weight_by_state_action) 
        expanded_est_T = dumb_transition()
        expanded_est_T = expanded_est_T.permute(0, 2, 3, 1)
        dumb_Ts.append(expanded_est_T)
        
        R_dumbT = CondensedRewardModel(
            states, actions, deltas, delta_names,
            terminal_state, gamma, action_cost, device,
            is_windygrid=is_windygrid, env=env, 
            oracle_reward_lookup=False).to(device)
        r_dumbT_optimizer = torch.optim.Adam(R_dumbT.parameters(), lr=R_lr)

        R_dumbT, avg_r_loss = update_reward_model(
            dataset, R_dumbT, r_loss_func, r_dumbT_optimizer, device,
            convergence=convergence, 
            patience=patience, 
            batchsize=model_batchsize,
            max_iters=max_iters,
            transition_model=dumb_transition,
            weight_by_actions=weight_by_actions)
        R_dumbT_lookup = R_to_R_lookup(R_dumbT, dumb_transition)
        R_dumbT_lookups.append(R_dumbT_lookup)
        
        empirical_rs = {tup: np.mean(rs) for (tup, rs) in empirical_rs.items()}
        R_empiric.append(empirical_rs)
        
        print('R_oracleT estimated lookup: ', R_oracleT.reward_lookup)
        print('R_smartT estimated lookup: ', R_smartT.reward_lookup)
        print('R_dumbT estimated lookup: ', R_dumbT.reward_lookup)
        print('true lookup: ', oracle_R.reward_lookup)
            
    all_R_oracleT_lookups[N] = R_oracleT_lookups
    all_R_smartT_lookups[N] = R_smartT_lookups
    all_R_dumbT_lookups[N] = R_dumbT_lookups
    all_empirical[N] = R_empiric
    

In [None]:
fig, ax = plt.subplots(figsize=(10, 5))

def plot_comparison(ax, all_R_lookups, label=''):
    # compute comparison to oracle expected aggregate rewards
    N_to_dists = {}
    for N, R_lookups in all_R_lookups.items():
        dists = []
        for R_lookup in R_lookups:
            dist = []
            for (s, a, k), g in R_lookup.items():
                gstar = oracle_R_lookup[(s, a, k)]
                dist.append(abs(g - gstar))
            dists.append(max(dist).item())
        N_to_dists[N] = dists

    means = np.array([np.mean(N_to_dists[N]) for N in Ns])
    stds = np.array([np.std(N_to_dists[N]) / np.sqrt(reps) for N in Ns])

    ax.plot(Ns, means, 'o-', label=label)
    ax.fill_between(Ns, means - 1.69 * stds, means + 1.69 * stds, alpha=0.1)
    
plot_comparison(ax, all_R_oracleT_lookups, label=r'learning $\widehat{R}$, oracle $\widehat{P}$')
plot_comparison(ax, all_R_smartT_lookups, label=r'learning $\widehat{R}$, timing-aware $\widehat{P}$')
plot_comparison(ax, all_R_dumbT_lookups, label=r'learning $\widehat{R}$, timing-naive $\widehat{P}$')
plot_comparison(ax, all_empirical, label=r'empirical avg')


ax.legend(fontsize=14)
ax.set_title(r'$||\mathcal{G}_{P, R} - \mathcal{G}_{\widehat{P}, \widehat{R}}||_{\infty}$ vs. n', fontsize=16)
ax.set_xlabel('n', fontsize=14)
ax.set_ylabel('Estimation error', fontsize=14)
ax.set_yscale('log')
plt.savefig('reward_estimation_error.pdf', dpi=300)