In [None]:
%load_ext autoreload
%autoreload 2
%config Completer.use_jedi = False

In [None]:
from datetime import datetime
from pathlib import Path
import pickle
import sys
import numpy as np
from tqdm import tqdm
from scipy.stats import sem
from matplotlib.colors import Normalize 
from scipy.interpolate import interpn
from IPython.display import clear_output

import matplotlib.pyplot as plt
from pymatgen.core.structure import Molecule

In [None]:
# https://gist.github.com/x94carbone/f5201b1c44963ff9453b9cc1d5f768ac
sys.path.append(str(Path.home() / Path("local")))
from mpl_utils import MPLAdjutant
adj = MPLAdjutant()
adj.set_defaults()

In [None]:
import json

def save_json(d, path):
    with open(path, 'w') as outfile:
        json.dump(d, outfile, indent=4, sort_keys=True)

def read_json(path):
    with open(path, 'r') as infile:
        dat = json.load(infile)
    return dat

Append the `home` path of this project.

In [None]:
# https://stackoverflow.com/questions/20105364/how-can-i-make-a-scatter-plot-colored-by-density-in-matplotlib
def density_scatter(x, y, ax, sort=True, bins=20, **kwargs):
    """
    Scatter plot colored by 2d histogram
    """

    data , x_e, y_e = np.histogram2d( x, y, bins = bins, density = True )
    z = interpn( ( 0.5*(x_e[1:] + x_e[:-1]) , 0.5*(y_e[1:]+y_e[:-1]) ) , data , np.vstack([x,y]).T , method = "splinef2d", bounds_error = False)

    #To be sure to plot all data
    z[np.where(np.isnan(z))] = 0.0

    # Sort the points by density, so that the densest points are plotted last
    if sort :
        idx = z.argsort()
        x, y, z = x[idx], y[idx], z[idx]

    ax.scatter(x, y, c=z, **kwargs)

    # norm = Normalize(vmin = np.min(z), vmax = np.max(z))
    # cbar = fig.colorbar(cm.ScalarMappable(norm = norm), ax=ax)
    # cbar.ax.set_ylabel('Density')

    return ax


def get_molecular_spectra(pred, gt, origin_smiles):
    # pred, gt = load_data(atom_type)
    # origin_smiles = data[atom_type]["test"]["origin_smiles"]
    
    index_dict = dict()
    for ii, smile in enumerate(origin_smiles):
        try:
            index_dict[smile].append(ii)
        except KeyError:
            index_dict[smile] = [ii]
    
    # Molecular spectrum ground truth
    molecular_smiles = []
    molecular_gt = []
    molecular_preds = []
    molecular_spreads = []
    for smile, list_of_idx in tqdm(index_dict.items()):
        list_of_idx = np.array(list_of_idx)
        molecular_smiles.append(smile)
        molecular_gt.append(gt[list_of_idx, :].mean(axis=0))
        
        # This is a [E x N_atoms_in_molecule x M] array
        ensemble_predictions = pred[:, list_of_idx, :]
        mu = ensemble_predictions.mean(axis=0)
        sd = ensemble_predictions.std(axis=0)  # [N_atoms... x M]
        N_atoms = sd.shape[0]
        
        # Average is just the average
        molecular_preds.append(mu.mean(axis=0))
        
        # Std requires propagation of error
        # sigma_f = sqrt(sigma_1^2 + sigma_2^2 + ...) / N
        molecular_spreads.append(np.sqrt((sd**2).sum(axis=0)) / N_atoms)
    
    return molecular_smiles, np.array(molecular_gt), np.array(molecular_preds), np.array(molecular_spreads)

# Load the data and trained ensembles

In [None]:
def load_trained_ensembles(ensemble_root_path="Ensembles", atom_types=["C", "N", "O"], n_atoms=[5, 6, 7, 8]):
    ensembles = dict()
    for atom_type in atom_types:
        ensembles[atom_type] = dict()
        for n_atom in n_atoms:
            ensemble_path = f"{ensemble_root_path}/XANES-220711-ACSF-{atom_type}-TRAIN-ATMOST-{n_atom}-TOTAL-ATOMS/0.9/ensemble.json"
            ensembles[atom_type][n_atom] = Ensemble.from_dict(read_json(ensemble_path))
    return ensembles

