In [None]:
import numpy as np
import wandb
from matplotlib import pyplot as plt
from util import populate_plt_settings, get_column_width, get_fig_size, get_latex_float
import re
import plotly.express as px
import pandas as pd

In [None]:
populate_plt_settings(plt)

In [None]:
api = wandb.Api()

runs = {
    'Training set': api.run('kennychufk/alluvion-rl/2amh43ojAug-val'),
    'Nephroid': api.run('kennychufk/alluvion-rl/2amh43ojAug-nephroid'),
    'Bidirectional': api.run('kennychufk/alluvion-rl/2amh43ojAuga-bidir-circles'),
    'Epitrochoid': api.run('kennychufk/alluvion-rl/2amh43ojAug-interesting-loop')
}

In [None]:
histories = {}

for run_name in runs:
    run = runs[run_name]
    histories[run_name] = run.scan_history(keys=None,
                           page_size=1000,
                           min_step=None,
                           max_step=None)

In [None]:
metric_name = 'eulerian_masked'
# metric_name = 'eulerian'

overall_name = f'overall-{metric_name}'


step_interval = 50
num_steps_per_sequence = 2000
num_artifacts = {}
overall_errors = {}
step_numbers = {}

for run_name in runs:
    num_artifacts_tmp = 0
    for row in histories[run_name]:
        num_artifacts_tmp+=((row['_step']+1)%step_interval==0)
    num_artifacts[run_name]=num_artifacts_tmp
    
    step_numbers[run_name] = (np.arange(num_artifacts_tmp)+1)*step_interval*num_steps_per_sequence

    overall_errors[run_name] = np.zeros(num_artifacts_tmp)

    artifact_id = 0
    for row_id, row in enumerate(histories[run_name]):
        if (row['_step']+1)%50!=0:
            continue
        for key in row:
            if key == overall_name:
                overall_errors[run_name][artifact_id] = row[key]
        artifact_id+=1

In [None]:
num_rows = 1
num_cols = 1
fig, ax = plt.subplots(num_rows, num_cols, figsize = get_fig_size(get_column_width(), ratio=(np.sqrt(5)-1)*0.5))


for run_name in ['Bidirectional', 'Training set', 'Nephroid', 'Epitrochoid']:
    ax.plot(step_numbers[run_name], overall_errors[run_name], label=run_name)
ax.set_xlabel('Steps')
ax.set_ylabel(r'Eulerian error ($\textup{m}^{2}\textup{s}^{-2}$)')
ax.legend()


fig.tight_layout(pad=0.05) # should set tight_layout before add_axes()
fig.savefig('learning-curve-generality.pgf')