<a href="https://colab.research.google.com/github/iskra3138/stable-baselines/blob/main/Cell_Based_Stable_baselines_A2C_HPO.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TF 1.X 선택 및 필요한 Library 설치

In [None]:
# Stable Baselines only supports tensorflow 1.x for now
%tensorflow_version 1.x

!apt install swig
!pip install stable-baselines[mpi]==2.10.0
!pip install optuna

# 필요 모듈 호출

In [None]:
import os
import numpy as np
import gym
import time

import optuna
from optuna.pruners import SuccessiveHalvingPruner, MedianPruner
from optuna.samplers import RandomSampler, TPESampler
from optuna.integration.skopt import SkoptSampler

from stable_baselines import A2C
#from stable_baselines.common.noise import AdaptiveParamNoiseSpec, NormalActionNoise, OrnsteinUhlenbeckActionNoise
from stable_baselines.common.vec_env import VecNormalize, DummyVecEnv, VecEnv
from stable_baselines.common import set_global_seeds
#from stable_baselines.bench import Monitor
from stable_baselines.common import make_vec_env

# <font color='red'> Arguments 입력 </font>

In [None]:
algo ='a2c'
env_id ='CartPole-v1'

n_trials = 1000 # (int) maximum number of trials for finding the best hyperparams
n_timesteps = 50000 #  (int) maximum number of timesteps per trial
sampler_method = 'tpe' # 'random', 'tpe', 'skopt' 중 선택
pruner_method = 'median' # 'halving',  'median', 'none' 중 선택

n_jobs = 2 # (int) number of parallel jobs
seed = 0

n_envs = 1 # (int) number of environments

tensorboard_log = './tb_log' # Tensorboard log 기록을 위한 path
log_folder = './log' # HPO 결과를 저장할 폴더 명

n_startup_trials = 10
n_eval_episodes = 5 # test during 5 episodes 
n_evaluations = 20 # evaluate every 20th of the maximum budget per iteration

eval_freq = int(n_timesteps / n_evaluations) # 몇 번 evaluatino할 것인가?

# <font color='red'> HPO 탐색공간 입력 </font>

In [None]:
# batch_size 고정에 따른 최적 하이퍼파라미터 탐색을 위해 n_steps값은 고정시킴. 

def sample_a2c_params(trial):
    """
    Sampler for A2C hyperparams.

    :param trial: (optuna.trial)
    :return: (dict)
    """
    gamma = trial.suggest_categorical('gamma', [0.9, 0.95, 0.98, 0.99, 0.995, 0.999, 0.9999])
    n_steps = trial.suggest_categorical('n_steps', [128]) #[8, 16, 32, 64, 128, 256, 512, 1024, 2048]
    lr_schedule = trial.suggest_categorical('lr_schedule', ['linear', 'constant'])
    learning_rate = trial.suggest_loguniform('lr', 1e-5, 1)
    ent_coef = trial.suggest_loguniform('ent_coef', 0.00000001, 0.1)
    vf_coef = trial.suggest_uniform('vf_coef', 0, 1)
    # normalize = trial.suggest_categorical('normalize', [True, False])
    # TODO: take into account the normalization (also for the test env)

    return {
        'n_steps': n_steps,
        'gamma': gamma,
        'learning_rate': learning_rate,
        'lr_schedule': lr_schedule,
        'ent_coef': ent_coef,
        'vf_coef': vf_coef
    }

# 필요한 Class, 함수 정의

In [None]:
from stable_baselines.common.callbacks import BaseCallback, EvalCallback
class TrialEvalCallback(EvalCallback):
    """
    Callback used for evaluating and reporting a trial.
    """
    def __init__(self, eval_env, trial, n_eval_episodes=5,
                 eval_freq=10000, deterministic=True, verbose=0):

        super(TrialEvalCallback, self).__init__(eval_env=eval_env, n_eval_episodes=n_eval_episodes,
                                                eval_freq=eval_freq,
                                                deterministic=deterministic,
                                                verbose=verbose)
        self.trial = trial
        self.eval_idx = 0
        self.is_pruned = False

    def _on_step(self):
        if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
            super(TrialEvalCallback, self)._on_step()
            self.eval_idx += 1
            # report best or report current ?
            # report num_timesteps or elasped time ?
            self.trial.report(-1 * self.last_mean_reward, self.eval_idx)
            # Prune trial if need
            if self.trial.should_prune():
                self.is_pruned = True
                return False
        return True

