In [1]:
import numpy as np 
import jVMC
import jax.numpy as jnp
import jax.random as jrnd
import matplotlib.pyplot as plt
from tqdm import tqdm



In [2]:
N = 10
L = 10
ldim = 8
homFock = jnp.ones((1,1,L),dtype=int)
oneSiteFockStates = jnp.expand_dims(jnp.eye(L,dtype=int)*N,0)
J = 1.
U = 0.
H = jVMC.operator.bosons.BoseHubbard_Hamiltonian1D(L,J,U,lDim=ldim)

depth_RWKV = 2
emb_RWKV =16
hidden_size=32
num_heads = 4


net = jVMC.nets.RpxRWKV(L,LocalHilDim=ldim,hidden_size=hidden_size,num_heads=num_heads,embedding_size=emb_RWKV,num_layers=depth_RWKV,bias=True)
seed = 252
key2 = jrnd.PRNGKey(seed)
numSamp = 2**8


In [3]:

sym= jVMC.util.symmetries.get_orbit_1D(L,"reflection","translation")
sym_net = jVMC.nets.sym_wrapper.SymNet(sym,net,avgFun=jVMC.nets.sym_wrapper.avgFun_Coefficients_Sep_real)
gum_net = jVMC.nets.gumbel_wrapper(net)
par_net = jVMC.nets.particle_conservation(net,N)

psi_sym_net= jVMC.vqs.NQS(sym_net,seed=seed)
psi_gum_net= jVMC.vqs.NQS(gum_net,seed=seed)
psi_par_net= jVMC.vqs.NQS(par_net,seed=seed)

psi_sym_net(homFock),psi_gum_net(homFock),psi_par_net(homFock)


In [4]:
sampler_sym_net = jVMC.sampler.MCSampler(psi_sym_net,(L,),key2)
sampler_gum_net = jVMC.sampler.MCSampler(psi_gum_net,(L,),key2)
sampler_par_net = jVMC.sampler.MCSampler(psi_par_net,(L,),key2)


sampler_sym_net.sample(numSamples=numSamp),
sampler_gum_net.sample(numSamples=numSamp),
sampler_par_net.sample(numSamples=numSamp)

ddd
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd


(Array([[[4, 1, 2, ..., 0, 0, 1],
         [3, 3, 0, ..., 0, 0, 0],
         [4, 2, 1, ..., 0, 0, 0],
         ...,
         [7, 0, 0, ..., 0, 0, 0],
         [7, 1, 2, ..., 0, 0, 0],
         [1, 5, 1, ..., 0, 0, 0]]], dtype=int64),
 Array([[-4.19908259, -4.8489091 , -3.25011363, -2.64333904, -3.68303335,
         -3.71974545, -2.78519585, -2.49714454, -2.11536222, -3.39915681,
         -3.23016532, -3.88882351, -4.18678252, -5.4476143 , -5.17618972,
         -1.68036564, -4.30306016, -1.71812263, -5.41117489, -4.28778551,
         -6.11046554, -1.68036564, -4.72351633, -2.89768293, -3.10027164,
         -2.24520924, -2.88307923, -2.71378959, -4.14176517, -2.71378959,
         -3.27489602, -2.11536222, -4.14675879, -6.88517442, -2.80770231,
         -4.40028294, -2.65247853, -4.26726745, -2.78519585, -1.71812263,
         -3.45738749, -2.65000042, -2.78519585, -2.86220538, -2.65247853,
         -4.37012194, -2.3876904 , -3.42100194, -3.63192864, -3.57677257,
         -4.76191076, -3.5

In [5]:

sym_par_net = jVMC.nets.sym_wrapper.SymNet(sym,par_net,avgFun=jVMC.nets.sym_wrapper.avgFun_Coefficients_Sep_real)
gum_par_net = jVMC.nets.gumbel_wrapper(par_net)
sym_gum_net = jVMC.nets.sym_wrapper.SymNet(sym,gum_net,avgFun=jVMC.nets.sym_wrapper.avgFun_Coefficients_Sep_real)

psi_sym_par_net= jVMC.vqs.NQS(sym_par_net,seed=1)
psi_gum_par_net= jVMC.vqs.NQS(gum_par_net,seed=1)
psi_sym_gum_net= jVMC.vqs.NQS(sym_gum_net,seed=1)

psi_sym_par_net(homFock),psi_gum_par_net(homFock),psi_sym_gum_net(homFock)


ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd


(Array([[-7.73128439]], dtype=float64),
 Array([[-7.73128439]], dtype=float64),
 Array([[-9.97147022]], dtype=float64))

In [6]:
type(gum_par_net.net) == "particle_conservation"

False

In [7]:
sampler_sym_par_net = jVMC.sampler.MCSampler(psi_sym_par_net,(L,),key2)
sampler_gum_par_net = jVMC.sampler.MCSampler(psi_gum_par_net,(L,),key2)
sampler_sym_gum_net = jVMC.sampler.MCSampler(psi_sym_gum_net,(L,),key2)


sampler_sym_par_net.sample(numSamples=numSamp),sampler_gum_par_net.sample(numSamples=numSamp),sampler_sym_gum_net.sample(numSamples=numSamp)

ddd
ddd
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee


((Array([[[0, 2, 4, ..., 0, 0, 0],
          [5, 3, 0, ..., 0, 0, 0],
          [0, 2, 1, ..., 0, 0, 0],
          ...,
          [0, 3, 0, ..., 0, 7, 0],
          [0, 0, 0, ..., 0, 0, 0],
          [0, 3, 1, ..., 0, 0, 0]]], dtype=int64),
  Array([[-3.85815376, -5.31880522, -4.56057926, -6.90613332, -5.25470245,
          -5.7499025 , -4.73238433, -5.62008373, -4.74057985, -3.85815376,
          -3.85815376, -4.94400121, -3.51397006, -5.67579463, -5.67029128,
          -4.18274784, -4.31422449, -4.91049585, -6.75591418, -3.70713157,
          -3.85815376, -4.09802618, -3.54048418, -5.02828572, -3.74795889,
          -3.60073298, -3.80190115, -3.35282927, -4.29720406, -3.12172309,
          -4.0755143 , -3.35282927, -6.83150639, -7.3546048 , -5.52259966,
          -4.71854332, -4.3003053 , -5.53405647, -3.79779807, -3.79779807,
          -5.48530187, -3.78953634, -3.79779807, -4.2509559 , -4.05212976,
          -5.08321458, -3.12172309, -4.58760845, -4.2426468 , -4.46747829,
         

In [8]:
sym_gum_par_net = jVMC.nets.sym_wrapper.SymNet(sym,gum_par_net,avgFun=jVMC.nets.sym_wrapper.avgFun_Coefficients_Sep_real)

psi_sym_gum_par_net= jVMC.vqs.NQS(sym_gum_par_net,seed=1)

psi_sym_par_net(homFock)

ddd


Array([[-7.73128439]], dtype=float64)

In [9]:
sampler_sym_gum_par_net = jVMC.sampler.MCSampler(psi_sym_gum_par_net,(L,),key2)

sampler_sym_gum_par_net.sample(numSamples=numSamp)

ddd
ddd
ddd
ddd
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee


(Array([[[0, 3, 5, ..., 0, 0, 0],
         [4, 6, 0, ..., 0, 0, 0],
         [0, 0, 4, ..., 0, 0, 1],
         ...,
         [1, 0, 2, ..., 0, 5, 1],
         [0, 0, 0, ..., 0, 1, 0],
         [0, 2, 2, ..., 0, 0, 0]]], dtype=int64),
 Array([[-2.30894655, -1.92093929, -4.27084127, -3.93503066, -2.27548233,
         -2.89680184, -4.45756688, -2.0025845 , -3.59227293, -4.25890079,
         -4.9551252 , -4.07401432, -2.17798461, -3.69749724, -4.85048758,
         -5.02430214, -2.98377138, -3.0750483 , -4.39774119, -2.35980781,
         -2.57710987, -2.53004455, -2.27417264, -3.09664276, -3.6636301 ,
         -3.29943239, -4.27133367, -5.80749121, -4.90623616, -3.55692201,
         -2.52841354, -6.28434773, -1.99438203, -1.73985631, -2.99915267,
         -2.83013395, -3.63863745, -2.44789913, -6.27105057, -2.2673498 ,
         -2.59522982, -2.19360038, -3.42411612, -4.84921271, -5.16949447,
         -2.87620443, -5.41210236, -2.82776929, -3.95265545, -3.19432716,
         -3.0313257 , -3.0

In [10]:
key = jrnd.PRNGKey(1)
key2 = jrnd.PRNGKey(1)

sampler = sampler_gum_par_net
psi = psi_par_net
lr_SR = 1e-5
minSR_equation = jVMC.util.MinSR(sampler, makeReal='real',diagonalShift=1e-2,diagonalMulti=1e-3)
stepperSR = jVMC.util.stepper.Euler(timeStep=lr_SR)  


In [11]:
training_steps= 10
resTraining = np.zeros((training_steps,2))
resTraining_gumbel = np.zeros((training_steps,2))

numS = 2**10
pbar = tqdm(range(training_steps))
for n,p in enumerate(pbar):
    
    dpOld = psi.get_parameters()        
    
    #print(dpOld)
    dp, _ = stepperSR.step(0, minSR_equation, dpOld, hamiltonian=H, psi=psi, numSamples=numS)
    psi.set_parameters(jnp.real(dp))
    resTraining[n] = [jnp.real(minSR_equation.ElocMean0) , minSR_equation.ElocVar0 ]
    #print(dp)
    
    pbar.set_description(f"energy: {resTraining[n][0]:.4f}+-{np.sqrt(resTraining[n][1]):.4f}")
    


  0%|          | 0/100 [00:00<?, ?it/s]

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd


energy: -12.5299+-4.6697:   1%|          | 1/100 [00:16<27:10, 16.47s/it]

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd


energy: -12.5028+-4.6667:   2%|▏         | 2/100 [00:17<12:31,  7.67s/it]

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd


energy: -12.5484+-4.9231:   3%|▎         | 3/100 [00:19<07:50,  4.85s/it]

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd


energy: -12.4102+-4.6265:   4%|▍         | 4/100 [00:20<05:37,  3.51s/it]

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd


energy: -12.5881+-4.8739:   5%|▌         | 5/100 [00:23<05:11,  3.28s/it]

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd


energy: -12.4953+-4.6997:   6%|▌         | 6/100 [00:25<04:12,  2.69s/it]

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd


energy: -12.5800+-4.8174:   7%|▋         | 7/100 [00:26<03:32,  2.29s/it]

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd


energy: -12.5781+-4.9518:   8%|▊         | 8/100 [00:28<03:08,  2.05s/it]

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd


energy: -12.6076+-4.9060:   9%|▉         | 9/100 [00:29<02:48,  1.85s/it]

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd


energy: -12.5664+-4.6860:  10%|█         | 10/100 [00:31<02:34,  1.72s/it]

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd


energy: -12.5761+-4.8448:  11%|█         | 11/100 [00:32<02:23,  1.61s/it]

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd


energy: -12.5597+-4.7777:  12%|█▏        | 12/100 [00:33<02:15,  1.54s/it]

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd


energy: -12.5775+-4.8248:  13%|█▎        | 13/100 [00:35<02:08,  1.48s/it]

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd


energy: -12.5432+-4.6511:  14%|█▍        | 14/100 [00:36<02:03,  1.44s/it]

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd


energy: -12.5324+-4.8005:  15%|█▌        | 15/100 [00:37<02:00,  1.42s/it]

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd


energy: -12.5720+-4.7367:  16%|█▌        | 16/100 [00:39<01:58,  1.41s/it]

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd


energy: -12.6438+-4.9543:  17%|█▋        | 17/100 [00:40<01:57,  1.41s/it]

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd


energy: -12.5015+-4.6463:  18%|█▊        | 18/100 [00:42<01:55,  1.41s/it]

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd


energy: -12.5035+-4.8197:  19%|█▉        | 19/100 [00:43<01:54,  1.41s/it]

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd


energy: -12.5284+-4.7681:  20%|██        | 20/100 [00:45<01:52,  1.41s/it]

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd


energy: -12.5391+-4.7373:  21%|██        | 21/100 [00:46<01:50,  1.40s/it]

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd


energy: -12.5851+-4.9044:  22%|██▏       | 22/100 [00:47<01:49,  1.40s/it]

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd


energy: -12.5149+-4.8488:  23%|██▎       | 23/100 [00:49<01:46,  1.39s/it]

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd


energy: -12.5456+-4.7798:  24%|██▍       | 24/100 [00:50<01:43,  1.37s/it]

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd


energy: -12.5224+-4.7263:  25%|██▌       | 25/100 [00:51<01:41,  1.35s/it]

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd


energy: -12.5924+-4.8404:  26%|██▌       | 26/100 [00:53<01:38,  1.33s/it]

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd


energy: -12.5697+-4.8612:  27%|██▋       | 27/100 [00:54<01:35,  1.30s/it]

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd


energy: -12.5790+-4.7397:  28%|██▊       | 28/100 [00:55<01:33,  1.29s/it]

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd


energy: -12.6215+-4.9360:  29%|██▉       | 29/100 [00:56<01:33,  1.31s/it]

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd


energy: -12.6178+-4.9195:  30%|███       | 30/100 [00:58<01:31,  1.31s/it]

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd


energy: -12.4756+-4.6643:  31%|███       | 31/100 [00:59<01:30,  1.31s/it]

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd


energy: -12.5930+-4.8538:  32%|███▏      | 32/100 [01:00<01:29,  1.32s/it]

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd


energy: -12.4561+-4.6100:  33%|███▎      | 33/100 [01:02<01:28,  1.31s/it]

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd


energy: -12.6260+-4.8187:  34%|███▍      | 34/100 [01:03<01:26,  1.31s/it]

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd


energy: -12.5077+-4.8593:  35%|███▌      | 35/100 [01:04<01:24,  1.31s/it]

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd


energy: -12.5920+-4.7863:  36%|███▌      | 36/100 [01:06<01:23,  1.31s/it]

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd


energy: -12.5603+-4.7721:  37%|███▋      | 37/100 [01:08<01:48,  1.72s/it]

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd


energy: -12.5059+-4.7073:  38%|███▊      | 38/100 [01:10<01:40,  1.62s/it]

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd


energy: -12.5416+-4.7976:  39%|███▉      | 39/100 [01:11<01:32,  1.52s/it]

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd


energy: -12.5863+-4.8251:  40%|████      | 40/100 [01:12<01:28,  1.47s/it]

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd


energy: -12.5217+-4.8539:  41%|████      | 41/100 [01:14<01:25,  1.44s/it]

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd


energy: -12.5382+-4.7950:  42%|████▏     | 42/100 [01:15<01:20,  1.38s/it]

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd


energy: -12.5522+-4.7705:  43%|████▎     | 43/100 [01:16<01:16,  1.34s/it]

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd


energy: -12.4669+-4.7542:  44%|████▍     | 44/100 [01:17<01:13,  1.32s/it]

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd


energy: -12.5776+-4.9009:  45%|████▌     | 45/100 [01:19<01:12,  1.32s/it]

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd


energy: -12.5317+-4.7496:  46%|████▌     | 46/100 [01:20<01:10,  1.31s/it]

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd


energy: -12.5234+-4.8624:  47%|████▋     | 47/100 [01:21<01:32,  1.74s/it]


eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd
ddd


KeyboardInterrupt: 

In [12]:
observalbes_dict = {
                    "H": [H], 
                    }
out_dict = jVMC.util.util.measure(observalbes_dict, psi_gum_par_net, sampler_gum_par_net)

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd


In [13]:
out_dict["H"]["mean"]

Array([-11.73034296], dtype=float64)

In [14]:
s, log_psi, psi_p = sampler_gum_par_net.sample()

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee


In [15]:
psi_p.sum(-1)

Array([1.], dtype=float64)

In [16]:
s, log_psi, psi_p = sampler_gum_par_net.sample()

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee


In [17]:
matel = H.get_O_loc(s,psi_gum_par_net,log_psi,0.)

ddd
ddd


In [18]:
sum(matel.flatten() * psi_p.flatten())

Array(-11.6678466+0.j, dtype=complex128)

In [19]:
s, log_psi, psi_p = sampler_gum_net.sample()
log_psi-psi_gum_net(s)

Array([[-1.77635684e-15,  1.77635684e-15,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00, -1.77635684e-15,
         0.00000000e+00,  0.00000000e+00,  1.77635684e-15,
         1.77635684e-15,  0.00000000e+00,  0.00000000e+00,
         1.77635684e-15,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  1.77635684e-15,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  1.77635684e-15, -1.77635684e-15,
         0.00000000e+00,  1.77635684e-15,  0.00000000e+00,
         1.77635684e-15,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00, -3.55271368e-15,
         0.00000000e+00,  1.77635684e-15,  0.00000000e+00,
        -1.77635684e-15,  0.00000000e+00,  0.00000000e+0

In [20]:
s, log_psi, psi_p = sampler_par_net.sample()
log_psi-psi_par_net(s)

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd


Array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0.]], dtype=float64)

In [21]:
s, log_psi, psi_p = sampler_sym_par_net.sample()
log_psi-psi_sym_par_net(s)

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd


Array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0.]], dtype=float64)

In [42]:
s, log_psi, psi_p = sampler_gum_par_net.sample()
log_psi-psi_gum_par_net(s),jnp.linalg.norm(log_psi-psi_gum_par_net(s))

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd


(Array([[ 0.00000000e+00, -4.44089210e-16,  0.00000000e+00,
         -8.88178420e-16, -2.22044605e-16,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         -4.44089210e-16,  8.88178420e-16,  0.00000000e+00,
         -8.88178420e-16,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00, -4.44089210e-16,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00, -4.44089210e-16,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         -4.44089210e-16,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  8.88178420e-16,  0.00000000e+00,
          0.00000000e+00,  4.44089210e-16,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  1.77635684e-15,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          4.44089210e-16,  0.00000000e+0

In [43]:

s, log_psi, psi_p = sampler_sym_gum_par_net.sample()
log_psi-psi_sym_gum_par_net(s),jnp.linalg.norm(log_psi-psi_sym_gum_par_net(s))


eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd


(Array([[ 1.4530292 ,  1.43941122,  1.38244139,  1.45216817,  1.43884892,
          1.21648226,  1.40919358,  1.11632493,  0.4410388 ,  1.40438142,
          1.45258508,  1.28532687,  1.24149338, -0.18225029, -0.73712938,
          1.33152887,  1.46638469,  1.33499006,  1.01330861,  1.26166614,
         -0.24177383,  1.31984993,  1.37465041,  1.21252696,  0.77351502,
          1.4613435 ,  1.4288821 ,  1.1681246 ,  0.92827494,  1.387197  ,
          1.30385189,  1.45497334,  0.07376508,  1.40592058,  1.44576337,
          0.06981864,  1.42598854, -0.01871426,  1.45392503,  1.08375212,
          1.29917717,  1.45951135, -0.21863893,  1.45437278,  1.32784683,
          1.15095503,  1.38352594,  1.40082638,  1.44279858,  1.44809787,
          1.10562508,  1.19043026,  1.42886742,  1.0915988 ,  1.27194684,
          1.37807778,  1.03163522,  1.44825489,  0.6821156 ,  1.20473351,
          0.99887232, -0.0695137 , -0.23785294,  1.26370556,  0.99762031,
          0.80717161,  1.43475284,  1.

In [36]:

psi_sym_gum_par_net(s),log_psi


ddd


(Array([[-3.56453123, -5.71490823, -5.18563143, -2.85728833, -3.86526082,
         -6.66995855, -6.30585801, -5.49502102, -4.98748592, -4.33754793,
         -6.66995855, -5.08078779, -4.03662122, -3.02656437, -4.98205329,
         -4.50468837, -3.20690899, -5.44963582, -3.78318421, -5.05330341,
         -3.98765535, -3.28620598, -3.80773022, -3.97146835, -3.02656437,
         -4.58946522, -3.45324181, -3.67932457, -4.68345248, -3.41177808,
         -5.23270328, -3.87032932, -4.14155936, -4.50476669, -3.56453123,
         -3.81722952, -4.18408824, -4.50265334, -5.5018704 , -3.94904738,
         -3.99015802, -7.74734369, -5.86417708, -4.59385522, -3.64393662,
         -5.17203789, -4.40049743, -4.5924636 , -3.4818365 , -4.28517245,
         -3.90142102, -6.88072226, -4.5059374 , -3.67941092, -4.60308581,
         -3.73212642, -4.00359189, -7.20214702, -3.73212642, -2.85728833,
         -4.38951603, -6.06212064, -5.50125011, -6.05315878, -4.901238  ,
         -4.23571951, -6.00682215, -7.

In [37]:

s, log_psi, psi_p = sampler_sym_gum_net.sample()
log_psi-psi_sym_gum_net(s),jnp.linalg.norm(log_psi-psi_sym_gum_net(s))


(Array([[ 0.38905178,  0.10369599, -0.10341643,  0.23546912,  0.09422586,
          0.01196003,  0.22954125,  0.22993449,  0.36646854,  0.05675554,
         -0.08525708,  0.18829086, -0.07986208,  0.11607538,  0.27442611,
          0.43634784, -0.15953018, -0.06463051, -0.02719975, -0.03175168,
          0.03021513,  0.45963807,  0.18444397,  0.06075557,  0.26223037,
          0.1000149 ,  0.55714943,  0.09839882,  0.08646461, -0.32715463,
         -0.44926808,  0.18665019, -0.34273435,  0.11635797,  0.14821961,
         -0.8045634 , -0.36953056,  0.12646108,  0.26922058, -0.0193669 ,
          0.43227823, -0.06002791, -0.06462168, -0.38497428,  0.0047674 ,
          0.34959672,  0.17159901,  0.3310727 ,  0.08374074,  0.35113672,
         -0.14444387, -0.11409634,  0.04708134,  0.05143834,  0.09525678,
          0.15001612,  0.05491732,  0.0273707 ,  0.13651502,  0.24730119,
         -0.19738408, -0.40224497,  0.01102439,  0.09071146, -0.03815406,
          0.01204656,  0.11277906,  0.

In [26]:
s, log_psi, psi_p = sampler_gum_net.sample()
log_psi-psi_gum_net(s),jnp.linalg.norm(log_psi-psi_gum_net(s))

(Array([[ 1.77635684e-15,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  3.55271368e-15,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         -3.55271368e-15,  1.77635684e-15,  0.00000000e+00,
          1.77635684e-15,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00, -1.77635684e-15,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         -1.77635684e-15, -1.77635684e-15,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00, -1.77635684e-15,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+0

In [40]:
s, log_psi, psi_p = sampler_gum_par_net.sample()
log_psi,psi_gum_par_net(s),jnp.linalg.norm(log_psi-psi_gum_par_net(s))

eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
eeee
ddd
ddd


(Array([[-3.94210889, -1.79668022, -1.96539964, -2.99511718, -2.53016665,
         -2.88207403, -1.94109619, -4.37786981, -3.44040408, -7.0760002 ,
         -4.91612725, -2.78711886, -3.21402727, -4.64044913, -1.91776038,
         -4.80297233, -2.65310981, -2.56114695, -4.95366704, -3.83374573,
         -2.21832614, -3.46833455, -5.54357904, -1.98568849, -3.90094952,
         -3.07938888, -2.85943676, -2.71080492, -2.91032711, -2.95444784,
         -2.88161107, -2.64404931, -3.06655018, -5.1953147 , -4.08403321,
         -3.03849127, -2.43592063, -4.90354355, -3.2385217 , -3.35825654,
         -3.03371354, -2.15359106, -4.25554717, -5.04980893, -3.50398066,
         -3.05715471, -5.52018569, -5.5810453 , -2.87149806, -4.54244574,
         -2.95227474, -3.00181782, -2.67122153, -2.05090189, -4.10060529,
         -4.68474971, -3.83931713, -3.67850962, -3.31035292, -2.9061601 ,
         -2.95508743, -4.27173237, -3.48437551, -6.2798976 , -3.68196784,
         -3.61373302, -2.90616105, -2.

In [28]:
psi_gum_par_net(s)

ddd


Array([[-3.4897816 , -2.64404931, -5.86110948, -2.41485172, -5.29769603,
        -2.30930369, -2.15359106, -5.02622702, -2.21832614, -1.96539964,
        -1.98568849, -4.14406984, -2.91032711, -2.72549654, -6.02850292,
        -2.82093517, -1.94109619, -2.36758868, -3.22002327, -1.79668022,
        -4.55030577, -2.56114695, -3.71851512, -4.60756634, -5.08754813,
        -2.65310981, -3.47939066, -3.34403749, -3.34544123, -3.43802157,
        -3.74589068, -2.94164414, -2.28221522, -3.11820872, -3.17529649,
        -5.14583303, -2.72127173, -3.08576862, -3.40762833, -3.27080102,
        -4.71405424, -3.05715471, -2.54425594, -3.47400008, -4.43985968,
        -5.18188461, -4.23356146, -2.85943676, -2.53016665, -3.68523031,
        -3.13213836, -5.33493286, -3.05212227, -2.90616105, -3.62090676,
        -3.31896121, -6.89660526, -3.19808622, -4.24491715, -3.91139786,
        -2.67253245, -3.31086898, -6.26471313, -3.94924487, -5.9333759 ,
        -3.72872124, -1.91776038, -3.44040408, -3.8