In [None]:
import pickle
import pandas as pd
import matplotlib.pyplot as plt
import yaml
import orbax
import jax
import jax.numpy as jnp
from src.algorithms.value_iteration_and_prediction import general_value_iteration
from train_stochastic_bilevel_opt import environment_setup
from visualization_functions import plot_UL_rewards, plot_incentive_grid

# Parameters for plotting
linewidth = 3
plt.rcParams.update(
    {
        'font.size': 26,
        'text.usetex': True,
        'axes.linewidth': linewidth,
        'xtick.major.width': linewidth,
        'ytick.major.width': linewidth,
        'xtick.major.size': 2*linewidth,
        'ytick.major.size': 2*linewidth,
        'axes.prop_cycle': plt.cycler(color=plt.cm.Dark2.colors),
        "lines.linewidth": linewidth,
    }
)
colors = plt.cm.Dark2.colors
save_figures = True

# Load the data

In [None]:
dir = "../data/results/4Rooms_grid_search_lambda_0_001"
rolling_window = 100
plot_top = 10

grid_search_outputs = {}
grid_search_params = {}
grid_search_train_states = {}
summary_dfs = {}

with open(f"{dir}/config.yaml", "r") as f:
    config = yaml.load(f, Loader=yaml.FullLoader)
print(config)
rng = jax.random.PRNGKey(config["random_seed"])

# Environment Setup
print("Environment Setup")
env, env_params, incentive_train_state, config_incentive = environment_setup(rng, config)

algorithms = ["bilevel", "benchmark", "zero_order"]
for algo in algorithms:
    try:
        print(f"\n--- {algo} ---")
        with open(f"{dir}/metrics_{algo}.pkl" if algo != "bilevel" else f"{dir}/metrics_on_policy.pkl", "rb") as f:
            outputs = pickle.load(f)
        grid_search_outputs[algo] = outputs
        with open(f"{dir}/grid_search_{algo}.pkl" if algo != "bilevel" else f"{dir}/grid_search.pkl", "rb") as f:
            grid_params = pd.DataFrame(pickle.load(f))
        grid_search_params[algo] = grid_params
        print("Output shape: ", outputs["UL_initial_value"].shape)
        print("Grid param keys: ", grid_search_params[algo].columns)

        # Load orbax checkpoint
        ckpt_path = f"{dir}/checkpoint_incentive_{algo}" if algo != "bilevel" else f"{dir}/checkpoint_incentive_on_policy"
        orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
        grid_search_train_states[algo] = orbax_checkpointer.restore(ckpt_path)
        # print(incentive_train_state)
        print("Incentive Train State Loaded")
    except:
        print(f"Failed to load {algo}")
        continue

    df = grid_params.copy()
    df["UL_initial_value"] = jnp.mean(outputs["UL_initial_value"][:, -1000:], axis=1)
    df = df.set_index(list(grid_params.keys()))
    df = df.groupby(list(grid_params.keys())).mean()
    df = df.sort_values("UL_initial_value", ascending=False)
    display(df)
    summary_dfs[algo] = df

# Performance Table

In [None]:
# Find the best parameters for each reg_lambda
best_parameters_for_reg_lambda = {}
for algo in algorithms:
    df_grid_params = grid_search_params[algo]
    grouped = df_grid_params.groupby(list(df_grid_params.columns))
    df = summary_dfs[algo]
    best_params = {}
    for reg_lambda in jnp.unique(df_grid_params["incentive_reg_grid"].values):
        reg_lambda = float(reg_lambda)
        df_selected = df.loc[(slice(None), reg_lambda), :]
        best_params[reg_lambda] = df_selected.index[0]
    best_parameters_for_reg_lambda[algo] = best_params

# Create a summary table
df_summary_mean = pd.DataFrame()
df_summary_std_error = pd.DataFrame()
for reg_lambda in jnp.unique(grid_params["incentive_reg_grid"].values):
    reg_lambda = float(reg_lambda)
    for algo in algorithms:
        df_grid_params = grid_search_params[algo]
        outputs = grid_search_outputs[algo]
        idx = (df_grid_params == best_parameters_for_reg_lambda[algo][reg_lambda]).all(1).values
        UL_estimate = jnp.mean(outputs["UL_initial_value"][idx, -1000:], -1)
        df_summary_mean.loc[reg_lambda, algo] = jnp.mean(UL_estimate)  # Average across seeds
        df_summary_std_error.loc[reg_lambda, algo] = jnp.std(UL_estimate)/jnp.sqrt(UL_estimate.shape[0])  # Standard error across seeds