In [None]:
def make_env(env_id, rank, seed=0):
    """
    Utility function for multiprocessed env.
    
    :param env_id: (str) the environment ID
    :param seed: (int) the inital seed for RNG(난수생성기)
    :param rank: (int) index of the subprocess
    """
    def _init():
        env = gym.make(env_id)
        # Important: use a different seed for each environment
        env.seed(seed + rank)
        return env
    set_global_seeds(seed)
    return _init

# HPO 탐색 사전 준비

In [None]:
# n_warmup_steps: Disable pruner until the trial reaches the given number of step.
if sampler_method == 'random':
    sampler = RandomSampler(seed=seed)
elif sampler_method == 'tpe':
    sampler = TPESampler(n_startup_trials=n_startup_trials, seed=seed)
elif sampler_method == 'skopt':
    # cf https://scikit-optimize.github.io/#skopt.Optimizer
    # GP: gaussian process
    # Gradient boosted regression: GBRT
    sampler = SkoptSampler(skopt_kwargs={'base_estimator': "GP", 'acq_func': 'gp_hedge'})
else:
    raise ValueError('Unknown sampler: {}'.format(sampler_method))

if pruner_method == 'halving':
    pruner = SuccessiveHalvingPruner(min_resource=1, reduction_factor=4, min_early_stopping_rate=0)
elif pruner_method == 'median':
    pruner = MedianPruner(n_startup_trials=n_startup_trials, n_warmup_steps=n_evaluations // 3)
elif pruner_method == 'none':
    # Do not prune
    pruner = MedianPruner(n_startup_trials=n_trials, n_warmup_steps=n_evaluations)
else:
    raise ValueError('Unknown pruner: {}'.format(pruner_method))

study = optuna.create_study(sampler=sampler, pruner=pruner)
algo_sampler = sample_a2c_params

# Objective Fn. 정의

In [None]:
# 탐색 Task 정의
def objective(trial):
    kwargs = {}
    #trial.model_class = None
    kwargs.update(algo_sampler(trial)) # 탐색공간에서 샘플링을 한 다음 argument를 업데이트하는 것 같음

    # 환경 정의
    #env = make_vec_env('CartPole-v1', n_envs= 1)
    if n_envs == 1:
        env = DummyVecEnv([make_env(env_id, 0, seed)])
    else:
        env = DummyVecEnv([make_env(env_id, i, seed) for i in range(n_envs)])

    # 모델 정의
    model = A2C('MlpPolicy', env=env, tensorboard_log=tensorboard_log, verbose=0, **kwargs)

    # 평가 환경 정의?
    eval_env = DummyVecEnv([make_env(env_id, 0, seed)])

    # Account for parallel envs
    eval_freq_ = eval_freq
    if isinstance(model.get_env(), VecEnv):
        eval_freq_ = max(eval_freq // model.get_env().num_envs, 1)
    # TODO: use non-deterministic eval for Atari?
    eval_callback = TrialEvalCallback(eval_env, trial, n_eval_episodes=n_eval_episodes,
                                      eval_freq=eval_freq_, deterministic=True)

    try:
        model.learn(n_timesteps, callback=eval_callback)
        # Free memory
        model.env.close()
        eval_env.close()
    except AssertionError:
        # Sometimes, random hyperparams can generate NaN
        # Free memory
        model.env.close()
        eval_env.close()
        raise optuna.exceptions.TrialPruned()
    is_pruned = eval_callback.is_pruned
    cost = -1 * eval_callback.last_mean_reward

    del model.env, eval_env
    del model

    if is_pruned:
        raise optuna.exceptions.TrialPruned()

    return cost

# <font color='blue'> 탐색 </font>

In [None]:
try:
    study.optimize(objective, n_trials=n_trials, n_jobs=n_jobs)
except KeyboardInterrupt:
    pass

print('Number of finished trials: ', len(study.trials))

print('Best trial:')
trial = study.best_trial

print('Value: ', trial.value)

print('Params: ')
for key, value in trial.params.items():
    print('    {}: {}'.format(key, value))

data_frame = study.trials_dataframe()

# 결과 저장

In [None]:
report_name = "report_{}_{}-trials-{}-{}-{}_{}.csv".format(env_id, n_trials, n_timesteps,
                                                        sampler, pruner, int(time.time()))
log_path = os.path.join(log_folder, algo, report_name)
print("Writing report to {}".format(log_path))

os.makedirs(os.path.dirname(log_path), exist_ok=True)
data_frame.to_csv(log_path)

# 결과 다운로드

In [None]:
from google.colab import files
files.download(log_path)