In [18]:
import logging
from naslib.defaults.trainer import Trainer
from naslib.optimizers import DARTSOptimizer
from naslib.search_spaces import DartsSearchSpace
from naslib.utils import utils, setup_logger
from naslib.search_spaces.core.graph import Graph, EdgeData
from naslib.search_spaces.core import primitives as ops
from torch import nn

In [21]:
config = utils.get_config_from_args(config_type='nas_predictor')
utils.set_seed(config.seed)
utils.log_args(config)

usage: ipykernel_launcher.py [-h] [--config-file FILE] [--eval-only]
                             [--seed SEED] [--resume]
                             [--model-path MODEL_PATH]
                             [--world-size WORLD_SIZE] [--rank RANK]
                             [--gpu GPU] [--dist-url DIST_URL]
                             [--dist-backend DIST_BACKEND]
                             [--multiprocessing-distributed]
                             ...
ipykernel_launcher.py: error: unrecognized arguments: --ip=127.0.0.1 --stdin=9003 --control=9001 --hb=9000 --Session.signature_scheme="hmac-sha256" --Session.key=b"4c1ec925-4650-41b0-8f62-3a52904380c8" --shell=9002 --transport="tcp" --iopub=9004 --f=/home/robertsj/.local/share/jupyter/runtime/kernel-v2-1611060d8RH7F8FeACA.json


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [20]:
class DartsSearchSpace(Graph):

    OPTIMIZER_SCOPE = [
        'a_stage_1',
        'a_stage_2', 
        'a_stage_3'
    ]

    QUERYABLE = False

    def __init__(self):
        super().__init__()

        # cell definition
        activation_cell = Graph()
        activation_cell.name = 'activation_cell'
        activation_cell.add_node(1) # input node
        activation_cell.add_node(2) # intermediate node
        activation_cell.add_node(3) # output node
        activation_cell.add_edge(1, 2) # mutable intermediate edge
        activation_cell.edges[1, 2].set('cell_name', 'activation_cell')
        activation_cell.add_edges_from([(2, 3, EdgeData().finalize())]) # immutable output edge
  

        # macroarchitecture definition
        self.name = 'macrograph'
        self.add_node(1) # input node
        self.add_node(2) # intermediate node
        self.add_node(3, subgraph=activation_cell.copy().set_scope('a_stage_1').set_input([2])) # activation node 1
        self.add_node(4, subgraph=activation_cell.copy().set_scope('a_stage_2').set_input([3])) # activation node 2
        self.add_node(5, subgraph=activation_cell.copy().set_scope('a_stage_3').set_input([4])) # activation node 3
        self.add_node(6) # output node
        
        self.add_edge(1, 2) # convolutional edge
        self.edges[1, 2].set('op', [
            ops.Sequential(
                nn.Conv2d(3, 6, 5),
                nn.MaxPool2d(2),
                nn.Conv2d(6, 16, 5),
                nn.MaxPool2d(2),
                nn.Flatten()
            )
        ]) # convolutional edge
        self.add_edges_from([(i, i+1) for i in range(2, 6)]) # identity edges
        dims = [(16 * 5 * 5, 120), (120, 84), (84, 10)]
        for i, (in_dim, out_dim) in enumerate(dims):
            self.update_edges(
                update_func=lambda edge: self._set_ops(edge, in_dim, out_dim),
                scope=f"a_stage_{i+1}",
                private_edge_data=True,
            )

    def _set_ops(self, edge, in_dim, out_dim):
        if out_dim != 10:
            edge.data.set('op', [
                ops.Sequential(nn.Linear(in_dim, out_dim), nn.ReLU()),
                ops.Sequential(nn.Linear(in_dim, out_dim), nn.Hardswish()),
                ops.Sequential(nn.Linear(in_dim, out_dim), nn.LeakyReLU()),
                ops.Sequential(nn.Linear(in_dim, out_dim), nn.Identity())
            ], shared=False) # FIXME
        else:
            edge.data.set('op', [
                ops.Sequential(nn.Linear(in_dim, out_dim), nn.Softmax())
            ])    

In [21]:
search_space = DartsSearchSpace()

Update function could not be veryfied. Be cautious with the setting of `private_edge_data` in `update_edges()`
Update function could not be veryfied. Be cautious with the setting of `private_edge_data` in `update_edges()`
Update function could not be veryfied. Be cautious with the setting of `private_edge_data` in `update_edges()`


In [None]:
logger = setup_logger(config.save + "/log.log")
logger.setLevel(logging.INFO)

In [None]:
optimizer = DARTSOptimizer(config)
optimizer.adapt_search_space(search_space)

In [None]:
trainer = Trainer(optimizer, config)
trainer.search()

In [None]:
trainer.evaluate()