In [None]:
%load_ext autoreload
%autoreload 2
%config Completer.use_jedi = False

In [None]:
from datetime import datetime
from pathlib import Path
import pickle
import sys
import numpy as np
from matplotlib.colors import Normalize 
from scipy.interpolate import interpn
from functools import cache
from tqdm import tqdm

import matplotlib.pyplot as plt
from pymatgen.core.structure import Molecule

Custom plotting code... ignore this if you don't know what it is, it should gracefully do nothing if you don't have the `MPLAdjutant` class. 

In [None]:
sys.path.append(str(Path.home() / Path("local")))
class NullClass:
    def do_nothing(*args, **kwargs):
        pass
    def add_colorbar(self, im, **kwargs):
        return plt.colorbar(im)
    def __getattr__(self, _):
        return self.do_nothing
try:
    from mpl_utils import MPLAdjutant
    adj = MPLAdjutant()
except ImportError:
    adj = NullClass()

In [None]:
adj.set_defaults()

In [None]:
import json

def save_json(d, path):
    with open(path, 'w') as outfile:
        json.dump(d, outfile, indent=4, sort_keys=True)

def read_json(path):
    with open(path, 'r') as infile:
        dat = json.load(infile)
    return dat

Append the `home` path of this project.

In [None]:
sys.path.append(str(Path.cwd().parent))

In [None]:
# https://stackoverflow.com/questions/20105364/how-can-i-make-a-scatter-plot-colored-by-density-in-matplotlib
def density_scatter(x, y, ax, sort=True, bins=20, **kwargs):
    """
    Scatter plot colored by 2d histogram
    """

    data , x_e, y_e = np.histogram2d( x, y, bins = bins, density = True )
    z = interpn( ( 0.5*(x_e[1:] + x_e[:-1]) , 0.5*(y_e[1:]+y_e[:-1]) ) , data , np.vstack([x,y]).T , method = "splinef2d", bounds_error = False)

    #To be sure to plot all data
    z[np.where(np.isnan(z))] = 0.0

    # Sort the points by density, so that the densest points are plotted last
    if sort :
        idx = z.argsort()
        x, y, z = x[idx], y[idx], z[idx]

    ax.scatter(x, y, c=z, **kwargs)

    # norm = Normalize(vmin = np.min(z), vmax = np.max(z))
    # cbar = fig.colorbar(cm.ScalarMappable(norm = norm), ax=ax)
    # cbar.ax.set_ylabel('Density')

    return ax

# Load the data and trained ensembles

In [None]:
from xas_nne.ml import Ensemble

In [None]:
ATOM_TYPES = ["C", "N", "O"]
MAX_ABS = [7, 4, 3]

data = {
    key: pickle.load(open(f"../data/qm9/ml_ready/XANES-220629-ACSF-{key}-MAX_TRAINING_ABSORBERS-{aa}.pkl", "rb"))
    for key, aa in zip(ATOM_TYPES, MAX_ABS)
}
ensembles = {
    key: Ensemble.from_dict(read_json(f"Ensembles/220629-{key}-MAX_TRAINING_ABSORBERS-{aa}/Ensemble.json"))
    for key, aa in zip(ATOM_TYPES, MAX_ABS)
}

In [None]:
@cache
def load_data(atom_type, ensembles=ensembles, data=data):
    pred = ensembles[atom_type].predict(data[atom_type]["test"]["x"])
    gt = data[atom_type]["test"]["y"]
    return pred, gt

# Evaluate the ensemble effectiveness

## Error histograms

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(6, 6/3/1.6))

