In [1]:
from collections import defaultdict
from glob import glob
import os
import yaml

from tensorboard.backend.event_processing import event_accumulator
import numpy as np
import plotly.graph_objects as go

In [2]:
def get_color(idx, alpha):
    colors = [
        "rgba(99, 110, 250,",
        "rgba(239, 85, 59,",
        "rgba(0, 204, 150,",
        "rgba(171, 99, 250,",
        "rgba(255, 161, 90,",
        "rgba(25, 211, 243,",
        "rgba(255, 102, 146,",
        "rgba(182, 232, 128,",
        "rgba(255, 151, 255,",
        "rgba(254, 203, 82,",
    ]
    return colors[idx % len(colors)] + str(alpha) + ")"


def extract_scalar_values(logdir, metric_name):
    tensorboard_files = glob(os.path.join(logdir, "**", "events.*"))
    assert len(tensorboard_files) == 1, "More than one tensorboard file found"
    event_acc = event_accumulator.EventAccumulator(tensorboard_files[0])
    event_acc.Reload()
    return [event.value for event in event_acc.Scalars(metric_name)]


def plot_curves(curves, title=None, metric_name="Return"):
    learning_curve = go.Figure()
    for idx, (run_name, x_values, y_values, lower_bound, upper_bound) in enumerate(
        curves
    ):
        learning_curve.add_trace(
            go.Scatter(
                x=x_values,
                y=y_values,
                mode="lines",
                name=f"{run_name}",
                line=dict(color=get_color(idx, 1)),
            )
        )
        learning_curve.add_trace(
            go.Scatter(
                x=x_values,
                y=lower_bound,
                fill=None,
                mode="lines",
                showlegend=False,
                line=dict(color=get_color(idx, 0.01)),
            )
        )
        learning_curve.add_trace(
            go.Scatter(
                x=x_values,
                y=upper_bound,
                fill="tonexty",
                mode="lines",
                showlegend=False,
                line=dict(color=get_color(idx, 0.01)),
            )
        )

    learning_curve.update_layout(
        title=title,
        xaxis_title="Step",
        yaxis_title=metric_name,
        legend=dict(x=0, y=1),
        template="plotly",
        width=800,
        height=500,
    )

    learning_curve.show()

An *experiment* compares different *trials*, which have many *runs* (for each seed).

In [3]:
def get_trial_curves(trial_run_dirs, metric_name):
    values_list = [
        extract_scalar_values(run_dir, metric_name) for run_dir in trial_run_dirs
    ]
    max_length = max(len(values) for values in values_list)

    mean_values = np.array([])
    std_dev_values = np.array([])

    while len(mean_values) < max_length:
        processed = len(mean_values)
        section_length = min(
            len(values[processed:]) for values in values_list if len(values) > processed
        )
        section_means = np.mean(
            [
                values[processed : processed + section_length]
                for values in values_list
                if len(values) > processed
            ],
            axis=0,
        )
        section_std_devs = np.std(
            [
                values[processed : processed + section_length]
                for values in values_list
                if len(values) > processed
            ],
            axis=0,
        )
        mean_values = np.concatenate((mean_values, section_means))
        std_dev_values = np.concatenate((std_dev_values, section_std_devs))

    conf_int = 1.96 * (
        std_dev_values / np.sqrt(len(trial_run_dirs))
    )  # 95% confidence interval

    x_values = list(range(1, len(mean_values) + 1))
    y_values = mean_values
    lower_bound = mean_values - conf_int
    upper_bound = mean_values + conf_int

    return x_values, y_values, lower_bound, upper_bound


def get_trials_and_runs(
    experiment_name, config_queries: dict, key_to_compare: str, metric_name: str
):
    trials = defaultdict(list)
    for config_file in glob("runs/*/configs.yaml"):
        with open(config_file, "r") as f:
            config = yaml.load(f, Loader=yaml.FullLoader)
        for query_key, query_value in config_queries.items():
            if config.get(query_key) != query_value:
                break
        else:
            # add run to respective trial
            trials[config[key_to_compare]].append("runs/" + config["run_id"])

    curves = []
    for trial_name, run_dirs in trials.items():
        print(f"Found {len(run_dirs)} runs for {trial_name}: {run_dirs}")
        curves.append(
            (
                f"{trial_name} ({len(run_dirs)} seeds)",
                *get_trial_curves(run_dirs, metric_name),
            )
        )

    return plot_curves(
        curves,
        title=f"{experiment_name} {metric_name.split('/')[-1]}",
        metric_name=metric_name,
    )

In [9]:
# env shuffle steps may have influenced here
get_trials_and_runs(
    "2x2 Normalized",
    {
        "env_variation": "normalized",
    },
    "env_h",
    "rollout/ep_rew_mean"
),
get_trials_and_runs(
    "2x2 Normalized",
    {
        "env_variation": "normalized",
    },
    "env_h",
    "rollout/success_rate"
)

Found 3 runs for 2: ['runs/20240121-211534', 'runs/20240121-213457', 'runs/20240121-215623']
Found 3 runs for 3: ['runs/20240121-221624', 'runs/20240121-223822', 'runs/20240121-225912']
Found 3 runs for 4: ['runs/20240121-232028', 'runs/20240121-234316', 'runs/20240122-000706']
Found 3 runs for 5: ['runs/20240122-003016', 'runs/20240122-005711', 'runs/20240122-012249']


