# Advanced Features of Eve
## Making learning interesting

## NAS with RL

Now, let's try some more interesting things with **Eve**.

Let's begin with an example on minst, which use DDPG method to search for the best bit width for quantization
network.

First, let's import the necessary package.

In [1]:
import argparse
import difflib
import importlib
import os
import sys
import uuid

import gym
import numpy as np

# Although we do not import env here, 
# but it requires us to import the env to register the env to global.
import eve.rl.envs

import seaborn
import torch
from pprint import pprint
from eve.rl.exp_manager import ExperimentManager
from eve.rl.utils.utils import ALGOS, StoreDict
from stable_baselines3.common.utils import set_random_seed

seaborn.set()

Define hyperparameters for experiments

In [2]:
parser = argparse.ArgumentParser()
parser.add_argument("--algo", 
                    help="RL Algorithm used to NAS searching.",
                    default="ddpg",
                    type=str,
                    required=False,
                    choices=list(ALGOS.keys()),
                   )
parser.add_argument("--env",
                    help="The environment used to wrapper trainer."
                         "Different environment will apply different"
                         "reward function and interactive steps.",
                    default="FixedNas-v0",
                    type=str,
                    required=False,
                   )
parser.add_argument("-tb",
                    "--tensorboard-log",
                    help="Tensorboard log dir.",
                    default="/media/densechen/data/code/eve/examples/logs/",
                    type=str,
                   )
parser.add_argument("-i",
                    "--trained_agent",
                    help="Path to a pretrained agent to continue training",
                    default="",
                    type=str,
                   )
parser.add_argument("--truncate-last-trajectory",
                    help="When using HER with online sampling the last"
                         "trajectory in the replay buffer will be truncated"
                         "after reloading the replay buffer.",
                    default=True,
                    type=bool,
                   )
parser.add_argument("-n",
                    "--n-timesteps",
                    help="Overwrite the number of timesteps",
                    default=-1,
                    type=int,
                   )
parser.add_argument("--num-threads",
                    help="Number of threads for PyTorch (-1 to use default)",
                    default=-1,
                    type=int,
                   )
parser.add_argument("--log-interval",
                    help="Overwrite log interval (default: -1, no change)",
                    default=-1,
                    type=int,
                   )
parser.add_argument("--eval-freq",
                    help="Evaluate the agent every n steps (if negative, no evaluation)",
                    default=10000,
                    type=int,
                   )
parser.add_argument("--eval-episodes",
                    help="Number of episodes to use for evaluation",
                    default=5,
                    type=int,
                   )
parser.add_argument("--save-freq",
                    help="Save the model every n steps (if negative, no checkpoint)",
                    default=-1,
                    type=int,
                   )
parser.add_argument("--save-replay-buffer",
                    help="Save the replay buffer too (when applicable)",
                    action="store_true",
                    default=False,
                   )
parser.add_argument("-f",
                    "--log-folder",
                    help="Log folder",
                    type=str,
                    default="logs",
                   )
parser.add_argument("--seed",
                    help="Random generator seed",
                    type=int,
                    default=-1,
                   )
parser.add_argument("--vec-env",
                    help="VecEnv type",
                    type=str,
                    default="dummy",
                    choices=["dummy", "subproc"],
                   )
parser.add_argument("--n-trials",
                    help="Number of trials for optimizing hyperparameters",
                    type=int,
                    default=10,
                   )
parser.add_argument("-optimize",
                    "--optimize-hyperparameters",
                    action="store_true",
                    default=False,
                    help="Run hyperparameters search",
                   )
parser.add_argument("--n-jobs",
                    help="Number of parallel jobs when optimizing hyperparameters",
                    type=int,
                    default=1,
                   )
parser.add_argument("--sampler",
                    help="Sampler to use when optimizing hyperparameters",
                    type=str,
                    default="tpe",
                    choices=["random", "tpe", "skopt"],
                   )
parser.add_argument("--pruner",
                    help="Pruner to use when optimizing hyperparameters",
                    type=str,
                    default="median",
                    choices=["halving", "median", "none"],
                   )
parser.add_argument("--n-startup-trials",
                    help="Number of trials before using optuna sampler",
                    type=int, 
                    default=10,
                   )
parser.add_argument("--n-evaluations",
                    help="Number of evaluations for hyperparameter optimization",
                    type=int,
                    default=20,
                   )
parser.add_argument("--storage",
                    help="Database storage path if distributed optimization should be used",
                    type=str,
                    default=None,
                   )
parser.add_argument("--study-name",
                    help="Study name for distributed optimization",
                    type=str,
                    default=None,
                   )
parser.add_argument("--verbose",
                    help="Verbose mode (0: no output, 1: INFO)",
                    default=0, 
                    type=int)
parser.add_argument("--gym-packages",
                    type=str,
                    nargs="+",
                    default=[],
                    help="Additional eve Gym environment package modules to import"
                         "(e.g. gym_minigrid)",
                   )
parser.add_argument("--env-kwargs",
                    type=str,
                    nargs="+",
                    action=StoreDict,
                    help="Optional keyword argument to pass to the env constructor"
                         "Discard! Manually defined in latter.",
                   )
parser.add_argument("-params",
                    "--hyperparams",
                    type=str,
                    nargs="+",
                    action=StoreDict,
                    help="Overwrite hyperparameter (e.g. learning_rate:0.01)",
                   )
parser.add_argument("-uuid",
                    "--uuid",
                    action="store_true",
                    default=False,
                    help="Ensure that the run has a unique ID.")
args = parser.parse_args()

