Imports 

In [2]:
import random
import numpy as np
import xautodl
from nats_bench import create
from pprint import pprint
from xautodl.models import get_cell_based_tiny_net


  from .autonotebook import tqdm as notebook_tqdm


Create a Benchmark Instance

In [3]:
api = create('/work/ws-tmp/g051463-tenneti_NAS/Thesis/data/NATS-tss-v1_0-3ffb9-full/', 'tss', fast_mode=True, verbose=True)
print('\nAPI create done: {:}\n'.format(api))

[2023-01-25 12:49:26] Try to create the NATS-Bench (topology) api from /work/ws-tmp/g051463-tenneti_NAS/Thesis/data/NATS-tss-v1_0-3ffb9-full/ with fast_mode=True
[2023-01-25 12:49:26] Create NATS-Bench (topology) done with 0/15625 architectures avaliable.

API create done: NATStopology(0/15625 architectures, fast_mode=True, file=)



Get architecture of an index(0-15625)

In [4]:
arch = api.arch(1200)
pprint(arch)

Call the arch function with index=1200
'|skip_connect~0|+|nor_conv_1x1~0|nor_conv_3x3~1|+|avg_pool_3x3~0|skip_connect~1|avg_pool_3x3~2|'


Get more information for an architecture of an index(0-15625), dataset(cifar10,cifar100, ImageNet16-120) and hp(12,200)

In [5]:
info = api.get_more_info(1200,"cifar100",hp="200")
pprint(info)

[2023-01-25 12:49:31] Call the get_more_info function with index=1200, dataset=cifar100, iepoch=None, hp=200, and is_random=True.
[2023-01-25 12:49:31] Call query_index_by_arch with arch=1200
[2023-01-25 12:49:31] Call clear_params with archive_root=/work/ws-tmp/g051463-tenneti_NAS/Thesis/data/NATS-tss-v1_0-3ffb9-full/ and index=1200
{'test-accuracy': 56.76000000610352,
 'test-all-time': 121.94741850807492,
 'test-loss': 1.609574425125122,
 'test-per-time': 0.6097370925403746,
 'train-accuracy': 62.326,
 'train-all-time': 3415.230631828308,
 'train-loss': 1.3463448803710938,
 'train-per-time': 17.07615315914154,
 'valid-accuracy': 55.67999995727539,
 'valid-all-time': 121.94741850807492,
 'valid-loss': 1.6315204977035522,
 'valid-per-time': 0.6097370925403746,
 'valtest-accuracy': 56.22,
 'valtest-all-time': 243.89483701614984,
 'valtest-loss': 1.6205476230621338,
 'valtest-per-time': 1.2194741850807491}


Simulate Train Evaluation

In [11]:
sim = api.simulate_train_eval(1200,"cifar100",iepoch= 200,hp="200")
pprint(sim)

[2023-01-25 13:02:36] Call query_index_by_arch with arch=1200
[2023-01-25 13:02:36] Call the get_more_info function with index=1200, dataset=cifar100, iepoch=200, hp=200, and is_random=True.
[2023-01-25 13:02:36] Call query_index_by_arch with arch=1200
[2023-01-25 13:02:36] Call _prepare_info with index=1200 skip because it is in arch2infos_dict


ValueError: invalid iepoch=200 < 200

Get random index of all architectures

In [10]:
rand = api.random()
pprint(rand)

8381


Get the index of an architecture in search space

In [11]:
index = api.query_index_by_arch("|skip_connect~0|+|nor_conv_1x1~0|nor_conv_3x3~1|+|avg_pool_3x3~0|skip_connect~1|avg_pool_3x3~2|")
pprint(index)

[2023-01-23 09:59:11] Call query_index_by_arch with arch=|skip_connect~0|+|nor_conv_1x1~0|nor_conv_3x3~1|+|avg_pool_3x3~0|skip_connect~1|avg_pool_3x3~2|
1200


Get results for an architecture index

In [12]:
results = api.query_by_index(1200,"cifar100",hp ="200")
pprint(results)

