# Imports

In [1]:
import os, sys, time, glob, random, argparse
import wandb
import numpy as np
from copy import deepcopy
import torch
import torch.nn as nn
import time
import tqdm
import scipy.stats as stats
import matplotlib.pyplot as plt
import pickle
import pandas as pd

# XAutoDL 
from xautodl.config_utils import load_config, dict2config, configure2str
from xautodl.datasets import get_datasets, get_nas_search_loaders
from xautodl.procedures import (
    prepare_seed,
    prepare_logger,
    save_checkpoint,
    copy_checkpoint,
    get_optim_scheduler,
)
from xautodl.utils import get_model_infos, obtain_accuracy
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
from xautodl.models import get_search_spaces

# API
from nats_bench import create

# custom modules
from custom.tss_model import TinyNetwork
from xautodl.models.cell_searchs.genotypes import Structure
from ZeroShotProxy import *
from tss_utils import compute_vkdnw, plot_stats, get_stats, get_metrics, analyze_results, generate_accs, get_results_from_api, get_scores

# All

In [2]:
# scp 'tyblondr@login.rci.cvut.cz:~/PycharmProjects/VKDNW/NB201/*_log.csv' data_0811/

target = 'val_accs'
dataset = 'cifar100'
compute_graf = True
zero_cost_score_list = ['vkdnw', 'vkdnw_dim', 'vkdnw_chisquare', 'az_nas', 'jacov','gradsign', 'zico', 'zen', 'grad_norm', 'naswot', 'synflow', 'snip', 'grasp', 'te_nas', 'flops', 'trainability', 'expressivity', 'progressivity']
plt.rcParams['text.usetex'] = False
compute_graf_str = 'filtered' if compute_graf else 'unfiltered'

