In [None]:
import os
import pandas as pd

import h5py
import matplotlib.pyplot as plt
import numpy as np
import matplotlib as mpl
import matplotlib.ticker

import dlqmc.mplext

In [None]:
refs = []
with h5py.File(f'../data/raw/data_pub_diatomics_refereces.h5', 'r') as f:
    f.visititems(lambda k, v: refs.append((k, v)))
    refs = {tuple(k.split('/')): v[...] for k, v in refs if isinstance(v, h5py.Dataset)}
refs_qmc = (
    pd.DataFrame(
        {k: v for k, v in refs.items() if k[0][0].isupper()},
        index=['Li2', 'Be2', 'B2', 'C2'],
    )
    .transpose()
    .stack()
    .unstack(-2)
    .rename_axis(index=['ref', 'system'])
    .reorder_levels([1, 0])
    .sort_index()
)
refs_qmc.to_csv('../data/extern/diatomics-qmc.csv')
refs_exact = (
    pd.DataFrame(
        {k[1]: v for k, v in refs.items() if k[0] == 'ref_energies'},
        index=['HF', 'exact']
    )
    .transpose()
    .rename_axis(index=['system'])
)
refs_exact.to_csv('../data/extern/diatomics-exact.csv')

In [None]:
refs_qmc = pd.read_csv('../data/extern/diatomics-qmc.csv').set_index(['system', 'ref'])
refs_exact = pd.read_csv('../data/extern/diatomics-exact.csv').set_index(['system'])

dets = [1, 3, 10, 30, 100]
systems = ['Li2', 'Be2', 'B2', 'C2']
refs = ['Filippi', 'Toulouse', 'Morales']
with h5py.File(f'../data/raw/data_pub_diatomics.h5', 'r') as f:
    Es = np.array(
        [[f[system][f'{d}det'].attrs['energy'] for d in dets] for system in systems]
    )
    ref_energies = np.array([f[system].attrs['ref_energy'] for system in systems])

In [None]:
def to_corr(x, ref):
    return (ref[0] - x) / (ref[0] - ref[1])


def to_corr_error(x, ref):
    return x / (ref[0] - ref[1])


def plot_vmc_dmc(ax, dets, e_vmc, e_dmc, e_ref, label, color, marker):
    ax.plot(
        dets,
        to_corr(e_vmc, e_ref),
        label=f'Ref. [{label}]',
        ls=' ',
        fillstyle='none',
        marker=marker,
        color=color,
        ms=5,
    )
    ax.plot(
        dets,
        to_corr(e_dmc, e_ref),
        ls=' ',
        marker='x',
        fillstyle='none',
        color=color,
    )
    ax.plot(
        [dets, dets],
        [to_corr(e_vmc, e_ref), to_corr(e_dmc, e_ref)],
        color=color,
        ls=':',
    )


fig, axes = plt.subplots(
    2, 2,
    sharex=True, sharey=True,
    figsize=(3.5, 2.6),
    gridspec_kw=dict(hspace=0.08, wspace=0.06),
)
for i, ax, s, ek in zip(range(4), axes.flat, systems, Es):
    ax.errorbar(
        dets,
        to_corr(ek[:, 0], ref_energies[i]),
        to_corr_error(ek[:, 1], ref_energies[i]),
        ms=5,
        marker='o',
        ls='',
        color='C0',
        linewidth=2,
        label='PauliNet',
    )
    ax.plot(
        dets,
        to_corr(ek[:, 0], ref_energies[i]),
        ls=':',
        color='grey',
        linewidth=2,
        zorder=0,
    )
    for ref, color, marker in zip(refs, 'gry', 'o^^'):
        ref_j = refs_qmc.loc(0)[systems[i], ref]
        plot_vmc_dmc(
            ax,
            ref_j['ndet'],
            ref_j['e_vmc'],
            ref_j['e_dmc'],
            refs_exact.loc()[s],
            ref,
            color,
            marker,
        )
    ax.set_xscale('log')
    ax.set_xlim(0.5, 5_000)
    ax.set_xticks([1, 10, 100, 1000])
    ax.set_yscale('corr_energy')
    ax.set_ylim(0.7, 0.9992)
    ax.grid(axis='y', which='major', ls='dotted')
    ax.annotate(fr'$\mathrm{{{s.replace("2", "_2")}}}$', (0.05, 0.8), xycoords='axes fraction')
    if ax is axes.flat[0]:
        ax.legend(loc='lower center', bbox_to_anchor=(1, 1), ncol=2)
fig.text(
    0.5, -0.02,
    'number of determinants/CSFs',
    ha='center', va='center',
)
fig.text(
    -0.02, 0.5,
    'correlation energy',
    ha='center',
    va='center',
    rotation='vertical',
);