In [1]:
import numpy as np 
import jVMC
import jax.numpy as jnp
import jax.random as jrnd
import matplotlib.pyplot as plt
from tqdm import tqdm



In [2]:
N = 10
L = 10
ldim = 8
homFock = jnp.ones((1,1,L),dtype=int)
oneSiteFockStates = jnp.expand_dims(jnp.eye(L,dtype=int)*N,0)
J = 1.
U = 0.
H = jVMC.operator.bosons.BoseHubbard_Hamiltonian1D(L,J,U,lDim=ldim)

depth_RWKV = 2
emb_RWKV =16
hidden_size=32
num_heads = 4


net = jVMC.nets.RpxRWKV(L,LocalHilDim=ldim,hidden_size=hidden_size,num_heads=num_heads,embedding_size=emb_RWKV,num_layers=depth_RWKV,bias=True)
seed = 252
key2 = jrnd.PRNGKey(seed)
numSamp = 2**8


In [3]:
sym= jVMC.util.symmetries.get_orbit_1D(L,"reflection","translation")
sym_net = jVMC.nets.sym_wrapper.SymNet(sym,net,avgFun=jVMC.nets.sym_wrapper.avgFun_Coefficients_Sep_real)
gum_net = jVMC.nets.gumbel_wrapper(net)
par_net = jVMC.nets.particle_conservation(net,N)

psi_sym_net= jVMC.vqs.NQS(sym_net,seed=seed)
psi_gum_net= jVMC.vqs.NQS(gum_net,seed=seed)
psi_par_net= jVMC.vqs.NQS(par_net,seed=seed)

psi_sym_net(homFock),psi_gum_net(homFock),psi_par_net(homFock)


(Array([[-9.55769659]], dtype=float64),
 Array([[-9.55769659]], dtype=float64),
 Array([[-6.91764096]], dtype=float64))

In [4]:
sampler_sym_net = jVMC.sampler.MCSampler(psi_sym_net,(L,),key2)
sampler_gum_net = jVMC.sampler.MCSampler(psi_gum_net,(L,),key2)
sampler_par_net = jVMC.sampler.MCSampler(psi_par_net,(L,),key2)


sampler_sym_net.sample(numSamples=numSamp),
sampler_gum_net.sample(numSamples=numSamp),
sampler_par_net.sample(numSamples=numSamp)

