In [1]:
import sys
import logging
import pickle

import numpy as np
import matplotlib.pyplot as plt
from joblib import hash, dump, load
import os

from deer.default_parser import process_args
from deer.agent import NeuralAgent
from deer.learning_algos.CRAR_torch import CRAR
from figure8_env import MyEnv as figure8_env
from figure8_alt1 import MyEnv as figure8_alt1
import deer.experiment.base_controllers as bc

from deer.policies import EpsilonGreedyPolicy, FixedFigure8Policy

In [2]:
figure8_give_rewards = True
nn_yaml = 'network_noconv.yaml'
higher_dim_obs = False
internal_dim = 10
fname = 'figure8_alt1'
set_network = None

In [3]:
class Defaults:
    # ----------------------
    # Setup Parameters (copied for convenience)
    # ----------------------
    figure8_give_rewards = figure8_give_rewards
    nn_yaml = nn_yaml
    higher_dim_obs = higher_dim_obs
    internal_dim = internal_dim
    fname = fname
    
    # ----------------------
    # Experiment Parameters
    # ----------------------
    steps_per_epoch = 5000
    epochs = 50
    steps_per_test = 1000
    period_btw_summary_perfs = 1

    # ----------------------
    # Temporal Processing Parameters
    # ----------------------
    nstep = 15
    nstep_decay = 0.8
    encoder_type = 'regular'
    
    # ----------------------
    # Environment Parameters
    # ----------------------
    frame_skip = 2
    show_rewards = False

    # ----------------------
    # DQN Agent parameters:
    # ----------------------
    learning_rate = 1*1E-4
    learning_rate_decay = 1.0
    discount = 0.9
    epsilon_start = 1.0
    epsilon_min = 1.0
    epsilon_decay = 1000
    update_frequency = 1
    replay_memory_size = 100000 #50000
    batch_size = 64
    freeze_interval = 1000
    deterministic = False
    
    # ----------------------
    # Learning algo parameters
    # ----------------------
    # T, entropy_neighbor, entropy_random, volume, gamma, R, Q, variational
    #loss_weights = [5E-3, 1E-3, 5E-3, 5E-3, 5E-3, 5E-3, 1.]
    #loss_weights = [0., 0., 0., 0., 0., 0., 1., 2E-4]
    loss_weights = [0, 0, 0, 0, 0, 0, 1., 0.]
    #loss_weights = [5E-3, 5E-3, 5E-3, 0, 5E-3, 5E-3, 1., 0.]

In [4]:
parameters = Defaults()
with open(f'params/{fname}.p', 'wb') as f:
    pickle.dump(parameters, f)

In [5]:
if parameters.deterministic:
    rng = np.random.RandomState(123456)
else:
    rng = np.random.RandomState()

# --- Instantiate environment ---
env = figure8_alt1(
    give_rewards=figure8_give_rewards,
    intern_dim=internal_dim,
    higher_dim_obs=higher_dim_obs,
    show_rewards=parameters.show_rewards,
    nstep=parameters.nstep, nstep_decay=parameters.nstep_decay
    )

# --- Instantiate learning_algo ---
learning_algo = CRAR(
    env,
    parameters.freeze_interval,
    parameters.batch_size,
    rng,
    high_int_dim=False,
    internal_dim=internal_dim, lr=parameters.learning_rate,
    nn_yaml=nn_yaml, double_Q=True,
    loss_weights=parameters.loss_weights,
    nstep=parameters.nstep, nstep_decay=parameters.nstep_decay,
    encoder_type=parameters.encoder_type
    )

if figure8_give_rewards:
    train_policy = EpsilonGreedyPolicy(
        learning_algo, env.nActions(), rng, 0.2,
        consider_valid_transitions=False
        )
    test_policy = EpsilonGreedyPolicy(
        learning_algo, env.nActions(), rng, 0.
        )
else:
    train_policy = FixedFigure8Policy.FixedFigure8Policy(
        learning_algo, env.nActions(), rng, epsilon=0.2,
        height=env.HEIGHT, width=env.WIDTH
        )
    test_policy = FixedFigure8Policy.FixedFigure8Policy(
        learning_algo, env.nActions(), rng,
        height=env.HEIGHT, width=env.WIDTH
        )

# --- Instantiate agent ---
agent = NeuralAgent(
    env, learning_algo,
    parameters.replay_memory_size,
    1, parameters.batch_size, rng,
    train_policy=train_policy, test_policy=test_policy)
