In [244]:
import numpy as np
from jax.nn import softmax
from pathlib import Path
from collections import namedtuple
np.set_printoptions(precision=4)

from grl.utils import load_info
from definitions import ROOT_DIR

In [226]:
results_dir = Path(ROOT_DIR, 'results', 'runs_tmaze_dm')

split_by = ['spec', 'algo', 'mi_iterations', 'policy_optim_alg']
Args = namedtuple('args', split_by)

In [239]:
all_results = {}

for results_path in results_dir.iterdir():
    if results_path.suffix != '.npy':
        continue

    info = load_info(results_path)
    args = info['args']
    logs = info['logs']
    agent = info['agent']

    single_res = {
        'final_mem_discrep_v': logs['final_mem_discrep']['v'],
        'final_mem_discrep_q': logs['final_mem_discrep']['q'],
        'final_mem': agent.memory,
        'final_policy': agent.policy,
        'initial_improvement_policy': logs['initial_improvement_policy'],
        'initial_expanded_improvement_policy': logs['initial_expanded_improvement_policy'],
        'initial_improvement_discrep': logs['initial_improvement_discrep'],
        'initial_discrep': logs['initial_discrep'],
        'initial_policy': logs['initial_policy'],
    }

    hparams = Args(*tuple(args[s] for s in split_by))

    if hparams not in all_results:
        all_results[hparams] = {}
        
    for k, v in single_res.items():
        if k not in all_results[hparams]:
            all_results[hparams][k] = []
        all_results[hparams][k].append(v)
    all_results[hparams]['args'] = args
    
for hparams, res_dict in all_results.items():
    for k, v in res_dict.items():
        all_results[hparams][k] = np.stack(v)

In [240]:
all_results[list(all_results.keys())[0]]['final_mem'].shape

(30, 4, 5, 2, 2)

In [241]:
logs['final_outputs']

{'v': DeviceArray([1.04, 1.04, 1.15, 1.95, 0.  ], dtype=float64),
 'q': DeviceArray([[0.93, 0.93, 1.04, 1.95, 0.  ],
              [0.93, 0.93, 1.04, 1.95, 0.  ],
              [1.04, 1.04, 1.15, 1.75, 0.  ],
              [0.93, 0.93, 1.01, 1.04, 0.  ]], dtype=float64)}

In [249]:
res = all_results[list(all_results.keys())[0]]
idx = 1
mem = res['final_mem'][idx]
final_policy = res['final_policy'][idx]
first_improvement_policy = res['initial_improvement_policy'][idx]

initial_policy = res['initial_policy'][idx]

initial_expanded_improvement_policy = res['initial_expanded_improvement_policy'][idx]
initial_improvement_discrep = res['initial_improvement_discrep'][idx]
initial_discrep = res['initial_discrep'][idx]
right = mem[2]
# print(f"Memory for RIGHT action:\n"
#       f"Goal(Up) start obs, from mem[0]: {right[0, 0]}\n"
#       f"Goal(Down) start obs, from mem[0]: {right[1, 0]}\n"
#       f"Corridor obs mem func:\n{right[2]}")
print()
print(f"initial policy: \n {initial_policy}")
print("initial policy lambda discreps:")
print(initial_discrep['v'])
print(initial_discrep['q'])
print()
print(f"policy after first improvement: \n{first_improvement_policy}")
print("lambda discreps:")
print(initial_improvement_discrep['v'])
print(initial_improvement_discrep['q'])
print("MC Q-vals of policy after first improvement:")
print(initial_improvement_discrep['mc_vals_q'])
print("TD Q-vals of policy after first improvement:")
print(initial_improvement_discrep['td_vals_q'])


initial policy: 
 [[0.1802 0.2435 0.2942 0.282 ]
 [0.2775 0.2725 0.2003 0.2497]
 [0.218  0.2306 0.2531 0.2983]
 [0.308  0.2354 0.1596 0.297 ]
 [0.1675 0.23   0.2128 0.3897]]
