In [None]:
%load_ext autoreload
%autoreload 2
%config Completer.use_jedi = False

In [None]:
from datetime import datetime
from pathlib import Path
import pickle
import sys
import numpy as np
from scipy.stats import sem
from matplotlib.colors import Normalize 
from scipy.interpolate import interpn
from IPython.display import clear_output
from rdkit import Chem
from PyAstronomy.pyasl import broadGaussFast
from tqdm import tqdm

import matplotlib.pyplot as plt
from matplotlib import cm
from pymatgen.core.structure import Molecule

In [None]:
# https://gist.github.com/x94carbone/f5201b1c44963ff9453b9cc1d5f768ac
sys.path.append(str(Path.home() / Path("local")))
from mpl_utils import MPLAdjutant
adj = MPLAdjutant()
adj.set_defaults()

In [None]:
import matplotlib
matplotlib.rcParams['text.latex.preamble']=[r"\usepackage{amsmath}"]

In [None]:
import json

def save_json(d, path):
    with open(path, 'w') as outfile:
        json.dump(d, outfile, indent=4, sort_keys=True)

def read_json(path):
    with open(path, 'r') as infile:
        dat = json.load(infile)
    return dat

Append the `home` path of this project.

In [None]:
# https://stackoverflow.com/questions/20105364/how-can-i-make-a-scatter-plot-colored-by-density-in-matplotlib
def density_scatter(x, y, ax, sort=True, bins=20, **kwargs):
    """
    Scatter plot colored by 2d histogram
    """

    data , x_e, y_e = np.histogram2d( x, y, bins = bins, density = True )
    z = interpn( ( 0.5*(x_e[1:] + x_e[:-1]) , 0.5*(y_e[1:]+y_e[:-1]) ) , data , np.vstack([x,y]).T , method = "splinef2d", bounds_error = False)

    #To be sure to plot all data
    z[np.where(np.isnan(z))] = 0.0

    # Sort the points by density, so that the densest points are plotted last
    if sort :
        idx = z.argsort()
        x, y, z = x[idx], y[idx], z[idx]

    scat = ax.scatter(x, y, c=z, **kwargs)
    return scat

# Load the data and trained ensembles

In [None]:
def load_trained_ensembles(ensemble_root_path="Ensembles"):
    ensembles = dict()
    ensemble_paths = list(Path(ensemble_root_path).rglob("ensemble.json"))
    
    for ensemble_path in ensemble_paths:
        downsample_prop = float(ensemble_path.parent.parts[-1])
        atom_key = str(ensemble_path).split("-ACSF-")[1].split("-")[0]
        if "TOTAL-ATOMS" in str(ensemble_path):
            n_atoms = str(ensemble_path).split("-TOTAL-ATOMS")[0].split("-")[-1]
            atom_key = f"{atom_key}-{n_atoms}"
        if atom_key not in ensembles.keys():
            ensembles[atom_key] = dict()
        ensembles[atom_key][downsample_prop] = Ensemble.from_dict(read_json(ensemble_path))

    return ensembles

def load_data():
    return {
        "O": pickle.load(open("data/qm9/ml_ready/random_splits/XANES-220712-ACSF-O-RANDOM-SPLITS.pkl", "rb")),
        "N": pickle.load(open("data/qm9/ml_ready/random_splits/XANES-220712-ACSF-N-RANDOM-SPLITS.pkl", "rb")),
        "C": pickle.load(open("data/qm9/ml_ready/random_splits/XANES-220712-ACSF-C-RANDOM-SPLITS.pkl", "rb"))
    }

In [None]:
data = load_data()

In [None]:
from xas_nne.ml import Ensemble

In [None]:
ensembles = load_trained_ensembles()

# Evaluate the ensemble effectiveness on the randomly sampled data

## Get the results compiled

In [None]:
# downsample_values = sorted([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], reverse=True)
downsample_values = [0.1, 0.5, 0.9]

In [None]:
ATOMS = ["C"]

Each `preds[atom_type][downsample]` is of the shape `(N_ensemble, N_examples, M)`. These predictions are masked `numpy` arrays, where the mask values correspond to outlier predictions (relative to other estimators), or totally unphysical ones.

In [None]:
preds_no_filter = {
    atom_type: {
        downsample: ensembles[atom_type][downsample].predict(data[atom_type]["test"]["x"])
        for downsample in downsample_values
    } for atom_type in ATOMS
}
clear_output()

In [None]:
preds = {
    atom_type: {
        downsample: ensembles[atom_type][downsample].predict_filter_outliers(
            data[atom_type]["test"]["x"],
            sd_mult=2.0,
            threshold_sd=0.7,
            max_spectra_value=20.0,
            threshold_zero=0.5,
            min_spectra_value=0.05,
        )
        for downsample in downsample_values
    } for atom_type in ATOMS
}
clear_output()

