In [None]:
%load_ext autoreload
%autoreload 2
%config Completer.use_jedi = False

Notebook local path should be at `ScientificValueAgent/figures`.

In [None]:
import sys
sys.path.append("..")

In [None]:
from collections import Counter
from itertools import product
from pathlib import Path

import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import pickle
from scipy.spatial import distance_matrix
from sklearn.cluster import KMeans
from sklearn.linear_model import LogisticRegression
from tqdm import tqdm
import warnings

In [None]:
from sva import utils

In [None]:
!which python3

Set some plotting defaults.

In [None]:
utils.set_defaults()

# Multi-phase one-dimensional XRD results

In [None]:
from sva.postprocessing import read_data, parse_results_by_acquisition_function
from sva.truth.xrd1dim import (
    _get_1d_phase_data,
    residual_1d_phase_relative_mae,
    xrd1dim_compute_metrics_all_acquisition_functions_and_LTB,
    _get_1d_phase_fractions,
    truth_xrd1dim,
    residual_1d_phase_mse
)

In [None]:
results_Adam = read_data("../results/results_23-05-02-xrd1dim/")

In [None]:
results_by_acqf_Adam = parse_results_by_acquisition_function(results_Adam)

In [None]:
cache = Path("cache")
cache.mkdir(exist_ok=True)

## Core manuscript figure

### Subfigure (a) and (b)

In [None]:
acquisition_function = "UpperConfidenceBound10"
all_results_Adam = np.array([xx.data.X.squeeze() for xx in results_by_acqf_Adam[acquisition_function]])

Get the phases...

In [None]:
x_grid = np.linspace(0, 100, 1000)
phases = _get_1d_phase_fractions(x_grid).T

Resolve by the experiment iteration...

In [None]:
all_results_Adam_n_resolved = [all_results_Adam[:, :nn].flatten() for nn in range(3, all_results_Adam.shape[1] + 1)]

In [None]:
all_results_Adam_n_resolved_coordinates = []
for ii, res in enumerate(all_results_Adam_n_resolved):
    n = len(res)
    coords = (np.ones(shape=(n,)) * ii).astype(int)
    arr = np.array([res, coords]).T
    all_results_Adam_n_resolved_coordinates.append(arr)
all_results_Adam_n_resolved_coordinates = np.concatenate(all_results_Adam_n_resolved_coordinates, axis=0)
all_results_Adam_n_resolved_coordinates[:, 1] += 3

In [None]:
vmax = 1500

In [None]:
fig, axs = plt.subplots(2, 1, figsize=(3, 3), gridspec_kw={'height_ratios':[1, 2]}, sharex=True)

ax = axs[0]
ax.plot(x_grid, phases[:, 0], label="1")
ax.plot(x_grid, phases[:, 1], label="2")
ax.plot(x_grid, phases[:, 2], label="3")
ax.plot(x_grid, phases[:, 3], label="4")
axlims = ax.get_ylim()
ax.text(1.05, 1.0, "Phase", ha="left", va="bottom", transform=ax.transAxes)
ax.text(0.025, 0.9, "(a)", ha="left", va="top", transform=ax.transAxes)
ax.legend(frameon=False, bbox_to_anchor=(1.0, 0.5), loc="center left")

ax.fill_betweenx(np.linspace(*axlims, 10), 10, 50, color="black", alpha=0.1, linewidth=0)
ax.fill_betweenx(np.linspace(*axlims, 10), 60, 80, color="black", alpha=0.1, linewidth=0)
ax.fill_betweenx(np.linspace(*axlims, 10), 88.5, 91.5, color="black", alpha=0.1, linewidth=0)

utils.set_grids(ax)
ax.set_ylabel("$p(x)$")
ax.set_ylim(*axlims)

ax = axs[1]
ax.hist2d(all_results_Adam_n_resolved_coordinates[:, 0], all_results_Adam_n_resolved_coordinates[:, 1], bins=[100, 247], cmap="viridis", vmax=vmax, rasterized=True)
ax.set_yticks([3, 50, 150, 250])
ax.set_xticks([0, 20, 40, 60, 80, 100])
utils.set_grids(ax)
ax.tick_params(which="minor", left=False, right=False)
ax.set_ylabel(r"$N$")
ax.set_xlabel("$x$~[a.u.]")
ax.text(0.025, 0.9, "(b)", ha="left", va="top", transform=ax.transAxes, color="white")