In [3]:
run_dict = {
    'cifar10': {
        # same images
        #'nazderaze/VKDNW/n6n44keg': 1,
        #'nazderaze/VKDNW/6c1pv095': 2,
        #'nazderaze/VKDNW/45us11be': 3,
        #'nazderaze/VKDNW/rwp2qvkw': 4,
        #'nazderaze/VKDNW/6bsssh4z': 5,
        
        # different images
        #'nazderaze/VKDNW/7a8jm975' : -1 # testing (successfull)
        #'nazderaze/VKDNW/cixw6r9b': 1,
        #'nazderaze/VKDNW/1nikalf3': 2,
        #'nazderaze/VKDNW/r5l3qg12': 3,
        #'nazderaze/VKDNW/fq581h0d': 4,
        #'nazderaze/VKDNW/0yfzaxra': 5,
        
        # entropy
        #'nazderaze/VKDNW/87vy5rby': 1,
        #'nazderaze/VKDNW/8sd2hdws': 5,
        #'nazderaze/VKDNW/bh09isa0': 3,
        #'nazderaze/VKDNW/ei31s875': 4,
        #'nazderaze/VKDNW/koxirqt3': 2,
        
        # entropy all
        #'nazderaze/VKDNW/6lcy0jqp': 1,
        #'nazderaze/VKDNW/orl2me1l': 2,
        #'nazderaze/VKDNW/y5n4lnsn': 3,
        #'...': 4,
        #'nazderaze/VKDNW/7gkdl3gl': 5,
        
        # final
        'nazderaze/VKDNW/k7llaf2s': 1,
        'nazderaze/VKDNW/knu2gv65': 2,
        'nazderaze/VKDNW/fnfbpspk': 3,
        'nazderaze/VKDNW/w4b39d0d': 4,
        'nazderaze/VKDNW/vbr5n8iy': 5
        
    },
    'cifar100': {
        # same images
        #'nazderaze/VKDNW/qmq5vp3k': 1,
        #'nazderaze/VKDNW/31lrq6p7': 2,
        #'nazderaze/VKDNW/3qazc6po': 3,
        #'nazderaze/VKDNW/424twoyv': 4,
        #'nazderaze/VKDNW/783h4opf': 5,
        
        # different images
        #'nazderaze/VKDNW/n2m8i53l': 1,
        #'nazderaze/VKDNW/48xbnfdg': 2,
        #'nazderaze/VKDNW/f0czkx5u': 3,
        #'nazderaze/VKDNW/ldv3b1bh': 4,
        #'nazderaze/VKDNW/qwae4nqx': 5,
        
        # entropy
        #'nazderaze/VKDNW/lfnin2ui': 1,
        #'nazderaze/VKDNW/1cs8orlm': 4,
        #'nazderaze/VKDNW/4tep5lv0': 5,
        #'nazderaze/VKDNW/lqka0nwl': 2,
        #'nazderaze/VKDNW/mqrfyv7t': 3,
        
        # entropy all
        #'...': 1,
        #'...': 2,
        #'...': 3,
        #'...': 4,
        #'...': 5,
        
        # final
        'nazderaze/VKDNW/v2b816ul': 1,
        'nazderaze/VKDNW/velwxtxd': 2,
        'nazderaze/VKDNW/f3ljf5jf': 3,
        'nazderaze/VKDNW/e3w3dkv7': 4,
        'nazderaze/VKDNW/9ibvj04q': 5,
    },
    'ImageNet16-120': {
        # same images
        #'nazderaze/VKDNW/ftg0tdsa': 1,
        #'nazderaze/VKDNW/vqf1ey6x': 2,
        #'nazderaze/VKDNW/v0a0m67q': 3,
        #'nazderaze/VKDNW/uiv37u18': 4,
        #'nazderaze/VKDNW/c1338vfg': 5,
        
        # different images
        #'nazderaze/VKDNW/55f1omxn': 1,
        #'nazderaze/VKDNW/amdcxrz7': 2,
        #'nazderaze/VKDNW/sl0rjhwh': 3,
        #'nazderaze/VKDNW/z2ph6iav': 4,
        #'nazderaze/VKDNW/ol9rwkeo': 5,
        
        # entropy
        # 'nazderaze/VKDNW/hwtw58ot': -1 # test
        #'nazderaze/VKDNW/u214u8p1': 4,
        #'nazderaze/VKDNW/bycu6ed3': 5,
        #'nazderaze/VKDNW/2ztcsos6': 3,
        #'nazderaze/VKDNW/kptcxbi2': 2,
        #'nazderaze/VKDNW/f9n7j83e': 1,
        
        #'nazderaze/VKDNW/hrrxcexh': 1,
        #'nazderaze/VKDNW/jwhy017w': 2,
        #'nazderaze/VKDNW/82rycftz': 3,
        #'nazderaze/VKDNW/bw6b93p9': 4,
        #'nazderaze/VKDNW/ca423k7t': 5,
        
        # final
        'nazderaze/VKDNW/ss2kwvpp': 1,
        'nazderaze/VKDNW/es9t2696': 2,
        'nazderaze/VKDNW/m1dusbs9': 3,
        'nazderaze/VKDNW/kgmcpu2y': 4,
        'nazderaze/VKDNW/cpbmsvuw': 5,
        
    }
}

In [4]:
api_nats = create('/mnt/personal/tyblondr/NATS-tss-v1_0-3ffb9-simple/', 'tss', fast_mode=True, verbose=False)

if os.path.exists(f"./tss_features_{dataset}.pickle"):
    archs = pd.read_pickle(f"./tss_features_{dataset}.pickle")
else:
    archs = generate_accs(api_nats, dataset=dataset)
    print(f'No. of generated archs: {archs.shape[0]}')
    archs.to_pickle(f"./tss_features_{dataset}.pickle")

In [5]:
api_wandb = wandb.Api()

log = None
results = None
for run_id, seed in run_dict[dataset].items():
    
    run = pd.DataFrame(api_wandb.run(run_id).scan_history())
    run.rename({'arch': 'net_str'}, axis=1, inplace=True)
    
    run = pd.merge(archs, run, on='net_str', how='inner')
    if compute_graf:
        run = run.loc[run['net'].notnull(), :]  # keep only nets with features
        
    print(f'No. of archs for seed {seed} after filtering: {run.shape[0]}.')
    
    for col in run.columns:
        if col not in ['net_str', 'net']:
            run[col] = run[col].astype(float)
    
    if 'jacov' in run.columns:        
        run['jacov'] = run['jacov'].fillna(run['jacov'].min()).astype(float)

    df_scores = get_scores(run.copy(), compute_graf=compute_graf, zero_cost_score_list=zero_cost_score_list)
    df_scores['dataset'] = dataset
    df_scores['seed'] = seed
    if results is None:
        results = df_scores
    else:
        results = pd.concat([results, df_scores], ignore_index=True)
