In [1]:
import os, sys, time, glob, random, argparse
import numpy as np
from copy import deepcopy
import torch
import torch.nn as nn
import time
import tqdm
import scipy.stats as stats
import matplotlib.pyplot as plt
import pickle

# XAutoDL 
from xautodl.config_utils import load_config, dict2config, configure2str
from xautodl.datasets import get_datasets, get_nas_search_loaders
from xautodl.procedures import (
    prepare_seed,
    prepare_logger,
    save_checkpoint,
    copy_checkpoint,
    get_optim_scheduler,
)
from xautodl.utils import get_model_infos, obtain_accuracy
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
from xautodl.models import get_search_spaces

# API
from nats_bench import create

# custom modules
from custom.tss_model import TinyNetwork
from xautodl.models.cell_searchs.genotypes import Structure
from ZeroShotProxy import *

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

2024-01-25 11:11:24.123008: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-01-25 11:11:24.168000: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
parser = argparse.ArgumentParser("Training-free NAS on NAS-Bench-201 (NATS-Bench-TSS)")
parser.add_argument("--data_path", type=str, default='./cifar.python', help="The path to dataset")
parser.add_argument("--dataset", type=str, default='cifar10',choices=["cifar10", "cifar100", "ImageNet16-120"], help="Choose between Cifar10/100 and ImageNet-16.")

# channels and number-of-cells
parser.add_argument("--search_space", type=str, default='tss', help="The search space name.")
parser.add_argument("--config_path", type=str, default='./configs/nas-benchmark/algos/weight-sharing.config', help="The path to the configuration.")
parser.add_argument("--max_nodes", type=int, default=4, help="The maximum number of nodes.")
parser.add_argument("--channel", type=int, default=16, help="The number of channels.")
parser.add_argument("--num_cells", type=int, default=5, help="The number of cells in one stage.")
parser.add_argument("--affine", type=int, default=1, choices=[0, 1], help="Whether use affine=True or False in the BN layer.")
parser.add_argument("--track_running_stats", type=int, default=0, choices=[0, 1], help="Whether use track_running_stats or not in the BN layer.")

# log
parser.add_argument("--print_freq", type=int, default=200, help="print frequency (default: 200)")

# custom
parser.add_argument("--gpu", type=int, default=0, help="")
parser.add_argument("--workers", type=int, default=4, help="number of data loading workers")
parser.add_argument("--api_data_path", type=str, default="./api_data/NATS-tss-v1_0-3ffb9-simple/", help="")
parser.add_argument("--save_dir", type=str, default='./results/tmp', help="Folder to save checkpoints and log.")
parser.add_argument('--zero_shot_score', type=str, default='az_nas_time')
parser.add_argument("--rand_seed", type=int, default=1, help="manual seed (we use 1-to-5)")
args = parser.parse_args(args=[])

if args.rand_seed is None or args.rand_seed < 0:
    args.rand_seed = random.randint(1, 100000)

print(args.rand_seed)
print(args)
xargs=args

1
Namespace(data_path='./cifar.python', dataset='cifar10', search_space='tss', config_path='./configs/nas-benchmark/algos/weight-sharing.config', max_nodes=4, channel=16, num_cells=5, affine=1, track_running_stats=0, print_freq=200, gpu=0, workers=4, api_data_path='./api_data/NATS-tss-v1_0-3ffb9-simple/', save_dir='./results/tmp', zero_shot_score='az_nas_time', rand_seed=1)


In [3]:
assert torch.cuda.is_available(), "CUDA is not available."
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_num_threads(xargs.workers)
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)

