In [1]:
from pathlib import Path
import re
from collections import defaultdict
import pandas as pd
from statistics import mean, stdev

In [5]:
results = defaultdict(list)

for log_dir in [i for i in Path("..").glob("slurm-*.out")]:
    with open(log_dir, "r") as f:
        log = f.read().split("\n")
        params = [i for i in log if i.startswith("Params")][0].split(" ")[1]
        epochs = [i for i in log if i.startswith("PPO epochs")][0].split(" ")[-1]
        params += f"-{epochs}"
        reward = float([i for i in log if i.startswith(" Largest")][0].split(" ")[-1])
        results[params].append(reward)

In [6]:
{k: (mean(v), stdev(v)) for k, v in results.items()}

{'baseline-10': (-359.63, 153.65118253608347),
 'baseline-pretrained_transe-1': (-285.88, 163.5393707627423),
 'xlvin-no_exe-pretrained_transe-10': (-294.7, 159.03834003157854),
 'xlvin-no_exe-10': (-356.79217326732675, 153.83374989517742),
 'xlvin-no_exe-pretrained_transe-1': (-197.5375, 94.71711927046101),
 'baseline-pretrained_transe-10': (-239.4924801980198, 156.5859804331391),
 'xlvin-pretrained_transe-1': (-204.4900297029703, 140.27003084654214),
 'xlvin-no_exe-1': (-449.082, 105.6866427152507),
 'xlvin-pretrained_transe-10': (-207.611, 138.57354068470192),
 'xlvin-1': (-443.4245, 104.29428018201291),
 'xlvin-10': (-345.9785, 153.1277262966752),
 'baseline-1': (-373.8405, 178.98283135761784)}

In [7]:
for k, v in results.items():
    print(f"{k}," + ",".join(str(i) for i in v))

baseline-10,-384.38,-498.83,-500.0,-500.0,-390.52,-238.93,-500.0,-103.42,-176.47,-417.17,-139.07,-102.62,-500.0,-500.0,-157.91,-500.0,-279.77,-338.98,-487.47,-477.06
baseline-pretrained_transe-1,-86.61,-135.52,-500.0,-117.34,-228.87,-211.42,-257.72,-500.0,-90.37,-331.95,-500.0,-213.77,-90.24,-468.93,-500.0,-380.64,-162.94,-91.39,-482.09,-367.8
xlvin-no_exe-pretrained_transe-10,-154.02,-172.7,-205.98,-500.0,-400.62,-500.0,-482.11,-201.29,-191.04,-240.67,-89.54,-499.59,-494.67,-315.58,-198.32,-466.16,-100.32,-441.45,-93.19,-146.75
xlvin-no_exe-10,-500.0,-460.84,-494.8,-312.09,-489.71,-133.17,-495.68,-118.07,-370.93,-489.65346534653463,-206.19,-104.55,-128.9,-224.69,-358.11,-494.89,-500.0,-473.91,-279.66,-500.0
xlvin-no_exe-pretrained_transe-1,-121.69,-119.39,-178.76,-85.27,-129.54,-274.72,-119.88,-85.24,-286.89,-299.21,-159.89,-174.48,-235.42,-196.91,-243.29,-178.89,-195.05,-169.62,-500.0,-196.61
baseline-pretrained_transe-10,-500.0,-310.03960396039605,-116.14,-433.59,-129.97,-500.0,-179

In [17]:
for i in range(0, 20, 5):
    print({k: mean(v[i:i+5]) for k, v in results.items()})

{'transe-baseline': -349.606, 'xlvin': -198.25, 'ppo-baseline': -296.986}
{'transe-baseline': -286.234, 'xlvin': -308.242, 'ppo-baseline': -258.662}
{'transe-baseline': -252.71, 'xlvin': -378.132, 'ppo-baseline': -223.782}
{'transe-baseline': -353.918, 'xlvin': -248.964, 'ppo-baseline': -297.134}


In [2]:
def find_first_with(log, prefix):
    return [i for i in log if i.startswith(prefix)][0]

In [3]:
root = Path("..")

In [4]:
def get_table(log, n, prefix):
    test_start_idx = [i for i, l in enumerate(log) if l.startswith("┡")][n]
    test_end_idx = [i for i, l in enumerate(log) if l.startswith("└")][n]
    test_lines = log[test_start_idx+1:test_end_idx]
    split_lines = [i.split("│") for i in test_lines]
    test_pairs = [(i[1].strip(), float(i[2].strip())) for i in split_lines]
    test_metrics = [(i.split("/")[-1], j) for i, j in test_pairs if i.startswith(prefix)]
    test_metrics = [(f"{prefix}_{i}" if not i.startswith(prefix) else i, float(j)) for i, j in test_metrics]
    return test_metrics

