In [None]:
%load_ext autoreload
%autoreload 1
%aimport  dlqmc.sampling, dlqmc.utils, dlqmc.nn.base, dlqmc.fit
%config InlineBackend.figure_format = 'svg' 
%config InlineBackend.print_figure_kwargs = \
    {'bbox_inches': 'tight', 'dpi': 300}

In [None]:
import ipywidgets
import torch.nn as nn
import numpy as np
from scipy import special
import scipy.stats as sps
import matplotlib.pyplot as plt
import torch
#from torch.utils.data import DataLoader, RandomSampler
#from torch.distributions import Normal
from pyscf import gto, scf, dft
import pyscf
from pyscf.data.nist import BOHR
import time
from functools import partial
from tqdm.auto import tqdm, trange
from tensorboardX import SummaryWriter

from dlqmc.nn.base import * 
from dlqmc.geom import *
from dlqmc.gto import *
from dlqmc.nn import *
from dlqmc.sampling import langevin_monte_carlo, hmc ,samples_from
from dlqmc.fit import *
from dlqmc.nn.anti import *
#from dlqmc.utils import assign_where
from dlqmc.physics import (
    local_energy, grad, quantum_force,nuclear_potential,
    nuclear_energy, laplacian, electronic_potential
)
#from dlqmc.analysis import autocorr_coeff, blocking
from dlqmc.nn import ssp
from dlqmc.nn.hannet import HanNet

def normplot(x,y,norm,*args,**kwargs):
    if norm:
        plt.plot(x,y/np.max(np.abs(y)),*args,**kwargs)
    else:
        plt.plot(x,y,*args,**kwargs)


In [None]:
print(torch.cuda.memory_allocated())
print(torch.cuda.memory_cached(device=None))
print(torch.cuda.max_memory_cached(device=None))
torch.cuda.empty_cache()

In [None]:
d_ref_h2p=1.9972 
h2p = Geometry([[-d_ref_h2p/2, 0., 0.], [d_ref_h2p/2, 0., 0.]], [1., 1.])
#h2p = geomdb['H2+']


h2 = geomdb['H2']
d_ref_h2 = h2.coords[1][0]
h2 = Geometry([[-d_ref_h2/2, 0., 0.], [d_ref_h2/2, 0., 0.]], [1., 1.])
print(d_ref_h2)

### Activation function

In [None]:
def soft(x):
    return np.log(1+np.exp(x))

In [None]:
x=np.linspace(-5,5,100)
def relu(x):
    return x*[x>0]
plt.figure(figsize=(5,3))
plt.plot(x,relu(x).flatten(),color='grey',ls=':',label='relu')
#plt.plot(x,ssp(torch.from_numpy(x)).numpy()+np.log(2),color='k',label='shifted softplus')
plt.plot(x,soft(x),color='k',label='softplus')
plt.legend(loc='upper left')
plt.xticks([])
plt.yticks([])
plt.xlabel("in")
plt.ylabel("out")
plt.savefig('activation.svg')
plt.show()

## $ H_2^+$

In [None]:
mol = gto.M(
    atom=[
        ['H', (-d_ref_h2p/2, 0, 0)],
        ['H', (d_ref_h2p/2, 0, 0)]
    ],
    unit='bohr',
    basis='6-31G',
    charge=1,
    spin=1,
    cart=True
)
mf = scf.RHF(mol)
mf.kernel()
gtowf = TorchGTOSlaterWF(mf)

In [None]:
n_electrons=1
n_up = 1
n_down = n_electrons-n_up
net = HanNet(h2p,n_up,n_down).cuda()

In [None]:
if False:
    sampler = hmc(
        net,
        torch.randn(1000, n_electrons, 3, device='cuda')/3,
        dysteps=3,
        stepsize=0.2,
        tau = 0.1,
        cutoff = 1.0
    )
    sampler_gtowf = hmc(
        gtowf,
        torch.randn(1000, n_electrons, 3, device='cuda')/3,
        dysteps=3,
        stepsize=0.2,
        tau = 0.1,
        cutoff = 1.0
    )
else:
    sampler = langevin_monte_carlo(
        net,
        torch.randn(1000, n_electrons, 3, device='cuda')/3,
        tau=0.1,
    )
    sampler_gtowf = langevin_monte_carlo(
        gtowf,
        torch.randn(1000, n_electrons, 3, device='cuda')/3,
        tau=0.1,
    )

In [None]:
molecule = h2p
fit_wfnet(
    net,
    partial(loss_local_energy, E_ref=-1.1,p=2),
    torch.optim.Adam(net.parameters(), lr=1e-3),
    wfnet_fit_driver(
            sampler,
            samplings=range(1),
            n_epochs=1,
            n_sampling_steps=250,
            batch_size=1_000,
            n_discard=50,
            range_sampling=partial(trange, desc='sampling steps', leave=False),
            range_training=partial(trange, desc='training steps', leave=False),
        ),
    clip_grad = None,
    writer = None,
    )

fit_wfnet(
    net,
    partial(loss_local_energy, E_ref=None,p=1),
    torch.optim.Adam(net.parameters(), lr=1e-3),
    wfnet_fit_driver(
            sampler,
            samplings=range(3),
            n_epochs=1,
            n_sampling_steps=550,
            batch_size=1_000,
            n_discard=50,
            range_sampling=partial(trange, desc='sampling steps', leave=False),
            range_training=partial(trange, desc='training steps', leave=False),
        ),
    clip_grad = None,
    writer = None,
    )