print(f'Total number of records: {results.shape[0]}')

No. of archs for seed 1 after filtering: 9445.
Running vkdnw
Running vkdnw_dim
Running vkdnw_chisquare
Running az_nas
Running jacov
Running gradsign
Running zico
Running zen
Running grad_norm
Running naswot
Running synflow
Running snip
Running grasp
Running te_nas
Running flops
Running trainability
Running expressivity
Running progressivity
No. of archs for seed 2 after filtering: 9445.
Running vkdnw
Running vkdnw_dim
Running vkdnw_chisquare
Running az_nas
Running jacov
Running gradsign
Running zico
Running zen
Running grad_norm
Running naswot
Running synflow
Running snip
Running grasp
Running te_nas
Running flops
Running trainability
Running expressivity
Running progressivity
No. of archs for seed 3 after filtering: 9445.
Running vkdnw
Running vkdnw_dim
Running vkdnw_chisquare
Running az_nas
Running jacov
Running gradsign
Running zico
Running zen
Running grad_norm
Running naswot
Running synflow
Running snip
Running grasp
Running te_nas
Running flops
Running trainability
Running expres

In [6]:
log = None
for seed in results['seed'].unique():
    
    results_temp = results.loc[results['seed'] == seed, :].copy()
    for zero_cost_rank in [p for p in results_temp.columns if '_rank' in p]:
        results_temp[[zero_cost_rank]] = results_temp[[zero_cost_rank]].apply(lambda x: x.replace(-np.inf, x[x != -np.inf].min()))
        results_temp[[zero_cost_rank]] = results_temp[[zero_cost_rank]].apply(lambda x: x.replace(-np.nan, x[x != -np.inf].min()))
        results_temp[[zero_cost_rank]] = results_temp[[zero_cost_rank]].apply(lambda x: x.replace(np.inf, x[x != np.inf].max()))
        log_temp = pd.DataFrame(get_metrics(results_temp, pred_name=zero_cost_rank, show_plot=False, seed=seed))
        
        if log is None:
            log = log_temp.copy()
        else:
            log = pd.concat([log, log_temp.copy()], ignore_index=True)
            
        if seed == min(results['seed'].unique()):
            plot_stats(get_stats(results_temp, 'vkdnw_dim', target, zero_cost_rank), 'vkdnw_dim', target, zero_cost_rank, f'{dataset}_{str(compute_graf)}_{zero_cost_rank}')

log = log.groupby('pred_name', as_index=False).agg(['mean', 'std']).reset_index()
log['dataset'] = dataset
log['no_seeds'] = len(results['seed'].unique())
log['archs_filtered'] = compute_graf_str
log

  return x.astype(dtype, copy=copy, casting=casting)
  return x.astype(dtype, copy=copy, casting=casting)
  return x.astype(dtype, copy=copy, casting=casting)
The PostScript backend does not support transparency; partially transparent artists will be rendered opaque.
  return x.astype(dtype, copy=copy, casting=casting)
  return x.astype(dtype, copy=copy, casting=casting)
  return x.astype(dtype, copy=copy, casting=casting)
The PostScript backend does not support transparency; partially transparent artists will be rendered opaque.
  return x.astype(dtype, copy=copy, casting=casting)
  return x.astype(dtype, copy=copy, casting=casting)
  return x.astype(dtype, copy=copy, casting=casting)
The PostScript backend does not support transparency; partially transparent artists will be rendered opaque.
  return x.astype(dtype, copy=copy, casting=casting)
  return x.astype(dtype, copy=copy, casting=casting)
  return x.astype(dtype, copy=copy, casting=casting)
The PostScript backend does not suppo