In [5]:
per_alg_metrics = defaultdict(list)
for log_dir in [i for i in Path("..").glob("slurm-*.out")]:
    with open(log_dir, "r") as f:
        log = f.read().split("\n")

        experiment = log[0]
        alg, rpt, gated = re.search(r"Experiment ([A-Za-z]*)_([0-9]*)_(.*)", experiment).groups()
        alg = alg + "_" + gated
        
        weights_line = find_first_with(log, "Loaded model weights from checkpoint at ")
        best_ckpt_step = re.search(r"(.*)step=([0-9]*)(.*)", weights_line).groups()[1]

        val_metric_last = get_table(log, 2, "val")
        test_metric_last = get_table(log, 3, "test")
        val_metric_best = get_table(log, 0, "val")
        test_metric_best = get_table(log, 1, "test")
        
        
        metrics_best = [('step', int(best_ckpt_step)), ('repeat', int(rpt))] + val_metric_best + test_metric_best
        metrics_last = [('step', 10000), ('repeat', int(rpt))] + val_metric_last + test_metric_last
        per_alg_metrics[alg].append(dict(metrics_best))
        per_alg_metrics[alg].append(dict(metrics_last))

In [6]:
per_alg_metrics.keys()

dict_keys(['ValueIterationSampler_gated', 'ValueIterationSampler_no_gated'])

In [7]:
for key in ['ValueIterationSampler_gated', 'ValueIterationSampler_no_gated']:
    print(key)
    with pd.option_context('display.max_rows', None, 'display.max_columns', None): 
        df = pd.DataFrame(per_alg_metrics[key]).sort_values(by=['repeat']).reset_index(drop=True)
        df['best'] = df.step == 10000
        df = df.sort_values(by=['best', 'repeat']).reset_index(drop=True)
        print(df.to_csv(index=False))

ValueIterationSampler_gated
step,repeat,value,val_loss,val_pi,test_value,test_loss,test_pi,best
256,0,0.7758354544639587,0.9329655766487122,1.0,3.5566787719726562,3.8120837211608887,0.97119140625,False
256,1,1.36759352684021,1.5341386795043945,1.0,5.34912109375,5.597189903259277,1.0,False
256,2,0.6811755895614624,0.8364886045455933,1.0,3.024841785430908,3.263089895248413,0.99169921875,False
10000,0,0.011533566750586033,0.025231879204511642,0.99609375,0.017449162900447845,0.053389281034469604,0.99169921875,True
10000,1,0.05284559726715088,0.06704340130090714,0.998046875,0.038589708507061005,0.0708828866481781,0.98974609375,True
10000,2,0.033453233540058136,0.04547455906867981,1.0,0.1330200731754303,0.25534147024154663,0.9775390625,True

ValueIterationSampler_no_gated
step,repeat,value,val_loss,val_pi,test_value,test_loss,test_pi,best
256,0,1.6914010047912598,1.8476405143737793,1.0,6.979653358459473,7.24434757232666,0.97607421875,False
480,1,0.42428404092788696,0.5053846836090088,1.0,1.4

In [None]:
dfs_df = pd.DataFrame(per_alg_metrics['DFSSampler']).sort_values(by=['repeat']).reset_index(drop=True)

In [31]:
with pd.option_context('display.max_rows', None, 'display.max_columns', None): 
    display(dfs_df)

