In [None]:
import math
import warnings
from functools import partial

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
import wandb

%matplotlib inline
from carbs import CARBS
from carbs import ObservationInBasic
from carbs import get_pareto_curve_plot
from carbs import load_latest_checkpoint_from_wandb_run
from carbs import observation_group_cost
from carbs import observation_group_output
from matplotlib import MatplotlibDeprecationWarning
from matplotlib.ticker import LogFormatter
from scipy.interpolate import interp1d
from sklearn.linear_model import LinearRegression

from research.quarantine.abe.plot_helpers import set_axes_style

In [None]:

api = wandb.Api()
run_path = "sourceress/abe__bones/2i8gnlf9"
run = api.run(run_path)
history_df = run.history()
history_df = history_df.replace("Infinity", float("-inf"))

In [None]:

carbs_checkpoint_path = load_latest_checkpoint_from_wandb_run(run_path)
carbs = CARBS.load_from_file(carbs_checkpoint_path)
search_vars = list(carbs._real_number_space_by_name.keys())
search_space_scale = {k: v.plot_scale for k, v in carbs._real_number_space_by_name.items()}

In [None]:

# Performance

is_best_shown = True
is_resampled_shown = True
is_search_space_shown = False
performance_min, performance_max = (3, 6)

# sns.set(rc={"figure.figsize": (12, 4)})
# sns.set_theme(style="whitegrid")
cmap = sns.color_palette("viridis" if carbs.params.better_direction_sign > 0 else "viridis_r", as_cmap=True)

output_observation_df = history_df[["observation_count", f"observation/output"]].dropna()
observation_x, observation_y = output_observation_df.to_numpy().T

if is_resampled_shown:
    resampled_df = history_df[
        ["observation_count", "best_resampled_observation/output_mean", "best_resampled_observation/output_std_dev"]
    ].dropna()
    resampled_x, resampled_mean, resampled_std = resampled_df.to_numpy().T
    plt.plot(resampled_x, resampled_mean, color="green", linewidth=2, label="best parameters mean", linestyle="dotted")
    plt.fill_between(
        resampled_x,
        resampled_mean - resampled_std,
        resampled_mean + resampled_std,
        color="green",
        alpha=0.1,
        label="best parameters variance",
    )
if is_best_shown:
    output_best_observation_df = history_df[["observation_count", f"best_observation/output"]].dropna()
    best_observation_x, best_observation_y = output_best_observation_df.to_numpy().T
    plt.plot(best_observation_x, best_observation_y, linestyle="dashed", linewidth=2, label="best single observation")
plt.scatter(
    observation_x,
    observation_y,
    c=observation_y,
    s=20,
    label="observation",
    cmap=cmap,
    vmin=performance_min,
    vmax=performance_max,
)
plt.title("Combined performance metric")
plt.xlabel("Observation count")
plt.ylabel("Performance")
plt.ylim(performance_min, performance_max)
plt.legend()
plt.show()

In [None]:

"""
Convergence by parameter

Color matches above plot: lighter yellow = better performance.
"Best" observation shows input parameter for the best output performance so far
"""
for search_var in search_vars[:]:
    search_var_observation_df = history_df[
        ["observation_count", f"observation/{search_var}", "observation/output"]
    ].dropna()
    observation_x, observation_y, observation_z = search_var_observation_df.to_numpy().T

    if is_best_shown:
        search_var_best_observation_df = history_df[["observation_count", f"best_observation/{search_var}"]].dropna()
        best_observation_x, best_observation_y = search_var_best_observation_df.to_numpy().T
        plt.plot(
            best_observation_x, best_observation_y, linestyle="dashed", linewidth=2, label="best single observation"
        )
    if is_resampled_shown:
        search_var_resampled_observation_df = history_df[
            ["observation_count", f"best_resampled_observation/{search_var}"]
        ].dropna()
        resampled_observation_x, resampled_observation_y = search_var_resampled_observation_df.to_numpy().T
        plt.plot(
            resampled_observation_x,
            resampled_observation_y,
            linestyle="dotted",
            color="green",
            linewidth=2,
            label="best parameter value",
        )
    search_var_name = search_var.replace("_", " ").replace("pdrop", "dropout")

    plt.scatter(
        observation_x,
        observation_y,
        c=observation_z,
        s=20,
        label="observation",
        cmap=cmap,
        vmin=performance_min,
        vmax=performance_max,
    )
    plt.title(f"{search_var_name} convergence")
    plt.xlabel("Observation count")
    plt.ylabel(search_var)
    plt.yscale(search_space_scale[search_var])
    plt.legend()
    plt.show()

