In [None]:
import jax
import jax.numpy as jnp
import yaml
import pickle
import pandas as pd
from notebooks.visualization_functions import plot_UL_rewards
import matplotlib.pyplot as plt

linewidth = 4
plt.rcParams.update(
    {
        'font.size': 26,
        'text.usetex': True,
        'lines.linewidth': linewidth,
        'axes.linewidth': linewidth,
        'xtick.major.width': linewidth,
        'ytick.major.width': linewidth,
        'xtick.major.size': 2*linewidth,
        'ytick.major.size': 2*linewidth,
        'axes.prop_cycle': plt.cycler(color=plt.cm.Dark2.colors),
    }
)
colors = plt.cm.Dark2.colors

def flatten_keys(d):
    if isinstance(d, dict):
        for k, v in d.items():
            if isinstance(v, dict):
                for sub_k in flatten_keys(v):
                    yield k+"."+sub_k
            else:
                yield k

# Load Data

In [None]:
# Read config
dir = "../data/path/to/experiment/folder*"
configs = {}
dataframes = {}
run_metrics = {}
with open(f"{dir}/config.yaml", "r") as f:
    config = yaml.safe_load(f)
print("Config: ", config)
rng = jax.random.PRNGKey(config["random_seed"])
for algorithm in [
    "hpgd",
    "zero_order"
]:
    print("Loading: ", algorithm)
    try:
        with open(f"{dir}/metrics_{algorithm}.pkl", "rb") as f:
            metrics = pickle.load(f)
            metrics = jax.tree_map(lambda x: x.astype(jnp.float16), metrics)
        with open(f"{dir}/update_dict_{algorithm}.pkl", "rb") as f:
            update_dict = pickle.load(f)
    except FileNotFoundError:
        print(f"File not found for {algorithm}")
        continue
    if len(update_dict) > 0:
        update_dict_names = list(flatten_keys(update_dict))
        indices = jax.tree_leaves(update_dict, is_leaf=lambda x: isinstance(x, list))
        indices = [x.tolist() for x in indices]
        indices = [[str(x) if isinstance(x, list) else x for x in l] for l in indices]
        df = pd.DataFrame(
            {
                "returns": jnp.mean(jnp.mean(metrics["V_UL"][...,-1:], -1), 0),
                "std": jnp.nanstd(jnp.mean(metrics["V_UL"][...,-1:], -1), 0),
                "returns_start": jnp.mean(jnp.mean(metrics["V_UL"][...,:10], -1), 0),
                "returns_end": jnp.mean(jnp.mean(metrics["V_UL"][...,-10:], -1), 0),
            },
            index=pd.MultiIndex.from_arrays(
                indices,
                names=update_dict_names
            )
        )
        df["returns_ratio"] = df["returns_end"] / df["returns_start"]
    else:
        df = pd.DataFrame(
            {
                "returns": jnp.mean(jnp.mean(metrics["V_UL"][...,-1:], -1), 0),
                "std": jnp.nanstd(jnp.mean(metrics["V_UL"][...,-1:], -1), 0)
            },
            index=pd.Index(["all"])
        )
    dataframes[algorithm] = df
    run_metrics[algorithm] = metrics
    configs[algorithm] = config
    print("Algorithm: ", algorithm)
    if len(update_dict) > 0:
        display(df.sort_values("returns_ratio", ascending=False))
    else:
        display(df)

# Experiment plots

In [None]:
rolling = 100  # Rolling mean smoothing
algorithms = {
    "hpgd":  r"\textsc{Bilevel}",
    "benchmark": r"\textsc{AMD}",
    "zero_order": r"\textsc{Zero-order}",
}
y_label_map = {
    "V_UL": rf"Upper-level objective: $F$",
    "V_LL": rf"Lower-level objective: $J(\pi)$",
    "tax": rf"Tax rates",
}