for ii, (ax, atom_type) in enumerate(zip(axs, ATOM_TYPES)):
    
    ax.set_title(atom_type)

    # Get the data
    pred, gt = load_data(atom_type)
    
    # Ensemble error itself
    ensemble_err = np.mean(np.abs(gt - pred.mean(axis=0)), axis=-1)
    log_ensemble_err = np.log10(ensemble_err)
    
    # Individual errors from each estimator
    individual_errs = np.mean(np.abs(gt - pred), axis=-1)
    log_individual_errs = np.log10(individual_errs)
    avg_log_estimator_err = np.mean(log_individual_errs, axis=0)
    
    # Average testing set error as a baseline
    average_spectrum_in_testing_set = np.mean(gt, axis=0)
    dummy_testing_set_error = np.log10(np.mean(np.abs(average_spectrum_in_testing_set - gt)))
    ax.axvline(dummy_testing_set_error, color="blue", linestyle="--", linewidth=0.5, zorder=0)
    
    print(atom_type)
    print(log_ensemble_err.shape, avg_log_estimator_err.shape)
    
    # Plot
    bins = [0.2 - ii * 0.05 for ii in range(50)][::-1]
    ax.hist(log_ensemble_err, bins=bins, color="black", label=r"$\varepsilon^{(i)}$" if ii==1 else None)
    ax.hist(avg_log_estimator_err, bins=bins, color="red", alpha=0.5, label=r"$\varepsilon_\mathrm{est}^{(i)}$" if ii==1 else None)
    ax.text(0.1, 0.3, r"$%.02f$" % np.median(log_ensemble_err), color="black", ha="left", va="center", transform=ax.transAxes, fontsize=8)
    ax.text(0.1, 0.2, r"$%.02f$" % np.median(avg_log_estimator_err), color="red", ha="left", va="center", transform=ax.transAxes, fontsize=8)
    t = ax.text(0.9, 0.4, r"$%.02f$" % dummy_testing_set_error, color="blue", ha="right", va="center", transform=ax.transAxes, fontsize=8)
    t.set_bbox(dict(facecolor='white', alpha=1, edgecolor='white'))
    
    # Fine tuning
    adj.set_grids(ax, grid=False)
    ax.set_yticklabels([])
    ax.set_xticks([-3, -2, -1, 0])
    adj.set_xlim(ax, -3, 2)
    
    if ii == 0:  # Carbon
        val = 10000
        ax.text(0.1, 0.8, val, ha="left", va="center", transform=ax.transAxes, color="gray") 
    elif ii == 1:  # Nitrogen
        val = 500
        ax.text(0.1, 0.8, val, ha="left", va="center", transform=ax.transAxes, color="gray") 
    else:  # Oxygen
        val = 400
        ax.text(0.1, 0.8, val, ha="left", va="center", transform=ax.transAxes, color="gray") 

    ax.axhline(val, color="gray", linestyle="--", linewidth=0.5, zorder=0)

    
        
axs[1].set_xlabel(r"$\log_{10} \varepsilon^{(i)}$")
axs[0].set_ylabel("Counts")
axs[1].legend(frameon=False, loc="center left", fontsize=10)

plt.subplots_adjust(wspace=0.1)

# plt.savefig("qm9_hists_generalize.svg", bbox_inches="tight", dpi=300)
# needs post-processing on InkScape
plt.show()

## Plot the correlation between error and std

In [None]:
debug = None

fig, axs = plt.subplots(1, 3, figsize=(6, 2), sharex=True, sharey=True)

bins_map = {"C": 30, "N": 15, "O": 25}
    
for ii, atom_type in enumerate(ATOM_TYPES):
    
    # Not exactly efficient to do this every time but it's fast enough...
    pred, gt = load_data(atom_type)

    # Pointwise
    ensemble_pointwise_err = np.abs(gt - pred.mean(axis=0)).flatten()
    ensemble_pointwise_std = pred.std(axis=0).flatten()

    # Spectrum-wise
    ensemble_err = np.abs(gt - pred.mean(axis=0)).mean(axis=1)
    ensemble_std = pred.std(axis=0).mean(axis=-1)

    ax = axs[ii]
    ax.set_title(atom_type)
    x = np.log10(ensemble_err[::debug])
    y = np.log10(ensemble_std[::debug])
    ax = density_scatter(x, y, ax=ax, sort=True, bins=bins_map[atom_type], s=0.4, alpha=1, rasterized=True)
    idx = np.argsort(x)
    p = np.polyfit(x[idx], y[idx], deg=1)
    poly = np.poly1d(p)
    ax.axline((-3, poly(-3)), (-2, poly(-2)), color="black", zorder=0, linestyle="--", linewidth=0.5)
    
    # ax = axs[1, ii]
    # x = np.log10(200 * ensemble_pointwise_std[::debug])
    # y = np.log10(200 * ensemble_pointwise_err[::debug])
    # ax = density_scatter(x, y, ax=ax, sort=True, bins=bins_map[atom_type], s=0.4, alpha=1)
    # idx = np.argsort(x)
    # p = np.polyfit(x[idx], y[idx], deg=1)
    # poly = np.poly1d(p)
    # ax.axline((-3, poly(-3)), (-2, poly(-2)), color="black", zorder=0, linestyle="--", linewidth=0.5)