# plt.savefig("figures_xrd1dim/ucb.pdf", bbox_inches="tight", dpi=300)
plt.savefig("figures_xrd1dim/xrd1dim_subfigure_a.svg", dpi=300, bbox_inches="tight")
# plt.show()

### Subfigure (a) colorbar

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(2, 2), sharex=True, sharey=True)

im = ax.hist2d(all_results_Adam_n_resolved_coordinates[:, 0], all_results_Adam_n_resolved_coordinates[:, 1], bins=[150, 247], cmap="viridis", vmax=vmax)

cbar = utils.add_colorbar(im[-1], aspect=20)
cbar.set_ticks([0, vmax])
cbar.set_ticklabels([0, f"$\geq$ %i" % int(vmax / 300)])
cbar.set_label(r"Average Counts", labelpad=-10)

ax.remove()

plt.savefig("figures_xrd1dim/xrd1dim_cbar.svg", dpi=300, bbox_inches="tight")
# plt.show()

### Subfigure (c): select metrics

The metrics for this part take a long time to calculate, so we cache them.

In [None]:
acquisition_function_name_maps = {
    "Linear": "LTB",
    "UpperConfidenceBound10": "UCB(10)",
    "ExpectedImprovement": "EI",
    "UpperConfidenceBound1": "UCB(1)",
    "UpperConfidenceBound20": "UCB(20)",
    "UpperConfidenceBound100": "UCB(100)"
}

In [None]:
metrics_grid = list(range(3, 251, 10))
linspace_points = 10000

In [None]:
path = cache / "xrd1dim_all.pkl"
if not path.exists():
    print("Recalculating...")
    _m = xrd1dim_compute_metrics_all_acquisition_functions_and_LTB(
        results_by_acqf_Adam,
        metrics_grid=metrics_grid,
        metrics_grid_linear=metrics_grid,
        metric="mse",
        grid_points=linspace_points,
        disable_pbar=False,
        xmin=0.0,
        xmax=100.0,
    )
    all_metrics = _m["metrics"]
    pickle.dump(all_metrics, open(path, "wb"), protocol=pickle.HIGHEST_PROTOCOL)
else:
    all_metrics = pickle.load(open(path, "rb"))

In [None]:
only_plot = ["LTB", "EI", "UCB(10)"]

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(2, 3))

for acquisition_function_name in acquisition_function_name_maps.keys():
    values = all_metrics[acquisition_function_name]
    label = acquisition_function_name_maps[acquisition_function_name]
    if only_plot is None or label in only_plot:
        v = np.log(values)
        mu = np.nanmean(v, axis=1)
        sd = np.nanstd(v, axis=1) * 2
        ax.plot(metrics_grid, mu, label=label if label != "LTB" else "Grid")
        ax.fill_between(metrics_grid, mu - sd, mu + sd, linewidth=0, alpha=0.3)

utils.set_grids(ax)
ax.tick_params(which="minor", bottom=False, top=False)
ax.set_xticks([3, 50, 150, 250])

ax.legend(frameon=False, loc="upper right")
# ax.text(0.1, 0.05, r"$\mu \pm \sigma / 3$", ha="left", va="bottom", transform=ax.transAxes)
ax.text(0.1, 0.95, r"(c)", ha="left", va="top", transform=ax.transAxes)

# ax.set_yscale("log")
yticks = np.array([-2, -5, -8, -11, -14])
ax.set_yticks((yticks).tolist())
ax.set_yticklabels([f"${ii}$" for ii in yticks])
# ax.set_ylim(10**-5.3, 10**-0.7)
ax.tick_params(axis='y', which='minor', left=True, right=True)

ax.set_xlabel(r"$N$")
ax.set_ylabel(r"$\ln$(MSE)")

# plt.savefig("figures_xrd1dim/xrd1dim_subfigure_c.svg", dpi=300, bbox_inches="tight")
plt.show()

# Bayes clustering (SI figure)

In [None]:
max_queries = 250
grid_points = 10000
N_exp = 10

In [None]:
np.random.seed(123)

