In [None]:
import os
from itertools import product

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from matplotlib.legend_handler import HandlerTuple
from matplotlib.lines import Line2D
from matplotlib.patches import Rectangle
from uncertainties import ufloat

import dlqmc.mplext
from deepqmc.wf.paulinet import DistanceBasis

In [None]:
NATURE_CHEM = False
FIG_ROOT = '../pub/figs' if not NATURE_CHEM else '../pub/final-natchem/figs'

In [None]:
# needs to be in a separate cell, see https://github.com/ipython/ipython/issues/11098
mpl.rcParams['figure.dpi'] = 150
if not NATURE_CHEM:
    mpl.rc('font', family='serif', serif='STIXGeneral', size=9)
    mpl.rc('mathtext', fontset='stix')
    mpl.rc('axes', titlesize=9)
else:
    mpl.rc('font', family='sans-serif', serif='Helvetica', size=7)
    mpl.rc('axes', titlesize=7)
    # from https://stackoverflow.com/a/20709149
    mpl.rcParams['text.latex.preamble'] = [
           r'\usepackage{siunitx}',   # i need upright \micro symbols, but you need...
           r'\sisetup{detect-all}',   # ...this to force siunitx to actually use your fonts
           r'\usepackage{helvet}',    # set the normal font here
           r'\usepackage{sansmath}',  # load up the sansmath so that math -> helvet
           r'\sansmath'               # <- tricky! -- gotta actually tell tex to use!
    ]  
COLORS = plt.rcParams['axes.prop_cycle'].by_key()['color']

In [None]:
def savefig(fig, name, ext='pdf', **kwargs):
    fig.savefig(
        f'{FIG_ROOT}/{name}.{ext}',
        transparent=True,
        dpi=600,
        bbox_inches='tight',
        pad_inches=0.03,
        **kwargs,
    )


def to_corr(x, ref):
    return (ref[0] - x) / (ref[0] - ref[1])


def to_corr_error(x, ref):
    return x / (ref[0] - ref[1])

## Table plot

In [None]:
systems = ['H2', 'LiH', 'Li2', 'Be', 'B', 'C']
refs_qmc = pd.read_csv('../data/extern/small-systems-vmc.csv').set_index(
    ['reference', 'system']
)
refs_exact = pd.read_csv('../data/extern/small-systems-exact.csv').set_index('system')
data = pd.read_csv('../data/final/small-systems.csv').set_index(['system', 'ansatz'])
data_paulinet = np.array(
    [
        [
            (
                100
                - 100
                * to_corr(
                    data.loc(0)[(system, ansatz)]['energy'], refs_exact.loc(0)[system]
                ),
                100
                * to_corr_error(
                    data.loc(0)[(system, ansatz)]['err'], refs_exact.loc(0)[system]
                ),
            )
            for ansatz in ['SD-SJ', 'SD-SJBF', 'MD-SJ', 'MD-SJBF']
        ]
        for system in systems
    ]
)

In [None]:
def get_ref(references, ref, system):
    data = []
    if (ref, system) in references.index:
        ref_i = references.loc(0)[ref].loc(0)[system].loc(0)
        for ansatz in ['SD-SJ', 'SD-SJBF', 'MD-SJ', 'MD-SJBF']:
            if not np.isnan(ref_i[f'{ansatz} energy']):
                energy = 100 - 100 * to_corr(
                    ref_i[f'{ansatz} energy'], refs_exact.loc(0)[system]
                )
                error = 100 * to_corr_error(
                    ref_i[f'{ansatz} error'], refs_exact.loc(0)[system]
                )
                data.append(
                    (energy, error, 1)
                    if 'SD' in ansatz
                    else (energy, error, ref_i[f'{ansatz} nCSF'])
                )
            else:
                data.append((float('NaN'), float('NaN'), float('NaN')))
    else:
        data = [[float('NaN')] * 3] * 4
    return np.array(data)


