In [None]:
%load_ext autoreload
%autoreload 2
%config Completer.use_jedi = False

In [None]:
from datetime import datetime
from pathlib import Path
import pickle
import sys
import numpy as np
from scipy.stats import sem
from matplotlib.colors import Normalize 
from scipy.interpolate import interpn
from IPython.display import clear_output
from rdkit import Chem
from PyAstronomy.pyasl import broadGaussFast

import matplotlib.pyplot as plt
from matplotlib import cm
from pymatgen.core.structure import Molecule

In [None]:
# https://gist.github.com/x94carbone/f5201b1c44963ff9453b9cc1d5f768ac
sys.path.append(str(Path.home() / Path("local")))
from mpl_utils import MPLAdjutant
adj = MPLAdjutant()
adj.set_defaults()

In [None]:
import matplotlib
matplotlib.rcParams['text.latex.preamble']=[r"\usepackage{amsmath}"]

In [None]:
import json

def save_json(d, path):
    with open(path, 'w') as outfile:
        json.dump(d, outfile, indent=4, sort_keys=True)

def read_json(path):
    with open(path, 'r') as infile:
        dat = json.load(infile)
    return dat

Append the `home` path of this project.

In [None]:
# https://stackoverflow.com/questions/20105364/how-can-i-make-a-scatter-plot-colored-by-density-in-matplotlib
def density_scatter(x, y, ax, sort=True, bins=20, **kwargs):
    """
    Scatter plot colored by 2d histogram
    """

    data , x_e, y_e = np.histogram2d( x, y, bins = bins, density = True )
    z = interpn( ( 0.5*(x_e[1:] + x_e[:-1]) , 0.5*(y_e[1:]+y_e[:-1]) ) , data , np.vstack([x,y]).T , method = "splinef2d", bounds_error = False)

    #To be sure to plot all data
    z[np.where(np.isnan(z))] = 0.0

    # Sort the points by density, so that the densest points are plotted last
    if sort :
        idx = z.argsort()
        x, y, z = x[idx], y[idx], z[idx]

    scat = ax.scatter(x, y, c=z, **kwargs)
    return scat

# Load the data and trained ensembles

In [None]:
def load_trained_ensembles(ensemble_root_path="Ensembles"):
    ensembles = dict()
    ensemble_paths = list(Path(ensemble_root_path).rglob("ensemble.json"))
    
    for ensemble_path in ensemble_paths:
        downsample_prop = float(ensemble_path.parent.parts[-1])
        atom_key = str(ensemble_path).split("-ACSF-")[1].split("-")[0]
        if "TOTAL-ATOMS" in str(ensemble_path):
            n_atoms = str(ensemble_path).split("-TOTAL-ATOMS")[0].split("-")[-1]
            atom_key = f"{atom_key}-{n_atoms}"
        if atom_key not in ensembles.keys():
            ensembles[atom_key] = dict()
        ensembles[atom_key][downsample_prop] = Ensemble.from_dict(read_json(ensemble_path))

    return ensembles

def load_data():
    return {
        "O": pickle.load(open("data/qm9/ml_ready/random_splits/XANES-220712-ACSF-O-RANDOM-SPLITS.pkl", "rb")),
        "N": pickle.load(open("data/qm9/ml_ready/random_splits/XANES-220712-ACSF-N-RANDOM-SPLITS.pkl", "rb")),
        "C": pickle.load(open("data/qm9/ml_ready/random_splits/XANES-220712-ACSF-C-RANDOM-SPLITS.pkl", "rb"))
    }

In [None]:
data = load_data()

In [None]:
from xas_nne.ml import Ensemble

In [None]:
ensembles = load_trained_ensembles()

# Evaluate the ensemble effectiveness on the randomly sampled data

## Get the results compiled

In [None]:
downsample_values = sorted([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], reverse=True)
# downsample_values = [0.1, 0.2, 0.0.9]

In [None]:
ATOMS = ["C", "N", "O"]

