In [None]:
from shared.plotting import (
    plot_result,
    plot_result_rliable,
    plot_metrics_figure,
    condition_name_map,
    label_name_map,
    calculate_correlation,
)
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as stats
from analyze_helpers import *
from pylatex import Document, Tabular, NoEscape
from rliable import plot_utils

%config InlineBackend.figure_format = 'retina'

list_recent_subdirs(n=15)

### Figure 2: Basic comparison of baseline and reset-all

In [None]:
# Create a gridspec layout
fig = plt.figure(figsize=(16, 6), dpi=150)
gs = gridspec.GridSpec(2, 5, figure=fig)

quantity = "train_r"

experiment_name = "jumper-50k-by100-window-all"
hyperparams, combined_data, quantities, conditions = process_data(experiment_name)
conditions = ["baseline", "reset-all", "l2-init"]

# Add first plot at the top
ax1 = plt.subplot(gs[0, :2])
plot_result(
    combined_data[quantity],
    quantity,
    hyperparams,
    f"./results/{experiment_name}",
    current_epoch=None,
    overview=False,
    print_title=False,
    axis=ax1,
    add_legend=False,
    conditions=conditions,
    custom_title="(A) Permute",
)

# Add second plot at bottom left
ax2 = plt.subplot(gs[0, 2])
rounds_axis = plot_result(
    combined_data[quantity],
    quantity,
    hyperparams,
    f"./results/{experiment_name}",
    current_epoch=None,
    overview=True,
    print_title=False,
    axis=ax2,
    add_legend=False,
    conditions=conditions,
    custom_title="(B) Permute",
)

experiment_name = "jumper-50k-by100-window-all"
hyperparams, combined_data, quantities, conditions = process_data(experiment_name)
conditions = ["baseline", "reset-all", "l2-init"]

# Add third plot at bottom right
ax3 = plt.subplot(gs[0, 3])
rounds_axis = plot_result(
    combined_data[quantity],
    quantity,
    hyperparams,
    f"./results/{experiment_name}",
    current_epoch=None,
    overview=True,
    print_title=False,
    axis=ax3,
    add_legend=False,
    conditions=conditions,
    custom_title="(C) Window",
)

experiment_name = "jumper-50k-by100-window-all"
hyperparams, combined_data, quantities, conditions = process_data(experiment_name)
conditions = ["baseline", "reset-all", "l2-init"]

# Add third plot at bottom right
ax4 = plt.subplot(gs[0, 4])
rounds_axis = plot_result(
    combined_data[quantity],
    quantity,
    hyperparams,
    f"./results/{experiment_name}",
    current_epoch=None,
    overview=True,
    print_title=False,
    axis=ax4,
    add_legend=False,
    conditions=conditions,
    custom_title="(D) Expand",
)

quantity = "test_r"

experiment_name = "jumper-50k-by100-window-all"
hyperparams, combined_data, quantities, conditions = process_data(experiment_name)
conditions = ["baseline", "reset-all", "l2-init"]

# Add first plot at the top
ax1 = plt.subplot(gs[1, :2])
plot_result(
    combined_data[quantity],
    quantity,
    hyperparams,
    f"./results/{experiment_name}",
    current_epoch=None,
    overview=False,
    print_title=False,
    axis=ax1,
    add_legend=False,
    conditions=conditions,
    custom_title="(A) Permute",
)

# Add second plot at bottom left
ax2 = plt.subplot(gs[1, 2])
rounds_axis = plot_result(
    combined_data[quantity],
    quantity,
    hyperparams,
    f"./results/{experiment_name}",
    current_epoch=None,
    overview=True,
    print_title=False,
    axis=ax2,
    add_legend=False,
    conditions=conditions,
    custom_title="(B) Permute",
)

experiment_name = "jumper-50k-by100-window-all"
hyperparams, combined_data, quantities, conditions = process_data(experiment_name)
conditions = ["baseline", "reset-all", "l2-init"]

# Add third plot at bottom right
ax3 = plt.subplot(gs[1, 3])
rounds_axis = plot_result(
    combined_data[quantity],
    quantity,
    hyperparams,
    f"./results/{experiment_name}",
    current_epoch=None,
    overview=True,
    print_title=False,
    axis=ax3,
    add_legend=False,
    conditions=conditions,
    custom_title="(C) Window",
)

