# DL-QMC jh

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from functools import partial

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import torch
from pyscf import gto, scf
from tensorboardX import SummaryWriter
from torch import nn
from tqdm.auto import tqdm, trange

from deepqmc.analysis import autocorr_coeff, blocking, pair_correlations_from_samples
from deepqmc.fit import (
    fit_wfnet,
    fit_wfnet_multi,
    loss_local_energy,
    wfnet_fit_driver,
    wfnet_fit_driver_simple,
)
from deepqmc.geom import geomdb
from deepqmc.nn import (
    SSP,
    DistanceBasis,
    GTOBasis,
    GTOShell,
    HanNet,
    HFNet,
    WFNet,
    get_custom_dnn,
    get_log_dnn,
    pairwise_diffs,
    pairwise_distance,
    pairwise_self_distance,
)
from deepqmc.physics import local_energy
from deepqmc.pyscfext import electron_density_of
from deepqmc.sampling import langevin_monte_carlo, samples_from
from deepqmc.stats import GaussianKDEstimator
from deepqmc.utils import (
    DebugContainer,
    batch_eval_tuple,
    number_of_parameters,
    plot_func,
    plot_func_x,
    plot_func_xy,
    shuffle_tensor,
)

In [None]:
# needs to be in a separate cell, see https://github.com/ipython/ipython/issues/11098
mpl.rcParams['figure.dpi'] = 100

## H2+

### GTO WF

In [None]:
mol = gto.M(
    atom=geomdb['H2+'].as_pyscf(),
    unit='bohr',
    basis='cc-pv5z',
    cart=True,
    charge=1,
    spin=1,
)
mf = scf.RHF(mol)
scf_energy_big = mf.kernel()
gtowf_big = HFNet.from_pyscf(mf, cusp_correction=False).cuda()

In [None]:
mol = gto.M(
    atom=geomdb['H2+'].as_pyscf(),
    unit='bohr',
    basis='6-311g',
    cart=True,
    charge=1,
    spin=1,
)
mf = scf.RHF(mol)
scf_energy = mf.kernel()
gtowf = HFNet.from_pyscf(mf, cusp_correction=False).cuda()

In [None]:
plot_func_x(lambda x: gtowf.basis(pairwise_diffs(x, gtowf.coords)), [-7, 7], device='cuda')
plt.ylim(-1, 1)

In [None]:
plot_func_x(lambda x: gtowf_big.basis(pairwise_diffs(x, gtowf_big.coords)), [-7, 7], device='cuda')
plt.ylim(-1, 1)

In [None]:
plot_func_x(gtowf_big.orbitals, [-7, 7], device='cuda')
plot_func_x(gtowf.orbitals, [-7, 7], device='cuda')

In [None]:
plot_func_x(lambda x: local_energy(x[:, None], gtowf_big)[0], [-3, 3], device='cuda')
plot_func_x(lambda x: local_energy(x[:, None], gtowf)[0], [-3, 3], device='cuda')
plt.ylim((-10, 0))

In [None]:
n_walker = 1_000
sampler = langevin_monte_carlo(gtowf, torch.randn(n_walker, 1, 3).cuda(), tau=0.1)
rs, psis, info = samples_from(sampler, trange(500))
E_loc = local_energy(rs.flatten(end_dim=1), gtowf)[0].view(n_walker, -1)
info.acceptance.mean()

In [None]:
plt.plot(*rs[0][:50, 0, :2].cpu().numpy().T)
plt.gca().set_aspect(1)

In [None]:
plt.plot(E_loc.mean(dim=0).cpu().numpy())
plt.ylim(-0.7, -0.5)

In [None]:
plt.hist2d(
    *rs[:, 50:].flatten(end_dim=1)[:, 0, :2].cpu().numpy().T,
    bins=100,
    range=[[-3, 3], [-3, 3]],
)
plt.gca().set_aspect(1)

In [None]:
_ = plt.hist(E_loc[:, 50:].flatten().clamp(-1.25, 0).cpu().numpy(), bins=100)

In [None]:
E_loc[:, 50:].std()

In [None]:
scf_energy, E_loc[:, 50:].mean().item(), (
    E_loc[:, 50:].mean(dim=1).std() / np.sqrt(E_loc.shape[0])
).item()

In [None]:
plt.plot(blocking(E_loc[:, 50:]).cpu().numpy())