Unnamed: 0_level_0,pred_name,kendall,kendall,spearman,spearman,pearson,pearson,ndcg50,ndcg50,ndcg100,...,ndcg1000,ndcg5000,ndcg5000,acc_top,acc_top,acc_top_true,acc_top_true,dataset,no_seeds,archs_filtered
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,std,mean,std,mean,std,mean,std,mean,...,std,mean,std,mean,std,mean,std,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,az_nas_rank,0.695602,0.001746,0.879845,0.001545,0.858297,0.001490419,0.2371874,0.004582295,0.268264,...,0.000205,0.799945,0.002586,71.035333,0.415522,73.513333,0.0,cifar100,5,filtered
1,expressivity_rank,0.654863,0.001015,0.845054,0.000665,0.804432,0.002075859,0.09657334,0.01020208,0.1155037,...,0.000158,0.745047,0.002101,69.858667,0.221731,73.513333,0.0,cifar100,5,filtered
2,flops_rank,0.585992,0.0,0.763405,0.0,0.552692,0.0,0.2558686,0.0,0.3184627,...,0.0,0.795824,0.0,71.106667,0.0,73.513333,0.0,cifar100,5,filtered
3,grad_norm_rank,0.340531,0.000945,0.451059,0.001457,0.092914,0.00687614,0.004226659,5.29801e-05,0.01820478,...,0.000346,0.460893,0.000558,61.643333,1.106933,73.513333,0.0,cifar100,5,filtered
4,gradsign_rank,-0.528416,0.001065,-0.72081,0.001284,-0.694566,0.0008028933,2.759439e-08,2.863677e-08,6.847989e-07,...,0.001872,0.089389,0.000709,32.150667,6.960669,73.513333,0.0,cifar100,5,filtered
5,grasp_rank,0.349295,0.004538,0.49832,0.00673,-0.00488,0.01173927,0.01315595,0.003172422,0.03012571,...,0.000683,0.474424,0.006881,56.568667,6.210313,73.513333,0.0,cifar100,5,filtered
6,jacov_rank,0.638973,0.002053,0.819601,0.001824,0.49753,0.0307489,0.1247021,0.04317487,0.1525877,...,0.001052,0.759757,0.008842,65.196,4.213467,73.513333,0.0,cifar100,5,filtered
7,naswot_rank,0.606959,0.000323,0.798992,0.000274,0.833875,3.399385e-05,0.1695155,0.002965319,0.1928396,...,0.000167,0.766585,0.000577,68.665333,1.664259,73.513333,0.0,cifar100,5,filtered
8,progressivity_rank,0.513938,0.004868,0.698894,0.005675,0.745085,0.006396033,0.177445,0.01199341,0.202988,...,0.000668,0.680026,0.005277,69.424,0.089889,73.513333,0.0,cifar100,5,filtered
9,snip_rank,0.439946,0.000344,0.596509,0.000223,0.174439,0.005036299,0.007845338,0.001824556,0.03041179,...,0.00031,0.466606,0.001977,54.526667,6.465118,73.513333,0.0,cifar100,5,filtered


In [7]:
"""
for zero_cost_score in ['vkdnw']:
    for i in range(5):
        results_temp = results.loc[results['seed'] == results['seed'].unique().min(), :].copy().sample(n=3000)
        analyze_results(api_nats, results_temp, zero_cost_score, target)
        
df_top = None
for seed in range(1, 6):
    results_temp = results.loc[results['seed'] == seed, :].sample(n=3000, random_state=seed).copy()
    for pred in [p for p in results_temp.columns if '_rank' in p]:
        top_acc = results_temp.loc[results_temp[pred].idxmax(), target]
        if df_top is None:
            df_top = pd.DataFrame({'acc': top_acc, 'acc_max': results_temp[target].max(), 'pred': pred, 'seed': seed}, index=[0])
        else:
            df_top = pd.concat([df_top, pd.DataFrame({'acc': top_acc, 'acc_max': results_temp[target].max(), 'pred': pred, 'seed': seed}, index=[0])], ignore_index=True)
df_top.groupby('pred', as_index=False)[['acc', 'acc_max']].agg(['max', 'mean', 'std']).reset_index()
"""