In [None]:
x_line = torch.cat((torch.linspace(-3, 3, 5000)[:, None], torch.zeros((5000, 3*n_electrons-1))), dim=1)
x_line=x_line.view(-1,n_electrons,3).cuda()
#x_line.requires_grad = True
dists_elec = pairwise_distance(x_line, x_line)
dists_nuc = pairwise_distance(x_line, net.coords[None, ...])
dists = torch.cat([dists_elec, dists_nuc], dim=2)
dists_basis = net.dist_basis(dists)
xs = net.schnet(dists_basis)
jastrow = net.orbital(xs).squeeze(dim=-1).sum(dim=-1)
normed=True
normplot(x_line[:,0 , 0].cpu().detach().numpy(), gtowf(x_line).cpu().detach().numpy(),label="HF",norm=normed,color='grey')

normplot(x_line[:,0 , 0].cpu().detach().numpy(), torch.exp(jastrow).squeeze().cpu().detach().numpy(),ls='--',label="sym",norm=normed)
normplot(x_line[:,0 , 0].cpu().detach().numpy(), net.asymp_nuc(dists_nuc).cpu().detach().numpy(),ls='--',label="asym",norm=normed)
normplot(x_line[:,0 , 0].cpu().detach().numpy(), net(x_line).cpu().detach().numpy(),label="WF",norm=normed,lw=2,color='k')


#plt.axhline(0,ls=':',color='k')
plt.axvline(h2p.coords[0][0],ls=':',color='k')
plt.axvline(h2p.coords[1][0],ls=':',color='k')
plt.ylabel("wavefunction in arbitrary units")
plt.xlabel("position in $a_0$")
#plt.ylim(-1,2)
x_line.requires_grad = False
plt.legend()
plt.savefig('h2p_wf.svg')
plt.show()


In [None]:
h2p._coords = h2p._coords.cuda()
h2p._charges = h2p._charges.cuda()

In [None]:
plt.plot(
    x_line[:, 0, 0].detach().cpu().numpy(),
    local_energy(x_line,lambda x: net(x), h2p)[0].cpu().detach().numpy()#*net(x_line).cpu().detach().numpy()**2
)
plt.ylim((-10, 10));

In [None]:
t=time.time()
samples = samples_from(sampler,range(1000))[0].flatten(end_dim=1)
print("it took: "+str(time.time()-t))

In [None]:
plt.hist2d(
    samples[:,0, 0].cpu().detach().numpy(),
    samples[:,0, 1].cpu().detach().numpy(),
    bins=100,
    range=[[-3, 3], [-3, 3]],
)                                   
plt.gca().set_aspect(1)
plt.xlabel("x in $a_0$")
plt.ylabel("y in $a_0$")
plt.savefig('h2p_samples.svg')

In [None]:
E_loc = dlqmc.utils.batch_eval_tuple(partial(local_energy, wf=lambda x: net(x),geom=h2p),tqdm(samples.view([-1,n_electrons,3]).split(1000)))[0]

In [None]:
t=time.time()
samples = samples_from(sampler_gtowf,range(1000))[0].flatten(end_dim=1)
print("it took: "+str(time.time()-t))
E_loc_gtowf = dlqmc.utils.batch_eval_tuple(partial(local_energy, wf=lambda x: gtowf(x),geom=h2p),tqdm(samples.view([-1,n_electrons,3]).split(1000)))[0]

In [None]:
plt.figure()
E_min=-1.25
E_max=0.25

mean=E_loc_gtowf.detach().clamp(E_min, E_max).mean().item()
c1="grey"

h = plt.hist(E_loc_gtowf.detach().clamp(E_min, E_max).cpu().numpy(), bins=100,alpha = 0.8,color=c1,density=True)
plt.annotate("mean = "+str(np.round(mean,4)),(-0.5,18),color=c1)
plt.annotate("var     = "+str(np.round(np.var(E_loc_gtowf.cpu().detach().numpy()),4)),(-0.5,15),color=c1)


mean=E_loc.detach().clamp(E_min, E_max).mean().item()
c2="red"
h = plt.hist(E_loc.detach().clamp(E_min, E_max).cpu().numpy(), bins=100,alpha = 0.8,color=c2,density=True)
plt.annotate("mean = "+str(np.round(mean,4)),(-0.5,np.max(h[0])/2),color=c2)
plt.annotate("var     = "+str(np.round(np.var(E_loc.cpu().detach().numpy()),4)),(-0.5,np.max(h[0])/2-np.max(h[0])/15),color=c2)

plt.xlabel("$E_{loc}$ in $E_h$")
plt.ylabel("relative occurrence")
plt.savefig('h2p_Elochist.svg')
plt.show()

## $H_2$ singlet

In [None]:
#D = np.linspace(0.5,2,30)
#E = []
#for d in D:
#    mol = gto.M(
#        atom=[
#            ['H', (-d_ref_h2p, 0, 0)],
#            ['H', (d_ref_h2p, 0, 0)]
#        ],
#        unit='bohr',
#        basis='4-31G',
#        charge=0,
#        spin=0,
#    )
#    mf = scf.RHF(mol)
#    E.append(mf.kernel())
    