experiment_name = "jumper-50k-by100-window-all"
hyperparams, combined_data, quantities, conditions = process_data(experiment_name)
conditions = ["baseline", "reset-all", "l2-init"]

# Add third plot at bottom right
ax4 = plt.subplot(gs[1, 4])
rounds_axis = plot_result(
    combined_data[quantity],
    quantity,
    hyperparams,
    f"./results/{experiment_name}",
    current_epoch=None,
    overview=True,
    print_title=False,
    axis=ax4,
    add_legend=False,
    conditions=conditions,
    custom_title="(D) Expand",
)


condition_labels = [condition_name_map[c] for c in conditions]
handles, _ = ax1.get_legend_handles_labels()
fig.legend(
    handles, condition_labels, loc="upper center", bbox_to_anchor=(0.5, 1.1), ncols=3
)

# Adjust layout and save as PDF
plt.tight_layout()
plt.savefig("combined_figure.pdf", dpi=150, bbox_inches="tight")

### Figure 3: Correlation plots

In [None]:
# experiment_names = [
#    "coinrun-50k-by100-bg-permute-all",
#    "coinrun-50k-by100-bg-window-all",
#    "coinrun-50k-by100-bg-expand-all",
# ]

experiment_names = [
    "gather-200k-by100-expand-all-mlp",
    "gather-200k-by100-permute-all-mlp",
    "gather-200k-by100-window-all-mlp",
]

metrics = [
    "entropy",
    # "v_estimate",
    # "p_loss",
    # "v_loss",
    "weight_m",
    "weight_diff",
    "grad_norm",
    "dead_units",
]

conditions = [
    "baseline",
    "reset-all",
    "sp-5-5",
    "soft-sp-6",
    "l2-init-4",
    "l2-norm",
    "reset-final",
    "inject",
    "redo-reset-10",
]


# Create a gridspec layout
fig = plt.figure(figsize=(18, 7), dpi=150)
gs = gridspec.GridSpec(2, 5, figure=fig)


transformed = process_corrs("train_r", experiment_names, metrics, conditions)
plot_corr_row(0, transformed, conditions, metrics, gs, quantity="train_r")
transformed = process_corrs("test_r", experiment_names, metrics, conditions)
ax1 = plot_corr_row(1, transformed, conditions, metrics, gs, quantity="test_r")

condition_labels = [condition_name_map[c] for c in conditions]
handles, _ = ax1.get_legend_handles_labels()
fig.legend(
    handles, condition_labels, loc="upper center", bbox_to_anchor=(0.5, 1.15), ncols=4
)
# use tight layout to make sure legend/labels are not cut off
plt.tight_layout()
plt.savefig("metrics_corr.pdf", dpi=150, bbox_inches="tight")

### RLIABLE

In [None]:
conditions = [
    "baseline",
    "reset-all",
    "l2-norm",
    # "l2-init-4",
    "l2-init",
    "sp-5-5",
    "ssp-6",
    # "soft-sp-6",
    "crelu",
    "inject",
    "reset-final",
    "layernorm",
    "redo-reset-10",
    "l2-init-ln",
    "ssp-6-ln",
]

# experiment_names = [
#     "gather-200k-by100-permute-all-mlp",
#     "gather-200k-by100-window-all-mlp",
#     "gather-200k-by100-expand-all-mlp",
# ]

experiment_names = [
    "coinrun-50k-by100-bg-permute-all",
    "coinrun-50k-by100-bg-window-all",
    "coinrun-50k-by100-bg-expand-all",
]

# conditions = [
#     "redo-reset-1",
#     "redo-reset-10",
#     "redo-reset-100",
#     "redo-reset-1000",
# ]

experiment_labels = ["Permute", "Window", "Expand"]

quantities = ["train_r", "test_r"]

quantity_nice = ["Normalized Reward (Train)", "Normalized Reward (Test)"]

