In [1]:
from dataclasses import dataclass
import expt
import matplotlib.pyplot as plt
import wandb
from expt import Hypothesis, Run
from expt.plot import GridPlot
import numpy as np

api = wandb.Api(timeout=60)


env_ids = [
    "BigFish",
    "BossFight",
    "CaveFlyer",
    "Chaser",
    "Climber",
    "CoinRun",
    "Dodgeball",
    "FruitBot",
    "Heist",
    "Jumper",
    "Leaper",
    "Maze",
    "Miner",
    "Ninja",
    "Plunder",
    "StarPilot",
]


@dataclass
class Experiment:
    learning_curve: np.array
    global_step: np.array

    def smoothed_learning_curve(self, window_size=50):
        return np.array([np.convolve(self.learning_curve[i], np.ones(window_size) / window_size, mode="same") for i in range(self.learning_curve.shape[0])])


def download_metrics(env: str = "bigfish", entity: str = "cswinter", project: str = "cleanRL", filters = []) -> Experiment:
    runs = api.runs(
        f"{entity}/{project}",
        filters={"$and": [{"config.env_id.value": env}, {"config.exp_name.value": "ppo_procgen"}, *filters]},
    )
    learning_curves = []
    for run in runs:
        global_step = []
        learning_curve = []
        vals = run.history(keys=["global_step", "charts/episodic_return"], pandas=False)
        for val in vals:
            global_step.append(val["global_step"])
            learning_curve.append(val["charts/episodic_return"])
        learning_curves.append(np.array(learning_curve))
    return Experiment(np.array(learning_curves), global_step)

# Plot the learning curves with min/max shaded
def plot(experiments: list[Experiment]):
    fig, ax = plt.subplots()
    # Increase plot size
    fig.set_size_inches(10, 5)
    for xp in experiments:
        x = xp.smoothed_learning_curve()
        ax.plot(xp.global_step, x.mean(axis=0), label="Mean")
        ax.fill_between(
            xp.global_step,
            x.min(axis=0),
            x.max(axis=0),
            alpha=0.2,
            label="Min/Max",
        )
    ax.set_xlabel("Training Steps")
    ax.set_ylabel("Episodic Return")
    ax.legend()
    return fig

env_id = "BigFish"
baseline = download_metrics(env="bigfish")
enn_ppo = download_metrics(
    env="bigfish",
    entity="entity-neural-network",
    project="enn-ppo",
    filters=[{"config.base_name": f"221010-010215-procgen-baselines-mode=easy-env={env_id}"}],
)

  from .autonotebook import tqdm as notebook_tqdm


HTTPError: 500 Server Error: Internal Server Error for url: https://api.wandb.ai/graphql