In [None]:
%load_ext autoreload
%autoreload 2
%config Completer.use_jedi = False

Notebook local path should be at `ScientificValueAgent/figures`.

In [None]:
import sys
sys.path.append("..")

In [None]:
from collections import Counter
from itertools import product
from pathlib import Path

import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import pickle
from scipy.stats import pearsonr
from tqdm import tqdm
import xarray as xr

In [None]:
from sva import utils

Set some plotting defaults.

In [None]:
utils.set_defaults()

# BTO results

In [None]:
from sva.postprocessing import read_data, parse_results_by_acquisition_function
from sva.truth.bto import cmf_predicted_mse, bto_compute_metrics_all_acquisition_functions_and_LTB, truth_bto

In [None]:
results_Adam = read_data("results/results_23-01-20_bto2")

In [None]:
results_by_acqf_Adam = parse_results_by_acquisition_function(results_Adam)

In [None]:
cache = Path("cache")
cache.mkdir(exist_ok=True)

## Core manuscript figure

Load in the NMF weights from Phil's paper: Applied Physics Reviews 8, 041410 (2021); https://doi.org/10.1063/5.0052859

In [None]:
weights = xr.open_dataarray("../sva/truth/bto_xca_weights.nc")

In [None]:
data = xr.open_dataarray("../sva/truth/bto_data.nc")
grad = np.abs(np.gradient(data.data, axis=1)).mean(axis=0)

In [None]:
temperature_grid = weights["temperature"]

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(3, 2))

ax.plot(temperature_grid, weights.data[:, 0])
ax.plot(temperature_grid, weights.data[:, 1])
ax.plot(temperature_grid, weights.data[:, 2])
ax.plot(temperature_grid, weights.data[:, 3])

plt.show()

### Subfigure (a)

In [None]:
acquisition_function = "UpperConfidenceBound10"
all_results_Adam = np.array([xx.data.X.squeeze() for xx in results_by_acqf_Adam[acquisition_function]])

Resolve by the experiment iteration...

In [None]:
all_results_Adam_n_resolved = [all_results_Adam[:, :nn].flatten() for nn in range(3, all_results_Adam.shape[1] + 1)]

In [None]:
all_results_Adam_n_resolved_coordinates = []
for ii, res in enumerate(all_results_Adam_n_resolved):
    n = len(res)
    coords = (np.ones(shape=(n,)) * ii).astype(int)
    arr = np.array([res, coords]).T
    all_results_Adam_n_resolved_coordinates.append(arr)
all_results_Adam_n_resolved_coordinates = np.concatenate(all_results_Adam_n_resolved_coordinates, axis=0)
all_results_Adam_n_resolved_coordinates[:, 1] += 3

In [None]:
vmax = 1500

In [None]:
fig, axs = plt.subplots(2, 1, figsize=(3, 3), gridspec_kw={'height_ratios':[1, 2]}, sharex=True)

ax = axs[0]
ax.plot(temperature_grid, weights[:, 0], label="Rhomb")
ax.plot(temperature_grid, weights[:, 1], label="Ortho")
ax.plot(temperature_grid, weights[:, 2], label="Tetra")
ax.plot(temperature_grid, weights[:, 3], label="Cubic")
axlims = ax.get_ylim()
ax.text(1.05, 1.0, "Component", ha="left", va="bottom", transform=ax.transAxes)
ax.text(0.025, 0.5, "(a)", ha="left", va="center", transform=ax.transAxes)
ax.legend(frameon=False, bbox_to_anchor=(1.0, 0.5), loc="center left")

d = (
    np.abs(np.gradient(weights[:, 0])) +
    np.abs(np.gradient(weights[:, 1])) +
    np.abs(np.gradient(weights[:, 2])) +
    np.abs(np.gradient(weights[:, 3]))
) / 4.0
d = d / d.max()
ax.plot(temperature_grid, d, "k-")
ax.plot(temperature_grid, grad / grad.max(), color="cyan")

ax.axvline(185, color="grey", linewidth=0.5, linestyle="--")
ax.axvline(280, color="grey", linewidth=0.5, linestyle="--")
ax.axvline(400, color="grey", linewidth=0.5, linestyle="--")

# ax.fill_betweenx(np.linspace(*axlims, 10), 10, 50, color="black", alpha=0.1, linewidth=0)
# ax.fill_betweenx(np.linspace(*axlims, 10), 60, 80, color="black", alpha=0.1, linewidth=0)
# ax.fill_betweenx(np.linspace(*axlims, 10), 88.5, 91.5, color="black", alpha=0.1, linewidth=0)