mol = gto.M(
    atom=[
        ['H', (-d_ref_h2/2, 0, 0)],
        ['H', (d_ref_h2/2, 0, 0)]
    ],
    unit='bohr',
    basis='4-31G',
    charge=0,
    spin=0,
    cart=True
)
mf = scf.RHF(mol)
mf.kernel()
gtowf = TorchGTOSlaterWF(mf)

In [None]:
#plt.plot(2*D,E)
#plt.xlabel("distance nuclei in $a_0$" )
#plt.ylabel("energy in $E_h$ ")
#plt.title("Ground state energy of $H_2^+$ with respect to distance of nuclei")
#plt.show()

In [None]:
n_electrons=2
n_up = 1
n_down = n_electrons-n_up
net = HanNet(h2,n_up,n_down).cuda()

In [None]:
if False:
    sampler = hmc(
        net,
        torch.randn(1000, n_electrons, 3, device='cuda')/3,
        dysteps=3,
        stepsize=0.2,
        tau = 0.1,
        cutoff = 1.0
    )
    sampler_gtowf = hmc(
        gtowf,
        torch.randn(1000, n_electrons, 3, device='cuda')/3,
        dysteps=3,
        stepsize=0.2,
        tau = 0.1,
        cutoff = 1.0
    )
else:
    sampler = langevin_monte_carlo(
        net,
        torch.randn(1000, n_electrons, 3, device='cuda')/3,
        tau=0.1,
    )
    sampler_gtowf = langevin_monte_carlo(
        gtowf,
        torch.randn(1000, n_electrons, 3, device='cuda')/3,
        tau=0.1,
    )

In [None]:
molecule = h2
fit_wfnet(
    net,
    partial(loss_local_energy, E_ref=-1.1,p=2),
    torch.optim.Adam(net.parameters(), lr=1e-3),
    wfnet_fit_driver(
            sampler,
            samplings=range(1),
            n_epochs=1,
            n_sampling_steps=250,
            batch_size=1_000,
            n_discard=50,
            range_sampling=partial(trange, desc='sampling steps', leave=False),
            range_training=partial(trange, desc='training steps', leave=False),
        ),
    clip_grad = None,
    writer = None,
    )

fit_wfnet(
    net,
    partial(loss_local_energy, E_ref=None,p=1),
    torch.optim.Adam(net.parameters(), lr=1e-3),
    wfnet_fit_driver(
            sampler,
            samplings=range(5),
            n_epochs=1,
            n_sampling_steps=550,
            batch_size=1_000,
            n_discard=50,
            range_sampling=partial(trange, desc='sampling steps', leave=False),
            range_training=partial(trange, desc='training steps', leave=False),
        ),
    clip_grad = None,
    writer = None,
    )


In [None]:

x_line = torch.cat((torch.linspace(-3, 3, 5000)[:, None], torch.zeros((5000, 3*n_electrons-1))), dim=1)
x_line[:,3] = 0
x_line=x_line.view(-1,n_electrons,3).cuda()
#x_line.requires_grad = True
dists_elec = pairwise_distance(x_line, x_line)
dists_nuc = pairwise_distance(x_line, net.coords[None, ...])
dists = torch.cat([dists_elec, dists_nuc], dim=2)
dists_basis = net.dist_basis(dists)
xs = net.schnet(dists_basis)
jastrow = net.orbital(xs).squeeze().sum(dim=-1)

normed=True
normplot(x_line[:,0 , 0].cpu().detach().numpy(), gtowf(x_line).cpu().detach().numpy(),label="HF",norm=normed,color='grey')

normplot(x_line[:,0 , 0].cpu().detach().numpy(), torch.exp(jastrow).squeeze().cpu().detach().numpy(),ls='--',label="sym",norm=normed)
normplot(x_line[:,0 , 0].cpu().detach().numpy(), net.asymp_nuc(dists_nuc).cpu().detach().numpy(),ls='--',label="asym",norm=normed)
normplot(x_line[:,0 , 0].cpu().detach().numpy(), net(x_line).cpu().detach().numpy(),label="WF",norm=normed,lw=2,color='k')


#plt.axhline(0,ls=':',color='k')
plt.axvline(h2.coords[0][0],ls=':',color='k')
plt.axvline(h2.coords[1][0],ls=':',color='k')
plt.axvline(0,ls='--',color='grey',lw=0.5)
plt.ylabel("wavefunction in arbitrary units")
plt.xlabel("position in $a_0$")
#plt.ylim(-1,2)
x_line.requires_grad = False
plt.legend()
plt.savefig('h2singlet_wf.svg')
plt.show()

#plt.savefig('lastrunloss.png')


In [None]:
h2._coords = h2._coords.cuda()
h2._charges = h2._charges.cuda()

In [None]:
plt.plot(
    x_line[:, 0, 0].detach().cpu().numpy(),
    local_energy(x_line,lambda x: net(x), h2)[0].cpu().detach().numpy()#*net(x_line).cpu().detach().numpy()**2
)
plt.ylim((-10, 10));

In [None]:
t=time.time()
samples = samples_from(sampler,range(1000))[0].flatten(end_dim=1)
print("it took: "+str(time.time()-t))

