In [None]:
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'svg' 
%config InlineBackend.print_figure_kwargs = {'bbox_inches': 'tight', 'dpi': 300}
%matplotlib inline

In [None]:
import time
from functools import partial

import numpy as np
import matplotlib.pyplot as plt
from scipy import special
import torch
from torch import nn
import pyscf
from pyscf import gto, scf
from tqdm.auto import tqdm, trange
from tensorboardX import SummaryWriter

%aimport dlqmc.nn.base
from dlqmc.nn import WFNet, HanNet, DistanceBasis
from dlqmc.fit import fit_wfnet_multi, loss_local_energy, wfnet_fit_driver
from dlqmc.sampling import samples_from, langevin_monte_carlo
from dlqmc.physics import local_energy
from dlqmc.gto import TorchGTOSlaterWF, PyscfGTOSlaterWF, electron_density_of
from dlqmc.analysis import autocorr_coeff, blocking
from dlqmc.geom import Geometry, angstrom, geomdb
from dlqmc.utils import plot_func, plot_func_xy, plot_func_x, integrate_on_mesh, batch_eval_tuple
from dlqmc.stats import GaussianKDEstimator
import dlqmc.torchext as torchext

## H2+

### GTO WF

In [None]:
mol = gto.M(
    atom=geomdb['H2+'].as_pyscf(),
    unit='bohr',
    basis='aug-cc-pv5z',
    cart=True,
    charge=1,
    spin=1,
)
mf = scf.RHF(mol)
scf_energy_big = mf.kernel()
gtowf_big = TorchGTOSlaterWF(mf)

In [None]:
mol = gto.M(
    atom=geomdb['H2+'].as_pyscf(),
    unit='bohr',
    basis='6-311g',
    cart=True,
    charge=1,
    spin=1,
)
mf = scf.RHF(mol)
scf_energy = mf.kernel()
gtowf = TorchGTOSlaterWF(mf)

In [None]:
plot_func_x(gtowf.get_aos, [-7, 7]);
plt.ylim(-1, 1)

In [None]:
plot_func_x(gtowf_big.get_aos, [-7, 7]);
plt.ylim(-1, 1)

In [None]:
integrate_on_mesh(lambda x: gtowf(x.cuda()[:, None])**2, [(-6, 6), (-4, 4), (-4, 4)])

In [None]:
integrate_on_mesh(lambda x: gtowf_big(x.cuda()[:, None])**2, [(-6, 6), (-4, 4), (-4, 4)])

In [None]:
plot_func_x(lambda x: local_energy(x[:, None], gtowf_big, geomdb['H2+'])[0], [-3, 3], shift=1e-4)
plot_func_x(lambda x: local_energy(x[:, None], gtowf, geomdb['H2+'])[0], [-3, 3], shift=1e-10)
plt.ylim((-10, 0));

In [None]:
n_walker = 1_000
sampler = langevin_monte_carlo(
    gtowf,
    torch.randn(n_walker, 1, 3),
    tau=0.1,
)
rs, psis, info = samples_from(sampler, trange(500))
E_loc = local_energy(rs.flatten(end_dim=1), gtowf, geomdb['H2+'])[0].view(n_walker, -1)
info.acceptance.mean()

In [None]:
plt.plot(*rs[0][:50, 0, :2].numpy().T)
plt.gca().set_aspect(1)

In [None]:
plt.plot(E_loc.mean(dim=0).numpy())
plt.ylim(-0.7, -0.5)

In [None]:
plt.hist2d(
    *rs[:, 50:].flatten(end_dim=1)[:, 0, :2].numpy().T,
    bins=100,
    range=[[-3, 3], [-3, 3]],
)                                   
plt.gca().set_aspect(1)

In [None]:
plt.hist(E_loc[:, 50:].flatten().clamp(-1.25, 0).numpy(), bins=100);

In [None]:
E_loc[:, 50:].std()