utils.set_grids(ax)
ax.set_ylabel("$w(T)$")
ax.set_ylim(*axlims)

ax = axs[1]
ax.hist2d(all_results_Adam_n_resolved_coordinates[:, 0], all_results_Adam_n_resolved_coordinates[:, 1], bins=[100, 247], cmap="viridis", vmax=vmax, rasterized=True)
ax.set_yticks([3, 50, 150, 250])
# ax.set_xticks([0, 20, 40, 60, 80, 100])
utils.set_grids(ax)
ax.tick_params(which="minor", left=False, right=False)
ax.set_ylabel(r"$N$")
ax.set_xlabel("$T$~[K]")
ax.text(0.025, 0.9, "(b)", ha="left", va="top", transform=ax.transAxes, color="black")

# plt.savefig("bto_subfigure_a_b.svg", dpi=300, bbox_inches="tight")
plt.show()

In [None]:
where = np.where(all_results_Adam_n_resolved_coordinates[:, 1] == 250)[0]
points = all_results_Adam_n_resolved_coordinates[where, 0]

In [None]:
dense_T_grid = np.linspace(150, 445, 1000)
dense_truth = truth_bto(dense_T_grid)
dense_gradient = np.abs(np.gradient(dense_truth, axis=0)).mean(axis=1)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(3, 2))

res = ax.hist(points, density=True, bins=1000)
ax.plot(dense_T_grid, dense_gradient / dense_gradient.max() * res[0].max(), color="black", label="$\langle |\\nabla_T I(Q; T)| \\rangle$")

ax.set_yticklabels([])
utils.set_grids(ax)
ax.legend(frameon=False)
ax.set_xlabel("$T$")

plt.show()


### Subfigure (a) colorbar

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(2, 2), sharex=True, sharey=True)

im = ax.hist2d(all_results_Adam_n_resolved_coordinates[:, 0], all_results_Adam_n_resolved_coordinates[:, 1], bins=[100, 247], cmap="viridis", vmax=vmax)

cbar = utils.add_colorbar(im[-1], aspect=6)
cbar.set_ticks([0, vmax])
cbar.set_ticklabels([0, f"$\geq$ %i" % int(vmax / 300)])
cbar.set_label(r"Average Counts", labelpad=-10)

ax.remove()

# plt.savefig("bto_cbar.svg", dpi=300, bbox_inches="tight")
plt.show()

### Subfigure metrics

In [None]:
acquisition_function_name_maps = {
    "Linear": "LTB",
    "ExpectedImprovement": "EI",
    "UpperConfidenceBound1": "UCB(1)",
    "UpperConfidenceBound10": "UCB(10)",
    "UpperConfidenceBound20": "UCB(20)",
    "UpperConfidenceBound100": "UCB(100)"
}

In [None]:
metrics_grid = list(range(3, 251, 10))
linspace_points = 10000

In [None]:
path = cache / "bto_all.pkl"
if not path.exists():
    print("Recalculating...")
    _m = bto_compute_metrics_all_acquisition_functions_and_LTB(
        results_by_acqf_Adam,
        metrics_grid=metrics_grid,
        metrics_grid_linear=metrics_grid,
        metric="mse",
        grid_points=linspace_points,
        disable_pbar=False,
        xmin=150.0,
        xmax=445.0,
    )
    all_metrics = _m["metrics"]
    pickle.dump(all_metrics, open(path, "wb"), protocol=pickle.HIGHEST_PROTOCOL)
else:
    all_metrics = pickle.load(open(path, "rb"))

In [None]:
only_plot = ["LTB", "EI", "UCB(1)", "UCB(10)", "UCB(20)", "UCB(100)"]

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(2, 3))

for acquisition_function_name in acquisition_function_name_maps.keys():
    values = all_metrics[acquisition_function_name]
    label = acquisition_function_name_maps[acquisition_function_name]
    if only_plot is None or label in only_plot:
        mu = np.nanmean(values, axis=1)
        sd = np.nanstd(values, axis=1) / 3
        ax.plot(metrics_grid, mu, label=label)
        ax.fill_between(metrics_grid, mu - sd, mu + sd, linewidth=0, alpha=0.3)

utils.set_grids(ax)
ax.tick_params(which="minor", bottom=False, top=False)
ax.set_xticks([3, 50, 150, 250])