In [None]:
plt.plot(autocorr_coeff(range(50), E_loc[:, 50:]).cpu().numpy())
plt.axhline()

In [None]:
n_walker = 1_000
sampler = langevin_monte_carlo(gtowf_big, torch.randn(n_walker, 1, 3).cuda(), tau=0.1)
rs, psis, info = samples_from(sampler, trange(500))
E_loc = batch_eval_tuple(
    local_energy, tqdm(rs.flatten(end_dim=1).split(50_000)), gtowf_big
)[0].view(n_walker, -1)
info.acceptance.mean()

In [None]:
plt.plot(E_loc.mean(dim=0).cpu().numpy())
plt.ylim(-0.7, -0.5)

In [None]:
plt.hist2d(
    *rs[:, 50:].flatten(end_dim=1)[:, 0, :2].cpu().numpy().T,
    bins=100,
    range=[[-3, 3], [-3, 3]],
)
plt.gca().set_aspect(1)

In [None]:
_ = plt.hist(E_loc[:, 50:].flatten().clamp(-1.25, 0).cpu().numpy(), bins=100)

In [None]:
E_loc[:, 50:].std()

In [None]:
(
    scf_energy_big,
    E_loc[:, 50:].mean().item(),
    (E_loc[:, 50:].mean(dim=1).std() / np.sqrt(E_loc.shape[0])).item(),
)

In [None]:
plt.plot(blocking(E_loc[:, 50:]).numpy())

In [None]:
plt.plot(autocorr_coeff(range(50), E_loc[:, 50:]).numpy())
plt.axhline()

### DL WFs

In [None]:
class AsympNet(BaseWFNet):
    def __init__(self, geom, ion_pot=0.5):
        super().__init__()
        self.register_geom(geom)
        self.asymp_nuc = NuclearAsymptotic(self.charges, ion_pot)
        
    def forward(self, rs):
        dists_nuc = pairwise_distance(rs, self.coords[None, ...])
        asymp_nuc = self.asymp_nuc(dists_nuc)
        return asymp_nuc

In [None]:
fix, axes = plt.subplots(2, 1)
_ = plot_func(DistanceBasis(32), [0, 11], ax=axes[0])
_ = plot_func(DistanceBasis(32, envelope='nocusp'), [0, 11], ax=axes[1])

In [None]:
mol = gto.M(
    atom=geomdb['H2+'].as_pyscf(),
    unit='bohr',
    basis='6-311g',
    cart=True,
    charge=1,
    spin=1,
)
mf = scf.RHF(mol)
mf.kernel()
gtowf = HFNet.from_pyscf(mf, cusp_correction=False).cuda()
gtowf_cusp = HFNet.from_pyscf(mf).cuda()
wfnet = WFNet(geomdb['H2+'], 1, n_orbital_layers=4, ion_pot=0.7).cuda()
asympnet = AsympNet(geomdb['H2+'], ion_pot=0.7).cuda()

class Orbnet(nn.Module):
    def __init__(self, features_in, n_orbitals):
        super().__init__()
        self.net = get_log_dnn(
            features_in, n_orbitals, SSP, n_layers=4, last_bias=False
        )

    def forward(self, mos, xs):
        return mos * torch.exp(self.net(xs.flatten(start_dim=1)))
        

hfnet = HFNet.from_pyscf(mf, orbnet_factory=Orbnet).cuda()
hfnet.mo_coeff.weight.requires_grad_(False);

In [None]:
gtowf

In [None]:
gtowf_cusp

In [None]:
wfnet

In [None]:
hfnet

In [None]:
bounds = [-2, 2]
plot_func_x(gtowf.orbitals, bounds, device='cuda', density=0.002, label='GTO')
plot_func_x(gtowf_cusp.orbitals, bounds, device='cuda', density=0.002, label='GTO w/ cusp')
plot_func_x(lambda x: 0.4*wfnet.debug('asymp_nuc', x[:, None]), bounds, device='cuda', label='asymptotics')
plt.legend(loc='lower center', bbox_to_anchor=(0.5, -0.35), ncol=2)
plt.yscale('log')
plt.ylim(0.25, None)

In [None]:
bounds = [-5, 5]
plot_func_x(lambda x: local_energy(x[:, None], gtowf)[0], bounds, device='cuda', density=0.002, label='GTO')
plot_func_x(lambda x: local_energy(x[:, None], gtowf_cusp)[0], bounds, device='cuda', density=0.002, label='GTO w/ cusp')
plot_func_x(lambda x: local_energy(x[:, None], asympnet)[0], bounds, device='cuda', density=0.002, label='asymptotics')
plt.legend(loc='lower center', bbox_to_anchor=(0.5, -0.35), ncol=2)
plt.ylim((-2, 0))

