# NATS Bench (same topology space as NasBench 201)

In [None]:
import os
import sys

import gym
import ray
from ray import tune
#from ray.rllib.agents.ppo import ppo 
from ray.rllib.algorithms.ppo import PPO, PPOConfig
from ray.rllib.algorithms.callbacks import MultiCallbacks
from ray.tune.registry import register_env

from ray.air.callbacks.wandb import WandbLoggerCallback

module_path = os.path.abspath(os.path.join('nas-bench-envs'))
if module_path not in sys.path:
    sys.path.append(module_path)
    os.environ['PYTHONPATH'] = module_path
from nas_bench_envs.envs.nas_bench_201_envs import NasBench201, NasBench201Clusters
from nas_bench_envs.callbacks import MetricsCallbacks


#### Test NatsBench

In [None]:
from nats_bench import create

api = create("/scratch2/sem22hs2/NATS-tss-v1_0-3ffb9-simple", 'tss', fast_mode=True, verbose=False)
print(api.get_more_info(1234, 'cifar10'))
print()
print(api.get_latency(1234, 'cifar10'))
print()
print(api.get_cost_info(1234, 'cifar10'))

## Train using RLlib

#### Whole NasBench Dataset

In [None]:
def env_creator(env_config):
    return NasBench201(env_config)


select_env = "NasBench201"
register_env(select_env, env_creator)

In [None]:
config = (
    PPOConfig()
    .framework("torch")
    .resources(num_gpus=1, num_cpus_per_worker=1)
    .environment(env=NasBench201, render_env=False)
    .rollouts(horizon=1000, num_rollout_workers=8)
    .callbacks(MetricsCallbacks)
)
ray.init(ignore_reinit_error=True)
tune.run(
    PPO,
    config=config.to_dict(),
    stop={"training_iteration": 100},
    callbacks=[WandbLoggerCallback(api_key="c36c598399c6c7f2f0b446aac164da6c7956a263", project="RayNasBenchV0")]
)
ray.shutdown()

#### On selected clusters

In [None]:
ray.shutdown()

In [None]:
env_config = {"cluster": 11, "network_init": "cluster", "dataset": "cifar100"}
config = (
    PPOConfig()
    .framework("torch")
    .resources(num_gpus=1, num_cpus_per_worker=1)
    .environment(env=NasBench201Clusters, render_env=False, env_config=env_config)
    .rollouts(horizon=1000, num_rollout_workers=8)
    .reporting()  # keep_per_episode_custom_metrics= True
    .callbacks(MetricsCallbacks)
)
ray.init(ignore_reinit_error=True)
tune.run(
    PPO,
    config=config.to_dict(),
    stop={"training_iteration": 100},
    callbacks=[WandbLoggerCallback(api_key="c36c598399c6c7f2f0b446aac164da6c7956a263", project="RayNasBenchClustersV0")]
)
ray.shutdown()

In [None]:
from ray.tune.schedulers import PopulationBasedTraining

env_config = {"cluster": 11, "network_init": "cluster", "dataset": "cifar10"}
config = (
    PPOConfig()
    .framework("torch")
    .resources(num_gpus=1, num_cpus_per_worker=1)
    .environment(env=NasBench201Clusters, render_env=False, env_config=env_config)
    .rollouts(horizon=1000, num_rollout_workers=8)
    .callbacks(MetricsCallbacks)
)


# Postprocess the perturbed config to ensure it's still valid
def explore(config):
    # ensure we collect enough timesteps to do sgd
    if config["train_batch_size"] < config["sgd_minibatch_size"] * 2:
        config["train_batch_size"] = config["sgd_minibatch_size"] * 2
    # ensure we run at least one sgd iter
    if config["num_sgd_iter"] < 1:
        config["num_sgd_iter"] = 1
    return config


pbt = PopulationBasedTraining(
    time_attr="time_total_s",
    perturbation_interval=120,
    resample_probability=0.25,
    # Specifies the mutations of these hyperparams
    hyperparam_mutations={
        "lambda": lambda: random.uniform(0.9, 1.0),
        "clip_param": lambda: random.uniform(0.01, 0.5),
        "lr": [1e-3, 5e-4, 1e-4, 5e-5, 1e-5],
        "num_sgd_iter": lambda: random.randint(1, 30),
        "sgd_minibatch_size": lambda: random.randint(128, 16384),
        "train_batch_size": lambda: random.randint(2000, 160000),
    },
    custom_explore_fn=explore,
)

tuner = tune.Tuner(
    PPO,
    tune_config=tune.TuneConfig(
        metric="episode_reward_max",
        mode="max",
        scheduler=pbt,
        num_samples=1,
    ),
    param_space=config.to_dict(),

)
results = tuner.fit()

print("best hyperparameters: ", results.get_best_result().config)

Train using rllib trainer in for loop

In [None]:
# Import the RL algorithm (Trainer) we would like to use.
from ray.rllib.agents.ppo import PPOTrainer