ax.legend(frameon=False, loc="upper right")
ax.text(0.1, 0.05, r"$\mu \pm \sigma / 3$", ha="left", va="bottom", transform=ax.transAxes)
ax.text(0.1, 0.95, r"(c)", ha="left", va="top", transform=ax.transAxes)

ax.set_yscale("log")
# yticks = np.array([-1, -2, -3, -4, -5])
# ax.set_yticks((10.0**yticks).tolist())
# ax.set_yticklabels([f"${ii}$" for ii in yticks])
# ax.set_ylim(10**-5.3, 10**-0.7)
ax.tick_params(axis='y', which='minor', left=True, right=True)

ax.set_xlabel(r"$N$")
ax.set_ylabel(r"$\log_{10} J$")

# plt.savefig("bto_subfigure_c.svg", dpi=300, bbox_inches="tight")
plt.show()

## SI Figure

In [None]:
acquisition_function_name_maps = {
    "ExpectedImprovement": "EI",
    "UpperConfidenceBound1": "UCB(1)",
    "UpperConfidenceBound10": "UCB(10)",
    "UpperConfidenceBound20": "UCB(20)",
    "UpperConfidenceBound100": "UCB(100)"
}

In [None]:
all_points = dict()
for acquisition_function_name, values in results_by_acqf_Adam.items():
    tmp_metrics = [exp.data.X for exp in values]
    all_points[acquisition_function_name] = np.array(tmp_metrics)

In [None]:
L = len(acquisition_function_name_maps) + 1

fig, axs = plt.subplots(L, 1, figsize=(3, L), sharex=True, sharey=False)


ax = axs[0]
ax.plot(temperature_grid, weights[:, 0], label="Rhomb")
ax.plot(temperature_grid, weights[:, 1], label="Ortho")
ax.plot(temperature_grid, weights[:, 2], label="Tetra")
ax.plot(temperature_grid, weights[:, 3], label="Cubic")
utils.set_grids(ax)
axlims = ax.get_ylim()
ax.set_ylabel(r"$w(T)$")

for ii, acquisition_function_name in enumerate(acquisition_function_name_maps.keys()):
    
    ax = axs[ii + 1]
    
    value = all_points[acquisition_function_name].squeeze()
    all_results_Adam_n_resolved = [value[:, :nn].flatten() for nn in range(3, value.shape[1] + 1)]
    
    all_results_Adam_n_resolved_coordinates = []
    for ii, res in enumerate(all_results_Adam_n_resolved):
        n = len(res)
        coords = (np.ones(shape=(n,)) * ii).astype(int)
        arr = np.array([res, coords]).T
        all_results_Adam_n_resolved_coordinates.append(arr)
    all_results_Adam_n_resolved_coordinates = np.concatenate(all_results_Adam_n_resolved_coordinates, axis=0)
    all_results_Adam_n_resolved_coordinates[:, 1] += 3
    
    label = acquisition_function_name_maps[acquisition_function_name]
    
    ax.hist2d(
        all_results_Adam_n_resolved_coordinates[:, 0],
        all_results_Adam_n_resolved_coordinates[:, 1],
        bins=[100, 247], cmap="viridis", vmax=vmax, rasterized=True
    )
    
    utils.set_grids(ax)
    ax.set_yticks([3, 100, 250])
    ax.text(1.05, 0.5, label, ha="left", va="center", transform=ax.transAxes, rotation=90)
    

axs[5].set_xlabel("$T$~[K]")
axs[3].set_ylabel(r"$N$")
axs[0].text(0.025, 0.5, "(a)", ha="left", va="center", transform=axs[0].transAxes)
axs[1].text(0.025, 0.5, "(b)", ha="left", va="center", transform=axs[1].transAxes, color="white")
axs[2].text(0.025, 0.5, "(c)", ha="left", va="center", transform=axs[2].transAxes, color="white")
axs[3].text(0.025, 0.5, "(d)", ha="left", va="center", transform=axs[3].transAxes, color="white")
axs[4].text(0.025, 0.5, "(e)", ha="left", va="center", transform=axs[4].transAxes, color="white")
axs[5].text(0.025, 0.5, "(f)", ha="left", va="center", transform=axs[5].transAxes, color="white")



plt.subplots_adjust(hspace=0.4, wspace=0.03)

plt.savefig("SI_bto_hist.pdf", dpi=300, bbox_inches="tight")
# plt.show()