In [None]:

# Pareto curve plot

pareto_groups = carbs._get_pareto_groups(True)

get_pareto_curve_plot(carbs.observations_in_basic, pareto_groups, obs_count=carbs.observation_count)
plt.ylim(performance_min, performance_max)

In [None]:


surrogate_model = carbs.get_surrogate_model()
surrogate_model.fit_observations(carbs.observations_in_basic)

In [None]:

# Get loguniform inputs by interpolating the pareto points from the random sampling
num_uniform_inputs = 30
num_contour_levels = 10

pareto_costs = [observation_group_cost(x) for x in pareto_groups]
pareto_logcosts = [math.log(x) for x in pareto_costs]
pareto_outputs = [observation_group_output(x) for x in pareto_groups]
uniform_logcosts = np.linspace(min(pareto_logcosts), max(pareto_logcosts), num=num_uniform_inputs)
pareto_inputs = torch.stack([x[0].real_number_input for x in pareto_groups], dim=0)

reg = LinearRegression()
reg.fit(np.array(pareto_logcosts)[:, None], pareto_inputs)

uniform_pareto_inputs = torch.from_numpy(reg.predict(np.array(uniform_logcosts)[:, None])).float()

# Then evaluate those on the surrogate
uniform_surrogate_outputs = surrogate_model.observe_surrogate(uniform_pareto_inputs)
# print(f"Outputs: {uniform_surrogate_outputs.target_estimate}")
# print(f"Cost: {uniform_surrogate_outputs.cost_estimate}")

# Filter observations to those in the range of the pareto front
observations_in_basic = [
    x for x in carbs.observations_in_basic if x.cost >= min(pareto_costs) and x.cost <= max(pareto_costs)
]
obs_cost = [x.cost for x in observations_in_basic]
obs_output = [x.output for x in observations_in_basic]
contour_levels = np.linspace(min(pareto_outputs), max(pareto_outputs), num_contour_levels)
vmin, vmax = min(pareto_outputs), max(pareto_outputs)

In [None]:


import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib import cm

scalar_map = cm.ScalarMappable(norm=mpl.colors.Normalize(vmin=vmin, vmax=vmax), cmap=cmap)

interp_pareto_value = interp1d(uniform_logcosts, uniform_pareto_inputs, axis=0, fill_value="extrapolate")
observation_pareto_distance = [
    torch.norm(torch.from_numpy(interp_pareto_value(np.log(x.cost))).float() - x.real_number_input).item()
    for x in observations_in_basic
]
search_radius = carbs.params.initial_search_radius
rescaled_observation_pareto_distance = [search_radius / (search_radius + x) for x in observation_pareto_distance]
observation_marker_size = [200 * x for x in rescaled_observation_pareto_distance]
obs_color = [
    scalar_map.to_rgba(output)[:3] + (min(2 * alpha, 1),)
    for output, alpha in zip(obs_output, rescaled_observation_pareto_distance)
]

In [None]:

pareto_set = set()


def obs_to_key(obs: ObservationInBasic):
    return tuple(obs.real_number_input.tolist())


for group in pareto_groups:
    for obs in group:
        pareto_set.add(obs_to_key(obs))

obs_is_in_pareto_set = [obs_to_key(obs) in pareto_set for obs in observations_in_basic]
edgecolors = ["black" if x else "none" for x in obs_is_in_pareto_set]

In [None]:


class CustomLogFormatter(LogFormatter):
    def _num_to_string(self, x, vmin, vmax) -> str:
        return f"{int(x)}"

In [None]:

base_two_search_vars = {
    "model.n_layers",
    "model.n_heads",
    "model.kv_size",
    "model.ffw_size",
}