In [None]:
scf_energy, E_loc[:, 50:].mean().item(), (E_loc[:, 50:].mean(dim=1).std()/np.sqrt(E_loc.shape[0])).item()

In [None]:
plt.plot(blocking(E_loc[:, 50:]).numpy())

In [None]:
plt.plot(autocorr_coeff(range(50), E_loc[:, 50:]).numpy())
plt.axhline()

In [None]:
n_walker = 1_000
sampler = langevin_monte_carlo(
    gtowf_big,
    torch.randn(n_walker, 1, 3).double().cuda(),
    tau=0.1,
)
rs, psis, info = samples_from(sampler, trange(500))
E_loc = batch_eval_tuple(
    local_energy,
    tqdm(rs.flatten(end_dim=1).split(50_000)),
    gtowf_big,
    geomdb['H2+'].as_param_dict().double().cuda(),
)[0].view(n_walker, -1)
info.acceptance.mean()

In [None]:
plt.plot(E_loc.mean(dim=0).cpu().numpy())
plt.ylim(-0.7, -0.5)

In [None]:
plt.hist2d(
    *rs[:, 50:].flatten(end_dim=1)[:, 0, :2].cpu().numpy().T,
    bins=100,
    range=[[-3, 3], [-3, 3]],
)                                   
plt.gca().set_aspect(1)

In [None]:
plt.hist(E_loc[:, 50:].flatten().clamp(-1.25, 0).cpu().numpy(), bins=100);

In [None]:
E_loc[:, 50:].std()

In [None]:
(
    scf_energy_big,
    E_loc[:, 50:].mean().item(),
    (E_loc[:, 50:].mean(dim=1).std()/np.sqrt(E_loc.shape[0])).item()
)

In [None]:
plt.plot(blocking(E_loc[:, 50:]).numpy())

In [None]:
plt.plot(autocorr_coeff(range(50), E_loc[:, 50:]).numpy())
plt.axhline()

### Net WF

In [None]:
plot_func(DistanceBasis(32), [1, 11]);

In [None]:
wfnet = WFNet(geomdb['H2+'], 1, n_orbital_layers=4, ion_pot=0.7).cuda()
sampler = langevin_monte_carlo(
    wfnet,
    torch.randn(1_000, 1, 3, device='cuda'),
    tau=0.1,
)
wfnet

In [None]:
fit_wfnet_multi(
    wfnet,
    (partial(loss_local_energy, E_ref=-0.5), loss_local_energy),
    (torch.optim.Adam(wfnet.parameters(), lr=2e-3) for _ in range(2)),
    partial(
        wfnet_fit_driver,
        sampler,
        n_epochs=1,
        n_sampling_steps=550,
        batch_size=1_000,
        n_discard=50,
        range_sampling=partial(trange, desc='sampling steps', leave=False),
        range_training=partial(trange, desc='training steps', leave=False),
    ),
    ({'samplings': range(1)}, {'samplings': trange(10, desc='samplings')}),
    (SummaryWriter(f'runs/wfnet/09/{s}') for s in ['pretrain', 'variance']),
)

In [None]:
plot_func_x(
    lambda x: local_energy(x[:, None], wfnet, wfnet.geom)[0],
    [-15, 15],
    device='cuda'
)
plt.ylim((-1, 0));

In [None]:
plot_func_xy(
    lambda x: wfnet.debug('jastrow', x[:, None]),
    [[-10, 10], [-10, 10]],
    device='cuda'
);

In [None]:
n_walker = 1_000
sampler = langevin_monte_carlo(
    wfnet,
    torch.randn(n_walker, 1, 3, device='cuda'),
    tau=0.1,
)
rs, _, info = samples_from(sampler, trange(500))
E_loc = local_energy(rs.flatten(end_dim=1), wfnet, wfnet.geom)[0].view(n_walker, -1)
info.acceptance.mean()

