In [1]:
import torch
import pandas as pd

import json
from math import sqrt

In [2]:
Z_99_PERCENT = 2.576

def ci(t):
    return (
        Z_99_PERCENT *
        (t.std() / sqrt(t.size(0)))
    ).item()

def to_dict(t):
    t = t.float().sort().values
    quartile = t.size(0) // 4
    trunc = t[quartile:-quartile]
    trunc_mean = trunc.mean().item()
    ci_range = ci(trunc)

    return {
        'mean': t.mean().item(),
        'max': t.max().item(),
        'min': t.min().item(),
        'std': t.std().item(),
        'trunc-mean': trunc_mean,
        'CI-low': trunc_mean - ci_range,
        'CI-high': trunc_mean + ci_range
    }

def get_results(s):
    data = torch.load(f'../results/{s}_eval.pt')
    lens = data['lens']; rews = data['rews']

    return {'rewards': to_dict(rews), 'episode lens': to_dict(lens)}


In [5]:
import pandas as pd
def compare_one(n, dir):
    d_last = get_results(f'{dir}dqn_{n}N_0_last')
    #d_best = get_results(f'ppo_{n}N_0')

    last_r = d_last['rewards']
    last_r['n'] = n
    last_l = d_last['episode lens']
    last_l['n'] = n

    '''
    best_r = d_best['rewards']
    best_r['name'] = 'Best'
    best_l = d_best['episode lens']
    best_l['name'] = 'Best'
    '''

    return last_r, last_l

def eval(dir='', extra=[]):
    rs, ls = zip(*[
        compare_one(n, dir)
        for n in [10,20,40] + extra
    ])
    return pd.DataFrame(rs).transpose()#, pd.DataFrame(ls)


In [6]:
eval(dir='doorman_gamma1/')

Unnamed: 0,0,1,2
mean,442.728424,-23.103765,-18.027456
max,599.0,599.0,599.0
min,-95.699997,-88.5,-71.25
std,143.926498,146.813461,93.543579
trunc-mean,472.898651,-79.077591,-56.525894
CI-low,464.467513,-79.342663,-56.918476
CI-high,481.329789,-78.812519,-56.133312
n,10.0,20.0,40.0


In [4]:
eval(dir='stablebaseline_default_naive_sgd/')

Unnamed: 0,0,1,2
mean,440.204254,325.803619,557.552979
max,599.0,599.0,599.0
min,160.59964,148.399658,399.224976
std,129.3302,163.965408,49.178612
trunc-mean,478.156006,277.79837,569.960022
CI-low,469.154557,261.907113,563.637688
CI-high,487.157454,293.689628,576.282356
n,10.0,20.0,40.0


In [5]:
eval(dir='stablebaseline_default_sgd/')

Unnamed: 0,0,1,2
mean,487.072815,339.814758,597.827637
max,599.0,599.0,599.0
min,195.299576,145.249664,255.700867
std,114.713966,165.792297,17.60755
trunc-mean,531.851624,303.050323,598.983398
CI-low,526.58382,285.07147,598.981469
CI-high,537.119427,321.029177,598.985328
n,10.0,20.0,40.0


In [6]:
eval(dir='stablebaseline_default_sgd_N40/', extra=[100,250,500])

Unnamed: 0,0,1,2,3,4,5
mean,597.083618,598.389648,597.827637,590.114136,479.579773,503.41449
max,599.0,599.0,599.0,599.0,599.0,599.0
min,223.699341,321.750427,255.700867,246.558746,362.796173,473.556122
std,23.751122,12.397595,17.60755,47.600197,89.622864,32.952225
trunc-mean,598.940796,598.968018,598.983398,598.994141,463.429382,493.600189
CI-low,598.932775,598.964101,598.981469,598.993335,452.133663,493.089901
CI-high,598.948817,598.971935,598.985328,598.994946,474.725101,494.110477
n,10.0,20.0,40.0,100.0,250.0,500.0


In [7]:
eval(dir='stablebaseline_default_naive_sgd_N40/', extra=[100,250,500])

Unnamed: 0,0,1,2,3,4,5
mean,562.95282,577.624817,557.552979,518.135681,501.937561,514.394043
max,599.0,599.0,599.0,599.0,599.0,599.0
min,380.19986,444.349609,399.224976,274.940216,424.227722,499.211945
std,53.749527,37.453655,49.178612,44.233551,20.586145,8.988159
trunc-mean,587.793701,597.652771,569.960022,504.96579,500.116119,514.046265
CI-low,583.786808,597.149851,563.637688,504.094907,499.655599,513.758194
CI-high,591.800595,598.155691,576.282356,505.836672,500.57664,514.334335
n,10.0,20.0,40.0,100.0,250.0,500.0


In [10]:
eval(dir='transductive/')

Unnamed: 0,0,1,2
mean,254.316864,-79.509003,-57.863049
max,599.0,-68.849998,-43.449989
min,153.89975,-89.300003,-71.099998
std,148.915848,3.110183,4.454319
trunc-mean,184.013763,-79.626801,-58.048103
CI-low,178.441215,-79.827142,-58.321227
CI-high,189.586312,-79.426459,-57.77498
n,10.0,20.0,40.0
