In [1]:
import numpy as np 
import jVMC

import flax
import jax
jax.config.update("jax_enable_x64", True)

from jax import Array, vmap, jit
import jax.numpy as jnp
import jax.random as jrnd

from flax import linen as nn

from functools import partial
from itertools import repeat
from typing import Any, List, Optional, Tuple

from jax.nn import log_softmax





In [61]:
class particle_conservation(nn.Module):
    """
    Wrapper module for symmetrization.
    This is a wrapper module for the incorporation of lattice symmetries. 
    The given plain ansatz :math:`\\psi_\\theta` is symmetrized as

        :math:`\\Psi_\\theta(s)=\\frac{1}{|\\mathcal S|}\\sum_{\\tau\\in\\mathcal S}\\psi_\\theta(\\tau(s))`

    where :math:`\\mathcal S` denotes the set of symmetry operations (``orbit`` in our nomenclature).

    Initialization arguments:
        * ``orbit``: orbits which define the symmetry operations (instance of ``util.symmetries.LatticeSymmetry``)
        * ``net``: Flax module defining the plain ansatz.
        * ``avgFun``: Different choices for the details of averaging.

    """
    #orbit: LatticeSymmetry
    net: callable
    Q: int
    
    
    #avgFun: callable = avgFun_Coefficients_Exp
    def setup(self):
        self.L = self.net.L
        self.LocalHilDim = self.net.LocalHilDim

        self.must_mask = 2 * jnp.tril(jnp.ones((self.LocalHilDim,self.LocalHilDim)),k=-1)
        self.can_mask = jnp.flip(self.must_mask)
        self.max_particles = jnp.pad((self.LocalHilDim-1)*jnp.arange(1,self.L+1)[::-1],(0,1))[1::]
        self.logProbFactor = self.net.logProbFactor
    def __post_init__(self):

        super().__post_init__()
        #self.orbit = None
        #self.is_generator = self.net.is_generator
    """
    def __call__(self, x):

        inShape = x.shape
        x = 2 * x - 1
        x = jax.vmap(lambda o, s: jnp.dot(o, s.ravel()).reshape(inShape), in_axes=(0, None))(self.orbit.orbit, x)
        x = (x + 1) // 2

        def evaluate(x):
            return self.net(x)

        res = jax.vmap(evaluate)(x)
        return self.avgFun(res, self.orbit.factor)
    """
    
    def __call__(self,*args,**kwargs):
        
        #def evaluate(x):
        #    return self.net(x)

        #res = jax.vmap(evaluate)(x)
        kwargs["output_state"] =True
        s = args[0]
        y = jnp.pad(s[:-1],(1,0),mode='constant',constant_values=0)
        cum_sum = self.Q - jnp.cumsum(y)
        #x, state = self.net(*args,output_state=True,**kwargs)
        x, state = self.net(*args,**kwargs)
        
        must_give = nn.relu(cum_sum-self.max_particles)
        can_give = jnp.minimum(cum_sum, self.LocalHilDim-1)
        mask = (self.must_mask[must_give] + 
                    self.can_mask[can_give.astype(int)]) 


        #jax.debug.print("s {x}", x=s)
        #jax.debug.print("x {x}", x=x)
        #jax.debug.print("must_give {x}", x=must_give)
        #jax.debug.print("can_give {x}", x=can_give)
        #jax.debug.print("mask {x}", x=mask)
        ##############################################
        x = x - mask ** jnp.inf
        #jax.debug.print("x after mask {x}", x=x)

        x = log_softmax(x)
        #jax.debug.print("x after softmax {x}", x=x)

        x *= self.logProbFactor
        # compute the phase in the auotregressive style
        # the log-probs according the state
        return (jnp.take_along_axis(x, jnp.expand_dims(s, -1), axis=-1)
                                .sum(axis=-2)
                                .squeeze(-1))
        return x
    def _apply_fun(self, *args,**kwargs):
        return self.net.apply(*args,**kwargs)

