Skip to content

Commit

Permalink
Add example for combining auto module with NNI.
Browse files Browse the repository at this point in the history
  • Loading branch information
iffiX committed May 21, 2021
1 parent a27275a commit d10727b
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 3 deletions.
@@ -0,0 +1,25 @@
authorName: default
experimentName: example_nni_auto
trialConcurrency: 1
maxExecDuration: 1h
maxTrialNum: 10
#choice: local, remote, pai
trainingServicePlatform: local
searchSpacePath: search_space.json
#choice: true, false
useAnnotation: false
tuner:
#choice: TPE, Random, Anneal, Evolution, BatchTuner, MetisTuner, GPTuner
#SMAC (SMAC should be installed through nnictl)
builtinTunerName: TPE
classArgs:
#choice: maximize, minimize
optimize_mode: maximize
trial:
command: python nni_main.py
codeDir: .
gpuNum: 1
localConfig:
useActiveGpu: true
maxTrialNumPerGpu: 2
gpuIndices: "0"
@@ -0,0 +1 @@
nnictl create --config config.yml --port 8088 --debug
@@ -0,0 +1,56 @@
from machin.auto.config import (
generate_algorithm_config,
generate_env_config,
generate_training_config,
launch,
)
from pytorch_lightning.callbacks import Callback

import nni
import torch as t
import torch.nn as nn


class SomeQNet(nn.Module):
def __init__(self, state_dim, action_num):
super().__init__()

self.fc1 = nn.Linear(state_dim, 16)
self.fc2 = nn.Linear(16, 16)
self.fc3 = nn.Linear(16, action_num)

def forward(self, state):
a = t.relu(self.fc1(state))
a = t.relu(self.fc2(a))
return self.fc3(a)


class InspectCallback(Callback):
def __init__(self):
self.total_reward = 0

def on_train_batch_end(
self, trainer, pl_module, outputs, batch, _batch_idx, _dataloader_idx
) -> None:
for l in batch[0].logs:
if "total_reward" in l:
self.total_reward = l["total_reward"]


if __name__ == "__main__":
param = nni.get_next_parameter()
cb = InspectCallback()
while param:
config = generate_algorithm_config("DQN")
config = generate_env_config("openai_gym", config)
config = generate_training_config(
root_dir="trial", episode_per_epoch=10, max_episodes=10000, config=config
)
config["frame_config"]["models"] = ["SomeQNet", "SomeQNet"]
config["frame_config"]["model_kwargs"] = [{"state_dim": 4, "action_num": 2}] * 2
config["frame_config"]["learning_rate"] = param["lr"]
config["frame_config"]["update_rate"] = param["upd"]
launch(config, pl_callbacks=[cb])
# we use total reward as "accuracy"
nni.report_final_result(cb.total_reward)
param = nni.get_next_parameter()
@@ -0,0 +1,4 @@
{
"lr":{"_type":"choice","_value":[0.0001, 0.001, 0.01, 0.1]},
"upd":{"_type": "choice", "_value": [0.005, 0.002, 0.001, 0.0001]}
}
7 changes: 4 additions & 3 deletions machin/auto/config.py
@@ -1,5 +1,6 @@
from copy import deepcopy
from typing import Dict, Any, Union
from typing import Dict, Any, Union, List
from pytorch_lightning.callbacks import Callback
from machin.frame.algorithms import TorchFramework
from machin.utils.conf import Config
from . import envs
Expand Down Expand Up @@ -128,9 +129,9 @@ def assert_env_config_complete(config: Union[Dict[str, Any], Config]):
assert "test_env_config" in config, 'Missing key "test_env_config" ' "in config."


def launch(config: Union[Dict[str, Any], Config]):
def launch(config: Union[Dict[str, Any], Config], pl_callbacks: List[Callback] = None):
assert_training_config_complete(config)
assert_env_config_complete(config)
assert_algorithm_config_complete(config)
e_module = getattr(envs, config["env"])
return e_module.launch(config)
return e_module.launch(config, pl_callbacks)

0 comments on commit d10727b

Please sign in to comment.