In [None]:
from itertools import product

import h5py
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from matplotlib.patches import Rectangle
from uncertainties import ufloat
from uncertainties import unumpy as unp

import dlqmc.mplext
from deepqmc.ewm import EWMAverage
from deepqmc.wf.paulinet import DistanceBasis

In [None]:
# needs to be in a separate cell, see https://github.com/ipython/ipython/issues/11098
mpl.rcParams['figure.dpi'] = 150
mpl.rc('font', family='serif', serif='STIXGeneral', size=9)
mpl.rc('mathtext', fontset='stix')
mpl.rc('axes', titlesize=9)
COLORS = plt.rcParams['axes.prop_cycle'].by_key()['color']

In [None]:
def savefig(fig, name, ext='pdf', **kwargs):
    fig.savefig(
        f'../pub/figs/{name}.{ext}',
        transparent=True,
        bbox_inches='tight',
        pad_inches=0,
        **kwargs,
    )


def to_corr(x, ref):
    return (ref[0] - x) / (ref[0] - ref[1])


def to_corr_error(x, ref):
    return x / (ref[0] - ref[1])

## Table plot

In [None]:
refs_qmc = pd.read_csv('../data/extern/small-systems-vmc.csv').set_index(
    ['reference', 'system']
)
refs_exact = pd.read_csv('../data/extern/small-systems-exact.csv').set_index('system')
systems = ['H2', 'LiH', 'Li2', 'Be', 'B', 'C']
with h5py.File('../data/raw/data_pub_small_systems.h5', 'a') as f:
    data_paulinet = np.array(
        [
            [
                (
                    100
                    - 100
                    * to_corr(
                        f[system][ansatz].attrs['energy'][0], refs_exact.loc(0)[system]
                    ),
                    100
                    * to_corr_error(
                        f[system][ansatz].attrs['energy'][1], refs_exact.loc(0)[system]
                    ),
                )
                for ansatz in ['SD-SJ', 'SD-SJBF', 'MD-SJ', 'MD-SJBF']
            ]
            for system in systems
        ]
    )

In [None]:
def get_ref(references, ref, system):
    data = []
    for ansatz in ['SD-SJ', 'SD-SJBF', 'MD-SJ', 'MD-SJBF']:
        try:
            ref_i = references.loc(0)[ref].loc(0)[system].loc(0)
            if ansatz in ['SD-SJ', 'SD-SJBF'] and not np.isnan(
                ref_i[f"{ansatz} energy"]
            ):
                data.append(
                    (
                        100
                        - 100
                        * to_corr(ref_i[f"{ansatz} energy"], refs_exact.loc(0)[system]),
                        100
                        * to_corr_error(
                            ref_i[f"{ansatz} error"], refs_exact.loc(0)[system]
                        ),
                        1,
                    )
                )
            else:
                data.append(
                    (
                        100
                        - 100
                        * to_corr(ref_i[f"{ansatz} energy"], refs_exact.loc(0)[system]),
                        100
                        * to_corr_error(
                            ref_i[f"{ansatz} error"], refs_exact.loc(0)[system]
                        ),
                        ref_i[f"{ansatz} nCSF"],
                    )
                )
        except:
            data.append((float('NaN'), float('NaN'), float('NaN')))
    return np.array(data)


