In [None]:
dataset = 'ImageNet16-120' # choose between 'ImageNet16-120', 'cifar10' and 'cifar100'
data_loc = '../datasets/ImageNet16' # choose ImageNet16 for ImageNet16-120 and cifar for cifar10 and cifar100
api_loc = '../datasets/NAS-Bench-201-v1_1-096897.pth'
n_runs = 500
n_init = 100
n_samples = 1000
batch_size = 256
trainval = False # set to True to get access to the validation error for cifar10
GPU = '0' # choose the GPU to be used

In [None]:
import os
import time
import argparse
import random
import numpy as np
from tqdm import trange
from statistics import mean

%matplotlib inline

import matplotlib.pyplot as plt
import matplotlib.colors as colors

os.environ['CUDA_VISIBLE_DEVICES'] = GPU
seed = 1

save_loc = 'results'

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torch.optim as optim

from models import get_cell_based_tiny_net

torch.cuda.empty_cache()

# Reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Setting the seed
def prepare_seed(rand_seed):
    torch.manual_seed(rand_seed)
    torch.cuda.manual_seed(rand_seed)
    torch.cuda.manual_seed_all(rand_seed)

# Colormap for plots
def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100):
    new_cmap = colors.LinearSegmentedColormap.from_list(
        'trunc({n},{a:.2f},{b:.2f})'.format(n=cmap.name, a=minval, b=maxval),
        cmap(np.linspace(minval, maxval, n)))
    return new_cmap  

def plot_test(scores, accs, nparams):
    cmap = plt.get_cmap('inferno_r')
    new_cmap = truncate_colormap(cmap, 0.12, 0.63)

    fig = plt.figure(figsize=(8,6))

    plt.box(on=None)
    plt.grid(color='#dbdbd9', linewidth=0.5)
    plt.xlabel('Trained accuracy', fontsize = 12)
    plt.ylabel('$\sigma_{R}$', fontsize = 12)
    plt.scatter(np.array(accs),
                np.array(scores), 
                s=20,
                c=np.log10(np.array(nparams)),
                cmap=new_cmap,
                vmin=np.log10(np.min(nparams)),
                vmax=np.log10(np.max(nparams))
                )
    plt.show()

    
def initialize_resnet(m):
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.BatchNorm2d):
        #changed here
        nn.init.constant_(m.weight, 1)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.Linear):
        nn.init.normal_(m.weight, 0, 0.01)
        nn.init.constant_(m.bias, 0)

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

import torchvision.transforms as transforms
from datasets import get_datasets
from config_utils import load_config
from nas_201_api import NASBench201API as API
    
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
THE_START = time.time()
api = API(api_loc, verbose = False)

train_data, valid_data, xshape, class_num = get_datasets(dataset, data_loc, cutout=0)
if dataset == 'cifar10':
    acc_type = 'ori-test'
    val_acc_type = 'x-valid'

else:
    acc_type = 'x-test'
    val_acc_type = 'x-valid'

if trainval:
    cifar_split = load_config('config_utils/cifar-split.txt', None, None)
    train_split, valid_split = cifar_split.train, cifar_split.valid
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,
                                               num_workers=0, pin_memory=True, sampler= torch.utils.data.sampler.SubsetRandomSampler(train_split))

else:
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True,
                                               num_workers=0, pin_memory=True)

times     = []
chosen    = []
acc       = []
val_acc   = []
topscores = []

dset = dataset if not trainval else 'cifar10-valid'

save_dir = './results/{}bs_{}it_{}runs_{}arch/'.format(batch_size, n_init, n_runs, n_samples)
os.makedirs(save_dir, exist_ok=True)
if trainval:
    logs_filename = save_dir + 'logs_' + dataset.upper() +  'val.txt'
else:
    logs_filename = save_dir + 'logs_' + dataset.upper() +  '.txt'
with open(logs_filename, 'w') as logs:
    logs.write("Starting logging...\n")
    ind_actual_best_mean = 0
    runs = trange(n_runs, desc='acc: ')
    for N in runs:
        start = time.time()
        # Randomly select n_samples architectures
        indices = np.random.randint(0,15625,n_samples)
        data_iterator = iter(train_loader)
        x, target = next(data_iterator)
        x, target = x.to(device), target.to(device)
        scores = []
        acc_run = []
        acc_run_cur = []
        nparams = []
        for arch in indices:
            config = api.get_net_config(arch, dataset)
            info = api.query_by_index(arch, hp='200')
            acc_run.append(info.get_metrics(dset, acc_type)['accuracy'])
            network = get_cell_based_tiny_net(config)  # create the network from configuration
            # Compute the number of parameters
            nparams.append(sum(p.numel() for p in network.parameters()))
            network = network.to(device)
            untrained_acc = []
            # Test the same network with different initializations
            for seed in range(n_init):
                # Initialise the network with the seed
                prepare_seed(seed)
                network.apply(initialize_resnet)
                # Propagate through the network once
                _, y_pred = network(x)
                # Get predictions
                label_pred = torch.argmax(y_pred, dim=1)
                # Compute accuracy
                coinc = torch.sum(torch.eq(label_pred, target))
                pred_accuracy = coinc.cpu().detach().numpy()/batch_size
                untrained_acc.append(pred_accuracy)
            # Condition takes care of random informationless networks, that give 0 standard deviation
            if np.std(untrained_acc)>0.0:
                # Compute the score
                scores.append(np.std(untrained_acc)/np.mean(untrained_acc))
            else:
                # Otherwise set the score to some large value
                scores.append(999)
            
            
        acc_run.sort(reverse=True)
        best_arch = indices[np.argmin(scores)]
        info_best = api.query_by_index(best_arch, hp='200')
        ind_actual_best = acc_run.index(info_best.get_metrics(dset, acc_type)['accuracy'])
        ind_actual_best_mean += ind_actual_best

    #     print("Actual ranking: " + str(ind_actual_best))# + '/' + str(ind_actual_best_median))
        topscores.append(scores[np.argmin(scores)])
        chosen.append(best_arch)
        acc.append(info_best.get_metrics(dset, acc_type)['accuracy'])

        if not dataset == 'cifar10' or trainval:
            val_acc.append(info_best.get_metrics(dset, val_acc_type)['accuracy'])
    #         val_acc_median.append(info_best_median.get_metrics(dset, val_acc_type)['accuracy'])
        logs.write(f"Mean acc: {mean(acc if not trainval else val_acc):.2f}% ")
        logs.write(f"Actual ranking: {ind_actual_best} \n")
#         print(f"Actual ranking: {ind_actual_best} \n")
        times.append(time.time()-start)
        runs.set_description(f"mean acc: {mean(acc if not trainval else val_acc):.2f}%")
    
    logs.write(f"Average chosen architecure's rank: {ind_actual_best_mean/n_runs} \n")
    logs.write(f"Final mean test accuracy: {np.mean(acc)} +- {np.std(acc)} \n")
    logs.write(f"Median duration: {np.median(times)} \n")
    if len(val_acc) > 1:
        logs.write(f"Final mean validation accuracy: {np.mean(val_acc)} +- {np.std(val_acc)} \n")
    

state = {'accs': acc,
         'val_accs': val_acc,
         'chosen': chosen,
         'times': times,
         'topscores': topscores,
         }

dset = dataset if not trainval else 'cifar10-valid'
fname = f"{save_loc}/{dset}_{n_runs}_{n_samples}_{seed}.t7"
torch.save(state, fname)