In [None]:
from plot_results import fetch_run_data
import matplotlib.pyplot as plt
import numpy as np
import wandb
from functools import lru_cache
import matplotlib.dates as mdates
from datetime import datetime 


In [None]:

@lru_cache(maxsize=None)
def fetch_run_data(descriptor: str, metrics):
    if isinstance(metrics, str):
        metrics = [metrics]
    else:
        metrics = list(metrics)
    api = wandb.Api()
    runs = api.runs("cswinter/deep-codecraft-vs", {"config.descriptor": descriptor})
    
    curves = []
    for run in runs:
        step = []
        value = []
        vals = run.history(keys=metrics, samples=100, pandas=False)
        for entry in vals:
            if metrics[0] in entry:
                step.append(entry['_step'] * 1e-6)
                meanvalue = np.array([entry[metric] for metric in metrics]).mean()
                value.append(meanvalue)
        curves.append((np.array(step), np.array(value)))
    return curves, runs[0].summary["_timestamp"]

In [None]:
runs = [
    "154506-agents15-hpsetstandard-steps150e6",
    "24e131-agents15-hpsetstandard-steps150e6",
    "613056-agents15-hpsetstandard-steps150e6",
    "87c1ab-hpsetstandard",
    "8af81d-hpsetstandard-num_self_play30-num_vs_aggro_replicator1-num_vs_destroyer2-num_vs_replicator1",
    "d33903-batches_per_update32-batches_per_update_schedule-hpsetstandard-lr0.001-lr_schedulecosine-steps150e6",
    "49b7fa-entropy_bonus0.02-entropy_bonus_schedulelin 20e6:0.005,60e6:0.0-hpsetstandard",
    "49b7fa-feat_dist_to_wallTrue-hpsetstandard",
    "b9bab7-hpsetstandard-max_hardness150",
    "46e0b2-hpsetstandard-spatial_attnFalse",
    "2d9e29-hpsetstandard",
    "30ed5b-hpsetstandard-max_hardness175",
    "fc244e-hpsetstandard-spatial_attnTrue-spatial_attn_lr_multiplier10.0",
    "0a5940-hpsetstandard-item_item_attn_layers1-item_item_spatial_attnTrue-item_item_spatial_attn_vfFalse-max_grad_norm200",
    "0a5940-hpsetstandard-mothership_damage_scale4.0-mothership_damage_scale_schedulelin 50e6:1.0,150:0.0",
    "83a3af-hpsetstandard-mothership_damage_scale4.0-mothership_damage_scale_schedulelin 50e6:0.0",
    "667ac7-hpsetstandard",
    "80a87d-entropy_bonus0.15-entropy_bonus_schedulelin 15e6:0.07,60e6:0.0-hpsetstandard",
    "80a87d-entropy_bonus0.2-entropy_bonus_schedulelin 15e6:0.1,60e6:0.0-final_lr5e-05-hpsetstandard-lr0.0005-vf_coef1.0",
    "c0b3b4-hpsetstandard-partial_score0",
    "9fc3de-hpsetstandard",
    "9fc3de-adr_hstepsize0.001-hpsetstandard-linear_hardnessFalse",
    "ac84c0-gamma0.9997-hpsetstandard",
    "a1210b-gamma_schedulecos 1.0-hpsetstandard",
    "b9f907-adr_average_cost_target1-hpsetstandard",
    "5fb181-hpsetstandard",
    "5fb181-hpsetstandard-steps150e6",
    "3c69a5-adr_average_cost_target0.5-adr_avg_cost_schedulelin 80e6:1.0-hpsetstandard",
    "35b3a7-hpsetstandard-nearby_mapFalse-steps150e6",
    "152ec3-hpsetstandard-nearby_mapFalse-steps125e6",
]

In [None]:
fig, ax = plt.subplots(figsize=(19, 10))
cmap = plt.get_cmap('viridis')

t0 = 1593959023.8568478
tn = 1607756232
ts = []
for ri, run in enumerate(runs):
    #print(f"Fetching {run}")
    curves, date = fetch_run_data(run, "eval_mean_score")
    samples = []
    values = []
    for curve in curves:
        ax.plot(curve[0], curve[1], color=cmap((date-t0)/(tn-t0)), marker='o')
        for i, value in enumerate(curve[1]):
            if len(values) <= i:
                samples.append(curve[0][i])
                values.append([value])
            else:
                values[i].append(value)
    #values = np.array([np.array(vals).mean() for vals in values])
    #ax.plot(samples, values, color=cmap((date-t0)/(tn-t0)), marker='o')
    #ts.append(mdates.date2num(datetime.fromtimestamp(date)))

from matplotlib.cm import ScalarMappable
from matplotlib.colors import Normalize
loc = mdates.AutoDateLocator()
def dateformatter(x, pos=None):
    return datetime.fromtimestamp(x*(tn-t0)+t0).strftime('%Y-%m-%d')
fig.colorbar(ScalarMappable(cmap=cmap), ticks=loc, format=dateformatter)

ax.set_yticks([-1.0, -0.5, 0, 0.5, 1])
ax.set_xlim(0, 200)
ax.grid()
fig.show()

In [None]:
fig, ax = plt.subplots(figsize=(20, 15))
cmap = plt.get_cmap('viridis')

t0 = 1593959023.8568478
tn = 1607756232
ts = []
for ri, run in enumerate(runs):
    #print(f"Fetching {run}")
    curves, date = fetch_run_data(run, "eval_mean_score")
    samples = []
    values = []
    for curve in curves:
        for i, value in enumerate(curve[1]):
            if len(values) <= i:
                samples.append(curve[0][i])
                values.append([value])
            else:
                values[i].append(value)
    values = np.array([np.array(vals).mean() for vals in values])
    ax.plot(samples, values)#, color=cmap((date-t0)/(tn-t0)))
    #ts.append(mdates.date2num(datetime.fromtimestamp(date)))

#from matplotlib.cm import ScalarMappable
#from matplotlib.colors import Normalize
#loc = mdates.AutoDateLocator()
#fig.colorbar(ScalarMappable(norm=Normalize(t0, tn), cmap=cmap))#, ticks=loc, format=mdates.AutoDateFormatter(loc))

ax.set(xlabel='million samples', ylim=(-1, 1))
ax.set_yticks([-1.0, -0.5, 0, 0.5, 1])
ax.set_xlim(0, 200e6)
#ax.set_xticks([0, 25, 50, 75, 100, 125])
ax.legend(loc='upper left')
ax.grid()
fig.show()

In [None]:
api = wandb.Api()
runs = api.runs("cswinter/deep-codecraft-vs", {"config.descriptor": runs[0]})

In [None]:
runs

In [None]:
fetch_run_data(runs[-1], 'eval_mean_score')[1]

In [None]:
#help(runs[0])
{metric: values for metric, values in runs[0].summary.items() if metric.startswith('eval')}