with warnings.catch_warnings():
    warnings.simplefilter("ignore")

    for k_clusters in [3, 4, 5]:
        metrics_grid = list(range(k_clusters, max_queries + 1, 10))
        all_metrics[f"Bayesian_{k_clusters}"] = []

        for exp in tqdm(range(N_exp)):
            k_means = KMeans(k_clusters, n_init="auto")
            clf = LogisticRegression()
            x = list(np.linspace(0, 100, k_clusters))
            y = truth_xrd1dim(np.array(x))  # This won't take single items, so we feed it the entire array each time???

            for _ in range(k_clusters + 1, max_queries + 1):
                labels = k_means.fit_predict(y)
                clf.fit(np.array(x).reshape(-1, 1), labels)
                linspace = np.linspace(0, 100, 1000).reshape(-1, 1)
                proby = clf.predict_proba(linspace)
                shannon = np.sum(proby * np.log(1 / proby), axis=-1)
                max_entropy_loc = float(linspace[np.argmax(shannon)])

                x.append(max_entropy_loc)
                y = truth_xrd1dim(np.array(x))

            _metrics = []
            for N in metrics_grid:
                res = residual_1d_phase_mse(
                    np.array(x)[:N].reshape(-1, 1),
                    linspace_points=grid_points,
                    use_only=None,
                )
                _metrics.append(res)
            all_metrics[f"Bayesian_{k_clusters}"].append(_metrics)

for k_clusters in [3, 4, 5]:
    all_metrics[f"Bayesian_{k_clusters}"] = np.array(all_metrics[f"Bayesian_{k_clusters}"])

In [None]:
metric_grids_bayesian = []
for k_clusters in [3, 4, 5]:
    metric_grids_bayesian.append(list(range(k_clusters, max_queries + 1, 10)))

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(2, 3))

for acquisition_function_name in acquisition_function_name_maps.keys():
    values = all_metrics[acquisition_function_name]
    label = acquisition_function_name_maps[acquisition_function_name]
    if only_plot is None or label in only_plot:
        v = np.log(values)
        mu = np.nanmean(v, axis=1)
        sd = np.nanstd(v, axis=1) * 2
        ax.plot(metrics_grid, mu, label=label if label != "LTB" else "Grid")
        ax.fill_between(metrics_grid, mu - sd, mu + sd, linewidth=0, alpha=0.3)


acquisition_function_name_maps_bayesian = {
    f"Bayesian_{k_clusters}": f"Bayesian_{k_clusters}_clusters" for k_clusters in [3, 4, 5]
}
for ii, acquisition_function_name in enumerate(acquisition_function_name_maps_bayesian.keys()):
    values = all_metrics[acquisition_function_name]
    label = acquisition_function_name_maps_bayesian[acquisition_function_name]
    mu = np.nanmean(np.log(values), axis=0)
    sd = np.nanstd(np.log(values), axis=0) * 2
    ax.plot(metric_grids_bayesian[ii], mu, label=label.replace("_", " "))
    ax.fill_between(metric_grids_bayesian[ii], mu - sd, mu + sd, linewidth=0, alpha=0.3)


utils.set_grids(ax)
ax.tick_params(which="minor", bottom=False, top=False)
ax.set_xticks([3, 50, 150, 250])

ax.legend(frameon=False, loc="center left", bbox_to_anchor=(1, 0.5))
# ax.text(0.1, 0.05, r"$\mu \pm \sigma / 3$", ha="left", va="bottom", transform=ax.transAxes)
# ax.text(0.1, 0.95, r"(c)", ha="left", va="top", transform=ax.transAxes)

# ax.set_yscale("log")
yticks = np.array([-2, -5, -8, -11, -14])
ax.set_yticks((yticks).tolist())
ax.set_yticklabels([f"${ii}$" for ii in yticks])
# ax.set_ylim(10**-5.3, 10**-0.7)
ax.tick_params(axis='y', which='minor', left=True, right=True)

ax.set_xlabel(r"$N$")
ax.set_ylabel(r"$\ln$(MSE)")

# plt.show()
plt.savefig("figures_xrd1dim/SI_xrd1dim_Bayesian.svg", dpi=300, bbox_inches="tight")