In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
from matplotlib import pyplot as plt

from agents.pgp.pgp_softmax import SoftMaxPGP
from environments.gridworlds.gridworlds_classic import GridWorld
from plots.gridworlds.gridworld_visualizer import GridWorldVisualizer
from utils.policy_tools import *
from utils.policy_functions import *

In [3]:
# Agent parameters
gamma = 0.95

# World parameters
n = 4
init_locs = np.array([[1, 1]])

# Training parameters
alpha = 0.3
n_steps = 100
do_plots = False

# Actions
action_labels = np.array(["↑", "→", "↓", "←", "X"])

# Objects
world = GridWorld(n, n, init_locs)
world.A[:] = True

gpp = SoftMaxPGP(world, gamma)
viz = GridWorldVisualizer(world, gpp)

# policy softening
epsilon = 0.1

# initial position
s = gpp.state
sr = gpp.SR
ss = np.array([s])

# p(s0)
p_s = p0_onehot(gpp, s)

# action sequence
aa = np.array([1, 2, 1, 2, 0, 2])

In [4]:
# P(s2 | a) = Σ_s1 P(s2, s1 | a)
#           = Σ_s1 P(s1 | a) · P(s2 | a, s1)
#           = Σ_s1 P(s1 | a) · P(s2 | a, s1)
#           = Σ_s1 P(s1 | a) · P(a | s2, s1) · P(s2 | s1) / p(a | s1)

for i, a in enumerate(aa):

    # Do action
    s, _, _ = world.step(s, a)
    ss = np.append(ss, s)

    aa_sequence = "".join(action_labels[aa[:i+1]])
    a_current = action_labels[aa[i]]

    posterior = np.ndarray([world.n_state, world.n_state])
    p_s2 = np.zeros([world.n_state])
    
    for sP in range(world.n_state):

        # Compute Prior
        prior = np.exp(sr[sP, :]) / np.sum(np.exp(sr[sP, :]))

        # Compute likelihood
        print("STEP #{}: action sequence = {}. Considering s{} = {}".format(i+1, aa_sequence, i, world.decode(sP)))
        print("")

        likelihood = np.ndarray([world.n_state])

        for sL in range(world.n_state):

            world.reset_rewards()
            world.add_rewards( np.array([100]), np.array([sL]))
            world.A[:] = True

            gpp.reset()
            gpp.state = sP
            gpp.p0_func = (lambda agent, ss, ps, : prior/np.sum(prior))

            gpp.learn(n_steps=n_steps, alpha=alpha)
            likelihood[sL] = gpp.policy(sP)[a]
                
        # Compute posterior
        posterior[sP, :] = prior * likelihood
        posterior[sP, :] = posterior[sP, :] / np.sum(posterior[sP, :])
        p_s2 = p_s2 + posterior[sP, :] * p_s[sP]


        # Plots
        traj = np.array([sP])

        if do_plots:

            plt.subplot(231)
            viz.plot_grid(prior, plot_axis=False)
            viz.plot_trajectory(ss=traj, plot_maze=False, plot_grid=False, plot_axis=False, jitter=0)
            plt.title("p(s{} | s{} = {})".format(i+1, i, world.decode(sP)))

            plt.subplot(232)
            viz.plot_grid(likelihood, plot_axis=False)
            viz.plot_trajectory(ss=traj, plot_maze=False, plot_grid=False, plot_axis=False, jitter=0)
            plt.title("p(a = {} | s{}, s{} = {})".format(a_current, i+1, i, world.decode(sP)))


            plt.subplot(233)
            viz.plot_grid(posterior[sP, :], plot_axis=False)
            viz.plot_trajectory(ss=traj, plot_maze=False, plot_grid=False, plot_axis=False, jitter=0)
            plt.title("p(s{} | a = {}, s{} = {})".format(i+1, a_current, i, world.decode(sP)))

            plt.subplot(234)
            viz.plot_grid(p_s, plot_axis=False)
            viz.plot_trajectory(ss=traj, plot_maze=False, plot_grid=False, plot_axis=False, jitter=0)
            plt.title("p(s{} | a = {})".format(i, aa_sequence))

            plt.subplot(235)
            viz.plot_grid(p_s2, plot_axis=False)
            plt.title("partial p(s{} | a={})".format(i+1, a_current))

            plt.tight_layout()
            plt.show()
    
    p_s = p_s2


    viz.plot_grid(p_s2, plot_axis=False)
    viz.plot_trajectory(ss=ss, plot_maze=False, plot_grid=False, plot_axis=False)
    plt.title("p(s{} | a={})".format(i+1, a_current))

    plt.tight_layout()
    plt.show()
            


    





STEP #1: action sequence = →. Considering s0 = [0 0]



