In [None]:
import os, sys, time, glob, random, argparse
import numpy as np
from copy import deepcopy
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
import scipy.stats as stats

sys.path.insert(0, '../../')

from xautodl.config_utils import load_config, dict2config, configure2str
from xautodl.datasets import get_datasets, get_nas_search_loaders
from xautodl.procedures import (
    prepare_seed,
    prepare_logger,
    save_checkpoint,
    copy_checkpoint,
    get_optim_scheduler,
)
from xautodl.utils import get_model_infos, obtain_accuracy
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
from xautodl.models import get_cell_based_tiny_net, get_search_spaces
from nas_201_api import NASBench201API as API

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  # Arrange GPU devices starting from 0
os.environ["CUDA_VISIBLE_DEVICES"]= "2"  # Set the GPU 2 to use
os.chdir('../../')
api = API('nasbench201/NAS-Bench-201-v1_1-096897.pth')

# file_name = 'baseline'

In [None]:
epochs = 250


search_space = get_search_spaces("cell", 'nas-bench-201')
model_config = dict2config(
    {
        "name": "RANDOM",
        "C": 16,
        "N": 5,
        "max_nodes": 4,
        "num_classes": 10,
        "space": search_space,
        "affine": False,
        "track_running_stats": bool(0),
    },
    None,
)

supernet_config = dict2config(
    {
        "name": "supernet",
        "C": 16,
        "N": 5,
        "max_nodes": 4,
        "num_classes": 10,
        "space": search_space,
        "affine": False,
        "track_running_stats": bool(0),
    },
    None,
)

def distill(result):
    result = result.split('\n')
    cifar10 = result[5].replace(' ', '').split(':')
    cifar100 = result[7].replace(' ', '').split(':')
    imagenet16 = result[9].replace(' ', '').split(':')

    cifar10_train = float(cifar10[1].strip(',test')[-7:-2].strip('='))
    cifar10_test = float(cifar10[2][-7:-2].strip('='))
    cifar100_train = float(cifar100[1].strip(',valid')[-7:-2].strip('='))
    cifar100_valid = float(cifar100[2].strip(',test')[-7:-2].strip('='))
    cifar100_test = float(cifar100[3][-7:-2].strip('='))
    imagenet16_train = float(imagenet16[1].strip(',valid')[-7:-2].strip('='))
    imagenet16_valid = float(imagenet16[2].strip(',test')[-7:-2].strip('='))
    imagenet16_test = float(imagenet16[3][-7:-2].strip('='))

    return cifar10_train, cifar10_test, cifar100_train, cifar100_valid, \
        cifar100_test, imagenet16_train, imagenet16_valid, imagenet16_test

search_model = get_cell_based_tiny_net(model_config)
supernet = get_cell_based_tiny_net(supernet_config)
supernet = supernet.cuda()
optimizer = torch.optim.SGD(
    params = search_model.parameters(),
    lr = 0.025,
    momentum = 0.9,
    weight_decay = 0.0005,
    nesterov = True 
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max = epochs,
    eta_min = 0.001
)

criterion = torch.nn.CrossEntropyLoss()



network = search_model.cuda()
criterion = criterion.cuda()

train_data, valid_data, _, _ = get_datasets(
        'cifar10', './dataset', -1
    )

search_loader, _, _ = get_nas_search_loaders( 
        train_data,                                      
        valid_data,                                     
        'cifar10',
        "configs/nas-benchmark/",
        (64, 256), 
        4,
    )

valid_loader = torch.utils.data.DataLoader(
            valid_data,
            batch_size=256,
            shuffle = False,
            num_workers=4,
            pin_memory=True,
        )

from xautodl.models.cell_searchs.genotypes import Structure

genotypes = []
op_names = deepcopy(search_space)
for i in range(1, 4):
    xlist = []
    for j in range(i):
        op_name = random.choice(op_names)
        xlist.append((op_name, j))
    genotypes.append(tuple(xlist))
arch = Structure(genotypes)

