In [25]:
import argparse
import pandas as pd

In [26]:
parser = argparse.ArgumentParser()
 
parser.add_argument('--data_path', type=str, default='/home/gean/nns_performance_prediction/meta_datasets/', 
                    help='location of the dataset')    
parser.add_argument('--results_path', type=str, default='/home/gean/nns_performance_prediction/results/oracle/', 
                    help='location of the results directory')    
parser.add_argument('--k', type=int, default=5000, 
                    help='number of k best archs to extract accs')

#'+' == 1 or more, '*' == 0 or more, '?' == 0 or 1.
parser.add_argument('--dataset', type=str, default=['cifar10valid', 'cifar100', 'imagenet16_120'], nargs='+', 
                    help='one of the datasets from nasbench201, being cifar10valid, cifar100, or imagenet16_120')
parser.add_argument('--data_subset', type=int, default=[4, 12, 36, 108, 200], nargs='+', 
                    help='one of the subsets from nasbench201 with 1, 4, 12, 36, 108, or 200 epochs')
parser.add_argument('--verbose', type=int, default=1, 
                    help='control the logging prints. 0 for deactivate and 1 for activate')

args, unknown = parser.parse_known_args()  

In [27]:
def save_results(performance_dict, dataset, subset): 
    df_results = pd.DataFrame.from_dict(performance_dict)
    df_results.to_csv(str(args.results_path + 'nasbench201_' + str(dataset) + '_' + str(subset) + 
                          'epochs_k' + str(args.k) + '_oracle.csv'), index=False, float_format='%.6f')
    
    return df_results

In [30]:
def main():
    for data in args.dataset:
    
        for subset in args.data_subset:
            '''
            best_test_acc -> best test acc of the whole subset
            true_valid_acc -> best valid acc of the whole subset
            true_test_acc -> test acc of the arch with the best true_valid_acc of the whole subset 
            '''
            try:
                df_whole = pd.read_csv(str(args.data_path + 'nasbench201_' + str(data) + '_' + 
                                           str(subset) + 'epochs.csv'), index_col=0)
            except FileNotFoundError as e:
                print(e)
                continue
                
            if args.verbose: print("\n\n######### {}, Subset{}".format(data, subset))

            best_test_acc = df_whole.sort_values(by='acc_test', ascending=False).iloc[0]['acc_test']    
            best_acc_valid_row = df_whole.sort_values(by='acc_valid', ascending=False).iloc[0]
            true_valid_acc = best_acc_valid_row['acc_valid']
            true_test_acc = best_acc_valid_row['acc_test']

            #results for each k does not change in the oracle
            df_dict = {'Dataset': [data] * args.k, 
                       'Epoch': [subset] * args.k, 
                       'Model': ['Oracle'] * args.k, 
                       'K': list(range(1, args.k + 1)), 
                       'True_Val_Acc': [true_valid_acc] * args.k, 
                       'True_Test_Acc': [true_test_acc] * args.k, 
                       'Best_Test_Acc': [best_test_acc] * args.k}            

            # one per dataset-subset
            df_results = save_results(df_dict, data, subset)
            if args.verbose: print(df_results.head())

In [31]:
if __name__ == '__main__':
    print("data_path: ", args.data_path)
    print("results_path: ", args.results_path)
    print("k: ", args.k)
    print("dataset: ", args.dataset)
    print("data_subset: ", args.data_subset)   
    print("verbose: ", args.verbose)   
    
    main()

data_path:  /home/gean/nns_performance_prediction/meta_datasets/
results_path:  /home/gean/nns_performance_prediction/results/oracle/
k:  5000
dataset:  ['cifar10valid', 'cifar100', 'imagenet16_120']
data_subset:  [4, 12, 36, 108, 200]
verbose:  1


######### cifar10valid, Subset4
        Dataset  Epoch   Model  K  True_Val_Acc  True_Test_Acc  Best_Test_Acc
0  cifar10valid      4  Oracle  1        67.256            NaN            NaN
1  cifar10valid      4  Oracle  2        67.256            NaN            NaN
2  cifar10valid      4  Oracle  3        67.256            NaN            NaN
3  cifar10valid      4  Oracle  4        67.256            NaN            NaN
4  cifar10valid      4  Oracle  5        67.256            NaN            NaN


######### cifar10valid, Subset12
        Dataset  Epoch   Model  K  True_Val_Acc  True_Test_Acc  Best_Test_Acc
0  cifar10valid     12  Oracle  1        78.906            NaN            NaN
1  cifar10valid     12  Oracle  2        78.906            