if set_network is not None:
    agent.setNetwork(
        f'{set_network[0]}/fname', nEpoch=set_network[1],
        encoder_only=set_network[2]
        )

agent.run(10, 500)
print("end gathering data")

# --- Bind controllers to the agent ---
# Before every training epoch (periodicity=1), we want to print a summary of the agent's epsilon, discount and 
# learning rate as well as the training epoch number.
agent.attach(bc.VerboseController(
    evaluate_on='epoch', 
    periodicity=1))

# Learning rate may follow a scheduler
agent.attach(bc.LearningRateController(
    initial_learning_rate=parameters.learning_rate, 
    learning_rate_decay=parameters.learning_rate_decay,
    periodicity=1))

# During training epochs, we want to train the agent after every [parameters.update_frequency] action it takes.
# Plus, we also want to display after each training episode (!= than after every training) the average bellman
# residual and the average of the V values obtained during the last episode, hence the two last arguments.
agent.attach(bc.TrainerController(
    evaluate_on='action', 
    periodicity=parameters.update_frequency, 
    show_episode_avg_V_value=True, 
    show_avg_Bellman_residual=True))

# We wish to discover, among all versions of our neural network (i.e., after every training epoch), which one 
# has the highest validation score.
# To achieve this goal, one can use the FindBestController along with an InterleavedTestEpochControllers. It is 
# important that the validationID is the same than the id argument of the InterleavedTestEpochController.
# The FindBestController will dump on disk the validation scores for each and every network, as well as the 
# structure of the neural network having the best validation score. These dumps can then used to plot the evolution 
# of the validation and test scores (see below) or simply recover the resulting neural network for your 
# application.
agent.attach(bc.FindBestController(
    validationID=figure8_env.VALIDATION_MODE,
    testID=None,
    unique_fname=fname, savefrequency=5))

# All previous controllers control the agent during the epochs it goes through. However, we want to interleave a 
# "validation epoch" between each training epoch. For each validation epoch, we want also to display the sum of all 
# rewards obtained, hence the showScore=True. Finally, we want to call the summarizePerformance method of ALE_env 
# every [parameters.period_btw_summary_perfs] *validation* epochs.
agent.attach(bc.InterleavedTestEpochController(
    id=figure8_env.VALIDATION_MODE, 
    epoch_length=parameters.steps_per_test,
    periodicity=1,
    show_score=True,
    summarize_every=1,
    unique_fname=fname))

end gathering data


In [6]:
try:
    os.mkdir("params")
except Exception:
    pass
dump(vars(parameters), "params/" + fname + ".jldump")
#agent.gathering_data=False
if set_network is not None:
    agent.setNetwork(
        f'{set_network[0]}/fname', nEpoch=set_network[1],
        encoder_only=set_network[2]
        )
agent.run(parameters.epochs, parameters.steps_per_epoch)

# --- Show results ---
basename = "scores/" + fname
scores = load(basename + "_scores.jldump")
print(scores)

Printing a few elements useful for debugging:
actions_val[0], rewards_val[0], terminals_val[0]
2 0.0 0.0
Es[0], TEs[0], Esp_[0]
tensor([ 0.2286, -0.2181,  0.1445, -0.3463, -0.1055, -0.1022, -0.1445, -0.3482,
         0.4319,  0.1541]) tensor([ 0.2875, -0.3783,  0.3928, -0.2561, -0.0773, -0.4406, -0.4212, -0.1797,
         0.7151,  0.0861]) tensor([ 0.2343, -0.2141,  0.1566, -0.3309, -0.0969, -0.1152, -0.1554, -0.3498,
         0.4264,  0.1532])
R[0]
tensor([-0.0353], grad_fn=<SelectBackward0>)
LOSSES
T = 0.0390671131759882; R = 0.0009367903624661267;                 Gamma = 0.559142674446106; Q = 0.0010826034946221626;
Entropy Neighbor = 0.9167151342630386;                 Entropy Random = 0.7394493681788444;                 Volume = 0.0; VAE = 0.0
Printing a few elements useful for debugging:
actions_val[0], rewards_val[0], terminals_val[0]
2 0.0 0.0
Es[0], TEs[0], Esp_[0]
tensor([ 0.0322, -0.4239, -0.3579, -0.2188, -0.2336, -0.0344, -0.1115, -0.1618,
         0.4818,  0.1785]) tensor



