In [None]:
import os
import pickle

import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
performance_dict = pickle.load(
    open(os.path.join("synthetic_output", "0.003_performance_dict.pkl"), "rb")
)

In [None]:
performance_dict

In [None]:
def plot_test_performance_vs_dim(
    num_layers: int,
    autoregressive: bool,
    add_bos: bool = None,
    n_runs: int = 10,
    output_dir: str = "synthetic_output",
):
    # num_layers = 2
    # autoregressive = True
    # add_bos = False
    if not autoregressive:
        assert add_bos is None

    plot_dict = dict()
    for k, v in performance_dict.items():
        if isinstance(k, tuple):
            if num_layers != int(k[0].split(":")[1]) or autoregressive != (
                k[3].split(":")[1] == "True"
            ):
                continue
            # print(k)
            if autoregressive:
                if add_bos != (k[4].split(":")[1] == "True"):
                    continue

            time_feat_dim = k[1].split(":")[1]
            time_encoding_method = k[2].split(":")[1]
            # autoregressive = k[3].split(":")[1]
            # if autoregressive == "True":
            #     add_bos = k[4].split(":")[1]
            # else:
            #     add_bos = None

            avg_test_performance = v[2][0]
            std_test_performance = v[2][1]
            if time_encoding_method not in plot_dict:
                plot_dict[time_encoding_method] = (
                    [time_feat_dim],
                    [avg_test_performance],
                    [std_test_performance],
                )
            else:
                plot_dict[time_encoding_method][0].append(time_feat_dim)
                plot_dict[time_encoding_method][1].append(avg_test_performance)
                plot_dict[time_encoding_method][2].append(std_test_performance)

    sns.set_theme()
    plt.figure(figsize=(14, 8))
    plt.tight_layout()

    # avg_list = [2,4,6,8,10]
    # stderr_list = [0.1, 0.15, 0.03, 0.2, 0.13]
    # upper = [avg_list[j] + 1.96 * stderr_list[j] for j in range(len(stderr_list))]
    # lower = [avg_list[j] - 1.96 * stderr_list[j] for j in range(len(stderr_list))]
    # sns.lineplot(x=range(len(avg_list)), y=avg_list, markersize=18, legend=None)
    # plt.fill_between(range(len(avg_list)), lower, upper, alpha=0.07)
    for k, v in plot_dict.items():
        # plt.errorbar(v[0], v[1], yerr=v[2], label=k)
        # plt.plot(v[0], v[1], label=k)
        avg_list = v[1]
        stddev_list = v[2]
        stderr_list = [stddev / n_runs**0.5 for stddev in stddev_list]
        upper = [avg_list[j] + 1.96 * stderr_list[j] for j in range(len(stderr_list))]
        lower = [avg_list[j] - 1.96 * stderr_list[j] for j in range(len(stderr_list))]
        # upper = [avg_list[j] + stddev_list[j] for j in range(len(stddev_list))]
        # lower = [avg_list[j] - stddev_list[j] for j in range(len(stddev_list))]
        sns.lineplot(x=v[0], y=avg_list, markersize=18, label=k)
        plt.fill_between(v[0], lower, upper, alpha=0.07)

    avg_oracle_test_acc = performance_dict["oracle"][2][0]
    std_oracle_test_acc = performance_dict["oracle"][2][1]
    stderr_oracle_test_acc = std_oracle_test_acc / n_runs**0.5
    sns.lineplot(x=v[0], y=[avg_oracle_test_acc] * len(v[0]), label="oracle", linestyle="--")
    plt.fill_between(
        v[0],
        [avg_oracle_test_acc - 1.96 * stderr_oracle_test_acc] * len(v[0]),
        [avg_oracle_test_acc + 1.96 * stderr_oracle_test_acc] * len(v[0]),
        alpha=0.07,
    )

    if autoregressive:
        if add_bos:
            title = f"{num_layers}_layer-autoregressive-add_bos"
        else:
            title = f"{num_layers}_layer-autoregressive-wo_bos"
    else:
        title = f"{num_layers}_layer-nonautoregressive"
    plt.title(title, fontsize=24)
    plt.ylabel("Test Accuracy", fontsize=24)
    plt.ylim(0.75, 1.0)
    plt.xlabel("Time Feature Dimension", fontsize=24)
    plt.legend()
    plt.savefig(os.path.join(output_dir, f"{title}.png"))
    plt.show()

In [None]:
plot_test_performance_vs_dim(num_layers=1, autoregressive=False, add_bos=None)
plot_test_performance_vs_dim(num_layers=1, autoregressive=True, add_bos=False)
plot_test_performance_vs_dim(num_layers=1, autoregressive=True, add_bos=True)

In [None]:
plot_test_performance_vs_dim(num_layers=2, autoregressive=False, add_bos=None)
plot_test_performance_vs_dim(num_layers=2, autoregressive=True, add_bos=False)
plot_test_performance_vs_dim(num_layers=2, autoregressive=True, add_bos=True)