In [None]:
%load_ext autoreload
%autoreload 2
%config Completer.use_jedi = False

In [None]:
from datetime import datetime
from pathlib import Path
import pickle
import sys
import numpy as np
from scipy.stats import sem
from matplotlib.colors import Normalize 
from scipy.interpolate import interpn
from IPython.display import clear_output
from rdkit import Chem
from PyAstronomy.pyasl import broadGaussFast

import matplotlib.pyplot as plt
from matplotlib import cm
from pymatgen.core.structure import Molecule

In [None]:
# https://gist.github.com/x94carbone/f5201b1c44963ff9453b9cc1d5f768ac
sys.path.append(str(Path.home() / Path("local")))
from mpl_utils import MPLAdjutant
adj = MPLAdjutant()
adj.set_defaults()

In [None]:
import matplotlib
matplotlib.rcParams['text.latex.preamble']=[r"\usepackage{amsmath}"]

In [None]:
import json

def save_json(d, path):
    with open(path, 'w') as outfile:
        json.dump(d, outfile, indent=4, sort_keys=True)

def read_json(path):
    with open(path, 'r') as infile:
        dat = json.load(infile)
    return dat

Append the `home` path of this project.

In [None]:
# https://stackoverflow.com/questions/20105364/how-can-i-make-a-scatter-plot-colored-by-density-in-matplotlib
def density_scatter(x, y, ax, sort=True, bins=20, **kwargs):
    """
    Scatter plot colored by 2d histogram
    """

    data , x_e, y_e = np.histogram2d( x, y, bins = bins, density = True )
    z = interpn( ( 0.5*(x_e[1:] + x_e[:-1]) , 0.5*(y_e[1:]+y_e[:-1]) ) , data , np.vstack([x,y]).T , method = "splinef2d", bounds_error = False)

    #To be sure to plot all data
    z[np.where(np.isnan(z))] = 0.0

    # Sort the points by density, so that the densest points are plotted last
    if sort :
        idx = z.argsort()
        x, y, z = x[idx], y[idx], z[idx]

    scat = ax.scatter(x, y, c=z, **kwargs)
    return scat

# Load the data and trained ensembles

In [None]:
def load_trained_ensembles(ensemble_root_path="Ensembles"):
    ensembles = dict()
    ensemble_paths = list(Path(ensemble_root_path).rglob("ensemble.json"))
    
    for ensemble_path in ensemble_paths:
        downsample_prop = float(ensemble_path.parent.parts[-1])
        atom_key = str(ensemble_path).split("-ACSF-")[1].split("-")[0]
        if "TOTAL-ATOMS" in str(ensemble_path):
            n_atoms = str(ensemble_path).split("-TOTAL-ATOMS")[0].split("-")[-1]
            atom_key = f"{atom_key}-{n_atoms}"
        if atom_key not in ensembles.keys():
            ensembles[atom_key] = dict()
        ensembles[atom_key][downsample_prop] = Ensemble.from_dict(read_json(ensemble_path))

    return ensembles

def load_data():
    return {
        "C": pickle.load(open("data/qm9/ml_ready/XANES-220817-ACSF-C-distorted.pkl", "rb"))
    }

In [None]:
data = load_data()["C"]

In [None]:
from xas_nne.ml import Ensemble

In [None]:
ensembles = load_trained_ensembles()

# Evaluate the ensemble effectiveness on the randomly sampled data

## Get the results compiled

In [None]:
# downsample_values = sorted([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], reverse=True)
downsample_values = [0.9]

In [None]:
ATOMS = ["C"]

Each `preds[atom_type][downsample]` is of the shape `(N_ensemble, N_examples, M)`. These predictions are masked `numpy` arrays, where the mask values correspond to outlier predictions (relative to other estimators), or totally unphysical ones.

In [None]:
ensemble = ensembles["C"][0.9]

In [None]:
preds_no_filter = ensemble.predict(data["x"])

In [None]:
preds = ensemble.predict_filter_outliers(
    data["x"],
    sd_mult=2.0,
    threshold_sd=0.7,
    max_spectra_value=20.0,
    threshold_zero=0.5,
    min_spectra_value=0.05,
)

In [None]:
def get_indexes(distortion=None, qm9id=None, names=data["names"]):
    
    if qm9id is not None and distortion is not None:
        return np.array([
            ii for ii, name in enumerate(names)
            if str(qm9id) in name and f"{distortion:.02f}" in name
        ])
    
    if qm9id is None and distortion is not None:
        return np.array([
            ii for ii, name in enumerate(names)
            if f"{distortion:.02f}" in name
        ])
    
    if qm9id is not None and distortion is None:
        return np.array([
            ii for ii, name in enumerate(names)
            if str(qm9id) in name
        ])
    
    raise ValueError

In [None]:
grid = data["grid"]

In [None]:
errors = np.abs(data["y"] - preds.mean(axis=0))

# Plot the errors as a function of the distortion

In [None]:
qm9_ids = {129158: 7, 43138: 6, 87244: 1, 67255: 2, 50994: 1, 110619: 4, 108590: 6, 17249: 7, 104189: 0, 65272: 2}

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(3, 1.5), sharex=True, sharey=True)

distortions = [0.01 + 0.01 * ii for ii in range(0, 10)]

mean_errors = []
std_errors = []
mean_model_std = []
std_model_std = []