In [None]:
sampler = langevin_monte_carlo(wfnet, torch.randn(1_000, 1, 3, device='cuda'), tau=0.1)
for _ in range(50):
    next(sampler)
with SummaryWriter('runs/H2+/wfnet/11') as writer:
    fit_wfnet(
        wfnet,
        loss_local_energy,
        torch.optim.Adam(wfnet.parameters(), lr=3e-3),
        wfnet_fit_driver_simple(sampler, n_sampling_steps=1, samplings=trange(2000)),
        writer=writer,
    )

In [None]:
plot_func_x(lambda x: local_energy(x[:, None], wfnet)[0], [-15, 15], device='cuda')
plt.ylim((-1, 0))

In [None]:
plot_func_xy(
    lambda x: wfnet.debug('jastrow', x[:, None]), [[-10, 10], [-10, 10]], device='cuda'
)

In [None]:
n_walker = 1_000
sampler = langevin_monte_carlo(
    wfnet, torch.randn(n_walker, 1, 3, device='cuda'), tau=0.1
)
rs, _, info = samples_from(sampler, trange(500))
E_loc = local_energy(rs.flatten(end_dim=1), wfnet)[0].view(n_walker, -1)
info.acceptance.mean()

In [None]:
plt.plot(E_loc.mean(dim=0).cpu().numpy())
plt.ylim(-0.7, -0.5)

In [None]:
plt.hist(E_loc[:, 50:].flatten().clamp(-1.25, 0).cpu().numpy(), bins=100)
plt.xlim(-1.25, 0)

In [None]:
E_loc[:, 50:].std()

In [None]:
scf_energy_big, E_loc[:, 50:].mean().item(), (
    E_loc[:, 50:].mean(dim=1).std() / np.sqrt(E_loc.shape[0])
).item()

In [None]:
bounds = [-2, 2]
plot_func_x(
    lambda x: torch.log(gtowf_big.float()(x[:, None])),
    bounds,
    device='cuda',
    label='~exact WF',
)
plot_func_x(
    lambda x: torch.log(gtowf(x[:, None])),
    bounds,
    device='cuda',
    label='small-basis WF',
)
plot_func_x(
    lambda x: torch.log(wfnet(x[:, None])) - 0.2, bounds, device='cuda', label='DL WF'
)
plot_func_x(
    lambda x: torch.log(wfnet.debug('asymp_nuc', x[:, None])) - 0.89,
    bounds,
    device='cuda',
    label='asymptotics',
)
plot_func_x(
    lambda x: wfnet.debug('jastrow', x[:, None]) - 0.2,
    bounds,
    device='cuda',
    label='NN',
)
plt.legend(loc='lower center', bbox_to_anchor=(0.5, -0.35), ncol=2)
plt.ylim(-1.5, None)

In [None]:
mol = gto.M(
    atom=geomdb['H2+'].as_pyscf(),
    unit='bohr',
    basis='6-311g',
    cart=True,
    charge=1,
    spin=1,
)
mf = scf.RHF(mol)
mf.kernel()
gtowf = HFNet.from_pyscf(mf, cusp_correction=False).cuda()
gtowf_cusp = HFNet.from_pyscf(mf).cuda()

class Orbnet(nn.Module):
    def __init__(self, features_in, n_orbitals):
        super().__init__()
        self.net = get_log_dnn(
            features_in, n_orbitals, SSP, n_layers=4, last_bias=False
        )

    def forward(self, mos, xs):
        return mos * torch.exp(self.net(xs.flatten(start_dim=1)))
        

hfnet = HFNet.from_pyscf(mf, orbnet_factory=Orbnet).cuda()
hfnet.mo_coeff.weight.requires_grad_(False);

In [None]:
bounds = [-.5, 2]
plot_func_x(gtowf.orbitals, bounds, device='cuda', density=0.002)
plot_func_x(gtowf_cusp.orbitals, bounds, device='cuda', density=0.002)
plot_func_x(hfnet.orbitals, bounds, device='cuda', density=0.002)