Testing score per episode (id: 0) is 0.0 (average over 1 episode(s))
== Mean score per episode is 0.0 over 1 episodes ==


  plt.show()
  dist_matrix = dist_matrix/np.nanpercentile(dist_matrix.flatten(), 99)
  ylim_max = np.nanmax(self._separability_tracking)*1.1


Printing a few elements useful for debugging:
actions_val[0], rewards_val[0], terminals_val[0]
2 0.0 0.0
Es[0], TEs[0], Esp_[0]
tensor([ 0.1535, -0.3471, -0.0997, -0.2279, -0.1796,  0.0454, -0.0832, -0.0746,
         0.2726,  0.1162]) tensor([ 0.2114, -0.5072,  0.1524, -0.1348, -0.1445, -0.2949, -0.3677,  0.0974,
         0.5512,  0.0510]) tensor([ 0.1532, -0.3428, -0.0918, -0.2028, -0.1670,  0.0300, -0.1126, -0.0482,
         0.2704,  0.1012])
R[0]
tensor([-0.0305], grad_fn=<SelectBackward0>)
LOSSES
T = 0.04058194959908724; R = 0.0008959239681717009;                 Gamma = 0.6048417918682099; Q = 1.1376452060574138e-06;
Entropy Neighbor = 0.8878252104520797;                 Entropy Random = 0.4723686958551407;                 Volume = 0.0; VAE = 0.0
Printing a few elements useful for debugging:
actions_val[0], rewards_val[0], terminals_val[0]
2 0.0 0.0
Es[0], TEs[0], Esp_[0]
tensor([ 0.1475, -0.3545, -0.0894, -0.1148, -0.0895, -0.0310, -0.1999,  0.0401,
         0.2449,  0.0621]) ten

  abs_states[i:i+1], torch.as_tensor([action_encoding])
  plt.show()
  ylim_max = np.nanmax(self._separability_tracking)*1.1


Printing a few elements useful for debugging:
actions_val[0], rewards_val[0], terminals_val[0]
1 0.0 0.0
Es[0], TEs[0], Esp_[0]
tensor([ 0.1534, -0.3422, -0.0838, -0.1187, -0.1084, -0.0379, -0.1999,  0.0425,
         0.2576,  0.0580]) tensor([ 0.2267, -0.4960,  0.1631, -0.0213, -0.0790, -0.3631, -0.4983,  0.2225,
         0.5271, -0.0169]) tensor([ 0.1529, -0.3427, -0.0844, -0.1175, -0.1066, -0.0388, -0.2008,  0.0437,
         0.2570,  0.0575])
R[0]
tensor([-0.0266], grad_fn=<SelectBackward0>)
LOSSES
T = 0.040675786569714545; R = 0.0007149546800646931;                 Gamma = 0.6068195993900299; Q = 6.237174925871613e-07;
Entropy Neighbor = 0.910564215540886;                 Entropy Random = 0.5748949845433236;                 Volume = 0.0; VAE = 0.0
Printing a few elements useful for debugging:
actions_val[0], rewards_val[0], terminals_val[0]
2 0.0 0.0
Es[0], TEs[0], Esp_[0]
tensor([ 0.1563, -0.3378, -0.0722, -0.1262, -0.1224, -0.0412, -0.1967,  0.0364,
         0.2551,  0.0546]) tens

  plt.show()
  ylim_max = np.nanmax(self._separability_tracking)*1.1


Printing a few elements useful for debugging:
actions_val[0], rewards_val[0], terminals_val[0]
0 0.0 0.0
Es[0], TEs[0], Esp_[0]
tensor([-0.0708, -0.3815, -0.3132, -0.1291, -0.3121, -0.0398, -0.0164, -0.1977,
         0.1963,  0.0141]) tensor([-0.0153, -0.5527, -0.0735, -0.0342, -0.2555, -0.3829, -0.3171, -0.0453,
         0.4587, -0.0332]) tensor([-0.0708, -0.3815, -0.3132, -0.1291, -0.3121, -0.0398, -0.0164, -0.1977,
         0.1963,  0.0141])