for j, quantity in enumerate(quantities):
    mean_dict = {}
    ste_dict = {}

    for i, (experiment_name, quantity_label) in enumerate(
        zip(experiment_names, experiment_labels)
    ):
        hyperparams, combined_data, _, _ = process_data(experiment_name)
        exp_mean, exp_ste = plot_result_rliable(
            combined_data[quantity],
            quantity,
            hyperparams,
            f"./results/{experiment_name}",
            current_epoch=None,
            overview=True,
            print_title=False,
            axis=None,
            add_legend=False,
            conditions=conditions,
            custom_title=f"{quantity_label} ({quantity_nice[j]})",
        )
        for c in list(exp_mean.keys()):
            if c not in mean_dict:
                mean_dict[c] = []
                ste_dict[c] = []
            mean_dict[c].append(exp_mean[c])
            ste_dict[c].append(exp_ste[c])
    for c in list(exp_mean.keys()):
        mean_dict[c] = np.array(mean_dict[c]).squeeze()
        ste_dict[c] = np.array(ste_dict[c]).squeeze().T
    fig, axis = plot_utils.plot_interval_estimates(
        mean_dict,
        ste_dict,
        experiment_labels,
        algorithms=list(exp_mean.keys()),
        subfigure_width=6,
        row_height=0.5,
        xlabel=quantity_nice[j],
        color_palette=plt.cm.tab20.colors,
    )
    if j == 0:
        fig.suptitle("CoinRun", fontsize=30, y=1.1, x=0.4)
    fig.savefig(f"coin_rliable_{j}.pdf", dpi=150, bbox_inches="tight")

### Figure 4: Comparisons between different interventions

In [None]:
conditions = [
    "baseline",
    "reset-all",
    "l2-norm",
    # "l2-init-4",
    "l2-init",
    "sp-5-5",
    "ssp-6",
    # "soft-sp-6",
    "crelu",
    "inject",
    "reset-final",
    "layernorm",
    "redo-reset-10",
    "l2-init-ln",
    "ssp-6-ln",
]

# experiment_names = [
#     "gather-200k-by100-permute-all-mlp",
#     "gather-200k-by100-window-all-mlp",
#     "gather-200k-by100-expand-all-mlp",
# ]

experiment_names = [
    "fruitbot-50k-by100-permute-all",
    "fruitbot-50k-by100-window-all",
    "fruitbot-50k-by100-expand-all",
]

# conditions = [
#     "redo-reset-1",
#     "redo-reset-10",
#     "redo-reset-100",
#     "redo-reset-1000",
# ]

experiment_labels = ["Permute", "Window", "Expand"]

quantities = ["train_r", "test_r"]

# Create a gridspec layout
fig = plt.figure(figsize=(16, 8), dpi=150)
gs = gridspec.GridSpec(2, 3, figure=fig)

for i, (experiment_name, quantity_label) in enumerate(
    zip(experiment_names, experiment_labels)
):
    hyperparams, combined_data, _, _ = process_data(experiment_name)

    for j, quantity in enumerate(quantities):
        ax = plt.subplot(gs[j, i])
        plot_result(
            combined_data[quantity],
            quantity,
            hyperparams,
            f"./results/{experiment_name}",
            current_epoch=None,
            overview=True,
            print_title=False,
            axis=ax,
            add_legend=False,
            conditions=conditions,
            custom_title=f"({chr(65+j)}) {quantity_label}",
        )

# Create legend
condition_labels = [condition_name_map[c] for c in conditions]
handles, _ = ax.get_legend_handles_labels()  # Use the last axis to get legend handles
fig.legend(
    handles, condition_labels, loc="upper center", bbox_to_anchor=(0.5, 1.13), ncols=5
)
# Add title to figure
fig.suptitle("FruitBot", fontsize=20)
# Adjust layout and save as PDF
plt.tight_layout()
plt.savefig(
    "fruit_full.pdf",
    dpi=150,
    bbox_inches="tight",
)

### Appendix Figure: Error Plots

In [None]:
# Create a gridspec layout
fig = plt.figure(figsize=(16, 8), dpi=150)
gs = gridspec.GridSpec(2, 3, figure=fig)

conditions = [
    "baseline",
    "reset-all",
    "crelu",
    "inject",
    "reset-final",
    "l2-norm",
    "sp-5-5",
    "ssp-6",
    "l2-init",
    "layernorm",
    "ssp-6-ln",
    "l2-init-ln",
]

experiment_info = [
    ("coinrun-50k-by100-bg-permute-all", "(A) Permute"),
    ("coinrun-50k-by100-bg-window-all", "(A) Window"),
    ("coinrun-50k-by100-bg-expand-all", "(A) Expand"),
]

experiment_labels = ["Permute", "Window", "Expand"]

quantities = ["train_r", "test_r"]