[2023-01-23 10:03:14] Call query_by_index with arch_index=1200, dataname=cifar100, hp=200
Call query_meta_info_by_index with arch_index=1200, hp=200
[2023-01-23 10:03:14] Call _prepare_info with index=1200 skip because it is in arch2infos_dict
{777: ResultsCount(cifar100, arch=|skip_connect~0|+|nor_conv_1x1~0|nor_conv_3x3~1|+|avg_pool_3x3~0|skip_connect~1|avg_pool_3x3~2|, FLOP=47.11M, Param=0.350MB, seed=0777, 3 eval-sets: [ori-test, x-valid, x-test]),
 888: ResultsCount(cifar100, arch=|skip_connect~0|+|nor_conv_1x1~0|nor_conv_3x3~1|+|avg_pool_3x3~0|skip_connect~1|avg_pool_3x3~2|, FLOP=47.11M, Param=0.350MB, seed=0888, 3 eval-sets: [ori-test, x-valid, x-test]),
 999: ResultsCount(cifar100, arch=|skip_connect~0|+|nor_conv_1x1~0|nor_conv_3x3~1|+|avg_pool_3x3~0|skip_connect~1|avg_pool_3x3~2|, FLOP=47.11M, Param=0.350MB, seed=0999, 3 eval-sets: [ori-test, x-valid, x-test])}


Get the best architecture wrt metrics

In [None]:
api.verbose = False
best_arch_index, highest_valid_accuracy = api.find_best(dataset='cifar100', metric_on_set='x-valid', hp='200')
print(best_arch_index, api.arch(best_arch_index))

Get network paramaters

In [16]:
params = api.get_net_param(1200,"cifar100",seed= 777,hp="200")
pprint(params)

