In [1]:
import or_gym
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from common import make_env

from tqdm import tqdm

from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy

from stable_baselines3 import SAC
from stable_baselines3.sac.policies import MlpPolicy as SACPolicy

from stable_baselines3 import A2C
from stable_baselines3.a2c.policies import MlpPolicy as A2CPolicy

from stable_baselines3 import PPO
from stable_baselines3.ppo.policies import MlpPolicy as PPOPolicy

from sb3_contrib import ARS
from sb3_contrib.ars.policies import ARSPolicy

from sb3_contrib import RecurrentPPO
from sb3_contrib.ppo_recurrent.policies import RecurrentActorCriticPolicy

from sb3_contrib import TQC
from sb3_contrib.tqc.policies import MlpPolicy as TQCPolicy

from sb3_contrib import TRPO
from sb3_contrib.trpo.policies import MlpPolicy as TRPOPolicy


plt.rcParams['figure.dpi'] = 256
plt.rcParams['text.usetex'] = True

In [2]:
def get_algo(algo_name):
    if algo_name == 'PPO':
        return PPO

    if algo_name == 'A2C':
        return A2C

    if algo_name == 'TRPO':
        return TRPO


def run_evals(env_name, algo_name, name, n_eval_episodes, env_seed=42):
    save_path = f'./data/{env_name}/{algo_name}/{name}/'

    env = make_env(env_name, env_seed=env_seed)
    algo = get_algo(algo_name)
    model = algo.load(save_path + 'best_model', env=env)

    obs = env.reset()

    df_names = ['D', 'X', 'R', 'P', 'Y']

    for episode in tqdm(range(n_eval_episodes)):
        for timestep in range(env.num_periods):
            action = model.predict(obs)
            obs, reward, _, _ = env.step(action[0])

        # Done with the episode so now add the data their lists

        for df_name in df_names:
            df = getattr(env, df_name)

            if not os.path.exists(save_path + f'eval/{episode}/'):
                os.makedirs(save_path + f'eval/{episode}/')

            df.to_csv(save_path + f'eval/{episode}/{df_name}.csv')

        # Reset the env for the next episode and select a new seed

        obs = env.reset()
        env.seed_int = episode


run_evals('NetworkManagement-v1-100', 'TRPO', 'default', 10)
            

100%|██████████| 10/10 [00:05<00:00,  1.67it/s]