Main Function with logger : Logger(dir=results/tmp, use-tf=False, writer=None)
Arguments : -------------------------------
data_path        : ./cifar.python
dataset          : cifar10
search_space     : tss
config_path      : ./configs/nas-benchmark/algos/weight-sharing.config
max_nodes        : 4
channel          : 16
num_cells        : 5
affine           : 1
track_running_stats : 0
print_freq       : 200
gpu              : 0
workers          : 4
api_data_path    : ./api_data/NATS-tss-v1_0-3ffb9-simple/
save_dir         : ./results/tmp
zero_shot_score  : az_nas_time
rand_seed        : 1
Python  Version  : 3.10.11 (main, Apr 20 2023, 19:02:41) [GCC 11.2.0]
Pillow  Version  : 9.4.0
PyTorch Version  : 2.0.1
cuDNN   Version  : 8500
CUDA available   : True
CUDA GPU numbers : 1
CUDA_VISIBLE_DEVICES : 1


In [4]:
## API
api = create(xargs.api_data_path, xargs.search_space, fast_mode=True, verbose=False)
logger.log("Create API = {:} done".format(api))

## data
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
config = load_config(xargs.config_path, {"class_num": class_num, "xshape": xshape}, logger)
search_loader, train_loader, valid_loader = get_nas_search_loaders(train_data,
                                                                   valid_data,
                                                                   xargs.dataset,
                                                                   "./configs/nas-benchmark/",
                                                                   (config.batch_size, config.test_batch_size),
                                                                   xargs.workers,)
