# Reproduction of  `epsilon` metric

If you wish to reproduce the results presented in our paper from scratch, feel free to use the below code.
In this notebook, we provide the codes to reproduce the results for NAS-Bench-101 sarch space, CIFAR10 dataset.

In [None]:
import os
import json
import time
import itertools

import numpy as np
import pandas as pd
import pickle as pkl
from scipy import stats
from tqdm import trange
from dotmap import DotMap

import torch.nn as nn
import torch

import nasspace
from datasets import data
from epsinas_utils import prepare_seed, compute_epsinas

In [None]:
dataset = 'cifar10'
data_loc = './datasets/cifardata'
batch_size = 256
repeat = 1
GPU = '0'
augtype = 'none'
trainval = False

In [None]:
# Arguments required for NAS-Bench-101 search space initialisation
args = DotMap()

args.api_loc = './nasbench_only108.tfrecord'
args.nasspace = 'nasbench101'
args.dataset = dataset
args.stem_out_channels = 128
args.num_stacks = 3
args.num_modules_per_stack = 3
args.num_labels = 1

savedataset = dataset
dataset = 'fake' if 'fake' in savedataset else savedataset
savedataset = savedataset.replace('fake', '')
if savedataset == 'cifar10':
    savedataset = savedataset + '-valid'

In [None]:
# Load the search space (it takes some time)
searchspace = nasspace.get_search_space(args)

In [None]:
if 'valid' in savedataset:
    savedataset = savedataset.replace('-valid', '')
    
if args.dataset == 'cifar10':
    acc_type = 'ori-test'
    val_acc_type = 'x-valid'
else:
    acc_type = 'x-test'
    val_acc_type = 'x-valid'

In [None]:
# Define the device
os.environ['CUDA_VISIBLE_DEVICES'] = GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Randomly select n_samples architectures
prepare_seed(21)

In [None]:
# Load the data 
train_loader = data.get_data(dataset, data_loc, trainval, batch_size, augtype, repeat, args)

# Pick up a batch
data_iterator = iter(train_loader)
x, _= next(data_iterator)
x = x.to(device)

In [None]:
save_dir = f'../epsinas-release-data/NAS-Bench-101/evaluation/{dataset.upper()}'
os.makedirs(save_dir, exist_ok=True)

datafile_name = f'{save_dir}/data_NAS-Bench-101_{dataset.upper()}_test'

if os.path.exists(datafile_name):
    # Load precomputed results
    datafile = open(datafile_name,'rb')
    input_data = pkl.load(datafile)
    scores = input_data["scores"]
    test_accs_mean = input_data["test_accs_mean"]
    test_accs_min = input_data["test_accs_min"]
    test_accs_max = input_data["test_accs_max"]
    val_accs_mean = input_data["val_accs_mean"]
    val_accs_min = input_data["val_accs_min"]
    val_accs_max = input_data["val_accs_max"]
    nparams = input_data["nparams"]
else:
    weights = [1e-4, 10]
    test_accs_mean = []
    test_accs_min = []
    test_accs_max = []
    val_accs_mean = []
    val_accs_min = []
    val_accs_max = []
    nparams = []
    scores = []
    times = []
    
    for i in trange(len(searchspace)):
        start = time.time()
        uid = searchspace[i]
        network = searchspace.get_network(uid)
        network = network.to(device)
        score = compute_epsinas(x, network, weights)
        scores.append(score)
        nparams.append(sum(p.numel() for p in network.parameters()))
        test_accs_mean.append(searchspace.get_final_accuracy(uid, acc_type, False)[0])
        test_accs_min.append(searchspace.get_final_accuracy(uid, acc_type, False)[1])
        test_accs_max.append(searchspace.get_final_accuracy(uid, acc_type, False)[2])
        val_accs_mean.append(searchspace.get_final_accuracy(uid, val_acc_type, False)[0])
        val_accs_min.append(searchspace.get_final_accuracy(uid, val_acc_type, False)[1])
        val_accs_max.append(searchspace.get_final_accuracy(uid, val_acc_type, False)[2])
        times.append(time.time()-start)

    # Save your results
    save_dic = {}
    save_dic["scores"] = scores
    save_dic["test_accs_mean"] = test_accs_mean
    save_dic["test_accs_min"] = test_accs_min
    save_dic["test_accs_max"] = test_accs_max
    save_dic["val_accs_mean"] = val_accs_mean
    save_dic["val_accs_min"] = val_accs_min
    save_dic["val_accs_max"] = val_accs_max
    save_dic["nparams"] = nparams

    pkl.dump(save_dic, open(datafile_name, "wb"))