def load_data():
    
    # Just used for testing!
    return {
        "C": pickle.load(open("data/qm9/ml_ready/by_total_atoms/XANES-220711-ACSF-C-TRAIN-ATMOST-5-TOTAL-ATOMS.pkl", "rb")),
        "N": pickle.load(open("data/qm9/ml_ready/by_total_atoms/XANES-220711-ACSF-N-TRAIN-ATMOST-5-TOTAL-ATOMS.pkl", "rb")),
        "O": pickle.load(open("data/qm9/ml_ready/by_total_atoms/XANES-220711-ACSF-O-TRAIN-ATMOST-5-TOTAL-ATOMS.pkl", "rb"))
    }

In [None]:
data = load_data()

In [None]:
from xas_nne.ml import Ensemble

In [None]:
ensembles = load_trained_ensembles()

# Evaluate the ensemble effectiveness on the generalization test data

## Get the results compiled

In [None]:
n_training_atoms = [5, 6, 7, 8]

In [None]:
ATOMS = ["C", "N", "O"]

Each `preds[atom_type][downsample]` is of the shape `(N_ensemble, N_examples, M)`. These predictions are masked `numpy` arrays, where the mask values correspond to outlier predictions (relative to other estimators), or totally unphysical ones.

In [None]:
# path = Path("Ensembles/preds-gen-downsample-5.pkl")
# if path.exists():
#     print(f"loading from path {path}")
#     preds = pickle.load(open(path, "rb"))
# else:
#     preds = {
#         atom_type: {
#             n_atoms: ensembles[atom_type][n_atoms].predict_filter_outliers(
#                 data[atom_type]["test"]["x"][::5, :],
#                 sd_mult=2.0,
#                 threshold=0.7,
#                 max_spectra_value=20.0
#             )
#             for n_atoms in n_training_atoms
#         } for atom_type in ATOMS
#     }
#     clear_output()
#     pickle.dump(preds, open(path, "wb"), protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
ground_truths = {
    atom_type: data[atom_type]["test"]["y"] for atom_type in ATOMS
}

In [None]:
# path = Path("Ensembles/gt-gen-downsample-5.pkl")
# if path.exists():
#     print(f"loading from path {path}")
#     ground_truths = pickle.load(open(path, "rb"))
# else:
#     ground_truths = {
#         atom_type: data[atom_type]["test"]["y"][::5, :] for atom_type in ATOMS
#     }
#     pickle.dump(ground_truths, open(path, "wb"), protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
grids = {atom_type: data[atom_type]["train"]["grid"] for atom_type in ATOMS}

In [None]:
fname = Path("molecule_results.pkl")
if fname.exists():
    print(f"{fname} exists")
    results = pickle.load(open(fname, "rb"))
else:
    results = dict()
    for atom_type in ATOMS:
        results[atom_type] = dict()
        for n_atoms in n_training_atoms:
            pred = ensembles[atom_type][n_atoms].predict_filter_outliers(
                data[atom_type]["test"]["x"],
                sd_mult=2.0,
                threshold_sd=0.7,
                max_spectra_value=20.0,
                threshold_zero=0.5,
                min_spectra_value=0.05,
            )
            molecular_smiles, molecular_gt, molecular_preds, molecular_spreads = get_molecular_spectra(
                pred, ground_truths[atom_type], data[atom_type]["test"]["origin_smiles"]
            )
            results[atom_type][n_atoms] = dict()
            results[atom_type][n_atoms]["molecular_smiles"] = molecular_smiles
            results[atom_type][n_atoms]["molecular_gt"] = molecular_gt
            results[atom_type][n_atoms]["molecular_preds"] = molecular_preds
            results[atom_type][n_atoms]["molecular_spreads"] = molecular_spreads

    pickle.dump(results, open(fname, "wb"), protocol=pickle.HIGHEST_PROTOCOL)
    pass

In [None]:
errors = {
    atom_type: {
        n_atoms: np.abs(results[atom_type][n_atoms]["molecular_gt"] - results[atom_type][n_atoms]["molecular_preds"])
        for n_atoms in n_training_atoms
    }
    for atom_type in ATOMS
}

In [None]:
errors["C"][5].shape

In [None]:
C_x = [437, 2503, 15424, 102253]
C_errors = [errors["C"][n].mean() for n in n_training_atoms]

N_x = [88, 521, 2938, 18039]
N_errors = [errors["N"][n].mean() for n in n_training_atoms]

