In [None]:
import pickle
from scipy import stats
from tqdm import tqdm
from prettytable import PrettyTable
from nas_201_api import NASBench201API as API
import numpy as np
from functools import cmp_to_key

api = API('./data/NAS-Bench-201-v1_1-096897.pth', verbose=False)
print(len(api))

## Val and Test accuracy on NAS-Bench-201

In [None]:
val_acc_c10  = {}
val_acc_c100 = {}
val_acc_im16 = {}
acc_c10  = {}
acc_c100 = {}
acc_im16 = {}
for i in tqdm(range(len(api))):
    arch_str = api[i]
    index = api.query_index_by_arch(arch_str)
    # val acc on cifar10
    val_acc_c10[arch_str]  = api.get_more_info(index, 'cifar10-valid')['valid-accuracy']
    # val acc on cifar100
    val_acc_c100[arch_str] = api.get_more_info(index, 'cifar100')['valid-accuracy']
    # val acc on ImageNet16-120
    val_acc_im16[arch_str] = api.get_more_info(index, 'ImageNet16-120')['valid-accuracy']
    # test acc on cifar10
    info = api.get_more_info(index, 'cifar10-valid', iepoch=None, hp='200', is_random=False)
    acc_c10[arch_str]  = info['test-accuracy']
    # test acc on cifar100
    info = api.get_more_info(index, 'cifar100', iepoch=None, hp='200', is_random=False)
    acc_c100[arch_str] = info['test-accuracy']
    # test acc on ImageNet16-120
    info = api.get_more_info(index, 'ImageNet16-120', iepoch=None, hp='200', is_random=False)
    acc_im16[arch_str] = info['test-accuracy']
val_accs = {'cifar10': val_acc_c10, 'cifar100': val_acc_c100, 'ImageNet16-120': val_acc_im16}
accs = {'cifar10': acc_c10, 'cifar100': acc_c100, 'ImageNet16-120': acc_im16}

## Get the list of (nb2*.p path, dataset) and names of all metrics

In [None]:
import os, glob, pickle
def search_file(pattern, search_path):
    for path in search_path.split(os.pathsep):
        for match in glob.glob(os.path.join(path, pattern)):
            yield match

path_list = list(search_file('nb2*.p', './'))
p_list = [(x, 'ImageNet16-120' if 'im120' in x else 'cifar100' if 'cf100' in x else 'cifar10', 3 if 'im120' in x else 2 if 'cf100' in x else 1) for x in path_list]
p_list = sorted(p_list,key=lambda x: x[2])
print(p_list)
metric_names = ['val_acc'] + list(pickle.load(open(p_list[0][0],'rb'))['logmeasures'].keys())
print(metric_names)

## Vote setting

In [None]:
orignal_vote_metric_names = ['snip', 'synflow', 'jacob_cor']
orignal_vote_signs = [1, 1, 1]
our_vote_1_metric_names = ['synflow', 'jacob_cor', 'act_grad_cor_weighted']
our_vote_1_signs = [1, 1, 1]
vote_metric_names = [orignal_vote_metric_names, our_vote_1_metric_names]
vote_signs        = [orignal_vote_signs, our_vote_1_signs]
metric_names.extend(['ori_vote', 'our_vote'])
print(metric_names)

## Soft vote and Hard vote

In [None]:
def soft_vote(vote_metric_names, vote_signs, metrics, normlize='zscore'):
    vote_metrics = []
    for i, k in enumerate(vote_metric_names):
        metric_nparray = np.array(vote_signs[i] * metrics[k])
        if normlize == 'minmax':
            metric_nparray = (metric_nparray - np.ma.masked_invalid(metric_nparray).min()) / (np.ma.masked_invalid(metric_nparray).max() - np.ma.masked_invalid(metric_nparray).min())
        elif normlize == 'zscore':
            metric_nparray = (metric_nparray - np.ma.masked_invalid(metric_nparray).mean()) / np.ma.masked_invalid(metric_nparray).std()
        else:
            raise ValueError('No normlize {}.'.format(normlize))
        vote_metrics.append(metric_nparray)
    return (sum(vote_metrics)/len(vote_metrics)).tolist()

