In [None]:
%load_ext autoreload
%autoreload 2
%config Completer.use_jedi = False

Notebook local path should be at `ScientificValueAgent/production_figures`.

In [None]:
import sys
sys.path.append("..")

In [None]:
from itertools import product
from pathlib import Path

import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
from tqdm import tqdm
from scipy.spatial import distance_matrix

In [None]:
from sva import utils

Set some plotting defaults.

In [None]:
utils.set_defaults()

# Two-phase sine result

In [None]:
from sva.postprocessing import read_data, parse_results_by_acquisition_function
from sva.truth.sine2phase import (
    phase_1_sine_on_2d_raster,
    truth_sine2phase,
    sine2phase_interpolant_2d,
    sine2phase_compute_metrics_all_acquisition_functions_and_LTB,
    sine2phase_residual_2d_phase_relative_mae,
)
from sva.truth.common import limited_time_budget, get_phase_plot_info

Get the results. This is a few GB of raw data and thus could take a little time to load.

In [None]:
results = read_data("../results/results_23-05-02-sine2phase/")

In [None]:
results_by_acqf = parse_results_by_acquisition_function(results)

## Core manuscript figure

In [None]:
x, y, Z = get_phase_plot_info(phase_1_sine_on_2d_raster)
X, Y = np.meshgrid(x, y)
gradZ = np.array(np.gradient(Z))
gradZ = np.sqrt((gradZ**2).sum(axis=0))

In [None]:
extent = (x[0], x[-1], y[0], y[-1])
scale = 1
lw = 0.5

g = np.linspace(0, 1, 200)

### Subfigure (a)

In [None]:
color="black"

fig, ax = plt.subplots(1, 1, figsize=(2*scale, 2*scale), sharex=True, sharey=True)

im = ax.imshow(Z, interpolation='bilinear', origin='lower', cmap=mpl.cm.seismic, extent=extent, alpha=0.6)

ax.text(0.05, 0.95, r"(a)", ha="left", va="top", transform=ax.transAxes, color=color)
ax.text(0.95, 0.95, "Phase 1", ha="right", va="top", transform=ax.transAxes, color=color)
ax.text(0.05, 0.05, "Phase 2", ha="left", va="bottom", transform=ax.transAxes, color=color)

ax.set_xticks([0, 1])
ax.set_yticks([0, 1])

utils.set_grids(ax, grid=False)
ax.set_ylabel("$x_2$~[a.u.]")
ax.set_xlabel("$x_1$~[a.u.]")

plt.savefig("figures_sine2phase/sine2phase_subfigure_a.svg", dpi=300, bbox_inches="tight")
# plt.show()

### Subfigure $p(x)$ colorbar

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(2*scale, 2*scale), sharex=True, sharey=True)

im = ax.imshow(Z, interpolation='bilinear', origin='lower', cmap=mpl.cm.seismic, extent=extent, alpha=0.6)

cbar = utils.add_colorbar(im, aspect=20)
cbar.set_ticks([0, 1])
cbar.set_label(r"$p(\mathbf{x})$", labelpad=-10)

ax.remove()

plt.savefig("figures_sine2phase/sine2phase_p_cbar.svg", dpi=300, bbox_inches="tight")
# plt.show()

### Subfigure (b)

In [None]:
single_result = results_by_acqf["UpperConfidenceBound10"][0]
Z = single_result._record[-1]["mu"].reshape(50, 50).T
Z = Z - Z.min()
Z = Z / Z.max()
cmap = mpl.cm.magma

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(2*scale, 2*scale), sharex=True, sharey=True)

im = ax.imshow(Z, interpolation='bilinear', origin='lower', cmap=cmap, extent=extent, alpha=0.4)

sine = 0.5 + np.sin(2.0 * np.pi * g) / 4
sine_upper = sine + 0.05
sine_lower = sine - 0.05

# ax.fill_between(g, sine_lower, sine_upper, hatch="|"*10, color=None, alpha=0, linewidth=0)
color="k"
ax.plot(g, sine, f"{color}--", linewidth=lw,)
# ax.plot(g, sine_upper, f"{color}--", linewidth=lw)
# ax.plot(g, sine_lower, f"{color}--", linewidth=lw)

X = single_result.data.X
ax.scatter(X[:, 0], X[:, 1], color="black", s=0.5, zorder=3)

ax.text(0.05, 0.95, r"(b)", ha="left", va="top", transform=ax.transAxes)

ax.set_xticks([0, 1])
ax.set_yticks([0, 1])

utils.set_grids(ax, grid=False)
ax.set_ylabel("$x_2$~[a.u.]")
ax.set_xlabel("$x_1$~[a.u.]")