# ** end class SymNet
        #return sample,logits,gumbel,states
    def sample(self,numSamples: int, key) -> Array:
        """Autoregressively sample a spin configuration.

        Args:
            * ``numSamples``: The number of configurations to generate.
            * ``key``: JAX random key.

        Returns:
            A batch of spin configurations.
        """
        def generate_sample(key):
            key = jrnd.split(key, self.L)
            logits, carry = self.net(jnp.zeros(1,dtype=int),block_states = None, output_state=True)
            must_give = nn.relu(self.Q-self.max_particles[0])
            can_give = jnp.minimum(self.Q, self.LocalHilDim-1)
            mask = (self.must_mask[must_give] + 
                        self.can_mask[can_give.astype(int)]) 
            
            ##############################################
            logits = logits - mask ** jnp.inf
            choice = jrnd.categorical(key[0], logits.ravel()) # abide by the indexing convention and apply -1
            s_cumsum = self.Q - choice # create cumsum of the quantum number
            
            _, s = self._scanning_fn((jnp.expand_dims(choice,0),carry,s_cumsum),(key[1:],jnp.arange(1,self.L)))
            return jnp.concatenate([jnp.expand_dims(choice,0),s])

        # get the samples
        keys = jrnd.split(key, numSamples)
        samples = vmap(generate_sample)(keys)
        # return to the spinless representation
        return samples

    @partial(nn.scan,
             variable_broadcast='params',
             split_rngs={'params': False})
    def _scanning_fn(self, carry, key):
        logits, next_states = self.net(carry[0],block_states = carry[1], output_state=True)
        
        must_give = nn.relu(carry[2]-self.max_particles[key[1]])
        can_give = jnp.minimum(carry[2], self.LocalHilDim-1)
        mask = (self.must_mask[must_give] + 
                    self.can_mask[can_give.astype(int)]) 
        
        ##############################################
        logits = logits - mask ** jnp.inf
        choice = jrnd.categorical(key[0], logits.ravel().real) # abide by the indexing convention
        ##############################################
        s_cumsum = carry[2] - choice
        return (jnp.expand_dims(choice,0), next_states, s_cumsum), choice

In [62]:
N = 10
L = 10
ldim = 13
homFock = jnp.ones((1,1,L),dtype=int)

oneSiteFockStates = jnp.expand_dims(jnp.eye(L,dtype=int)*N,0)

J = 1.
U = 20.

H = jVMC.operator.bosons.BoseHubbard_Hamiltonian1D(L,J,U,lDim=ldim)
print('homogeneous fock:',homFock)
emb_RWKV = 16

depth_RWKV = 2

print('all one sited fock states:',oneSiteFockStates)

homogeneous fock: [[[1 1 1 1 1 1 1 1 1 1]]]
all one sited fock states: [[[10  0  0  0  0  0  0  0  0  0]
  [ 0 10  0  0  0  0  0  0  0  0]
  [ 0  0 10  0  0  0  0  0  0  0]
  [ 0  0  0 10  0  0  0  0  0  0]
  [ 0  0  0  0 10  0  0  0  0  0]
  [ 0  0  0  0  0 10  0  0  0  0]
  [ 0  0  0  0  0  0 10  0  0  0]
  [ 0  0  0  0  0  0  0 10  0  0]
  [ 0  0  0  0  0  0  0  0 10  0]
  [ 0  0  0  0  0  0  0  0  0 10]]]


In [63]:
net_RWKV_old = jVMC.nets.bosons.RWKV(L,N,LocalHilDim=ldim,num_layers=depth_RWKV,embedding_size=emb_RWKV)
seed = 252
psi_RWKV_old= jVMC.vqs.NQS(net_RWKV_old,seed=seed)
psi_RWKV_old(homFock),psi_RWKV_old(2*homFock),psi_RWKV_old(oneSiteFockStates)

(Array([[-7.60829403]], dtype=float64),
 Array([[-inf]], dtype=float64),
 Array([[ -1.28896415,  -2.47297731,  -3.65995463,  -4.84143034,
          -6.02287505,  -7.20273768,  -8.38229959,  -9.5616321 ,
         -10.74082936, -10.70980814]], dtype=float64))

In [64]:
psi_RWKV_old(homFock),psi_RWKV_old(2*homFock),psi_RWKV_old(oneSiteFockStates)