Each `preds[atom_type][downsample]` is of the shape `(N_ensemble, N_examples, M)`. These predictions are masked `numpy` arrays, where the mask values correspond to outlier predictions (relative to other estimators), or totally unphysical ones.

In [None]:
preds_no_filter = {
    atom_type: {
        downsample: ensembles[atom_type][downsample].predict(data[atom_type]["test"]["x"])
        for downsample in downsample_values
    } for atom_type in ATOMS
}
clear_output()

In [None]:
preds = {
    atom_type: {
        downsample: ensembles[atom_type][downsample].predict_filter_outliers(
            data[atom_type]["test"]["x"],
            sd_mult=2.0,
            threshold_sd=0.7,
            max_spectra_value=20.0,
            threshold_zero=0.5,
            min_spectra_value=0.05,
        )
        for downsample in downsample_values
    } for atom_type in ATOMS
}
clear_output()

In [None]:
ground_truths = {
    atom_type: data[atom_type]["test"]["y"] for atom_type in ATOMS
}
for atom_type in ATOMS:
    ground_truths[atom_type][ground_truths[atom_type] < 0] = 0.0

In [None]:
grids = {atom_type: data[atom_type]["train"]["grid"] for atom_type in ATOMS}

In [None]:
errors = {
    atom_type: {
        downsample: np.abs(ground_truths[atom_type] - preds[atom_type][downsample].mean(axis=0))
        for downsample in downsample_values
    }
    for atom_type in ATOMS
}

## Plot the average errors $\varepsilon$ for each atom type

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(3, 2))

colors = {"C": "black", "N": "blue", "O": "red"}

for atom_type in ATOMS:
    e = []
    for downsample in downsample_values:
        e.append(errors[atom_type][downsample].mean())
    e = np.array(e)
    print(e)
    ax.plot(downsample_values, e*100, marker="o", color=colors[atom_type], label=atom_type)
    
adj.set_grids(ax, grid=False)
ax.legend(frameon=False)
ax.set_xticks([0.1, 0.5, 0.9])
ax.set_ylim(2.9, 6.1)
ax.set_xlabel("$p$")
ax.set_ylabel(r"$100\varepsilon(p)$")

plt.savefig("Figures/qm9_errors_as_p.svg", bbox_inches="tight", dpi=300)
# plt.show()

## Plot some examples

Sort by the errors.

In [None]:
atom_type = "N"
downsample = 0.9

In [None]:
argsorted = np.argsort( errors[atom_type][downsample].mean(axis=-1) )

Decide on an example and plot it.

In [None]:
predicted_spectra = preds[atom_type][downsample].copy()
ground_truth_spectra = ground_truths[atom_type].copy()

In [None]:
len(argsorted)

In [None]:
ii = -10
ii = argsorted[ii]

predicted_spectra = preds[atom_type][downsample][:, ii, :]
ground_truth_spectra = ground_truths[atom_type][ii, :]

mu = predicted_spectra.mean(axis=0)
sd = predicted_spectra.std(axis=0)
# cond = (predicted_spectra > mu + 3 * sd) | (predicted_spectra < mu - 3 * sd)
# where_keep = np.where(cond.sum(axis=1) < 150)[0]
# predicted_spectra = predicted_spectra[where_keep, :]
# mu = predicted_spectra.mean(axis=0)
# sd = predicted_spectra.std(axis=0)

fig, ax = plt.subplots(1, 1, figsize=(3, 2))

print(data[atom_type]["test"]["origin_smiles"][ii])

ax.plot(grids[atom_type], ground_truth_spectra, "k-")

for prediction in predicted_spectra:
    ax.plot(grids[atom_type], prediction, 'r-', linewidth=0.5, alpha=0.5)

# ax.plot(grids[atom_type], mu, "r-")
ax.fill_between(grids[atom_type], mu - sd * 3, mu + sd * 3, alpha=0.4, color="red", linewidth=0)

err = np.mean(np.abs(ground_truth_spectra - mu))
print(f"{np.log10(err):.02f}")