[2023-01-23 10:21:06] Call the get_net_param function with index=1200, dataset=cifar100, seed=777, hp=200
Call query_meta_info_by_index with arch_index=1200, hp=200
[2023-01-23 10:21:06] Call _prepare_info with index=1200 skip because it is in arch2infos_dict
OrderedDict([('stem.0.weight',
              tensor([[[[-8.2964e-02, -6.8212e-02, -5.3936e-02],
          [ 8.5075e-02,  1.4554e-01,  1.2761e-01],
          [-1.3669e-01, -3.7396e-02, -1.0313e-01]],

         [[-7.7681e-02, -3.1143e-02, -6.1263e-02],
          [ 1.7869e-01,  2.8117e-01,  1.9188e-01],
          [-1.5274e-01, -1.6082e-02, -1.1131e-01]],

         [[-5.1918e-01, -4.5329e-01, -3.8465e-01],
          [ 1.0595e-01,  3.6098e-01,  3.5560e-01],
          [ 1.1070e-01,  4.4192e-01,  3.5489e-01]]],


        [[[ 1.7412e-01,  4.1579e-01, -2.1199e-01],
          [ 3.7251e-01,  6.1738e-01,  1.0666e-02],
          [-1.4778e-01,  6.4297e-02, -3.4661e-01]],

         [[-1.2322e-01,  1.6360e-01, -2.5000e-01],
          [ 7.6914e-02

Get network configuration

In [18]:
config = api.get_net_config(1200,"cifar100")
pprint(config)

[2023-01-23 10:22:39] Call the get_net_config function with index=1200, dataset=cifar100.
[2023-01-23 10:22:39] Call _prepare_info with index=1200 skip because it is in arch2infos_dict
{'C': 16,
 'N': 5,
 'arch_str': '|skip_connect~0|+|nor_conv_1x1~0|nor_conv_3x3~1|+|avg_pool_3x3~0|skip_connect~1|avg_pool_3x3~2|',
 'name': 'infer.tiny',
 'num_classes': 100}


Get cost information of a network

In [19]:
c_info = api.get_cost_info(1200,"cifar100",hp="200")
pprint(c_info)

[2023-01-23 10:24:11] Call the get_cost_info function with index=1200, dataset=cifar100, and hp=200.
[2023-01-23 10:24:11] Call _prepare_info with index=1200 skip because it is in arch2infos_dict
Call query_meta_info_by_index with arch_index=1200, hp=200
[2023-01-23 10:24:11] Call _prepare_info with index=1200 skip because it is in arch2infos_dict
{'T-ori-test@epoch': 1.2194741850807491,
 'T-ori-test@total': 243.89483701614984,
 'T-train@epoch': 17.07615315914154,
 'T-train@total': 3415.230631828308,
 'T-x-test@epoch': 0.6097370925403746,
 'T-x-test@total': 121.94741850807492,
 'T-x-valid@epoch': 0.6097370925403746,
 'T-x-valid@total': 121.94741850807492,
 'flops': 47.1105,
 'latency': 0.01346242869341815,
 'params': 0.350196}


Get Latency 

In [20]:
latency = api.get_latency(1200,"cifar100", hp="200")
pprint(latency)

[2023-01-23 10:31:49] Call the get_latency function with index=1200, dataset=cifar100, and hp=200.
[2023-01-23 10:31:49] Call the get_cost_info function with index=1200, dataset=cifar100, and hp=200.
[2023-01-23 10:31:49] Call _prepare_info with index=1200 skip because it is in arch2infos_dict
Call query_meta_info_by_index with arch_index=1200, hp=200
[2023-01-23 10:31:49] Call _prepare_info with index=1200 skip because it is in arch2infos_dict
0.01346242869341815


Print information specific to an architecture

In [21]:
show =api.show(1200)
pprint(show)

>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> 012 epochs >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
|skip_connect~0|+|nor_conv_1x1~0|nor_conv_3x3~1|+|avg_pool_3x3~0|skip_connect~1|avg_pool_3x3~2|
datasets : ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120'], extra-info : None
cifar10-valid  FLOP= 47.10 M, Params=0.344 MB, latency=14.72 ms.
cifar10-valid  train : [loss = 0.948 & top1 = 65.98%], valid : [loss = 0.996 & top1 = 64.95%]
cifar10        FLOP= 47.10 M, Params=0.344 MB, latency=14.72 ms.
cifar10        train : [loss = 0.778 & top1 = 72.41%], test  : [loss = 0.813 & top1 = 71.16%]
cifar100       FLOP= 47.11 M, Params=0.350 MB, latency=13.46 ms.
cifar100       train : [loss = 2.556 & top1 = 34.05%], valid : [loss = 2.611 & top1 = 32.22%], test : [loss = 2.606 & top1 = 33.04%]
ImageNet16-120 FLOP= 11.78 M, Params=0.351 MB, latency=13.74 ms.
ImageNet16-120 train : [loss = 3.373 & top1 = 19.64%], valid : [loss = 3.323 & top1 = 21.33%], test : [loss = 3.348 & top1 = 19.80%]
>>>>>>>>

In [None]:
stats = api.statistics(1200,"cifar100",hp =200)
pprint(stats)

Find the largest model's performance

In [25]:
largest_candidate_tss = '|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|nor_conv_3x3~0|nor_conv_3x3~1|nor_conv_3x3~2|'
arch_index = api.query_index_by_arch(largest_candidate_tss)
print('The architecture-index for the largest model is {:}'.format(arch_index))
datasets = ('cifar10', 'cifar100', 'ImageNet16-120')
for dataset in datasets:
    print('Its performance on {:} with 12-epoch-training'.format(dataset))
    info = api.get_more_info(arch_index, dataset, hp='12', is_random=False)
    pprint(info)
    print('Its performance on {:} with 200-epoch-training'.format(dataset))
    info = api.get_more_info(arch_index, dataset, hp='200', is_random=False)
    pprint(info)

[2023-01-23 10:53:15] Call query_index_by_arch with arch=|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|nor_conv_3x3~0|nor_conv_3x3~1|nor_conv_3x3~2|
The architecture-index for the largest model is 1462
Its performance on cifar10 with 12-epoch-training
[2023-01-23 10:53:15] Call the get_more_info function with index=1462, dataset=cifar10, iepoch=None, hp=12, and is_random=False.
[2023-01-23 10:53:15] Call query_index_by_arch with arch=1462
[2023-01-23 10:53:15] Call clear_params with archive_root=/work/ws-tmp/g051463-tenneti_NAS/Thesis/data/NATS-tss-v1_0-3ffb9-full/ and index=1462
{'comment': 'In this dict, train-loss/accuracy/time is the metric on the '
            'train+valid sets of CIFAR-10. The test-loss/accuracy/time is the '
            'performance of the CIFAR-10 test set after training on the '
            'train+valid sets by 12 epochs. The per-time and total-time '
            'indicate the per epoch and total time costs, respectively.',
 'test-accuracy': 82.2,
 'test-a

Total architectures on tss

In [29]:
print('There are {:} architectures on the topology search space'.format(len(api)))

There are 15625 architectures on the topology search space