for ax in axs.flatten():
    ax.set_xlim(-3, 2)
    ax.set_ylim(-3, 2)
    ax.set_xticks([-3, -2, -1, 0, 1])
    ax.set_yticks([-3, -2, -1, 0, 1])
    adj.set_grids(ax, grid=False)
    adj.set_xlim(ax, -3, 2)
    adj.set_ylim(ax, -3, 2)

axs[1].set_xlabel(r"$\log_{10} \varepsilon^{(i)}$")
axs[0].set_ylabel(r"$\log_{10} \sigma$")

plt.subplots_adjust(wspace=0.1)

# plt.savefig("qm9_sigma_parity_generalize.svg", bbox_inches="tight", dpi=300)
plt.show()

## Everything together

In [None]:
debug = None
bins_map = {"C": 30, "N": 15, "O": 25}

fig, axs = plt.subplots(2, 3, figsize=(6, 3.3), gridspec_kw={"height_ratios": [1.3, 2]}, sharex=True)

for ii, atom_type in enumerate(ATOM_TYPES):
    
    
    #### UPPER PLOTS ####
    
    ax = axs[0, ii]
    
    ax.set_title(atom_type)

    # Get the data
    pred, gt = load_data(atom_type)
    
    # Ensemble error itself
    ensemble_err = np.mean(np.abs(gt - pred.mean(axis=0)), axis=-1)
    log_ensemble_err = np.log10(ensemble_err)
    
    # Individual errors from each estimator
    individual_errs = np.mean(np.abs(gt - pred), axis=-1)
    log_individual_errs = np.log10(individual_errs)
    avg_log_estimator_err = np.mean(log_individual_errs, axis=0)
    
    # Average testing set error as a baseline
    average_spectrum_in_testing_set = np.mean(gt, axis=0)
    dummy_testing_set_error = np.log10(np.mean(np.abs(average_spectrum_in_testing_set - gt)))
    ax.axvline(dummy_testing_set_error, color="blue", linestyle="--", linewidth=0.5, zorder=0)
    
    print(atom_type)
    print(log_ensemble_err.shape, avg_log_estimator_err.shape)
    
    # Plot
    bins = [0.2 - ii * 0.05 for ii in range(50)][::-1]
    ax.hist(log_ensemble_err, bins=bins, color="black", label=r"$\varepsilon^{(i)}$" if ii==1 else None)
    ax.hist(avg_log_estimator_err, bins=bins, color="red", alpha=0.5, label=r"$\varepsilon_\mathrm{est}^{(i)}$" if ii==1 else None)
    ax.text(0.1, 0.3, r"$%.02f$" % np.median(log_ensemble_err), color="black", ha="left", va="center", transform=ax.transAxes, fontsize=8)
    ax.text(0.1, 0.2, r"$%.02f$" % np.median(avg_log_estimator_err), color="red", ha="left", va="center", transform=ax.transAxes, fontsize=8)
    t = ax.text(0.9, 0.4, r"$%.02f$" % dummy_testing_set_error, color="blue", ha="right", va="center", transform=ax.transAxes, fontsize=8)
    t.set_bbox(dict(facecolor='white', alpha=1, edgecolor='white'))
    
    # Fine tuning
    adj.set_grids(ax, grid=False)
    ax.set_yticklabels([])
    ax.set_xticks([-3, -2, -1, 0])
    adj.set_xlim(ax, -3, 0)
    
    if ii == 0:  # Carbon
        val = 10000
        ax.text(0.1, 0.8, val, ha="left", va="center", transform=ax.transAxes, color="gray") 
    elif ii == 1:  # Nitrogen
        val = 300
        ax.text(0.1, 0.8, val, ha="left", va="center", transform=ax.transAxes, color="gray") 
    else:  # Oxygen
        val = 300
        ax.text(0.1, 0.8, val, ha="left", va="center", transform=ax.transAxes, color="gray") 

    ax.axhline(val, color="gray", linestyle="--", linewidth=0.5, zorder=0)
    
    #### LOWER PLOTS ####
    
    # Pointwise
    ensemble_pointwise_err = np.abs(gt - pred.mean(axis=0)).flatten()
    ensemble_pointwise_std = pred.std(axis=0).flatten()

    # Spectrum-wise
    ensemble_err = np.abs(gt - pred.mean(axis=0)).mean(axis=1)
    ensemble_std = pred.std(axis=0).mean(axis=-1)
    
    ax = axs[1, ii]
    x = np.log10(ensemble_err[::debug])
    y = np.log10(ensemble_std[::debug])
    ax = density_scatter(x, y, ax=ax, sort=True, bins=bins_map[atom_type], s=0.4, alpha=1, rasterized=True)
    idx = np.argsort(x)
    p = np.polyfit(x[idx], y[idx], deg=1)
    poly = np.poly1d(p)
    ax.axline((-3, poly(-3)), (-2, poly(-2)), color="black", zorder=0, linestyle="--", linewidth=0.5)
    r2 = np.corrcoef(x[idx], y[idx])[0, 1]**2
    ax.text(0.1, 0.1, r"$r^2 = %.02f$" % r2, ha="left", va="bottom", transform=ax.transAxes)
    
    
    ax.set_xlim(-3, 1)
    ax.set_ylim(-3, 1)
    ax.set_xticks([-3, -1, 1])
    ax.set_yticks([-3, -1, 1])
    adj.set_grids(ax, grid=False)
    adj.set_xlim(ax, -3, 1)
    adj.set_ylim(ax, -3, 1)
    