In [None]:
bounds = [-15, 15]
plot_func_x(lambda x: local_energy(x[:, None], gtowf)[0], bounds, device='cuda', density=0.002)
plot_func_x(lambda x: local_energy(x[:, None], gtowf_cusp)[0], bounds, device='cuda', density=0.002)
plot_func_x(lambda x: local_energy(x[:, None], hfnet)[0], bounds, device='cuda', density=0.002)
plt.ylim((-2, 0))

In [None]:
bounds = [-3, 3]
plot_func_x(lambda x: local_energy(x[:, None], gtowf)[0], bounds, device='cuda', density=0.002)
plot_func_x(lambda x: local_energy(x[:, None], gtowf_cusp)[0], bounds, device='cuda', density=0.002)
plot_func_x(lambda x: local_energy(x[:, None], hfnet)[0], bounds, device='cuda', density=0.002)
plt.ylim((-4, 0))

In [None]:
sampler = langevin_monte_carlo(hfnet, torch.randn(1_000, 1, 3, device='cuda'), tau=0.1)
for _ in range(50):
    next(sampler)
with SummaryWriter('runs/H2+/hfnet/23') as writer:
    fit_wfnet(
        hfnet,
        loss_local_energy,
        torch.optim.Adam(hfnet.parameters(), lr=3e-3),
        wfnet_fit_driver_simple(sampler, n_sampling_steps=1, samplings=trange(2000)),
        writer=writer,
    )

In [None]:
plot_func_x(lambda x: local_energy(x[:, None], hfnet)[0], [-15, 15], device='cuda')
plt.ylim((-1, 0))

In [None]:
n_walker = 1_000
sampler = langevin_monte_carlo(
    hfnet, torch.randn(n_walker, 1, 3, device='cuda'), tau=0.1
)
rs, _, info = samples_from(sampler, trange(500))
E_loc = local_energy(rs.flatten(end_dim=1), hfnet)[0].view(n_walker, -1)
info.acceptance.mean()

In [None]:
plt.hist(E_loc[:, 50:].flatten().clamp(-1.25, 0).cpu().numpy(), bins=100)
plt.xlim(-1.25, 0)

In [None]:
E_loc[:, 50:].std()

In [None]:
scf_energy_big, E_loc[:, 50:].mean().item(), (
    E_loc[:, 50:].mean(dim=1).std() / np.sqrt(E_loc.shape[0])
).item()

## H2

### GTO WF

In [None]:
mol = gto.M(
    atom=geomdb['H2'].as_pyscf(),
    unit='bohr',
    basis='cc-pv5z',
    charge=0,
    spin=0,
    cart=True,
)
mf_big = scf.RHF(mol)
mf_big.kernel()

In [None]:
mol = gto.M(
    atom=geomdb['H2'].as_pyscf(),
    unit='bohr',
    basis='6-311g',
    charge=0,
    spin=0,
    cart=True,
)
mf = scf.RHF(mol)
scf_energy = mf.kernel()
gtowf = HFNet.from_pyscf(mf).cuda()

In [None]:
n_walker = 1_000
sampler = langevin_monte_carlo(gtowf, torch.randn(n_walker, 2, 3).cuda(), tau=0.1)
rs, _, info = samples_from(sampler, trange(500))
E_loc = local_energy(rs.flatten(end_dim=1), gtowf)[0].view(n_walker, -1)
info.acceptance.mean()

In [None]:
plt.plot(E_loc.mean(dim=0).cpu().numpy())
plt.ylim(-1.2, -1)

In [None]:
_ = plt.hist(E_loc[:, 100:].flatten().clamp(-2, 0).cpu().numpy(), bins=100)

In [None]:
E_loc[:, 100:].std()

In [None]:
scf_energy, E_loc[:, 100:].mean().item(), (
    E_loc[:, 100:].mean(dim=1).std() / np.sqrt(E_loc.shape[0])
).item()

In [None]:
plt.plot(blocking(E_loc[:, 100:]).numpy())

In [None]:
plt.plot(autocorr_coeff(range(50), E_loc[:, 100:]).numpy())
plt.axhline()

### Net WF

In [None]:
wfnet0 = WFNet(geomdb['H2'], 2, n_orbital_layers=4, ion_pot=0.7).cuda()
n_walker = 10_000
sampler0 = langevin_monte_carlo(
    wfnet0, torch.randn(n_walker, 2, 3, device='cuda'), tau=0.1
)
rs0, _, info = samples_from(sampler0, trange(500))
E_loc0 = batch_eval_tuple(
    local_energy, tqdm(rs0.flatten(end_dim=1).split(50_000)), wfnet0
)[0].view(n_walker, -1)
info.acceptance.mean()