In [None]:
plt.plot(E_loc.mean(dim=0).cpu().numpy())
plt.ylim(-0.7, -0.5)

In [None]:
plt.hist(E_loc[:, 50:].flatten().clamp(-1.25, 0).cpu().numpy(), bins=100);
plt.xlim(-1.25, 0)

In [None]:
E_loc[:, 50:].std()

In [None]:
scf_energy_big, E_loc[:, 50:].mean().item(), (E_loc[:, 50:].mean(dim=1).std()/np.sqrt(E_loc.shape[0])).item()

In [None]:
bounds = [-2, 2]
plot_func_x(lambda x: torch.log(gtowf_big(x[:, None])), bounds, label='~exact WF');
plot_func_x(lambda x: torch.log(gtowf(x[:, None])), bounds, label='small-basis WF');
plot_func_x(lambda x: torch.log(wfnet(x[:, None]))-0.2, bounds, device='cuda', label='DL WF');
plot_func_x(lambda x: torch.log(wfnet.debug('asymp_nuc', x[:, None]))-0.89, bounds, device='cuda', label='asymptotics');
plot_func_x(lambda x: wfnet.debug('jastrow', x[:, None])-0.2, bounds, device='cuda', label='NN');
plt.legend(loc='lower center', bbox_to_anchor=(0.5, -0.35), ncol=2)
plt.ylim(-1.5, None)

## H2

### GTO WF

In [None]:
mol = gto.M(
    atom=geomdb['H2'].as_pyscf(),
    unit='bohr',
    basis='aug-cc-pvqz',
    charge=0,
    spin=0,
)
mf_big = scf.RHF(mol)
mf_big.kernel()
gtowf_big = PyscfGTOSlaterWF(mf_big)

In [None]:
mol = gto.M(
    atom=geomdb['H2'].as_pyscf(),
    unit='bohr',
    basis='6-311g',
    charge=0,
    spin=0,
)
mf = scf.RHF(mol)
scf_energy = mf.kernel()
gtowf = TorchGTOSlaterWF(mf)

In [None]:
n_walker = 1_000
sampler = langevin_monte_carlo(
    gtowf,
    torch.randn(n_walker, 2, 3).cuda(),
    tau=0.1,
)
rs, _, info = samples_from(sampler, trange(500))
E_loc = local_energy(rs.flatten(end_dim=1), gtowf, geomdb['H2'].as_param_dict().cuda())[0].view(n_walker, -1)
info.acceptance.mean()

In [None]:
plt.plot(E_loc.mean(dim=0).cpu().numpy())
plt.ylim(-1.2, -1)

In [None]:
plt.hist(E_loc[:, 100:].flatten().clamp(-2, 0).cpu().numpy(), bins=100);

In [None]:
E_loc[:, 100:].std()

In [None]:
scf_energy, E_loc[:, 100:].mean().item(), (E_loc[:, 100:].mean(dim=1).std()/np.sqrt(E_loc.shape[0])).item()

In [None]:
plt.plot(blocking(E_loc[:, 100:]).numpy())

In [None]:
plt.plot(autocorr_coeff(range(50), E_loc[:, 100:]).numpy())
plt.axhline()

### Net WF

In [None]:
wfnet0 = WFNet(geomdb['H2'], 2, n_orbital_layers=4, ion_pot=0.7).cuda()
n_walker = 10_000
sampler0 = langevin_monte_carlo(
    wfnet0,
    torch.randn(n_walker, 2, 3, device='cuda'),
    tau=0.1,
)
rs0, _, info = samples_from(sampler0, trange(500))
E_loc0 = batch_eval_tuple(
    local_energy,
    tqdm(rs0.flatten(end_dim=1).split(50_000)),
    wfnet0,
    wfnet0.geom
)[0].view(n_walker, -1)
info.acceptance.mean()