plt.savefig("figures_sine2phase/sine2phase_subfigure_b.svg", dpi=300, bbox_inches="tight")
# plt.show()

### Subfigure value colorbar

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(2*scale, 2*scale), sharex=True, sharey=True)

im = ax.imshow(Z, interpolation='bilinear', origin='lower', cmap=cmap, extent=extent, alpha=0.4)

cbar = utils.add_colorbar(im, aspect=20)
cbar.set_ticks([0, 1])
cbar.set_label(r"$U(\mathbf{x})$", labelpad=-10)

ax.remove()

plt.savefig("figures_sine2phase/sine2phase_U_cbar.svg", dpi=300, bbox_inches="tight")
# plt.show()

### Subfigure (c): select metrics

In [None]:
acquisition_function_name_maps = {
    "Linear": "LTB",
    "UpperConfidenceBound10": "UCB(10)",
    "ExpectedImprovement": "EI",
    "UpperConfidenceBound1": "UCB(1)",
    "UpperConfidenceBound20": "UCB(20)",
    "UpperConfidenceBound100": "UCB(100)"
}

In [None]:
metrics_grid = list(range(6, 251, 6)) + [250]
metrics_grid = np.unique(metrics_grid).tolist()
metrics_grid_linear = [ii for ii in range(2, 16)]
grid_points = 200
metric = "mse"

In [None]:
all_metrics_linear = sine2phase_compute_metrics_all_acquisition_functions_and_LTB(
    results_by_acqf,
    metrics_grid=metrics_grid,
    metrics_grid_linear=metrics_grid_linear,
    metric=metric,
    grid_points=grid_points,
    interpolation_method="linear",
    disable_pbar=False,
)

In [None]:
only_plot = ["LTB", "EI", "UCB(10)"]

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(1.5, 2))
    
for acquisition_function_name in acquisition_function_name_maps.keys():

    values = all_metrics_linear["metrics"][acquisition_function_name]
    metrics_grid = all_metrics_linear["metrics_grid"]
    metrics_grid_linear = all_metrics_linear["metrics_grid_linear"]
    label = acquisition_function_name_maps[acquisition_function_name]
    
    if only_plot is None or label in only_plot:
        mu = np.log(values).mean(axis=1)
        sd = np.log(values).std(axis=1) * 2
        label = label if label != "LTB" else "Grid"
        if label == "Grid":
            ax.plot(metrics_grid_linear, mu, label=label)
            ax.fill_between(metrics_grid_linear, mu - sd, mu + sd, linewidth=0, alpha=0.3)
        else:
            ax.plot(metrics_grid, mu, label=label)
            ax.fill_between(metrics_grid, mu - sd, mu + sd, linewidth=0, alpha=0.3)

utils.set_grids(ax)
ax.tick_params(which="minor", bottom=False, top=False)
ax.set_xticks([4, 50, 150, 250])
# ax.set_yscale('log')

# yticks = np.array([-1, -2, -3])
# ax.set_yticks((10.0**yticks).tolist())
# ax.set_yticklabels([f"${ii}$" for ii in yticks])

# ax.axhline(0.1, color="black", linewidth=0.5, linestyle="--", zorder=-1)

# ax.legend(frameon=False, bbox_to_anchor=(1, 0.5), loc="center left")
ax.legend(frameon=False, loc="upper right", fontsize=8)
# ax.text(0.05, 0.05, r"$\mu \pm \sigma / 3$", ha="left", va="bottom", transform=ax.transAxes)
ax.text(0.05, 0.05, r"(c)", ha="left", va="bottom", transform=ax.transAxes)
ax.set_ylim(top=0.5, bottom=-7.5)

ax.set_xlabel(r"$N$")
ax.set_ylabel(r"$\ln$(MSE)")

plt.savefig("figures_sine2phase/sine2phase_subfigure_c.svg", dpi=300, bbox_inches="tight")
# plt.show()

## Plot of the basis functions

In [None]:
from sva.truth.common import mu_Gaussians

In [None]:
E = np.linspace(-1, 1, 100)

In [None]:
pure1 = mu_Gaussians(0)
pure2 = mu_Gaussians(1)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(3, 2))

ax.plot(E, pure1, color="red", label="Phase 1")
ax.plot(E, pure2, color="blue", label="Phase 2")

ax.legend(frameon=False, bbox_to_anchor=(1.0, 0.5), loc="center left")
ax.set_xlabel("$Q$~[a.u.]")
ax.set_ylabel("$I(Q)$~[a.u.]")

plt.savefig("sine2phase_phases.pdf", bbox_inches="tight", dpi=300)