In [2]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import wandb

In [3]:
api = wandb.Api()
runs = api.runs("[ENTER YOUR WANDB ACCOUNT]")  # Replace here with your wandb account

In [4]:
from collections import defaultdict
data = defaultdict(lambda: defaultdict(list))  # {"env": {"algo": [result, result, ...]}}

for run in runs: 
    env_id = run.config["env_id"]
    seed = run.config["seed"]

    if run.config["env_random_prob"] != 0:  # we use the results with env_random_prob == 0
        continue

    is_munchausen = run.config["kl_coef"] != 0 and run.config["ent_coef"] != 0
    is_weighted = run.config["weight_type"] == "variance-net"

    if is_weighted and run.config["weight_epsilon"] != 0.1:  # we use the results with weight_epsilon == 0.1
        continue

    if is_munchausen and is_weighted:
        algo = "DVW M-DQN"
    elif is_munchausen and not is_weighted:
        algo = "M-DQN"
    elif not is_munchausen and is_weighted:
        algo = "DVW DQN"
    elif not is_munchausen and not is_weighted:
        algo = "DQN"

    steps, returns = [], []
    for _, row in run.history(keys=["charts/episodic_return", "global_step"]).iterrows():
        # global_step = round(row["global_step"] / 2e5) * 2e5
        global_step = row["global_step"]
        steps.append(global_step)
        returns.append(row["charts/episodic_return"])
    
    df = pd.DataFrame({"Samples": steps, "Return": returns})
    data[env_id][algo].append(df)

In [5]:
data2 = defaultdict(dict)  # {"env": {"algo": result}}

for env in data:
    for algo in data[env]:
        dfs = []
        for df in data[env][algo]:
            df2 = {}
            df2["Samples"] = df["Samples"].round(-5)
            df2["Return"] = df["Return"].rolling(50, min_periods=1).mean()
            dfs.append(pd.DataFrame(df2))
        data2[env][algo] = pd.concat(dfs)

In [12]:
colors = {
    "DVW M-DQN": sns.color_palette()[0],
    "M-DQN": sns.color_palette()[1],
    "DVW DQN": sns.color_palette()[2],
    "DQN": sns.color_palette()[3],
}

def plot_result(env, ylim, target="both", label=True, title=True):
    sns.set(font_scale=1.4)
    plt.figure(figsize=(7, 5))

    with sns.axes_style("whitegrid"):
        if target == "both":
            cond = lambda algo: True
            filename = f"{env}-Both"
        elif target == "munchausen":
            cond = lambda algo: "M-" in algo
            filename = f"{env}-M-DQN"
        elif target == "vanilla":
            cond = lambda algo: "M-" not in algo
            filename = f"{env}-DQN"
        else:
            raise ValueError

        for algo, color in colors.items():
            result = data2[env][algo]
            if cond(algo):
                if not label:
                    algo = None
                ax = sns.lineplot(data=result, x="Samples", y="Return", errorbar=("sd", 0.5), label=algo, estimator="mean", color=color)
                ax.set_ylim(top=ylim)
                ax.set_xlabel("Samples", fontsize=21)
                ax.set_ylabel("Return", fontsize=21)

        if label:
            plt.legend(loc="upper left")

    if title:
        plt.title(env.capitalize(), fontsize=21)
    plt.savefig(f"{filename}.png", bbox_inches="tight")
    plt.savefig(f"{filename}.pdf", bbox_inches="tight")
    plt.close()

In [None]:
env = "breakout"
plot_result(env, ylim=95, target="both", label=True)
plot_result(env, ylim=95, target="vanilla", label=True)
plot_result(env, ylim=95, target="munchausen", label=True)

In [None]:
env = "seaquest"
plot_result(env, ylim=120, target="both", label=False)
plot_result(env, ylim=120, target="vanilla", label=False)
plot_result(env, ylim=120, target="munchausen", label=False)

In [None]:
env = "freeway"
plot_result(env, ylim=75, target="both", label=False)
plot_result(env, ylim=75, target="vanilla", label=False)
plot_result(env, ylim=75, target="munchausen", label=False)

In [10]:
env = "space_invaders"
plot_result(env, ylim=300, target="both", label=False)
plot_result(env, ylim=300, target="vanilla", label=False)
plot_result(env, ylim=300, target="munchausen", label=False)

In [11]:
env = "asterix"
plot_result(env, ylim=55, target="both", label=False)
plot_result(env, ylim=55, target="vanilla", label=False)
plot_result(env, ylim=55, target="munchausen", label=False)