In [199]:
from huggingface_hub import hf_hub_download, HfApi
from plotnine import ggplot, aes, geom_density, scale_fill_manual, xlim, element_text, theme
from skopt import load
from stable_baselines3 import PPO

from rl4fisheries import AsmEnv, AsmMovingAvg
from rl4fisheries.utils import evaluate_agent
from rl4fisheries.envs.asm_fns import get_r_devs

import numpy as np
import pandas as pd
import ray

In [200]:
directory = "sb3/rl4fisheries/post-review-results/"
regular_ppo = "PPO-AsmEnv-2obs-UM3-256-64-16-noise0.1-no-rescaling-chkpnt4.zip"
mv_avg_ppo = "PPO-AsmMovingAvg-2obs-UM3-full-avg-chkpnt3.zip"


repo = "boettiger-lab/rl4eco"

PPO_avg = hf_hub_download(repo_id=repo, filename=directory+mv_avg_ppo)
PPO_reg = hf_hub_download(repo_id=repo, filename=directory+regular_ppo)


In [201]:
PPO_reg_agent = PPO.load(PPO_reg, device='cpu')
PPO_mv_agent = PPO.load(PPO_avg, device='cpu')

In [202]:
def get_rews(agent, env, agent_name):
    rewards = evaluate_agent(
        agent=agent, env=env, ray_remote=True,
    ).evaluate(
        return_episode_rewards=True, n_eval_episodes=500,
    )
    return {'rew': rewards, 'agent': [agent_name] * len(rewards)}

In [203]:
CFG_UM3 = {
    'observation_fn_id': 'observe_2o',
    'n_observs': 2,
    'obs_noise': 0.1,
    #
    'harvest_fn_name': "trophy",
    'upow': 1,
    'n_trophy_ages': 10,
    'avg_win_size': 3,
}

asm = AsmEnv(config=CFG_UM3)
asm_avg = AsmMovingAvg(config=CFG_UM3)


In [204]:
PPO_reg_rew = get_rews(
    agent=PPO_reg_agent,
    env=asm,
    agent_name='Regular RL',
)


In [206]:
PPO_avg_rew = get_rews(
    agent=PPO_mv_agent,
    env=asm_avg,
    agent_name='Moving Average RL',
)

In [207]:
ray.shutdown()