# Reproduction of  `epsilon` metric

If you wish to reproduce the results presented in our paper from scratch, feel free to use the below code.
In this notebook, we provide the codes to reproduce the results for NAS-Bench-201 sarch space, CIFAR10, CIFAR100 and ImageNet16-120 datasets.

In [None]:
import os
import json
import time
import itertools

import numpy as np
import pandas as pd
import pickle as pkl
import seaborn as sns
from scipy import stats
from tqdm import trange
from dotmap import DotMap
from statistics import mean
import matplotlib.pyplot as plt

import torch

import nasspace
from datasets import data
from epsinas_utils import prepare_seed, compute_epsinas

## Select the dataset
Choose the dataset you would like to reproduce.
Choose among: 'cifar10', 'cifar100', 'ImageNet16-120'

In [None]:
dataset = 'cifar10'

In [None]:
if dataset=='ImageNet16-120':
    data_loc = './datasets/ImageNet16'
else:
    data_loc = './datasets'

In [None]:
batch_size=256
repeat=1
GPU='0'
augtype='none'
trainval=True

In [None]:
# Arguments required for NAS-Bench-201 search space initialisation
args = DotMap()
args.nasspace = 'nasbench201'
args.dataset=dataset
args.api_loc = './api/NAS-Bench-201-v1_1-096897.pth'

savedataset = dataset
dataset = 'fake' if 'fake' in savedataset else savedataset
savedataset = savedataset.replace('fake', '')
if savedataset == 'cifar10':
    savedataset = savedataset + '-valid'

In [None]:
# Load the search space (it takes some time)
searchspace = nasspace.get_search_space(args)

In [None]:
if 'valid' in savedataset:
    savedataset = savedataset.replace('-valid', '')

if dataset == 'cifar10':
    acc_type = 'ori-test'
    val_acc_type = 'x-valid'
else:
    acc_type = 'x-test'
    val_acc_type = 'x-valid'

In [None]:
# Define the device
os.environ['CUDA_VISIBLE_DEVICES'] = GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
prepare_seed(21)

In [None]:
# Load the data
# The first time, data will be downloaded into 
train_loader = data.get_data(dataset, data_loc, trainval, batch_size, augtype, repeat, args)

# Pick up a batch
data_iterator = iter(train_loader)
x, _ = next(data_iterator) # No need of true labels
x = x.to(device)

In [None]:
save_dir = f'../epsinas-release-data/NAS-Bench-201/evaluation/{dataset.upper()}'
os.makedirs(save_dir, exist_ok=True)

datafile_name = f'{save_dir}/data_NAS-Bench-201_{dataset.upper()}'

if os.path.exists(datafile_name):
    # Load precomputed results
    datafile = open(datafile_name,'rb')
    input_data = pkl.load(datafile)
    scores = input_data["scores"]
    test_accs = input_data["test_accs"]
    val_accs = input_data["val_accs"]
    nparams = input_data["nparams"]
else:
    weights = [1e-7, 1]
    test_accs = []
    val_accs = []
    nparams = []
    scores = []
    times = []
    for i in trange(len(searchspace)):
        start = time.time()
        uid = searchspace[i]
        network = searchspace.get_network(uid)
        network = network.to(device)
        score = compute_epsinas(x, network, weights)
        scores.append(score)
        nparams.append(sum(p.numel() for p in network.parameters()))
        test_accs.append(searchspace.get_final_accuracy(uid, acc_type, False))
        if dataset=='cifar10':
            val_accs.append(searchspace.get_final_accuracy(uid, val_acc_type, True))
        else:
            val_accs.append(searchspace.get_final_accuracy(uid, val_acc_type, False))
        times.append(time.time()-start)

    # Save your results
    save_dic = {} 
    save_dic["scores"] = scores
    save_dic["nparams"] = nparams
    save_dic["test_accs"] = test_accs
    save_dic["val_accs"] = val_accs
    save_dic["times"] = times

    pkl.dump(save_dic, open(datafile_name, "wb"))