# axs[0, 1].set_xlabel(r"$\log_{10} \varepsilon^{(i)}$")
axs[1, 1].set_xlabel(r"$\log_{10} \varepsilon^{(i)}$")
axs[0, 0].set_ylabel("Counts")
axs[1, 0].set_ylabel(r"$\log_{10} \sigma$")
axs[0, 1].legend(frameon=False, loc="center left", fontsize=10)

axs[1, 1].set_yticklabels([])
axs[1, 2].set_yticklabels([])

letters = ["(a)", "(b)", "(c)", "(d)", "(e)", "(f)"]
for letter, ax in zip(letters, axs.flatten()):
    ax.text(0.1, 0.5, letter, ha="left", va="center", transform=ax.transAxes)

plt.subplots_adjust(wspace=0.1, hspace=0.1)

plt.savefig("qm9_hists_generalize.svg", bbox_inches="tight", dpi=300)
# needs post-processing on InkScape
# plt.show()

## Plot some examples

For the "generalizing" dataset, there is actually a more robust test we can perform. Instead of predicting the site spectra, we can predict the molecular spectra by averaging the predictions and properly propagating errors.

In [None]:
def get_molecular_spectra(atom_type, data=data):
    pred, gt = load_data(atom_type)
    origin_smiles = data[atom_type]["test"]["origin_smiles"]
    
    index_dict = dict()
    for ii, smile in enumerate(origin_smiles):
        try:
            index_dict[smile].append(ii)
        except KeyError:
            index_dict[smile] = [ii]
    
    # Molecular spectrum ground truth
    molecular_smiles = []
    molecular_gt = []
    molecular_preds = []
    molecular_spreads = []
    for smile, list_of_idx in tqdm(index_dict.items()):
        list_of_idx = np.array(list_of_idx)
        molecular_smiles.append(smile)
        molecular_gt.append(gt[list_of_idx, :].mean(axis=0))
        
        # This is a [E x N_atoms_in_molecule x M] array
        ensemble_predictions = pred[:, list_of_idx, :]
        mu = ensemble_predictions.mean(axis=0)
        sd = ensemble_predictions.std(axis=0)  # [N_atoms... x M]
        N_atoms = sd.shape[0]
        
        # Average is just the average
        molecular_preds.append(mu.mean(axis=0))
        
        # Std requires propagation of error
        # sigma_f = sqrt(sigma_1^2 + sigma_2^2 + ...) / N
        molecular_spreads.append(np.sqrt((sd**2).sum(axis=0)) / N_atoms)
    
    return molecular_smiles, np.array(molecular_gt), np.array(molecular_preds), np.array(molecular_spreads)

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(3*2.5, 4), sharey=True)
smiles_list = []
n = 5