"\nfor zero_cost_score in ['vkdnw']:\n    for i in range(5):\n        results_temp = results.loc[results['seed'] == results['seed'].unique().min(), :].copy().sample(n=3000)\n        analyze_results(api_nats, results_temp, zero_cost_score, target)\n        \ndf_top = None\nfor seed in range(1, 6):\n    results_temp = results.loc[results['seed'] == seed, :].sample(n=3000, random_state=seed).copy()\n    for pred in [p for p in results_temp.columns if '_rank' in p]:\n        top_acc = results_temp.loc[results_temp[pred].idxmax(), target]\n        if df_top is None:\n            df_top = pd.DataFrame({'acc': top_acc, 'acc_max': results_temp[target].max(), 'pred': pred, 'seed': seed}, index=[0])\n        else:\n            df_top = pd.concat([df_top, pd.DataFrame({'acc': top_acc, 'acc_max': results_temp[target].max(), 'pred': pred, 'seed': seed}, index=[0])], ignore_index=True)\ndf_top.groupby('pred', as_index=False)[['acc', 'acc_max']].agg(['max', 'mean', 'std']).reset_index()\n"

In [8]:
log_train = None
if not compute_graf:
    print('No graf prediction.')
else:
    print('Graf prediction.')
    
    from sklearn.ensemble import RandomForestRegressor
    from sklearn.pipeline import Pipeline
    from sklearn.preprocessing import StandardScaler
    from sklearn.model_selection import train_test_split
    
    pred_lists = {
        'model_vkdnw': [p for p in results.columns if '_lambda_' in p] + ['vkdnw_entropy', 'vkdnw_dim', 'flops'], #['vkdnw_chisquare', 'vkdnw_dim', 'flops'],
        'model_vkdnw+zs': [p for p in results.columns if '_lambda_' in p] + ['vkdnw_dim', 'flops'] + ['expressivity', 'progressivity', 'trainability', 'jacov','gradsign', 'zico', 'zen', 'grad_norm', 'naswot', 'synflow', 'snip', 'grasp', 'ntk', 'linear_region'],
        'model_vkdnw+zs+graf': [p for p in results.columns if '_lambda_' in p] + ['vkdnw_dim', 'flops'] + ['expressivity', 'progressivity', 'trainability', 'jacov','gradsign', 'zico', 'zen', 'grad_norm', 'naswot', 'synflow', 'snip', 'grasp', 'ntk', 'linear_region'] + [p for p in results.columns if 'op_' in p] + [p for p in results.columns if 'node_' in p],
        'model_graf': [p for p in results.columns if 'op_' in p] + [p for p in results.columns if 'node_' in p]
    }
    
    for train_size in [1024]:
        for seed in results['seed'].unique():
            
            results_temp = results.loc[results['seed'] == seed, :].copy()
            results_temp = results_temp.apply(lambda x: x.replace(-np.inf, x[x != -np.inf].min()))
            results_temp = results_temp.apply(lambda x: x.replace(np.inf, x[x != np.inf].max()))
            for model_name, pred_list in pred_lists.items():
                
                train_df, test_df = train_test_split(results_temp, test_size=1 - (train_size / results_temp.shape[0]), random_state=seed)
                model = Pipeline([
                    ('scaler', StandardScaler()),           # Step 1: Standardize features
                    ('regressor', RandomForestRegressor(n_estimators=100))  # Step 2: Train RandomForestRegressor
                ])
                model.fit(train_df[pred_list], train_df[target])
                test_df['pred_' + model_name] = model.predict(test_df[pred_list])
                log_train_temp = pd.DataFrame(get_metrics(test_df, 'pred_' + model_name, show_plot=False, seed=seed))        
                log_train_temp['train_size'] = train_size
            
                if log_train is None:
                    log_train = log_train_temp.copy()
                else:
                    log_train = pd.concat([log_train, log_train_temp.copy()], ignore_index=True)
                    
                if seed == min(results['seed'].unique()):
                    plot_stats(get_stats(test_df, 'vkdnw_dim', target, 'pred_' + model_name), 'vkdnw_dim', target, model_name, f'{dataset}_{str(compute_graf)}_{model_name}_{train_size}')

    log_train = log_train.groupby(['pred_name', 'train_size'], as_index=False).agg(['mean', 'std']).reset_index()
    log_train['dataset'] = dataset
    log_train['no_seeds'] = len(results['seed'].unique())
    log_train['archs_filtered'] = compute_graf_str
    log_train

Graf prediction.


  return x.astype(dtype, copy=copy, casting=casting)
  return x.astype(dtype, copy=copy, casting=casting)
  return x.astype(dtype, copy=copy, casting=casting)