In [None]:
plt.hist2d(
    samples[:,:, 0].cpu().detach().numpy().reshape(-1),
    samples[:,:, 1].cpu().detach().numpy().reshape(-1),
    bins=100,
    range=[[-3, 3], [-3, 3]],
)                                   
plt.gca().set_aspect(1)
plt.xlabel("x in $a_0$")
plt.ylabel("y in $a_0$")
plt.savefig('h2singlet_samples.svg')

In [None]:
cor = (np.linspace(0.05,5.05,1000))**2*4*np.pi
count,bins = np.histogram(torch.norm(samples[:,0]-samples[:,1],dim=1).cpu().numpy(),bins=1000,density=True)
plt.plot((bins[:-1]+bins[1:])/2,count)
plt.plot((bins[:-1]+bins[1:])/2,count/cor)
#plt.plot(1/cor)

In [None]:
E_loc = dlqmc.utils.batch_eval_tuple(partial(local_energy, wf=lambda x: net(x),geom=h2),tqdm(samples.view([-1,n_electrons,3]).split(1000)))[0]

In [None]:
t=time.time()
samples = samples_from(sampler_gtowf,range(1000))[0].flatten(end_dim=1)
print("it took: "+str(time.time()-t))
E_loc_gtowf = dlqmc.utils.batch_eval_tuple(partial(local_energy, wf=lambda x: gtowf(x),geom=h2),tqdm(samples.view([-1,n_electrons,3]).split(1000)))[0]

In [None]:
plt.figure()

E_min=-2.
E_max=0

mean=E_loc_gtowf.detach().clamp(E_min, E_max).mean().item()
c1="grey"

h = plt.hist(E_loc_gtowf.detach().clamp(E_min, E_max).cpu().numpy(), bins=100,alpha = 0.8,color=c1,density=True)
plt.annotate("mean = "+str(np.round(mean,4)),(-1,13),color=c1)
plt.annotate("var     = "+str(np.round(np.var(E_loc_gtowf.cpu().detach().numpy()),4)),(-1,10.5),color=c1)


mean=E_loc.detach().clamp(E_min, E_max).mean().item()
c2="red"
h = plt.hist(E_loc.detach().clamp(E_min, E_max).cpu().numpy(), bins=100,alpha = 0.8,color=c2,density=True)
plt.annotate("mean = "+str(np.round(mean,4)),(-1,np.max(h[0])/2),color=c2)
plt.annotate("var     = "+str(np.round(np.var(E_loc.cpu().detach().numpy()),4)),(-1,np.max(h[0])/2-np.max(h[0])/15),color=c2)


plt.xlabel("$E_{loc}$ in $E_h$")
plt.ylabel("relative occurrence")
plt.savefig('h2singlet_Elochist.svg')
plt.show()

## $H_2$ triplet

In [None]:
mol = gto.M(
    atom=[
        ['H', (-d_ref_h2/2, 0, 0)],
        ['H', (d_ref_h2/2, 0, 0)]
    ],
    unit='bohr',
    basis='6-31G',
    charge=0,
    spin=2,
    cart=True
)
mf = scf.RHF(mol)
mf.kernel()
gtowf = TorchGTOSlaterWF(mf)

In [None]:
x_line = torch.cat((torch.linspace(-3, 3, 5000)[:, None], torch.zeros((5000, 3-1))), dim=1)
x_line = torch.cat(x_line,x_line,dim=1)

In [None]:
n_electrons=2
n_up = 2
n_down = n_electrons-n_up
net = HanNet(h2,n_up,n_down,latent_dim=30).cuda()

In [None]:
if False:
    sampler = hmc(
        net,
        torch.randn(1000, n_electrons, 3, device='cuda')/3,
        dysteps=3,
        stepsize=0.2,
        tau = 0.1,
        cutoff = 1.0
    )
    sampler_gtowf = hmc(
        gtowf,
        torch.randn(1000, n_electrons, 3, device='cuda')/3,
        dysteps=3,
        stepsize=0.2,
        tau = 0.1,
        cutoff = 1.0
    )
else:
    sampler = langevin_monte_carlo(
        net,
        torch.randn(1000, n_electrons, 3, device='cuda')/3,
        tau=0.1,
    )
    sampler_gtowf = langevin_monte_carlo(
        gtowf,
        torch.randn(1000, n_electrons, 3, device='cuda')/3,
        tau=0.1,
    )

In [None]:
molecule = h2
fit_wfnet(
    net,
    partial(loss_local_energy, E_ref=-0.8,p=2),
    torch.optim.Adam(net.parameters(), lr=1e-3),
    wfnet_fit_driver(
            sampler,
            samplings=range(1),
            n_epochs=1,
            n_sampling_steps=250,
            batch_size=1_000,
            n_discard=50,
            range_sampling=partial(trange, desc='sampling steps', leave=False),
            range_training=partial(trange, desc='training steps', leave=False),
        ),
    clip_grad = None,
    writer = None,
    )

