In [None]:
import os
import h5py
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rc('font', family='serif', serif='STIXGeneral', size=10)

def to_corr(x, ref):
    return (ref[0] - x) / (ref[0] - ref[1])
def to_corr_error(x, ref):
    return x / (ref[0] - ref[1])

In [None]:
f = h5py.File(f'{os.path.dirname(os.getcwd())}/data/raw/data_pub_h10.h5', 'a')

In [None]:
ref_energies

In [None]:
distances = np.array([1.2,1.4,1.6,1.8,2.0,2.4,2.8,3.2,3.6])
systems = [f'H10_d{di}' for di in distances]
ref_energies = np.array([f[system].attrs['ref_energy']for system in systems])
width = 3.63
fig = plt.figure(figsize= (width, width))
spec = gridspec.GridSpec(ncols=1, nrows=2, height_ratios=[2,1])
ax2 = fig.add_subplot(spec[0])
ax1 = fig.add_subplot(spec[1])

ax1.plot(distances, ref_energies[:,0],ls='--',color='r',label='RHF')
ax1.plot(distances, ref_energies[:,1],ls='--',color='k',label='MRCI+Q-F12',zorder=10)
ax1.plot([1.2,1.2], [0,0],ls='',label=' ')

for i,ansatz in enumerate(['SD-SJ','SD-SJBF','MD-SJBF']):
    data = np.array([f[system][ansatz].attrs['energy'] for system in systems])
    ax1.errorbar(distances,data[:,0],data[:,1],label=ansatz,ls=[':','-.','-'][i],fillstyle=['none','full','full'][i],marker='o',ms='4',color='C0')
ax1.set_ylim(-5.75,-4.75)
ax1.set_yticks([-5.5,-5.])
ax1.set_yticks(np.arange(-5.7,-4.7,0.1),minor=True)
ax1.set_xticks([1.2,1.6,2.0,2.4,2.8,3.2,3.6])
ax1.set_ylabel('total energy [a.u.]',labelpad=5)
ax1.set_xlabel('H–H distance [a.u.]')
ax1.legend(loc='center',bbox_to_anchor=(0.5,3.5),ncol=2)
ax1.grid(axis='y',which='both',ls=':',color='grey')


for i,ansatz in enumerate(['SD-SJ','SD-SJBF','MD-SJBF']): #TODO add data to data_pub_h10.h5
    data = np.array([f[system][ansatz].attrs['energy'] for system in systems])
    data_corr = np.array([[to_corr(ei[0],ei_ref),to_corr_error(ei[1],ei_ref)] for ei,ei_ref in zip(data,ref_energies)])
    ax2.errorbar(distances,1-data_corr[:,0],data_corr[:,1],ls=[':','-.','-'][i],fillstyle=['none','full','full'][i],marker='o',ms='4',color='C0')

ax2.grid(axis='y',ls=':',color='grey')

ax2.set_ylabel('correlation energy',labelpad=7)
ax2.tick_params(axis='x',which='both',bottom=False) 
ax2.set_yscale('log')
ax2.set_ylim(0.2,0.008)
ax2.set_yticks([1e-1, 1e-2])
ax2.set_yticklabels(['90%', '99%'])
fig.subplots_adjust(hspace = 0.05 )
#plt.savefig("h10-dis-curve.pdf",bbox_inches='tight')