The PostScript backend does not support transparency; partially transparent artists will be rendered opaque.
  return x.astype(dtype, copy=copy, casting=casting)
  return x.astype(dtype, copy=copy, casting=casting)
  return x.astype(dtype, copy=copy, casting=casting)
The PostScript backend does not support transparency; partially transparent artists will be rendered opaque.
  return x.astype(dtype, copy=copy, casting=casting)
  return x.astype(dtype, copy=copy, casting=casting)
  return x.astype(dtype, copy=copy, casting=casting)
The PostScript backend does not support transparency; partially transparent artists will be rendered opaque.
  return x.astype(dtype, copy=copy, casting=casting)
  return x.astype(dtype, copy=copy, casting=casting)
  return x.astype(dtype, copy=copy, casting=casting)
The PostScript backend does not suppo

In [None]:
log_train

In [30]:
if compute_graf:
    log_save = pd.concat([log, log_train], ignore_index=True)
else:
    log_save = log
    
log_save.columns = ['_'.join(col) for col in log_save.columns]
log_save.to_csv(f'{dataset}_{compute_graf_str}_log.csv', index=False)
log_save

Unnamed: 0,pred_name_,kendall_mean,kendall_std,spearman_mean,spearman_std,pearson_mean,pearson_std,ndcg50_mean,ndcg50_std,ndcg100_mean,...,ndcg5000_mean,ndcg5000_std,acc_top_mean,acc_top_std,acc_top_true_mean,acc_top_true_std,dataset_,no_seeds_,archs_filtered_,train_size_
0,az_nas_rank,0.672816,0.001249,0.858689,0.001364,0.873234,0.001140594,0.3290842,0.02327173,0.3797528,...,0.79899,0.00559,44.792222,1.361551,47.311111,0.0,ImageNet16-120,5,filtered,
1,expressivity_rank,0.587117,0.000693,0.77801,0.000602,0.769308,0.0019705,0.01919631,0.00249915,0.02629433,...,0.706877,0.002471,40.338889,1.702005,47.311111,0.0,ImageNet16-120,5,filtered,
2,flops_rank,0.545322,0.0,0.717813,0.0,0.593022,0.0,0.1501999,0.0,0.1787965,...,0.748332,0.0,41.444444,0.0,47.311111,0.0,ImageNet16-120,5,filtered,
3,grad_norm_rank,0.310441,0.000504,0.417613,0.000802,0.036808,0.008416389,0.0003749862,0.0002213798,0.008434647,...,0.444068,0.000431,18.45,10.127342,47.311111,0.0,ImageNet16-120,5,filtered,
4,gradsign_rank,-0.45441,0.00243,-0.626827,0.003195,-0.679063,0.001058595,5.892313e-09,6.299651e-09,2.276524e-08,...,0.059683,0.001371,12.208889,1.132227,47.311111,0.0,ImageNet16-120,5,filtered,
5,grasp_rank,0.35911,0.003922,0.50166,0.006311,-0.000345,0.04345215,0.005221175,0.003141932,0.02473023,...,0.472276,0.004348,12.051111,10.259637,47.311111,0.0,ImageNet16-120,5,filtered,
6,jacov_rank,0.602018,0.00178,0.779219,0.001888,0.399077,0.01950957,0.09877489,0.02561853,0.1248206,...,0.732392,0.002704,38.788889,1.119282,47.311111,0.0,ImageNet16-120,5,filtered,
7,naswot_rank,0.605221,0.000404,0.794006,0.000386,0.815677,0.000146786,0.3198674,0.003261148,0.2831819,...,0.784324,0.000449,37.108889,3.647275,47.311111,0.0,ImageNet16-120,5,filtered,
8,progressivity_rank,0.458632,0.00468,0.629668,0.005409,0.649028,0.01048948,0.1258578,0.01959717,0.1436977,...,0.605715,0.005453,42.868889,2.220074,47.311111,0.0,ImageNet16-120,5,filtered,
9,snip_rank,0.38926,0.000437,0.52115,0.000602,0.101274,0.009506354,0.001220818,0.0004913652,0.01511454,...,0.449849,0.001805,0.833333,0.0,47.311111,0.0,ImageNet16-120,5,filtered,