O_x = [113, 626, 3682, 23305]
O_errors = [errors["O"][n].mean() for n in n_training_atoms]

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(3, 2))

ax.plot(np.log10(C_x), np.array(C_errors) * 100.0, "ko-", label=r"$\mathrm{C}$")
ax.plot(np.log10(N_x), np.array(N_errors) * 100.0, "bo-", label=r"$\mathrm{N}$")
ax.plot(np.log10(O_x), np.array(O_errors) * 100.0, "ro-", label=r"$\mathrm{O}$")


ax.set_xlabel(r"$\log_{10} N_\mathrm{train}$")
ax.set_ylabel(r"$100\varepsilon(N_\mathrm{train})$")

ax.set_yticks([2, 8, 14])

adj.set_xlim(ax, 2, 5, threshold=0.05)
adj.set_ylim(ax, 2, 14, threshold=0.05)

adj.set_grids(ax, grid=False)

ax.legend(frameon=False)

# plt.show()
plt.savefig("Figures/qm9_scaling_M.svg", bbox_inches="tight", dpi=300)

## Plot the average errors $\varepsilon$ for each atom type

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(3, 1.5))

colors = {"C": "black", "N": "blue", "O": "red"}

for atom_type in ATOMS:
    e = []
    for n_atom in n_training_atoms:
        e.append(errors[atom_type][n_atom].mean())
    e = np.array(e) * 100
    ax.plot([5, 6, 7, 8], e, marker="o", color=colors[atom_type], label=atom_type)
    
adj.set_grids(ax, grid=False)
ax.legend(frameon=False)
# ax.set_xticks([0.1, 0.5, 0.9])
ax.set_ylim(2, 16)
adj.set_ylim(ax, 2, 16)
ax.tick_params(axis="x", which="minor", top=False, bottom=False)
ax.set_xlabel("$|\mathcal{M}|$")
ax.set_ylabel(r"$100\varepsilon(|\mathcal{M}|)$")
ax.set_yticks([2, 9, 16])

# plt.show()
plt.savefig("Figures/qm9_generalized_epsilon.svg", bbox_inches="tight", dpi=300)

## Plot some examples

Sort by the errors.

In [None]:
atom_type = "C"
n_atoms = 8

In [None]:
argsorted = np.argsort( errors[atom_type][n_atoms].mean(axis=-1) )

Decide on an example and plot it.

In [None]:
predicted_spectra = preds[atom_type][n_atoms].copy()
ground_truth_spectra = ground_truths[atom_type].copy()

In [None]:
ii = 1002
ii = argsorted[ii]

predicted_spectra = preds[atom_type][n_atoms][:, ii, :]
ground_truth_spectra = ground_truths[atom_type][ii, :]

mu = predicted_spectra.mean(axis=0)
sd = predicted_spectra.std(axis=0)
# cond = (predicted_spectra > mu + 3 * sd) | (predicted_spectra < mu - 3 * sd)
# where_keep = np.where(cond.sum(axis=1) < 150)[0]
# predicted_spectra = predicted_spectra[where_keep, :]
# mu = predicted_spectra.mean(axis=0)
# sd = predicted_spectra.std(axis=0)

fig, ax = plt.subplots(1, 1, figsize=(3, 2))

print(data[atom_type]["test"]["origin_smiles"][ii])

ax.plot(grids[atom_type], ground_truth_spectra, "k-")

for prediction in predicted_spectra:
    ax.plot(grids[atom_type], prediction, 'r-', linewidth=0.5, alpha=0.5)

ax.plot(grids[atom_type], mu, "r-")
ax.fill_between(grids[atom_type], mu - sd * 3, mu + sd * 3, alpha=0.4, color="red", linewidth=0)

err = np.mean(np.abs(ground_truth_spectra - mu))
print(f"{np.log10(err):.02f}")

plt.show()


## Error histograms

In [None]:
from matplotlib import cm

In [None]:
cmap = cm.get_cmap("viridis", 4)

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(6, 6/3/1.6))