(Array([[-7.60829403]], dtype=float64),
 Array([[-inf]], dtype=float64),
 Array([[ -1.28896415,  -2.47297731,  -3.65995463,  -4.84143034,
          -6.02287505,  -7.20273768,  -8.38229959,  -9.5616321 ,
         -10.74082936, -10.70980814]], dtype=float64))

In [65]:
depth_RWKV = 2
emb_RWKV =16
hidden_size=32
num_heads = 4


net_RWKV = jVMC.nets.RpxRWKV(L,LocalHilDim=ldim,hidden_size=hidden_size,num_heads=num_heads,embedding_size=emb_RWKV,num_layers=depth_RWKV,bias=True)
seed = 252
psi_RWKV_wo_part= jVMC.vqs.NQS(net_RWKV,seed=seed)
psi_RWKV_wo_part(homFock),psi_RWKV_wo_part(2*homFock)

(Array([[-11.88401534]], dtype=float64),
 Array([[-13.79210809]], dtype=float64))

In [66]:
net_RWKV_wrap_part = particle_conservation(net_RWKV,N)
psi_RWKV_w_part= jVMC.vqs.NQS(net_RWKV_wrap_part,seed=seed)

psi_RWKV_w_part(homFock),psi_RWKV_w_part(2*homFock),psi_RWKV_w_part(oneSiteFockStates)

(Array([[-7.30447865]], dtype=float64),
 Array([[-inf]], dtype=float64),
 Array([[ -1.34503637,  -2.84606767,  -4.37602534,  -5.90778971,
          -7.43979482,  -8.97198267, -10.50431147, -12.03677132,
         -13.56935075, -13.75126752]], dtype=float64))

In [67]:
variable = net_RWKV.init(jrnd.PRNGKey(12),homFock[0,0])
variable

