In [2]:
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from jax.nn import softmax
from jax.config import config
from pathlib import Path
from collections import namedtuple

config.update('jax_platform_name', 'cpu')
np.set_printoptions(precision=4)
plt.rcParams['axes.facecolor'] = 'white'
plt.rcParams.update({'font.size': 22})

from grl.utils import load_info
from grl.utils.mdp import get_perf
from definitions import ROOT_DIR
np.set_printoptions(precision=3)

In [3]:
results_dir = Path(ROOT_DIR, 'results', 'pomdps_mi_pi_q_abs')


In [4]:
all_results = {}

# for results_path in results_dir.iterdir():
#     if 'tmaze' not in results_path.stem:
#         continue
#     info = load_info(results_path)
results_paths = [res_path for res_path in results_dir.iterdir() if 'cheese' in res_path.stem]
results_path = results_paths[0]

In [5]:
agent_path = results_path.parent / 'agents' / f"{results_path.stem}.pkl.npy"

In [6]:
agent = load_info(agent_path)
mem = agent.memory
action_map = ['NORTH', 'SOUTH', 'EAST', 'WEST']
# info['logs']['initial_improvement_policy'], agent.memory, agent.policy
for a, m in enumerate(mem):
    print(action_map[a])
    print(mem)
    print()

NORTH
[[[[1.805e-01 8.195e-01]
   [7.719e-02 9.228e-01]]

  [[9.942e-01 5.792e-03]
   [2.235e-02 9.776e-01]]

  [[5.960e-01 4.040e-01]
   [4.301e-01 5.699e-01]]

  [[4.881e-01 5.119e-01]
   [5.022e-01 4.978e-01]]

  [[2.128e-01 7.872e-01]
   [1.622e-02 9.838e-01]]

  [[3.060e-05 1.000e+00]
   [2.209e-04 9.998e-01]]

  [[9.992e-01 7.584e-04]
   [9.995e-01 4.500e-04]]]


 [[[4.340e-03 9.957e-01]
   [4.864e-03 9.951e-01]]

  [[9.944e-01 5.590e-03]
   [2.326e-02 9.767e-01]]

  [[9.998e-01 2.343e-04]
   [9.997e-01 2.877e-04]]

  [[7.170e-04 9.993e-01]
   [8.460e-04 9.992e-01]]

  [[1.738e-01 8.262e-01]
   [9.508e-01 4.919e-02]]

  [[7.842e-01 2.158e-01]
   [3.330e-01 6.670e-01]]

  [[9.655e-01 3.446e-02]
   [9.816e-01 1.838e-02]]]


 [[[3.983e-03 9.960e-01]
   [4.005e-03 9.960e-01]]

  [[8.938e-01 1.062e-01]
   [6.726e-01 3.274e-01]]

  [[8.871e-01 1.129e-01]
   [9.684e-01 3.164e-02]]

  [[3.617e-01 6.383e-01]
   [5.143e-01 4.857e-01]]

  [[9.945e-01 5.548e-03]
   [9.590e-04 9.990e-01]]

  

In [7]:
tol = 0.1

SET = np.array([
    [0, 1], 
    [0, 1]
])

RESET = np.array([
    [1, 0], 
    [1, 0]
])

HOLD = np.array([
    [1, 0], 
    [0, 1]
])

FLIP = np.array([
    [0, 1],
    [1, 0]
])

for a, act_mem in enumerate(agent.memory):
    for obs_idx, obs_act_mem in enumerate(act_mem):
        if obs_idx == act_mem.shape[0] - 1:
            continue
            
        if np.allclose(obs_act_mem, SET, atol=tol):
            print(f"(obs: {obs_idx}, action: {action_map[a]}) -> SET")
        elif np.allclose(obs_act_mem, RESET, atol=tol):
            print(f"(obs: {obs_idx}, action: {action_map[a]}) -> RESET")
        elif np.allclose(obs_act_mem, HOLD, atol=tol):
            print(f"(obs: {obs_idx}, action: {action_map[a]}) -> HOLD")
        elif np.allclose(obs_act_mem, FLIP, atol=tol):
            print(f"(obs: {obs_idx}, action: {action_map[a]}) -> FLIP")
        else:
            continue


(obs: 1, action: NORTH) -> HOLD
(obs: 5, action: NORTH) -> SET
(obs: 0, action: SOUTH) -> SET
(obs: 1, action: SOUTH) -> HOLD
(obs: 2, action: SOUTH) -> RESET
(obs: 3, action: SOUTH) -> SET
(obs: 0, action: EAST) -> SET
(obs: 4, action: EAST) -> HOLD
(obs: 3, action: WEST) -> RESET
(obs: 4, action: WEST) -> HOLD


In [9]:
agent.memory.round(2)

DeviceArray([[[[0.18, 0.82],
               [0.08, 0.92]],

              [[0.99, 0.01],
               [0.02, 0.98]],

              [[0.6 , 0.4 ],
               [0.43, 0.57]],

              [[0.49, 0.51],
               [0.5 , 0.5 ]],

              [[0.21, 0.79],
               [0.02, 0.98]],

              [[0.  , 1.  ],
               [0.  , 1.  ]],

              [[1.  , 0.  ],
               [1.  , 0.  ]]],


             [[[0.  , 1.  ],
               [0.  , 1.  ]],

              [[0.99, 0.01],
               [0.02, 0.98]],

              [[1.  , 0.  ],
               [1.  , 0.  ]],

              [[0.  , 1.  ],
               [0.  , 1.  ]],

              [[0.17, 0.83],
               [0.95, 0.05]],

              [[0.78, 0.22],
               [0.33, 0.67]],

              [[0.97, 0.03],
               [0.98, 0.02]]],


             [[[0.  , 1.  ],
               [0.  , 1.  ]],

              [[0.89, 0.11],
               [0.67, 0.33]],

              [[0.89, 0.11],
       