for i, (experiment_name, a_title) in enumerate(experiment_info):
    hyperparams, combined_data, _, _ = process_data(experiment_name)

    for j, quantity in enumerate(quantities):
        ax1 = plt.subplot(gs[j, i])
        plot_metrics_figure(
            combined_data,
            hyperparams,
            experiment_name,
            conditions,
            quantity=quantity,
            print_title=False,
            add_legend=False,
            axis=ax1,
            custom_title=a_title,
        )


# Add numbers to condition_labels
condition_labels = [condition_name_map[c] for c in conditions]
condition_labels = [f"({i}) {c}" for i, c in enumerate(condition_labels)]
handles, _ = ax1.get_legend_handles_labels()
fig.legend(
    handles, condition_labels, loc="upper center", bbox_to_anchor=(0.5, 1.12), ncols=4
)

# Adjust layout and save as PDF
plt.tight_layout()
plt.savefig(
    "coinrun.pdf",
    dpi=150,
    bbox_inches="tight",
)

### Appendix Table: t-test and p-value plots

In [None]:
quantity = "test_r"
conditions = [
    "baseline",
    "reset-all",
    "crelu",
    "inject",
    "reset-final",
    "l2-norm",
    "sp-5-5",
    "soft-sp-6",
    "l2-init-4",
    "layernorm",
    "ssp-6-ln",
    "l2-init-ln",
]

experiment_name = "gather-200k-by100-expand-all-mlp"
hyperparams, combined_data, quantities, _ = process_data(experiment_name)

results = plot_metrics_figure(
    combined_data,
    hyperparams,
    experiment_name,
    conditions,
    quantity=quantity,
    print_title=False,
    add_legend=False,
    plot_figure=False,
    ttest=True,
)

doc = Document()

# Add a table to the document
with doc.create(Tabular("|c|c|c|")) as table:
    table.add_hline()
    # add multicolumn for the first row
    table.add_hline()
    table.add_row(("Method", "t-value", "p-value"))
    table.add_hline()
    for test, values in results.items():
        # round the values to 3 decimal places
        values = [round(v, 3) for v in values]
        table.add_row([test, f"t({values[1]}) = {values[0]}", values[2]])
        table.add_hline()

# Generate the LaTeX string
latex_str = doc.dumps()

# Optionally, save the LaTeX file
doc.generate_tex("my_latex_table")

# Print the LaTeX string
print(latex_str)

### Appendix Figure: Additional Metric Plots

In [None]:
conditions = [
    "baseline",
    "reset-all",
    "sp-5-5",
    "soft-sp-6",
    "l2-init-4",
    "l2-norm",
    "reset-final",
    "inject",
    "redo-reset-10",
]

metrics = [
    # "entropy",
    # "v_estimate",
    # "p_loss",
    # "v_loss",
    "weight_m",
    # "weight_diff",
    "grad_norm",
    "dead_units",
    "train_r",
    # "test_r",
]

# experiment_names = [
#    "gather-200k-by100-permute-all-mlp",
#    "gather-200k-by100-window-all-mlp",
#    "gather-200k-by100-expand-all-mlp",
# ]
experiment_names = ["gather-200k-by100-window-all-mlp"]
# experiment_titles = ["Permute", "Window", "Expand"]

# Create a gridspec layout
fig = plt.figure(figsize=(16, 1 * 4), dpi=150)
gs = gridspec.GridSpec(1, len(metrics), figure=fig)

for idx, experiment_name in enumerate(experiment_names):
    hyperparams, combined_data, quantities, _ = process_data(experiment_name)
    for j, quantity in enumerate(metrics):
        ax1 = plt.subplot(gs[idx, j])
        plot_result(
            combined_data[quantity],
            quantity,
            hyperparams,
            f"./results/{experiment_name}",
            current_epoch=None,
            overview=True,
            print_title=False,
            axis=ax1,
            add_legend=False,
            conditions=conditions,
            # custom_title=f"{experiment_titles[idx]}",
            custom_title=f"{label_name_map[quantity]}",
        )


condition_labels = [condition_name_map[c] for c in conditions]
handles, _ = ax1.get_legend_handles_labels()
# fig.legend(
#    handles, condition_labels, loc="upper center", bbox_to_anchor=(0.5, 1.1), ncols=4
# )

# Adjust layout and save as PDF
plt.tight_layout()
plt.savefig(
    "extra_metrics_part.pdf",
    dpi=150,
    bbox_inches="tight",
)