In [None]:
%load_ext autoreload
%autoreload 1
%aimport dlqmc.nn, dlqmc.sampling, dlqmc.utils
%config InlineBackend.figure_format = 'svg' 
%config InlineBackend.print_figure_kwargs = \
    {'bbox_inches': 'tight', 'dpi': 300}

In [None]:
import ipywidgets

import numpy as np
from scipy import special
import scipy.stats as sps
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader, RandomSampler
from torch.distributions import Normal
from pyscf import gto, scf, dft
import pyscf
from pyscf.data.nist import BOHR
import time
from functools import partial
from tqdm.auto import tqdm, trange

from dlqmc.gto import *
from dlqmc.nn import *
from dlqmc.sampling import langevin_monte_carlo, hmc ,samples_from, metropolis
from dlqmc.fit import loss_local_energy
from dlqmc.utils import (
    plot_func, get_flat_mesh, assign_where, plot_func_xy,
    plot_func_x, integrate_on_mesh, assign_where
)
from dlqmc.physics import (
    local_energy, grad, quantum_force,nuclear_potential,
    nuclear_energy, laplacian, electronic_potential
)
from dlqmc.geom import *
from dlqmc.analysis import autocorr_coeff, blocking
from dlqmc.nn import ssp

## HF WF

In [None]:
mol = gto.M(
    atom=[
        ['H', (-0.742, 0, 0)],
        ['H', (0.742, 0, 0)]
    ],
    unit='bohr',
    basis='aug-cc-pv5z',
    charge=0,
    spin=2,
)
mf = scf.RHF(mol)
mf.kernel()


In [None]:
#gtowf.get_aos(torch.randn(1, 3))

In [None]:
gtowf = PyscfGTOSlaterWF(mf)

