In [1]:
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import pandas as pd
from tqdm.notebook import tqdm

from jax.nn import softmax
from jax.config import config
from pathlib import Path
from collections import namedtuple

config.update('jax_platform_name', 'cpu')
np.set_printoptions(precision=4)
plt.rcParams['axes.facecolor'] = 'white'
plt.rcParams.update({'font.size': 18})

from grl.utils import load_info
from grl.utils.data import uncompress_episode_rewards

from definitions import ROOT_DIR

In [2]:
results_dir = Path(ROOT_DIR, 'results', 'rnn_reruns_td')

In [3]:

offline_eval = []
# results_path = list(results_dir.iterdir())[10]

for results_path in tqdm(list(results_dir.iterdir())):
    if results_path.is_dir() or results_path.suffix != '.npy':
        continue
    
    info = load_info(results_path)
    args = info['args']

    offline_evals = info['episodes_info']['offline_eval']
    eval_freq = args['offline_eval_freq']
    total_steps = args['total_steps']

    d = {**args}
    
    all_t_undisc_returns = []
    for i, oe in enumerate(offline_evals):
        avg_undisc_returns = 0
        for ep in oe['episode_rewards']:
            episode_rewards = np.array(uncompress_episode_rewards(ep['episode_length'], ep['most_common_reward'], ep['compressed_rewards']))

            # TODO: calculate value estimation error through q_t - g_t
            undisc_returns = episode_rewards.sum()
            avg_undisc_returns += undisc_returns

        avg_undisc_returns /= len(oe['episode_rewards'])

        all_t_undisc_returns.append(undisc_returns)

    all_t_undisc_returns = np.array(all_t_undisc_returns)
    assert all_t_undisc_returns.shape[0] == total_steps // eval_freq
    
    d['undisc_returns'] = all_t_undisc_returns
    offline_eval.append(d)

offline_eval_df = pd.DataFrame(offline_eval)

  0%|          | 0/960 [00:00<?, ?it/s]

In [24]:
offline_lstm_eval_df = offline_eval_df[offline_eval_df['arch'] == 'lstm']
unique_seeds = offline_lstm_eval_df.seed.unique()

split_by_args = ['spec', 'multihead_action_mode', 'multihead_loss_mode', 'multihead_lambda_coeff']
lstm_grouped = offline_lstm_eval_df.groupby(split_by_args, as_index=False)

In [79]:
lstm_res = lstm_grouped.mean(numeric_only=True)
lstm_res['undisc_returns_means'] = lstm_grouped.undisc_returns.apply(lambda g: g.mean())
lstm_res['undisc_returns_std_err'] = lstm_grouped.undisc_returns.agg(np.stack)['undisc_returns'].apply(lambda g: g.std(axis=0) / np.sqrt(g.shape[0]))


Unnamed: 0,spec,multihead_action_mode,multihead_loss_mode,multihead_lambda_coeff,no_gamma_terminal,max_episode_steps,epsilon,lr,hidden_size,value_head_layers,...,batch_size,offline_eval_freq,offline_eval_episodes,offline_eval_epsilon,checkpoint_freq,save_all_checkpoints,total_steps,seed,undisc_returns_means,undisc_returns_std_err
0,4x3.95,td,both,-1.0,0.0,1000.0,0.1,0.001,12.0,0.0,...,1.0,1000.0,5.0,0.1,-1.0,0.0,150000.0,2024.5,"[-45.12800000000001, -44.848000000000006, -42....","[5.210235464928625, 5.284558032607837, 5.14065..."
1,4x3.95,td,both,0.0,0.0,1000.0,0.1,0.001,12.0,0.0,...,1.0,1000.0,5.0,0.1,-1.0,0.0,150000.0,2024.5,"[-45.12800000000001, -44.848000000000006, -42....","[5.210235464928625, 5.284558032607836, 5.14065..."
2,4x3.95,td,both,1.0,0.0,1000.0,0.1,0.001,12.0,0.0,...,1.0,1000.0,5.0,0.1,-1.0,0.0,150000.0,2024.5,"[-45.12799999999999, -44.84800000000001, -42.0...","[5.210235464928626, 5.284558032607836, 5.14065..."
3,4x3.95,td,td,-1.0,0.0,1000.0,0.1,0.001,12.0,0.0,...,1.0,1000.0,5.0,0.1,-1.0,0.0,150000.0,2024.5,"[-49.704, -39.53600000000001, -41.704000000000...","[3.1571744329384153, 5.941720154971959, 3.8979..."
4,4x3.95,td,td,0.0,0.0,1000.0,0.1,0.001,12.0,0.0,...,1.0,1000.0,5.0,0.1,-1.0,0.0,150000.0,2024.5,"[-49.70400000000001, -39.53600000000001, -41.7...","[3.1571744329384153, 5.941720154971959, 3.8979..."
5,4x3.95,td,td,1.0,0.0,1000.0,0.1,0.001,12.0,0.0,...,1.0,1000.0,5.0,0.1,-1.0,0.0,150000.0,2024.5,"[-49.704, -39.536, -41.704, -49.85600000000001...","[3.157174432938415, 5.941720154971959, 3.89799..."
6,cheese.95,td,both,-1.0,0.0,1000.0,0.1,0.001,12.0,0.0,...,1.0,1000.0,5.0,0.1,-1.0,0.0,150000.0,2024.5,"[0.6, 0.9, 1.5, 1.0, 0.4, 1.9, 3.4, 3.0, 3.3, ...","[0.25298221281347033, 0.6549809157525127, 0.49..."
7,cheese.95,td,both,0.0,0.0,1000.0,0.1,0.001,12.0,0.0,...,1.0,1000.0,5.0,0.1,-1.0,0.0,150000.0,2024.5,"[0.6, 0.9, 1.5, 1.0, 0.4, 1.9, 3.4, 3.0, 3.3, ...","[0.25298221281347033, 0.6549809157525125, 0.49..."
8,cheese.95,td,both,1.0,0.0,1000.0,0.1,0.001,12.0,0.0,...,1.0,1000.0,5.0,0.1,-1.0,0.0,150000.0,2024.5,"[0.6, 0.9, 1.5, 1.0, 0.4, 1.9, 3.4, 3.0, 3.3, ...","[0.25298221281347033, 0.6549809157525127, 0.49..."
9,cheese.95,td,td,-1.0,0.0,1000.0,0.1,0.001,12.0,0.0,...,1.0,1000.0,5.0,0.1,-1.0,0.0,150000.0,2024.5,"[2.1, 2.9, 16.3, 2.6, 4.6, 5.1, 4.0, 11.0, 3.7...","[1.1441153787970861, 2.036909423612155, 9.2661..."