for ii, (ax, atom_type) in enumerate(zip(axs, ["C", "N", "O"])):
    
    ax.set_title(atom_type)

    for jj, (n_atom, alpha) in enumerate(zip([5, 6, 7, 8], [1.0, 0.9, 0.8, 0.7])):
        
        color = cmap(jj)
        
        # Get the data
        pred = results[atom_type][n_atom]["molecular_preds"].copy()
        gt = results[atom_type][n_atom]["molecular_gt"].copy()
        gt[gt < 0.0] = 0.0

        # Ensemble error itself
        ensemble_err = np.mean(np.abs(pred - gt), axis=1)
        log_ensemble_err = np.log10(ensemble_err)

        # Average testing set error as a baseline
        average_spectrum_in_testing_set = np.mean(gt, axis=0)
        dummy_testing_set_error = np.log10(np.mean(np.abs(average_spectrum_in_testing_set - gt)))
        ax.axvline(dummy_testing_set_error, color="blue", linestyle="--", linewidth=0.5, zorder=0)

        # Plot
        bins = [0.2 - ii * 0.05 for ii in range(50)][::-1]
        ax.hist(log_ensemble_err, bins=bins, color=color, alpha=alpha, label=n_atom if ii == 1 else None)
        t = ax.text(0.9, 0.4, r"$%.02f$" % dummy_testing_set_error, color="blue", ha="right", va="center", transform=ax.transAxes, fontsize=8)
        t.set_bbox(dict(facecolor='white', alpha=1, edgecolor='white'))
    
    # Fine tuning
    adj.set_grids(ax, grid=False)
    ax.set_yticklabels([])
    ax.set_xticks([-3, -2, -1, 0])
    adj.set_xlim(ax, -3, 0)
    
    if ii == 0:  # Carbon
        val = 10000
        ax.text(0.1, 0.8, val, ha="left", va="center", transform=ax.transAxes, color="gray") 
    elif ii == 1:  # Nitrogen
        val = 5000
        ax.text(0.1, 0.8, val, ha="left", va="center", transform=ax.transAxes, color="gray") 
    else:  # Oxygen
        val = 6000
        ax.text(0.1, 0.8, val, ha="left", va="center", transform=ax.transAxes, color="gray") 

    ax.axhline(val, color="gray", linestyle="--", linewidth=0.5, zorder=0)
        
axs[1].set_xlabel(r"$\log_{10} \varepsilon^{(\mathcal{M})}$")
axs[0].set_ylabel("Counts")
axs[1].legend(frameon=False, loc="center left", fontsize=10)

plt.subplots_adjust(wspace=0.1)

# plt.savefig("Figures/qm9_generalize_hists.svg", bbox_inches="tight", dpi=300)
# needs post-processing on InkScape
# clear_output()
plt.show()

## Plot the correlation between error and std

In [None]:
# fig, axs = plt.subplots(len(downsample_values), len(ATOMS), figsize=(6, 2), sharex=True, sharey=True)
fig, axs = plt.subplots(1, len(ATOMS), figsize=(6, 2), sharex=True, sharey=True)

bins_map = {"C": 70, "N": 30, "O": 50}

n_atom = 7

debug = None
    
for ii, atom_type in enumerate(ATOMS):
    
    # Get the predictions
    pred = results[atom_type][n_atom]["molecular_preds"].copy()
    gt = results[atom_type][n_atom]["molecular_gt"].copy()
    gt[gt < 0] = 0.0
    
    # Pointwise
    ensemble_pointwise_err = np.abs(gt - pred).flatten()
    ensemble_pointwise_std = results[atom_type][n_atom]["molecular_spreads"].copy().flatten()

    ax = axs[ii]
    ax.set_title(atom_type)

    y = np.log10(ensemble_pointwise_err[::debug])
    x = np.log10(ensemble_pointwise_std[::debug])
    ax = density_scatter(x, y, ax=ax, sort=True, bins=bins_map[atom_type], s=0.4, alpha=1, rasterized=True)
    if atom_type == "C":
        p = np.polyfit(x[::10], y[::10], deg=1)
        r2 = np.corrcoef(x[::10], y[::10])[0, 1]**2
    else:
        p = np.polyfit(x, y, deg=1)
        r2 = np.corrcoef(x, y)[0, 1]**2
    poly = np.poly1d(p)
    ax.axline((-3, poly(-3)), (-2, poly(-2)), color="black", linestyle="-", linewidth=0.5, alpha=0.9)
    ax.axline((-3, poly(-3) + 0.5), (-2, poly(-2) + 0.5), color="black", linestyle="--", linewidth=0.5, alpha=0.8)
    ax.axline((-3, poly(-3) + 1.0), (-2, poly(-2) + 1.0), color="black", linestyle="--", linewidth=0.5, alpha=0.7)
    ax.axline((-3, poly(-3) + 1.5), (-2, poly(-2) + 1.5), color="black", linestyle="--", linewidth=0.5, alpha=0.6)
    ax.axline((-3, poly(-3) + 2.0), (-2, poly(-2) + 2.0), color="black", linestyle="--", linewidth=0.5, alpha=0.5)
    ax.text(0.1, 0.9, r"$r^2 = %.02f$" % r2, ha="left", va="top", transform=ax.transAxes)