initial policy lambda discreps:
[3.6342e-05 3.6342e-05 3.9138e-03 4.4655e-02 0.0000e+00]
[[2.9437e-05 2.9437e-05 3.1702e-03 1.9722e-31 0.0000e+00]
 [2.9437e-05 2.9437e-05 3.1702e-03 4.9304e-32 0.0000e+00]
 [5.5867e-05 7.1154e-05 1.5641e-02 3.6170e-02 0.0000e+00]
 [2.9437e-05 2.9437e-05 3.5920e-04 3.7123e-01 0.0000e+00]]

policy after first improvement: 
[[1.5399e-04 1.6015e-04 9.9952e-01 1.6269e-04]
 [1.9259e-04 1.9221e-04 9.9942e-01 1.9036e-04]
 [1.3115e-05 1.3172e-05 9.9997e-01 7.0919e-06]
 [9.9992e-01 3.5453e-06 4.7719e-05 3.1299e-05]
 [1.6747e-01 2.2996e-01 2.1283e-01 3.8974e-01]]
lambda discreps:
[1.1867e+00 1.1867e+00 6.4741e-02 2.8909e-10 0.0000e+00]
[[9.6122e-01 9.6122e-01 5.2441e-02 0.0000e+00 0.0000e+00]
 [9.6122e-01 9.6122e-01 5.2441e-02 0.0000e+00 0.0000e+00]
 [1.1868e+00 1.1868e+00 6.4742e-02 2.3417e-

In [106]:
print(f"Junction policy for memory 0: {policy[3*2]}")
print(f"Junction policy for memory 1: {policy[3*2 + 1]}")

print(f"Policy after first improvement:\n {first_improvement_policy}")
print(f"Policy after initial expansion:\n {initial_expanded_improvement_policy}")

Junction policy for memory 0: [0.93 0.03 0.03 0.03]
Junction policy for memory 1: [0.03 0.93 0.03 0.03]
Policy after first improvement:
 [[2.76e-02 9.53e-01 3.34e-05 1.90e-02]
 [7.07e-02 3.05e-02 8.48e-01 5.06e-02]
 [6.41e-01 2.70e-01 1.87e-03 8.77e-02]
 [2.94e-05 2.95e-01 7.05e-01 2.98e-05]
 [4.25e-01 2.09e-01 3.98e-02 3.26e-01]]
Policy after initial expansion:
 [[2.76e-02 9.53e-01 3.34e-05 1.90e-02]
 [7.40e-01 1.96e-01 4.43e-02 2.01e-02]
 [7.07e-02 3.05e-02 8.48e-01 5.06e-02]
 [2.65e-03 5.93e-01 3.51e-01 5.43e-02]
 [6.41e-01 2.70e-01 1.87e-03 8.77e-02]
 [6.45e-01 1.58e-01 6.01e-02 1.37e-01]
 [2.94e-05 2.95e-01 7.05e-01 2.98e-05]
 [2.24e-02 2.72e-01 1.92e-01 5.14e-01]
 [4.25e-01 2.09e-01 3.98e-02 3.26e-01]
 [1.95e-01 1.05e-01 6.27e-01 7.21e-02]]


In [118]:
res['initial_policy'].mean(axis=0)

array([[0.28, 0.24, 0.15, 0.33],
       [0.26, 0.3 , 0.19, 0.25],
       [0.21, 0.21, 0.15, 0.42],
       [0.17, 0.34, 0.32, 0.16],
       [0.23, 0.3 , 0.27, 0.19]])

In [225]:
softmax(np.random.normal(size=res['initial_policy'][0].shape) * 0.2, axis=-1)

DeviceArray([[0.25, 0.18, 0.26, 0.31],
             [0.3 , 0.24, 0.26, 0.2 ],
             [0.3 , 0.26, 0.23, 0.21],
             [0.19, 0.22, 0.38, 0.21],
             [0.22, 0.24, 0.29, 0.25]], dtype=float64)