# Report the Best Parameter Sets found in each NNI Experiment

In [1]:
import os
import numpy as np
import pandas as pd
import pickle
import matplotlib.pyplot as plt

from utils import add_roll_mean, trials_to_goal, get_best, str_to_arr

import warnings
warnings.filterwarnings('ignore')

Get data from pickle or csvs files

In [2]:
data_type = "csvs" #"pickles"

folder = f'.\\data\\nni_{data_type}'

Make a list of the hyperparameters that were tested

In [3]:
params_list = ['n_rotates', 'n_scales', 'length_scale', 'lr', 'act_dis', 'state_dis', 'active_prop', 'eps'] 

Create a single dataframe containing the parameter sets that achieve the top 5% performance

In [4]:
count=0
for filename in os.listdir(folder):
    filepath = os.path.join(folder, filename)
    
    ## Read the data to a pandas dataframe
    if data_type == "csvs":
        df = pd.read_csv(filepath)
    elif data_type == "pickles":
        df = pd.read_pickle(filepath)
        
    df = str_to_arr(df, ['episodes', 'rewards'])
        
    ## Add the rolling mean rewards as a new column
    df = add_roll_mean(df)
    
    ## Calculate number of trials to goal rolling mean and add it as a new column
    df = trials_to_goal(df, 95)
    
    df_best = get_best(df, params_list)

    ## Only get the columns of interest
    df_best = df_best[['index'] + params_list]

    ## Remove duplicate rows
    df_best.drop_duplicates()

    ## Add exp column
    df_best['exp'] = filename
        
    if count == 0: 
        ## Make new joint df 
        df_best_all = df_best
        
    else: 
        df_best_all = pd.concat([df_best_all, df_best])
        
    count+=1
    

Show results

In [5]:
df_best_all

Unnamed: 0,index,n_rotates,n_scales,length_scale,lr,act_dis,state_dis,active_prop,eps,exp
0,0,4,1,1.0,0.14826,0.815828,0.981973,0.1,0.383644,ratbox100_discrete10.csv
1,1,4,1,1.0,0.009559,0.997999,0.978057,0.1,0.446527,ratbox100_discrete10.csv
2,2,4,1,1.0,0.003878,0.970027,0.879743,0.1,0.550991,ratbox100_discrete10.csv
3,3,4,1,1.0,0.004342,0.82777,0.920651,0.1,0.346343,ratbox100_discrete10.csv
4,4,4,1,1.0,0.006811,0.854644,0.828833,0.1,0.535418,ratbox100_discrete10.csv
0,0,4,1,1.0,0.34378,0.908315,0.908915,0.1,0.36112,ratbox100_discrete12.csv
1,1,4,1,1.0,0.008628,0.961214,0.828801,0.1,0.401866,ratbox100_discrete12.csv
2,2,4,1,1.0,0.023758,0.952302,0.849024,0.1,0.541772,ratbox100_discrete12.csv
3,3,4,1,1.0,0.00384,0.895973,0.857758,0.1,0.31886,ratbox100_discrete12.csv
4,4,4,1,1.0,0.004963,0.971761,0.924515,0.1,0.513579,ratbox100_discrete12.csv