In [None]:
wfnet = WFNet(geomdb['H2'], 2, n_orbital_layers=4, ion_pot=0.7).cuda()
sampler = langevin_monte_carlo(wfnet, torch.randn(1_000, 2, 3, device='cuda'), tau=0.1)
for _ in range(50):
    next(sampler)
wfnet

In [None]:
fit_wfnet_multi(
    wfnet,
    (partial(loss_local_energy, E_ref=-1.1), loss_local_energy),
    (torch.optim.Adam(wfnet.parameters(), lr=3e-3) for _ in range(2)),
    partial(wfnet_fit_driver_simple, sampler, n_sampling_steps=1),
    (
        {'samplings': trange(200, desc='pretrain')},
        {'samplings': trange(2000, desc='variance')},
    ),
    (SummaryWriter(f'runs/H2/wfnet/16/{s}') for s in ['pretrain', 'variance']),
)

In [None]:
n_walker = 10_000
sampler = langevin_monte_carlo(
    wfnet, torch.randn(n_walker, 2, 3, device='cuda'), tau=0.1
)
rs, _, info = samples_from(sampler, trange(500))
E_loc = batch_eval_tuple(
    local_energy, tqdm(rs.flatten(end_dim=1).split(50_000)), wfnet
)[0].view(n_walker, -1)
info.acceptance.mean()

In [None]:
plt.plot(E_loc0.mean(dim=0).cpu().numpy())
plt.plot(E_loc.mean(dim=0).cpu().numpy())
plt.ylim(-1.2, -1)

In [None]:
plt.hist(E_loc0[:, 50:].flatten().clamp(-2, 0).cpu().numpy(), bins=100)
plt.hist(E_loc[:, 50:].flatten().clamp(-2, 0).cpu().numpy(), bins=100)
plt.xlim(-2, 0)

In [None]:
E_loc[:, 50:].std()

In [None]:
E_loc[:, 50:].mean().item(), (
    E_loc[:, 50:].mean(dim=1).std() / np.sqrt(E_loc.shape[0])
).item()

In [None]:
rs_flat = rs[:, 50:].flatten(end_dim=1)
rs_flat0 = rs0[:, 50:].flatten(end_dim=1)

In [None]:
density = GaussianKDEstimator(rs_flat.flatten(end_dim=1), bw=0.05, weights=2)
density0 = GaussianKDEstimator(rs_flat0.flatten(end_dim=1), bw=0.05, weights=2)
plot_func_x(density0, [-2, 3], device='cuda')
plot_func_x(density, [-2, 3], device='cuda')
plot_func_x(lambda x: electron_density_of(mf_big, x), [-2, 3], is_torch=False)

In [None]:
pair = pair_correlations_from_samples(rs_flat, 1)
pair0 = pair_correlations_from_samples(rs_flat0, 1)
plot_func(
    lambda r: pair0['ud'](r[:, None]) / pair['decorr'](r[:, None]),
    (0, 5),
    device='cuda',
)
plot_func(
    lambda r: pair['ud'](r[:, None]) / pair['decorr'](r[:, None]), (0, 5), device='cuda'
)

### HanNet

In [None]:
wfnet0 = HanNet(
    geomdb['H2'],
    1,
    1,
    ion_pot=0.7,
    basis_dim=32,
    kernel_dim=32,
    embedding_dim=32,
    latent_dim=1,
    n_interactions=2,
    n_orbital_layers=2,
).cuda()
n_walker = 10_000
sampler0 = langevin_monte_carlo(
    wfnet0, torch.randn(n_walker, 2, 3, device='cuda'), tau=0.1
)
rs0, _, info = samples_from(sampler0, trange(500, desc='sampling'))
E_loc0 = batch_eval_tuple(
    local_energy, tqdm(rs0.flatten(end_dim=1).split(10_000), desc='E_loc'), wfnet0
)[0].view(n_walker, -1)
info.acceptance.mean()

In [None]:
wfnet = HanNet(
    geomdb['H2'],
    1,
    1,
    ion_pot=0.7,
    basis_dim=32,
    kernel_dim=32,
    embedding_dim=32,
    latent_dim=1,
    n_interactions=2,
    n_orbital_layers=2,
).cuda()
sampler = langevin_monte_carlo(wfnet, torch.randn(1_000, 2, 3, device='cuda'), tau=0.1)
for _ in range(50):
    next(sampler)