label_reposition = np.ones([6, 6, 4, 2])  # H2 LiH Li2 Be B C / ref / ansatz
label_reposition[
    [2, 2, 3, 4, 4, 5, 5, 5, 5, 5, 5, 5],
    [5, 2, 0, 2, 5, 5, 3, 0, 5, 3, 0, 2],
    [0, 2, 3, 2, 2, 0, 0, 0, 2, 1, 1, 2],
] = np.array(
    [
        [1, 0.85],
        [0.25, 1],
        [1, 1.2],
        [0.25, 1],
        [1, 1.2],
        [1.1, 1],
        [1, 1.25],
        [1.1, 1.2],
        [1.3, 1.0],
        [1.0, 0.9],
        [1.0, 0.9],
        [0.25, 1],
    ]
)
fig, axs = plt.subplots(
    1,
    5,
    sharex=True,
    sharey=True,
    figsize=(3.7, 4.5),
    gridspec_kw=dict(hspace=0.08, wspace=0.09),
)
for s, (system, axi) in enumerate(zip(systems[:-1], axs)):
    axi.set_title(fr'$\mathrm{{{system.replace("2", "_2")}}}$')
    ds = data_paulinet[s]
    for k in range(2):
        axi.errorbar(
            [1, 1, 6, 6][k::2],
            ds[k::2, 0].clip(0.05),
            ds[k::2, 1],
            ls='',
            marker='o',
            fillstyle=['none', 'full'][k],
            ms=7,
            color='C0',
            clip_on=False,
        )
    for i, p in enumerate(ds[:, 0]):
        if p < 0.05:
            axi.annotate(
                '',
                xy=([1, 6][i // 2], 0.05),
                xycoords='data',
                xytext=([1, 6][i // 2], 0.03),
                textcoords='data',
                arrowprops=dict(arrowstyle='<-'),
                annotation_clip=False,
            )
    for j, ref in enumerate(
        ['Brown', 'Casalengo', 'Morales', 'Rios', 'Seth', 'Toulouse']
    ):
        dj = get_ref(refs_qmc, ref, system)
        for k, l in product([1, 0], range(2)):
            axi.errorbar(
                dj[k + 2 * l, 2],
                dj[k + 2 * l, 0],
                dj[k + 2 * l, 1],
                fillstyle=['none', 'full'][k],
                ls='',
                marker=['o', '^'][l],
                c='C1',
                ms=6,
            )
        for i, (y, _, x) in enumerate(dj):
            axi.annotate(j + 1, (x * 1.9, y * 1.05) * label_reposition[s, j, i])
    axi.set_yscale('log')
    axi.set_xscale('log')
    axi.set_yticks([0.1, 1, 10])
    axi.set_yticklabels(['99.9%', '99%', '90%'])
    axi.set_xticks([1, 10, 100])
    axi.set_xticklabels([1, 10, 100])
    axi.xaxis.set_minor_locator(mpl.ticker.LogLocator(subs=(4, 7), numticks=8))
    axi.set_ylim(60, 0.05)
    axi.set_xlim(0.6, 1000)
    axi.grid(axis='y', which='major', ls='dotted')
    if axi is axs.flat[0]:
        handles = [
            axi.errorbar([], [], ls='', marker='o', c='C0')[0],
            axi.errorbar([], [], ls='', marker='o', c='C1')[0],
            axi.errorbar([], [], fillstyle='none', ls='', marker='o', c='black')[0],
            (
                axi.errorbar([], [], ls='', marker='^', c='black')[0],
                axi.errorbar([], [], ls='', fillstyle='none', marker='^', c='black')[0],
            ),
        ]
        axi.legend(
            handles,
            ['PauliNet', 'other works', 'without backflow', 'uses CSFs'],
            numpoints=1,
            handler_map={tuple: HandlerTuple(ndivide=None)},
            loc='lower left',
            bbox_to_anchor=(0.5, 1.06),
            ncol=2,
            handletextpad=0.5,
            columnspacing=1,
        )
fig.text(0.5, 0.03, 'number of determinants/CSFs', ha='center')
fig.text(-0.02, 0.5, 'correlation energy', va='center', rotation='vertical')
savefig(fig, 'small-systems')

## Learning curves

In [None]:
systems = ['H2', 'Be', 'B', 'LiH', 'Li2']
ansatzes = ['SD-SJ', 'SD-SJBF', 'MD-SJ', 'MD-SJBF']
results = pd.read_csv('../data/final/learning-curves.csv').set_index(
    ['system', 'ansatz']
)
refs_exact = pd.read_csv('../data/extern/small-systems-exact.csv').set_index('system')

In [None]:
def plot_mean_err(ax, data, ref_enes, bs, decay):
    step, energy, err = data
    inds = np.geomspace(1, 9_999, 200).astype(int)
    energy = to_corr(energy, ref_enes)
    err = to_corr_error(err, ref_enes)
    ax.plot(step[inds], energy[inds], color='#444444')
    ax.fill_between(
        step[inds], (energy + err)[inds], (energy - err)[inds], color='grey', alpha=0.5
    )


fig, axes = plt.subplots(
    len(systems),
    len(ansatzes),
    figsize=(3.3, 4.3),
    gridspec_kw=dict(hspace=0.1, wspace=0.09),
    sharex=True,
    sharey=True,
)
for (i, system), (j, ansatz) in product(enumerate(systems), enumerate(ansatzes)):
    ax = axes[i, j]
    plot_mean_err(
        ax,
        (
            np.arange(0, 10_000),
            np.array(results.loc(0)[system, ansatz]['energy']),
            np.array(results.loc(0)[system, ansatz]['err']),
        ),
        refs_exact.loc(0)[system],
        10_000,
        3,
    )
    ax.set_xscale('log')
    ax.set_yscale('corr_energy')
    ax.set_xticks([1, 10, 100, 1000])
    ax.set_xticklabels([1, None, None, 1000])
    ax.xaxis.set_minor_locator(mpl.ticker.LogLocator(subs=(4, 7), numticks=8))
    ax.set_xticklabels([], minor=True)
    ax.set_xlim(1, 1e4)
    ax.set_ylim(-0.3, 0.9993)
    ax.grid(axis='y', which='major', ls='dotted')
    if i == 0:
        ax.set_title(ansatz)
    if j == 0:
        ax.set_ylabel(fr'$\mathrm{{{system.replace("2", "_2")}}}$', labelpad=22)
fig.text(0.5, 0.03, 'iterations', ha='center')
fig.text(-0.05, 0.5, 'correlation energy', rotation='vertical', va='center')
savefig(fig, 'learning-curves')

## H10

In [None]:
def import_ref(ref):
    root = '../data/extern/motta-hydrogen/N_10_OBC'
    assert ref in os.listdir(f'{root}/R_1.0')
    return np.array([np.loadtxt(f'{root}/R_{d}/{ref}') for d in distances])

In [None]:
distances = np.array([1.2, 1.4, 1.6, 1.8, 2.0, 2.4, 2.8, 3.2, 3.6])
systems = [f'H10_d{di}' for di in distances]
ansatzes = ['SD-SJ', 'SD-SJBF', 'MD-SJBF']
refs = {
    ref: import_ref(ref)
    for ref in ['RHF_CBS', 'MRCI+Q+F12_CBS', 'VMC_AGP_basis-TZ', 'VMC_LDA_basis-TZ']
}
data = pd.read_csv('../data/final/h10.csv').set_index('ansatz')
data = {ansatz: np.array(data.loc(0)[ansatz][['energy', 'err']]) for ansatz in ansatzes}

In [None]:
fig, (ax2, ax1) = plt.subplots(
    2,
    1,
    figsize=(3.63, 3.63),
    sharex=True,
    gridspec_kw=dict(hspace=0.06, height_ratios=(2, 1)),
)
plots = []
plots.append(
    *ax1.plot(distances, refs['RHF_CBS'][:, 0], ls=':', color='r', label='RHF')
)
plots.append(
    *ax1.plot(
        distances,
        refs['MRCI+Q+F12_CBS'][:, 0],
        color='k',
        label='MRCI+Q-F12',
        zorder=10,
    )
)
for i, ansatz in enumerate(ansatzes):
    plots.append(
        ax1.errorbar(
            distances,
            data[ansatz][:, 0],
            data[ansatz][:, 1],
            label=ansatz,
            ls=[':', 'dashed', '-'][i],
            fillstyle=['none', 'full', 'full'][i],
            marker='o',
            ms='4',
            color='C0',
        )
    )
ax1.grid(axis='y', which='major', ls='dotted')
ax1.xaxis.set_major_locator(mpl.ticker.MultipleLocator(0.4))
ax1.yaxis.set_minor_locator(mpl.ticker.MultipleLocator(0.1))
ax1.set_ylabel(r'total energy [$E_{\mathrm{h}}$]')
ax1.set_xlabel(r'H–H distance [$r_{\mathrm{Bohr}}$]')
ax1.set_ylim(-5.75, None)
lines_for_legend = [
    Line2D(
        [0],
        [0],
        ls=['-', 'dotted'][i],
        fillstyle=['none', 'none'][i],
        lw=1.2,
        marker='o',
        ms='4',
        color='C1',
    )
    for i in range(2)
]
ax1.legend(
    np.array([*plots, *lines_for_legend])[[0, 1, 5, 2, 3, 4]],
    ['HF', 'MRCI+Q-F12', 'VMC', 'SD-SJ', 'SD-SJBF', 'MD-SJBF'],
    loc='lower center',
    bbox_to_anchor=(0.41, 3.1),
    ncol=3,
    columnspacing=0.75,
)
for i, ref in enumerate(['VMC_AGP_basis-TZ', 'VMC_LDA_basis-TZ']):
    ax2.errorbar(
        distances,
        to_corr(refs[ref][:, 0], (refs['RHF_CBS'][:, 0], refs['MRCI+Q+F12_CBS'][:, 0])),
        to_corr_error(
            refs[ref][:, 1], (refs['RHF_CBS'][:, 0], refs['MRCI+Q+F12_CBS'][:, 0])
        ),
        ls=['dashed', 'dotted'][i],
        lw=1.2,
        fillstyle=['none', 'none'][i],
        marker='o',
        ms='4',
        color='C1',
    )

ax = ax1.inset_axes((1.3, -5.8, 1.75, 2), transform=ax1.transData)
ax.imshow(plt.imread('../assets/h10.png'))
ax.set_axis_off()

for i, ansatz in enumerate(ansatzes):
    ax2.errorbar(
        distances,
        to_corr(
            data[ansatz][:, 0], (refs['RHF_CBS'][:, 0], refs['MRCI+Q+F12_CBS'][:, 0])
        ),
        to_corr_error(
            data[ansatz][:, 1], (refs['RHF_CBS'][:, 0], refs['MRCI+Q+F12_CBS'][:, 0])
        ),
        ls=[':', 'dashed', '-'][i],
        fillstyle=['none', 'full', 'full'][i],
        marker='o',
        ms='4',
        color='C0',
    )
ax2.grid(axis='y', which='major', ls='dotted')
ax2.set_yscale('corr_energy')
ax2.set_ylabel('correlation energy')
savefig(fig, 'h10-dis-curve')

## Distance basis

In [None]:
fig, ax = plt.subplots(figsize=(2, 2))
x = torch.linspace(0, 12, 300)
ax.plot(x.numpy(), DistanceBasis(32, envelope='nocusp')(x).numpy())
ax.set_xlabel(r'$r/a_0$')
ax.set_ylabel(r'$\mathbf{e}(r)$')
ax.set_yticks([0, 0.4])
savefig(fig, 'dist-features')

## Diatomics

In [None]:
refs_qmc = pd.read_csv('../data/extern/diatomics-qmc.csv').set_index(['system', 'ref'])
refs_exact = pd.read_csv('../data/extern/diatomics-exact.csv').set_index(['system'])
data = pd.read_csv('../data/final/diatomics.csv').set_index(['system', 'ndet'])
dets = [1, 3, 10, 30, 100]
systems = ['Li2', 'Be2', 'B2', 'C2']
refs = ['Filippi', 'Toulouse', 'Morales']

In [None]:
def plot_vmc_dmc(ax, dets, e_vmc, e_dmc, e_ref, label, color, marker):
    (p_vmc,) = ax.plot(
        dets,
        to_corr(e_vmc, e_ref),
        label=f'ref. [{label}]',
        ls=' ',
        fillstyle='none',
        marker=marker,
        color=color,
        ms=5,
    )
    (p_dmc,) = ax.plot(
        dets, to_corr(e_dmc, e_ref), ls=' ', marker='x', fillstyle='none', color=color,
    )
    ax.plot(
        [dets, dets],
        [to_corr(e_vmc, e_ref), to_corr(e_dmc, e_ref)],
        color=color,
        ls=':',
    )
    return p_vmc, p_dmc


fig, axes = plt.subplots(
    2,
    2,
    sharex=True,
    sharey=True,
    figsize=(3.5, 2.6),
    gridspec_kw=dict(hspace=0.08, wspace=0.06),
)
for ax, system in zip(axes.flat, systems):
    ek = np.array([data.loc(0)[(system, d)] for d in dets])
    p_pn = ax.errorbar(
        dets,
        to_corr(ek[:, 0], refs_exact.loc(0)[system]),
        to_corr_error(ek[:, 1], refs_exact.loc(0)[system]),
        ms=5,
        marker='o',
        ls='',
        color='C0',
        linewidth=2,
        label='PauliNet',
    )
    ax.plot(
        dets,
        to_corr(ek[:, 0], refs_exact.loc(0)[system]),
        ls=':',
        color='grey',
        linewidth=2,
        zorder=0,
    )
    for ref, color, marker in zip(refs, 'gry', 'o^^'):
        ref_j = refs_qmc.loc(0)[system, ref]
        plot_vmc_dmc(
            ax,
            ref_j['ndet'],
            ref_j['e_vmc'],
            ref_j['e_dmc'],
            refs_exact.loc(0)[system],
            ['FU', 'TU', 'Mo'][refs.index(ref)],
            color,
            marker,
        )
    ax.set_xscale('log')
    ax.set_xlim(0.5, 5_000)
    ax.set_xticks([1, 10, 100, 1000])
    ax.set_xticklabels([1, 10, 100, 1000])
    ax.xaxis.set_minor_locator(mpl.ticker.LogLocator(subs=(4, 7), numticks=8))
    ax.set_xticklabels([], minor=True)
    ax.set_yscale('corr_energy')
    ax.set_ylim(0.7, 0.9992)
    ax.grid(axis='y', which='major', ls='dotted')
    ax.annotate(
        fr'$\mathrm{{{system.replace("2", "_2")}}}$',
        (0.05, 0.8),
        xycoords='axes fraction',
    )
    if ax is axes.flat[0]:
        plots = [
            (
                ax.plot([], [], ls='', color='black', fillstyle='none', marker='o')[0],
                ax.plot([], [], ls='', color='black', fillstyle='none', marker='^')[0],
            ),
            ax.plot([], [], ls='', color='black', fillstyle='none', marker='x')[0],
        ]
        ax.legend(
            [p_pn, *plots],
            ['PauliNet', 'VMC other works', 'DMC other works'],
            numpoints=1,
            handler_map={tuple: HandlerTuple(ndivide=None)},
            loc='lower center',
            bbox_to_anchor=(0.87, 1.03),
            ncol=2,
            columnspacing=0.7,
        )
fig.text(
    0.5, -0.02, 'number of determinants/CSFs', ha='center', va='center',
)
fig.text(
    -0.04, 0.5, 'correlation energy', ha='center', va='center', rotation='vertical',
)
savefig(fig, 'diatomics')

## Determinants

- NN-QMC: 1s to 10s
- standard MD-QMC: 100s to 100,000s
- NN-CI: 100,000s
- FCI-QMC: 1,000,000s to 100,000,000s
- FCI: to 100,000,000s to 1,000,000,000s

In [None]:
def sstep(x):
    return np.piecewise(
        x,
        [x <= 0, x < 1],
        [0, lambda x: -20 * x ** 7 + 70 * x ** 6 - 84 * x ** 5 + 35 * x ** 4, 1],
    )


def get_bar(ws, vmax=1, dens=100):
    return np.hstack(
        [
            np.linspace(0, vmax, dens * ws[0]),
            vmax * np.ones(dens * ws[1]),
            np.linspace(vmax, 0, dens * ws[2]),
        ]
    )


fig, ax = plt.subplots(figsize=(2.55, 1.6))
payload = [
    ('multideterminant\nQMC + NNs', 2, 50),
    ('multideterminant QMC', 100, 1e5),
    ('configuration\ninteraction + NNs', 1e5, 1e6),
    ('configuration interaction', 2e6, 2e9),
]
for i, (_, fro, to) in enumerate(payload):
    ax.add_patch(Rectangle((fro, i + 0.1), to - fro, 0.8, color='grey'))
ax.set_xlim(1, 1e10)
ax.set_ylim(0, 4)
ax.set_xscale('log')
ax.set_xlabel('number of determinants')
ax.axvline(1e5, color='black', ls='dashed')
ax.text(2.5, 4.3, '1st quantization', fontstyle='italic')
ax.text(2.0e5, 4.3, '2nd quantization', fontstyle='italic')
ax.set_yticks([0.5, 1.5, 2.5, 3.5])
ax.set_yticklabels([l for l, *_ in payload], ha='right')
savefig(fig, 'ndets')

## Cyclobutadiene

In [None]:
results = pd.read_csv('../data/final/cyclobutadiene-sample.csv')
enes = (
    results.groupby(['batch', 'state'])
    .apply(lambda x: ufloat(x['energy'].mean(), x['energy'].std() / np.sqrt(len(x))))
    .unstack()
    .pipe(lambda x: 632 * (x['ground'] - x['transition']))
)
enes

In [None]:
results = pd.read_csv('../data/final/cyclobutadiene-fit.csv')

In [None]:
fig = plt.figure(constrained_layout=False, figsize=(3.63, 4.5))
gs = fig.add_gridspec(nrows=4, ncols=7, wspace=2)
ax = fig.add_subplot(gs[1:, 1:4])
ax1 = fig.add_subplot(gs[1:, 5:7])
ax2 = fig.add_subplot(gs[0, :])
ax.axhline(-153.71, c='red', ls='dotted', label='HF')
ax.axhline(-154.25, c='black', ls='dotted', label='CCSD(T)')
ax.axhline(-154.45, c='black', ls='dotted')
ax.axhline(-154.55, c='black', ls='dotted')
for (batch, state), traj in results.groupby(['batch', 'state']):
    if batch != 250:
        continue
    color = {'ground': COLORS[0], 'transition': 'lightskyblue'}[state]
    if state == 'ground':
        state = 'minimum'
    ax.plot(traj['step'], traj['energy_ewm'], label=state, color=color)
ax.set_ylim(-154.65, -153.6)
ax.set_xlim(10, None)
ax.yaxis.set_major_locator(mpl.ticker.MultipleLocator(0.5))
ax.yaxis.set_minor_locator(mpl.ticker.MultipleLocator(0.1))
ax.grid(axis='y', which='major')
ax.grid(axis='y', which='minor', ls='dotted')
ax.set_xscale('log')
ax.set_xlabel('iterations')
ax.set_ylabel(r'total energy [$E_{\mathrm{h}}$]')
ax.text(-0.63, 0.98, 'b', transform=ax.transAxes, va='top', weight='bold')
ax.legend(
    loc='upper center',
    bbox_to_anchor=(0.3, -0.18),
    ncol=2,
    columnspacing=0.7,
)


def plot_bar(ax, y, **kwargs):
    ax.axhline(y, 0.1, 0.9, lw=1.5, **kwargs)


def plot_rect(ax, fro, to, w, **kwargs):
    ax.add_patch(
        mpl.patches.Rectangle((0.5 - w / 2, fro), w, to - fro, ec=None, **kwargs)
    )


plot_bar(ax1, 18.3, color='red', ls='dashed', label='CCSD(T)')
plot_bar(ax1, 6.8, color='black', label='MR-CC')  # , label='BW-MRCCSD(T)')
plot_bar(ax1, 8.75, color='black')  # , label='MRCISD+Q')
plot_bar(ax1, 8.95, color='black')  # , label='Mk-MRCCSD(T)')
plot_bar(ax1, 9.5, color='black')  # , label='RMRCCSD(T)')
plot_bar(ax1, 10.7, color='black')  # , label='MR-DI-EOMCCSD')
plot_rect(
    ax1,
    1.6,
    10,
    0.75,
    color=mpl.colors.to_rgb(COLORS[1]) + (0.5,),
    zorder=-120,
    label='experiment',
)
plot_rect(
    ax1,
    -enes.max().n,
    -enes.min().n,
    0.65,
    color=mpl.colors.to_rgb(COLORS[0]) + (1,),
    zorder=-100,
    label='PauliNet',
)
ax1.set_xlim(0, 1)
ax1.set_xticks([])
ax1.set_ylim(-1, 21)
ax1.yaxis.set_major_locator(mpl.ticker.MultipleLocator(5))
ax1.yaxis.set_minor_locator(mpl.ticker.MultipleLocator(1))
ax1.grid(axis='y', which='major', ls='dotted')
ax1.set_ylabel('transition barrier [kcal/mol]')
ax1.text(-0.7, 0.98, 'c', transform=ax1.transAxes, va='top', weight='bold')
ax1.legend()
items = list(zip(*ax1.get_legend_handles_labels()))
items = items[2], items[0], items[1], items[3]
ax1.legend(
    *zip(*items), loc='upper center', bbox_to_anchor=(0.1, -0.05),
)

ax2.imshow(plt.imread('../assets/cclbd.png'))
ax2.set_axis_off()
ax2.text(-0.08, 0.98, 'a', transform=ax2.transAxes, va='top', weight='bold')

savefig(fig, 'cyclobutadiene-training')