In [None]:
from __future__ import annotations

import random
from dataclasses import dataclass
from pathlib import Path

import plotly.express as px
from joblib import Parallel, delayed
from stable_baselines3 import PPO, DQN
from tqdm.autonotebook import tqdm
import pandas as pd

import main as M

%load_ext autoreload
%autoreload 2

In [None]:
print("Action space:", M.RANDOM_GOAL_ENV.action_space)
print("Observation space:", M.RANDOM_GOAL_ENV.observation_space)

agent_files = list(Path("agents").glob("*.zip"))
print(f"Collected {len(agent_files)} agents")

In [None]:
def get_random_agent():
    file = random.choice(agent_files)
    print("Loading agent:", file)
    return PPO.load(file)

In [None]:
policy = get_random_agent()
print("Policy:", policy.policy)
print("Model size:", sum(p.numel() for p in policy.policy.parameters()))

In [None]:
policy = get_random_agent()
# M.eval_agent(policy, M.RANDOM_GOAL_ENV, end_condition=lambda locals_: locals_["env"].agent_pos == (3, 3))
M.Perfs.from_agent(policy)
# M.show_behavior(policy, M.RANDOM_GOAL_ENV, 40)
# M.show_behavior(policy, M.BR_GOAL_ENV, 10)
# M.eval_agent(policy, plot=True)

In [None]:
perfs = list(
    Parallel(n_jobs=-3)(
        delayed(M.Perfs.from_agent)(PPO.load(file)) for file in tqdm(agent_files)
    )
)

In [None]:
# Scatter plot of the perfs, 2 by 2
br_env = [p.br_env for p in perfs]
general_env = [p.general_env for p in perfs]
general_br_freq = [p.general_br_freq for p in perfs]

df = pd.DataFrame(
    dict(
        br_env=br_env,
        general_env=general_env,
        general_br_freq=general_br_freq,
        file=[f.name for f in agent_files],
    )
)

px.scatter(
    df, x="general_br_freq", y="general_env", color="br_env", hover_name="file"
).show()

In [None]:
agent = PPO.load("agents/ppo_50000steps_612gen_998br_2odds_1689943572.zip")
M.show_behavior(agent, M.RANDOM_GOAL_ENV, 40)

In [None]:
import train

policy = train.get_agent(5, 100_000, net_arch=(64, 32), env_size=6, save=False)

In [None]:
random_env = M.wrap_env(M.SimpleEnv(6, None, None, render_mode="rgb_array"))
M.show_behavior(policy, random_env, 40)

# Try 2

In [None]:
import time

import click
import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env

from main import wrap_env, SimpleEnv, eval_agent, uniform_distribution

random_goal_env = lambda size: wrap_env(
    SimpleEnv(size=size, goal_pos=None, agent_start_pos=None, render_mode="rgb_array")
)

br_env = lambda size: wrap_env(
    SimpleEnv(
        size=size,
        goal_pos=(size - 2, size - 2),
        agent_start_pos=None,
        render_mode="rgb_array",
    )
)


def get_agent(
    bottom_right_odds: int,
    steps: int = 50_000,
    n_envs: int = 1,
    net_arch: tuple = (30, 10),
    env_size: int = 5,
    save: bool = True,
):
    # Define the training environment
    goal_distrib = uniform_distribution((env_size - 1, env_size - 1))
    # There are (envsize-2)**2-1 other positions
    goal_distrib[env_size - 2, env_size - 2] = (
        bottom_right_odds * (env_size - 2) ** 2 - 1
    )
    env = make_vec_env(
        lambda: wrap_env(
            SimpleEnv(
                size=env_size,
                goal_pos=goal_distrib,
                # goal_pos=(-2, -2),
                agent_start_pos=None,
                # render_mode='rgb_array'
            )
        ),
        n_envs=n_envs,
    )

    # Define the policy network
    policy = DQN(
        "MlpPolicy",
        env,
        verbose=1,
        learning_rate=0.001,
        # learning_rate=lambda f: 0.001 * f,
        # learning_rate=lambda f: 0.01 * f ** 1.5,
        # policy_kwargs=dict(net_arch=net_arch),
        # n_steps=2000 // n_envs,
        # batch_size=100,
        # n_epochs=40,
        buffer_size=5_000,
        learning_starts=5_000,
        gradient_steps=100,
        target_update_interval=1000,
        exploration_fraction=0.2,
        # exploration_final_eps=0.2,
        # gamma=1,
        tensorboard_log="run_logs",
        device="cpu",
    )
    # Train the agent
    policy.learn(total_timesteps=steps)

    # Evaluate the agent
    # perfs = M.Perfs.from_agent(policy, env_size=env_size, episodes=300)
    br_success_rate = eval_agent(policy, br_env(env_size), 1000)
    success_rate = eval_agent(policy, random_goal_env(env_size), 1000)
    print("Bottom right success rate:", br_success_rate)
    print("Success rate:", success_rate)

    # Save the agent
    if save:
        name = f"agents/ppo_{steps}steps_{success_rate * 1000:03.0f}gen_{br_success_rate * 1000:03.0f}br_{bottom_right_odds}odds_{time.time():.0f}"
        policy.save(name)
        print(f"Saved model to {name}")

    return policy

In [None]:
policy = get_agent(100, 50_000, net_arch=(64, 32), env_size=7, save=False)

In [None]:
policy.save("agents/old/learned_size7.")

In [None]:
M.show_behavior(policy, random_goal_env(7), 40, 15)