wfnet

In [None]:
fit_wfnet_multi(
    wfnet,
    (partial(loss_local_energy, E_ref=-1.1), loss_local_energy),
    (torch.optim.Adam(wfnet.parameters(), lr=3e-3) for _ in range(2)),
    partial(wfnet_fit_driver_simple, sampler, n_sampling_steps=1),
    (
        {'samplings': trange(200, desc='pretrain')},
        {'samplings': trange(2000, desc='variance')},
    ),
    (SummaryWriter(f'runs/H2/hannet/07/{s}') for s in ['pretrain', 'variance']),
)


In [None]:
n_walker = 10_000
sampler = langevin_monte_carlo(
    wfnet, torch.randn(n_walker, 2, 3, device='cuda'), tau=0.1
)
rs, _, info = samples_from(sampler, trange(500))
E_loc = batch_eval_tuple(
    local_energy, tqdm(rs.flatten(end_dim=1).split(50_000)), wfnet
)[0].view(n_walker, -1)
info.acceptance.mean()

In [None]:
plt.plot(E_loc0.mean(dim=0).cpu().numpy())
plt.plot(E_loc.mean(dim=0).cpu().numpy())
plt.ylim(-1.2, -1)

In [None]:
plt.hist(E_loc0[:, 50:].flatten().clamp(-2, 0).cpu().numpy(), bins=100)
plt.hist(E_loc[:, 50:].flatten().clamp(-2, 0).cpu().numpy(), bins=100)
plt.xlim(-2, 0)

In [None]:
E_loc[:, 50:].mean().item(), (
    E_loc[:, 50:].mean(dim=1).std() / np.sqrt(E_loc.shape[0])
).item()

In [None]:
rs_flat = rs[:, 50:].flatten(end_dim=1)
rs_flat0 = rs0[:, 50:].flatten(end_dim=1)

In [None]:
density = GaussianKDEstimator(rs_flat.flatten(end_dim=1), bw=0.05, weights=2)
density0 = GaussianKDEstimator(rs_flat0.flatten(end_dim=1), bw=0.05, weights=2)
plot_func_x(density0, [-2, 3], device='cuda')
plot_func_x(density, [-2, 3], device='cuda')
plot_func_x(lambda x: electron_density_of(mf_big, x), [-2, 3], is_torch=False)

In [None]:
pair = pair_correlations_from_samples(rs_flat, 1)
pair0 = pair_correlations_from_samples(rs_flat0, 1)
plot_func(
    lambda r: pair0['ud'](r[:, None]) / pair['decorr'](r[:, None]),
    (0, 5),
    device='cuda',
)
plot_func(
    lambda r: pair['ud'](r[:, None]) / pair['decorr'](r[:, None]), (0, 5), device='cuda'
)

## H2 triplet

### GTO WF

In [None]:
mol = gto.M(
    atom=geomdb['H2'].as_pyscf(),
    unit='bohr',
    basis='cc-pv5z',
    charge=0,
    spin=2,
    cart=True,
)
mf_big = scf.RHF(mol)
mf_big.kernel()

In [None]:
mol = gto.M(
    atom=geomdb['H2'].as_pyscf(),
    unit='bohr',
    basis='6-311g',
    charge=0,
    spin=2,
    cart=True,
)
mf = scf.RHF(mol)
scf_energy = mf.kernel()
gtowf = HFNet.from_pyscf(mf).cuda()

In [None]:
n_walker = 1_000
sampler = langevin_monte_carlo(gtowf, torch.randn(n_walker, 2, 3).cuda(), tau=0.1)
rs, _, info = samples_from(sampler, trange(500))
E_loc = local_energy(rs.flatten(end_dim=1), gtowf)[0].view(n_walker, -1)
info.acceptance.mean()

In [None]:
_ = plt.hist(E_loc[:, 100:].flatten().clamp(-2, 0).cpu().numpy(), bins=100)

In [None]:
E_loc[:, 100:].std()

In [None]:
scf_energy, E_loc[:, 100:].mean().item(), (
    E_loc[:, 100:].mean(dim=1).std() / np.sqrt(E_loc.shape[0])
).item()

In [None]:
plt.plot(blocking(E_loc[:, 100:]).numpy())