plt.show()


## Error histograms

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(6, 6/3/1.6))

downsample = 0.9

for ii, (ax, atom_type) in enumerate(zip(axs, ["C", "N", "O"])):
    
    ax.set_title(atom_type)

    # Get the data
    pred = preds[atom_type][downsample].copy()
    pred2 = preds_no_filter[atom_type][downsample].copy()
    gt = ground_truths[atom_type].copy()
    
    # Ensemble error itself
    ensemble_err = np.mean(np.abs(gt - pred.mean(axis=0)), axis=-1)
    log_ensemble_err = np.log10(ensemble_err)
    
    # Individual errors from each estimator
    individual_errs = np.mean(np.abs(gt - pred2), axis=-1)
    log_individual_errs = np.log10(individual_errs)
    avg_log_estimator_err = np.mean(log_individual_errs, axis=0)
    
    # Average testing set error as a baseline
    average_spectrum_in_testing_set = np.mean(gt, axis=0)
    dummy_testing_set_error = np.log10(np.mean(np.abs(average_spectrum_in_testing_set - gt)))
    ax.axvline(dummy_testing_set_error, color="blue", linestyle="--", linewidth=0.5, zorder=0)
    
    # Plot
    bins = [0.2 - ii * 0.05 for ii in range(50)][::-1]
    ax.hist(log_ensemble_err, bins=bins, color="black", label=r"$\varepsilon^{(i)}$" if ii==1 else None)
    ax.hist(avg_log_estimator_err, bins=bins, color="red", alpha=0.5, label=r"$\varepsilon_\mathrm{est}^{(i)}$" if ii==1 else None)
    ax.text(0.1, 0.3, r"$%.02f$" % np.mean(log_ensemble_err), color="black", ha="left", va="center", transform=ax.transAxes, fontsize=8)
    ax.text(0.1, 0.2, r"$%.02f$" % np.mean(avg_log_estimator_err), color="red", ha="left", va="center", transform=ax.transAxes, fontsize=8)
    t = ax.text(0.9, 0.4, r"$%.02f$" % dummy_testing_set_error, color="blue", ha="right", va="center", transform=ax.transAxes, fontsize=8)
    t.set_bbox(dict(facecolor='white', alpha=1, edgecolor='white'))
    
    # Fine tuning
    adj.set_grids(ax, grid=False)
    ax.set_yticklabels([])
    ax.set_xticks([-3, -2, -1, 0])
    adj.set_xlim(ax, -3, 0)
    
    if ii == 0:  # Carbon
        val = 3000
        ax.text(0.1, 0.8, val, ha="left", va="center", transform=ax.transAxes, color="gray") 
    elif ii == 1:  # Nitrogen
        val = 1500
        ax.text(0.1, 0.8, val, ha="left", va="center", transform=ax.transAxes, color="gray") 
    else:  # Oxygen
        val = 2000
        ax.text(0.1, 0.8, val, ha="left", va="center", transform=ax.transAxes, color="gray") 

    ax.axhline(val, color="gray", linestyle="--", linewidth=0.5, zorder=0)
    
        
axs[1].set_xlabel(r"$\log_{10} \varepsilon^{(i)}$")
axs[0].set_ylabel("Counts")
axs[1].legend(frameon=False, loc="center left", fontsize=10)

plt.subplots_adjust(wspace=0.1)

# plt.savefig("Figures/qm9_hists.svg", bbox_inches="tight", dpi=300)
# needs post-processing on InkScape
# clear_output()
plt.show()

## Error histograms (full plot of everything; might not use this)

In [None]:
scale = 1.5

fig, axs = plt.subplots(len(downsample_values), 3, figsize=(4 * scale, 6 * scale), sharex=True)


