In [23]:
from collections import defaultdict
import pandas as pd
import numpy as np

import wandb
api = wandb.Api()

In [32]:
entity, project = "nupic-research", "multitask_journal"
metric = "SuccessRate"
tasks = [
    "Average",
    "button-press-topdown-v2",
    "door-open-v2",
    "drawer-close-v2",
    "drawer-open-v2",
    "peg-insert-side-v2",
    "pick-place-v2",
    "push-v2",
    "reach-v2",
    "window-close-v2",
    "window-open-v2"
]
tasks_all_data = ["Average"]
runs_to_ignore = ["clzl304", "phhk331"]

def collect_runs(state_filter="finished", mv_avg_window=10):

    runs = api.runs(entity + "/" + project) 
    data = defaultdict(list)
    for idx, run in enumerate(runs): 
        if  (run.state == state_filter
            and ("wandb_group" in run.config and run.config["wandb_group"] == "Strategy 1")
            and run.id not in runs_to_ignore):
                # -- Collect main config
                # .name is the human-readable name of the run.
                data["id"].append(run.id)            
                data["name"].append(run.name)
                # .summary contains the output keys/values for metrics like accuracy.
                #  We call ._json_dict to omit large files 
                data["summary"].append(run.summary._json_dict)
                # .config contains the hyperparameters.
                #  We remove special values that start with _.
                data["config"].append(
                    {k: v for k,v in run.config.items()
                    if not k.startswith('_')})
                # pick and choose specific attributes from config
                data["net_type"].append(run.config["net_type"])

                if "wandb_group" in run.config:
                    data["group"].append(run.config["wandb_group"])
                else:
                    data["group"].append(None)        

                # -- Collect success rate
                keys = [f"{task}/{metric}" for task in tasks]
                history = run.scan_history(keys=keys)
                for key, task in zip(keys, tasks):
                    values = [row[key] for row in history]
                    data[task].append(np.mean(values[-mv_avg_window:]))
                    # exception for the Average
                    if task in tasks_all_data:
                        data[f"{task}_fullhist"].append(values)

    return pd.DataFrame(data)

runs_df = collect_runs()
runs_df.shape

(19, 18)

In [34]:
runs_df.to_pickle("data_for_plots_1_and_3.p")

In [42]:
entity, project = "nupic-research", "multitask_journal"
metric = "SuccessRate"
tasks = [
    "Average",
    "button-press-topdown-v2",
    "door-open-v2",
    "drawer-close-v2",
    "drawer-open-v2",
    "peg-insert-side-v2",
    "pick-place-v2",
    "push-v2",
    "reach-v2",
    "window-close-v2",
    "window-open-v2"
]
tasks_all_data = ["Average"]
runs_to_ignore = ["clzl304", "phhk331"]

def collect_runs_plot2(state_filter="finished", mv_avg_window=10):

    runs = api.runs(entity + "/" + project) 
    data = defaultdict(list)
    for idx, run in enumerate(runs): 
        if  (run.state == state_filter
            and (("wandb_group" in run.config and run.config["wandb_group"] != "Strategy 1")
                  or "wandb_group" not in run.config)
            and run.id not in runs_to_ignore):
                # -- Collect main config
                # .name is the human-readable name of the run.
                data["id"].append(run.id)            
                data["name"].append(run.name)
                # .summary contains the output keys/values for metrics like accuracy.
                #  We call ._json_dict to omit large files 
                data["summary"].append(run.summary._json_dict)
                # .config contains the hyperparameters.
                #  We remove special values that start with _.
                data["config"].append(
                    {k: v for k,v in run.config.items()
                    if not k.startswith('_')})
                # pick and choose specific attributes from config
                data["net_type"].append(run.config["net_type"])

                if "wandb_group" in run.config:
                    data["group"].append(run.config["wandb_group"])
                else:
                    data["group"].append(None)        

                # -- Collect success rate for average only
                history = run.scan_history(keys=["Average/SuccessRate"])
                values = [row["Average/SuccessRate"] for row in history]
                data["Average"].append(np.mean(values[-mv_avg_window:]))
                data["Average_fullhist"].append(values)

    return pd.DataFrame(data)

runs_df2 = collect_runs_plot2()
runs_df2.shape

ids_list = []
for name in runs_df2.name.unique():
    argmax = runs_df2[runs_df2["name"] == name]["Average"].argmax()
    ids_list.append(runs_df2[runs_df2["name"] == name].iloc[argmax].id)

ids_filter = runs_df2.id.isin(ids_list)
plot2_df = runs_df2[ids_filter]

runs_df2.shape, plot2_df.shape

(49, 8)

In [82]:
plot2_df.to_pickle("data_for_plots_2.p")

In [22]:
def get_run(run_id):
    entity, project = "nupic-research", "multitask"
    return api.run(entity + "/" + project + "/" + run_id) 

run_id = "hcob816"
run = get_run(run_id)
# entity, project = "nupic-research", "multitask"
# runs = api.runs(entity + "/" + project) 
# f = run.files()[10]
# f.name
# f.download()

In [38]:
def get_run(run_id):
    entity, project = "nupic-research", "multitask"
    return api.run(entity + "/" + project + "/" + run_id) 

run_id = "kidg700"
run = get_run(run_id)