In [None]:
%load_ext autoreload
%autoreload 2
%config Completer.use_jedi = False

Notebook local path should be at `ScientificValueAgent/figures`.

In [None]:
import sys
sys.path.append("..")

In [None]:
from collections import Counter
from itertools import product
from pathlib import Path

import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import pickle
from tqdm import tqdm

In [None]:
from sva import utils

Set some plotting defaults.

In [None]:
utils.set_defaults()

# Multi-phase one-dimensional XRD results

In [None]:
from sva.postprocessing import read_data, parse_results_by_acquisition_function
from sva.truth.xrd1dim import (
    residual_1d_phase_relative_mae,
    xrd1dim_compute_metrics_all_acquisition_functions_and_LTB,
    _get_1d_phase_fractions,
)

In [None]:
results_LGBFS = read_data("results/results_22-12-21_xrd1dim")
results_Adam = read_data("results/results_22-12-21_xrd1dim_Adam")

In [None]:
results_by_acqf_LGBFS = parse_results_by_acquisition_function(results_LGBFS)
results_by_acqf_Adam = parse_results_by_acquisition_function(results_Adam)

In [None]:
cache = Path("cache")
cache.mkdir(exist_ok=True)

## Core manuscript figure

### Subfigure (a) and (b)

In [None]:
acquisition_function = "UpperConfidenceBound10"
all_results_Adam = np.array([xx.data.X.squeeze() for xx in results_by_acqf_Adam[acquisition_function]])

Get the phases...

In [None]:
x_grid = np.linspace(0, 100, 1000)
phases = _get_1d_phase_fractions(x_grid).T

Resolve by the experiment iteration...

In [None]:
all_results_Adam_n_resolved = [all_results_Adam[:, :nn].flatten() for nn in range(3, all_results_Adam.shape[1] + 1)]

In [None]:
all_results_Adam_n_resolved_coordinates = []
for ii, res in enumerate(all_results_Adam_n_resolved):
    n = len(res)
    coords = (np.ones(shape=(n,)) * ii).astype(int)
    arr = np.array([res, coords]).T
    all_results_Adam_n_resolved_coordinates.append(arr)
all_results_Adam_n_resolved_coordinates = np.concatenate(all_results_Adam_n_resolved_coordinates, axis=0)
all_results_Adam_n_resolved_coordinates[:, 1] += 3

In [None]:
vmax = 1500

In [None]:
fig, axs = plt.subplots(2, 1, figsize=(3, 3), gridspec_kw={'height_ratios':[1, 2]}, sharex=True)

ax = axs[0]
ax.plot(x_grid, phases[:, 0], label="1")
ax.plot(x_grid, phases[:, 1], label="2")
ax.plot(x_grid, phases[:, 2], label="3")
ax.plot(x_grid, phases[:, 3], label="4")
axlims = ax.get_ylim()
ax.text(1.05, 1.0, "Phase", ha="left", va="bottom", transform=ax.transAxes)
ax.text(0.025, 0.9, "(a)", ha="left", va="top", transform=ax.transAxes)
ax.legend(frameon=False, bbox_to_anchor=(1.0, 0.5), loc="center left")

ax.fill_betweenx(np.linspace(*axlims, 10), 10, 50, color="black", alpha=0.1, linewidth=0)
ax.fill_betweenx(np.linspace(*axlims, 10), 60, 80, color="black", alpha=0.1, linewidth=0)
ax.fill_betweenx(np.linspace(*axlims, 10), 88.5, 91.5, color="black", alpha=0.1, linewidth=0)

utils.set_grids(ax)
ax.set_ylabel("$p(x)$")
ax.set_ylim(*axlims)

ax = axs[1]
ax.hist2d(all_results_Adam_n_resolved_coordinates[:, 0], all_results_Adam_n_resolved_coordinates[:, 1], bins=[100, 247], cmap="viridis", vmax=vmax, rasterized=True)
ax.set_yticks([3, 50, 150, 250])
ax.set_xticks([0, 20, 40, 60, 80, 100])
utils.set_grids(ax)
ax.tick_params(which="minor", left=False, right=False)
ax.set_ylabel(r"$N$")
ax.set_xlabel("$x$~[a.u.]")
ax.text(0.025, 0.9, "(b)", ha="left", va="top", transform=ax.transAxes, color="white")

# plt.savefig("xrd1dim_subfigure_a.svg", dpi=300, bbox_inches="tight")
plt.show()

### Subfigure (a) colorbar

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(2, 2), sharex=True, sharey=True)

im = ax.hist2d(all_results_Adam_n_resolved_coordinates[:, 0], all_results_Adam_n_resolved_coordinates[:, 1], bins=[150, 247], cmap="viridis", vmax=vmax)