edge2index = network.edge2index
max_nodes = 4
def genotype(enc): 
    theta = enc
    genotypes = []
    for i in range(1, max_nodes):
      xlist = []
      for j in range(i):
        node_str = '{:}<-{:}'.format(i, j)
        with torch.no_grad():
          weights = theta[ edge2index[node_str] ]
          op_name = op_names[ weights.argmax().item() ]
        xlist.append((op_name, j))
      genotypes.append( tuple(xlist) )
    return Structure( genotypes )


struc = []
base = torch.zeros(6,5)
for i in range(5):   
    base[0,i] = 1
    
    for ii in range(5):
        base[1,ii] = 1     
        
        for iii in range(5):
            base[2,iii]=1
            
            for j in range(5):
                base[3,j] = 1
                
                for jj in range(5):
                    base[4,jj] = 1
                    
                    for jjj in range(5):
                        base[5,jjj] = 1
                        
                        struc.append(base.clone())
                       
                        
                        base[5] = 0
                    base[4] = 0
                base[3] = 0
            base[2] = 0
        base[1] = 0
    base[0] = 0
    
def get_num_params(result):
    result = result.split('\n')
    cifar10 = result[2].split(' ')

    return float(cifar10[-4].strip('Params='))

In [None]:
import pickle

In [None]:


with open("./exps/NAS-Bench-201-algos/kendal_valid_accs/cifar10_accs.pkl","rb") as f:
    cifar10_accs = pickle.load(f)    

with open("./exps/NAS-Bench-201-algos/kendal_valid_accs/cifar100_accs.pkl","rb") as f:
    cifar100_accs = pickle.load(f)    

with open("./exps/NAS-Bench-201-algos/kendal_valid_accs/imagenet_accs.pkl","rb") as f:
    imagenet_accs = pickle.load(f)  
    
with open("./exps/NAS-Bench-201-algos/kendal_valid_accs/num_params.pkl","rb") as f:
    num_params = pickle.load(f)    



In [None]:


low_param_val = 0.344
high_param_val = 0.343

print(low_param_val, high_param_val)

def analysis(file_name):
    
    print(f'===============  {file_name}  ===============')
    with open(f"./exps/NAS-Bench-201-algos/kendal_valid_accs/{file_name}.pkl","rb") as f:
        valid_accs = pickle.load(f)
    cifar10_valid_true_tau, _ = stats.kendalltau(valid_accs, cifar10_accs)     
    cifar100_valid_true_tau, _ = stats.kendalltau(valid_accs, cifar100_accs)   
    imagenet_valid_true_tau, _ = stats.kendalltau(valid_accs, imagenet_accs) 

    cifar10_true_valid_tau, _ = stats.kendalltau(cifar10_accs, valid_accs)   
    cifar100_true_valid_tau, _ = stats.kendalltau(cifar100_accs, valid_accs)  
    imagenet_true_valid_tau, _ = stats.kendalltau(imagenet_accs, valid_accs)

    print(f'cifar10_valid_true_tau: {cifar10_valid_true_tau}')
    print(f'cifar100_valid_true_tau: {cifar100_valid_true_tau}')
    print(f'imagenet_valid_true_tau: {imagenet_valid_true_tau}')

    print(f'param vs valid_accs: {stats.kendalltau(valid_accs, num_params) }')
    low_param = []
    low_param_valid = []
    low_param_real = []

    for i in range(len(num_params)):
        if num_params[i] < float(low_param_val):
            low_param.append(num_params[i])
            low_param_valid.append(valid_accs[i])
            low_param_real.append(cifar10_accs[i])

    print(f'low_param_kendal: {stats.kendalltau(low_param_valid, low_param_real)}')

    high_param = []
    high_param_valid = []
    high_param_real = []

    for i in range(len(num_params)):
        if num_params[i] > float(high_param_val):
            high_param.append(num_params[i])
            high_param_valid.append(valid_accs[i])
            high_param_real.append(cifar10_accs[i])

    print(f'high_param_kendal: {stats.kendalltau(high_param_real, high_param_valid)}')
    
    print(api.query_by_arch(genotype(struc[np.argmax(np.array(valid_accs))]), '200'))

    
analysis('baseline_ep_250')
print('*********************************************************\n*********************************************************')
analysis('Adaptive_LR_max_coeff_3_log_formulation')
print('*********************************************************\n*********************************************************')