for ii, (tmp_ax, downsample) in enumerate(zip(axs, downsample_values)):
    
    for jj, atom_type in enumerate(["C", "N", "O"]):
        ax = tmp_ax[jj]
    
        # Get the data
        pred = preds[atom_type][downsample].copy()
        gt = ground_truths[atom_type].copy()

        # Ensemble error itself - these should include the bad predictions
        ensemble_err = np.mean(np.abs(gt - pred.mean(axis=0)), axis=-1)
        log_ensemble_err = np.log10(ensemble_err)

        # Individual errors from each estimator
        pred2 = preds_no_filter[atom_type][downsample].copy()
        individual_errs = np.mean(np.abs(gt - pred2), axis=-1)
        log_individual_errs = np.log10(individual_errs)
        avg_log_estimator_err = np.mean(log_individual_errs, axis=0)

        # Average testing set error as a baseline
        average_spectrum_in_testing_set = np.mean(gt, axis=0)
        dummy_testing_set_error = np.log10(np.mean(np.abs(average_spectrum_in_testing_set - gt)))
        ax.axvline(dummy_testing_set_error, color="blue", linestyle="--", linewidth=0.5, zorder=0)

        # Plot
        bins = [0.2 - ii * 0.05 for ii in range(50)][::-1]
        ax.hist(log_ensemble_err, bins=bins, color="black", label=r"$\varepsilon^{(i)}$" if ii==0 and jj == 1 else None, density=True)
        ax.hist(avg_log_estimator_err, bins=bins, color="red", alpha=0.5, label=r"$\varepsilon_\mathrm{est}^{(i)}$" if ii==0 and jj == 1 else None, density=True)
        ax.text(0.1, 0.4, r"$%.02f$" % np.median(log_ensemble_err), color="black", ha="left", va="center", transform=ax.transAxes, fontsize=8)
        ax.text(0.1, 0.2, r"$%.02f$" % np.median(avg_log_estimator_err), color="red", ha="left", va="center", transform=ax.transAxes, fontsize=8)
        
        # Labels for the downsample values
        if jj == 0:
            t = ax.text(0.1, 0.8, r"$%.01f$" % downsample, color="black", ha="left", va="top", transform=ax.transAxes, fontsize=8)
            # t.set_bbox(dict(facecolor='grey', alpha=0.5, edgecolor='white'))

        # Vertical lines for the dummy model baseline
        if ii == 0:
            t = ax.text(0.9, 0.4, r"$%.02f$" % dummy_testing_set_error, color="blue", ha="right", va="center", transform=ax.transAxes, fontsize=8)
            t.set_bbox(dict(facecolor='white', alpha=1, edgecolor='white'))

        # Fine tuning
        adj.set_grids(ax, grid=False)
        ax.set_yticklabels([])
        ax.set_yticks([])
        ax.set_xticks([-3, 0])
        adj.set_xlim(ax, -3, 0)
        
axs[-1, 1].set_xlabel(r"$\log_{10} \varepsilon^{(i)}$")
axs[4, 0].set_ylabel("Density")
axs[0, 1].legend(frameon=False, loc="upper left", fontsize=10)
axs[0, 0].set_title("C")
axs[0, 1].set_title("N")
axs[0, 2].set_title("O")

plt.subplots_adjust(wspace=0.05, hspace=0.1)

# plt.savefig("qm9_hists.svg", bbox_inches="tight", dpi=300)
# needs post-processing on InkScape
# clear_output()
plt.show()

## Plot the correlation between error and std

In [None]:
def adjacent_values(vals, q1, q3):
    upper_adjacent_value = q3 + (q3 - q1) * 1.5
    upper_adjacent_value = np.clip(upper_adjacent_value, q3, vals[-1])

    lower_adjacent_value = q1 - (q3 - q1) * 1.5
    lower_adjacent_value = np.clip(lower_adjacent_value, vals[0], q1)
    return lower_adjacent_value, upper_adjacent_value


def set_axis_style(ax, labels):
    ax.xaxis.set_tick_params(direction='out')
    ax.xaxis.set_ticks_position('bottom')
    ax.set_xticks(np.arange(1, len(labels) + 1), labels=labels)
    ax.set_xlim(0.25, len(labels) + 0.75)
    ax.set_xlabel('Sample name')
    

