In [1]:
import numpy as np
import pandas as pd
import ast

# Generic Functions

In [2]:
def get_training_hyper(row, hyper="seed"):
    hypers = ast.literal_eval(row["config/training"])
    return hypers[hyper]

def get_d_hypers(row, hyper="n_hidden_units"):
    hypers = ast.literal_eval(row["config/discriminator"])
    return hypers[hyper] 

def get_g_hypers(row, hyper="n_hidden_units"):
    hypers = ast.literal_eval(row["config/generator"])
    return hypers[hyper] 

def get_expanded_df(results_df):
    
    # List the training hyperparameters
    training_hypers = ["seed", "d_lr", "g_lr", "gamma", "momentum", "step_size"]

    # List the discriminator and generator network hyperparameters
    nn_hypers = ["n_hidden_units", "n_hidden_layers"]

    # Add a new column for each training hyperparameter
    for hyper in training_hypers:
        results_df[hyper] = results_df.apply(lambda row: get_training_hyper(row, hyper), axis=1)

    # Add a new column for each discriminator and generator hyperparameter
    for hyper in nn_hypers:
        results_df["d_"+hyper] = results_df.apply(lambda row: get_d_hypers(row, hyper), axis=1)
        results_df["g_"+hyper] = results_df.apply(lambda row: get_g_hypers(row, hyper), axis=1)
        
    return results_df

# 1. EXP

In [12]:
# Read in the EXP results
exp_results = pd.read_csv("EXP_results.csv", usecols=np.arange(1,22))
exp_results = get_expanded_df(exp_results)

In [13]:
# Get all rows for the lowest MSE for each random seed
exp_results.loc[exp_results.groupby("seed")["mean_squared_error"].idxmin()]

Unnamed: 0,mean_squared_error,time_this_iter_s,done,timesteps_total,episodes_total,training_iteration,experiment_id,date,timestamp,time_total_s,...,seed,d_lr,g_lr,gamma,momentum,step_size,d_n_hidden_units,g_n_hidden_units,d_n_hidden_layers,g_n_hidden_layers
710,3.626995e-08,0.218002,False,,,199,91099197f8834cf3a85f6ae57c011b31,2021-06-09_11-57-53,1623254273,50.878741,...,0,0.074483,0.065553,0.995229,0.929729,6,30,40,2,3
40,3.748609e-11,0.450999,False,,,48,d3a7b04453ef4acea03c632f224e376b,2021-06-09_08-14-51,1623240891,22.482649,...,1,0.067565,0.097713,0.993886,0.942638,15,20,30,3,4
888,2.487941e-14,0.582383,False,,,138,8bb58dd6ea32416ebc9ccfcf75db3533,2021-06-09_13-21-46,1623259306,88.656755,...,2,0.083371,0.073697,0.990045,0.947535,11,40,40,3,2
772,1.317922e-08,0.29,False,,,178,f6c31c6e6ba141bb99529801eb50c76d,2021-06-09_12-18-13,1623255493,54.741657,...,3,0.076678,0.048351,0.991895,0.927196,18,40,20,4,4
608,3.046888e-11,0.280822,False,,,42,397dfa2eb4b84e24a3a6586700c4465d,2021-06-09_11-26-26,1623252386,11.630288,...,4,0.080288,0.097001,0.995761,0.928902,20,20,30,4,4
925,1.494417e-08,1.331141,False,,,44,9c0a78c70afa4278807eaa91d3b6857b,2021-06-09_13-48-15,1623260895,22.494441,...,5,0.074211,0.080597,0.998966,0.951447,14,20,20,3,2
494,7.789988e-08,0.519998,False,,,192,301f336152f54aa7927ba42bb175b6a4,2021-06-09_10-47-15,1623250035,115.489961,...,6,0.044028,0.092102,0.994797,0.932087,10,20,30,4,4
656,2.10748e-07,0.217001,True,,,200,446e7b785a8342cfb50e24785bf3064a,2021-06-09_11-43-20,1623253400,48.129216,...,7,0.024223,0.028349,0.999897,0.926489,19,20,40,2,4
839,1.201932e-08,0.68324,False,,,194,0a0ba3a9f25e4c1ea34b0cc32e23d092,2021-06-09_12-53-42,1623257622,153.647188,...,8,0.086091,0.095919,0.99784,0.942006,9,30,30,4,3
749,3.705156e-15,0.279523,False,,,49,e41befcaf1f745d49c7acf1735b49763,2021-06-09_12-09-29,1623254969,15.118922,...,9,0.094798,0.068751,0.997908,0.961214,15,30,20,4,4