In [None]:
ground_truths = {
    atom_type: data[atom_type]["test"]["y"] for atom_type in ATOMS
}
for atom_type in ATOMS:
    ground_truths[atom_type][ground_truths[atom_type] < 0] = 0.0

In [None]:
grids = {atom_type: data[atom_type]["train"]["grid"] for atom_type in ATOMS}

In [None]:
errors = {
    atom_type: {
        downsample: np.abs(ground_truths[atom_type] - preds[atom_type][downsample].mean(axis=0))
        for downsample in downsample_values
    }
    for atom_type in ATOMS
}

In [None]:
ground_truths["C"].shape

In [None]:
preds['C'][0.1].shape

## Plot the correlation between error and std

In [None]:
def adjacent_values(vals, q1, q3):
    upper_adjacent_value = q3 + (q3 - q1) * 1.5
    upper_adjacent_value = np.clip(upper_adjacent_value, q3, vals[-1])

    lower_adjacent_value = q1 - (q3 - q1) * 1.5
    lower_adjacent_value = np.clip(lower_adjacent_value, vals[0], q1)
    return lower_adjacent_value, upper_adjacent_value


def set_axis_style(ax, labels):
    ax.xaxis.set_tick_params(direction='out')
    ax.xaxis.set_ticks_position('bottom')
    ax.set_xticks(np.arange(1, len(labels) + 1), labels=labels)
    ax.set_xlim(0.25, len(labels) + 0.75)
    ax.set_xlabel('Sample name')
    

In [None]:
bins = [-3.5 + ii for ii in range(5)]

In [None]:
def make_violin_plot(ax, log10_ensemble_pointwise_err, log10_ensemble_pointwise_std, downsample=0.9):
    
    cmap = cm.get_cmap("rainbow", len(bins))
    binned_by_std = np.digitize(log10_ensemble_pointwise_std, bins)
    
    data = [sorted(np.array(log10_ensemble_pointwise_err[np.where(binned_by_std == bin_index)[0]]).tolist()) for bin_index in np.unique(binned_by_std)]

    parts = ax.violinplot(
        data, showmeans=False, showmedians=False, showextrema=False
    )

    for ii, pc in enumerate(parts['bodies']):
        pc.set_facecolor(cmap(ii))
        pc.set_edgecolor('black')
        pc.set_alpha(1)

    quartile1 = []
    medians = []
    quartile3 = []
    for datum in data:
        q1, m, q3 = np.percentile(datum, [25, 50, 75])
        quartile1.append(q1)
        medians.append(m)
        quartile3.append(q3)
    whiskers = np.array([
        adjacent_values(sorted_array, q1, q3)
        for sorted_array, q1, q3 in zip(data, quartile1, quartile3)
    ])
    whiskers_min, whiskers_max = whiskers[:, 0], whiskers[:, 1]

    inds = np.arange(1, len(medians) + 1)
    ax.scatter(inds, medians, marker='o', color='white', s=5, zorder=3)
    ax.vlines(inds, quartile1, quartile3, color='k', linestyle='-', lw=5, zorder=2)
    ax.vlines(inds, whiskers_min, whiskers_max, color='k', linestyle='-', lw=1, zorder=2)
    
    for ii, med in enumerate(medians):
        ax.axhline(med, color=cmap(ii), linewidth=0.5, zorder=1)

In [None]:
debug = None
atom_type = "C"
downsamples = [0.1, 0.5, 0.9]

fig, axs = plt.subplots(1, 3, figsize=(3, 1.5), sharex=False, sharey=True)

bins_map = {"C": 70, "N": 30, "O": 40}


for ii, downsample in tqdm(enumerate(downsamples)):
    
    # Get the predictions
    pred = preds[atom_type][downsample].copy()
    gt = ground_truths[atom_type].copy()
    gt[gt < 0] = 0.0
    
    # Pointwise
    ensemble_pointwise_err = np.abs(gt - pred.mean(axis=0)).flatten()
    ensemble_pointwise_std = pred.std(axis=0).flatten()

    ax = axs[ii]
    make_violin_plot(ax, np.log10(ensemble_pointwise_err)[:debug], np.log10(ensemble_pointwise_std)[:debug])
    ax.set_title(f"$p={downsample:.1f}$")
    adj.set_grids(ax, grid=False)
    ax.set_xticklabels([])
    ax.tick_params(axis="x", which="both", bottom=False, top=False)
    adj.set_ylim(ax, -5, 1)
    ax.set_yticks([1, -2, -5])

axs[0].set_ylabel(r"$\log_{10} \varepsilon^{(i)}_j$")
axs[1].set_xlabel(r"$\log_{10} \hat{\sigma}^{(i)}_j$")

plt.subplots_adjust(wspace=0.1, hspace=0.1)

plt.savefig(f"Figures/qm9_sigma_parity_with_violins_p_resolved_{atom_type}.svg", bbox_inches="tight", dpi=300)
# plt.show()