In [None]:
bins = [-3.5 + ii for ii in range(5)]

In [None]:
def make_violin_plot(ax, log10_ensemble_pointwise_err, log10_ensemble_pointwise_std, downsample=0.9):
    
    cmap = cm.get_cmap("rainbow", len(bins))
    binned_by_std = np.digitize(log10_ensemble_pointwise_std, bins)
    
    data = [sorted(np.array(log10_ensemble_pointwise_err[np.where(binned_by_std == bin_index)[0]]).tolist()) for bin_index in np.unique(binned_by_std)]

    parts = ax.violinplot(
        data, showmeans=False, showmedians=False, showextrema=False
    )

    for ii, pc in enumerate(parts['bodies']):
        pc.set_facecolor(cmap(ii))
        pc.set_edgecolor('black')
        pc.set_alpha(1)

    quartile1 = []
    medians = []
    quartile3 = []
    for datum in data:
        q1, m, q3 = np.percentile(datum, [25, 50, 75])
        quartile1.append(q1)
        medians.append(m)
        quartile3.append(q3)
    whiskers = np.array([
        adjacent_values(sorted_array, q1, q3)
        for sorted_array, q1, q3 in zip(data, quartile1, quartile3)
    ])
    whiskers_min, whiskers_max = whiskers[:, 0], whiskers[:, 1]

    inds = np.arange(1, len(medians) + 1)
    ax.scatter(inds, medians, marker='o', color='white', s=5, zorder=3)
    ax.vlines(inds, quartile1, quartile3, color='k', linestyle='-', lw=5, zorder=2)
    ax.vlines(inds, whiskers_min, whiskers_max, color='k', linestyle='-', lw=1, zorder=2)
    
    for ii, med in enumerate(medians):
        ax.axhline(med, color=cmap(ii), linewidth=0.5, zorder=1)

In [None]:
debug = None

fig, axs = plt.subplots(2, len(ATOMS), figsize=(6, 3), sharex=False, sharey=True, gridspec_kw={"height_ratios": [1, 2]})

bins_map = {"C": 70, "N": 30, "O": 40}

downsample = 0.9


ax_min = -6
ax_max = 2
for ax in axs[1, :]:
    ax.set_xlim(ax_min, ax_max)
    ax.set_ylim(ax_min, ax_max)
    ax.set_xticks([-6, -2, 2])
    ax.set_yticks([-6, -2, 2])
    adj.set_grids(ax, grid=False)
    adj.set_xlim(ax, ax_min, ax_max)
    adj.set_ylim(ax, ax_min, ax_max)
    