FrozenDict({
    params: {
        Head: {
            kernel: Array([[ 0.01079601, -0.27497497,  0.1394813 ,  0.31337826, -0.01639144,
                     0.09474495, -0.08647541, -0.23903588, -0.16516423,  0.10456421,
                     0.13778455,  0.20751816, -0.16281643],
                   [-0.02399634, -0.01461839, -0.08125063, -0.13205897,  0.21446964,
                    -0.16115794, -0.05929874,  0.14554846,  0.11821388, -0.04838565,
                    -0.11430673, -0.30064014,  0.19928252],
                   [ 0.1439655 , -0.2495926 ,  0.07319328,  0.06646912, -0.06910193,
                    -0.0580003 ,  0.13067494,  0.24328205,  0.17691969, -0.31654338,
                    -0.28949719,  0.13242956, -0.13046143],
                   [ 0.01519505, -0.37661995,  0.0758131 ,  0.31655831,  0.07716328,
                    -0.19118639, -0.13695186,  0.00582176, -0.21998471, -0.05380604,
                     0.09692861,  0.35929888, -0.12828338],
                   [ 0.053972

In [68]:

x, state = net_RWKV.apply(variable,homFock[0,0],output_state=True)

In [69]:
sym= jVMC.util.symmetries.get_orbit_1D(L,"reflection","translation")
symNet = jVMC.nets.sym_wrapper.SymNet(sym,net_RWKV_wrap_part,avgFun=jVMC.nets.sym_wrapper.avgFun_Coefficients_Sep_real)
psi_RWKV_w_part_sym= jVMC.vqs.NQS(symNet,seed=seed)
psi_RWKV_w_part_sym(homFock),psi_RWKV_w_part_sym(2*homFock),psi_RWKV_w_part_sym(oneSiteFockStates)

(Array([[-8.71411702]], dtype=float64),
 Array([[-inf]], dtype=float64),
 Array([[-2.21454926, -2.21454926, -2.21454926, -2.21454926, -2.21454926,
         -2.21454926, -2.21454926, -2.21454926, -2.21454926, -2.21454926]],      dtype=float64))

In [70]:
key2 = jrnd.PRNGKey(12)
sampler = jVMC.sampler.MCSampler(psi_RWKV_old,(L,),key2)
sampler_w_part = jVMC.sampler.MCSampler(psi_RWKV_w_part,(L,),key2)
sampler_w_part_sym = jVMC.sampler.MCSampler(psi_RWKV_w_part_sym,(L,),key2)


In [71]:
numSamp = 2**8
sampler.sample(numSamples=numSamp)

(Array([[[10,  0,  0, ...,  0,  0,  0],
         [10,  0,  0, ...,  0,  0,  0],
         [ 6,  4,  0, ...,  0,  0,  0],
         ...,
         [ 1,  2,  6, ...,  0,  0,  0],
         [ 1,  0,  5, ...,  0,  0,  0],
         [ 9,  1,  0, ...,  0,  0,  0]]], dtype=int64),
 Array([[-1.28896415, -1.28896415, -2.19339018, -3.85463611, -1.94037556,
         -1.28896415, -1.28896415, -2.26643646,        -inf,        -inf,
         -5.00892216, -4.08783269, -3.6173196 , -3.39068573, -4.65251   ,
         -3.76005808, -1.82359761,        -inf, -3.34512308, -2.68100565,
                -inf, -2.12701905, -1.82359761, -4.66927137, -2.58481351,
         -2.42405583, -3.16288205,        -inf,        -inf, -1.28896415,
                -inf, -4.06714359, -3.29534459, -3.65995463, -2.3293728 ,
         -3.79518599, -4.05622957, -3.28636606, -1.28896415, -1.94037556,
         -3.28482648, -2.9159504 , -2.50242183, -3.34512308, -2.47297731,
         -4.33761953,        -inf, -4.43852046, -3.17224214, -4.

In [72]:
sampler_w_part.sample(numSamples=numSamp)

(Array([[[ 9,  1,  0, ...,  0,  0,  0],
         [10,  0,  0, ...,  0,  0,  0],
         [ 6,  4,  0, ...,  0,  0,  0],
         ...,
         [ 1,  2,  6, ...,  0,  0,  0],
         [ 1,  0,  5, ...,  0,  0,  0],
         [ 9,  1,  0, ...,  0,  0,  0]]], dtype=int64),
 Array([[-1.5838947 , -1.34503637, -2.17336047, -3.80625988, -1.88830848,
         -1.34503637, -1.34503637, -2.96209031, -2.42223833, -1.5838947 ,
         -4.35609669, -4.20468731, -3.10844926, -2.34177747, -2.13732351,
         -3.7330561 , -2.67869019, -3.04401872, -2.97207734, -2.07817369,
         -5.229104  , -1.64340836, -1.66355012, -3.62818472, -2.09802702,
         -1.66355012, -3.04401872, -1.34503637, -1.34503637, -1.34503637,
         -1.5838947 , -4.57296226, -3.04248935, -3.73866039, -2.16498235,
         -3.8628361 , -4.37601457, -3.44051766, -1.34503637, -1.88830848,
         -3.03077776, -4.77813615, -2.19229847, -2.97207734, -2.84606767,
         -4.98950854, -3.3729789 , -4.6815263 , -3.25700343, -4.

In [73]:
sampler_w_part_sym.sample(numSamples=numSamp)

(Array([[[ 0,  0,  0, ...,  0,  0,  0],
         [ 0,  0,  0, ..., 10,  0,  0],
         [ 0,  0,  0, ...,  0,  0,  0],
         ...,
         [ 1,  0,  0, ...,  0,  0,  0],
         [ 0,  1,  0, ...,  2,  2,  5],
         [ 0,  0,  0, ...,  0,  1,  9]]], dtype=int64),
 Array([[-2.21454926, -2.21454926, -3.60187074, -5.48983565, -3.56624891,
         -2.21454926, -2.21454926, -2.21454926, -3.63090502, -2.85487978,
         -5.10688323, -5.58284947, -4.87157477, -3.60187074, -5.98696285,
         -5.3580046 , -4.36703966, -3.71705356, -4.60040677, -4.29941679,
         -6.01003361, -3.55176971, -3.01674431, -5.45904542, -4.09557494,
         -3.63090502, -4.39695293, -2.21454926, -2.21454926, -2.21454926,
         -4.90769248, -5.1717748 , -3.55176971, -2.21454926, -3.56624891,
         -5.13198958, -6.21096689, -4.67133045, -2.21454926, -3.56624891,
         -4.97679743, -3.63090502, -3.01674431, -4.60040677, -2.21454926,
         -5.71217584, -3.56624891, -5.89624683, -4.30632442, -5.