fit_wfnet(
    net,
    partial(loss_local_energy, E_ref=None,p=1),
    torch.optim.Adam(net.parameters(), lr=1e-3),
    wfnet_fit_driver(
            sampler,
            samplings=range(5),
            n_epochs=1,
            n_sampling_steps=550,
            batch_size=1_000,
            n_discard=50,
            range_sampling=partial(trange, desc='sampling steps', leave=False),
            range_training=partial(trange, desc='training steps', leave=False),
        ),
    clip_grad = None,
    writer = None,
    )


In [None]:
x_line = torch.cat((torch.linspace(-3, 3, 5000)[:, None], torch.zeros((5000, 3*n_electrons-1))), dim=1)
x_line[:,3] = 0
x_line=x_line.view(-1,n_electrons,3).cuda()
#x_line.requires_grad = True
dists_elec = pairwise_distance(x_line, x_line)
dists_nuc = pairwise_distance(x_line, net.coords[None, ...])
dists = torch.cat([dists_elec, dists_nuc], dim=2)
dists_basis = net.dist_basis(dists)
xs = net.schnet(dists_basis)
jastrow = net.orbital(xs).squeeze().sum(dim=-1)

normed=True
normplot(x_line[:,0 , 0].cpu().detach().numpy(), gtowf(x_line).cpu().detach().numpy(),label="HF",norm=normed,color='grey')

normplot(x_line[:,0 , 0].cpu().detach().numpy(), torch.exp(jastrow).squeeze().cpu().detach().numpy(),label="sym",norm=normed,ls='--')
normplot(x_line[:,0 , 0].cpu().detach().numpy(), net.asymp_nuc(dists_nuc).cpu().detach().numpy(),label="asym",norm=normed,ls='--')
normplot(x_line[:,0 , 0].cpu().detach().numpy(), net.anti_up(x_line[:,:], dists_elec[:, :, :, None]).squeeze(dim=-1).cpu().detach().numpy(),label="antisym",norm=normed,ls='--')
normplot(x_line[:,0 , 0].cpu().detach().numpy(), net(x_line).cpu().detach().numpy(),label="WF",norm=normed,lw=2,color='k')

#plt.axhline(0,ls=':',color='k')
plt.axvline(h2.coords[0][0],ls=':',color='grey')
plt.axvline(h2.coords[1][0],ls=':',color='grey')
plt.axvline(0,ls='--',color='grey',lw=0.5)
plt.ylabel("wavefunction in arbitrary units")
plt.xlabel("position in $a_0$")
#plt.ylim(-1,2)
x_line.requires_grad = False
plt.legend()
plt.savefig('h2triplet_wf.svg')
plt.show()

#plt.savefig('lastrunloss.png')


In [None]:
h2._coords = h2._coords.cuda()
h2._charges = h2._charges.cuda()

In [None]:
plt.plot(
    x_line[:, 0, 0].detach().cpu().numpy(),
    local_energy(x_line,lambda x: net(x), h2)[0].cpu().detach().numpy()#*net(x_line).cpu().detach().numpy()**2
)
plt.ylim((-10, 10));

In [None]:
t=time.time()
samples = samples_from(sampler,range(1000))[0].flatten(end_dim=1)
print("it took: "+str(time.time()-t))

In [None]:
plt.hist2d(
    samples[:,:, 0].cpu().detach().numpy().reshape(-1),
    samples[:,:, 1].cpu().detach().numpy().reshape(-1),
    bins=100,
    range=[[-3, 3], [-3, 3]],
)                                   
plt.gca().set_aspect(1)
plt.xlabel("x in $a_0$")
plt.ylabel("y in $a_0$")
plt.savefig('h2triplet_samples.svg')

In [None]:
E_loc = dlqmc.utils.batch_eval_tuple(partial(local_energy, wf=lambda x: net(x),geom=h2),tqdm(samples.view([-1,n_electrons,3]).split(1000)))[0]

In [None]:
t=time.time()
samples = samples_from(sampler_gtowf,range(1000))[0].flatten(end_dim=1)
print("it took: "+str(time.time()-t))
E_loc_gtowf = dlqmc.utils.batch_eval_tuple(partial(local_energy, wf=lambda x: gtowf(x),geom=h2),tqdm(samples.view([-1,n_electrons,3]).split(1000)))[0]

In [None]:
plt.figure()
E_loc_gtowf=E_loc_gtowf[E_loc_gtowf>-4]
E_loc=E_loc[E_loc>-4]
E_min=-1.5
E_max=0.5

mean=E_loc_gtowf.detach().clamp(E_min, E_max).mean().item()
c1="grey"

h = plt.hist(E_loc_gtowf.detach().clamp(E_min, E_max).cpu().numpy(), bins=100,alpha = 0.8,color=c1,density=True)
plt.annotate("mean = "+str(np.round(mean,4)),(-0.5,5.8),color=c1)
plt.annotate("var     = "+str(np.round(np.var(E_loc_gtowf.cpu().detach().clamp(E_min, E_max).numpy()),4)),(-0.5,4.2),color=c1)


mean=E_loc.detach().clamp(E_min, E_max).mean().item()
c2="red"
h = plt.hist(E_loc.detach().clamp(E_min, E_max).cpu().numpy(), bins=100,alpha = 0.8,color=c2,density=True)
plt.annotate("mean = "+str(np.round(mean,4)),(-0.5,np.max(h[0])/2),color=c2)
plt.annotate("var     = "+str(np.round(np.var(E_loc.cpu().detach().clamp(E_min, E_max).numpy()),4)),(-0.5,np.max(h[0])/2-np.max(h[0])/15),color=c2)