In [None]:
wfnet = WFNet(geomdb['H2'], 2, n_orbital_layers=4, ion_pot=0.7).cuda()
sampler = langevin_monte_carlo(
    wfnet,
    torch.randn(1_000, 2, 3, device='cuda'),
    tau=0.1,
)
wfnet

In [None]:
fit_wfnet_multi(
    wfnet,
    (partial(loss_local_energy, E_ref=-1.1), loss_local_energy),
    (torch.optim.Adam(wfnet.parameters(), lr=2e-3) for _ in range(2)),
    partial(
        wfnet_fit_driver,
        sampler,
        n_epochs=1,
        n_sampling_steps=300,
        n_discard=50,
        range_sampling=partial(trange, desc='sampling steps', leave=False),
        range_training=partial(trange, desc='training steps', leave=False),
    ),
    (
        {'samplings': trange(3, desc='samplings'), 'batch_size': 5_000},
        {'samplings': trange(10, desc='samplings'), 'batch_size': 2_000},
    ),
    (SummaryWriter(f'runs/H2/wfnet/10/{s}') for s in ['pretrain', 'variance']),
)

In [None]:
n_walker = 10_000
sampler = langevin_monte_carlo(
    wfnet,
    torch.randn(n_walker, 2, 3, device='cuda'),
    tau=0.1,
)
rs, _, info = samples_from(sampler, trange(500))
E_loc = batch_eval_tuple(
    local_energy,
    tqdm(rs.flatten(end_dim=1).split(50_000)),
    wfnet,
    wfnet.geom
)[0].view(n_walker, -1)
info.acceptance.mean()

In [None]:
plt.plot(E_loc0.mean(dim=0).cpu().numpy())
plt.plot(E_loc.mean(dim=0).cpu().numpy())
plt.ylim(-1.2, -1)

In [None]:
plt.hist(E_loc0[:, 50:].flatten().clamp(-2, 0).cpu().numpy(), bins=100);
plt.hist(E_loc[:, 50:].flatten().clamp(-2, 0).cpu().numpy(), bins=100);
plt.xlim(-2, 0)

In [None]:
E_loc[:, 50:].std()

In [None]:
E_loc[:, 50:].mean().item(), (E_loc[:, 50:].mean(dim=1).std()/np.sqrt(E_loc.shape[0])).item()

In [None]:
rs_flat = rs[:, 50:].flatten(end_dim=1)
rs_flat0 = rs0[:, 50:].flatten(end_dim=1)

In [None]:
density = GaussianKDEstimator(rs_flat.flatten(end_dim=1), bw=0.05, weights=2)
density0 = GaussianKDEstimator(rs_flat0.flatten(end_dim=1), bw=0.05, weights=2)
plot_func_x(density0, [-2, 3], device='cuda');
plot_func_x(density, [-2, 3], device='cuda');
plot_func_x(lambda x: electron_density_of(mf_big, x), [-2, 3], is_torch=False)

In [None]:
samples = (rs_flat[:, 0] - rs_flat[:, 1]).norm(dim=-1)
samples0 = (rs_flat0[:, 0] - rs_flat0[:, 1]).norm(dim=-1)
radial_pair = GaussianKDEstimator(
    samples[:, None],
    bw=0.01,
    weights=1/samples**2
)
radial_pair0 = GaussianKDEstimator(
    samples0[:, None],
    bw=0.01,
    weights=1/samples**2
)

In [None]:
plot_func(lambda r: radial_pair0(r[:, None]), (0, 5), device='cuda');
plot_func(lambda r: radial_pair(r[:, None]), (0, 5), device='cuda');

### HanNet

In [None]:
wfnet = HanNet(
    geomdb['H2'],
    1,
    1,
    ion_pot=0.7,
    # basis_dim=2,
    # kernel_dim=1,
    # embedding_dim=1,
    # latent_dim=1,
    # n_interactions=1,
    # n_orbital_layers=1,
)
wfnet