In [2]:
import wandb
import datetime
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

img_name = "icml_3d_" + datetime.datetime.now().strftime("%y-%m-%d-%H-%M-%S")

entity = "ocrl_benchmark"
project = "agent-learning"
default_tags = ["toplot", "sparserewtype", "hardmode"]

# envs
env_modes = ["hardmode"]
env_types = ["cwtargetN4C3S1S1"]
envs = np.zeros((len(env_modes), len(env_types)), dtype=object)
for i in range(len(env_modes)):
    for j in range(len(env_types)):
        envs[i, j] = [env_modes[i], env_types[j]]

# titles
titles = [
    "Object Reaching Task",
]

# colormap
cm = plt.cm.get_cmap("tab20").colors
# models
models = {
    "GT": {
        "tags": ["gt", "pooling-transformer", "toplot1"],
        "line_cm": cm[2],
        "range_cm": cm[3],
        "marker": "o",
        "fill_style": "full",
        "line_style": "dotted",
    },
    "CNN": {
        "tags": ["e2ecnn", "pooling-identity", "toplot1"],
        "line_cm": cm[4],
        "range_cm": cm[5],
        "marker": "v",
        "fill_style": "full",
        "line_style": "dashed",
    },
    "SLATE": {
        "tags": ["slate", "ocr", "pooling-transformer", "toplot1"],
        "line_cm": cm[6],
        "range_cm": cm[7],
        "marker": "s",
        "fill_style": "full",
        "line_style": "solid",
    },
    "VAE": {
        "tags": ["vae", "kld5", "singlevector", "pooling-mlp", "toplot1"],
        "line_cm": cm[0],
        "range_cm": cm[1],
        "marker": "P",
        "fill_style": "full",
        "line_style": "dashdot",
    },
}

print(f"Models: {models.keys()}")

top = 0.1
hspace = 0.13
bottom = 0.15 if envs.shape[0] != 1 else 0.27
height = (1 - top - bottom - hspace * (envs.shape[0] - 1)) / envs.shape[0]
left = 0.15
wspace = 0.07
#wspace = 0.07
right = 0.036
width = (1 - left - right - wspace * (envs.shape[1] - 1)) / envs.shape[1]
recs = np.zeros(envs.shape, dtype=object)
for i in range(recs.shape[0]):
    for j in range(recs.shape[1]):
        recs[i, j] = [
            left + j * (width + wspace),  # left
            bottom + (recs.shape[0] - 1 - i) * (height + hspace),  # bottom
            width,
            height,
        ]
plt.figure(figsize=(6.3 * recs.shape[1], 4.5 * recs.shape[0]))

results = {}
api = wandb.Api(timeout=19)
for i in range(envs.shape[0]):
    for j in range(envs.shape[1]):
        print(envs[i, j])
        ax = plt.axes(recs[i, j])
        for model_name, model_infos in models.items():
            if model_name not in results.keys():
                results[model_name] = {}
            if envs[i, j][0] not in results[model_name].keys():
                results[envs[i,j][-1]] = {}
            filters = []
            for env_tag in envs[i, j]:
                filters.append({"tags": env_tag})
            for model_tag in model_infos["tags"]:
                filters.append({"tags": model_tag})
            for default_tag in default_tags:
                filters.append({"tags": default_tag})
            if model_name == "IODINE":
                runs = api.runs(f"ocrl_benchmark/ocrl-synthetic-results", filters={"$and": filters})
            runs = api.runs(f"{entity}/{project}", filters={"$and": filters})
            global_steps = []
            success_rates = []
            for run in runs:
                print(run)
                history = run.scan_history(["global_step", "eval/success_rate"])
                _global_step = []
                _success_rate = []
                for row in history:
                    _global_step.append(row["global_step"])
                    _success_rate.append(row["eval/success_rate"])
                global_steps.append(_global_step)
                success_rates.append(_success_rate)
            if len(global_steps) == 0:  # no logs
                continue
            step = 1004000 // 4000
            # success rates
            for rate_idx in range(len(success_rates)):
                success_rates[rate_idx] = success_rates[rate_idx][:step]
                if len(success_rates[rate_idx]) < 251:
                    success_rates[rate_idx] += [success_rates[rate_idx][-1]] * (251-len(success_rates[rate_idx]))
            sr_steps = global_steps[-1][:step]
            success_rates = np.array(success_rates)
            sr_mean = np.mean(success_rates, axis=0)
            sr_std = np.std(success_rates, axis=0)
            
            results[model_name][envs[i,j][-1]]["step"] = sr_mean
            results[model_name][envs[i,j][-1]]["mean"] = sr_mean
            results[model_name][envs[i,j][-1]]["std"] = sr_mean

            (line,) = ax.plot(
                sr_steps,
                sr_mean,
                color=model_infos["line_cm"],
                label=model_name,
                linewidth=2,
            )
            plt.fill_between(
                sr_steps,
                sr_mean - sr_std,
                sr_mean + sr_std,
                color=model_infos["range_cm"],
                alpha=0.3,
            )
            models[model_name]["line"] = line
        ax.grid(True)
        ax.xaxis.offsetText.set_fontsize(20)
        axes = plt.gca()
        ax.set_ylim([0.0, 1.0])
        ax.set_xlim([0.0, 1e6])
        plt.yticks(fontsize=20)
        if j == 0:
            ax.set_ylabel("Success Rate", fontsize=20)
        else:
             ax.yaxis.set_ticklabels([])
             ax.yaxis.set_ticks_position("none")
        if i == envs.shape[0] - 1:
            plt.xticks(fontsize=18)
            ax.set_xlabel("Interaction Steps", fontsize=20)
        else:
            ax.xaxis.set_ticklabels([])
            ax.xaxis.set_ticks_position("none")
        ax.set_title(titles[j], fontsize=20)

legended = []
for key in models.keys():
    if "line" in models[key].keys():
        legended.append(models[key]["line"])
plt.figlegend(
    loc="lower center",
    ncol=4,
    fontsize=16,
    frameon=False,
    handles=legended,
)
plt.savefig(img_name + ".png", dpi=300)
#plt.close()

results
import json
with open("results/3d_globalstep.json", "w") as f:
    json.dump(results, f)


Models: dict_keys(['GT', 'CNN', 'SLATE', 'VAE'])
['hardmode', 'cwtargetN4C3S1S1']
<Run ocrl_benchmark/yifu-ocrl-cw-results/3bciaumy (killed)>
<Run ocrl_benchmark/yifu-ocrl-cw-results/2szy6l43 (killed)>
<Run ocrl_benchmark/yifu-ocrl-cw-results/o3wb6afr (killed)>
<Run ocrl_benchmark/yifu-ocrl-cw-results/1ua7ueh9 (crashed)>
<Run ocrl_benchmark/yifu-ocrl-cw-results/1laj1mkd (crashed)>
<Run ocrl_benchmark/yifu-ocrl-cw-results/2sgosr9b (killed)>
<Run ocrl_benchmark/yifu-ocrl-cw-results/3anfznvv (failed)>
<Run ocrl_benchmark/yifu-ocrl-cw-results/2fjlue6p (failed)>
<Run ocrl_benchmark/yifu-ocrl-cw-results/3j71q2a4 (killed)>
<Run ocrl_benchmark/yifu-ocrl-cw-results/10kbcvlp (failed)>
<Run ocrl_benchmark/yifu-ocrl-cw-results/1jmnlp7i (failed)>
<Run ocrl_benchmark/yifu-ocrl-cw-results/3ie5zsop (failed)>