for ii, atom_type in enumerate(ATOMS):
    
    # Get the predictions
    pred = preds[atom_type][downsample].copy()
    gt = ground_truths[atom_type].copy()
    gt[gt < 0] = 0.0
    
    # Pointwise
    ensemble_pointwise_err = np.abs(gt - pred.mean(axis=0)).flatten()
    ensemble_pointwise_std = pred.std(axis=0).flatten()

    ax = axs[1, ii]
    y = np.log10(ensemble_pointwise_err[::debug])
    x = np.log10(ensemble_pointwise_std[::debug])
    scat = density_scatter(x, y, ax=ax, sort=True, bins=bins_map[atom_type], s=0.4, alpha=1, rasterized=True)
    p = np.polyfit(x[::10], y[::10], deg=1)
    r2 = np.corrcoef(x[::10], y[::10])[0, 1]**2
    poly = np.poly1d(p)
    ax.axline((-3, poly(-3)), (-2, poly(-2)), color="black", linestyle="-", linewidth=0.5, alpha=0.9)
    ax.axline((-3, poly(-3) + 0.5), (-2, poly(-2) + 0.5), color="black", linestyle="--", linewidth=0.5, alpha=0.8)
    ax.axline((-3, poly(-3) + 1.0), (-2, poly(-2) + 1.0), color="black", linestyle="--", linewidth=0.5, alpha=0.7)
    ax.axline((-3, poly(-3) + 1.5), (-2, poly(-2) + 1.5), color="black", linestyle="--", linewidth=0.5, alpha=0.6)
    ax.axline((-3, poly(-3) + 2.0), (-2, poly(-2) + 2.0), color="black", linestyle="--", linewidth=0.5, alpha=0.5)
    ax.text(0.1, 0.9, r"$r^2 = %.02f$" % r2, ha="left", va="top", transform=ax.transAxes)
    adj.set_xlim(ax, -6, 2)
    
    cmap = cm.get_cmap("rainbow", len(bins))
    alpha = 0.2
    ax.fill_between(np.linspace(-7, -3.5, 100), -7, 3, color=cmap(0), alpha=alpha, linewidth=0, zorder=0)
    ax.fill_between(np.linspace(-3.5, -2.5, 100), -7, 3, color=cmap(1), alpha=alpha, linewidth=0, zorder=0)
    ax.fill_between(np.linspace(-2.5, -1.5, 100), -7, 3, color=cmap(2), alpha=alpha, linewidth=0, zorder=0)
    ax.fill_between(np.linspace(-1.5, -0.5, 100), -7, 3, color=cmap(3), alpha=alpha, linewidth=0, zorder=0)
    ax.fill_between(np.linspace(-0.5, 3, 100), -7, 3, color=cmap(4), alpha=alpha, linewidth=0, zorder=0)
    
    # if ii == 2:
    #     adj.add_colorbar(scat)
    
    ax = axs[0, ii]
    make_violin_plot(ax, np.log10(ensemble_pointwise_err), np.log10(ensemble_pointwise_std))
    ax.set_title(atom_type)
    adj.set_grids(ax, grid=False)
    ax.set_xticklabels([])
    ax.tick_params(axis="x", which="both", bottom=False, top=False)

    
ax = fig.add_subplot(111, frameon=False)
# hide tick and tick label of the big axes
plt.tick_params(labelcolor='none', top='off', bottom='off', left='off', right='off')
ax.set_xticks([])
ax.set_yticks([])

ax.set_ylabel(r"$\log_{10} \varepsilon^{(i)}_j$")
ax.set_xlabel(r"$\log_{10} \hat{\sigma}^{(i)}_j$")

plt.subplots_adjust(wspace=0.1, hspace=0.1)

plt.savefig("Figures/qm9_sigma_parity_with_violins.svg", bbox_inches="tight", dpi=300)
# plt.show()

### Examine some of the outlier predictions in the O database

In [None]:
pred = preds["O"][0.9].copy()
gt = ground_truths["O"].copy()
    
# Pointwise
ensemble_pointwise_err = np.log10(np.abs(gt - pred.mean(axis=0)).flatten())
ensemble_pointwise_std = np.log10(pred.std(axis=0).flatten())

