In [None]:
import numpy as np
from numpy import linalg as LA
from matplotlib import pyplot as plt
from util import populate_plt_settings, get_column_width, get_fig_size, get_latex_float
import matplotlib
import wandb

In [None]:
populate_plt_settings(plt)

In [None]:
api = wandb.Api(timeout=30)

# run_ids = [
#     'kennychufk/alluvion-rl/2dvg1pj7', # KL v
#     'kennychufk/alluvion-rl/2v8uur8r', # density
#     'kennychufk/alluvion-rl/3f0xin9v', # shape
#     'kennychufk/alluvion-rl/3lk6qkjz', # v
# ]

run_ids = [
    'kennychufk/alluvion-rl/2nw7bxr4', # statistical (kldiv)
    'kennychufk/alluvion-rl/e290336m', # eulerian
    'kennychufk/alluvion-rl/1g69ksir', # geometrical
]

In [None]:
labels = [
    'Statistical reward',
    'Eulerian reward',
    'Geometric reward',
]

In [None]:
score100_curves = []
score_curves = []

for i in range(len(run_ids)):
    run = api.run(run_ids[i])
    score100_curve = []
    score_curve = []
    history = run.scan_history(keys=None,
                               page_size=1000,
                               min_step=None,
                               max_step=None)

    for row_id, row in enumerate(history):
        episode_id = row_id + 1
        if (row_id != row['_step']):
            print('step id mismatch')
        if 'score' in row:
            score_curve.append(row['score'])
        if 'score100' in row:
            score100_curve.append(row['score100'])
        else:
            score100_curve.append(np.nan)
    score100_curves.append(score100_curve)
    score_curves.append(score_curve)
    print('finished', run_ids[i])

In [None]:
for i in range(len(run_ids)):
    score_curve = score_curves[i]
    plt.plot(score_curve/-np.min(score_curve))

In [None]:
fig, ax = plt.subplots(1, 1, figsize = get_fig_size(get_column_width()))

for i in range(len(run_ids)):
    score100_curve = score100_curves[i]
    ax.plot(score100_curve/-np.min(np.nan_to_num((score100_curve))), label=labels[i])
ax.legend()
ax.set_title('Comparison of reward functions')
ax.set_ylabel('Moving-average score') 
ax.set_xlabel('Episode')
fig.savefig('reward-comparison.pgf', bbox_inches='tight')