cbar = utils.add_colorbar(im[-1], aspect=20)
cbar.set_ticks([0, vmax])
cbar.set_ticklabels([0, f"$\geq$ %i" % int(vmax / 300)])
cbar.set_label(r"Average Counts", labelpad=-10)

ax.remove()

# plt.savefig("xrd1dim_cbar.svg", dpi=300, bbox_inches="tight")
plt.show()

### Subfigure (c): select metrics

The metrics for this part take a long time to calculate, so we cache them.

In [None]:
acquisition_function_name_maps = {
    "Linear": "LTB",
    "ExpectedImprovement": "EI",
    "UpperConfidenceBound1": "UCB(1)",
    "UpperConfidenceBound10": "UCB(10)",
    "UpperConfidenceBound20": "UCB(20)",
    "UpperConfidenceBound100": "UCB(100)"
}

In [None]:
metrics_grid = list(range(3, 251, 10))
linspace_points = 10000

In [None]:
path = cache / "xrd1dim_all.pkl"
if not path.exists():
    print("Recalculating...")
    _m = xrd1dim_compute_metrics_all_acquisition_functions_and_LTB(
        results_by_acqf_Adam,
        metrics_grid=metrics_grid,
        metrics_grid_linear=metrics_grid,
        metric="mse",
        grid_points=linspace_points,
        disable_pbar=False,
        xmin=0.0,
        xmax=100.0,
    )
    all_metrics = _m["metrics"]
    pickle.dump(all_metrics, open(path, "wb"), protocol=pickle.HIGHEST_PROTOCOL)
else:
    all_metrics = pickle.load(open(path, "rb"))

In [None]:
only_plot = ["LTB", "EI", "UCB(10)"]

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(2, 3))

for acquisition_function_name in acquisition_function_name_maps.keys():
    values = all_metrics[acquisition_function_name]
    label = acquisition_function_name_maps[acquisition_function_name]
    if only_plot is None or label in only_plot:
        mu = np.nanmean(values, axis=1)
        sd = np.nanstd(values, axis=1) / 3
        ax.plot(metrics_grid, mu, label=label)
        ax.fill_between(metrics_grid, mu - sd, mu + sd, linewidth=0, alpha=0.3)

utils.set_grids(ax)
ax.tick_params(which="minor", bottom=False, top=False)
ax.set_xticks([3, 50, 150, 250])

ax.legend(frameon=False, loc="upper right")
ax.text(0.1, 0.05, r"$\mu \pm \sigma / 3$", ha="left", va="bottom", transform=ax.transAxes)
ax.text(0.1, 0.95, r"(c)", ha="left", va="top", transform=ax.transAxes)

ax.set_yscale("log")
yticks = np.array([-1, -2, -3, -4, -5])
ax.set_yticks((10.0**yticks).tolist())
ax.set_yticklabels([f"${ii}$" for ii in yticks])
ax.set_ylim(10**-5.3, 10**-0.7)
ax.tick_params(axis='y', which='minor', left=True, right=True)

ax.set_xlabel(r"$N$")
ax.set_ylabel(r"$\log_{10} J$")

plt.savefig("xrd1dim_subfigure_c.svg", dpi=300, bbox_inches="tight")
# plt.show()

## Supplementary information

### Compare Adam and LGBFS

In [None]:
acquisition_function = "UpperConfidenceBound100"
all_results_LGBFS = np.array([xx.data.X.squeeze() for xx in results_by_acqf_LGBFS[acquisition_function]])
all_results_Adam = np.array([xx.data.X.squeeze() for xx in results_by_acqf_Adam[acquisition_function]])

In [None]:
bins = 50

In [None]:
fig, axs = plt.subplots(2, 1, sharex=True)

ax = axs[0]
ax.text(0.5, 0.9, "LGBFS", ha="center", va="top", transform=ax.transAxes)
ax.hist(all_results_LGBFS.flatten(), bins=bins)
ax.set_title(acquisition_function)

ax = axs[1]
ax.text(0.5, 0.9, "Adam", ha="center", va="top", transform=ax.transAxes)
ax.hist(all_results_Adam.flatten(), bins=bins)

for ax in axs:
    utils.set_grids(ax)
    ax.set_yticks([])

plt.show()

### Phase-resolve metrics

In [None]:
acquisition_function_name_maps = {
    "Linear": "LTB",
    "ExpectedImprovement": "EI",
    "UpperConfidenceBound1": "UCB(1)",
    "UpperConfidenceBound10": "UCB(10)",
    "UpperConfidenceBound20": "UCB(20)",
    "UpperConfidenceBound100": "UCB(100)"
}

In [None]:
metrics_grid = list(range(3, 251, 10))
linspace_points = 10000