In [7]:
best_per_seed.mean(axis=0)

mean_squared_error          3.651285e-08
time_this_iter_s            4.853109e-01
done                        1.000000e-01
timesteps_total                      NaN
episodes_total                       NaN
training_iteration          1.284000e+02
timestamp                   1.623254e+09
time_total_s                5.832698e+01
pid                         1.405880e+04
time_since_restore          5.832698e+01
timesteps_since_restore     0.000000e+00
iterations_since_restore    1.284000e+02
seed                        4.500000e+00
d_lr                        7.057363e-02
g_lr                        7.480333e-02
gamma                       9.956225e-01
momentum                    9.389243e-01
step_size                   1.370000e+01
n_hidden_units              3.000000e+01
n_hidden_layers             3.400000e+00
dtype: float64

In [None]:
# early abandonment: restart the training if LHS goes up fast (around 100 epochs)
# adaptive learning rates, noise
# goal: e-09 aveage for EXP
# try larger networks

# 2. SHO

In [14]:
# Read in the SHO results
sho_results = pd.read_csv("SHO_results.csv", usecols=np.arange(1,22))
sho_results = get_expanded_df(sho_results)

In [15]:
# Get all rows for the lowest MSE for each random seed
sho_results.loc[sho_results.groupby("seed")["mean_squared_error"].idxmin()]

Unnamed: 0,mean_squared_error,time_this_iter_s,done,timesteps_total,episodes_total,training_iteration,experiment_id,date,timestamp,time_total_s,...,seed,d_lr,g_lr,gamma,momentum,step_size,d_n_hidden_units,g_n_hidden_units,d_n_hidden_layers,g_n_hidden_layers
4,1.380299e-10,0.381963,False,,,846,9894ec6c9faf4df9b614b3bcbd2c34e7,2021-06-11_09-48-49,1623419329,334.988411,...,0,0.084104,0.069159,0.994216,0.916283,18,20,30,2,2
90,2.577914e-07,0.436886,False,,,396,d2db588b6251490a9580395c5c194379,2021-06-11_14-47-39,1623437259,177.537283,...,1,0.018642,0.043509,0.997259,0.909367,9,30,40,3,2
1,2.288413e-08,0.491001,False,,,986,4e62f1013160416cb921dff1bb617fc1,2021-06-11_09-33-53,1623418433,495.873694,...,2,0.01618,0.004661,0.996712,0.950134,18,40,30,4,2
75,2.479653e-10,0.443586,False,,,635,d10b054a2f1343c880920ce5cfd4d2ca,2021-06-11_14-22-08,1623435728,363.160771,...,3,0.077784,0.076084,0.993551,0.942164,12,20,40,2,3
63,2.71836e-06,0.824998,False,,,398,ab7d02ad81c94830a7485b1b110b98bf,2021-06-11_13-59-08,1623434348,328.340985,...,4,0.060863,0.01763,0.996388,0.938141,17,40,40,4,4
11,3.303208e-10,0.486004,False,,,887,b4a24b46a8f34a72a1984103d3ebf21a,2021-06-11_10-14-48,1623420888,426.529831,...,5,0.064445,0.018871,0.992179,0.943717,20,40,40,4,2
98,0.03049682,0.361995,True,,,400,7fa8fcffe9524f55af3fb1fdd2cec30e,2021-06-11_15-01-06,1623438066,148.14029,...,6,0.07071,0.079077,0.990227,0.921346,9,40,20,2,2
39,8.936116e-07,0.628546,False,,,390,3dfda59fd8c04c56ab92b0e0888a17c2,2021-06-11_11-08-28,1623424108,296.402668,...,7,0.038719,0.039024,0.994312,0.957289,11,20,40,2,4
10,4.189644e-10,0.581406,False,,,955,ba48cdb3b5c84589b90551b25e590440,2021-06-11_10-07-07,1623420427,690.117605,...,8,0.070874,0.027802,0.998684,0.908396,19,20,30,3,4
78,4.377555e-08,0.430998,False,,,393,fccfadcf22544d0daed02dd7f818d36d,2021-06-11_14-29-58,1623436198,171.268598,...,9,0.034648,0.071626,0.999795,0.940977,14,40,20,2,3


