In [None]:
import sys
sys.path.append("..")

%load_ext autoreload
%autoreload 2
import numpy as np
from gym import wrappers
from torch import nn

from matplotlib import pyplot as plt
from src import utils as utils
from src.model import SimpleRecurrent

In [None]:
# env_name = "MiniGrid-FourRooms-v0"
env_name = "MiniGrid-DoorKey-5x5-v0"
# env_name = "MiniGrid-Empty-Random-5x5-v0"
env = utils.make_env(env_name)

obs_space_shape = env.observation_space.shape
n_actions = env.action_space.n

plt.title('Game image')
plt.imshow(env.render('rgb_array'))
plt.show()

In [None]:
import os
import wandb

from tqdm import tqdm
from src.a2c import A2CAlgo

# os.environ['WANDB_MODE'] = 'dryrun'

LOG_EACH = 30
VIDEO_EACH = 400

config = {
    "time": 8,
    "n_games_mean": 1,
    "max_reward": 2,
    "device": "cpu",
    "env": env_name,
    "hidden_dim": 128,
    "emb_dim": 128,
    "n_env": 128,
    "gamma": 0.9,

    "max_grad_norm": 0.5,
    "lr": 0.001,
    "value_loss_coef": 1,
    "entropy_coef": 0.01,
}

obs = env.reset()
agent = SimpleRecurrent(
    obs_space_shape,
    n_actions,
    config
)
for p in agent.parameters():
    nn.init.uniform_(p, -0.1, 0.1)

wandb.init(project="mlsh",
           monitor_gym=True,
           name=f"{env_name[9:]}",
           config=config,
           dir="..",
           magic=True)
wandb.watch(agent)

In [None]:
from src.env_pool import EnvPool
pool = EnvPool(agent, lambda : utils.make_env(env_name), config["n_env"])
rollout_obs, rollout_actions, rollout_rewards, rollout_mask = pool.interact(10)

In [None]:
print("Actions shape:", rollout_actions.shape)
print("Rewards shape:", rollout_rewards.shape)
print("Mask shape:", rollout_mask.shape)
print("Observations shape: ", rollout_obs.shape)

In [None]:
alg = A2CAlgo(agent, config["device"], n_actions,
              config["gamma"],
              config["max_grad_norm"],
              config["entropy_coef"],
              config["lr"],
              config["value_loss_coef"])

memory = list(pool.prev_memory_states)
int_time = 0
step_time = 0
wandb_time = 0

from time import time

for i in tqdm(range(4000)):
    memory = list(pool.prev_memory_states)
    t = time()
    rollout_obs, rollout_actions, rollout_rewards, rollout_mask = pool.interact(config["time"])
    int_time += time() - t
    t = time()
    loss, grad_norm, entropy, values, al, cl = alg.step(
        rollout_obs, rollout_actions, rollout_rewards, rollout_mask, memory, config["gamma"])
    step_time += time() - t
    t = time()
    wandb.log({
            "rew": np.mean(rollout_rewards),
            "values": np.mean(values),
            "policy_loss": al,
            "value_loss": cl
        }, commit=False, step=i)
    wandb_time += time() - t

    if i % LOG_EACH == 0:
        reward = np.mean(utils.evaluate(agent, env, n_games=10))
        log = {
            "rewards": reward,
            "grad_norm": grad_norm,
            "entropy": entropy,
            "loss": loss
        }
        wandb.log(log, step=i, commit=i%VIDEO_EACH==0)

        # if i % VIDEO_EACH == 0:
        #     env_monitor = wrappers.Monitor(env, directory="videos", force=True)
        #     rw = utils.evaluate(agent, env_monitor, n_games=config["n_games_mean"],)
        #     env_monitor.close()
        if reward  >= config["max_reward"]:
            print("Your agent has just passed the minimum homework threshold")
            break