def hard_vote(vote_metric_names, vote_signs, metrics):
    num_archs = len(metrics[list(metrics.keys())[0]])
    archs_idx = list(range(num_archs))
    
    def cmp(idx1, idx2):
        ret = []
        for i, k in enumerate(vote_metric_names):
            ret.append(vote_signs[i] * (metrics[k][idx1] - metrics[k][idx2]))
        ret = np.array(ret)
        if sum(ret<0)>len(vote_metric_names)/2: return -1
        elif sum(ret>0)>len(vote_metric_names)/2: return 1
        else: return 0
    
    sorted_archs_idx = sorted(archs_idx, key=cmp_to_key(cmp))
    archs_idx_ranking = []
    for x in archs_idx:
        archs_idx_ranking.append(sorted_archs_idx.index(x))
    return archs_idx_ranking

def vote(vote_metric_names, vote_signs, metrics, mode):
    if 'hard' in mode:
        return hard_vote(vote_metric_names, vote_signs, metrics)
    elif 'soft' in mode:
        if 'zscore' in mode:
            return soft_vote(vote_metric_names, vote_signs, metrics, normlize='zscore')
        elif 'minmax' in mode:
            return soft_vote(vote_metric_names, vote_signs, metrics, normlize='minmax')
        else:
            raise ValueError('No {:} mode'.format(mode))
    else:
        raise ValueError('No {:} mode'.format(mode))

## Kendall τ distance of zero-cost proxies on NAS-Bench-201

In [None]:
t=None
hl=['Dataset']
hl.extend(metric_names)
t = PrettyTable(hl)

for fname, rname, _ in p_list:
    runs=[]
    f = open(fname,'rb')
    while(1):
        try:
            runs.append(pickle.load(f))
        except EOFError:
            break
    f.close()
    print(fname, len(runs))

    metrics={}
    for k in metric_names:
        metrics[k] = []
    acc = []

    for r in runs:
        for k,v in r['logmeasures'].items():
            if k in metrics:
                metrics[k].append(v)
        metrics['val_acc'].append(val_accs[rname][r['arch']])
        acc.append(accs[rname][r['arch']])

    for i, (vote_metric_name, vote_sign) in enumerate(reversed(list(zip(vote_metric_names, vote_signs)))):
        metrics[metric_names[-(i+1)]] = vote(vote_metric_name, vote_sign, metrics, mode='hard')

    res = []
    for k in hl:
        if k=='Dataset':
            continue
        v = metrics[k]
        cr = stats.kendalltau(acc, v, nan_policy='omit')[0]
        res.append(round(cr,3))

    t.add_row([rname]+res)

print(t)

## Spearman ρ of zero-cost proxies on NAS-Bench-201

In [None]:
t=None
hl=['Dataset']
hl.extend(metric_names)
t = PrettyTable(hl)

for fname, rname, _ in p_list:
    runs=[]
    f = open(fname,'rb')
    while(1):
        try:
            runs.append(pickle.load(f))
        except EOFError:
            break
    f.close()
    print(fname, len(runs))

    metrics={}
    for k in metric_names:
        metrics[k] = []
    acc = []

    for r in runs:
        for k,v in r['logmeasures'].items():
            if k in metrics:
                metrics[k].append(v)
        metrics['val_acc'].append(val_accs[rname][r['arch']])
        acc.append(accs[rname][r['arch']])

    for i, (vote_metric_name, vote_sign) in enumerate(reversed(list(zip(vote_metric_names, vote_signs)))):
        metrics[metric_names[-(i+1)]] = vote(vote_metric_name, vote_sign, metrics, mode='hard')

    res = []
    for k in hl:
        if k=='Dataset':
            continue
        v = metrics[k]
        cr = stats.spearmanr(acc, v, nan_policy='omit').correlation
        res.append(round(cr,3))

    t.add_row([rname]+res)

print(t)

## Origin paper result