# Parameter variation along pareto front
# TODO: add failed observations, predict cost from GP model, make red x
sel_search_vars = search_vars
fig, axs = plt.subplots(
    nrows=np.ceil((1 + len(sel_search_vars)) / 2).astype(int),
    ncols=2,
    figsize=(14, 4 * (1 + len(sel_search_vars)) // 2),
    sharex=True,
)
fig.tight_layout()
fig.subplots_adjust(hspace=0.2, wspace=0.2)
axs = axs.flatten()
warnings.simplefilter("ignore", MatplotlibDeprecationWarning)
for search_var_idx, search_var in enumerate(sel_search_vars):
    # search_var_idx += 10
    param_from_basic = carbs._real_number_space_by_name[search_var].param_from_basic
    obs_search_var = [param_from_basic(x.real_number_input[search_var_idx].item()) for x in observations_in_basic]

    num_search_var_grid_points = 50
    search_var_linspace_in_basic = torch.linspace(
        min([x.real_number_input[search_var_idx].item() for x in observations_in_basic]),
        max([x.real_number_input[search_var_idx].item() for x in observations_in_basic]),
        steps=num_search_var_grid_points,
    )
    # search_var_linspace_in_basic
    input_grid = uniform_pareto_inputs.repeat(num_search_var_grid_points, 1, 1)
    for i in range(num_uniform_inputs):
        input_grid[:, i, search_var_idx] = search_var_linspace_in_basic
    input_grid_flat = input_grid.view(-1, carbs.real_dim)
    surrogate_output_on_flat_grid = surrogate_model.observe_surrogate(input_grid_flat)
    cost_grid = surrogate_output_on_flat_grid.cost_estimate.view(num_search_var_grid_points, num_uniform_inputs).cpu()
    output_grid = surrogate_output_on_flat_grid.target_estimate.view(
        num_search_var_grid_points, num_uniform_inputs
    ).cpu()
    search_var_grid = input_grid[:, :, search_var_idx].cpu()
    search_var_grid.apply_(partial(param_from_basic, is_rounded=False))

    ax = axs[search_var_idx]
    contour_plot = ax.contour(
        cost_grid,
        search_var_grid,
        output_grid,
        cmap=cmap,
        vmin=vmin,
        vmax=vmax,
        levels=contour_levels,
    )
    pareto_search_var = [
        param_from_basic(x.item(), is_rounded=False) for x in uniform_pareto_inputs[:, search_var_idx]
    ]
    (pareto_line,) = ax.plot(
        np.exp(uniform_logcosts), pareto_search_var, color="black", linewidth=2, linestyle="dashed"
    )

    scatter_plot = ax.scatter(
        obs_cost,
        obs_search_var,
        c=obs_color,
        s=observation_marker_size,
        edgecolors=edgecolors,
    )
    if "." in search_var:
        ax.set_ylabel(search_var.split(".")[-1])
    else:
        ax.set_ylabel(search_var)
    if search_var in base_two_search_vars:
        ax.set_yscale("log", base=2)
        ax.yaxis.set_major_formatter(CustomLogFormatter(base=2.0, labelOnlyBase=True))
        ax.yaxis.set_minor_formatter(CustomLogFormatter(base=2.0, labelOnlyBase=False, minor_thresholds=(10.0, 0.1)))
        ax.xaxis.set_major_formatter(LogFormatter(labelOnlyBase=True))
        ax.xaxis.set_minor_formatter(LogFormatter(labelOnlyBase=False, minor_thresholds=(10.0, 0.1)))
    else:
        ax.set_yscale(search_space_scale[search_var])
    ax.set_xscale("log")
    ax.set_xlabel("Cost")
    ax.set_xlim(min(pareto_costs), max(pareto_costs))
    set_axes_style(ax, grid="both")

fig.legend(["Pareto front (fit)", "Observations"])
# cbar_ax = fig.add_axes([0.96, 0.2, 0.02, 0.6])
cbar = fig.colorbar(mappable=scalar_map, location="bottom")  # , cax=cbar_ax
for ax in axs[len(sel_search_vars) : len(axs)]:
    fig.delaxes(ax)
# fig.colorbar()
# cbar = fig.colorbar(contour_plot)
cbar.set_label("Validation Cross Entropy")
plt.savefig("/home/user/hyperspace_appendix_plot_2.pdf", bbox_inches="tight")
plt.show()