In [2]:
from collections import defaultdict
from glob import glob
import json
import os
import pathlib
from typing import List, Dict
import yaml

from tensorboard.backend.event_processing import event_accumulator
import numpy as np
import plotly.graph_objects as go

BASE_PATH = "logdir"

yaml.add_constructor(
    "tag:yaml.org,2002:python/object/apply:pathlib.PosixPath",
    lambda loader, node: pathlib.Path(*loader.construct_sequence(node)),
)

In [3]:
def ci95(std_dev, n_trials: int) -> float:
    return 1.96 * (std_dev / np.sqrt(n_trials))  # 95% confidence interval

In [4]:
random_results = {
    # w: (mean, 95% confidence interval)
    2: {
        "scalars/eval_return": (-41.38883333333334, ci95(50.217405158581336, 1000)),
        "scalars/log_success_mean": (1.0, ci95(0.0, 1000)),
    },
    3: {
        "scalars/eval_return": (-615.9455625, ci95(54.096148844087956, 1000)),
        "scalars/log_success_mean": (0.006, ci95(0.0772269383052313, 1000)),
    },
    4: {
        "scalars/eval_return": (-521.9251777777773, ci95(26.923752997348288, 1000)),
        "scalars/log_success_mean": (0.0, ci95(0.0, 1000)),
    },
    5: {
        "scalars/eval_return": (-428.0011354166666, ci95(23.243022947447443, 1000)),
        "scalars/log_success_mean": (0.0, ci95(0.0, 1000)),
    },
}

In [5]:
def get_color(idx, alpha):
    colors = [
        "rgba(99, 110, 250,",
        "rgba(239, 85, 59,",
        "rgba(0, 204, 150,",
        "rgba(171, 99, 250,",
        "rgba(255, 161, 90,",
        "rgba(25, 211, 243,",
        "rgba(255, 102, 146,",
        "rgba(182, 232, 128,",
        "rgba(255, 151, 255,",
        "rgba(254, 203, 82,",
    ]
    return colors[idx % len(colors)] + str(alpha) + ")"


def plot_curves(curves, title=None, metric_name="Return", ws_filter=None, sparse=False):
    learning_curve = go.Figure()

    # add horizontal lines from random_results
    # for w, metrics in random_results.items():
    #     if (ws_filter is not None and w not in ws_filter) or (
    #         "reward" in metric_name and sparse
    #     ):
    #         continue
    #     mean, ci = metrics[metric_name]
    #     learning_curve.add_hline(
    #         y=mean,
    #         line_dash="dot",
    #         line_color="black",
    #         annotation_text=f"Random {w}x{w}",
    #         annotation_position="bottom right",
    #     )
    #     learning_curve.add_hrect(
    #         y0=mean - ci,
    #         y1=mean + ci,
    #         fillcolor="black",
    #         opacity=0.1,
    #         line_width=0,
    #     )

    for idx, (run_name, x_values, y_values, lower_bound, upper_bound) in enumerate(
        curves
    ):
        learning_curve.add_trace(
            go.Scatter(
                x=x_values,
                y=y_values,
                mode="lines",
                name=f"{run_name}",
                line=dict(color=get_color(idx, 1)),
            )
        )
        learning_curve.add_trace(
            go.Scatter(
                x=x_values,
                y=lower_bound,
                fill=None,
                mode="lines",
                showlegend=False,
                line=dict(color=get_color(idx, 0.01)),
            )
        )
        learning_curve.add_trace(
            go.Scatter(
                x=x_values,
                y=upper_bound,
                fill="tonexty",
                mode="lines",
                showlegend=False,
                line=dict(color=get_color(idx, 0.01)),
            )
        )

    learning_curve.update_layout(
        title=title,
        xaxis_title="Step",
        yaxis_title=metric_name,
        template="plotly",
        width=800,
        height=500,
        legend=dict(
            yanchor="bottom",
            y=0.01,
            xanchor="right",
            x=0.9905,
        ),
    )

    learning_curve.show()

