In [2]:
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from jax.nn import softmax
from jax.config import config
from pathlib import Path
from collections import namedtuple

config.update('jax_platform_name', 'cpu')
np.set_printoptions(precision=4)
plt.rcParams['axes.facecolor'] = 'white'
plt.rcParams.update({'font.size': 22})

from grl.mdp import MDP, AbstractMDP
from grl.environment import load_spec

from grl.utils import load_info
from grl.utils.mdp import get_perf
from definitions import ROOT_DIR

In [3]:
from jax import jit
from jax.tree_util import register_pytree_node


# class Test:
#     def __init__(self, a: jnp.ndarray, b: jnp.ndarray):
#         self.a = a
#         self.b = b
    
#     @property
#     def r_max(self):
#         return jnp.max(a)
    
#     def add(self):
#         return a + b
    
# register_pytree_node(
#     Test,
#     lambda x: ((x.a, x.b), None),
#     lambda _, xs: Test(*xs)
# )

# def hello(t: Test):
#     return t.a

# jitted_hello = jit(hello)
# a = jnp.zeros((3, 2))
# b = jnp.ones((3, 2))
# test = Test(a, b)
# jitted_hello(test)

spec = load_spec('tmaze_5_two_thirds_up')

def test_amdp(amdp: AbstractMDP):
    return amdp.T

jitted_test_amdp = jit(test_amdp)
mdp = MDP(spec['T'], spec['R'], spec['p0'], gamma=spec['gamma'])
amdp = AbstractMDP(mdp, spec['phi'])
jitted_test_amdp(amdp)

