Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable fixed_arch on Retiarii #3972

Merged
merged 6 commits into from Jul 27, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/en_US/NAS/ApiReference.rst
Expand Up @@ -105,4 +105,6 @@ Retiarii Experiments
Utilities
---------

.. autofunction:: nni.retiarii.serialize
.. autofunction:: nni.retiarii.serialize

.. autofunction:: nni.retiarii.fixed_arch
8 changes: 7 additions & 1 deletion docs/en_US/NAS/OneshotTrainer.rst
Expand Up @@ -34,4 +34,10 @@ See `API reference <./ApiReference.rst>`__ for detailed usages. Here, we show an
trainer.fit()
final_architecture = trainer.export()

**Format of the exported architecture.** TBD.
After the searching is done, we can use the exported architecture to instantiate the full network for retraining. Here is an example:

.. code-block:: python

from nni.retiarii import fixed_arch
with fixed_arch('/path/to/checkpoint.json'):
model = Model()
8 changes: 4 additions & 4 deletions examples/nas/oneshot/darts/model.py
Expand Up @@ -7,7 +7,7 @@
import torch.nn as nn

import ops
from nni.nas.pytorch import mutables
from nni.retiarii.nn.pytorch import LayerChoice, InputChoice


class AuxiliaryHead(nn.Module):
Expand Down Expand Up @@ -45,17 +45,17 @@ def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect):
stride = 2 if i < num_downsample_connect else 1
choice_keys.append("{}_p{}".format(node_id, i))
self.ops.append(
mutables.LayerChoice(OrderedDict([
LayerChoice(OrderedDict([
("maxpool", ops.PoolBN('max', channels, 3, stride, 1, affine=False)),
("avgpool", ops.PoolBN('avg', channels, 3, stride, 1, affine=False)),
("skipconnect", nn.Identity() if stride == 1 else ops.FactorizedReduce(channels, channels, affine=False)),
("sepconv3x3", ops.SepConv(channels, channels, 3, stride, 1, affine=False)),
("sepconv5x5", ops.SepConv(channels, channels, 5, stride, 2, affine=False)),
("dilconv3x3", ops.DilConv(channels, channels, 3, stride, 2, 2, affine=False)),
("dilconv5x5", ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False))
]), key=choice_keys[-1]))
]), label=choice_keys[-1]))
self.drop_path = ops.DropPath()
self.input_switch = mutables.InputChoice(choose_from=choice_keys, n_chosen=2, key="{}_switch".format(node_id))
self.input_switch = InputChoice(n_candidates=len(choice_keys), n_chosen=2, label="{}_switch".format(node_id))

def forward(self, prev_nodes):
assert len(self.ops) == len(prev_nodes)
Expand Down
6 changes: 3 additions & 3 deletions examples/nas/oneshot/darts/retrain.py
Expand Up @@ -12,8 +12,8 @@
import datasets
import utils
from model import CNN
from nni.nas.pytorch.fixed import apply_fixed_architecture
from nni.nas.pytorch.utils import AverageMeter
from nni.retiarii import fixed_arch

logger = logging.getLogger('nni')

Expand Down Expand Up @@ -119,8 +119,8 @@ def validate(config, valid_loader, model, criterion, epoch, cur_step):
args = parser.parse_args()
dataset_train, dataset_valid = datasets.get_dataset("cifar10", cutout_length=16)

model = CNN(32, 3, 36, 10, args.layers, auxiliary=True)
apply_fixed_architecture(model, args.arc_checkpoint)
with fixed_arch(args.arc_checkpoint):
model = CNN(32, 3, 36, 10, args.layers, auxiliary=True)
criterion = nn.CrossEntropyLoss()

model.to(device)
Expand Down
1 change: 1 addition & 0 deletions nni/retiarii/__init__.py
Expand Up @@ -4,5 +4,6 @@
from .operation import Operation
from .graph import *
from .execution import *
from .fixed import fixed_arch
from .mutator import *
from .serializer import basic_unit, json_dump, json_dumps, json_load, json_loads, serialize, serialize_cls, model_wrapper
40 changes: 40 additions & 0 deletions nni/retiarii/fixed.py
@@ -0,0 +1,40 @@
import json
import logging
from pathlib import Path
from typing import Union, Dict, Any

from .utils import ContextStack

_logger = logging.getLogger(__name__)


def fixed_arch(fixed_arch: Union[str, Path, Dict[str, Any]], verbose=True):
"""
Load architecture from ``fixed_arch`` and apply to model. This should be used as a context manager. For example,

.. code-block:: python

with fixed_arch('/path/to/export.json'):
model = Model(3, 224, 224)

Parameters
----------
fixed_arc : str, Path or dict
Path to the JSON that stores the architecture, or dict that stores the exported architecture.
verbose : bool
Print log messages if set to True

Returns
-------
ContextStack
Context manager that provides a fixed architecture when creates the model.
"""

if isinstance(fixed_arch, (str, Path)):
with open(fixed_arch) as f:
fixed_arch = json.load(f)

if verbose:
_logger.info(f'Fixed architecture: %s', fixed_arch)

return ContextStack('fixed', fixed_arch)
5 changes: 3 additions & 2 deletions nni/retiarii/oneshot/pytorch/darts.py
Expand Up @@ -3,6 +3,7 @@

import copy
import logging
from collections import OrderedDict

import torch
import torch.nn as nn
Expand All @@ -19,7 +20,7 @@ class DartsLayerChoice(nn.Module):
def __init__(self, layer_choice):
super(DartsLayerChoice, self).__init__()
self.name = layer_choice.key
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be layer_choice.label, otherwise there would be warning. DartsInputChoice should also be modified.

self.op_choices = nn.ModuleDict(layer_choice.named_children())
self.op_choices = nn.ModuleDict(OrderedDict([(name, layer_choice[name]) for name in layer_choice.names]))
self.alpha = nn.Parameter(torch.randn(len(self.op_choices)) * 1e-3)

def forward(self, *args, **kwargs):
Expand All @@ -38,7 +39,7 @@ def named_parameters(self):
yield name, p

def export(self):
return torch.argmax(self.alpha).item()
return list(self.op_choices.keys())[torch.argmax(self.alpha).item()]


class DartsInputChoice(nn.Module):
Expand Down