In [31]:
import numpy as np
import jax.numpy as jnp
import fvmc
from fvmc.utils import load_pickle, displace_matrix, gen_kidx, split_spin
from fvmc.ewaldsum import gen_lattice_displacements, gen_pbc_disp_fn
from test_hegspin import _common_spin_config, _common_config
from fvmc.wavefunction.heg import heg_rs

In [32]:
n_elec = 12
n_k = n_elec
seed = 42
iterations = 100
cfg = _common_config(n_elec=n_elec, n_k=n_k,
                          seed=seed, iterations=iterations)

n_up,n_dn 12 0
coeeff shape (12, 12)


In [54]:
n_lat = 1
pref = 0.5
cell = latvec = jnp.asarray(cfg.system.cell)
recvec = jnp.linalg.inv(latvec).T
kpts = jnp.asarray(gen_kidx(2, n_k, close_shell=False)) #n_dim = 2
klist = 2 * jnp.pi * kpts @ recvec # [n_k,n_dim]

disp_fn = gen_pbc_disp_fn(latvec)
lat_disp = gen_lattice_displacements(latvec, n_lat)
lat_norm = jnp.linalg.norm(lat_disp, axis=-1)
lat_norm = lat_norm[lat_norm > 0]
lamb = 0.2 # 0.04 #0.001 #0.2
pref = 1 #1e2
#self_img = pref * jnp.sum(jnp.exp(-0.5* lamb * lat_norm * lat_norm))

In [55]:
n_sigma=1
lx = cell[0, 0]
ly = cell[1, 1]
area = lx * ly
dkx= 2 * np.pi / lx
dky= 2 * np.pi / ly
n_qmax = int(np.max(np.abs(klist)) / min(dkx, dky)) * 2
q1x = np.arange(-n_qmax, n_qmax + 1) * dkx
q1y = np.arange(-n_qmax, n_qmax + 1) * dky
qlist = np.stack(np.meshgrid(q1x, q1y, indexing='ij'), axis=-1).reshape(-1, 2)
n_q = qlist.shape[0]
print(klist.shape, qlist.shape)
print(qlist.shape)
# w is coeff, but pull the spin dim to the front
coeff = cfg.restart.params['params']['orbital_fn']['VmapDense_0']['kernel']
print(coeff.shape)
w = coeff[None,...]
#w = np.moveaxis(np.asarray(coeff), -1, 0) # [n_spin, n_k, n_orb(=n_elec)]
print('w',w.shape)
n_orb = w.shape[2]
_w = w.reshape(n_sigma * n_k, n_orb)
_q, _r = np.linalg.qr(_w)
w = _q.reshape(n_sigma, n_k, n_orb)
print(w.shape)

(12, 2) (81, 2)
(81, 2)
(12, 12)
w (1, 12, 12)
(1, 12, 12)


In [56]:
def _check_in_arr(el, ar):
        """
        el : dim 2
        ar : dim Nx2
        """
        diff = np.linalg.norm(el - ar, axis=1)
        return np.any(np.isclose(0, diff))

def _get_idx_in_arr(el, ar):
        """
        el : dim 2
        ar : dim Nx2
        """
        diff = np.linalg.norm(el - ar, axis=1)
        idx = np.argmin(diff)
        assert np.isclose(diff[idx], 0)
        return idx

In [57]:
def gaussian_q(qvec):
        qnorm = np.linalg.norm(qvec, axis=-1) #+(1e-8)**2)
        v = 2 * np.pi / lamb * pref * np.exp(-qnorm**2 / (2 * lamb))
        return v

In [58]:
def get_w_plus(w):
    w_plus = np.zeros((n_sigma, n_k, n_q, n_orb), dtype=complex)
    for iq in range(n_q):
        for ik1 in range(n_k):
                k1pq = klist[ik1] + qlist[iq]
                if _check_in_arr(k1pq, klist):
                    ind1 = _get_idx_in_arr(k1pq, klist)
                    w_plus[:, ik1, iq, :] = w[:, ind1, :]
    return w_plus