label_reposition = np.ones([6, 6, 4, 2])  # H2 LiH Li2 Be B C / ref / ansatz
label_reposition[
    [2, 2, 3, 4, 4, 5, 5, 5, 5, 5, 5, 5],
    [5, 2, 0, 2, 5, 5, 3, 0, 5, 3, 0, 2],
    [0, 2, 3, 2, 2, 0, 0, 0, 2, 1, 1, 2],
] = np.array(
    [
        [1, 0.85],
        [0.25, 1],
        [1, 1.2],
        [0.25, 1],
        [1, 1.2],
        [1.1, 1],
        [1, 1.25],
        [1.1, 1.2],
        [1.3, 1.0],
        [1.0, 0.9],
        [1.0, 0.9],
        [0.25, 1],
    ]
)
labels = {
    (0, 1): 'other works',
    (1, 1): None,
    (0, 0): 'w/o backflow',
    (1, 0): 'uses CSFs',
}
fig, axs = plt.subplots(
    1, 5,
    sharex=True, sharey=True,
    figsize=(6, 4.5),
    gridspec_kw=dict(hspace=0.08, wspace=0.09),
)
for s, (system, axi) in enumerate(zip(systems[:-1], axs)):
    axi.set_title(fr'$\mathrm{{{system.replace("2", "_2")}}}$')
    ds = data_paulinet[s]
    for k in range(2):
        axi.errorbar(
            [1, 1, 6, 6][k::2],
            ds[k::2, 0].clip(0.05),
            ds[k::2, 1],
            label=(None, 'PauliNet')[k],
            ls='',
            marker='o',
            fillstyle=['none', 'full'][k],
            ms=7,
            color=f"C0",
            clip_on=False,
        )
    for i, p in enumerate(ds[:, 0]):
        if p < 0.05:
            axi.annotate(
                "",
                xy=([1, 6][i // 2], 0.05),
                xycoords='data',
                xytext=([1, 6][i // 2], 0.03),
                textcoords='data',
                arrowprops=dict(arrowstyle="<-"),
                annotation_clip=False,
            )
    for j, ref in enumerate(
        ['Brown', 'Casalengo', 'Morales', 'Rios', 'Seth', 'Toulouse']
    ):
        dj = get_ref(refs_qmc, ref, system)
        for k, l in product([1, 0], range(2)):
            axi.errorbar(
                dj[k + 2 * l, 2],
                dj[k + 2 * l, 0],
                dj[k + 2 * l, 1],
                fillstyle=['none', 'full'][k],
                ls='',
                marker=['o', '^'][l],
                c=f"C1",
                ms=6,
                label=labels[l, k] if j == 0 else None
            )
        for i, (y, _, x) in enumerate(dj):
            axi.annotate(j + 1, (x * 1.5, y * 1.05) * label_reposition[s, j, i])
    axi.set_yscale('log')
    axi.set_xscale('log')
    axi.set_yticks([0.1, 1, 10])
    axi.set_yticklabels(['99.9%', '99%', '90%'])
    axi.set_xticks([1, 10, 100])
    axi.set_xticklabels([1, 10, 100])
    axi.xaxis.set_minor_locator(mpl.ticker.LogLocator(subs=(4, 7), numticks=8))
    axi.set_ylim(60, 0.05)
    axi.set_xlim(0.6, 1000)
    axi.grid(axis='y', which='major', ls='dotted')
    if axi is axs.flat[0]:
        axi.legend(
            loc='lower left',
            bbox_to_anchor=(0.5, 1.06),
            ncol=6,
            handletextpad=0.5,
            columnspacing=1,
        )
fig.text(0.5, 0.03, 'number of determinants/CSFs', ha='center')
fig.text(0.05, 0.5, 'correlation energy', va='center', rotation='vertical')
savefig(fig, 'small-systems')

## Learning curves

In [None]:
systems = ['H2', 'Be', 'B', 'LiH', 'Li2']
ansatzes = ['SD-SJ', 'SD-SJBF', 'MD-SJ', 'MD-SJBF']
results = {}
with h5py.File('../data/raw/data_pub_small_systems.h5', 'r') as f:
    for system, ansatz in product(systems, ansatzes):
        E_mean = f[system][ansatz]['train'][...].mean(axis=1)
        ewm = EWMAverage(outlier_maxlen=3, outlier=3, decay_alpha=10)
        E_ewm = []
        for e in E_mean:
            ewm.update(e)
            E_ewm.append((ewm.mean.item().n, ewm.mean.item().s))
        E_ewm = unp.uarray(*zip(*E_ewm))
        results[system, ansatz] = (E_ewm, f[system].attrs['ref_energy'])

In [None]:
def plot_mean_err(ax, data, ref_enes, bs, decay):
    step, E_ewm = data
    inds = np.geomspace(1, 9_999, 200).astype(int)
    energy = to_corr(unp.nominal_values(E_ewm), ref_enes)
    err = to_corr_error(unp.std_devs(E_ewm), ref_enes)
    ax.plot(step[inds], energy[inds])
    ax.fill_between(
        step[inds], (energy + err)[inds], (energy - err)[inds], color='grey', alpha=0.5
    )


fig, axes = plt.subplots(
    len(systems),
    len(ansatzes),
    figsize=(3.6, 4.3),
    gridspec_kw=dict(hspace=0.08, wspace=0.09),
    sharex=True,
    sharey=True,
)
for (i, system), (j, ansatz) in product(enumerate(systems), enumerate(ansatzes)):
    ax = axes[i, j]
    plot_mean_err(
        ax,
        (np.arange(0, 10_000), results[system, ansatz][0]),
        results[system, ansatz][1],
        10_000,
        3,
    )
    ax.set_xscale('log')
    ax.set_yscale('corr_energy')
    ax.set_xticks([1, 10, 100, 1000])
    ax.set_xticklabels([1, None, None, 1000])
    ax.xaxis.set_minor_locator(mpl.ticker.LogLocator(subs=(4, 7), numticks=8))
    ax.set_xticklabels([], minor=True)
    ax.set_xlim(1, 1e4)
    ax.set_ylim(-0.3, 0.999)
    ax.grid(axis='y', which='major', ls='dotted')
    if i == 0:
        ax.set_title(ansatz)
    if j == 0:
        ax.set_ylabel(fr'$\mathrm{{{system.replace("2", "_2")}}}$', labelpad=16)
fig.text(0.5, 0.06, 'iterations', ha='center')
fig.text(-0.015, 0.5, 'correlation energy', rotation='vertical', va='center')
savefig(fig, 'learning-curves')

## H10

In [None]:
distances = np.array([1.2, 1.4, 1.6, 1.8, 2.0, 2.4, 2.8, 3.2, 3.6])
systems = [f'H10_d{di}' for di in distances]
ansatzes = ['SD-SJ', 'SD-SJBF', 'MD-SJBF']
data_corr = {}
with h5py.File('../data/raw/data_pub_h10.h5', 'r') as f:
    ref_energies = np.array([f[system].attrs['ref_energy'] for system in systems])
    for ansatz in ansatzes:
        data = np.array([f[system][ansatz].attrs['energy'] for system in systems])
        data_corr[ansatz] = np.array(
            [
                [to_corr(ei[0], ei_ref), to_corr_error(ei[1], ei_ref)]
                for ei, ei_ref in zip(data, ref_energies)
            ]
        )

In [None]:
fig, ax = plt.subplots(figsize=(3.63, 2.7))
for i, ansatz in enumerate(ansatzes):
    ax.errorbar(
        distances,
        data_corr[ansatz][:, 0],
        data_corr[ansatz][:, 1],
        label=ansatz,
        ls=[':', '-.', '-'][i],
        fillstyle=['none', 'full', 'full'][i],
        marker='o',
        ms='4',
        color='C0',
    )

ax.legend(loc='lower center', bbox_to_anchor=(0.43, 1), ncol=3)
ax.set_ylabel('correlation energy', labelpad=7)
ax.set_xlabel('H–H distance [a.u.]')
ax.set_yscale('corr_energy')
ax.set_xticks(distances[::2])
ax.grid(axis='y', which='major', ls='dotted')
savefig(fig, 'h10-dis-curve')

## Distance basis

In [None]:
fig, ax = plt.subplots(figsize=(2, 2))
x = torch.linspace(0, 12, 300)
ax.plot(x.numpy(), DistanceBasis(32, envelope='nocusp')(x).numpy())
ax.set_xlabel(r'$r/a_0$')
ax.set_ylabel(r'$\mathbf{e}(r)$')
ax.set_yticks([0, 0.4])
savefig(fig, 'dist-features')

## Diatomics

In [None]:
refs_qmc = pd.read_csv('../data/extern/diatomics-qmc.csv').set_index(['system', 'ref'])
refs_exact = pd.read_csv('../data/extern/diatomics-exact.csv').set_index(['system'])

dets = [1, 3, 10, 30, 100]
systems = ['Li2', 'Be2', 'B2', 'C2']
refs = ['Filippi', 'Toulouse', 'Morales']
with h5py.File('../data/raw/data_pub_diatomics.h5', 'r') as f:
    Es = np.array(
        [[f[system][f'{d}det'].attrs['energy'] for d in dets] for system in systems]
    )
    ref_energies = np.array([f[system].attrs['ref_energy'] for system in systems])

In [None]:
def plot_vmc_dmc(ax, dets, e_vmc, e_dmc, e_ref, label, color, marker):
    ax.plot(
        dets,
        to_corr(e_vmc, e_ref),
        label=f'ref. [{label}]',
        ls=' ',
        fillstyle='none',
        marker=marker,
        color=color,
        ms=5,
    )
    ax.plot(
        dets, to_corr(e_dmc, e_ref), ls=' ', marker='x', fillstyle='none', color=color,
    )
    ax.plot(
        [dets, dets],
        [to_corr(e_vmc, e_ref), to_corr(e_dmc, e_ref)],
        color=color,
        ls=':',
    )


fig, axes = plt.subplots(
    2,
    2,
    sharex=True,
    sharey=True,
    figsize=(3.5, 2.6),
    gridspec_kw=dict(hspace=0.08, wspace=0.06),
)
for i, ax, s, ek in zip(range(4), axes.flat, systems, Es):
    ax.errorbar(
        dets,
        to_corr(ek[:, 0], ref_energies[i]),
        to_corr_error(ek[:, 1], ref_energies[i]),
        ms=5,
        marker='o',
        ls='',
        color='C0',
        linewidth=2,
        label='PauliNet',
    )
    ax.plot(
        dets,
        to_corr(ek[:, 0], ref_energies[i]),
        ls=':',
        color='grey',
        linewidth=2,
        zorder=0,
    )
    for ref, color, marker in zip(refs, 'gry', 'o^^'):
        ref_j = refs_qmc.loc(0)[systems[i], ref]
        plot_vmc_dmc(
            ax,
            ref_j['ndet'],
            ref_j['e_vmc'],
            ref_j['e_dmc'],
            refs_exact.loc()[s],
            ['FU', 'TU', 'Mo'][refs.index(ref)],
            color,
            marker,
        )
    ax.set_xscale('log')
    ax.set_xlim(0.5, 5_000)
    ax.set_xticks([1, 10, 100, 1000])
    ax.set_xticklabels([1, 10, 100, 1000])
    ax.xaxis.set_minor_locator(mpl.ticker.LogLocator(subs=(4, 7), numticks=8))
    ax.set_xticklabels([], minor=True)
    ax.set_yscale('corr_energy')
    ax.set_ylim(0.7, 0.9992)
    ax.grid(axis='y', which='major', ls='dotted')
    ax.annotate(
        fr'$\mathrm{{{s.replace("2", "_2")}}}$', (0.05, 0.8), xycoords='axes fraction'
    )
    if ax is axes.flat[0]:
        ax.legend(
            loc='lower center',
            bbox_to_anchor=(0.8, 1),
            ncol=4,
            handletextpad=0.5,
            columnspacing=1,
        )
fig.text(
    0.5, -0.02, 'number of determinants/CSFs', ha='center', va='center',
)
fig.text(
    -0.02, 0.5, 'correlation energy', ha='center', va='center', rotation='vertical',
)
savefig(fig, 'diatomics')

## Determinants

- NN-QMC: 1s to 10s
- standard MD-QMC: 100s to 100,000s
- NN-CI: 100,000s
- FCI-QMC: 1,000,000s to 100,000,000s
- FCI: to 100,000,000s to 1,000,000,000s

In [None]:
def sstep(x):
    return np.piecewise(
        x,
        [x <= 0, x < 1],
        [0, lambda x: -20 * x ** 7 + 70 * x ** 6 - 84 * x ** 5 + 35 * x ** 4, 1],
    )


def get_bar(ws, vmax=1, dens=100):
    return np.hstack(
        [
            np.linspace(0, vmax, dens * ws[0]),
            vmax * np.ones(dens * ws[1]),
            np.linspace(vmax, 0, dens * ws[2]),
        ]
    )


fig, ax = plt.subplots(figsize=(2.7, 1.6))
payload = [
    ('multideterminant\nQMC + NNs', 2, 50),
    ('multideterminant QMC', 100, 1e5),
    ('configuration\ninteraction + NNs', 1e5, 1e6),
    ('configuration interaction', 2e6, 2e9),
]
for i, (_, fro, to) in enumerate(payload):
    ax.add_patch(Rectangle((fro, i + 0.1), to - fro, 0.8, color='grey'))
ax.set_xlim(1, 1e10)
ax.set_ylim(0, 4)
ax.set_xscale('log')
ax.set_xlabel('number of determinants')
ax.axvline(1e5, color='black', ls='dashed')
ax.text(2.5, 4.3, '1st quantization', fontstyle='italic')
ax.text(2.0e5, 4.3, '2nd quantization', fontstyle='italic')
ax.set_yticks([0.5, 1.5, 2.5, 3.5])
ax.set_yticklabels([l for l, *_ in payload], ha='right')
savefig(fig, 'ndets')

## Cyclobutadiene

In [None]:
results = pd.read_csv('../data/final/cyclobutadiene-fit.csv')

In [None]:
fig, ax = plt.subplots(figsize=(3, 1.7))
for (batch, state), traj in results.groupby(['batch', 'state']):
    if batch != 250:
        continue
    color = {'ground': COLORS[0], 'transition': 'lightskyblue'}[state]
    if state == 'ground':
        state = 'minimum'
    ax.plot(traj['step'], traj['energy_ewm'], label=state, color=color)
ax.set_ylim(-154.65, None)
ax.set_xlim(10, None)
ax.yaxis.set_major_locator(mpl.ticker.MultipleLocator(0.5))
ax.yaxis.set_minor_locator(mpl.ticker.MultipleLocator(0.1))
ax.grid(axis='y', which='major')
ax.grid(axis='y', which='minor', ls='dotted')
ax.set_xscale('log')
ax.set_xlabel('iterations')
ax.set_ylabel('total energy [a.u.]')
ax.legend()
savefig(fig, 'cyclobutadiene-training')

In [None]:
results = pd.read_csv('../data/final/cyclobutadiene-sample.csv')
(
    results.groupby(['batch', 'state'])
    .apply(lambda x: ufloat(x['energy'].mean(), x['energy'].std() / np.sqrt(len(x))))
    .unstack()
    .pipe(lambda x: 632 * (x['ground'] - x['transition']))
)