print("Average Upper-Level Value (Last 1000 steps)")
display(df_summary_mean)
print("Std. Error Upper-Level Value (Last 1000 steps)")
display(df_summary_std_error)

# Visualize the learning curves

In [None]:
y_lim = {
    1: (0.2, 1.1),
    2: (0.0, 1.4),
    3: (0.0, 0.8),
    4: (0.0, 0.8),
    5: (0.0, 0.7),
}
# y_lim = {
#     1: (0.1, 1.1),
#     2: (0.0, 0.8),
#     3: (0.0, 0.8),
#     4: (-0.5, 0.6),
#     5: (-0.5, 0.6),
# }
# y_lim = {
#     1: (0.2, 1.4),
#     2: (0.0, 1.4),
#     3: (0.0, 1.4),
#     4: (-0.5, 1.2),
#     5: (-0.5, 1.2),
# }

for reg_lambda in jnp.unique(grid_params["incentive_reg_grid"].values):
    reg_lambda = float(reg_lambda)
    print(f"\n--- Reg_lambda: {reg_lambda} ---")
    input_data = {}
    for algo in algorithms:
        df_grid_params = grid_search_params[algo]
        outputs = grid_search_outputs[algo]
        idx = (df_grid_params == best_parameters_for_reg_lambda[algo][reg_lambda]).all(1).values  # Shape: (n_grid_params,)
        input_data[algo] = outputs["UL_initial_value"][idx]

    plot_UL_rewards(
        input_data,
        figsize=(12, 6),
        rolling_window=rolling_window,
        savefig_path=f"{dir}/UL_rewards_grid_search_reg_lambda_{reg_lambda}.pdf" if save_figures else None,
        xlim=(0, 10_000),
        ylim=y_lim[reg_lambda] if reg_lambda in y_lim else None,
        legend_position={
            "loc": "lower center",
            "bbox_to_anchor": (0.5, -0.02),
            "ncol": 3,
        },
        legend_names = {
            "bilevel": r"\textsc{HPGD}",
            "benchmark": r"\textsc{AMD}",
            "zero_order": r"\textsc{Zero-order}",
        },
        line_styles = {
            "bilevel": "-",
            "benchmark": "--",
            "zero_order": (0, (2, 4, 2, 4)),
        },
        algo_colors={
            "bilevel": colors[0],
            "benchmark": colors[1],
            "zero_order": colors[2],
        },
        zorder={
            "bilevel": 3,
            "benchmark": 2,
            "zero_order": 1,
        }
    )
    plt.close()
plt.clf()

# Visualize the penalty maps

