In [None]:
import sys
import wandb
import torch
from tqdm import tqdm
sys.path.append("..")

%load_ext autoreload
%autoreload 2
import numpy as np
from gym import wrappers
from torch import nn

from matplotlib import pyplot as plt
from src import utils as utils

In [None]:
env_name = "MiniGrid-DoorKey-5x5-v0"
# env_name = "MiniGrid-Empty-Random-5x5-v0"
# env_name = "MiniGrid-FourRooms-v0"
env = utils.make_env(env_name)

obs_space_shape = env.observation_space.shape
n_actions = env.action_space.n

plt.title('Game image')
plt.imshow(env.render('rgb_array'))
plt.show()

In [None]:
from src.a2c import A2CAlgo

config = {
    "time": 10,
    "n_games_mean": 1,
    "max_reward": 0.99,
    "device": "cpu",
    "env": env_name,
    "hidden_dim": 128,
    "emb_dim": 128,
    "n_env": 1,
    "gamma": 0.99,

    "max_grad_norm": 0.5,
    "lr": 0.001,
    "value_loss_coef": 0.5,
    "entropy_coef": 0.01,

    "n_sub": 1,
    "sub_n_iter": 50,
    "sub_step_size": 4,
    "sub_n_steps": 4,
    "sub_lr": 1e-3,

    "master_n_iter": 30,
    "master_step_size": 3,
    "master_n_steps": 4,
    "master_lr": 1e-3,

    "n_iter_epoch": 50,
    "n_steps_sub": 16
}

In [None]:

# import os
# os.environ["WANDB_MODE"] = "dryrun"
from src.mlsh_model import MLSHAgent
from src.env_pool import MLSHPool

agent = MLSHAgent(
    config["n_sub"],
    n_actions,
    obs_space_shape[1]
)
for p in agent.parameters():
    nn.init.uniform_(p, -0.1, 0.1)

pool = MLSHPool(agent,
                lambda : utils.make_env(env_name),
                config["n_env"],
                random_reset=False)

wandb.init(project="mlsh",
           monitor_gym=True,
           name=f"mlsh_{env_name[9:]}+{config['n_sub']}",
           config=config,
           dir="..",
           magic=True,
           group="comparing")
wandb.watch(agent)

In [None]:
a2c_subpolicies = \
    A2CAlgo(agent.subpolicies.parameters(),
            config["device"],
            n_actions,
            config["gamma"],
            config["max_grad_norm"],
            config["entropy_coef"],
            config["sub_lr"],
            config["value_loss_coef"])

ac2_master = \
    A2CAlgo(list(agent.master_policy.parameters()),
            config["device"],
            config["n_sub"],
            config["gamma"],
            config["max_grad_norm"],
            config["entropy_coef"],
            config["master_lr"],
            config["value_loss_coef"])

In [None]:
from src import mlsh_algo
for i in tqdm(range(4000)):
    pool.update_seeds()
    for env in pool.envs:
        env.reset()
    for p in agent.master_policy.parameters():
        nn.init.uniform_(p, -0.1, 0.1)

    mlsh_algo.warmup(ac2_master, pool,
                     config["master_n_iter"],
                     config["master_step_size"],
                     config["master_n_steps"], config)
    epoch_rew = mlsh_algo.joint_train(
        ac2_master,
        a2c_subpolicies,
        pool,
        config["sub_n_iter"],
        config["sub_step_size"],
        config["sub_n_steps"],
        config["n_env"])[0]

    with torch.no_grad():
        wandb.log({
            "mean_rewards_epoch": epoch_rew,
            "seen_evaluate_reward":
                np.mean(utils.evaluate(agent, env, n_games=5,
                    last_env=pool.seeds[0])[0]),
            "unseen_evaluate_reward":
                np.mean(utils.evaluate(agent, env, n_games=5,
                    last_env=None)[0])
        })