for distortion in distortions:
    distortion = round(distortion, 2)
    indexes = get_indexes(distortion=distortion)
    e = errors[indexes, :]
    p = preds[:, indexes, :]
    mean_errors.append(e.mean())
    std_errors.append(e.std())
    mean_model_std.append(p.std(axis=0).mean())
    std_model_std.append(p.std(axis=0).std())

mean_errors = np.array(mean_errors)
std_errors = np.array(std_errors)
mean_model_std = np.array(mean_model_std)
std_model_std = np.array(std_model_std)

ax.plot(distortions, np.log10(mean_errors), 'ko-', label=r"$\log_{10} \bar{\varepsilon}_\mathrm{dist.}(\delta)$" )
# ax.fill_between(distortions, (mean_errors - std_errors), (mean_errors + std_errors), color='black', alpha=0.1, linewidth=0)

ax.plot(distortions, np.log10(mean_model_std), 'ro-', label=r"$\log_{10} \bar{\sigma}_\mathrm{dist.}(\delta)$")
# ax.fill_between(distortions, (mean_model_std - std_model_std), (mean_model_std + std_model_std), color='red', alpha=0.1, linewidth=0)

adj.set_grids(ax, grid=False)
ax.set_xlabel(r"$\delta$~$(10^{-2}$ \AA)")
ax.legend(frameon=False)

ax.set_xticks([(0.01 + ii * 0.01) for ii in range(10)])
ax.set_xticklabels([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])

adj.set_ylim(ax, -1.5, -0.5)
ax.tick_params(bottom=False, top=False, which="minor")

plt.savefig("Figures/qm9_distortion_averages.svg", bbox_inches="tight", dpi=300)
# plt.show()

# Plot the errors (scatterplot)

In [None]:
cmap = cm.get_cmap("rainbow", 10)
distortions = [0.01 + 0.01 * ii for ii in range(0, 10)]

In [None]:
scale = 0.4
fig, axs = plt.subplots(len(qm9_ids), len(distortions), figsize=(3 * len(qm9_ids) * scale, 3 * len(qm9_ids) * scale), sharex=True, sharey=True)

for ii, qm9id in enumerate(list(qm9_ids.keys())[::-1]):
    for jj, distortion in enumerate(distortions):
        distortion = round(distortion, 2)
        
        ax = axs[ii, jj]
        adj.set_grids(ax, grid=False)
        
        if ii == 0:
            ax.set_title(f"{distortion:.02f}", fontsize=16)
        
        indexes = get_indexes(distortion=distortion, qm9id=qm9id)
        e = np.log10(errors[indexes, :].flatten())
        p = np.log10(preds[:, indexes, :].std(axis=0).flatten())
        
        # ax.scatter(p, e, color=cmap(jj), alpha=0.5, s=1, rasterized=True)
        density_scatter(p, e, ax=ax, sort=True, bins=20, s=1, alpha=1, rasterized=True)

ax_min = -6
ax_max = 2
for ax in axs.flatten():
    ax.set_xlim(ax_min, ax_max)
    ax.set_ylim(ax_min, ax_max)
    ax.set_xticks([-6, -2, 2])
    ax.set_yticks([-6, -2, 2])
    adj.set_grids(ax, grid=False)
    adj.set_xlim(ax, ax_min, ax_max)
    adj.set_ylim(ax, ax_min, ax_max)
    
plt.subplots_adjust(wspace=0.1, hspace=0.1)

axs[0, 0].set_ylabel(r"$\log_{10} \varepsilon^{(i)}_j$", fontsize=16)
axs[-1, 0].set_xlabel(r"$\log_{10} \hat{\sigma}^{(i)}_j$", fontsize=16)

plt.savefig("Figures/qm9_sigma_parity_distortion_test_3.svg", bbox_inches="tight", dpi=300)
# plt.show()

# Examples!

In [None]:
scale = 0.4
fig, axs = plt.subplots(len(qm9_ids), len(distortions), figsize=(3 * len(qm9_ids) * scale, 2 * len(qm9_ids) * scale), sharex=True, sharey=False)

for ii, qm9id in enumerate(list(qm9_ids.keys())[::-1]):
    for jj, distortion in enumerate(distortions):
        distortion = round(distortion, 2)
        
        ax = axs[ii, jj]
        adj.set_grids(ax, grid=False)
        
        if ii == 0:
            ax.set_title(f"{distortion:.02f}", fontsize=16)
        
        index = get_indexes(distortion=distortion, qm9id=qm9id)[5]  # Random index, essentially
        gt = data["y"][index, :]
        sd = preds[:, index, :].std(axis=0)
        p = preds[:, index, :].mean(axis=0)
        
        ax.plot(grid, p, 'r', linewidth=0.5)
        ax.fill_between(grid, p - 3*sd, p + 3*sd, color="red", alpha=0.5, linewidth=0)
        ax.plot(grid, gt, "k", linewidth=0.5)

for ax in axs.flatten():
    # ax.set_xticks([285, 315])
    # adj.set_grids(ax, grid=False)
    adj.set_xlim(ax, 275, 305)
    # ax.set_yticks([])
    # ax.set_ylim(bottom=-0.5, top=6.0)
    ax.axis('off')

plt.subplots_adjust(wspace=0.1, hspace=0.1)

axs[0, 0].set_ylabel(r"$\mu(E)$~(a.u.)", fontsize=16)
axs[-1, 0].set_xlabel(r"$E$~(e.V.)", fontsize=16)

plt.savefig("Figures/qm9_distortion_preds_waterfall.svg", bbox_inches="tight", dpi=300)
# plt.show()