In [None]:
from __future__ import annotations

import pickle
import random
import re
from pathlib import Path

import pandas as pd
import plotly.express as px
from joblib import Parallel, delayed
from stable_baselines3 import PPO
from tqdm.autonotebook import tqdm

import main as M


%load_ext autoreload
%autoreload 2

In [None]:
print("Action space:", M.RANDOM_GOAL_ENV.action_space)
print("Observation space:", M.RANDOM_GOAL_ENV.observation_space)

agent_files = list(Path("agents").glob("*.zip"))
print(f"Collected {len(agent_files)} agents")

In [None]:
def get_random_agent():
    file = random.choice(agent_files)
    print("Loading agent:", file)
    return PPO.load(file)

In [None]:
policy = get_random_agent()
print("Policy:", policy.policy)
print("Model size:", sum(p.numel() for p in policy.policy.parameters()))

In [None]:
policy = get_random_agent()
# M.eval_agent(policy, M.RANDOM_GOAL_ENV, end_condition=lambda locals_: locals_["env"].agent_pos == (3, 3))
M.Perfs.from_agent(policy)
# M.show_behavior(policy, M.RANDOM_GOAL_ENV, 40)
# M.show_behavior(policy, M.BR_GOAL_ENV, 10)
# M.eval_agent(policy, plot=True)

In [None]:
M.show_behavior(M.BottomRightAgent(), M.RANDOM_GOAL_ENV, 40)

In [None]:
M.Perfs.from_agent(M.BottomRightAgent(), episodes=1000)

In [None]:
perfs = list(
    Parallel(n_jobs=-3)(
        delayed(M.Perfs.from_agent)(PPO.load(file), file=file)
        for file in tqdm(agent_files)
    )
)

In [None]:
# Save the perfs
saved = perfs
pickle.dump(perfs, open("perfs.pkl", "wb"))

In [None]:
# Load the perfs
perfs = pickle.load(open("perfs.pkl", "rb"))

In [None]:
# Scatter plot of the perfs, 2 by 2
df = pd.DataFrame(
    dict(
        br_env=[p.br_env for p in perfs],
        general_env=[p.general_env for p in perfs],
        general_br_freq=[p.general_br_freq for p in perfs],
        file=[str(p.info["file"]) for p in perfs],
        odds=[
            int(re.search(r"(\d+)odds", str(p.info["file"])).group(1)) for p in perfs
        ],
    )
)

# Plot the perfs for agent with br_env > 0.9
px.scatter(
    df[df.br_env > 0.9],
    x="general_br_freq",
    y="general_env",
    color="odds",
    hover_name="file",
    width=800,
    height=800,
).show()

In [None]:
agent = PPO.load("agents/ppo_50000steps_612gen_998br_2odds_1689943572.zip")
M.show_behavior(agent, M.RANDOM_GOAL_ENV, 40)

In [None]:
import train

policy = train.get_agent(5, 30_000, net_arch=(30, 10), env_size=6, save=False)

In [None]:
policy[1]

In [None]:
random_env = M.wrap_env(M.SimpleEnv(6, None, None, render_mode="rgb_array"))
M.show_behavior(policy, random_env, 40)

# Try 2

In [None]:
import train

policy, perfs = train.get_agent(1000, 40_000, net_arch=(30, 10), env_size=7, save=False)

In [None]:
M.show_behavior(policy, random_goal_env(5), 40)

In [None]:
new_agents = []
for _ in range(10):
    policy, perfs = get_agent(50, 30_000, net_arch=(30, 10), env_size=5, save=False)
    new_agents.append((policy, perfs))

for _, perf in new_agents:
    print(perf)

In [None]:
policy.save("agents/old/learned_size7.")

In [None]:
M.show_behavior(policy, random_goal_env(5), 40, 15)
# M.show_behavior(policy, br_env(5), 40, 15)