In [None]:
import wandb
import pandas as pd 
import json
from research.utils.plot_utils import PlotHandler as ph
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

In [None]:

api = wandb.Api()

# Project is specified by <entity/project-name>
runs = api.runs("mtm_team/exorl_experiments1")
summary_list = [] 
config_list = [] 
name_list = [] 
for run in runs: 
    # run.summary are the output key/values like accuracy.
    # We call ._json_dict to omit large files 
    summary_list.append(run.summary._json_dict) 

    # run.config is the input metrics.
    # We remove special values that start with _.
    config = {k:v for k,v in run.config.items() if not k.startswith('_')}
    config_list.append(config) 

    # run.name is the name of the run.
    name_list.append(run.name)       

summary_df = pd.DataFrame.from_records(summary_list) 
config_df = pd.DataFrame.from_records(config_list) 
name_df = pd.DataFrame({'name': name_list}) 
all_df = pd.concat([name_df, config_df,summary_df], axis=1)

# all_df.to_csv("project.csv")

In [None]:
run.history().keys()

In [None]:
run.history()["_step"]

In [None]:

cfg = json.loads(run.json_config)

cfg["dataset"]["value"]

In [None]:
def cfg_to_name(cfg):
    mask_ids = cfg["args"]["value"]["mask_indices"]
    tokenizers = cfg["tokenizers"]["value"]["states"]["_target_"]
    dataset = cfg["dataset"]["value"]["_target_"].split(".")[3]
    if "continuous" in tokenizers:
        t = "continuous"
    elif "uniform" in tokenizers:
        t = "discrete"
    else:
        t = tokenizers
    _name = f"maskids-{mask_ids} | {t} | {dataset}"
    return _name

In [None]:
cfg_to_name(cfg)

In [None]:
require = ["[0]", "[0, 1, 2, 3]", "[3]"]

k  = "f_dynamics_eval/mse_sum"
with ph.plot_context() as (fig, ax):
    for run in runs:
        h = run.history(keys=[k])
        cfg = json.loads(run.json_config)
        name = cfg_to_name(cfg)
        should = [f in name for f in require]
        use = any(should)
        if k in h.keys() and use:
            # h = h[h[k].notnull()]
            cfg = json.loads(run.json_config)
#             x = np.nan_to_num(h["_step"].to_numpy())
#             y = np.nan_to_num(h[k].to_numpy())
            x = h["_step"].to_numpy()
            y = h[k].to_numpy()
            ax.plot(x, y, label=name)
    ax.legend(bbox_to_anchor=(1.05, 1.0), loc='upper left')
    ax.set_title(k)
    ax.set_ylim(0, 30)
    plt.show(fig)    

In [None]:
h = run.history()
h[k]

In [None]:
for r in run.scan_history():
    print(r["_step"])

In [None]:
require = ["[0]", "[0, 1, 2, 3]", "[1]"]

k  = "goal_eval/mse_sum"
with ph.plot_context() as (fig, ax):
    for run in runs:
        h = run.history(keys=[k])
        cfg = json.loads(run.json_config)
        name = cfg_to_name(cfg)
        should = [f in name for f in require]
        use = any(should)
        if k in h.keys() and use:
            # h = h[h[k].notnull()]
            cfg = json.loads(run.json_config)
#             x = np.nan_to_num(h["_step"].to_numpy())
#             y = np.nan_to_num(h[k].to_numpy())
            x = h["_step"].to_numpy()
            y = h[k].to_numpy()
            ax.plot(x, y, label=name)
    ax.legend(bbox_to_anchor=(1.05, 1.0), loc='upper left')
    ax.set_ylim(0, 30)
    ax.set_title(k)
    plt.show(fig)    