# Create env_kwargs here
# the following parameters is used to define a trainer for env.
args.env_kwargs = dict(
    trainer_id="mnist",
    checkpoint_path="/media/densechen/data/code/eve/examples/mnist.pt",
    max_timesteps=1, # keep 1 for non-spiking mode.
    data_kwargs={
        "root": "/media/densechen/data/dataset",
    },
    kwargs={
        "device": "cuda:0",
        "root_dir": "/media/densechen/data/code/eve/examples",
    },
)

# rewrite log floder.
args.log_folder = "/media/densechen/data/code/eve/examples/logs"

pprint(args)

Namespace(algo='ddpg', env='FixedNas-v0', env_kwargs={'trainer_id': 'mnist', 'checkpoint_path': '/media/densechen/data/code/eve/examples/mnist.pt', 'max_timesteps': 1, 'data_kwargs': {'root': '/media/densechen/data/dataset'}, 'kwargs': {'device': 'cuda:0', 'root_dir': '/media/densechen/data/code/eve/examples'}}, eval_episodes=5, eval_freq=10000, gym_packages=[], hyperparams=None, log_folder='/media/densechen/data/code/eve/examples/logs', log_interval=-1, n_evaluations=20, n_jobs=1, n_startup_trials=10, n_timesteps=-1, n_trials=10, num_threads=-1, optimize_hyperparameters=False, pruner='median', sampler='tpe', save_freq=-1, save_replay_buffer=False, seed=-1, storage=None, study_name=None, tensorboard_log='/media/densechen/data/code/eve/examples/logs/', trained_agent='', truncate_last_trajectory=True, uuid=False, vec_env='dummy', verbose=0)


In [3]:
# Going through custom gym packages to let them register in the global registory
for env_module in args.gym_packages:
    importlib.import_module(env_module)

env_id = args.env
registered_envs = set(gym.envs.registry.env_specs.keys())
# If the environment is not found, suggest the closest math
if env_id not in registered_envs:
    try:
        closest_match = difflib.get_close_matches(env_id, 
                                                  registered_envs,
                                                  n=1)[0]
    except IndexError:
        closest_match = "no close match found..."
    raise ValueError(
        r"{env_id} not found in gym registry, you maybe meant {closest_match}"
    )

# Unique id to ensure there is no race condition for the folder creation
uuid_str = f"_{uuid.uuid4()}" if args.uuid else ""
if args.seed < 0:
    # Seed but with a random one.
    args.seed = np.random.randint(2 ** 32 - 1, dtype="int64").item()

set_random_seed(args.seed)

# Setting num threads to 1 makes things run faster on cpu.
if args.num_threads > 0:
    if args.verbose > 0:
        pprint(f"Setting torch.num_threads to {args.num_threads}")
        torch.set_num_threads(args.num_threads)

if args.trained_agent != "":
    assert args.trained_agent.endswith(".zip") and os.path.isfile(args.trained_agent), \
        "The trained_agent must be a valid path to a .zip fle."
print("=" * 10, env_id, "=" * 10)
print(f"Seed: {args.seed}")

Seed: 1460492072


Define the ExperimentManager

In [4]:
exp_manager = ExperimentManager(
    args,
    args.algo,
    env_id,
    args.log_folder,
    args.tensorboard_log,
    args.n_timesteps,
    args.eval_freq,
    args.eval_episodes,
    args.save_freq,
    args.hyperparams,
    args.env_kwargs,
    args.trained_agent,
    args.optimize_hyperparameters,
    args.storage,
    args.study_name,
    args.n_trials,
    args.n_jobs,
    args.sampler,
    args.pruner,
    n_startup_trials=args.n_startup_trials,
    n_evaluations=args.n_evaluations,
    truncate_last_trajectory=args.truncate_last_trajectory,
    uuid_str=uuid_str,
    seed=args.seed,
    log_interval=args.log_interval,
    save_replay_buffer=args.save_replay_buffer,
    verbose=args.verbose,
    vec_env_type=args.vec_env,
)

Launch the trainer

In [5]:
model = exp_manager.setup_experiment()

("making new trainer: mnist ({'checkpoint_path': "
 "'/media/densechen/data/code/eve/examples/mnist.pt', 'max_timesteps': 1, "
 "'data_kwargs': {'root': '/media/densechen/data/dataset'}, 'kwargs': "
 "{'device': 'cuda:0', 'root_dir': "
 "'/media/densechen/data/code/eve/examples'}})")
original accuracy: 0.893690664556962
create an upgrader automatically
("making new trainer: mnist ({'checkpoint_path': "
 "'/media/densechen/data/code/eve/examples/mnist.pt', 'max_timesteps': 1, "
 "'data_kwargs': {'root': '/media/densechen/data/dataset'}, 'kwargs': "
 "{'device': 'cuda:0', 'root_dir': "
 "'/media/densechen/data/code/eve/examples'}})")
original accuracy: 0.893690664556962
create an upgrader automatically
Applying normal noise with std 0.1
Log path: /media/densechen/data/code/eve/examples/logs/ddpg/FixedNas-v0_2


In [6]:
# Normal training
if model is not None:
    exp_manager.learn(model)
    exp_manager.save_trained_model(model)
else:
    exp_manager.hyperparameters_optimization()

original acc: 0.893690664556962  vs. rl acc: 0.8918117088607594
Saving to /media/densechen/data/code/eve/examples/logs/ddpg/FixedNas-v0_2


Go to the tensorboard log folder, and run ```tensorboard --logdir .``` to see the trianing log.