Array([[[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0.,

In [None]:
set_count_1 = [0.0693, 0.,     0.0693, 0.,     0.2567, 0.254,  0.0336, 0.0401, 0.1386, 0.1386]
no_set_count = [0.05,   0.,     0.05,   0.,     0.1852, 0.1833, 0.0242, 0.0289, 0.3103, 0.168 ]

no_set_weight = 
[[0.     0.     0.     0.     0.     0.     0.0161 0.0193 0.3103 0.168 ]
 [0.     0.     0.     0.     0.     0.     0.0081 0.0096 0.     0.    ]
 [0.05   0.     0.05   0.     0.1852 0.1833 0.     0.     0.     0.    ]
 [0.     0.     0.     0.     0.     0.     0.     0.     0.     0.    ]]

set_term_1_weight = 
[[0.     0.     0.     0.     0.     0.     0.0158 0.0333 0.1386 0.1386]
 [0.     0.     0.     0.     0.     0.     0.0079 0.0166 0.     0.    ]
 [0.0693 0.     0.0693 0.     0.2133 0.2974 0.     0.     0.     0.    ]
 [0.     0.     0.     0.     0.     0.     0.     0.     0.     0.    ]]

uniform_weight = 
[[0.     0.     0.     0.     0.     0.     0.0667 0.0667 0.1    0.1   ]
 [0.     0.     0.     0.     0.     0.     0.0333 0.0333 0.     0.    ]
 [0.1    0.1    0.1    0.1    0.1    0.1    0.     0.     0.     0.    ]
 [0.     0.     0.     0.     0.     0.     0.     0.     0.     0.    ]]

In [None]:
# results_dir = Path(ROOT_DIR, 'results', 'pomdps_mi_pi')
results_dir = Path(ROOT_DIR, 'results', 'all_pomdps_mi_pi')

# results_dir = Path(ROOT_DIR, 'results', 'pomdps_mi_pi_q_abs')
# results_dir = Path(ROOT_DIR, 'results', 'pomdps_mi_dm')
vi_results_dir = Path(ROOT_DIR, 'results', 'pomdps_vi')
pomdp_files_dir = Path(ROOT_DIR, 'grl', 'environment', 'pomdp_files')

split_by = ['spec', 'algo', 'n_mem_states']
Args = namedtuple('args', split_by)
# this option allows us to compare to either the optimal belief state soln
# or optimal state soln. ('belief' | 'state')
compare_to = 'belief'

# spec_plot_order = ['example_7', 'slippery_tmaze_5_two_thirds_up',
#                    'tiger', 'paint.95', 'cheese.95',
#                    'network', 'shuttle.95', '4x3.95']
spec_plot_order = [
#     'example_7', 
    'tmaze_5_two_thirds_up', 'tiger-alt-start', 'paint.95', 'cheese.95', 'network',
    'shuttle.95', '4x3.95'
]

spec_to_belief_state = {'tmaze_5_two_thirds_up': 'tmaze5'}

In [None]:
all_results = {}

for results_path in results_dir.iterdir():
    if results_path.is_dir() or results_path.suffix != '.npy':
        continue
    info = load_info(results_path)

    args = info['args']

    if 'n_mem_states' not in args:
        args['n_mem_states'] = 2
        
    # agent = info['agent']
    init_policy_info = info['logs']['initial_policy_stats']
    init_improvement_info = info['logs']['greedy_initial_improvement_stats']
    final_mem_info = info['logs']['greedy_final_mem_stats']

    single_res = {
        'init_policy_perf': get_perf(init_policy_info),
        'init_improvement_perf': get_perf(init_improvement_info),
        'final_mem_perf': get_perf(final_mem_info),
        'init_policy': info['logs']['initial_policy'],
        'init_improvement_policy': info['logs']['initial_improvement_policy'],
        # 'final_mem': np.array(agent.memory),
        # 'final_policy': np.array(agent.policy)
    }

    hparams = Args(*tuple(args[s] for s in split_by))

    if hparams not in all_results:
        all_results[hparams] = {}

    for k, v in single_res.items():
        if k not in all_results[hparams]:
            all_results[hparams][k] = []
        all_results[hparams][k].append(v)
    all_results[hparams]['args'] = args

for hparams, res_dict in all_results.items():
    for k, v in res_dict.items():
        if k != 'args':
            all_results[hparams][k] = np.stack(v)

In [None]:
if compare_to == 'belief':
    for fname in pomdp_files_dir.iterdir():
        if 'pomdp-solver-results' in fname.stem:
            for hparams in all_results.keys():
                if (fname.stem ==
                        f"{spec_to_belief_state.get(hparams.spec, hparams.spec)}-pomdp-solver-results"
                    ):
                    belief_info = load_info(fname)
                    coeffs = belief_info['coeffs']
                    max_start_vals = coeffs[belief_info['max_start_idx']]
                    all_results[hparams]['compare_perf'] = np.array(
                        [np.dot(max_start_vals, belief_info['p0'])])
                    # print(f"loaded results for {hparams.spec} from {fname}")

elif compare_to == 'state':
    for hparams, res_dict in all_results.items():
        for vi_path in vi_results_dir.iterdir():
            if hparams.spec in vi_path.name:
                vi_info = load_info(vi_path)
                all_results[hparams]['compare_perf'] = np.array([
                    (vi_info['optimal_vs'] * vi_info['p0']).sum()
                ])
else:
    raise NotImplementedError

In [None]:
for hparams, res in all_results.items():
    max_key = 'compare_perf'
    if max_key not in res:
        print(hparams)
    #     max_key = 'final_mem_perf'
    max_v = res[max_key]
    min_v = res['init_policy_perf']
    for k, v in res.items():
        if '_perf' in k:
            all_results[hparams][k] = (v - min_v) / (max_v - min_v)

In [None]:
all_table_results = {}
all_plot_results = {'x': [], 'xlabels': []}

for i, spec in enumerate(spec_plot_order):
    hparams = sorted([k for k in all_results.keys() if k.spec == spec],
                     key=lambda hp: hp.n_mem_states)

    first_res = all_results[hparams[0]]
    all_plot_results['x'].append(i)
    all_plot_results['xlabels'].append(spec)

    # we first add initial and first improvement stats
    for k, v in first_res.items():
        if 'perf' in k and k != 'final_mem_perf':
            mean = v.mean(axis=0)
            std_err = v.std(axis=0) / np.sqrt(v.shape[0])

            stripped_str = k.replace('_perf', '')
            if stripped_str not in all_plot_results:
                all_plot_results[stripped_str] = {'mean': [], 'std_err': []}
            all_plot_results[stripped_str]['mean'].append(mean)
            all_plot_results[stripped_str]['std_err'].append(std_err)

    # now we add final memory perf, for each n_mem_states
    for hparam in hparams:
        res = all_results[hparam]
        for k, v in res.items():
            if k == 'final_mem_perf':
                mean = v.mean(axis=0)
                std_err = v.std(axis=0) / np.sqrt(v.shape[0])
                stripped_str = k.replace('_perf', '')
                mem_label = f"mem_{hparam.n_mem_states}"
                if mem_label not in all_plot_results:
                    all_plot_results[mem_label] = {'mean': [], 'std_err': []}
                all_plot_results[mem_label]['mean'].append(mean)
                all_plot_results[mem_label]['std_err'].append(std_err)

ordered_plot = []
# ordered_plot.append(('init_policy', all_plot_results['init_policy']))
ordered_plot.append(('init_improvement', all_plot_results['init_improvement']))
for k in sorted(all_plot_results.keys()):
    if 'mem' in k:
        ordered_plot.append((k, all_plot_results[k]))
# ordered_plot.append(('state_optimal', all_plot_results['vi']))
# v.shape[0]

In [None]:
all_plot_results

In [None]:
# filtered plots, where we only show n_mem_state = 2
def maybe_spec_map(id: str):
    spec_map = {
        '4x3.95': '4x3',
        'cheese.95': 'Cheese\nMaze',
        'paint.95': 'Paint',
        'shuttle.95': 'Shuttle',
        'example_7': 'ex. 7',
        'network': 'Network',
        'tmaze_5_two_thirds_up': 'T-maze',
        'tiger-alt-start': 'Tiger',
        'tiger-alt': 'Tiger'

    }
    if id not in spec_map:
        return id
    return spec_map[id]

def label_and_color_map(label_str: str):
    color_mapping = {
        'init_improvement': '#DF5B5D', # red
        'mem_2': '#5C97DF', # blue
        'mem_4': '#B99CDB', # purple
        'mem_6': '#F8E49B' # yellow
    }
    if 'mem' in label_str:
        n_mem_states = int(label_str.split('_')[-1])
        return f"{n_mem_states} Memory States", color_mapping[label_str]
    elif label_str == 'init_improvement':
        return "Memoryless", color_mapping[label_str]

filtered_ordered_plot = []
    
for (label, v) in ordered_plot:
    if 'mem' in label:
        if int(label.split('_')[-1]) > 2:
            continue

    filtered_ordered_plot.append((label, v)) 
            
group_width = 1
bar_width = group_width / (len(filtered_ordered_plot) + 2)
fig, ax = plt.subplots(figsize=(12, 6))

x = np.array(all_plot_results['x'])
xlabels = [maybe_spec_map(l) for l in all_plot_results['xlabels']]

for i, (label, plot_dict) in enumerate(filtered_ordered_plot):
    mapped_label, mapped_color = label_and_color_map(label)
    ax.bar(x + (i + 1) * bar_width,
           plot_dict['mean'],
           bar_width,
           yerr=plot_dict['std_err'],
           label=mapped_label,
           color=mapped_color
          )
ax.set_ylim([0, 1])
# ax.set_ylabel(f'Normalized Performance\n (w.r.t. optimal {compare_to} & random initial policy)')
ax.set_ylabel(f'Normalized performance \n (30 runs)')

ax.set_xticks(x + group_width / 2.65)
ax.set_xticklabels(xlabels)
ax.legend(bbox_to_anchor=(0.605, 1.02), framealpha=0.95)
fig.tight_layout()

# ax.set_title("Performance of Memory Iteration in POMDPs")

downloads = Path().home() / 'Downloads'
fig_path = downloads / f"{results_dir.stem}_2_mem.pdf"
fig.savefig(fig_path)

In [None]:
# unfiltered, all n_mem_state
group_width = 1
bar_width = group_width / (len(ordered_plot) + 2)
fig, ax = plt.subplots(figsize=(12, 6))

x = np.array(all_plot_results['x'])
xlabels = [maybe_spec_map(l) for l in all_plot_results['xlabels']]

for i, (label, plot_dict) in enumerate(ordered_plot):
    mapped_label, mapped_color = label_and_color_map(label)
    ax.bar(x + (i + 1) * bar_width,
           plot_dict['mean'],
           bar_width,
           yerr=plot_dict['std_err'],
           label=mapped_label,
           color=mapped_color
          )
ax.set_ylim([0, 1])
# ax.set_ylabel(f'Normalized Performance\n (w.r.t. optimal {compare_to} & random initial policy)')
ax.set_ylabel(f'Normalized performance \n (10 runs)')

ax.set_xticks(x + group_width / 2.4)
ax.set_xticklabels(xlabels)
ax.legend(bbox_to_anchor=(0.605, 1.02), framealpha=0.95)
fig.tight_layout()
# ax.set_title("Performance of Memory Iteration in POMDPs")

downloads = Path().home() / 'Downloads'
fig_path = downloads / f"{results_dir.stem}_all.pdf"
fig.savefig(fig_path)