plt.xlabel("$E_{loc}$ in $E_h$")
plt.ylabel("relative occurrence")
plt.savefig('h2triplet_Elochist.svg')
plt.show()

# Helium

In [None]:
mol = gto.M(
    atom=[
        ['He', (0, 0, 0)]
    ],
    unit='bohr',
    basis='6-31G',
    charge=0,
    spin=0,
    cart=True
)
mf = scf.RHF(mol)
mf.kernel()
gtowf = TorchGTOSlaterWF(mf)

In [None]:
he = geomdb['H']
he._coords = torch.tensor([[0,0,0.]])
he._charges = torch.tensor([2.])
print(he)

In [None]:
n_electrons=2
n_up = 1
n_down = n_electrons-n_up
net = HanNet(he,n_up,n_down,kernel_dim=128).cuda()

In [None]:
if False:
    sampler = hmc(
        net,
        torch.randn(1000, n_electrons, 3, device='cuda'),
        dysteps=3,
        stepsize=0.2,
        tau = 0.1,
        cutoff = 1.0
    )
    sampler_gtowf = hmc(
        gtowf,
        torch.randn(1000, n_electrons, 3, device='cuda'),
        dysteps=3,
        stepsize=0.2,
        tau = 0.1,
        cutoff = 1.0
    )
else:
    sampler = langevin_monte_carlo(
        net,
        torch.randn(1000, n_electrons, 3, device='cuda'),
        tau=0.1,
    )
    sampler_gtowf = langevin_monte_carlo(
        gtowf,
        torch.randn(1000, n_electrons, 3, device='cuda'),
        tau=0.1,
    )

In [None]:
molecule = he
fit_wfnet(
    net,
    partial(loss_local_energy, E_ref=-3,p=2),
    torch.optim.Adam(net.parameters(), lr=1e-3),
    wfnet_fit_driver(
            sampler,
            samplings=range(1),
            n_epochs=1,
            n_sampling_steps=200,
            batch_size=1_000,
            n_discard=50,
            range_sampling=partial(trange, desc='sampling steps', leave=False),
            range_training=partial(trange, desc='training steps', leave=False),
        ),
    clip_grad = 1,
    writer = None,
    )

fit_wfnet(
    net,
    partial(loss_local_energy, E_ref=None,p=1),
    torch.optim.Adam(net.parameters(), lr=1e-3),
    wfnet_fit_driver(
            sampler,
            samplings=range(5),
            n_epochs=1,
            n_sampling_steps=1050,
            batch_size=1_000,
            n_discard=50,
            range_sampling=partial(trange, desc='sampling steps', leave=False),
            range_training=partial(trange, desc='training steps', leave=False),
        ),
    clip_grad = 1,
    writer = None,
    )


In [None]:
x_line = torch.cat((torch.linspace(-3, 3, 5000)[:, None], torch.zeros((5000, 3*n_electrons-1))), dim=1)
x_line[:,3] = 1
x_line=x_line.view(-1,n_electrons,3).cuda()
#x_line.requires_grad = True
dists_elec = pairwise_distance(x_line, x_line)
dists_nuc = pairwise_distance(x_line, net.coords[None, ...])
dists = torch.cat([dists_elec, dists_nuc], dim=2)
dists_basis = net.dist_basis(dists)
xs = net.schnet(dists_basis)
jastrow = net.orbital(xs).squeeze().sum(dim=-1)

normed=True
normplot(x_line[:,0 , 0].cpu().detach().numpy(), gtowf(x_line).cpu().detach().numpy(),label="HF",norm=normed,color='grey')

normplot(x_line[:,0 , 0].cpu().detach().numpy(), torch.exp(jastrow).squeeze().cpu().detach().numpy(),label="sym",norm=normed,ls='--')
normplot(x_line[:,0 , 0].cpu().detach().numpy(), net.asymp_nuc(dists_nuc).cpu().detach().numpy(),label="asym",norm=normed,ls='--')
#normplot(x_line[:,0 , 0].cpu().detach().numpy(), net.anti_up(x_line[:,:], dists_elec[:, :, :, None]).squeeze(dim=-1).cpu().detach().numpy(),label="antisym",norm=normed,ls='--')
normplot(x_line[:,0 , 0].cpu().detach().numpy(), net(x_line).cpu().detach().numpy(),label="WF",norm=normed,lw=2,color='k')

#plt.axhline(0,ls=':',color='k')
plt.axvline(0,ls=':',color='grey')
plt.axvline(1,ls='--',color='grey',lw=0.5)
plt.ylabel("wavefunction in arbitrary units")
plt.xlabel("position in $a_0$")
#plt.ylim(-1,2)
x_line.requires_grad = False
plt.legend()
plt.savefig('hesinglet_wf.svg')
plt.show()

#plt.savefig('lastrunloss.png')


In [None]:
he._coords = he._coords.cuda()
he._charges = he._charges.cuda()

In [None]:
plt.plot(
    x_line[:, 0, 0].detach().cpu().numpy(),
    local_energy(x_line,lambda x: net(x), he)[0].cpu().detach().numpy()#*net(x_line).cpu().detach().numpy()**2
)
plt.ylim((-20, 10));