In [None]:
pred_index = np.array([ii // 200 for ii in range(len(ensemble_pointwise_err))])

In [None]:
where_O_bad = np.where( (ensemble_pointwise_std < -2) & (ensemble_pointwise_err > -1) )[0]

In [None]:
pred_index[where_O_bad]

Get the SMILES closest to the ground truth. What's going on here?

In [None]:
ii_star = 14107

In [None]:
dists = np.sum(np.abs(pred[:, ii_star, :].mean(axis=0) - data["O"]["train"]["y"]), axis=1)

In [None]:
argsorted = np.argsort(dists)

In [None]:
argsorted

In [None]:
print(grids["O"].min())
print(grids["O"].max())

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(3, 1.5))

ax.plot(grids["O"], gt[ii_star, :], color="black", label=r"{\boldmath$\mu$}$^{(i)}$")
for jj in range(30):
    ax.plot(grids["O"], pred[jj, ii_star, :], alpha=0.5, color="red", label=r"{\boldmath$\hat{\mu}$}$^{(i, k)}$" if jj == 0 else None)

as_idx = 0
ax.plot(grids["O"], data["O"]["train"]["y"][argsorted[as_idx]], color="black", linestyle="--", label=r"{\boldmath$\mu$}$^\star$")

print(data["O"]["test"]["names"][ii_star])
print(data["O"]["train"]["names"][argsorted[as_idx]])

avg_pred = pred[:, ii_star, :].mean(axis=0)
std = pred[:, ii_star, :].std(axis=0).mean()
print(np.log10(np.mean(np.abs(data["O"]["train"]["y"][argsorted[as_idx]] - avg_pred))))
print(np.log10(std))


adj.set_grids(ax, grid=False)
ax.set_yticklabels([])
ax.set_xticks([530, 555, 580])
ax.set_xlabel(r"$E$~(e.V.)")
ax.set_ylabel(r"$\mu(E)$~(a.u.)")
ax.legend(frameon=False)

plt.show()
# plt.savefig("Figures/qm9_O_fail_spec.svg", bbox_inches="tight", dpi=300)

In [None]:
from rdkit import Chem

names = ["CCCC1(C)COC=N1", "CCCC1COC=N1"]

svgs = Chem.Draw.MolsToGridImage([Chem.MolFromSmiles(smile) for smile in names], useSVG=True)
with open("Figures/qm9_O_fail.svg", "w") as f:
    f.write(svgs.data)

We also ran VASP calculations to try and figure out if this is truly a failure of FEFF.

In [None]:
feff1_path = "data/qm9/qm9_tests/013393/FEFF/018_O/xmu.dat"
feff1 = np.loadtxt(feff1_path, comments="#")

feff2_path = "data/qm9/qm9_tests/118981/FEFF/021_O/xmu.dat"
feff2 = np.loadtxt(feff2_path, comments="#")

vasp1_path = "data/qm9/qm9_tests/013393/VASP/018_O/mu.txt"
vasp1 = np.loadtxt(vasp1_path, skiprows=3)

vasp2_path = "data/qm9/qm9_tests/118981/VASP/021_O/mu.txt"
vasp2 = np.loadtxt(vasp2_path, skiprows=3)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(3, 1.5))



sig = 0.5
mul = 10**5 * 2
shift = 18

ax.plot(vasp1[:, 0] + shift, broadGaussFast(vasp1[:, 0], vasp1[:, 3], sig) * mul, "r-")
ax.plot(vasp2[:, 0] + shift, broadGaussFast(vasp2[:, 0], vasp2[:, 3], sig) * mul, "k-")



for jj in range(30):
    ax.plot(grids["O"], pred[jj, ii_star, :], alpha=0.5, color="red", label=r"{\boldmath$\hat{\mu}$}$^{(i, k)}$" if jj == 0 else None)
ax.plot(feff1[:, 0], feff1[:, 3], "k-")
ax.plot(feff2[:, 0], feff2[:, 3], color="grey", linewidth=1)
    
ax.set_xlim(525, 545)

adj.set_grids(ax, grid=False)
ax.set_yticklabels([])
# ax.set_xticks([530, 555, 580])
ax.set_xlabel(r"$E$~(e.V.)")
ax.set_ylabel(r"$\mu(E)$~(a.u.)")
ax.legend(frameon=False)

plt.show()

### Examine some of the outlier predictions in the N database

Turns out some of these are from when the FEFF spectrum itself is slightly negative near the edge onset. This is now fixed.

In [None]:
pred = preds["N"][0.9].copy()
gt = ground_truths["N"].copy()
    
# Pointwise
ensemble_pointwise_err = np.log10(np.abs(gt - pred.mean(axis=0)).flatten())
ensemble_pointwise_std = np.log10(pred.std(axis=0).flatten())