100%|██████████| 100/100 [00:00<00:00, 949.85it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

100%|██████████| 100/100 [00:00<00:00, 1113.82it/s]
100%|██████████| 100/100 [00:00<00:00, 1120.23it/s]
100%|██████████| 100/100 [00:00<00:00, 1111.24it/s]
100%|██████████| 100/100 [00:00<00:00, 935.08it/s]
100%|██████████| 100/100 [00:00<00:00, 1112.87it/s]
100%|██████████| 100/100 [00:00<00:00, 1107.12it/s]
100%|██████████| 100/100 [00:00<00:00, 1108.26it/s]
100%|██████████| 100/100 [00:00<00:00, 945.30it/s]
100%|██████████| 100/100 [00:00<00:00, 1113.85it/s]
100%|██████████| 100/100 [00:00<00:00, 1112.42it/s]
100%|██████████| 100/100 [00:00<00:00, 1103.53it/s]
100%|██████████| 100/100 [00:00<00:00, 936.42it/s]
100%|██████████| 100/100 [00:00<00:00, 1104.55it/s]
100%|██████████| 100/100 [00:00<00:00, 1103.68it/s]
100%|██████████| 100/100 [00:00<00:00, 926.91it/s]


STEP #1: action sequence = →. Considering s0 = [0 1]



100%|██████████| 100/100 [00:00<00:00, 1110.98it/s]
100%|██████████| 100/100 [00:00<00:00, 1104.76it/s]
100%|██████████| 100/100 [00:00<00:00, 1095.37it/s]
100%|██████████| 100/100 [00:00<00:00, 934.00it/s]
100%|██████████| 100/100 [00:00<00:00, 1099.67it/s]
100%|██████████| 100/100 [00:00<00:00, 1096.64it/s]
100%|██████████| 100/100 [00:00<00:00, 1095.71it/s]
100%|██████████| 100/100 [00:00<00:00, 929.64it/s]
100%|██████████| 100/100 [00:00<00:00, 1104.62it/s]
100%|██████████| 100/100 [00:00<00:00, 1095.69it/s]
100%|██████████| 100/100 [00:00<00:00, 1102.12it/s]
100%|██████████| 100/100 [00:00<00:00, 934.80it/s]
100%|██████████| 100/100 [00:00<00:00, 1107.04it/s]
100%|██████████| 100/100 [00:00<00:00, 1094.13it/s]
100%|██████████| 100/100 [00:00<00:00, 1094.22it/s]
100%|██████████| 100/100 [00:00<00:00, 922.85it/s]


STEP #1: action sequence = →. Considering s0 = [0 2]



100%|██████████| 100/100 [00:00<00:00, 1105.20it/s]
100%|██████████| 100/100 [00:00<00:00, 1099.55it/s]
100%|██████████| 100/100 [00:00<00:00, 1094.69it/s]
100%|██████████| 100/100 [00:00<00:00, 935.10it/s]
100%|██████████| 100/100 [00:00<00:00, 1093.91it/s]
100%|██████████| 100/100 [00:00<00:00, 1097.09it/s]
100%|██████████| 100/100 [00:00<00:00, 1101.15it/s]
100%|██████████| 100/100 [00:00<00:00, 931.88it/s]
100%|██████████| 100/100 [00:00<00:00, 1105.11it/s]
100%|██████████| 100/100 [00:00<00:00, 1082.89it/s]
100%|██████████| 100/100 [00:00<00:00, 919.56it/s]
100%|██████████| 100/100 [00:00<00:00, 1090.31it/s]
100%|██████████| 100/100 [00:00<00:00, 1101.21it/s]
100%|██████████| 100/100 [00:00<00:00, 1094.98it/s]
100%|██████████| 100/100 [00:00<00:00, 935.01it/s]
100%|██████████| 100/100 [00:00<00:00, 1096.47it/s]


STEP #1: action sequence = →. Considering s0 = [0 3]



100%|██████████| 100/100 [00:00<00:00, 1100.71it/s]
100%|██████████| 100/100 [00:00<00:00, 1096.18it/s]
100%|██████████| 100/100 [00:00<00:00, 931.16it/s]
100%|██████████| 100/100 [00:00<00:00, 1088.82it/s]
100%|██████████| 100/100 [00:00<00:00, 1096.22it/s]
100%|██████████| 100/100 [00:00<00:00, 1084.31it/s]
100%|██████████| 100/100 [00:00<00:00, 929.76it/s]
100%|██████████| 100/100 [00:00<00:00, 1088.77it/s]
100%|██████████| 100/100 [00:00<00:00, 1092.83it/s]
100%|██████████| 100/100 [00:00<00:00, 1096.19it/s]
100%|██████████| 100/100 [00:00<00:00, 922.82it/s]
100%|██████████| 100/100 [00:00<00:00, 1089.57it/s]
100%|██████████| 100/100 [00:00<00:00, 1092.80it/s]
100%|██████████| 100/100 [00:00<00:00, 1093.65it/s]
100%|██████████| 100/100 [00:00<00:00, 934.96it/s]
100%|██████████| 100/100 [00:00<00:00, 1098.04it/s]


STEP #1: action sequence = →. Considering s0 = [1 0]



100%|██████████| 100/100 [00:00<00:00, 1091.12it/s]
100%|██████████| 100/100 [00:00<00:00, 1098.96it/s]
100%|██████████| 100/100 [00:00<00:00, 929.95it/s]
100%|██████████| 100/100 [00:00<00:00, 1092.47it/s]
100%|██████████| 100/100 [00:00<00:00, 1090.89it/s]
100%|██████████| 100/100 [00:00<00:00, 923.23it/s]
100%|██████████| 100/100 [00:00<00:00, 1098.34it/s]
100%|██████████| 100/100 [00:00<00:00, 1090.97it/s]
100%|██████████| 100/100 [00:00<00:00, 1098.43it/s]
100%|██████████| 100/100 [00:00<00:00, 941.69it/s]
100%|██████████| 100/100 [00:00<00:00, 1099.27it/s]
100%|██████████| 100/100 [00:00<00:00, 1101.51it/s]
100%|██████████| 100/100 [00:00<00:00, 1099.27it/s]
100%|██████████| 100/100 [00:00<00:00, 927.16it/s]
100%|██████████| 100/100 [00:00<00:00, 1101.33it/s]
100%|██████████| 100/100 [00:00<00:00, 1100.34it/s]


STEP #1: action sequence = →. Considering s0 = [1 1]



100%|██████████| 100/100 [00:00<00:00, 1094.99it/s]
100%|██████████| 100/100 [00:00<00:00, 927.49it/s]
100%|██████████| 100/100 [00:00<00:00, 1100.10it/s]
100%|██████████| 100/100 [00:00<00:00, 1099.51it/s]
100%|██████████| 100/100 [00:00<00:00, 1098.06it/s]
100%|██████████| 100/100 [00:00<00:00, 931.43it/s]
100%|██████████| 100/100 [00:00<00:00, 1102.59it/s]
100%|██████████| 100/100 [00:00<00:00, 1108.88it/s]
100%|██████████| 100/100 [00:00<00:00, 1101.04it/s]
100%|██████████| 100/100 [00:00<00:00, 879.34it/s]
100%|██████████| 100/100 [00:00<00:00, 1089.43it/s]
100%|██████████| 100/100 [00:00<00:00, 1092.29it/s]
100%|██████████| 100/100 [00:00<00:00, 927.81it/s]
100%|██████████| 100/100 [00:00<00:00, 1095.85it/s]
100%|██████████| 100/100 [00:00<00:00, 1091.84it/s]
100%|██████████| 100/100 [00:00<00:00, 1089.59it/s]


STEP #1: action sequence = →. Considering s0 = [1 2]



100%|██████████| 100/100 [00:00<00:00, 931.80it/s]
100%|██████████| 100/100 [00:00<00:00, 1094.74it/s]
100%|██████████| 100/100 [00:00<00:00, 1101.90it/s]
100%|██████████| 100/100 [00:00<00:00, 1097.64it/s]
100%|██████████| 100/100 [00:00<00:00, 935.05it/s]
100%|██████████| 100/100 [00:00<00:00, 1096.22it/s]
100%|██████████| 100/100 [00:00<00:00, 1099.55it/s]
100%|██████████| 100/100 [00:00<00:00, 1096.69it/s]
100%|██████████| 100/100 [00:00<00:00, 928.99it/s]
100%|██████████| 100/100 [00:00<00:00, 1103.34it/s]
100%|██████████| 100/100 [00:00<00:00, 1104.99it/s]
100%|██████████| 100/100 [00:00<00:00, 1093.46it/s]
100%|██████████| 100/100 [00:00<00:00, 933.15it/s]
100%|██████████| 100/100 [00:00<00:00, 1098.15it/s]
100%|██████████| 100/100 [00:00<00:00, 1102.64it/s]
100%|██████████| 100/100 [00:00<00:00, 1098.70it/s]


STEP #1: action sequence = →. Considering s0 = [1 3]



100%|██████████| 100/100 [00:00<00:00, 902.53it/s]
100%|██████████| 100/100 [00:00<00:00, 1091.45it/s]
100%|██████████| 100/100 [00:00<00:00, 1096.30it/s]
100%|██████████| 100/100 [00:00<00:00, 1099.42it/s]
100%|██████████| 100/100 [00:00<00:00, 925.52it/s]
100%|██████████| 100/100 [00:00<00:00, 1096.01it/s]
100%|██████████| 100/100 [00:00<00:00, 1102.34it/s]
 11%|█         | 11/100 [00:00<00:00, 1032.92it/s]


KeyboardInterrupt: 