In [None]:
path = cache / "xrd1dim_linear.pkl"
if not path.exists():
    print("Recalculating...")
    _m = xrd1dim_compute_metrics_all_acquisition_functions_and_LTB(
        results_by_acqf_Adam,
        metrics_grid=metrics_grid,
        metrics_grid_linear=metrics_grid,
        metric="mse",
        grid_points=linspace_points,
        disable_pbar=False,
        xmin=9.0,
        xmax=51.0,
    )
    all_metrics_linear = _m["metrics"]
    pickle.dump(all_metrics_linear, open(path, "wb"), protocol=pickle.HIGHEST_PROTOCOL)
else:
    all_metrics_linear = pickle.load(open(path, "rb"))

In [None]:
path = cache / "xrd1dim_quad.pkl"
if not path.exists():
    print("Recalculating...")
    _m = xrd1dim_compute_metrics_all_acquisition_functions_and_LTB(
        results_by_acqf_Adam,
        metrics_grid=metrics_grid,
        metrics_grid_linear=metrics_grid,
        metric="mse",
        grid_points=linspace_points,
        disable_pbar=False,
        xmin=59.0,
        xmax=81.0,
    )
    all_metrics_quad = _m["metrics"]
    pickle.dump(all_metrics_quad, open(path, "wb"), protocol=pickle.HIGHEST_PROTOCOL)
else:
    all_metrics_quad = pickle.load(open(path, "rb"))

In [None]:
path = cache / "xrd1dim_sharp.pkl"
if not path.exists():
    print("Recalculating...")
    _m = xrd1dim_compute_metrics_all_acquisition_functions_and_LTB(
        results_by_acqf_Adam,
        metrics_grid=metrics_grid,
        metrics_grid_linear=metrics_grid,
        metric="mse",
        grid_points=linspace_points,
        disable_pbar=False,
        xmin=88.0,
        xmax=92.0,
    )
    all_metrics_sharp = _m["metrics"]
    pickle.dump(all_metrics_sharp, open(path, "wb"), protocol=pickle.HIGHEST_PROTOCOL)
else:
    all_metrics_sharp = pickle.load(open(path, "rb"))

In [None]:
only_plot = ["LTB", "EI", "UCB(1)", "UCB(10)", "UCB(20)", "UCB(100)"]

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(6, 3), sharex=True, sharey=False)

for ii, all_metrics in enumerate([all_metrics_linear, all_metrics_quad, all_metrics_sharp]):
    ax = axs[ii]
    utils.set_grids(ax)
    ax.tick_params(which="minor", bottom=False, top=False)
    ax.set_xticks([3, 50, 150, 250])
    ax.set_yscale('log')
    ax.axhline(10**-3, color="black", linestyle="--")

    for acquisition_function_name in acquisition_function_name_maps.keys():
        values = all_metrics[acquisition_function_name]
        label = acquisition_function_name_maps[acquisition_function_name]
        if only_plot is None or label in only_plot:
            mu = values.mean(axis=1)
            sd = values.std(axis=1) / 10.0
            ax.plot(metrics_grid, mu, label=label)
            ax.fill_between(metrics_grid, mu - sd, mu + sd, linewidth=0, alpha=0.3)

axs[2].legend(frameon=False, bbox_to_anchor=(1, 0.5), loc="center left")
axs[0].text(0.5, 0.95, r"$\mu \pm \sigma / 10$", ha="center", va="top", transform=axs[0].transAxes)

axs[0].text(0.05, 0.05, r"(a)", ha="left", va="bottom", transform=axs[0].transAxes)
axs[1].text(0.05, 0.05, r"(b)", ha="left", va="bottom", transform=axs[1].transAxes)
axs[2].text(0.05, 0.05, r"(c)", ha="left", va="bottom", transform=axs[2].transAxes)

axs[0].set_title("linear")
axs[1].set_title("quadratic")
axs[2].set_title("sharp")

axs[1].set_xlabel(r"$N$")
axs[0].set_ylabel(r"$\log_{10} J$")


# Set limits
ax = axs[0]
yticks = np.array([-2, -5, -8])
ax.set_yticks((10.0**yticks).tolist())
ax.set_yticklabels([f"${ii}$" for ii in yticks])
ax.set_ylim(10**-9, 10**0.0)

ax = axs[1]
yticks = np.array([-2, -4, -6])
ax.set_yticks((10.0**yticks).tolist())
ax.set_yticklabels([f"${ii}$" for ii in yticks])

ax = axs[2]
yticks = np.array([-1, -2, -3])
ax.set_yticks((10.0**yticks).tolist())
ax.set_yticklabels([f"${ii}$" for ii in yticks])
ax.tick_params(axis='y', which='minor', left=False, right=False)

plt.subplots_adjust(wspace=0.3)

plt.savefig("SI_xrd1dim_phase_resolved_metric.pdf", dpi=300, bbox_inches="tight")
# plt.show()