for metric_name in [
    "V_LL",
    "V_UL",
    "tax",
]:
    UL_rewards = {}
    for j, algorithm in enumerate(algorithms):
        if algorithm not in run_metrics:
            print("Skipping: ", algorithm)
            continue
        else:
            metrics = run_metrics[algorithm]
        if metric_name in ["V_UL", "V_LL"]:
            arr = metrics[metric_name]  # (n_experiments, n_parameters, n_steps)
            print(f"{algorithm} {metric_name}: ", jnp.mean(arr[:,-10:,...], axis=(0, 1)))
            UL_rewards[algorithm] = arr
        elif metric_name == "tax":
            arr = metrics["income_tax"]  # (n_experiments, n_parameters, n_steps)
            UL_rewards[f"{algorithm}_income_tax"] = arr
            arr = metrics["vat"]  # (n_experiments, n_parameters, n_steps)
            for i in range(arr.shape[-1]):
                UL_rewards[f"{algorithm}_vat_{i+1}"] = arr[...,i]
        else:
            assert False

    if metric_name in ["V_UL", "V_LL"]:
        plot_UL_rewards(
            UL_rewards,
            figsize=(10, 6),
            legend_position={
                "loc": "lower center",
                "bbox_to_anchor": (0.5, -0.02),
                "ncol": 3,
            },
            legend_names={
                "hpgd": r"\textsc{HPGD}",
                "benchmark": r"\textsc{AMD}",
                "zero_order": r"\textsc{Zero-order}",
            },
            line_styles={
                "hpgd": "-",
                "benchmark": "--",
                "zero_order": "-.",
            },
            algo_colors={
                "hpgd": colors[0],
                "benchmark": colors[1],
                "zero_order": colors[2],
            },
            zorder={
                "hpgd": 3,
                "benchmark": 2,
                "zero_order": 1,
            },
            y_ticks=None,
            ylabel= y_label_map[metric_name],
            savefig_path=f"{dir}/{metric_name}.pdf"
        )
    else:
        plot_UL_rewards(
            UL_rewards,
            figsize=(11, 6),
            legend_names={
                "hpgd_income_tax": r"\textsc{HPGD} $x$",
                "hpgd_vat_1": r"\textsc{HPGD} $y_1$",
                "hpgd_vat_2": r"\textsc{HPGD} $y_2$",
                "hpgd_vat_3": r"\textsc{HPGD} $y_3$",
                "benchmark_income_tax": r"\textsc{AMD} $x$",
                "benchmark_vat_1": r"\textsc{AMD} $y_1$",
                "benchmark_vat_2": r"\textsc{AMD} $y_2$",
                "benchmark_vat_3": r"\textsc{AMD} $y_3$",
                "zero_order_income_tax": r"\textsc{Zero-order} $x$",
                "zero_order_vat_1": r"\textsc{Zero-order} $y_1$",
                "zero_order_vat_2": r"\textsc{Zero-order} $y_2$",
                "zero_order_vat_3": r"\textsc{Zero-order} $y_3$",
            },
            line_styles={
                "hpgd_income_tax": "-",
                "hpgd_vat_1": "--",
                "hpgd_vat_2": "-.",
                "hpgd_vat_3": ":",
                "benchmark_income_tax": "-",
                "benchmark_vat_1": "--",
                "benchmark_vat_2": "-.",
                "benchmark_vat_3": ":",
                "zero_order_income_tax": "-",
                "zero_order_vat_1": "--",
                "zero_order_vat_2": "-.",
                "zero_order_vat_3": ":",
            },
            algo_colors={
                "hpgd_income_tax": colors[0],
                "hpgd_vat_1": colors[0],
                "hpgd_vat_2": colors[0],
                "hpgd_vat_3": colors[0],
                "benchmark_income_tax": colors[1],
                "benchmark_vat_1": colors[1],
                "benchmark_vat_2": colors[1],
                "benchmark_vat_3": colors[1],
                "zero_order_income_tax": colors[2],
                "zero_order_vat_1": colors[2],
                "zero_order_vat_2": colors[2],
                "zero_order_vat_3": colors[2],
            },
            zorder={
                "hpgd_income_tax": 3,
                "hpgd_vat_1": 3,
                "hpgd_vat_2": 3,
                "hpgd_vat_3": 3,
                "benchmark_income_tax": 2,
                "benchmark_vat_1": 2,
                "benchmark_vat_2": 2,
                "benchmark_vat_3": 2,
                "zero_order_income_tax": 1,
                "zero_order_vat_1": 1,
                "zero_order_vat_2": 1,
                "zero_order_vat_3": 1,
            },
            y_ticks=None,
            ylim=(0.1, 0.8),
            ylabel= y_label_map[metric_name],
            savefig_path=f"{dir}/{metric_name}.pdf"
        )