R[0]
tensor([0.0070], grad_fn=<SelectBackward0>)
LOSSES
T = 0.04141333675384522; R = 0.0019666570110130122;                 Gamma = 0.6144813226461411; Q = 0.0013595634132980193;
Entropy Neighbor = 0.8476336225271225;                 Entropy Random = 0.46406168496608735;                 Volume = 0.0; VAE = 0.0
Printing a few elements useful for debugging:
actions_val[0], rewards_val[0], terminals_val[0]
0 0.0 0.0
Es[0], TEs[0], Esp_[0]
tensor([-0.0811, -0.3803, -0.3156, -0.1182, -0.3031, -0.0439, -0.0300, -0.1853,
         0.1916,  0.0082]) tens

  ax.scatter(
  plt.show()
  ylim_max = np.nanmax(self._separability_tracking)*1.1


Printing a few elements useful for debugging:
actions_val[0], rewards_val[0], terminals_val[0]
3 0.0 0.0
Es[0], TEs[0], Esp_[0]
tensor([-0.0322, -0.4622, -0.3447, -0.1156, -0.0932, -0.0767, -0.0692, -0.1767,
         0.2042,  0.2460]) tensor([ 0.0242, -0.6411, -0.0807,  0.0050, -0.0529, -0.3990, -0.3892,  0.0124,
         0.4421,  0.1671]) tensor([-0.1061, -0.4231, -0.3569, -0.1215, -0.1358, -0.0774, -0.0933, -0.2161,
         0.2496,  0.1729])
R[0]
tensor([0.0032], grad_fn=<SelectBackward0>)
LOSSES
T = 0.04271558713912964; R = 0.0031103744656720664;                 Gamma = 0.6220539962053299; Q = 0.002431794282318151;
Entropy Neighbor = 0.7833728437423706;                 Entropy Random = 0.3441389217078686;                 Volume = 0.0; VAE = 0.0
Printing a few elements useful for debugging:
actions_val[0], rewards_val[0], terminals_val[0]
2 0.0 0.0
Es[0], TEs[0], Esp_[0]
tensor([-0.0736, -0.2615, -0.2630, -0.0826, -0.3922, -0.0163, -0.1237, -0.0934,
         0.2522, -0.0206]) tensor

  ax.scatter(
  plt.show()
  ylim_max = np.nanmax(self._separability_tracking)*1.1


Printing a few elements useful for debugging:
actions_val[0], rewards_val[0], terminals_val[0]
1 0.0 0.0
Es[0], TEs[0], Esp_[0]
tensor([-0.1644, -0.2581, -0.3036, -0.0774, -0.4119, -0.0207, -0.1265, -0.1308,
         0.2672, -0.0706]) tensor([-0.0849, -0.4072, -0.0581,  0.0202, -0.3869, -0.3443, -0.4291,  0.0499,
         0.5382, -0.1423]) tensor([-0.1644, -0.2581, -0.3036, -0.0774, -0.4119, -0.0207, -0.1265, -0.1308,
         0.2672, -0.0706])
R[0]
tensor([-0.0138], grad_fn=<SelectBackward0>)
LOSSES
T = 0.04270715135335922; R = 0.0034717901558033192;                 Gamma = 0.6247943366765976; Q = 0.0027675981817737922;
Entropy Neighbor = 0.7696245709657669;                 Entropy Random = 0.34583619925379755;                 Volume = 0.0; VAE = 0.0
Printing a few elements useful for debugging:
actions_val[0], rewards_val[0], terminals_val[0]
0 0.0 0.0
Es[0], TEs[0], Esp_[0]
tensor([-0.1587, -0.2338, -0.2887, -0.0882, -0.4321, -0.0209, -0.1310, -0.1302,
         0.2836, -0.0723]) ten

  ax.scatter(
  plt.show()
  ylim_max = np.nanmax(self._separability_tracking)*1.1


Printing a few elements useful for debugging:
actions_val[0], rewards_val[0], terminals_val[0]
3 0.0 0.0
Es[0], TEs[0], Esp_[0]
tensor([-0.2399, -0.2849, -0.3953, -0.2971, -0.2975,  0.0149,  0.0244, -0.2391,
         0.3615,  0.0197]) tensor([-0.1872, -0.4659, -0.1265, -0.1762, -0.2628, -0.3083, -0.2978, -0.0498,
         0.5996, -0.0515]) tensor([-0.2332, -0.2850, -0.3947, -0.3079, -0.2835,  0.0204,  0.0224, -0.2314,
         0.3689,  0.0125])
R[0]
tensor([0.0153], grad_fn=<SelectBackward0>)
LOSSES
T = 0.04288621859997511; R = 0.005040755304333288;                 Gamma = 0.6351458985805511; Q = 0.004220619219006039;
Entropy Neighbor = 0.7369756959676742;                 Entropy Random = 0.257257558748126;                 Volume = 0.0; VAE = 0.0
Printing a few elements useful for debugging:
actions_val[0], rewards_val[0], terminals_val[0]
2 0.0 0.0
Es[0], TEs[0], Esp_[0]
tensor([-0.0420, -0.3704, -0.2083, -0.2554, -0.3066, -0.0628,  0.0248, -0.2725,
         0.1903,  0.0798]) tensor([

  ax.scatter(
  plt.show()
  ylim_max = np.nanmax(self._separability_tracking)*1.1


Printing a few elements useful for debugging:
actions_val[0], rewards_val[0], terminals_val[0]
3 0.0 0.0
Es[0], TEs[0], Esp_[0]
tensor([ 0.1131, -0.4682, -0.2996, -0.3168, -0.3408, -0.0966,  0.2799, -0.4762,
         0.2940,  0.4647]) tensor([ 0.1606, -0.6496, -0.0309, -0.1990, -0.3047, -0.4243, -0.0508, -0.2873,
         0.5256,  0.3970]) tensor([ 0.1407, -0.4914, -0.3057, -0.3576, -0.3264, -0.0756,  0.3418, -0.5162,
         0.2913,  0.5049])
R[0]
tensor([0.0053], grad_fn=<SelectBackward0>)
LOSSES
T = 0.04318473277240992; R = 0.006972488264465938;                 Gamma = 0.6390772536993027; Q = 0.005653495355858467;
Entropy Neighbor = 0.7192538998126984;                 Entropy Random = 0.24633111676573755;                 Volume = 0.0; VAE = 0.0
Printing a few elements useful for debugging:
actions_val[0], rewards_val[0], terminals_val[0]
0 0.0 0.0
Es[0], TEs[0], Esp_[0]
tensor([ 0.0746, -0.4287, -0.2723, -0.2203, -0.3663, -0.0593,  0.1917, -0.3597,
         0.1950,  0.4005]) tensor

  ax.scatter(
  plt.show()
  ylim_max = np.nanmax(self._separability_tracking)*1.1


Printing a few elements useful for debugging:
actions_val[0], rewards_val[0], terminals_val[0]
2 0.0 0.0
Es[0], TEs[0], Esp_[0]
tensor([-0.1095, -0.2807, -0.3081, -0.1091, -0.4811,  0.0648,  0.0232, -0.2456,
         0.2209,  0.0011]) tensor([-0.0442, -0.4350, -0.0560, -0.0154, -0.4489, -0.2745, -0.2693, -0.0706,
         0.4999, -0.0606]) tensor([-0.1095, -0.2807, -0.3081, -0.1091, -0.4811,  0.0648,  0.0232, -0.2456,
         0.2209,  0.0011])
R[0]
tensor([-0.0170], grad_fn=<SelectBackward0>)
LOSSES
T = 0.0447517713829875; R = 0.011125648298708256;                 Gamma = 0.6347776871919631; Q = 0.008461621734779328;
Entropy Neighbor = 0.6666777740716934;                 Entropy Random = 0.18714740991592407;                 Volume = 0.0; VAE = 0.0
Printing a few elements useful for debugging:
actions_val[0], rewards_val[0], terminals_val[0]
3 0.0 0.0
Es[0], TEs[0], Esp_[0]
tensor([-0.1198, -0.4233, -0.3736, -0.0899, -0.2125, -0.0529, -0.1037, -0.1496,
         0.2103,  0.1090]) tensor

  ax.scatter(
  plt.show()
  ylim_max = np.nanmax(self._separability_tracking)*1.1


Printing a few elements useful for debugging:
actions_val[0], rewards_val[0], terminals_val[0]
2 0.0 0.0
Es[0], TEs[0], Esp_[0]
tensor([-0.1325, -0.2739, -0.3136, -0.2175, -0.1781,  0.0057, -0.1095, -0.1646,
         0.3489,  0.0490]) tensor([-0.0668, -0.4298, -0.0695, -0.1254, -0.1451, -0.3305, -0.3863,  0.0021,
         0.6327, -0.0183]) tensor([-0.1477, -0.3040, -0.3236, -0.1705, -0.1425, -0.0409, -0.1509, -0.1586,
         0.3329,  0.0709])
R[0]
tensor([-0.0244], grad_fn=<SelectBackward0>)
LOSSES
T = 0.045297746263444426; R = 0.01169893319538096;                 Gamma = 0.6318807172775268; Q = 0.009433930134633556;
Entropy Neighbor = 0.6373021242618561;                 Entropy Random = 0.1662583377957344;                 Volume = 0.0; VAE = 0.0
Printing a few elements useful for debugging:
actions_val[0], rewards_val[0], terminals_val[0]
2 0.0 0.0
Es[0], TEs[0], Esp_[0]
tensor([-0.0892, -0.2717, -0.2768, -0.1218, -0.4500,  0.0590, -0.0159, -0.1777,
         0.1966, -0.0067]) tensor

  ax.scatter(
  plt.show()
  ylim_max = np.nanmax(self._separability_tracking)*1.1


Printing a few elements useful for debugging:
actions_val[0], rewards_val[0], terminals_val[0]
2 0.0 0.0
Es[0], TEs[0], Esp_[0]
tensor([ 0.1181, -0.3761, -0.1856, -0.0931, -0.0191, -0.0682, -0.2359, -0.1067,
         0.2745,  0.2470]) tensor([ 0.1848, -0.5311,  0.0561, -0.0017,  0.0173, -0.4049, -0.5110,  0.0601,
         0.5577,  0.1748]) tensor([ 0.1154, -0.4213, -0.2006, -0.0326,  0.1066, -0.0968, -0.3602,  0.0506,
         0.2277,  0.1987])
R[0]
tensor([-0.0369], grad_fn=<SelectBackward0>)
LOSSES
T = 0.045643136650323865; R = 0.014920905178587417;                 Gamma = 0.6302621161937714; Q = 0.011004713690374047;
Entropy Neighbor = 0.6091972035169602;                 Entropy Random = 0.16026164031028747;                 Volume = 0.0; VAE = 0.0
Printing a few elements useful for debugging:
actions_val[0], rewards_val[0], terminals_val[0]
1 0.0 0.0
Es[0], TEs[0], Esp_[0]
tensor([ 0.0029, -0.2363, -0.1013, -0.4533, -0.1663,  0.1144, -0.1421,  0.0844,
         0.2041, -0.1038]) tens


KeyboardInterrupt



## Visualize performance

In [None]:
agent.setNetwork(f'{fname}/fname', nEpoch=15)

In [None]:
agent._in_episode = True
agent._mode = 0 # Testing mode with plan_depth=0
initState = env.reset(agent._mode)
inputDims = env.inputDimensions()

for i in range(len(inputDims)):
    if inputDims[i][0] > 1:
        agent._state[i][1:] = initState[i][1:]
agent._Vs_on_last_episode = []
is_terminal = False
reward = 0

### Frame by frame

In [None]:
%matplotlib inline

for i in range(100):
    obs = env.observe()
    _obs = obs[0].reshape((env.WIDTH, env.HEIGHT))
    plt.figure()
    plt.imshow(np.flip(_obs.squeeze()))
    plt.show()
    for i in range(len(obs)):
        agent._state[i][0:-1] = agent._state[i][1:]
        agent._state[i][-1] = obs[i]
    V, action, reward, _ = agent._step()
    print(action)
    agent._Vs_on_last_episode.append(V)
    is_terminal = env.inTerminalState()
    if is_terminal: break

### As animation

In [None]:
%matplotlib notebook

import numpy as np 
import matplotlib.pyplot as plt
from IPython.display import HTML
import matplotlib.animation as animation

fig, ax = plt.subplots(1,1)
obs = env.observe()
_obs = obs[0].reshape((env.WIDTH, env.HEIGHT))
_obs = np.flip(_obs.squeeze())
ax.set_xticks([])
ax.set_yticks([])
im = ax.imshow(np.zeros(_obs.shape))

def init():
    plt.cla()
    im = ax.imshow(_obs)
    return [im]

def animate(i, *args, **kwargs):
    plt.cla()
    obs = env.observe()
    _obs = obs[0].reshape((env.WIDTH, env.HEIGHT))
    _obs = np.flip(_obs.squeeze())
    im = ax.imshow(_obs)
    for i in range(len(obs)):
        agent._state[i][0:-1] = agent._state[i][1:]
        agent._state[i][-1] = obs[i]
        V, action, reward, _ = agent._step()
        agent._Vs_on_last_episode.append(V)
    return [im]

ani = animation.FuncAnimation(fig, animate, init_func=init, 
     frames=100, blit=False, repeat=True)
ani.save(f'figs/{fname}/behavior.gif', writer="ffmpeg", fps = 15)
fig.show()