In [None]:
pred_index = np.array([ii // 200 for ii in range(len(ensemble_pointwise_err))])

In [None]:
where_N_bad = np.where( (ensemble_pointwise_std < -4) & (ensemble_pointwise_err > -2) )[0]

In [None]:
pred_index[where_N_bad]

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(3, 2))

ax.plot(pred[:, 4, :].mean(axis=0))
ax.plot(gt[4, :])

plt.show()

In [None]:
xx =  1 - 972766/8334600
print(f"C: {xx*100:.02f}% succeed")
xx = 1 - 302591/2717200
print(f"N: {xx*100:.02f}% succeed")
xx = 1 - 406832/3669800
print(f"O: {xx*100:.02f}% succeed")

## Plot some examples

In [None]:
def mol_with_atom_index(mol):
    for atom in mol.GetAtoms():
        atom.SetAtomMapNum(atom.GetIdx())
    return mol

In [None]:
atom_type = "C"
downsample = 0.9
grid = grids[atom_type]
pred = preds["C"][downsample]
gt = ground_truths[atom_type]

In [None]:
err = np.mean(np.abs(gt - pred.mean(axis=0)), axis=1)
sorted_idx = np.argsort(err)[::-1]

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(3, 5))

n = 10
smiles = []
atom_indexes = []
for ii in range(n):
    
    offset = ii * 6
    idx = sorted_idx[int(ii / n * len(sorted_idx)) + 4]
    
    ground_truth_spectra = gt[idx, :]
    predicted_spectra = pred[:, idx, :]

    mu = predicted_spectra.mean(axis=0) + offset
    sd = predicted_spectra.std(axis=0) * 3
    smile = data[atom_type]["test"]["origin_smiles"][idx]
    smiles.append(smile)
    atom_index = data[atom_type]["test"]["names"][idx].split("_")[1]
    atom_indexes.append(atom_index)
    print(smile, data[atom_type]["test"]["names"][idx], atom_index)
    
    label = r"{\boldmath$\mu$}$^{(i)}$" if ii == 0 else None
    ax.plot(grid, ground_truth_spectra + offset, "k-", label=label, zorder=2, linewidth=0.5)

    label = r"{\boldmath$\hat{\mu}$}$^{(i)}$" if ii == 0 else None
    ax.plot(grid, mu, color="red", linewidth=1, label=label, zorder=4)
    
#     # for jj, prediction in enumerate(predicted_spectra):
#     #     label = r"$\hat{\mu}^{(i, k)}$" if jj == 0 and ii == 0 else None 
#     #     ax.plot(grid, prediction + offset, 'r-', linewidth=0.5, alpha=0.5, label=label)
#     # label=r"{\boldmath$\mu$}$^\star$"
    label = r"$3${\boldmath$\hat{\sigma}$}$^{(i)}$" if ii == 0 else None
    ax.fill_between(grid, mu - sd, mu + sd, color="red", alpha=0.5, linewidth=0, label=label, zorder=3)
    
    err = np.log10(np.mean(np.abs(ground_truth_spectra - predicted_spectra))).item()
    ax.text(0.9, 0.09 + ii / 10.5, r"$%.02f$" % err, ha="right", va="center", transform=ax.transAxes)
    ax.text(0.7, 0.09 + ii / 10.5, atom_index, ha="right", va="center", transform=ax.transAxes)
    
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.spines["left"].set_linewidth(0.5)
ax.spines["bottom"].set_linewidth(0.5)
adj.set_grids(ax, grid=False, top=False)
ax.set_yticklabels([])
ax.set_yticks([])
ax.set_xlabel(r"$E$~(e.V.)")
ax.set_ylabel(r"$\mu(E)$~(a.u.)")
ax.legend(frameon=False, ncol=3, loc="upper center", bbox_to_anchor=(0.5, 1.1))

# err = np.log10(np.mean(np.abs(gt[ii] - pred[ii])))
# print(f"{err:.02f}")


# plt.show()
plt.savefig("Figures/qm9_C_random_preds-2.svg", bbox_inches="tight", dpi=300)

In [None]:
from rdkit import Chem
svgs = Chem.Draw.MolsToGridImage([Chem.MolFromSmiles(smile) for smile in smiles], useSVG=True)
with open("Figures/qm9_C_examples-2.svg", "w") as f:
    f.write(svgs.data)