In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import datetime
import yaml

import gymnasium as gym
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env

from env import SlidingEnv
from wrappers import NormalizedObsWrapper

In [None]:
configs = {
    "total_timesteps": 1000000,

    "env_kwargs": {"w": 2, "h": 2, "shuffle_steps": 5},
    "n_envs": 16,
    "wrapper_class": NormalizedObsWrapper,

    "max_episode_steps": 100,
    "seed": 42,
}

In [None]:
gym.envs.register(
    id="SlidingEnv-v0",
    entry_point=SlidingEnv,
    max_episode_steps=configs["max_episode_steps"],
)

env = gym.make("SlidingEnv-v0", **configs["env_kwargs"])
env.reset()

In [None]:
env = make_vec_env(
    env_id="SlidingEnv-v0",
    seed=configs["seed"],
    wrapper_class=configs["wrapper_class"],
    env_kwargs=configs["env_kwargs"],
    n_envs=configs["n_envs"],
)

In [None]:
run_id = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
model = PPO(
    "MlpPolicy",
    env,
    # verbose=2,
    tensorboard_log=f"runs/{run_id}",
)

In [None]:
model.learn(
    total_timesteps=configs["total_timesteps"],
    progress_bar=True,
)

In [None]:
configs["run_id"] = run_id
with open(f"runs/{run_id}/configs.yaml", "w") as f:
    yaml.dump(configs, f)
model.save(f"runs/{run_id}/model")