(Array([[[4, 1, 2, ..., 0, 0, 1],
         [3, 3, 0, ..., 0, 0, 0],
         [4, 2, 1, ..., 0, 0, 0],
         ...,
         [7, 0, 0, ..., 0, 0, 0],
         [7, 1, 2, ..., 0, 0, 0],
         [1, 5, 1, ..., 0, 0, 0]]], dtype=int64),
 Array([[-4.25427984, -4.38302612, -3.59688655, -3.45564979, -3.21237142,
         -3.95906369, -2.75755163, -2.54363527, -2.29970936, -3.79131218,
         -3.79742102, -3.9861766 , -4.35102625, -4.89099899, -4.91961068,
         -2.53380661, -3.83722561, -2.7452417 , -4.87902445, -4.49060097,
         -5.07962711, -2.53380661, -4.3695826 , -2.89019481, -3.38656752,
         -2.75269043, -2.6067025 , -2.73556202, -4.06247653, -2.73556202,
         -3.24813742, -2.29970936, -4.42149826, -6.41539808, -3.30678576,
         -4.5991015 , -3.24097709, -3.76586689, -2.75755163, -2.7452417 ,
         -3.20909566, -2.42565637, -2.75755163, -3.02687588, -3.24097709,
         -3.62728828, -2.11224357, -2.83875549, -3.21829511, -3.39620564,
         -4.56168626, -3.4

In [5]:

sym_par_net = jVMC.nets.sym_wrapper.SymNet(sym,par_net,avgFun=jVMC.nets.sym_wrapper.avgFun_Coefficients_Sep_real)
gum_par_net = jVMC.nets.gumbel_wrapper(par_net)
sym_gum_net = jVMC.nets.sym_wrapper.SymNet(sym,gum_net,avgFun=jVMC.nets.sym_wrapper.avgFun_Coefficients_Sep_real)

psi_sym_par_net= jVMC.vqs.NQS(sym_par_net,seed=1)
psi_gum_par_net= jVMC.vqs.NQS(gum_par_net,seed=1)
psi_sym_gum_net= jVMC.vqs.NQS(sym_gum_net,seed=1)

psi_sym_par_net(homFock),psi_gum_par_net(homFock),psi_sym_gum_net(homFock)


(Array([[-7.75558061]], dtype=float64),
 Array([[-7.75558061]], dtype=float64),
 Array([[-9.97147022]], dtype=float64))

In [6]:
type(gum_par_net.net) == "particle_conservation"

False

In [7]:
sampler_sym_par_net = jVMC.sampler.MCSampler(psi_sym_par_net,(L,),key2)
sampler_gum_par_net = jVMC.sampler.MCSampler(psi_gum_par_net,(L,),key2)
sampler_sym_gum_net = jVMC.sampler.MCSampler(psi_sym_gum_net,(L,),key2)


sampler_sym_par_net.sample(numSamples=numSamp),sampler_gum_par_net.sample(numSamples=numSamp),sampler_sym_gum_net.sample(numSamples=numSamp)

((Array([[[0, 2, 4, ..., 0, 0, 0],
          [5, 3, 0, ..., 0, 0, 0],
          [0, 2, 1, ..., 0, 0, 0],
          ...,
          [0, 3, 0, ..., 0, 7, 0],
          [0, 0, 0, ..., 0, 0, 0],
          [0, 3, 1, ..., 0, 0, 0]]], dtype=int64),
  Array([[-3.71108913, -4.90505547, -4.49303221, -6.07461569, -5.39201945,
          -5.6223428 , -4.32299773, -5.20345524, -4.94216978, -3.71108913,
          -3.71108913, -4.55951793, -3.80736862, -5.67792881, -5.74253178,
          -4.12523399, -4.20492811, -4.94996751, -6.35164571, -4.03463884,
          -3.71108913, -3.99895004, -3.60476323, -4.85110188, -3.6740067 ,
          -3.44592623, -3.78818675, -3.85106897, -3.75823389, -3.03824082,
          -4.21559352, -3.85106897, -6.33289696, -6.47665541, -5.3889676 ,
          -4.44445435, -4.05926971, -5.28871567, -3.71622625, -3.71622625,
          -5.44319004, -3.76083489, -3.71622625, -4.521701  , -3.62789358,
          -4.80118261, -3.03824082, -4.03950795, -4.21086863, -4.48032598,
         

In [8]:
sym_gum_par_net = jVMC.nets.sym_wrapper.SymNet(sym,gum_par_net,avgFun=jVMC.nets.sym_wrapper.avgFun_Coefficients_Sep_real)

psi_sym_gum_par_net= jVMC.vqs.NQS(sym_gum_par_net,seed=1)

psi_sym_par_net(homFock)

Array([[-7.75558061]], dtype=float64)

In [9]:
sampler_sym_gum_par_net = jVMC.sampler.MCSampler(psi_sym_gum_par_net,(L,),key2)

sampler_sym_gum_par_net.sample(numSamples=numSamp)

(Array([[[0, 3, 5, ..., 0, 0, 0],
         [4, 6, 0, ..., 0, 0, 0],
         [0, 0, 4, ..., 0, 0, 1],
         ...,
         [2, 2, 0, ..., 0, 5, 1],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 2, 0, ..., 0, 0, 1]]], dtype=int64),
 Array([[-12.7079954 , -13.64668567, -12.80411162, -13.05412547,
         -11.89667107, -12.53145401, -12.18742051, -12.75881141,
         -11.54136596, -12.22002883, -11.39737878, -12.05189932,
         -11.52804932, -13.01022532, -12.7134786 , -11.4758443 ,
         -12.95890863, -11.79237393, -12.41169387, -11.16058227,
         -12.00722392, -12.39897774, -11.31083725, -12.65393687,
         -12.19808642, -12.17946158, -11.95494066, -12.18519629,
         -11.92741434, -11.02943826, -13.2034721 , -12.46731396,
         -12.16006062, -12.22675696, -11.74073327, -13.34242517,
         -13.86797509, -12.10363418, -12.58495965, -11.99624644,
         -13.30321424, -12.54113638, -12.8687098 , -12.66988889,
         -11.27646036, -11.84306573, -12.50245477, 

In [10]:
key = jrnd.PRNGKey(1)
key2 = jrnd.PRNGKey(1)

sampler = sampler_par_net
psi = psi_par_net
lr_SR = 1e-2
minSR_equation = jVMC.util.MinSR(sampler, makeReal='real',diagonalShift=1e-2,diagonalMulti=1e-3)
stepperSR = jVMC.util.stepper.Euler(timeStep=lr_SR)  


In [11]:
training_steps= 100
resTraining = np.zeros((training_steps,2))
resTraining_gumbel = np.zeros((training_steps,2))

numS = 2**10
pbar = tqdm(range(training_steps))
for n,p in enumerate(pbar):
    
    dpOld = psi.get_parameters()        
    
    #print(dpOld)
    dp, _ = stepperSR.step(0, minSR_equation, dpOld, hamiltonian=H, psi=psi, numSamples=numS)
    psi.set_parameters(jnp.real(dp))
    resTraining[n] = [jnp.real(minSR_equation.ElocMean0) , minSR_equation.ElocVar0 ]
    #print(dp)
    
    pbar.set_description(f"energy: {resTraining[n][0]:.4f}+-{np.sqrt(resTraining[n][1]):.4f}")
    


  0%|          | 0/100 [00:00<?, ?it/s]

energy: -18.0202+-1.2522:  43%|████▎     | 43/100 [01:47<01:58,  2.07s/it]