# Configure the algorithm.
config = {
    # Environment (RLlib understands openAI gym registered strings).
    "env": NasBench201Env,
    # Use 2 environment workers (aka "rollout workers") that parallelly
    # collect samples from their own environment clone(s).
    "num_workers": 2,
    # Change this to "framework: torch", if you are using PyTorch.
    # Also, use "framework: tf2" for tf2.x eager execution.
    "framework": "torch",
    # Set up a separate evaluation worker set for the
    # `trainer.evaluate()` call after training (see below).
    "evaluation_num_workers": 1,
    # Only for evaluation runs, render the env.
    "evaluation_config": {
        "render_env": False,
    },
    "horizon": 10
}

# Create our RLlib Trainer.
trainer = PPOTrainer(config=config)

# Run it for n training iterations. A training iteration includes
# parallel sample collection by the environment workers as well as
# loss calculation on the collected batch and a model update.
for _ in range(3):
    print(trainer.train())

# Evaluate the trained Trainer (and render each timestep to the shell's
# output).
trainer.evaluate()


Ray Gym Environment checker

In [None]:
ray.rllib.utils.check_env(NasBench201Clusters())

Stable Baselines

In [None]:
import os
import sys

module_path = os.path.abspath(os.path.join('nas-bench-envs'))
if module_path not in sys.path:
    sys.path.append(module_path)
import gym
from nas_bench_envs.envs.nas_bench_201_envs import NasBench201
from stable_baselines3 import PPO
from wandb.integration.sb3 import WandbCallback
from stable_baselines3.common.monitor import Monitor

config = {
    "policy_type": "MlpPolicy",
    "total_timesteps": 25000,
    "env_name": "nas_bench_envs/NasBench201Env-v0",
}


def make_env():
    env = gym.make(config["env_name"])
    env = Monitor(env)  # record stats such as returns
    return env


env = make_env()

run = wandb.init(
    project="debug",
    config=config,
    sync_tensorboard=True,  # auto-upload sb3's tensorboard metrics
    monitor_gym=True,  # auto-upload the videos of agents playing the game
    save_code=True,  # optional
)

model = PPO(config["policy_type"], env, verbose=1, tensorboard_log=f"runs/{run.id}")
model.learn(config["total_timesteps"],
            callback=WandbCallback(
                gradient_save_freq=100,
                model_save_path=f"models/{run.id}",
                verbose=2,
            ), )

obs = env.reset()
for i in range(1000):
    action, _states = model.predict(obs, deterministic=True)
    obs, reward, done, info = env.step(action)
    # env.render()
    if done:
        obs = env.reset()
env.close()
run.finish()

In [None]:
print(PPOConfig().)

Test functionality of NastBench and Env

In [None]:
from nats_bench import create

# Create the API instance for the topology search space in NATS
api = create("/scratch2/sem22hs2/NATS-tss-v1_0-3ffb9-simple", 'tss', fast_mode=True, verbose=False)
architecture_str = api.arch(12)
print(architecture_str)
info = api.get_more_info(architecture_str, 'cifar10')
print(info)

In [None]:
import gym.spaces as spaces

env = gym.make("nas_bench_envs/NasBench201Env-v0")
env.reset()
print(env.adjacency_tensor)
env.step(20)
env.step(1)
# print(arch_str)
# api.get_more_info("|none~0|+|avg_pool_3x3~0|skip_connect~1|+|nor_conv_3x3~0|skip_connect~0|none~1|", 'cifar10')
env.spec.max_episode_steps = 100
print(env.observation_space.shape)
print(env.action_space)

In [None]:
config = {"render_mode": "human"}
env = NasBench201Clusters(config)
env.reset()
print(env._nb201_lookup()['train-all-time'])
env.step(3)
env.render()
env._render_frame("human")

In [None]:
env = NasBench201()
env.reset()
env.step(3)
env._render_frame("human")

In [None]:
env.adjacency_tensor[0, 1, :].nonzero()[0][0]

In [None]:
env = NasBench201()
env.reset()
env.step(3)
env._render_frame("human")

In [None]:
env.adjacency_tensor[0, 1, :].nonzero()[0][0]

In [None]:
!pip install wandb --upgrade

--------------------/scratch2/thorir/Data/Results----------------------
Search space:                       nas-bench-201
Data set:                           cifar100
Used algorithm:                     random
Search criterion:                   test_accuracy
Number of runs:                     20
Random Seed:                        3700
Filter applied:                     0
Clustering used:                    True
Model selection method:             silhouette
Cluster selection mode:             oracle
Quantiling used:                    False
Quantile:                           0.2
Quantiling statistic:               naswot_v2_mean
Number of networks:                 1638
Random search mean performance:     71.64
Random search performance STD:      0.85
Random search performance median:   71.6
Random search performance skew:     0.11
Time taken:                         215.0s