ax_min = -6
ax_max = 2
for ax in axs.flatten():
    ax.set_xlim(ax_min, ax_max)
    ax.set_ylim(ax_min, ax_max)
    ax.set_xticks([-6, -2, 2])
    ax.set_yticks([-6, -2, 2])
    adj.set_grids(ax, grid=False)
    adj.set_xlim(ax, ax_min, ax_max)
    adj.set_ylim(ax, ax_min, ax_max)

axs[0].set_ylabel(r"$\log_{10} \varepsilon^{(\mathcal{M})}_j$")
axs[1].set_xlabel(r"$\log_{10} \hat{\sigma}^{(\mathcal{M})}_j$")

plt.subplots_adjust(wspace=0.1)

# clear_output()

plt.savefig("Figures/qm9_generalized_sigma_parity.svg", bbox_inches="tight", dpi=300)
# plt.show()

## Plot some examples

In [None]:
atom_type = "C"
grid = data[atom_type]["train"]["grid"]
pred = ensembles[atom_type].predict(data[atom_type]["test"]["x"])
gt = data[atom_type]["test"]["y"]

In [None]:
err = np.mean(np.abs(gt - pred.mean(axis=0)), axis=1)
sorted_idx = np.argsort(err)[::-1]

In [None]:
len(sorted_idx)

In [None]:
err[sorted_idx][:10]

In [None]:
(pred[:, sorted_idx, :].std(axis=0) * 3)[:10].mean(axis=1)

In [None]:
names = [data[atom_type]["test"]["names"][ii] for ii in sorted_idx[:10]]
names

In [None]:
names = [data[atom_type]["test"]["origin_smiles"][ii] for ii in sorted_idx[:10]]
names

In [None]:
from rdkit import Chem

In [None]:
grid = Chem.Draw.MolsToGridImage([Chem.MolFromSmiles(smile) for smile in names[:3]], useSVG=True)

In [None]:
with open("qm9_C_fail.svg", "w") as f:
    f.write(grid.data)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(3, 5))

n = 10
for ii in range(n):
    
    offset = ii * 6
    idx = sorted_idx[int(ii / n * len(sorted_idx)) + 11]
    
    ground_truth_spectra = gt[idx, :]
    predicted_spectra = pred[:, idx, :]

    mu = predicted_spectra.mean(axis=0) + offset
    
    sd = predicted_spectra.std(axis=0) * 3
    
    label = r"$\mu^{(i)}$" if ii == 0 else None
    ax.plot(grid, ground_truth_spectra + offset, "k-", label=label)

    # label = r"$\hat{\mu}^{(i)}$" if ii == 0 else None
    # ax.plot(grid, mu, color="purple", linewidth=1, label=label)
    
    for jj, prediction in enumerate(predicted_spectra):
        label = r"$\hat{\mu}^{(i, k)}$" if jj == 0 and ii == 0 else None 
        ax.plot(grid, prediction + offset, 'r-', linewidth=0.5, alpha=0.5, label=label)
    
    label = r"$3\sigma$" if ii == 0 else None
    ax.fill_between(grid, mu - sd, mu + sd, color="red", alpha=0.5, linewidth=0, label=label)
    
    err = np.log10(np.mean(np.abs(ground_truth_spectra - predicted_spectra))).item()
    ax.text(0.9, 0.09 + ii / 10.5, r"$%.02f$" % err, ha="right", va="center", transform=ax.transAxes)
    
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.spines["left"].set_linewidth(0.5)
ax.spines["bottom"].set_linewidth(0.5)
adj.set_grids(ax, grid=False, top=False)
ax.set_yticklabels([])
ax.set_yticks([])
ax.set_xlabel(r"$E$~(e.V.)")
ax.set_ylabel(r"$\mu(E)$~(a.u.)")
ax.legend(frameon=False, ncol=3, loc="upper center", bbox_to_anchor=(0.5, 1.1))