In [15]:
def get_trial_curves(trial_run_dirs, metric_name, bin_size=1):
    metrics = []
    max_length = 0
    for run_dir in trial_run_dirs:
        jsonl_file = os.path.join(run_dir, "metrics.jsonl")
        run_metrics = {}
        with open(jsonl_file, "r") as f:
            for line in f:
                data = json.loads(line)
                max_length = max(max_length, data["step"])
                run_metrics.setdefault(data["step"], {}).update(
                    {key: value for key, value in data.items() if key != "step"}
                )
                if "eval_length" in data:
                    run_metrics[data["step"]]["success_rate"] = int(data["eval_length"] < 1000)
        metrics.append(run_metrics)

    if bin_size > 1:
        # Adjust max_length to be a multiple of bin_size
        max_length = (max_length // bin_size) * bin_size

    mean_values = []
    std_values = []
    x_values = []

    for step in range(0, max_length, bin_size):
        section_values = []
        for run_metrics in metrics:
            # Aggregate values within each bin
            for b_step in range(step, min(step + bin_size, max_length)):
                if b_step in run_metrics and metric_name in run_metrics[b_step]:
                    section_values.append(run_metrics[b_step][metric_name])
        if len(section_values) == 0:
            continue
        mean_values.append(np.mean(section_values))
        std_values.append(np.std(section_values))
        x_values.append(step + bin_size / 2)

    conf_int = ci95(std_values, len(trial_run_dirs))
    lower_bound = [mean - ci for mean, ci in zip(mean_values, conf_int)]
    upper_bound = [mean + ci for mean, ci in zip(mean_values, conf_int)]

    return x_values, mean_values, lower_bound, upper_bound

In [16]:
def get_trials_and_runs(
    experiment_name,
    config_queries: List[dict],
    key_to_compare: str,
    metric_name: str,
    debug=False,
    bin_size=1,
):
    if type(config_queries) is not list:
        config_queries = [config_queries]

    trials = defaultdict(list)
    for config_file in glob(os.path.join(BASE_PATH, "*/config.yaml")):
        with open(config_file, "r") as f:
            config = yaml.load(f, Loader=yaml.FullLoader)
        
        

        # add missing keys
        config["train_every"] = (
            config["batch_size"] * config["batch_length"]
        ) / config["train_ratio"]

        for config_query in config_queries:
            # perform an "or" operation between queries in the list
            for query_key, query_value in config_query.items():
                if config.get(query_key) != query_value:
                    break
            else:
                run_id = config["logdir"].split("/")[-1]
                # add run to respective trial
                trials[config[key_to_compare]].append(os.path.join(BASE_PATH, run_id))
                if debug:
                    print(config[key_to_compare], run_id, config)

    curves = []
    for trial_name, run_dirs in sorted(trials.items()):
        print(f"Processing {trial_name}", run_dirs)
        curves.append(
            (
                f"{trial_name} ({len(run_dirs)} seeds)",
                *get_trial_curves(run_dirs, metric_name, bin_size=bin_size),
            )
        )

    ws_filter = [
        config_query.get("env_h") or config_query.get("env_w")
        for config_query in config_queries
    ]
    ws_filter = list(set(ws_filter))
    ws_filter = [w for w in ws_filter if w is not None]
    return plot_curves(
        curves,
        title=f"{experiment_name} {metric_name.split('/')[-1]}",
        metric_name=metric_name,
        ws_filter=ws_filter if len(ws_filter) > 0 else None,
        sparse=any(
            [config_query.get("env_sparse_rewards") for config_query in config_queries]
        ),
    )


def get_reward_and_success(*args, **kwargs):
    get_trials_and_runs(*args, **{**kwargs, "metric_name": "eval_length"})
    get_trials_and_runs(*args, **{**kwargs, "metric_name": "eval_return"})
    get_trials_and_runs(*args, **{**kwargs, "metric_name": "success_rate"})

## Hiperparâmetros

In [17]:
common_query = {
    "task": "sldp_onehot",
    "env__sparse_rewards": False,
    "env__w": 3,
    "units": 512,
}
get_reward_and_success(
    experiment_name="3x3 r1 vs r2",
    config_queries=[
        {
            **common_query,
            "batch_size": 16,
        },
        {
            **common_query,
            "batch_size": 48,
        },
    ],
    key_to_compare="train_every",
    bin_size=100,
)

Processing 1.0 ['logdir/20240330_110938-onehot_w3', 'logdir/20240325_174437-onehot_w3', 'logdir/20240329_135320-onehot_w3', 'logdir/20240331_123341-onehot_w3', 'logdir/20240325_075737-onehot_w3', 'logdir/20240328_163635-onehot_w3']
Processing 2.0 ['logdir/20240401_233732-onehot_w3_r2', 'logdir/20240401_104005-onehot_w3_r2', 'logdir/20240402_124310-onehot_w3_r2']


Processing 1.0 ['logdir/20240330_110938-onehot_w3', 'logdir/20240325_174437-onehot_w3', 'logdir/20240329_135320-onehot_w3', 'logdir/20240331_123341-onehot_w3', 'logdir/20240325_075737-onehot_w3', 'logdir/20240328_163635-onehot_w3']
Processing 2.0 ['logdir/20240401_233732-onehot_w3_r2', 'logdir/20240401_104005-onehot_w3_r2', 'logdir/20240402_124310-onehot_w3_r2']


Processing 1.0 ['logdir/20240330_110938-onehot_w3', 'logdir/20240325_174437-onehot_w3', 'logdir/20240329_135320-onehot_w3', 'logdir/20240331_123341-onehot_w3', 'logdir/20240325_075737-onehot_w3', 'logdir/20240328_163635-onehot_w3']
Processing 2.0 ['logdir/20240401_233732-onehot_w3_r2', 'logdir/20240401_104005-onehot_w3_r2', 'logdir/20240402_124310-onehot_w3_r2']


In [20]:
common_query = {
    "task": "sldp_onehot",
    "env__sparse_rewards": False,
    "env__w": 3,
    "train_every": 1,
}
get_reward_and_success(
    experiment_name="3x3 small vs large",
    config_queries=[
        {
            **common_query,
            "batch_size": 16,
        },
        {
            **common_query,
            "batch_size": 48,
        },
    ],
    key_to_compare="units",
    bin_size=100,
)

Processing 512 ['logdir/20240330_110938-onehot_w3', 'logdir/20240325_174437-onehot_w3', 'logdir/20240329_135320-onehot_w3', 'logdir/20240331_123341-onehot_w3', 'logdir/20240325_075737-onehot_w3', 'logdir/20240328_163635-onehot_w3']
Processing 768 ['logdir/20240401_121409-onehot_w3_large', 'logdir/20240402_013440-onehot_w3_large']


Processing 512 ['logdir/20240330_110938-onehot_w3', 'logdir/20240325_174437-onehot_w3', 'logdir/20240329_135320-onehot_w3', 'logdir/20240331_123341-onehot_w3', 'logdir/20240325_075737-onehot_w3', 'logdir/20240328_163635-onehot_w3']
Processing 768 ['logdir/20240401_121409-onehot_w3_large', 'logdir/20240402_013440-onehot_w3_large']


Processing 512 ['logdir/20240330_110938-onehot_w3', 'logdir/20240325_174437-onehot_w3', 'logdir/20240329_135320-onehot_w3', 'logdir/20240331_123341-onehot_w3', 'logdir/20240325_075737-onehot_w3', 'logdir/20240328_163635-onehot_w3']
Processing 768 ['logdir/20240401_121409-onehot_w3_large', 'logdir/20240402_013440-onehot_w3_large']


## Resultados gerais

In [42]:
common_query = {
    "task": "sldp_onehot",
    "env__sparse_rewards": False,
    "units": 512,
    "train_every": 1,
}
get_reward_and_success(
    experiment_name="onehot",
    config_queries=[
        {
            **common_query,
        },
    ],
    key_to_compare="env__w",
    bin_size=100,
)

Processing 2 ['logdir/20240330_111422-onehot_w2']
Processing 3 ['logdir/20240329_192841-onehot_w3_b96', 'logdir/20240330_110938-onehot_w3', 'logdir/20240325_174437-onehot_w3', 'logdir/20240328_163649-onehot_w3_b96', 'logdir/20240325_175242-onehot_w3_b96', 'logdir/20240329_135320-onehot_w3', 'logdir/20240331_123341-onehot_w3', 'logdir/20240325_075737-onehot_w3', 'logdir/20240328_163635-onehot_w3']
Processing 4 ['logdir/20240331_024449-onehot_w4', 'logdir/20240330_111437-onehot_w4', 'logdir/20240331_160502-onehot_w4']
Processing 5 ['logdir/20240330_110943-onehot_w5', 'logdir/20240331_123827-onehot_w5']


Processing 2 ['logdir/20240330_111422-onehot_w2']
Processing 3 ['logdir/20240329_192841-onehot_w3_b96', 'logdir/20240330_110938-onehot_w3', 'logdir/20240325_174437-onehot_w3', 'logdir/20240328_163649-onehot_w3_b96', 'logdir/20240325_175242-onehot_w3_b96', 'logdir/20240329_135320-onehot_w3', 'logdir/20240331_123341-onehot_w3', 'logdir/20240325_075737-onehot_w3', 'logdir/20240328_163635-onehot_w3']
Processing 4 ['logdir/20240331_024449-onehot_w4', 'logdir/20240330_111437-onehot_w4', 'logdir/20240331_160502-onehot_w4']
Processing 5 ['logdir/20240330_110943-onehot_w5', 'logdir/20240331_123827-onehot_w5']


Processing 2 ['logdir/20240330_111422-onehot_w2']
Processing 3 ['logdir/20240329_192841-onehot_w3_b96', 'logdir/20240330_110938-onehot_w3', 'logdir/20240325_174437-onehot_w3', 'logdir/20240328_163649-onehot_w3_b96', 'logdir/20240325_175242-onehot_w3_b96', 'logdir/20240329_135320-onehot_w3', 'logdir/20240331_123341-onehot_w3', 'logdir/20240325_075737-onehot_w3', 'logdir/20240328_163635-onehot_w3']
Processing 4 ['logdir/20240331_024449-onehot_w4', 'logdir/20240330_111437-onehot_w4', 'logdir/20240331_160502-onehot_w4']
Processing 5 ['logdir/20240330_110943-onehot_w5', 'logdir/20240331_123827-onehot_w5']


In [32]:
common_query = {
    "task": "sldp_image",
    "env__image_folder": "single",
    # "env__sparse_rewards": False,
}
get_reward_and_success(
    experiment_name="single",
    config_queries=[
        {
            **common_query,
        },
    ],
    key_to_compare="env__w",
    bin_size=100,
)

Processing 2 ['logdir/20240323_175347-single_w2_sparse', 'logdir/20240324_213203-single_w2_sparse']
Processing 3 ['logdir/20240324_183843-single_w3_sparse', 'logdir/20240401_105853-single_w3_r2', 'logdir/20240323_181013-single_w3_sparse']


Processing 2 ['logdir/20240323_175347-single_w2_sparse', 'logdir/20240324_213203-single_w2_sparse']
Processing 3 ['logdir/20240324_183843-single_w3_sparse', 'logdir/20240401_105853-single_w3_r2', 'logdir/20240323_181013-single_w3_sparse']


Processing 2 ['logdir/20240323_175347-single_w2_sparse', 'logdir/20240324_213203-single_w2_sparse']
Processing 3 ['logdir/20240324_183843-single_w3_sparse', 'logdir/20240401_105853-single_w3_r2', 'logdir/20240323_181013-single_w3_sparse']


In [30]:
common_query = {
    "task": "sldp_image",
    "env__image_folder": "imagenet-1k",
    # "env__sparse_rewards": False,
}
get_reward_and_success(
    experiment_name="imagenet",
    config_queries=[
        {
            **common_query,
        },
    ],
    key_to_compare="env__w",
    bin_size=100,
)

Processing 2 ['logdir/20240323_235855-imagenet_w2_sparse']
Processing 3 ['logdir/20240324_000431-imagenet_w3_sparse', 'logdir/20240325_053519-imagenet_w3']


Processing 2 ['logdir/20240323_235855-imagenet_w2_sparse']
Processing 3 ['logdir/20240324_000431-imagenet_w3_sparse', 'logdir/20240325_053519-imagenet_w3']


Processing 2 ['logdir/20240323_235855-imagenet_w2_sparse']
Processing 3 ['logdir/20240324_000431-imagenet_w3_sparse', 'logdir/20240325_053519-imagenet_w3']


In [31]:
common_query = {
    "task": "sldp_image",
    "env__image_folder": "mnist",
    # "env__sparse_rewards": False,
}
get_reward_and_success(
    experiment_name="mnist",
    config_queries=[
        {
            **common_query,
        },
    ],
    key_to_compare="env__w",
    bin_size=100,
)

Processing 2 ['logdir/20240324_061158-mnist_w2_sparse']
Processing 3 ['logdir/20240324_054805-mnist_w3_sparse']


Processing 2 ['logdir/20240324_061158-mnist_w2_sparse']
Processing 3 ['logdir/20240324_054805-mnist_w3_sparse']


Processing 2 ['logdir/20240324_061158-mnist_w2_sparse']
Processing 3 ['logdir/20240324_054805-mnist_w3_sparse']
