In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
from matplotlib import pyplot as plt
from agents.pgp.pgp_softmax import SoftMaxPGP
from environments.mutantworlds.mutantworlds_foster import FosterFosterWorldMaze
from plots.gridworlds.gridworld_visualizer import GridWorldVisualizer
from plots.gridworlds.gridworld_animator import GridWorldAnimator
from utils.policy_tools import *
from utils.policy_functions import *

In [None]:
# Maze Parameters
maze = FosterWorld(goal_reward=100, doors_cost=-100, pillars_as_cost=False, small_version=False)

# Agent Parameters
gamma = 0.99
p_situational = 0.995

In [None]:
# training parameters
replay_steps = 2000
replay_alpha_kwargs = {'situational':0.00, 'alpha_norm':None, 'alpha_mean':0.02}
replay_grad_kwargs={"λ":0, "natural":False}
replay_kld_threshold = 0.03

# plot parameteres
n_rows = 3
n_cols = 3
plot_n_histo_bins = 20
animation_fps = 10
animation_nframes = 50
animation_id_prefix = "__figures/hippocampal_replays/hippocampal_hierarchical_#"


In [None]:
%matplotlib qt

plt.close("all")
plt.pause(1)

f_policies = plt.figure(1)
f_divergences = plt.figure(2)

for i_m in range(1, maze.n_mutations):
    # change the configuration of the maze
    maze.mutate(i_m)

    # Reset the policy and the starting position
    gpp = SoftMaxPGP(maze, gamma=gamma, p0_func=(lambda agent, s0s=None, p0s=None, situational=p_situational: p0_situational(agent, s0s, p0s, situational)))
    viz = GridWorldVisualizer(maze, gpp)
    ani = GridWorldAnimator(viz)

    # Training
    print("Foster's Maze #{}".format(i_m+1))
    gpp.learn(n_steps=replay_steps, alpha_func=situational_alpha, gradient_kwargs=replay_grad_kwargs, alpha_kwargs=replay_alpha_kwargs)
    values, ranks, times, hist, mask, time_max = rank_states(gpp, f=policy_divergence, value_min=replay_kld_threshold, n_histo_bins=plot_n_histo_bins)

    # Plot

    plt.figure(f_policies.number)
    plt.subplot(n_rows, n_cols, i_m)

    viz.plot_trajectory_distribution(plot_maze=False, plot_axis=False, min_hue=0.1, max_hue=0.2)
    viz.plot_maze(plot_grid=False, plot_axis=False, neg_rew_cmap="Greys")
    viz.plot_policy(plot_maze=False, plot_grid=False, plot_axis=False)
    viz.plot_trajectory(plot_maze=False, plot_axis=False, greedy=True)
    plt.tight_layout()
    plt.show()
    plt.pause(1)

    plt.figure(f_divergences.number)
    plt.subplot(n_rows, n_cols, i_m)

    viz.plot_alpha_grid(hist, values, mask=mask, plot_axis=False, plot_grid=False, cmap="cool")
    viz.plot_maze(plot_grid=False, plot_axis=False, neg_rew_cmap="Greys")
    viz.plot_trajectory(greedy=True, plot_maze=False, plot_axis=False)
    plt.tight_layout()
    plt.show()
    plt.pause(1)

    animation = ani.animate_gradient_colored(ts_interval=round(replay_steps/animation_nframes), fps=animation_fps)
    animation_id = animation_id_prefix + str(i_m + 1) + ".mp4"
    animation.write_videofile(animation_id, fps=animation_fps)
    plt.show()
    plt.pause(1)

plt.figure(f_policies.number)
plt.savefig(animation_id_prefix + "policies" + ".png")
plt.figure(f_divergences.number)
plt.savefig(animation_id_prefix + "divergencies" + ".png")