# err = np.log10(np.mean(np.abs(gt[ii] - pred[ii])))
# print(f"{err:.02f}")


plt.show()
# plt.savefig("qm9_C_random_preds.svg", bbox_inches="tight", dpi=300)

## Final waterfall plots

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(3*2.5, 4), sharey=True)
smiles_list = []

n = 5

n_atom = 6

for ii, atom_type in enumerate(ATOMS):
    
    grid = data[atom_type]["train"]["grid"]

    ax = axs[ii]
    
    pred = results[atom_type][n_atom]["molecular_preds"].copy()
    gt = results[atom_type][n_atom]["molecular_gt"].copy()
    spreads = results[atom_type][n_atom]["molecular_spreads"].copy()
    gt[gt < 0] = 0.0
    errors = np.abs(gt - pred).mean(axis=1)
    sorted_idx = np.argsort(errors)[::-1]
    molecular_smiles = results[atom_type][n_atom]["molecular_smiles"].copy()
    L = len(sorted_idx)
    
    for ii in range(n): 

        offset = ii * 6
        idx = sorted_idx[int(ii / n * len(sorted_idx)) + L // (2*n)]

        ground_truth_spectra = gt[idx, :].copy() + offset
        predicted_spectra = pred[idx, :].copy() + offset
        predicted_spread = spreads[idx, :].copy()
        smiles = molecular_smiles[idx]

        label = r"{\boldmath$\mu$}$^{(\mathcal{M})}$" if ii == 0 and atom_type == "N" else None
        ax.plot(grid, ground_truth_spectra, "k-", label=label)

        label = r"{\boldmath$\hat{\mu}$}$^{(\mathcal{M})}$" if ii == 1 and atom_type == "N" else None
        ax.plot(grid, predicted_spectra, color="red", linewidth=1, label=label, zorder=4)
        
        label = r"$3${\boldmath$\hat{\sigma}$}$^{(\mathcal{M})}$" if ii == 1 and atom_type == "N" else None
        ax.fill_between(grid, predicted_spectra - predicted_spread * 3, predicted_spectra + predicted_spread * 3, color="red", alpha=0.5, linewidth=0, label=label, zorder=3)

        
        # if ii != 0:
        #     label = r"$3\sigma$" if ii == 1 and atom_type == "N" else None
        #     ax.fill_between(grid, predicted_spectra - predicted_spread * 3, predicted_spectra + predicted_spread * 3, color="red", alpha=0.5, linewidth=0, label=label)
        # else:
        #     label = r"$\sigma$" if ii == 0 and atom_type == "N" else None
        #     ax.fill_between(grid, predicted_spectra - predicted_spread, predicted_spectra + predicted_spread, color="magenta", alpha=0.5, linewidth=0, label=label)

        err = np.log10(np.mean(np.abs(ground_truth_spectra - predicted_spectra))).item()
        ax.text(0.9, 0.2 + ii / 10.7 * 2, r"$%.02f$" % err, ha="right", va="center", transform=ax.transAxes)
        
        if ii == 1:
            ax.legend(frameon=False, ncol=1, loc="upper center", bbox_to_anchor=(1.0, 1.1))
        ax.set_title(atom_type)

        print(f"{err:.02f} : {smiles}")
        smiles_list.append(smiles)

    ax.spines["right"].set_visible(False)
    ax.spines["top"].set_visible(False)
    ax.spines["left"].set_linewidth(0.5)
    ax.spines["bottom"].set_linewidth(0.5)
    adj.set_grids(ax, grid=False, top=False)
    ax.set_yticklabels([])
    ax.set_yticks([])

axs[1].set_xlabel(r"$E$~(e.V.)")
axs[0].set_ylabel(r"$\mu(E)$~(a.u.)")


# plt.show()
plt.savefig("Figures/qm9_CNO_examples.svg", bbox_inches="tight", dpi=300)




In [None]:
from rdkit import Chem
svgs = Chem.Draw.MolsToGridImage([Chem.MolFromSmiles(smile) for smile in smiles_list], useSVG=True)
with open("Figures/qm9_CNO_molecules.svg", "w") as f:
    f.write(svgs.data)