# 3. NLO

In [3]:
# Read in the NLO results
nlo_results = pd.read_csv("NLO_results.csv", usecols=np.arange(1,22))
nlo_results = get_expanded_df(nlo_results)

In [4]:
# Get all rows for the lowest MSE for each random seed
nlo_results.loc[nlo_results.groupby("seed")["mean_squared_error"].idxmin()]

Unnamed: 0,mean_squared_error,time_this_iter_s,done,timesteps_total,episodes_total,training_iteration,experiment_id,date,timestamp,time_total_s,...,seed,d_lr,g_lr,gamma,momentum,step_size,d_n_hidden_units,g_n_hidden_units,d_n_hidden_layers,g_n_hidden_layers
98,0.02614585,0.558571,True,,,800,72b4351df77048bf9394b79dc79abc91,2021-07-17_06-42-28,1626518548,412.898974,...,0,0.033364,0.061963,0.994488,0.926689,11,40,40,4,2
141,0.01706401,1.359611,False,,,1175,5ae8a410fde64c1db93be2a177e2cb3d,2021-07-17_10-11-32,1626531092,1243.346731,...,1,0.080263,0.055662,0.99514,0.960317,6,50,50,4,5
191,0.02559425,0.480864,False,,,205,3ef605526ece4c07bca283afd214323f,2021-07-17_14-36-18,1626546978,96.693987,...,2,0.055791,0.019971,0.994357,0.960169,7,40,30,2,3
192,0.02125287,0.549423,False,,,816,8d3f097123de4bafbc7f3c95dbb140dd,2021-07-17_14-48-31,1626547711,433.294434,...,3,0.01104,0.082327,0.991101,0.987068,10,50,30,2,3
44,6.475934e-08,1.455528,False,,,1951,15731f5a9e064ad4a6ba4bab107eecef,2021-07-17_01-04-52,1626498292,2079.083102,...,4,0.024206,0.078354,0.999127,0.944401,10,20,50,5,5
99,0.01118759,0.789633,False,,,1998,be45d5a4164c48f4a2d7297b13821b53,2021-07-17_07-10-48,1626520248,1693.237217,...,5,0.094169,0.082332,0.996459,0.987446,17,20,50,2,4
172,0.0081892,0.43927,True,,,2000,da9adfcd67e34dfd8e35290b1065c023,2021-07-17_13-00-14,1626541214,918.846905,...,6,0.008646,0.04286,0.999485,0.914401,12,20,20,4,2
135,0.02776054,0.810272,False,,,143,52604c07f9c2469f922caff5c5022982,2021-07-17_09-31-12,1626528672,116.161511,...,7,0.080861,0.065194,0.99404,0.976775,7,30,40,2,5
178,0.01893046,0.686707,False,,,1541,a3668745aa7c4069b6d0198f2746f463,2021-07-17_13-41-26,1626543686,1108.803679,...,8,0.070149,0.088745,0.991925,0.990609,11,30,50,3,3
122,0.022253,0.449004,True,,,2000,4a1b9bbbe8774e0ea11b72c0348b023e,2021-07-17_08-43-20,1626525800,869.495967,...,9,0.018547,0.06322,0.998006,0.945189,17,50,30,2,2