In [None]:
t=time.time()
samples = samples_from(sampler,range(1000))[0].flatten(end_dim=1)
print("it took: "+str(time.time()-t))

In [None]:
plt.hist2d(
    samples[:,:, 0].cpu().detach().numpy().reshape(-1),
    samples[:,:, 1].cpu().detach().numpy().reshape(-1),
    bins=100,
    range=[[-1, 1], [-1, 1]],
)                                   
plt.gca().set_aspect(1)
plt.xlabel("x in $a_0$")
plt.ylabel("y in $a_0$")
plt.savefig('hesinglet_samples.svg')

In [None]:
E_loc = dlqmc.utils.batch_eval_tuple(partial(local_energy, wf=lambda x: net(x),geom=he),tqdm(samples.view([-1,n_electrons,3]).split(1000)))[0]

In [None]:
t=time.time()
samples = samples_from(sampler_gtowf,range(1000))[0].flatten(end_dim=1)
print("it took: "+str(time.time()-t))
E_loc_gtowf = dlqmc.utils.batch_eval_tuple(partial(local_energy, wf=lambda x: gtowf(x),geom=he),tqdm(samples.view([-1,n_electrons,3]).split(1000)))[0]

In [None]:
plt.figure()
E_loc_gtowf=E_loc_gtowf[E_loc_gtowf>-4]
E_loc=E_loc[E_loc>-4]
E_min=-5.5
E_max=-1.5
loc=-2.5

mean=E_loc_gtowf.detach().clamp(E_min, E_max).mean().item()
c1="grey"

h = plt.hist(E_loc_gtowf.detach().clamp(E_min, E_max).cpu().numpy(), bins=100,alpha = 0.8,color=c1,density=True)
plt.annotate("mean = "+str(np.round(mean,4)),(loc,5.5),color=c1)
plt.annotate("var     = "+str(np.round(np.var(E_loc_gtowf.cpu().detach().clamp(E_min, E_max).numpy()),4)),(loc,4.2),color=c1)


mean=E_loc.detach().clamp(E_min, E_max).mean().item()
c2="red"
h = plt.hist(E_loc.detach().clamp(E_min, E_max).cpu().numpy(), bins=100,alpha = 0.8,color=c2,density=True)
plt.annotate("mean = "+str(np.round(mean,4)),(loc,np.max(h[0])/2),color=c2)
plt.annotate("var     = "+str(np.round(np.var(E_loc.cpu().detach().clamp(E_min, E_max).numpy()),4)),(loc,np.max(h[0])/2-np.max(h[0])/15),color=c2)


plt.xlabel("$E_{loc}$ in $E_h$")
plt.ylabel("relative occurrence")
plt.savefig('hesinglet_Elochist.svg')
plt.show()

## LiH

In [None]:
lih=geomdb['LiH']
molecule = lih


In [None]:
mol = gto.M(
    atom=[
        ['Li', (0, 0, 0)],
        ['H', (3.0141, 0, 0)]
    ],
    unit='bohr',
    basis='6-31G',
    charge=0,
    spin=0,
    cart=True
)
mf = scf.RHF(mol)
mf.kernel()
gtowf = TorchGTOSlaterWF(mf)

In [None]:
n_electrons=4
n_up = 2
n_down = n_electrons-n_up
net = HanNet(lih,n_up,n_down,cusp_same=-0.5,cusp_anti=-0.25).cuda()

In [None]:
if True:
    sampler = hmc(
        net,
        torch.randn(1000, n_electrons, 3, device='cuda'),
        dysteps=3,
        stepsize=0.2,
        tau = 0.1,
        cutoff = 1.0
    )
    sampler_gtowf = hmc(
        gtowf,
        torch.randn(1000, n_electrons, 3, device='cuda'),
        dysteps=3,
        stepsize=0.2,
        tau = 0.1,
        cutoff = 1.0
    )
else:
    sampler = langevin_monte_carlo(
        net,
        torch.randn(1000, n_electrons, 3, device='cuda')*3,
        tau=0.1,
    )
    sampler_gtowf = langevin_monte_carlo(
        gtowf,
        torch.randn(1000, n_electrons, 3, device='cuda')*3,
        tau=0.1,
    )

In [None]:
fit_wfnet(
    net,
    partial(loss_local_energy, E_ref=-9,p=2),
    torch.optim.Adam(net.parameters(), lr=1e-3),
    wfnet_fit_driver(
            sampler,
            samplings=range(1),
            n_epochs=1,
            n_sampling_steps=4,
            batch_size=1_000,
            n_discard=2,
            range_sampling=partial(trange, desc='sampling steps', leave=False),
            range_training=partial(trange, desc='training steps', leave=False),
        ),
    clip_grad = 1,
    writer = None,
    )

fit_wfnet(
    net,
    partial(loss_local_energy, E_ref=None,p=1),
    torch.optim.Adam(net.parameters(), lr=1e-3),
    wfnet_fit_driver(
            sampler,
            samplings=range(1),
            n_epochs=1,
            n_sampling_steps=150,
            batch_size=1_000,
            n_discard=50,
            range_sampling=partial(trange, desc='sampling steps', leave=False),
            range_training=partial(trange, desc='training steps', leave=False),
        ),
    clip_grad = 1,
    writer = None,
    )


