In [None]:
from __future__ import annotations

import dataclasses
import json
import pickle
import random
import re
from pathlib import Path

import pandas as pd
import plotly.express as px
from joblib import Parallel, delayed
from stable_baselines3 import PPO
from tqdm.autonotebook import tqdm

import main as M


%load_ext autoreload
%autoreload 2

In [None]:
print("Action space:", M.RANDOM_GOAL_ENV.action_space)
print("Observation space:", M.RANDOM_GOAL_ENV.observation_space)

agent_files = list(Path("agents").glob("*.zip"))
print("Number of agents:", len(agent_files))

In [None]:
def get_random_agent():
    file = random.choice(agent_files)
    print("Loading agent:", file)
    return PPO.load(file)

policy = get_random_agent()
print("Policy:", policy.policy)
print("Model size:", sum(p.numel() for p in policy.policy.parameters()))

## Baseline

Agent that always goes to the bottom right corner

In [None]:
env_size = 9
M.Perfs.from_agent(M.BottomRightAgent(), episodes=1000, env_size=env_size)

In [None]:
M.show_behavior(M.BottomRightAgent(), M.random_goal_env(env_size), 40)

## Perfomances of agents

In [None]:
agent_files = list(Path("agents").glob("*.zip"))
agent_files = list(Path("agents").glob("ppo_7env*.zip"))

print("Number of agents:", len(agent_files))

def get_perfs(file: Path) -> M.Perfs:
    if file.with_suffix(".json").exists():
        return M.Perfs(**json.load(open(file.with_suffix(".json"), "r")))
    else:
        env_size = int(re.search(r"(\d+)env", str(file)).group(1))
        perf = M.Perfs.from_agent(PPO.load(file), file=str(file), env_size=env_size)
        json.dump(dataclasses.asdict(perf), open(file.with_suffix(".json"), "w"))
        return perf

perfs = list(
    Parallel(n_jobs=-3)(
        delayed(get_perfs)(file)
        for file in tqdm(agent_files)
    )
)

In [None]:
# Scatter plot of the perfs, 2 by 2
df = pd.DataFrame(
    dict(
        br_env=[p.br_env for p in perfs],
        general_env=[p.general_env for p in perfs],
        general_br_freq=[p.general_br_freq for p in perfs],
        file=[str(p.info["file"]) for p in perfs],
        odds=[
            int(re.search(r"(\d+)odds", str(p.info["file"])).group(1)) for p in perfs
        ],
    )
)

# Plot the perfs for agent with br_env > 0.9
px.scatter(
    df[df.br_env > 0.9],
    x="general_br_freq",
    y="general_env",
    color="odds",
    hover_name="file",
    width=1000,
    height=800,
    title="Performances of agents in a 5×5 environment with random goals.",
    labels=dict(
        general_br_freq="Probability of going in the bottom right corner, regardless of where the goal is",
        general_env="Probability of reaching the goal",
        odds="Odds of goal being<br>the bottom right corner",
    ),
).show()

In [None]:
import train

policy, perfs = train.get_agent(1/25, 100_000, net_arch=(10, 10), env_size=7, save=False, learning_rate=0.0001)

In [None]:
# Print size of the model
print("Model size:", sum(p.numel() for p in policy.policy.parameters()))
policy.policy

In [None]:
M.show_behavior(policy, M.random_goal_env(7), 20)

# Try 2

In [None]:
import train

policy, perfs = train.get_agent(1000, 40_000, net_arch=(30, 10), env_size=7, save=False, use_wandb=True)

In [None]:
M.show_behavior(policy, M.random_goal_env(7), 20)

In [None]:
new_agents = []
for _ in range(10):
    policy, perfs = train.get_agent(50, 30_000, net_arch=(30, 10), env_size=5, save=False)
    new_agents.append((policy, perfs))

for _, perf in new_agents:
    print(perf)

In [None]:
policy.save("agents/old/learned_size7.")

In [None]:
M.show_behavior(policy, M.random_goal_env(7), 10, 15)
# M.show_behavior(policy, br_env(5), 40, 15)

# Hyperparameter search for agents

In [None]:
wandb.sweep(sweep=config)