In [None]:
plt.plot(autocorr_coeff(range(50), E_loc[:, 100:]).numpy())
plt.axhline()

### HanNet

In [None]:
wfnet0 = HanNet(
    geomdb['H2'],
    2,
    0,
    ion_pot=0.7,
    basis_dim=32,
    kernel_dim=32,
    embedding_dim=32,
    latent_dim=1,
    n_interactions=2,
    n_orbital_layers=2,
).cuda()
n_walker = 10_000
sampler0 = langevin_monte_carlo(
    wfnet0, torch.randn(n_walker, 2, 3, device='cuda'), tau=0.1
)
rs0, _, info = samples_from(sampler0, trange(500, desc='sampling'))
E_loc0 = batch_eval_tuple(
    local_energy, tqdm(rs0.flatten(end_dim=1).split(10_000), desc='E_loc'), wfnet0
)[0].view(n_walker, -1)
info.acceptance.mean()

In [None]:
wfnet = HanNet(
    geomdb['H2'],
    2,
    0,
    ion_pot=0.7,
    basis_dim=32,
    kernel_dim=32,
    embedding_dim=32,
    latent_dim=1,
    n_interactions=2,
    n_orbital_layers=2,
).cuda()
sampler = langevin_monte_carlo(wfnet, torch.randn(1_000, 2, 3, device='cuda'), tau=0.1)
for _ in range(50):
    next(sampler)
wfnet

In [None]:
fit_wfnet_multi(
    wfnet,
    (partial(loss_local_energy, E_ref=-0.8), loss_local_energy),
    (torch.optim.Adam(wfnet.parameters(), lr=3e-3) for _ in range(2)),
    partial(wfnet_fit_driver_simple, sampler, n_sampling_steps=5),
    (
        {'samplings': trange(200, desc='samplings')},
        {'samplings': trange(2000, desc='samplings')},
    ),
    (SummaryWriter(f'runs/H2t/hannet/12/{s}') for s in ['pretrain', 'variance']),
)

In [None]:
n_walker = 10_000
sampler = langevin_monte_carlo(
    wfnet, torch.randn(n_walker, 2, 3, device='cuda'), tau=0.1
)
rs, _, info = samples_from(sampler, trange(500))
E_loc = batch_eval_tuple(
    local_energy, tqdm(rs.flatten(end_dim=1).split(50_000)), wfnet
)[0].view(n_walker, -1)
info.acceptance.mean()

In [None]:
plt.plot(E_loc0.mean(dim=0).cpu().numpy())
plt.plot(E_loc.mean(dim=0).cpu().numpy())
plt.ylim(-0.8, -0.5)

In [None]:
plt.hist(E_loc0[:, 50:].flatten().clamp(-2, 0).cpu().numpy(), bins=100)
plt.hist(E_loc[:, 50:].flatten().clamp(-2, 0).cpu().numpy(), bins=100)
plt.xlim(-2, 0)

In [None]:
E_loc[:, 50:].mean().item(), (
    E_loc[:, 50:].mean(dim=1).std() / np.sqrt(E_loc.shape[0])
).item()

In [None]:
rs_flat = rs[:, 50:].flatten(end_dim=1)
rs_flat0 = rs0[:, 50:].flatten(end_dim=1)

In [None]:
density = GaussianKDEstimator(rs_flat.flatten(end_dim=1), bw=0.05, weights=2)
density0 = GaussianKDEstimator(rs_flat0.flatten(end_dim=1), bw=0.05, weights=2)
plot_func_x(density0, [-2, 3], device='cuda')
plot_func_x(density, [-2, 3], device='cuda')
plot_func_x(lambda x: electron_density_of(mf_big, x), [-2, 3], is_torch=False)

In [None]:
pair = pair_correlations_from_samples(rs_flat, 1)
pair0 = pair_correlations_from_samples(rs_flat0, 1)
plot_func(
    lambda r: pair0['ud'](r[:, None]) / pair['decorr'](r[:, None]),
    (0, 5),
    device='cuda',
)
plot_func(
    lambda r: pair['ud'](r[:, None]) / pair['decorr'](r[:, None]), (0, 5), device='cuda'
)

## LiH

### GTO WF

In [None]:
mol = gto.M(
    atom=geomdb['LiH'].as_pyscf(),
    unit='bohr',
    basis='6-311g',
    charge=0,
    spin=0,
    cart=True,
)
mf = scf.RHF(mol)
scf_energy = mf.kernel()
gtowf = HFNet.from_pyscf(mf).cuda()
mf.mo_energy[:2]