In [None]:
plt.figure(figsize=(12,4))
for i in range(6):
    plt.subplot2grid((2,3),(i//3,i%3))
    x = torch.zeros(500, 6)
    x[:,i] = torch.linspace(-5, 5, 500)
    x = x.view(-1,2,3)
    plt.title("electron " +str(i//3+1))
    plt.plot(np.linspace(-5, 5, 500),gtowf(x).numpy())
    plt.axhline(0,ls=':',color='k')
    #plt.axis('off')

In [None]:
try:
    net.cuda()
    plt.figure(figsize=(12,4))
    for i in range(6):
        plt.subplot2grid((2,3),(i//3,i%3))
        x = torch.zeros(500, 6)
        x[:,i] = torch.linspace(-5, 5, 500)
        x = x.view(-1,2,3)
        plt.title("electron " +str(i//3+1))
        plt.plot(np.linspace(-5, 5, 500),net(x.cuda()).cpu().detach().numpy())
        plt.axhline(0,ls=':',color='k')
        #plt.axis('off')
except:
    pass

In [None]:
try:
    G = np.array(np.meshgrid(np.linspace(-5, 5, 500),np.linspace(-5, 5, 500))).T.reshape(-1,2)
    F = np.append(G,np.ones((250000,4)),axis=-1)
    H = np.append(F[:,[0,2,4]],F[:,[1,3,5]],axis=-1)
    W1 = gtowf(torch.from_numpy(H).view(-1,2,3)).view(500,500).numpy()
    W2 = net(torch.from_numpy(H).view(-1,2,3).type(torch.FloatTensor).cuda()).view(500,500).cpu().detach().numpy()
    levels=30
    plt.figure(figsize=(8,3))
    plt.subplot2grid((1,2),(0,0))
    plt.title("gtowf")
    plt.contourf(W1,levels)
    plt.colorbar()
    plt.subplot2grid((1,2),(0,1))
    plt.title("netwf")
    plt.contourf(W2,levels)
    plt.colorbar()
    plt.show()
    
except:
    pass

## Net WF

In [None]:
#H2+     Energy = -0.6023424   for R = 1.9972
#fit(batch_size=10000, n_el=1, steps=500, epochs=1, RR=[[-1, 0, 0], [1., 0, 0]])

#H2		 Energy = -1.173427    for R = 1.40
#fit(batch_size=10000,n_el=2,steps=100,epochs=5,RR=torch.tensor([[-0.7,0,0],[0.7,0,0]]))

#He+	 Energy = -1.9998
#fit(batch_size=10000,n_el=1,steps=100,epochs=5,RR=torch.tensor([[0.,0,0]]),RR_charges=[2])

#He		 Energy = âˆ’2.90338583
#fit(batch_size=10000,n_el=2,steps=300,epochs=5,RR=torch.tensor([[0.3,0,0]]),RR_charges=[2])

In [None]:
h2p = Geometry([[-1., 0., 0.], [1., 0., 0.]], [1., 1.])
h2 = Geometry([[-0.7, 0., 0.], [0.7, 0., 0.]], [1., 1.])
print(h2p)

In [None]:
class Net_pair(nn.Module):
    def __init__(self,geom,n_dist_feats=32):
        super().__init__()
        self.dist_basis = DistanceBasis(n_dist_feats)
        self.geom = geom.as_param_dict()
        self.NN1=nn.Sequential(
            torch.nn.Linear(6, 10),
            SSP(),
            #torch.nn.Linear(10, 10),
            #SSP(),
            torch.nn.Linear(10, 10)
            )
        
    def forward(self,x1,x2):
        d=torch.cat((x1,x2),dim=-1)
        return self.NN1(d).view(-1,10)
    
class Net(nn.Module):
    def __init__(self):
        super().__init__()

        self.NN1=nn.Sequential(
            torch.nn.Linear(10, 10),
            SSP(),
            #torch.nn.Linear(15, 10),
            #SSP(),
            torch.nn.Linear(10, 1))#,
            #nn.Sigmoid())
        
    def forward(self,x):
        
        return torch.sigmoid(self.NN1(x).flatten())

    


In [None]:
class WFNetAnti(nn.Module):
    def __init__(
        self,
        geom,
        n_electrons,
        net,
        net_pair,
        ion_pot=0.5,
        cutoff=10.0,
        n_dist_feats=32,
        alpha=1.0,
    ):
        super().__init__()
        self.dist_basis = DistanceBasis(n_dist_feats)
        self.nuc_asymp = NuclearAsymptotic(geom.charges, ion_pot, alpha=alpha)
        self.geom = geom.as_param_dict()
        n_atoms = len(geom.charges)
        n_pairs = n_electrons * n_atoms + n_electrons * (n_electrons - 1) // 2
        self.deep_lin = nn.Sequential(
            nn.Linear(n_pairs * n_dist_feats, 64),
            SSP(),
            nn.Linear(64, 64),
            SSP(),
            nn.Linear(64, 64),
            SSP(),
            nn.Linear(64, 64),
            SSP(),
            nn.Linear(64, 1),
        )
        self.antisym = AntisymmetricPart(net, net_pair)
        self._pdist = PairwiseDistance3D()
        self._psdist = PairwiseSelfDistance3D()

    def _featurize(self, rs):
        dists_nuc = self._pdist(rs, self.geom.coords[None, ...])
        dists_el = self._psdist(rs)
        dists = torch.cat([dists_nuc.flatten(start_dim=1), dists_el], dim=1)
        xs = self.dist_basis(dists)  # .flatten(start_dim=1)
        return xs.flatten(start_dim=1), (dists_nuc, dists_el)

    def forward(self, rs):
        #dists_nuc = self._pdist(rs, self.geom.coords[None, ...])
        xs, (dists_nuc, dists_el) = self._featurize(rs)
        ys = self.deep_lin(xs).squeeze(dim=1)
        return self.nuc_asymp(dists_nuc) * torch.exp(ys) * self.antisym(rs)  #

In [None]:
#x_line = torch.cat((torch.linspace(-3, 3, 500)[:, None], torch.zeros((500, 3*n_electrons-1))), dim=1)
#x_line=x_line.view(-1,n_electrons,3)
#Pnet= Net_pair(molecule)
#APnet = NetPairwiseAntisymmetry(Pnet)
#Onet = NetOdd(Net())
#Anet = AntisymmetricPart(Onet, Pnet)
#Anet(x_line)
#APnet(x_line[:,0],x_line[:,1])


In [None]:
n_electrons=2
molecule = h2


Onet = Net()
Pnet = Net_pair(molecule)
net = WFNetAnti(molecule,n_electrons,Onet,Pnet,ion_pot=0.7).cuda()


L = []
V = []

x_line = torch.cat((torch.linspace(-3, 3, 500)[:, None], torch.zeros((500, 3*n_electrons-1))), dim=1)
x_line=x_line.view(-1,n_electrons,3).cuda()
#mesh = get_3d_cube_mesh([(-6, 6), (-4, 4), (-4, 4)], [600, 400, 400])

opt = torch.optim.Adam(net.parameters(), lr=1e-2)
t_start=time.time()
scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=1, gamma=0.999)

steps = 2_000
batchsize = 10_000
n_resamplings = 100
n_walker = 1_000

sampler = langevin_monte_carlo(
    net,
    torch.randn(n_walker, n_electrons, 3, device='cuda'),
    tau=0.1,
)



for i_step in range(steps):
        
    if i_step%(steps//4) == 0 or i_step==steps:
        with torch.no_grad():
            Psi2 = net(x_line)**2
            plt.plot(x_line[:,0 , 0].cpu().detach().numpy(), Psi2.cpu().detach().numpy(),label=i_step)
    
    scheduler.step()
    if i_step%(steps//n_resamplings)==0:
        print("resample                                                                        ",end="\r")
        rs,rs_psis  = samples_from(sampler,range(int(batchsize*steps/(n_resamplings*n_walker))))[0:-1]
        rs = rs.flatten(end_dim=1).cuda()
        rs_psis = rs_psis.flatten(end_dim=1).cuda()
        
    r=rs[i_step%(steps//n_resamplings)*batchsize:(i_step%(steps//n_resamplings)+1)*batchsize]
    
    #pr = torch.from_numpy(sps.norm.pdf(np.linalg.norm(r.cpu().detach().numpy(),axis=(1,2)),scale=scale)).type(torch.FloatTensor)
    
    E_loc,psi = local_energy(r,net,net.geom,create_graph=True)
    #loss = torch.mean(psi**2/pr*(E_loc**2-0.5)**2)
    wheigts=psi**2/rs_psis[i_step%(steps//n_resamplings)*batchsize:(i_step%(steps//n_resamplings)+1)*batchsize]**2
    
    if i_step<steps//5:
        loss = loss_local_energy(E_loc,wheigts,-1)
    else:
        loss = loss_local_energy(E_loc,wheigts,None)
    
    #loss = torch.mean(net(r)**2*(nuclear_potential(r,h2p)+electronic_potential(r)) - laplacian(r,net)[0]*net(r))/torch.mean(net(r)**2) + (1-torch.mean(net(r)**2))**2
        
    print("Progress {:2.0%}".format(i_step /steps)+"   ->"+"I"*(int(i_step/steps*100)//10)+"i"*(int(i_step/steps*100)%10)+"  "+"current loss = "+str(np.round(loss.item(),4))+"        ", end="\r")


    loss.backward()
    L.append(loss.cpu().detach().numpy())
    V.append(((E_loc**2-E_loc.mean()**2).mean()).cpu().detach().numpy())
    
    torch.nn.utils.clip_grad_norm_(net.parameters(),1000)
    
    opt.step()
    opt.zero_grad()
    
plt.legend()
print("it took ="+str(np.round(time.time()-t_start,5))+"                    ")
    


In [None]:
def normplot(x,y,norm,*args,**kwargs):
    if norm:
        plt.plot(x,y/np.max(np.abs(y)),*args,**kwargs)
    else:
        plt.plot(x,y,*args,**kwargs)

x_line = torch.cat((torch.linspace(-1, 1, 5000)[:, None], torch.zeros((5000, 3*n_electrons-1))), dim=1)
x_line=x_line.view(-1,n_electrons,3).cuda()
x_line.requires_grad = True
net.cuda()
f_line = net._featurize(x_line)
normed=True
normplot(x_line[:,0 , 0].cpu().detach().numpy(), torch.exp(net.deep_lin(f_line[0])).squeeze().cpu().detach().numpy(),label="sym",norm=normed)
normplot(x_line[:,0 , 0].cpu().detach().numpy(), net.antisym(x_line).cpu().detach().numpy(),label="anti",norm=normed)
normplot(x_line[:,0 , 0].cpu().detach().numpy(), net.nuc_asymp(f_line[1][0]).cpu().detach().numpy(),label="asym",norm=normed)
N = net.nuc_asymp(f_line[1][0]).cpu().detach().numpy()
normplot(x_line[:,0,0].cpu().detach().numpy(),-1*(N*x_line[:,0,0].cpu().detach().numpy()),label="asym*line",norm=normed)
normplot(x_line[:,0 , 0].cpu().detach().numpy(), net(x_line).cpu().detach().numpy(),label="WF",norm=normed,lw=2,color='k')

plt.axhline(0,ls=':',color='k')
plt.axvline(0,ls=':',color='k')

#plt.ylim(-1,2)
x_line.requires_grad = False
plt.legend()
plt.savefig('lastrunwf.svg')
plt.show()
plt.subplot2grid((2,1),(0,0))
plt.plot(L[:steps//5])
plt.yscale('log')
plt.subplot2grid((2,1),(1,0))
plt.plot(L[steps//5:])
plt.yscale('log')
plt.savefig('lastrunloss.svg')


In [None]:
#plt.plot(x_line[:,0 , 0].cpu().detach().numpy(),net.antisym.net_pair_anti(x_line[:,0],x_line[:,1]).cpu().detach().numpy())
#plt.show()

In [None]:
#tmp = net.antisym.net_pair_anti(torch.from_numpy(H[:,0:3]).type(torch.FloatTensor).cuda(),torch.from_numpy(H[:,3:]).type(torch.FloatTensor).cuda()).cpu().detach().numpy()[:,9].reshape(500,500)
#plt.contourf(tmp)
#plt.colorbar()

In [None]:
plt.plot(
    x_line[:, 0, 0].detach().cpu().numpy(),
    local_energy(x_line,lambda x: net(x), net.geom)[0].cpu().detach().numpy()#*net(x_line).cpu().detach().numpy()**2
)
#plt.ylim((-10, 20));

In [None]:
if True:
    sampler = hmc(
        net,
        torch.randn(1000, n_electrons, 3, device='cuda'),
        dysteps=3,
        stepsize=0.1
    )
else:
    sampler = langevin_monte_carlo(
        net,
        torch.randn(1000, n_electrons, 3, device='cuda'),
        tau=0.1,
    )
t=time.time()
samples = samples_from(sampler,range(1000))[0].flatten(end_dim=1)
print("it took: "+str(time.time()-t))

In [None]:
plt.hist2d(
    samples[:,0, 0].cpu().detach().numpy(),
    samples[:,0, 1].cpu().detach().numpy(),
    bins=100,
    range=[[-3, 3], [-3, 3]],
)                                   
plt.gca().set_aspect(1)

In [None]:
net = net.cpu()
samples = samples.cpu()

In [None]:
E_loc = local_energy(samples.view([-1,n_electrons,3]), lambda x: net(x),net.geom)[0]

In [None]:
print(np.where((E_loc.detach().numpy())>100)[0].shape)
print(np.where((E_loc.detach().numpy())<-100)[0].shape)
print(np.min(E_loc.detach().numpy()))
print(np.max(E_loc.detach().numpy()))
net(samples[np.where((E_loc.detach().numpy())>10)])**2

In [None]:
mean=E_loc.mean().item()

plt.hist(E_loc.detach().clamp(-1.5, 1).cpu().numpy(), bins=100)
plt.annotate("mean = "+str(np.round(mean,4)),(0,80000))
plt.annotate("var     = "+str(np.round(((E_loc-mean)**2).mean().item(),4)),(0,72000))
plt.savefig('lastruneloc.svg')
plt.show()

In [None]:
#del samples

In [None]:
net.cuda()