In [None]:
e_x2=3
e_y2=1
e_x3=0.3
e_y3=-1
e_x4=-0.3
x_line = torch.cat((torch.linspace(-8, 8, 5000)[:, None], torch.zeros((5000, 3*n_electrons-1))), dim=1)
x_line[:,3] = e_x2
x_line[:,4] = e_y2
x_line[:,6] = e_x3
x_line[:,7] = e_y3
x_line[:,9] = e_x4

x_line=x_line.view(-1,n_electrons,3).cuda()
#x_line.requires_grad = True
dists_elec = pairwise_distance(x_line, x_line)
dists_nuc = pairwise_distance(x_line, net.coords[None, ...])
dists = torch.cat([dists_elec, dists_nuc], dim=2)
dists_basis = net.dist_basis(dists)
xs = net.schnet(dists_basis)
jastrow = net.orbital(xs).squeeze().sum(dim=-1)

normed=True
normplot(x_line[:,0 , 0].cpu().detach().numpy(), gtowf(x_line).cpu().detach().numpy(),label="HF",norm=normed,color='grey')

normplot(x_line[:,0 , 0].cpu().detach().numpy(), torch.exp(jastrow).squeeze().cpu().detach().numpy(),label="sym",norm=normed,ls='--')
normplot(x_line[:,0 , 0].cpu().detach().numpy(), net.asymp_nuc(dists_nuc).cpu().detach().numpy(),label="asym",norm=normed,ls='--')
normplot(x_line[:,0 , 0].cpu().detach().numpy(), net.anti_up(x_line[:,:], dists_elec[:, :, :, None]).squeeze(dim=-1).cpu().detach().numpy(),label="antisym",norm=normed,ls='--')
normplot(x_line[:,0 , 0].cpu().detach().numpy(), net(x_line).cpu().detach().numpy(),label="WF",norm=normed,lw=2,color='k')

#plt.axhline(0,ls=':',color='k')
plt.axvline(0,ls=':',color='grey')
plt.ylabel("wavefunction in arbitrary units")
plt.xlabel("position in $a_0$")
#plt.ylim(-1,2)
x_line.requires_grad = False
plt.legend()
plt.savefig('hesinglet_wf.svg')
plt.show()

#plt.savefig('lastrunloss.png')


In [None]:
t=time.time()
samples = samples_from(sampler,range(1500))[0].flatten(end_dim=1)[500:]
print("it took: "+str(time.time()-t))

In [None]:
plt.hist2d(
    samples[:,:, 0].cpu().detach().numpy().reshape(-1),
    samples[:,:, 1].cpu().detach().numpy().reshape(-1),
    bins=100,
    range=[[-2, 6], [-2, 2]],
)                                   
plt.gca().set_aspect(1)
plt.xlabel("x in $a_0$")
plt.ylabel("y in $a_0$")
plt.savefig('lihinglet_samples.svg')

In [None]:
lih._coords = lih._coords.cuda()
lih._charges = lih._charges.cuda()

In [None]:
E_loc = dlqmc.utils.batch_eval_tuple(partial(local_energy, wf=lambda x: net(x),geom=molecule),tqdm(samples.view([-1,n_electrons,3]).split(1000)))[0]

In [None]:
t=time.time()
samples = samples_from(sampler_gtowf,range(1000))[0].flatten(end_dim=1)
print("it took: "+str(time.time()-t))


In [None]:
E_loc_gtowf = dlqmc.utils.batch_eval_tuple(partial(local_energy, wf=lambda x: gtowf(x),geom=molecule),tqdm(samples.view([-1,n_electrons,3]).split(1000)))[0]

In [None]:
plt.figure()
#E_loc_gtowf=E_loc_gtowf[E_loc_gtowf>-4]
#E_loc=E_loc[E_loc>-4]
E_min=-10.5
E_max=0
loc=-2.5

#mean=E_loc_gtowf.detach().clamp(E_min, E_max).mean().item()
#c1="grey"
#h#= plt.hist(E_loc_gtowf.detach().clamp(E_min, E_max).cpu().numpy(), bins=100,alpha = 0.8,color=c1,density=True)
#plt.annotate("mean = "+str(np.round(mean,4)),(loc,5.5),color=c1)
#plt.annotate("var     = "+str(np.round(np.var(E_loc_gtowf.cpu().detach().clamp(E_min, E_max).numpy()),4)),(loc,4.2),color=c1)


mean=E_loc.detach().clamp(E_min, E_max).mean().item()
c2="red"
h = plt.hist(E_loc.detach().clamp(E_min, E_max).cpu().numpy(), bins=100,alpha = 0.8,color=c2,density=True)
plt.annotate("mean = "+str(np.round(mean,4)),(loc,np.max(h[0])/2),color=c2)
plt.annotate("var     = "+str(np.round(np.var(E_loc.cpu().detach().clamp(E_min, E_max).numpy()),4)),(loc,np.max(h[0])/2-np.max(h[0])/15),color=c2)


plt.xlabel("$E_{loc}$ in $E_h$")
plt.ylabel("relative occurrence")
plt.savefig('hesinglet_Elochist.svg')
plt.show()