Found 3 runs for 2: ['runs/20240121-211534', 'runs/20240121-213457', 'runs/20240121-215623']
Found 3 runs for 3: ['runs/20240121-221624', 'runs/20240121-223822', 'runs/20240121-225912']
Found 3 runs for 4: ['runs/20240121-232028', 'runs/20240121-234316', 'runs/20240122-000706']
Found 3 runs for 5: ['runs/20240122-003016', 'runs/20240122-005711', 'runs/20240122-012249']


In [10]:
get_trials_and_runs(
    "2x2 Single Image",
    {
        "policy": "CnnPolicy", 
        "env_variation": "image",
        "env_image_folder": "./imgs/single"
    },
    "env_h",
    "rollout/ep_rew_mean"
),
get_trials_and_runs(
    "2x2 Single Image",
    {
        "policy": "CnnPolicy", 
        "env_variation": "image",
        "env_image_folder": "./imgs/single"
    },
    "env_h",
    "rollout/success_rate"
)

Found 3 runs for 2: ['runs/20240121-211511', 'runs/20240121-215723', 'runs/20240121-224203']
Found 3 runs for 3: ['runs/20240121-234457', 'runs/20240122-003003', 'runs/20240122-011335']
Found 3 runs for 4: ['runs/20240122-020307', 'runs/20240122-024853', 'runs/20240122-033339']
Found 3 runs for 5: ['runs/20240122-041809', 'runs/20240122-050008', 'runs/20240122-081702']


Found 3 runs for 2: ['runs/20240121-211511', 'runs/20240121-215723', 'runs/20240121-224203']
Found 3 runs for 3: ['runs/20240121-234457', 'runs/20240122-003003', 'runs/20240122-011335']
Found 3 runs for 4: ['runs/20240122-020307', 'runs/20240122-024853', 'runs/20240122-033339']
Found 3 runs for 5: ['runs/20240122-041809', 'runs/20240122-050008', 'runs/20240122-081702']


In [13]:
get_trials_and_runs(
    "Mnist",
    {
        "policy": "CnnPolicy", 
        "env_variation": "image",
        "env_image_folder": "./imgs/mnist/trainingSet"
    },
    "env_h",
    "rollout/ep_rew_mean"
),
get_trials_and_runs(
    "Mnist",
    {
        "policy": "CnnPolicy", 
        "env_variation": "image",
        "env_image_folder": "./imgs/mnist/trainingSet"
    },
    "env_h",
    "rollout/success_rate"
)

Found 3 runs for 3: ['runs/20240121-220126', 'runs/20240121-224017', 'runs/20240121-231922']
Found 3 runs for 4: ['runs/20240122-000016', 'runs/20240122-004921', 'runs/20240122-013311']
Found 3 runs for 5: ['runs/20240122-021834', 'runs/20240122-030800', 'runs/20240122-035510']
Found 5 runs for 2: ['runs/20240122-101921', 'runs/20240122-103007', 'runs/20240122-104245', 'runs/20240122-130104', 'runs/20240122-144452']


Found 3 runs for 3: ['runs/20240121-220126', 'runs/20240121-224017', 'runs/20240121-231922']
Found 3 runs for 4: ['runs/20240122-000016', 'runs/20240122-004921', 'runs/20240122-013311']
Found 3 runs for 5: ['runs/20240122-021834', 'runs/20240122-030800', 'runs/20240122-035510']
Found 5 runs for 2: ['runs/20240122-101921', 'runs/20240122-103007', 'runs/20240122-104245', 'runs/20240122-130104', 'runs/20240122-144452']


In [7]:
get_trials_and_runs(
    "Imagenet1k",
    {
        "policy": "CnnPolicy", 
        "env_variation": "image",
        "env_image_folder": "./imgs/imagenet-1k/val_images"
    },
    "env_h",
    "rollout/ep_rew_mean"
),
get_trials_and_runs(
    "Imagenet1k",
    {
        "policy": "CnnPolicy", 
        "env_variation": "image",
        "env_image_folder": "./imgs/imagenet-1k/val_images"
    },
    "env_h",
    "rollout/success_rate"
)

Found 5 runs for 2: ['runs/20240121-211603', 'runs/20240121-224929', 'runs/20240121-232617', 'runs/20240122-001258', 'runs/20240122-011208']
Found 6 runs for 3: ['runs/20240122-011037', 'runs/20240122-015902', 'runs/20240122-021437', 'runs/20240122-024512', 'runs/20240122-030110', 'runs/20240122-034553']
Found 6 runs for 4: ['runs/20240122-032923', 'runs/20240122-041441', 'runs/20240122-042946', 'runs/20240122-045722', 'runs/20240122-050858', 'runs/20240122-054417']
Found 6 runs for 5: ['runs/20240122-053358', 'runs/20240122-060525', 'runs/20240122-061357', 'runs/20240122-063613', 'runs/20240122-064422', 'runs/20240122-071307']


Found 5 runs for 2: ['runs/20240121-211603', 'runs/20240121-224929', 'runs/20240121-232617', 'runs/20240122-001258', 'runs/20240122-011208']
Found 6 runs for 3: ['runs/20240122-011037', 'runs/20240122-015902', 'runs/20240122-021437', 'runs/20240122-024512', 'runs/20240122-030110', 'runs/20240122-034553']
Found 6 runs for 4: ['runs/20240122-032923', 'runs/20240122-041441', 'runs/20240122-042946', 'runs/20240122-045722', 'runs/20240122-050858', 'runs/20240122-054417']
Found 6 runs for 5: ['runs/20240122-053358', 'runs/20240122-060525', 'runs/20240122-061357', 'runs/20240122-063613', 'runs/20240122-064422', 'runs/20240122-071307']
