1.Create Benchmark Instance

In [5]:
from nats_bench import create
from pprint import pprint

# Create the API instance for the size search space in NATS
api = create('/work/ws-tmp/g051463-tenneti_NAS/data/NATS-tss-v1_0-3ffb9-full/', 'tss', fast_mode=True, verbose=True)
print('\nAPI create done: {:}\n'.format(api))


info = api.get_more_info(100, 'cifar100', hp='12')
pprint(info)

[2022-12-08 15:47:55] Try to create the NATS-Bench (topology) api from /work/ws-tmp/g051463-tenneti_NAS/data/NATS-tss-v1_0-3ffb9-full/ with fast_mode=True
[2022-12-08 15:47:55] Create NATS-Bench (topology) done with 0/15625 architectures avaliable.

API create done: NATStopology(0/15625 architectures, fast_mode=True, file=)

[2022-12-08 15:47:55] Call the get_more_info function with index=100, dataset=cifar100, iepoch=None, hp=12, and is_random=True.
[2022-12-08 15:47:55] Call query_index_by_arch with arch=100
[2022-12-08 15:47:55] Call clear_params with archive_root=/work/ws-tmp/g051463-tenneti_NAS/data/NATS-tss-v1_0-3ffb9-full/ and index=100
{'test-accuracy': 43.44000002441406,
 'test-all-time': 9.295885120119369,
 'test-loss': 2.099257713317871,
 'test-per-time': 0.7746570933432807,
 'train-accuracy': 45.168,
 'train-all-time': 270.6521048545837,
 'train-loss': 2.0303653003692625,
 'train-per-time': 22.554342071215306,
 'valid-accuracy': 44.81999994506836,
 'valid-all-time': 9.29588

In [None]:
from nats_bench import create
# Create the API instance for the size search space in NATS
# api = create(None, 'sss', fast_mode=True, verbose=True)

# Create the API instance for the topology search space in NATS
# api = create('/work/ws-tmp/g051463-tenneti_NAS/NATS-Bench/fake_torch_dir/NATS-tss-v1_0-3ffb9-simple/', 'tss', fast_mode=True, verbose=True)
# api = create('/work/ws-tmp/g051463-tenneti_NAS/data/NATS-tss-v1_0-3ffb9-full/', 'tss', fast_mode=True, verbose=True)
api = create(None, 'tss', fast_mode=True, verbose=True)

2.Query the performance

In [29]:
# Show the architecture topology string of the 12-th architecture
# For the topology search space, the string is interpreted as
# arch = '|{}~0|+|{}~0|{}~1|+|{}~0|{}~1|{}~2|'.format(
#         edge_node_0_to_node_1,
#         edge_node_0_to_node_2,
#         edge_node_1_to_node_2,
#         edge_node_0_to_node_3,
#         edge_node_1_to_node_3,
#         edge_node_2_to_node_3,
#         )
# For the size search space, the string is interpreted as
# arch = '{}:{}:{}:{}:{}'.format(out_channel_of_1st_conv_layer,
#                                out_channel_of_1st_cell_stage,
#                                out_channel_of_1st_residual_block,
#                                out_channel_of_2nd_cell_stage,
#                                out_channel_of_2nd_residual_block,
#                                )
architecture_str = api.arch(1000)
print(architecture_str)

# Query the loss / accuracy / time for 1234-th candidate architecture on CIFAR-10
# info is a dict, where you can easily figure out the meaning by key
info = api.get_more_info(1000, 'cifar100')
pprint(info)

# Query the flops, params, latency. info is a dict.
info = api.get_cost_info(1000, 'cifar100')
pprint(info)

# Simulate the training of the 1224-th candidate:
validation_accuracy, latency, time_cost, current_total_time_cost = api.simulate_train_eval(1000, dataset='cifar100', hp='12')
sim = [validation_accuracy, latency, time_cost, current_total_time_cost]
pprint(sim)

Call the arch function with index=1000
|skip_connect~0|+|nor_conv_3x3~0|nor_conv_1x1~1|+|none~0|skip_connect~1|nor_conv_3x3~2|
[2022-12-07 16:50:06] Call the get_more_info function with index=1000, dataset=cifar100, iepoch=None, hp=12, and is_random=True.
[2022-12-07 16:50:06] Call query_index_by_arch with arch=1000
[2022-12-07 16:50:06] Call _prepare_info with index=1000 skip because it is in arch2infos_dict
{'test-accuracy': 58.30000002441406,
 'test-all-time': 7.9043262004852295,
 'test-loss': 1.4708942895889283,
 'test-per-time': 0.6586938500404358,
 'train-accuracy': 63.73,
 'train-all-time': 229.09286308288577,
 'train-loss': 1.2646459322738648,
 'train-per-time': 19.091071923573814,
 'valid-accuracy': 58.49999993896484,
 'valid-all-time': 7.9043262004852295,
 'valid-loss': 1.4571917526245117,
 'valid-per-time': 0.6586938500404358,
 'valtest-accuracy': 58.4,
 'valtest-all-time': 15.808652400970459,
 'valtest-loss': 1.4640430240631104,
 'valtest-per-time': 1.3173877000808716}
[202

3.Create Instance of architecture candidate in NATS_Bench

In [30]:
# Create the instance of th 12-th candidate for CIFAR-10.
# To keep NATS-Bench repo concise, we did not include any model-related codes here because they rely on PyTorch.
# The package of [models] is defined at https://github.com/D-X-Y/AutoDL-Projects
#   so that one need to first import this package.
import xautodl
from xautodl.models import get_cell_based_tiny_net
config = api.get_net_config(1000, 'cifar100')
network = get_cell_based_tiny_net(config)

# Load the pre-trained weights: params is a dict, where the key is the seed and value is the weights.
params = api.get_net_param(1000, 'cifar100', None)
network.load_state_dict(next(iter(params.values())))

[2022-12-07 16:50:20] Call the get_net_config function with index=1000, dataset=cifar100.
[2022-12-07 16:50:20] Call _prepare_info with index=1000 skip because it is in arch2infos_dict
[2022-12-07 16:50:20] Call the get_net_param function with index=1000, dataset=cifar100, seed=None, hp=12
Call query_meta_info_by_index with arch_index=1000, hp=12
[2022-12-07 16:50:20] Call _prepare_info with index=1000 skip because it is in arch2infos_dict


<All keys matched successfully>

In [31]:
pprint(params)

{111: OrderedDict([('stem.0.weight',
                    tensor([[[[-3.9259e-02,  2.7429e-01, -1.5944e-01],
          [ 2.7453e-01,  4.8323e-01,  1.3749e-01],
          [ 3.1095e-01,  3.5012e-01,  1.9776e-01]],

         [[ 9.6513e-02, -1.6555e-01, -1.8112e-01],
          [-2.7225e-02, -6.2463e-01, -4.5833e-01],
          [-2.1087e-01, -5.5725e-01, -2.8362e-01]],

         [[ 1.4423e-01,  2.8174e-01,  1.7254e-01],
          [ 2.4322e-01,  1.3004e-01, -1.4686e-02],
          [ 1.2553e-01, -1.5606e-01, -5.7584e-02]]],


        [[[ 8.0151e-02, -2.6579e-01,  1.1896e-01],
          [-1.9732e-01, -4.4863e-01, -1.8689e-01],
          [ 1.2604e-01, -1.5146e-01,  1.1715e-01]],

         [[ 2.2554e-01, -2.2820e-01,  1.8257e-01],
          [-1.1932e-01, -4.9699e-01, -1.9194e-01],
          [ 1.4155e-02, -1.5600e-01,  1.0840e-01]],

         [[ 1.6059e-01, -1.0694e-01,  9.2701e-02],
          [ 1.7291e-02, -3.7085e-01, -2.4440e-02],
          [ 2.1436e-01,  1.1276e-01,  1.0181e-01]]],


        [

4.Others

In [24]:
# Clear the parameters of the 12-th candidate.
api.clear_params(1000)

# Reload all information of the 12-th candidate.
api.reload(index=1000)


[2022-12-07 16:41:03] Call clear_params with index=1000 and hp=None
[2022-12-07 16:41:03] Call clear_params with archive_root=None and index=1000


KeyError: 'TORCH_HOME'

API Tests

In [None]:
from nats_bench import api_test
api_test.test_nats_bench_tss('NATS-tss-v1_0-3ffb9-simple')
api_test.test_nats_bench_tss('NATS-sss-v1_0-50262-simple')