def get_w_minus(w):
    w_minus = np.zeros((n_sigma, n_k, n_q, n_orb), dtype=complex)
    for iq in range(n_q):
        for ik2 in range(n_k):
                k2mq = klist[ik2] - qlist[iq]
                if _check_in_arr(k2mq, klist):
                    ind2 = _get_idx_in_arr(k2mq, klist)
                    w_minus[:, ik2, iq, :]=w[:, ind2, :]
    return w_minus

In [59]:
def calc_e_kin(w):
    ekin_pref = (klist[:, 0]**2 + klist[:, 1]**2) / 2
    ekin = np.einsum('ijk, j ->', w.conj() * w, ekin_pref)
    return ekin

In [60]:
def calc_e_pot(w):
    w_plus = get_w_plus(w)
    w_minus = get_w_minus(w)
    v_pref = 1 / (2 * area) * gaussian_q(qlist)
    e_ha =  np.einsum('sija, sia, j, tljb, tlb',
                    w_plus.conj(), w, v_pref, w_minus.conj(), w).real
    e_fo = -np.einsum('sija, sib, j, tljb, tla', \
                    w_plus.conj(), w, v_pref, w_minus.conj(), w).real
    e_bg = -0.5 * gaussian_q(np.array([0,0])) * n_elec**2 / area
    return e_ha, e_fo, e_bg

In [61]:
fchk = '/mnt/home/csmith1/ceph/excitedStates/fvmc/tests/wavefunction/tmp/checkpoint.pkl'
params = load_pickle(fchk)

In [62]:
coeffs = params[1]['params']['orbital_fn']['VmapDense_0']['kernel']
#print(coeffs)

In [63]:
#print(w)

In [66]:
ref_epot = calc_e_pot(w)
print(ref_epot)
ref_epot = np.sum(calc_e_pot(w))
ref_ekin = calc_e_kin(w)
ref_etot = ref_ekin + ref_epot
print("reference:")
print("ekin:", ref_ekin)
print("epot:", ref_epot)
print("etot:", ref_etot)
#print(calc_e_pot(w))
w_plus = get_w_plus(w)
w_minus = get_w_minus(w)
v_pref = 1 / (2 * area) * gaussian_q(qlist)
print(w_plus.shape, w.shape, v_pref.shape, w_minus.shape, w.shape)
e_ha =  np.einsum('sija, sia, j, tljb, tlb->ialb',
                    w_plus.conj(), w, v_pref, w_minus.conj(), w).real
print('eha sum:',e_ha.sum())
np.save('/mnt/home/csmith1/ceph/excitedStates/pyscf/e_ha.npy',e_ha)

e_fo = -np.einsum('sija, sib, j, tljb, tla->iabl', \
                    w_plus.conj(), w, v_pref, w_minus.conj(), w).real
#print(e_fo)
print(e_fo.shape)
print(e_fo.sum())

np.save('/mnt/home/csmith1/ceph/excitedStates/pyscf/e_fo.npy',e_fo)

print('ebg',-0.5 * gaussian_q(np.array([0,0])) * n_elec**2 / area)


(0.0666666666666667, -0.06592112202089512, Array(-0.06666667, dtype=float64))
reference:
ekin: 0.01377143961733388
epot: -0.0659211220208951
etot: -0.05214968240356122
(1, 12, 81, 12) (1, 12, 12) (81,) (1, 12, 81, 12) (1, 12, 12)
eha sum: 0.06666666666666671
(12, 12, 12, 12)
-0.06592112202089509
ebg -0.06666666666666668


In [65]:
ekin_pref = (klist[:, 0]**2 + klist[:, 1]**2) / 2
hcore = jnp.einsum('ijk, j -> k', w.conj() * w, ekin_pref)
print(hcore.sum())
print((w.conj() * w).shape)
np.save('/mnt/home/csmith1/ceph/excitedStates/pyscf/hcore.npy',hcore)

0.013771439617333882
(1, 12, 12)