Unnamed: 0,step,repeat,val_time,val_color,val_d,val_f,val_pi_h,val_s,val_s_last,val_s_prev,val_u,val_v,val_loss,val_pi,test_time,test_color,test_d,test_f,test_pi_h,test_s,test_s_last,test_s_prev,test_u,test_v,test_loss,test_pi
0,110,0,0.000128,0.923975,0.002999,0.005632,0.851771,0.98969,0.833682,0.837654,0.890258,0.649377,0.223682,0.994141,0.000224,0.928523,0.042222,0.091383,0.497388,0.999332,0.790824,0.612357,0.885517,0.508551,2.297445,0.473633
1,10000,0,0.00013,0.986489,0.000645,0.00061,0.994332,0.998355,0.991615,0.993735,0.992049,0.970012,0.018074,1.0,0.000429,0.878968,0.030236,0.065549,0.677656,0.982008,0.915698,0.669193,0.962685,0.69393,37.201763,0.188477
2,121,1,0.000117,0.925475,0.003368,0.013767,0.840603,0.993875,0.807149,0.843222,0.859966,0.647586,0.242261,0.994141,0.000262,0.934372,0.046767,0.074749,0.473972,0.999681,0.785308,0.614996,0.787003,0.564345,3.422422,0.260254
3,10000,1,9.7e-05,0.990605,0.000525,0.000434,0.997379,1.0,1.0,0.998082,1.0,0.991193,0.006391,1.0,0.000253,0.868816,0.035779,0.049603,0.609648,0.984094,0.862032,0.686662,0.885462,0.70859,6.280212,0.704102
4,128,2,0.000154,0.916672,0.003648,0.004253,0.858063,0.985701,0.821251,0.851767,0.879325,0.682412,0.220581,1.0,0.000216,0.923075,0.046005,0.081735,0.613199,0.992807,0.799015,0.64022,0.805163,0.583552,3.544262,0.315918
5,10000,2,0.000124,0.98969,0.00062,0.00066,0.996961,1.0,0.998641,0.996347,0.998641,0.965886,0.02365,0.996094,0.000213,0.883142,0.030441,0.056074,0.557535,0.978,0.885393,0.656977,0.944747,0.737868,39.796669,0.211426
6,118,3,0.00015,0.914465,0.004725,0.006214,0.84183,0.988623,0.764806,0.836666,0.855906,0.62668,0.253443,0.996094,0.00049,0.890811,0.034745,0.063905,0.538611,0.990008,0.701083,0.56055,0.75568,0.573712,3.081721,0.416504
7,10000,3,0.000127,0.989729,0.00057,0.000536,0.997422,1.0,0.993958,0.996346,0.996619,0.96495,0.012559,1.0,0.000193,0.828299,0.028893,0.056837,0.584738,0.943039,0.785099,0.682952,0.813991,0.691677,11.942507,0.429199
8,140,4,0.000174,0.926839,0.002216,0.003841,0.884708,0.98669,0.832118,0.868605,0.885376,0.66493,0.200212,0.998047,0.000283,0.890331,0.04255,0.085215,0.530692,0.999668,0.657225,0.678173,0.749651,0.542238,2.169616,0.393555
9,10000,4,0.000118,0.988598,0.000531,0.000485,0.997859,1.0,0.998162,0.998384,1.0,0.986329,0.007838,1.0,0.000339,0.814246,0.028648,0.05792,0.609427,0.995329,0.885253,0.612892,0.931067,0.71999,10.793756,0.509277


In [32]:
bf_df = pd.DataFrame(per_alg_metrics['BellmanFordSampler']).sort_values(by='repeat').reset_index(drop=True)

In [33]:
with pd.option_context('display.max_rows', None, 'display.max_columns', None): 
    display(bf_df)

Unnamed: 0,step,repeat,val_d,val_msk,val_pi_h,val_loss,val_pi,test_d,test_msk,test_pi_h,test_loss,test_pi
0,997,0,0.004049,1.0,0.990764,0.067764,0.986328,0.005388,1.0,0.91152,0.588678,0.890625
1,10000,0,0.001441,1.0,0.997005,0.016976,0.994141,0.002303,1.0,0.97558,0.184152,0.970215
2,967,1,0.016572,1.0,0.988111,0.091783,0.970703,0.022943,1.0,0.893472,0.500055,0.924805
3,10000,1,0.002578,1.0,0.99496,0.013153,1.0,0.002619,1.0,0.954131,0.171866,0.976074
4,1363,2,0.004435,1.0,0.987326,0.03348,0.996094,0.002882,1.0,0.941023,0.3176,0.961914
5,10000,2,0.006223,1.0,0.990208,0.095521,0.978516,0.009307,1.0,0.921834,0.306496,0.961914
6,1087,3,0.003798,1.0,0.980679,0.055809,0.998047,0.003055,1.0,0.89579,0.433687,0.939453
7,10000,3,0.002176,1.0,0.994722,0.016356,0.998047,0.001535,1.0,0.966257,0.156265,0.976562
8,1234,4,0.019248,1.0,0.98566,0.08812,0.984375,0.009364,1.0,0.89729,0.552906,0.944336
9,10000,4,0.010181,1.0,0.99664,0.047425,0.994141,0.006899,1.0,0.962572,0.194967,0.969238