logger.log("||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config))

## model
search_space = get_search_spaces(xargs.search_space, "nats-bench")
logger.log("search space : {:}".format(search_space))

device = torch.device('cuda:{}'.format(xargs.gpu))

Create API = NATStopology(0/15625 architectures, fast_mode=True, file=) done
Files already downloaded and verified
Files already downloaded and verified
./configs/nas-benchmark/algos/weight-sharing.config
Configure(scheduler='cos', LR=0.025, eta_min=0.001, epochs=100, warmup=0, optim='SGD', decay=0.0005, momentum=0.9, nesterov=True, criterion='Softmax', batch_size=64, test_batch_size=512, class_num=10, xshape=(1, 3, 32, 32))
||||||| cifar10    ||||||| Search-Loader-Num=391, Valid-Loader-Num=49, batch size=64
||||||| cifar10    ||||||| Config=Configure(scheduler='cos', LR=0.025, eta_min=0.001, epochs=100, warmup=0, optim='SGD', decay=0.0005, momentum=0.9, nesterov=True, criterion='Softmax', batch_size=64, test_batch_size=512, class_num=10, xshape=(1, 3, 32, 32))
search space : ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']


In [5]:
def random_genotype(max_nodes, op_names):
    genotypes = []
    for i in range(1, max_nodes):
        xlist = []
        for j in range(i):
            node_str = "{:}<-{:}".format(i, j)
            op_name = random.choice(op_names)
            xlist.append((op_name, j))
        genotypes.append(tuple(xlist))
    arch = Structure(genotypes)
    return arch

real_input_metrics = ['zico', 'snip', 'grasp', 'te_nas', 'gradsign']
    
def search_find_best(xargs, xloader, n_samples = None, archs = None):
    logger.log("Searching with {}".format(xargs.zero_shot_score.lower()))
    score_fn_name = "compute_{}_score".format(xargs.zero_shot_score.lower())
    score_fn = globals().get(score_fn_name)
    input_, target_ = next(iter(xloader))
    resolution = input_.size(2)
    batch_size = input_.size(0)
    zero_shot_score_dict = None
    arch_list = []
    if xargs.zero_shot_score.lower() in real_input_metrics:
        print('Use real images as inputs')
        trainloader = train_loader
    else:
        print('Use random inputs')
        trainloader = None
        
    if archs is None and n_samples is not None:
        all_time = []
        all_mem = []
        fwrd_time = []
        exp_time = []
        bkwd_time = []
        trn_time = []
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        for i in tqdm.tqdm(range(n_samples)):
            torch.cuda.empty_cache()
            torch.cuda.reset_peak_memory_stats()
            # random sampling
            arch = random_genotype(xargs.max_nodes, search_space)
            network = TinyNetwork(xargs.channel, xargs.num_cells, arch, class_num)
            network = network.to(device)
            network.train()

            start.record()
            

            info_dict = score_fn.compute_nas_score(network, gpu=xargs.gpu, trainloader=trainloader, resolution=resolution, batch_size=batch_size)

            end.record()
            torch.cuda.synchronize()
            all_time.append(start.elapsed_time(end))
#             all_mem.append(torch.cuda.max_memory_reserved())
            all_mem.append(torch.cuda.max_memory_allocated())
    
            fwrd_time.append(info_dict['fwrd_time'])
            bkwd_time.append(info_dict['bkwd_time'])
            exp_time.append(info_dict['exp_time'])
            trn_time.append(info_dict['trn_time'])

            arch_list.append(arch)
            if zero_shot_score_dict is None: # initialize dict
                zero_shot_score_dict = dict()
                for k in info_dict.keys():
                    zero_shot_score_dict[k] = []
            for k, v in info_dict.items():
                zero_shot_score_dict[k].append(v)

        logger.log("------Runtime------")
        logger.log("All: {:.5f} ms".format(np.mean(all_time)))
        logger.log("------Runtime - forward------")
        logger.log("All: {:.5f} ms".format(np.mean(fwrd_time)))
        logger.log("------Runtime - backwrad------")
        logger.log("All: {:.5f} ms".format(np.mean(bkwd_time)))
        logger.log("------Runtime - exp/prog------")
        logger.log("All: {:.5f} ms".format(np.mean(exp_time)))
        logger.log("------Runtime - trn------")
        logger.log("All: {:.5f} ms".format(np.mean(trn_time)))
        logger.log("------Avg Mem------")
        logger.log("All: {:.5f} GB".format(np.mean(all_mem)/1e9))
        logger.log("------Max Mem------")
        logger.log("All: {:.5f} GB".format(np.max(all_mem)/1e9))
        
    elif archs is not None and n_samples is None:
        all_time = []
        all_mem = []
        fwrd_time = []
        exp_time = []
        bkwd_time = []
        trn_time = []
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        for arch in tqdm.tqdm(archs):
            torch.cuda.empty_cache()
            torch.cuda.reset_peak_memory_stats()
            network = TinyNetwork(xargs.channel, xargs.num_cells, arch, class_num)
            network = network.to(device)
            network.train()

            start.record()
            

            info_dict = score_fn.compute_nas_score(network, gpu=xargs.gpu, trainloader=trainloader, resolution=resolution, batch_size=batch_size)

            end.record()
            torch.cuda.synchronize()
            all_time.append(start.elapsed_time(end))
#             all_mem.append(torch.cuda.max_memory_reserved())
            all_mem.append(torch.cuda.max_memory_allocated())
    
            fwrd_time.append(info_dict['fwrd_time'])
            bkwd_time.append(info_dict['bkwd_time'])
            exp_time.append(info_dict['exp_time'])
            trn_time.append(info_dict['trn_time'])

            arch_list.append(arch)
            if zero_shot_score_dict is None: # initialize dict
                zero_shot_score_dict = dict()
                for k in info_dict.keys():
                    zero_shot_score_dict[k] = []
            for k, v in info_dict.items():
                zero_shot_score_dict[k].append(v)

        logger.log("------Runtime------")
        logger.log("All: {:.5f} ms".format(np.mean(all_time)))
        logger.log("------Runtime - forward------")
        logger.log("All: {:.5f} ms".format(np.mean(fwrd_time)))
        logger.log("------Runtime - backwrad------")
        logger.log("All: {:.5f} ms".format(np.mean(bkwd_time)))
        logger.log("------Runtime - exp/prog------")
        logger.log("All: {:.5f} ms".format(np.mean(exp_time)))
        logger.log("------Runtime - trn------")
        logger.log("All: {:.5f} ms".format(np.mean(trn_time)))
        logger.log("------Avg Mem------")
        logger.log("All: {:.5f} GB".format(np.mean(all_mem)/1e9))
        logger.log("------Max Mem------")
        logger.log("All: {:.5f} GB".format(np.max(all_mem)/1e9))
        
    return arch_list, zero_shot_score_dict

In [6]:
# ######### search across random N archs #########
# archs, results = search_find_best(xargs, train_loader, n_samples=100)

# ######### search across N archs uniformly sampled according to test acc. #########
# def uniform_sample_archs(search_space, xargs, api, n_samples=1000, dataset='ImageNet16-120'):
#     arch = random_genotype(xargs.max_nodes, search_space)
#     search_space = get_search_spaces(xargs.search_space, "nats-bench")
#     archs = arch.gen_all(search_space, xargs.max_nodes, False)
    
#     def get_results_from_api(api, arch, dataset='cifar10'):
#         dataset_candidates = ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120']
#         assert dataset in dataset_candidates
#         index = api.query_index_by_arch(arch)
#         api._prepare_info(index)
#         archresult = api.arch2infos_dict[index]['200']
#         if dataset == 'cifar10-valid':
#             acc = archresult.get_metrics(dataset, 'x-valid', iepoch=None, is_random=False)['accuracy']
#         elif dataset == 'cifar10':
#             acc = archresult.get_metrics(dataset, 'ori-test', iepoch=None, is_random=False)['accuracy']
#         else:
#             acc = archresult.get_metrics(dataset, 'x-test', iepoch=None, is_random=False)['accuracy']
#         return acc

#     accs = []
#     for a in archs:
#         accs.append(get_results_from_api(api, a, dataset))
#     interval = len(archs) // n_samples
#     sorted_indices = np.argsort(accs)
#     new_archs = []
#     for i, idx in enumerate(sorted_indices):
#         if i % interval == 0:
#             new_archs.append(archs[idx])
#     archs = new_archs
    
#     return archs

# if os.path.exists("./tss_uniform_arch.pickle"):
#     with open("./tss_uniform_arch.pickle", "rb") as fp:
#         uniform_archs = pickle.load(fp)
# else:
#     uniform_archs = uniform_sample_archs(search_space, xargs, api, 1000, 'ImageNet16-120')
#     with open("./tss_uniform_arch.pickle", "wb") as fp:
#         pickle.dump(uniform_archs, fp)
        
# result_path = "./{}_uniform_arch.pickle".format(args.zero_shot_score)
# if os.path.exists(result_path):
#     print("results already exists")
#     with open(result_path, "rb") as fp:
#         results = pickle.load(fp)
#     archs = uniform_archs
# else:
#     archs, results = search_find_best(xargs, train_loader, archs=uniform_archs)
#     with open(result_path, "wb") as fp:
#         pickle.dump(results, fp)


######### search across all archs #########
def generate_all_archs(search_space, xargs):
    arch = random_genotype(xargs.max_nodes, search_space)
    archs = arch.gen_all(search_space, xargs.max_nodes, False)
    return archs

if os.path.exists("./tss_all_arch.pickle"):
    with open("./tss_all_arch.pickle", "rb") as fp:
        all_archs = pickle.load(fp)
else:
    all_archs = generate_all_archs(search_space, xargs)
    with open("./tss_all_arch.pickle", "wb") as fp:
        pickle.dump(all_archs, fp)

archs, results = search_find_best(xargs, train_loader, archs=all_archs)

# ####
# result_path = "./{}_all_arch.pickle".format(args.zero_shot_score)
# if os.path.exists(result_path):
#     print("results already exists")
#     with open(result_path, "rb") as fp:
#         results = pickle.load(fp)
#     archs = all_archs
# else:
#     archs, results = search_find_best(xargs, train_loader, archs=all_archs)
#     with open(result_path, "wb") as fp:
#         pickle.dump(results, fp)

Searching with az_nas_time
Use random inputs


100%|██████████| 15625/15625 [18:45<00:00, 13.88it/s]

------Runtime------
All: 45.56091 ms
------Runtime - forward------
All: 7.49732 ms
------Runtime - backwrad------
All: 9.61485 ms
------Runtime - exp/prog------
All: 10.51265 ms
------Runtime - trn------
All: 13.63413 ms
------Avg Mem------
All: 0.28326 GB
------Max Mem------
All: 0.53120 GB



