In [2]:
import numpy as np
import matplotlib.pyplot as plt
import or_gym
import os

from common import make_env

from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv, VecMonitor, SubprocVecEnv
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3.common.env_util import make_vec_env

from stable_baselines3 import SAC
from stable_baselines3.sac.policies import MlpPolicy as SACPolicy

from stable_baselines3 import A2C
from stable_baselines3.a2c.policies import MlpPolicy as A2CPolicy

from stable_baselines3 import PPO
from stable_baselines3.ppo.policies import MlpPolicy as PPOPolicy

from sb3_contrib import ARS
from sb3_contrib.ars.policies import ARSPolicy

from sb3_contrib import RecurrentPPO
from sb3_contrib.ppo_recurrent.policies import RecurrentActorCriticPolicy

from sb3_contrib import TQC
from sb3_contrib.tqc.policies import MlpPolicy as TQCPolicy

from sb3_contrib import TRPO
from sb3_contrib.trpo.policies import MlpPolicy as TRPOPolicy

In [3]:
def train_model_on_env(env_name, algo_name, name, n_envs=1, timesteps=int(1e5), eval_freq=int(5e3), env_seed=0, load_existing=True):
    save_path = f'./data/{env_name}/{algo_name}/{name}/'

    if os.path.exists(save_path):
        print(f'Using existing directory {save_path}')
    
    else:
        print(f'Creating new directory {save_path}')

    def make_subproc_env():    
        def _init():
            make_env(env_name)
        return _init
    
    if n_envs == 1:
        env = make_env(env_name)

        if load_existing:
            env = Monitor(env, save_path, override_existing=False)
        else:
            env = Monitor(env, save_path, override_existing=True)
    else:
        env = SubprocVecEnv([make_subproc_env() for _ in range(n_envs)])

        if load_existing:
            env = VecMonitor(env, save_path, override_existing=False)
        else:
            env = VecMonitor(env, save_path, override_existing=True)

    def make_model(algo_name, env, n_steps, batch_size):
        model_path = save_path + 'best_model.zip'

        if algo_name == 'PPO':
            if load_existing is True and os.path.exists(model_path):
                print('Loading existing model...')
                return PPO.load(model_path, env)
            else:
                return PPO(PPOPolicy, env, n_steps=n_steps, batch_size=batch_size)

        if algo_name == 'RecurrentPPO':
            if load_existing is True and os.path.exists(model_path):
                print('Loading existing model...')
                RecurrentPPO.load(model_path, env)
            else:
                return RecurrentPPO(RecurrentActorCriticPolicy, env, n_steps=n_steps, batch_size=batch_size)

        if algo_name == 'A2C':
            if load_existing is True and os.path.exists(model_path):
                print('Loading existing model...')
                return A2C.load(model_path, env)
            else:
                return A2C(A2CPolicy, env, n_steps=n_steps)
        
        if algo_name == 'ARS':
            if load_existing is True and os.path.isfile(model_path):
                print('Loading existing model...')
                return ARS.load(model_path, env)
            else:
                return ARS(ARSPolicy, env)

        if algo_name == 'SAC':
            if load_existing is True and os.path.isfile(model_path):
                print('Loading existing model...')
                return SAC.load(model_path, env)
            else:
                return SAC(SACPolicy, env, batch_size=batch_size)

        if algo_name == 'TQC':
            if load_existing is True and os.path.isfile(model_path):
                print('Loading existing model...')
                return TQC.load(model_path, env)
            else:
                return TQC(TQCPolicy, env, batch_size=batch_size)

        if algo_name == 'TRPO':
            if load_existing is True and os.path.isfile(model_path):
                print('Loading existing model...')
                return TRPO.load(model_path, env)
            else:
                return TRPO(TRPOPolicy, env, n_steps=n_steps, batch_size=batch_size)

    model = make_model(algo_name, env, n_steps=env.num_periods, batch_size=env.num_periods*n_envs)

    eval_callback = EvalCallback(env, best_model_save_path=save_path, verbose=1, log_path=save_path, 
                                    eval_freq=int(eval_freq), deterministic=True, render=False)

    model.learn(total_timesteps=int(timesteps), callback=eval_callback)

train_model_on_env(env_name='NetworkManagement-v1-100', algo_name='TRPO', name='default', n_envs=1, timesteps=4e6, eval_freq=5e3)
    

Using existing directory ./data/NetworkManagement-v1-100/TRPO/default/
Loading existing model...
Eval num_timesteps=5000, episode_reward=2550.23 +/- 15.49
Episode length: 100.00 +/- 0.00
New best mean reward!
Eval num_timesteps=10000, episode_reward=2199.96 +/- 7.23
Episode length: 100.00 +/- 0.00
Eval num_timesteps=15000, episode_reward=2053.74 +/- 5.48
Episode length: 100.00 +/- 0.00
Eval num_timesteps=20000, episode_reward=2547.31 +/- 140.30
Episode length: 100.00 +/- 0.00
Eval num_timesteps=25000, episode_reward=1589.08 +/- 3.37
Episode length: 100.00 +/- 0.00
Eval num_timesteps=30000, episode_reward=1832.96 +/- 7.92
Episode length: 100.00 +/- 0.00
Eval num_timesteps=35000, episode_reward=1892.72 +/- 3.46
Episode length: 100.00 +/- 0.00
Eval num_timesteps=40000, episode_reward=1582.47 +/- 1.62
Episode length: 100.00 +/- 0.00
Eval num_timesteps=45000, episode_reward=1452.99 +/- 2.12
Episode length: 100.00 +/- 0.00
Eval num_timesteps=50000, episode_reward=1900.89 +/- 2.62
Episode len

KeyboardInterrupt: 