In [None]:
seed_idx = 0
from src.environments.ConfigurableFourRooms import map_project
for reg_lambda in jnp.unique(grid_params["incentive_reg_grid"].values):
    reg_lambda = float(reg_lambda)
    config_tmp = config.copy()
    config_tmp["upper_optimisation"]["incentive_reg_param"] = 1.0
    print(f"\n--- Reg_lambda: {reg_lambda} ---")
    input_data = {}
    for algo in algorithms:
        df_grid_params = grid_search_params[algo]
        train_state = grid_search_train_states[algo]
        idx = (df_grid_params == best_parameters_for_reg_lambda[algo][reg_lambda]).all(1).values  # Shape: (n_grid_params,)
        input_data[algo] = jax.tree_map(lambda x: jnp.array(x)[idx], train_state["params"])

    fig, axes = plt.subplots(1, len(input_data), figsize=(21, 7), constrained_layout=True)
    label_names = {
        "bilevel": r"\textsc{HPGD}",
        "benchmark": r"\textsc{AMD}",
        "zero_order": r"\textsc{Zero-order}",
    }
    for i, (algo, incentive_train_state) in enumerate(input_data.items()):
        print(f"\n--- {algo} ---")
        axes[i].set_title(label_names[algo])
        pcm = plot_incentive_grid(
            env,
            env_params,
            incentive_train_state["params"]["weights"][seed_idx],
            config_incentive["coordinates"],
            config_tmp,
            verbose=False,
            plot_input=(fig, axes[i]),
            cmap="PuRd_r",
        )

        # Add policy steps to the map
        init_probs = env.state_initialization_distribution(env_params.state_initialization_params).probs
        init_probs_mask = init_probs > 1e-8
        env_params_viz = env_params.replace(
            incentive_params=jax.tree_map(lambda x: x[seed_idx], incentive_train_state)
        )
        config_lower_level = config["lower_optimisation"]
        q_final, _ = general_value_iteration(
            env,
            env_params_viz,
            config_lower_level["discount_factor"],
            n_policy_iter=config_lower_level["n_policy_iter"],
            n_value_iter=config_lower_level["n_value_iter"],
            regularization=config_lower_level["regularization"],
            reg_lambda=config_lower_level["reg_lambda"],
            return_q_value=True,
        )
        br_policy = jax.nn.softmax(q_final/config_lower_level["reg_lambda"], axis=-1)  # Shape: (n_goals, n_states, n_actions)
        for j in range(len(env.available_goals)):
            goal_pos = jnp.array(config["environment"]["available_goals"][j])
            pos = env.available_init_pos[init_probs_mask][0]
            # Add the policy trajectories to the map
            for _ in range(30):
                try:
                    pos_idx = jnp.where(jnp.all(env.coords == pos[None, :], 1))[0][0]
                except:
                    print(pos)
                    break
                action_sort = jnp.argsort(br_policy[j, pos_idx])[::-1]
                for action in action_sort:
                    action_direction = env.directions[action]
                    new_pos = map_project(env.env_map, pos, pos + action_direction)
                    if not jnp.all(new_pos == pos):
                        break
                color = "gray"
                axes[i].arrow(
                    pos[1],
                    pos[0],
                    action_direction[1]/2.0,
                    action_direction[0]/2.0,
                    head_width=0.1,
                    head_length=0.1,
                    linewidth=linewidth,
                    fc=color,
                    ec=color,
                    alpha=0.9,
                )
                pos = new_pos
                if jnp.all(pos == goal_pos):
                    break


        # Add goal and initial position annotations
        for j in range(len(env.available_goals)):
            goal_pos = config["environment"]["available_goals"][j]
            axes[i].annotate(
                rf"$\textbf{{G}}^{j+1}$",
                xy=(goal_pos[1], goal_pos[0]),
                xycoords="data",
                xytext=(goal_pos[1] - 0.3, goal_pos[0] + 0.25),
            )
        init_states_counter = 1
        for pos in env.available_init_pos[init_probs_mask]:
            axes[i].annotate(
                rf"$\textbf{{S^{init_states_counter}}}$" if sum(init_probs_mask) > 1 else rf"$\textbf{{S}}$",
                # fontsize=20,
                weight="bold",
                xy=(pos[1], pos[0]),
                xycoords="data",
                xytext=(pos[1]-0.3, pos[0]+0.25),

            )
            init_states_counter += 1
        pos = config["upper_optimisation"]["reward_function"]["target_state"]
        axes[i].annotate(
            rf"$\textbf{{+1}}$" if config["upper_optimisation"]["reward_function"]["type"] == "positive" else rf"$\textbf{{-1}}$",
            # fontsize=20,
            weight="bold",
            xy=(pos[1], pos[0]),
            xycoords="data",
            xytext=(pos[1]-0.3, pos[0]+0.25),
        )
        # Calculate the upper_level_value
        unused_pct = 100*jax.nn.softmax(incentive_train_state['params']['weights'][seed_idx])[-1]
        print(f"Unused percentage: {unused_pct:.4f}%")
    cbar = fig.colorbar(pcm, ax=axes.ravel().tolist(), shrink=0.8)
    cbar.outline.set_visible(False)
    plt.subplots_adjust(right=0.98)
    if save_figures:
        fig.savefig(f"{dir}/incentive_grid_grid_search_seed_{seed_idx}_reg_lambda_{reg_lambda}.pdf", bbox_inches='tight')
    plt.show()
    plt.close()