In [None]:
from __future__ import annotations

import dataclasses
import json
import pickle
import random
import re
from pathlib import Path
from pprint import pprint

import pandas as pd
import plotly.express as px
from joblib import Parallel, delayed
from stable_baselines3 import PPO
from stable_baselines3.common.torch_layers import (
    FlattenExtractor,
    BaseFeaturesExtractor,
)
from tqdm.autonotebook import tqdm

import typeguard
import jaxtyping

import main as M


%load_ext autoreload
%autoreload 2

In [None]:
print("Action space:", M.RANDOM_GOAL_ENV.action_space)
print("Observation space:", M.RANDOM_GOAL_ENV.observation_space)

agent_files = list(Path("agents").glob("*.zip"))
print("Number of trained agents:", len(agent_files))

## Baseline

Agent that always goes to the bottom right corner

In [None]:
env_size = 7
M.Perfs.from_agent(M.BottomRightAgent(), episodes=100, env_size=env_size)

In [None]:
M.show_behavior(M.BottomRightAgent(), M.random_goal_env(env_size))

# Training agents

In [None]:
import transformer

env = M.random_goal_env(7)
print(env.__class__.__mro__)
print(env.observation_space)
transformer.CustomActorCriticPolicy(
    env.observation_space,
    env.action_space,
    lr_schedule=lambda _: 0.01,
    policy_kwargs=dict(d_model=32),
)

In [None]:
import train
import transformer

env_size = 17
env = M.random_goal_env(env_size)
print(env.observation_space)
# For bottom_right_odds, None means uniform, 3 means three times more likely to be bottom right than anywhere else
policy, perfs = train.get_agent(
    bottom_right_prob=0.9,
    total_timesteps=100_000,
    net_arch=(10,),
    n_epochs=40,
    n_steps=4_000 // 10,
    batch_size=400,
    learning_rate=0.0001,
    env_size=env_size,
    n_envs=10,
    can_turn=False,
    # policy=transformer.CustomActorCriticPolicy,
    # policy_kwargs=dict(features_extractor_class=transformer.CustomFeaturesExtractor, arch=dict(d_model=20, d_head=6, heads=3, layers=1)),
    save=True,
)

pprint(perfs)

In [None]:
1 + 1

In [None]:
import torchinfo

torchinfo.summary(policy.policy, input_size=(1, 3, 3, 7), depth=4)

In [None]:
M.show_behavior(policy, M.random_goal_env(env_size), 20)

In [None]:
# Print size of the model
print("Model size:", sum(p.numel() for p in policy.policy.parameters()))

# Perfomances of agents

In [None]:
agent_files = list(Path("agents").glob("*.zip"))
agent_files = list(Path("agents").glob("ppo_7env*.zip"))

print("Number of agents:", len(agent_files))


def get_perfs(file: Path) -> M.Perfs:
    if file.with_suffix(".json").exists():
        return M.Perfs(**json.load(open(file.with_suffix(".json"), "r")))
    else:
        env_size = int(re.search(r"(\d+)env", str(file)).group(1))
        perf = M.Perfs.from_agent(PPO.load(file), file=str(file), env_size=env_size)
        json.dump(dataclasses.asdict(perf), open(file.with_suffix(".json"), "w"))
        return perf


perfs = list(
    Parallel(n_jobs=-3)(delayed(get_perfs)(file) for file in tqdm(agent_files))
)

In [None]:
# Scatter plot of the perfs, 2 by 2
df = pd.DataFrame(
    dict(
        br_env=[p.br_env for p in perfs],
        general_env=[p.general_env for p in perfs],
        general_br_freq=[p.general_br_freq for p in perfs],
        file=[str(p.info["file"]) for p in perfs],
        odds=[
            int(re.search(r"(\d+)odds", str(p.info["file"])).group(1)) for p in perfs
        ],
    )
)

# Plot the perfs for agent with br_env > 0.9
px.scatter(
    df[df.br_env > 0.9],
    x="general_br_freq",
    y="general_env",
    color="odds",
    hover_name="file",
    width=1000,
    height=800,
    title="Performances of agents in a 5×5 environment with random goals.",
    labels=dict(
        general_br_freq="Probability of going in the bottom right corner, regardless of where the goal is",
        general_env="Probability of reaching the goal",
        odds="Odds of goal being<br>the bottom right corner",
    ),
).show()