In [None]:
n_walker = 1_000
sampler = langevin_monte_carlo(gtowf, torch.randn(n_walker, 4, 3).cuda(), tau=0.1)
rs, _, info = samples_from(sampler, trange(500))
E_loc = batch_eval_tuple(
    local_energy, tqdm(rs.flatten(end_dim=1).split(50_000)), gtowf
)[0].view(n_walker, -1)
info.acceptance.mean()

In [None]:
plt.plot(E_loc.mean(dim=0).cpu().numpy())
plt.ylim(-8.5, -7.5)

In [None]:
_ = plt.hist(E_loc[:, 100:].flatten().clamp(-11, -3).cpu().numpy(), bins=100)

In [None]:
E_loc[:, 100:].std()

In [None]:
scf_energy, E_loc[:, 100:].mean().item(), (
    E_loc[:, 100:].mean(dim=1).std() / np.sqrt(E_loc.shape[0])
).item()

In [None]:
plt.plot(blocking(E_loc[:, 100:]).numpy())

In [None]:
plt.plot(autocorr_coeff(range(50), E_loc[:, 100:]).numpy())
plt.axhline()

In [None]:
rs_hf_flat = rs[:, 50:].flatten(end_dim=1)

In [None]:
density = GaussianKDEstimator(rs_hf_flat.flatten(end_dim=1), bw=0.05, weights=4)
plot_func_x(density, [-2, 5], device='cuda')
plot_func_x(gtowf.density, [-2, 5], device='cuda')
plt.gca().set_yscale('log')

In [None]:
_ = plot_func_x(gtowf.orbitals, [-2, 5], device='cuda')

In [None]:
pair = pair_correlations_from_samples(rs_hf_flat, 2)
plot_func(
    lambda r: pair['uu'](r[:, None]) / pair['decorr'](r[:, None]), (0, 8), device='cuda'
)
plot_func(
    lambda r: pair['dd'](r[:, None]) / pair['decorr'](r[:, None]), (0, 8), device='cuda'
)
plot_func(
    lambda r: pair['ud'](r[:, None]) / pair['decorr'](r[:, None]), (0, 8), device='cuda'
)

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
axes[0].hist(rs_hf_flat[..., 0].flatten().cpu().numpy(), bins=100, range=[-2, 5])
for idx, ax in zip([(0, 2), (0, 1)], axes[1:]):
    ax.hist2d(
        *rs_hf_flat[:, idx, 0].cpu().numpy().T, bins=100, range=[[-2, 5], [-2, 5]]
    )
    ax.set_aspect(1)

### GTO WF training

In [None]:
mol = gto.M(
    atom=geomdb['LiH'].as_pyscf(),
    unit='bohr',
    basis='6-311g',
    charge=0,
    spin=0,
    cart=True,
)
mf = scf.RHF(mol)
mf.kernel()
wfnet = HFNet(geomdb['LiH'], 2, 2, GTOBasis.from_pyscf(mol)).cuda()

In [None]:
sampler = langevin_monte_carlo(wfnet, torch.randn(1_000, 4, 3, device='cuda'), tau=0.1)
for _ in range(200):
    next(sampler)

In [None]:
debug = DebugContainer()
with SummaryWriter('runs/LiH/gtowf/18/pretrain') as writer:
    fit_wfnet(
        wfnet,
        partial(loss_local_energy, E_ref=-8),
        torch.optim.Adam(wfnet.parameters(), lr=3e-3),
        wfnet_fit_driver_simple(sampler, n_sampling_steps=1, samplings=trange(2000)),
        writer=writer,
        debug=debug,
    )

In [None]:
plot_func_x(wfnet.orbitals, [-2, 5], device='cuda')

In [None]:
n_walker = 1_000
sampler = langevin_monte_carlo(wfnet, torch.randn(n_walker, 4, 3).cuda(), tau=0.1)
rs, _, info = samples_from(sampler, trange(500))
E_loc = batch_eval_tuple(
    local_energy, tqdm(rs.flatten(end_dim=1).split(50_000)), wfnet
)[0].view(n_walker, -1)
info.acceptance.mean()

In [None]:
E_loc[:, 100:].mean().item(), (
    E_loc[:, 100:].mean(dim=1).std() / np.sqrt(E_loc.shape[0])
).item()