for ii, atom_type in enumerate(ATOM_TYPES):
    
    grid = data[atom_type]["train"]["grid"]

    ax = axs[ii]

    # Get all the indexes
    molecular_smiles, molecular_gt, molecular_preds, molecular_spreads = get_molecular_spectra(atom_type)
    err = np.mean(np.abs(molecular_gt - molecular_preds), axis=1)
    sorted_idx = np.argsort(err)[::-1]
    
    # Somewhat "random" offsets in order to get spectra that look halfway decent.
    # We address the "catastrophic failure" cases in the manuscript separately
    if atom_type == "C":
        push = 101
    elif atom_type == "N":
        push = 113
    else:
        push = 84

    
    for ii in range(n): 

        offset = ii * 6
        idx = sorted_idx[int(ii / n * len(sorted_idx)) + push]

        ground_truth_spectra = molecular_gt[idx, :].copy() + offset
        predicted_spectra = molecular_preds[idx, :].copy() + offset
        predicted_spread = molecular_spreads[idx, :].copy()
        smiles = molecular_smiles[idx]

        label = r"$\mu_\mathrm{mol}^{(i)}$" if ii == 0 and atom_type == "N" else None
        ax.plot(grid, ground_truth_spectra, "k-", label=label)

        label = r"$\hat{\mu}_\mathrm{mol}^{(i)}$" if ii == 1 and atom_type == "N" else None
        ax.plot(grid, predicted_spectra, color="red", linewidth=1, label=label)

        
        if ii != 0:
            label = r"$3\sigma$" if ii == 1 and atom_type == "N" else None
            ax.fill_between(grid, predicted_spectra - predicted_spread * 3, predicted_spectra + predicted_spread * 3, color="red", alpha=0.5, linewidth=0, label=label)
        else:
            label = r"$\sigma$" if ii == 0 and atom_type == "N" else None
            ax.fill_between(grid, predicted_spectra - predicted_spread, predicted_spectra + predicted_spread, color="magenta", alpha=0.5, linewidth=0, label=label)

        err = np.log10(np.mean(np.abs(ground_truth_spectra - predicted_spectra))).item()
        ax.text(0.9, 0.2 + ii / 10.7 * 2, r"$%.02f$" % err, ha="right", va="center", transform=ax.transAxes)
        
        if ii == 1:
            ax.legend(frameon=False, ncol=1, loc="upper center", bbox_to_anchor=(1.0, 1.1))
        ax.set_title(atom_type)

        print(f"{err:.02f} : {smiles}")
        smiles_list.append(smiles)

    ax.spines["right"].set_visible(False)
    ax.spines["top"].set_visible(False)
    ax.spines["left"].set_linewidth(0.5)
    ax.spines["bottom"].set_linewidth(0.5)
    adj.set_grids(ax, grid=False, top=False)
    ax.set_yticklabels([])
    ax.set_yticks([])

axs[1].set_xlabel(r"$E$~(e.V.)")
axs[0].set_ylabel(r"$\mu(E)$~(a.u.)")


# plt.show()
plt.savefig("qm9_C_random_preds_generalize.svg", bbox_inches="tight", dpi=300)



In [None]:
smiles_list

In [None]:
from rdkit import Chem
grid = Chem.Draw.MolsToGridImage([Chem.MolFromSmiles(smile) for smile in smiles_list], useSVG=True)
with open("qm9_C_generalize_list.svg", "w") as f:
    f.write(grid.data)