In [1]:
import jVMC
import jax
#from jax.config import config
jax.config.update("jax_enable_x64", True)
import jax.random as jrnd
import jax.random as random

import jax.numpy as jnp
from functools import partial
import numpy as np
import jVMC.global_defs as global_defs
from tqdm import tqdm
import matplotlib.pyplot as plt
@jax.jit
def sorting_gumble(sample,logits,gumbel,states):
    indexes = jnp.argsort((-gumbel),axis=None)#.reshape(shape_gumbel)
    numSamples = sample.shape[0]
    ldim = sample.shape[1]
    L = sample.shape[2]
    #jax.debug.print("shape {x},{y},{z}",x=numSamples,y=ldim,z=L)
    indexes_states = (indexes // ldim)[:numSamples]
    sample = sample.reshape(-1,L)[indexes]
    sample = sample.reshape(ldim,numSamples,L)
    sample = jnp.swapaxes(sample,0,1)
    
    logits = logits.ravel()[indexes]
    logits = logits.reshape(ldim,numSamples).T
    
    gumbel = gumbel.ravel()[indexes]
    gumbel = gumbel.reshape(ldim,numSamples).T

    vals, treedef  = jax.tree_util.tree_flatten(states)
    vals_ordered = [v[indexes_states] for v in vals]
    states = jax.tree_util.tree_unflatten(treedef,vals_ordered)
    return sample,logits,gumbel,states

@partial(jax.jit, static_argnums=1)
def scan_fn(fct,carry, key):
    position = key[1]
    sample = carry[0]
    #jax.debug.print("{x}", x=sample)

    logits = carry[1]
    gumbel = carry[2]
    states = carry[3]
    keys = jrnd.split(key[0],(carry[0].shape[0],1))

    p_workN = partial(fct,position=position)
    sample,logits,gumbel,states = jax.vmap(p_workN)(sample,logits,gumbel,keys,states)
    #jax.debug.print("gumbelnew new {x}", x=gumbel)

    #### sorting gumble value
    return sorting_gumble(sample,logits,gumbel,states),None
    

In [2]:
class MCSampler_gumbel(jVMC.sampler.MCSampler):
    #change name to appeal Jonas
    def __init__(self,*args,**kwargs):
        super().__init__(*args, **kwargs)
        try: 
            callable(self.net.net.forward_with_state)
        except:
            try:
                callable(self.net.net.net.forward_with_state)
            except:
                raise Exception("neural network has no autoregressive stepwise sampling subroutine 'forward_with_state'")
    def setup_gumble(self):
        try:
            self.ldim = self.net.net.LocalHilDim
            self.lowNet = self.net.net
            self.psiNet = self.net

        except:
            self.lowNet = self.net.net.net
            self.psiNet = self.net.net

        self.ldim = self.lowNet.LocalHilDim
        self.Q = self.lowNet.Q
        #self.netcall = self.lowNet.apply
        self.netcall = self.net.net.apply
        
        self.L = self.sampleShape[0]
        
        self.max_particles = jnp.pad((self.ldim-1)*jnp.arange(1,self.L+1)[::-1],(0,1))[1::]
        self.must_mask = 2 * jnp.tril(jnp.ones((self.ldim,self.ldim)),k=-1)
        self.can_mask = jnp.flip(self.must_mask)
        print(self.Q)
    def _workN(self,sample,logits,gumbel,key,states,position):
        
        particles_left = self.Q - jnp.sum(sample[0]+jnp.abs(sample[0]))//2
        #new samples with (0,..,ldim-1) at position
        sample = jnp.array([sample[0].at[position].set(l) for l in jnp.arange(self.ldim)])
        #right shifted input
        inputt = jnp.array([jnp.pad(sample[0,:-1],(1,0))])
        # sampling outside psi --> need to give parameters manually
        para = self.net.parameters
        #if "net" in para:
        #    paraA = para["net"]
        #else:
        #    paraA = para
        logitnew, next_states = self.netcall(para,inputt[:,position],block_states = states, output_state=True)
        #logitnew, next_states = self.net.net(inputt[:,position],block_states = states, output_state=True)
    
        #jax.debug.print("pl: {x}",x=particles_left)
        must_give = jax.nn.relu(particles_left-self.max_particles[position])

        # number of particles that can be assigned
        can_give = jnp.minimum(particles_left, self.ldim-1)
        mask = (self.must_mask[must_give] + 
                    self.can_mask[can_give.astype(int)]) 
        #logit_renorm =  logitnew - mask ** jnp.inf
        #jax.debug.print("mask {x}", x=mask)
        logitnew =  jax.nn.log_softmax(logitnew - mask ** jnp.inf)
        #jax.debug.print("logits new {x}", x=logitnew)
        #jax.debug.print("logits[0] {x}", x=logits[0])
        
        logitnew = logits[0] + logitnew - mask ** jnp.inf
        #jax.debug.print("logits new {x}", x=logitnew)

        gumbelnew = logitnew + jrnd.gumbel(key[0],shape=(self.ldim,))  ## logitnew or logits[0]?
        #gumbelnew = logits[0] + jrnd.gumbel(key[0],shape=(ldim,))
        #jax.debug.print("gumbelnew new {x}", x=gumbelnew)

        Z = jnp.nanmax(gumbelnew)
        gumbelnew = jnp.nan_to_num(-jnp.log(
            jnp.exp(-gumbel[0])-jnp.exp(-Z)+jnp.exp(-gumbelnew) # -sign at gumbelnew??
            ),nan=-jnp.inf)
        #gumbelnew = (-jnp.log(jnp.exp(-gumbel[0])-jnp.exp(-Z)+jnp.exp(gumbelnew)))
        gumbelnew = gumbelnew- mask ** jnp.inf
        return sample, logitnew, gumbelnew, next_states
        #return sample,logits,gumbel,states
    
    def ee_scan_fn(self,carry, key):
        position = key[1]
        sample = carry[0]
        #jax.debug.print("{x}", x=sample)

        logits = carry[1]
        gumbel = carry[2]
        states = carry[3]
        keys = jrnd.split(key[0],(carry[0].shape[0],1))

        p_workN = partial(self._workN,position=position)
        sample,logits,gumbel,states = jax.vmap(p_workN)(sample,logits,gumbel,keys,states)
        #jax.debug.print("gumbelnew new {x}", x=gumbel)

        #### sorting gumble value
        return sorting_gumble(sample,logits,gumbel,states),None
    
    def gumbel(self,numSamples, tmpKey, parameters,outputGumbel=False):
        shape_samples = (numSamples,self.ldim,self.L)
        shape_logits = (numSamples,self.ldim)
        shape_gumbel = (numSamples,self.ldim)
        #print(shape_samples,shape_logits)
        working_space_samples = jnp.full(shape_samples,-2,dtype=jnp.int64)
        working_space_logits = jnp.full(shape_logits,-jnp.inf,dtype=jnp.float64)
        working_space_gumbel = jnp.full(shape_gumbel,-jnp.inf,dtype=jnp.float64)
        
        working_space_samples = working_space_samples.at[0,0,0].set(0)
        working_space_logits = working_space_logits.at[0,0].set(0.)
        working_space_gumbel = working_space_gumbel.at[0,0].set(0.)
        
        keys = jrnd.split(tmpKey[0],(self.L))
        states = None
        init_work = partial(self._workN, position=0,states=states)
        key0=jrnd.split(keys[0],(numSamples,1))
        
        sample,logits,gumbel,states  = jax.vmap(init_work)(working_space_samples,working_space_logits,working_space_gumbel,key0)
        #jax.debug.print("out1 logits {x}",x=logits)

        #jax.debug.print("out1 gumbel {x}",x=gumbel)
        
        init_carry = sorting_gumble(sample,logits,gumbel,states)
        #res, _ = jax.lax.scan(self.scan_fn,init_carry,(keys[1:],jnp.arange(1,self.L)))
        
        scan_fnP = partial(scan_fn,self._workN)
        res, _ = jax.lax.scan(scan_fnP,init_carry,(keys[1:],jnp.arange(1,self.L)))
        if outputGumbel==1:
            return res#,res[1][:,0]
        if outputGumbel==2:
            return jnp.expand_dims(res[0][:,0,:],0), res[2][0,1]
            
        return jnp.expand_dims(res[0][:,0,:],0)#,res[1][:,0]
        
    def _get_samples_gen_gumbel(self,parameters, numSamples, multipleOf):
        #numSamples = mpi.distribute_sampling(numSamples, localDevices=global_defs.device_count(), numChainsPerDevice=multipleOf)

        tmpKeys = random.split(self.key[0], 3 * global_defs.device_count())
        self.key = tmpKeys[:global_defs.device_count()]
        tmpKey = tmpKeys[global_defs.device_count():2 * global_defs.device_count()]
        tmpKey2 = tmpKeys[2 * global_defs.device_count():]

        #### rewrite explicit scanning in here
        samples, kappa = self.gumbel(numSamples, tmpKey, parameters=parameters,outputGumbel=2)
        ####
        if not str(numSamples) in self._randomize_samples_jitd:
            self._randomize_samples_jitd[str(numSamples)] = global_defs.pmap_for_my_devices(self._randomize_samples, static_broadcasted_argnums=(), in_axes=(0, 0, None))

        if not self.orbit is None:
            return self._randomize_samples_jitd[str(numSamples)](samples, tmpKey2, self.orbit)
        
        return samples, kappa
    def sample(self, parameters=None, numSamples=None, multipleOf=1):
        """Generate random samples from wave function.

        If supported by ``net``, direct sampling is peformed. Otherwise, MCMC is run \
        to generate the desired number of samples. For direct sampling the real part \
        of ``net`` needs to provide a ``sample()`` member function that generates \
        samples from :math:`p_{\\mu}(s)`.

        Sampling is automatically distributed accross MPI processes and available \
        devices. In that case the number of samples returned might exceed ``numSamples``.

        Arguments:
            * ``parameters``: Network parameters to use for sampling.
            * ``numSamples``: Number of samples to generate. When running multiple processes \
            or on multiple devices per process, the number of samples returned is \
            ``numSamples`` or more. If ``None``, the default number of samples is returned \
            (see ``set_number_of_samples()`` member function).
            * ``multipleOf``: This argument allows to choose the number of samples returned to \
            be the smallest multiple of ``multipleOf`` larger than ``numSamples``. This feature \
            is useful to distribute a total number of samples across multiple processors in such \
            a way that the number of samples per processor is identical for each processor.

        Returns:
            A sample of computational basis configurations drawn from :math:`p_{\\mu}(s)`.
        """

        if numSamples is None:
            numSamples = self.numSamples

        
        if parameters is not None:
            tmpP = self.net.params
            self.net.set_parameters(parameters)

        
        configs,kappa = self._get_samples_gen_gumbel(self.net.parameters, numSamples, multipleOf)

        #jax.debug.print("{x}", x=configs)
        coeffs = self.net(configs)
        coeffs_reweighted = coeffs /(-jnp.expm1(-jnp.exp(coeffs-kappa)))
        
        if parameters is not None:
            self.net.params = tmpP
        return configs, coeffs, coeffs_reweighted/jnp.sum(coeffs_reweighted)#jnp.ones(configs.shape[:2]) / jnp.prod(jnp.asarray(configs.shape[:2]))


In [3]:
N = 5
L = 5
ldim = 9
homFock = jnp.ones((1,1,L),dtype=int)

emb_RWKV = 4

depth_RWKV = 2
net_RWKV = jVMC.nets.bosons.RWKV(L,N,LocalHilDim=ldim,num_layers=depth_RWKV,embedding_size=emb_RWKV)


psi_RWKV = jVMC.vqs.NQS(net_RWKV)

#dummies for testing
#net_FFN = jVMC.nets.ffn.FFN()
#psi_FFN = jVMC.vqs.NQS(net_FFN)

psi_RWKV(homFock)


Array([[-3.27405247]], dtype=float64)

In [4]:
seed = 12
key = jrnd.PRNGKey(seed)
sampler = jVMC.sampler.MCSampler(psi_RWKV,(L,),key)
sampler_gumbel = MCSampler_gumbel(psi_RWKV,(L,),key)
sampler_gumbel.setup_gumble()


5


In [5]:
a = sampler_gumbel.sample(numSamples=128)
b = sampler.sample(numSamples=128)

ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 1) of type <class 'tuple'> for function scan_fn is non-hashable.

In [None]:
len(np.unique(a[0][0],axis=0)),len(np.unique(b[0][0],axis=0)),len(b[0][0])

In [None]:
%timeit sampler_gumbel.sample(numSamples=128)

In [None]:
%timeit sampler.sample(numSamples=128)

In [None]:
sym = jVMC.util.symmetries.get_orbit_1D(L,"reflection","translation")
symNet_RWKV = jVMC.nets.sym_wrapper.SymNet(sym,net_RWKV,avgFun=jVMC.nets.sym_wrapper.avgFun_Coefficients_Sep_real)

sym_psi_RWKV = jVMC.vqs.NQS(symNet_RWKV,batchSize=128)
sym_psi_RWKV(homFock)


In [None]:
sym_sampler = jVMC.sampler.MCSampler(sym_psi_RWKV,(L,),key)
sym_sampler_gumbel = MCSampler_gumbel(sym_psi_RWKV,(L,),key)
sym_sampler_gumbel.setup_gumble()


In [None]:
#sym_sampler_gumbel.sample(numSamples=128) #errorrrororororor

In [None]:
H = jVMC.operator.bosons.BoseHubbard_Hamiltonian1D(L,1.,0.,lDim=ldim)

In [6]:
net_RWKV = jVMC.nets.bosons.RWKV(L,N,LocalHilDim=ldim,num_layers=depth_RWKV,embedding_size=emb_RWKV)
net_RWKV_gumbel = jVMC.nets.bosons.RWKV(L,N,LocalHilDim=ldim,num_layers=depth_RWKV,embedding_size=emb_RWKV)

psi_RWKV = jVMC.vqs.NQS(net_RWKV)
psi_RWKV_gumbel = jVMC.vqs.NQS(net_RWKV_gumbel)

psi_RWKV(homFock)
psi_RWKV_gumbel(homFock)

sampler = jVMC.sampler.MCSampler(psi_RWKV,(L,),key)

sampler_gumbel = MCSampler_gumbel(psi_RWKV_gumbel,(L,),key)
sampler_gumbel.setup_gumble()

lr_SR = 1e-2
minSR_equation = jVMC.util.MinSR(sampler, makeReal='real',diagonalShift=1e-3,diagonalMulti=1e-3)
minSR_equation_gumbel = jVMC.util.MinSR(sampler_gumbel, makeReal='real',diagonalShift=1e-3,diagonalMulti=1e-3)

stepperSR = jVMC.util.stepper.Euler(timeStep=lr_SR)  
    

5


In [None]:
keyN = jrnd.split( jrnd.PRNGKey(1,),(L,))

out =sampler_gumbel.gumbel(16,keyN,None,outputGumbel=True)


In [None]:
out[0][:,0,:],out[1][:,0],out[2][:,0]

In [None]:
out[2][:]

In [None]:
training_steps= 50
resTraining = np.zeros((training_steps,2))
resTraining_gumbel = np.zeros((training_steps,2))

numS = 2**10
pbar = tqdm(range(training_steps))
for n,p in enumerate(pbar):
    dpOld = psi_RWKV.get_parameters()        
    print(dpOld.shape)
    dp, _ = stepperSR.step(0, minSR_equation, dpOld, hamiltonian=H, psi=psi_RWKV, numSamples=numS)
    psi_RWKV.set_parameters(jnp.real(dp))
    resTraining[n] = [jnp.real(minSR_equation.ElocMean0) , minSR_equation.ElocVar0 ]
    
    dpOld = psi_RWKV_gumbel.get_parameters()                
    dp, _ = stepperSR.step(0, minSR_equation_gumbel, dpOld, hamiltonian=H, psi=psi_RWKV_gumbel, numSamples=numS)
    psi_RWKV_gumbel.set_parameters(jnp.real(dp))
    resTraining_gumbel[n] = [jnp.real(minSR_equation_gumbel.ElocMean0) , minSR_equation_gumbel.ElocVar0 ]
    
    pbar.set_description(f"energy: {resTraining[n][0]:.2f}+-{np.sqrt(resTraining[n][1]):.4f} __ gumbel "+f"energy: {resTraining_gumbel[n][0]:.2f}+-{np.sqrt(resTraining_gumbel[n][1]):.2f}")


In [None]:
plt.plot(resTraining[:,0],'x')
plt.plot(resTraining_gumbel[:,0])
#plt.ylim(-10,10)

In [None]:
plt.plot(resTraining[:,1],'x')
plt.plot(resTraining_gumbel[:,1])
#plt.ylim(-10,10)
plt.semilogy()

In [None]:
sampler.sample(numSamples=16)

In [None]:
a = sampler_gumbel.sample(numSamples=24)
a


In [None]:
import math
math.comb(5,2)

In [None]:
a[0][0,jnp.argmax(a[1][0])]

In [None]:
np.argmax(a[1][0])

In [None]:
a[1][0,76]

In [79]:
from flax import linen as nn
from typing import Any, List, Optional, Tuple
from jax import Array, vmap, jit

class gumbel_sampler(nn.Module):
    """
    Wrapper module for symmetrization.
    This is a wrapper module for the incorporation of lattice symmetries. 
    The given plain ansatz :math:`\\psi_\\theta` is symmetrized as

        :math:`\\Psi_\\theta(s)=\\frac{1}{|\\mathcal S|}\\sum_{\\tau\\in\\mathcal S}\\psi_\\theta(\\tau(s))`

    where :math:`\\mathcal S` denotes the set of symmetry operations (``orbit`` in our nomenclature).

    Initialization arguments:
        * ``orbit``: orbits which define the symmetry operations (instance of ``util.symmetries.LatticeSymmetry``)
        * ``net``: Flax module defining the plain ansatz.
        * ``avgFun``: Different choices for the details of averaging.

    """
    #orbit: LatticeSymmetry
    net: callable
    is_gumbel = True

    #avgFun: callable = avgFun_Coefficients_Exp
    def __setup__(self):
        self.L = self.net.L

    def __post_init__(self):

        super().__post_init__()
        #self.orbit = None
        #self.is_generator = self.net.is_generator
    """
    def __call__(self, x):

        inShape = x.shape
        x = 2 * x - 1
        x = jax.vmap(lambda o, s: jnp.dot(o, s.ravel()).reshape(inShape), in_axes=(0, None))(self.orbit.orbit, x)
        x = (x + 1) // 2

        def evaluate(x):
            return self.net(x)

        res = jax.vmap(evaluate)(x)
        return self.avgFun(res, self.orbit.factor)
    """
    
    def __call__(self,x,**kwargs):
        
        #def evaluate(x):
        #    return self.net(x)

        #res = jax.vmap(evaluate)(x)
        return self.net(x,**kwargs)
        
    def _apply_fun(self, *args,**kwargs):
        return self.net.apply(*args,**kwargs)

# ** end class SymNet
        #return sample,logits,gumbel,states
    
    
    def _workN(self,sample,logits,gumbel,key,states,position):        
        particles_left = self.net.Q - jnp.sum(sample[0]+jnp.abs(sample[0]))//2
        #new samples with (0,..,ldim-1) at position
        sample = jnp.array([sample[0].at[position].set(l) for l in jnp.arange(self.net.ldim)])
        #right shifted input
        inputt = jnp.array([jnp.pad(sample[0,:-1],(1,0))])
        logitnew, next_states = self(inputt[:,position],block_states = states, output_state=True)
        #logitnew, next_states = self.net.net(inputt[:,position],block_states = states, output_state=True)
    
        #jax.debug.print("pl: {x}",x=particles_left)
        must_give = jax.nn.relu(particles_left-self.net.max_particles[position])

        # number of particles that can be assigned
        can_give = jnp.minimum(particles_left, self.net.ldim-1)
        mask = (self.net.must_mask[must_give] + 
                    self.net.can_mask[can_give.astype(int)]) 
        #logit_renorm =  logitnew - mask ** jnp.inf
        #jax.debug.print("mask {x}", x=mask)
        logitnew =  jax.nn.log_softmax(logitnew - mask ** jnp.inf)
        #jax.debug.print("logits new {x}", x=logitnew)
        #jax.debug.print("logits[0] {x}", x=logits[0])
        
        logitnew = logits[0] + logitnew - mask ** jnp.inf
        #jax.debug.print("logits new {x}", x=logitnew)

        gumbelnew = logitnew + jrnd.gumbel(key[0],shape=(self.net.ldim,))  ## logitnew or logits[0]?
        #gumbelnew = logits[0] + jrnd.gumbel(key[0],shape=(ldim,))
        #jax.debug.print("gumbelnew new {x}", x=gumbelnew)

        Z = jnp.nanmax(gumbelnew)
        gumbelnew = jnp.nan_to_num(-jnp.log(
            jnp.exp(-gumbel[0])-jnp.exp(-Z)+jnp.exp(-gumbelnew) # -sign at gumbelnew??
            ),nan=-jnp.inf)
        #gumbelnew = (-jnp.log(jnp.exp(-gumbel[0])-jnp.exp(-Z)+jnp.exp(gumbelnew)))
        gumbelnew = gumbelnew- mask ** jnp.inf
        return sample, logitnew, gumbelnew, next_states
    def sample(self, numSamples: int, key) -> Array:
        """Autoregressively sample a spin configuration.

        Args:
            * ``numSamples``: The number of configurations to generate.
            * ``key``: JAX random key.

        Returns:
            A batch of spin configurations.
        """
        
        """
        def generate_sample(key):
            key = jrnd.split(key, self.net.L)
            logits, carry = self(jnp.zeros(1,dtype=int),block_states = None, output_state=True)

            choice = jrnd.categorical(key[0], logits.ravel()) # abide by the indexing convention and apply -1
            s_cumsum = self.net.Q - choice # create cumsum of the quantum number
            _, s = self._scanning_fn((jnp.expand_dims(choice,0),carry,s_cumsum),(key[1:],jnp.arange(1,self.net.L)))
            return jnp.concatenate([jnp.expand_dims(choice,0),s])
        """
        # get the samples
        keys = jrnd.split(key, (self.net.L))
        #keys = jrnd.split(tmpKey[0],(self.net.L))

        


        ## init stap
        shape_samples = (numSamples,self.net.ldim,self.net.L)
        shape_logits = (numSamples,self.net.ldim)
        shape_gumbel = (numSamples,self.net.ldim)
        #print(shape_samples,shape_logits)
        working_space_samples = jnp.full(shape_samples,-2,dtype=jnp.int64)
        working_space_logits = jnp.full(shape_logits,-jnp.inf,dtype=jnp.float64)
        working_space_gumbel = jnp.full(shape_gumbel,-jnp.inf,dtype=jnp.float64)
        
        working_space_samples = working_space_samples.at[0,0,0].set(0)
        working_space_logits = working_space_logits.at[0,0].set(0.)
        working_space_gumbel = working_space_gumbel.at[0,0].set(0.)
        
        states = None
        init_work = partial(self._workN, position=0,states=states)
        key0=jrnd.split(keys[0],(numSamples,1))
        
        sample,logits,gumbel,states  = jax.vmap(init_work)(working_space_samples,working_space_logits,working_space_gumbel,key0)
        #jax.debug.print("out1 logits {x}",x=logits)

        #jax.debug.print("out1 gumbel {x}",x=gumbel)
        
        init_carry = sorting_gumble(sample,logits,gumbel,states)
        #scan_fnP = partial(scan_fn,self._workN)
        #res, _ = jax.lax.scan(scan_fnP,init_carry,(keys[1:],jnp.arange(1,self.L)))

        res,_ = self._scanning_fn(init_carry,(keys[1:],jnp.arange(1,self.net.L)))
        samples, logits,gumbels,_ = res
        
        kappa = gumbels[0,1]
        jax.debug.print("kappa {x}", x=kappa)
        
        re_weights = jnp.nan_to_num(jnp.exp(logits[:,0]) /(-jnp.expm1(-jnp.exp(logits[:,0]-kappa))),0)
        jax.debug.print("re {x}", x=re_weights)


        # return to the spinless representation
        return samples[:,0,:],logits[:,0]*self.net.logProbFactor,re_weights/jnp.sum(re_weights)

    @partial(nn.scan,
             variable_broadcast='params',
             split_rngs={'params': False})
    def _scanning_fn(self, carry, key):
        position = key[1]
        sample = carry[0]
        #jax.debug.print("{x}", x=sample)

        logits = carry[1]
        gumbel = carry[2]
        states = carry[3]
        keys = jrnd.split(key[0],(carry[0].shape[0],1))

        p_workN = partial(self._workN,position=position)
        sample,logits,gumbel,states = jax.vmap(p_workN)(sample,logits,gumbel,keys,states)
        #jax.debug.print("gumbelnew new {x}", x=gumbel)

        #### sorting gumble value
        return sorting_gumble(sample,logits,gumbel,states),None
        """
        logits, next_states = self(carry[0],block_states = carry[1], output_state=True)
        #jax.debug.print("logits {x}\n carry {y}",x=logits,y=carry)

        ##############################################
        # mask logits
        # number of particles that mus assigned
        must_give = nn.relu(carry[2]-self.net.max_particles[key[1]])
        # number of particles that can be assigned
        can_give = jnp.minimum(carry[2], self.net.LocalHilDim-1)
        # compute mask
        mask = (self.net.must_mask[must_give] + 
                    self.net.can_mask[can_give.astype(int)]) 
        ##############################################
        logits = logits - mask ** jnp.inf
        choice = jrnd.categorical(key[0], logits.ravel().real) # abide by the indexing convention
        ##############################################
        s_cumsum = carry[2] - choice
        return (jnp.expand_dims(choice,0), next_states, s_cumsum), choice
        """

In [80]:
net_g = gumbel_sampler(net_RWKV_gumbel)
psi_g = jVMC.vqs.NQS(net_g)

In [81]:
sampler_new = jVMC.sampler.MCSampler(psi_g,(L,),12,jrnd.PRNGKey(1))

In [82]:
%timeit sampler_new.sample(numSamples=1024)

kappa -inf
re [0.01209263 0.0958689  0.10400124 ... 0.         0.         0.        ]
kappa -inf
re [0.01242702 0.01090644 0.03048982 ... 0.         0.         0.        ]
kappa -inf
re [0.03568013 0.08759054 0.14228799 ... 0.         0.         0.        ]
kappa -inf
re [0.00124116 0.01678184 0.14228799 ... 0.         0.         0.        ]
kappa -inf
re [0.0026353  0.01690758 0.03674051 ... 0.         0.         0.        ]
kappa -inf
re [0.0275279  0.14228799 0.02378192 ... 0.         0.         0.        ]
kappa -inf
re [0.06508407 0.00085876 0.03674051 ... 0.         0.         0.        ]
kappa -inf
re [0.06508407 0.01392076 0.03674051 ... 0.         0.         0.        ]
19 ms ± 3.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [83]:
s,l,p = sampler_new.sample(numSamples=129)


kappa -inf
re [1.42287990e-01 8.37928854e-03 9.29498365e-03 8.75905352e-02
 9.58689008e-02 3.67405112e-02 2.75375950e-04 6.50840706e-02
 2.43768300e-02 3.04898240e-02 1.69075834e-02 2.37819208e-02
 1.04001239e-01 3.56801268e-02 2.24345058e-03 1.09064383e-02
 3.31360105e-02 2.75278982e-02 1.39207589e-02 6.71446697e-03
 5.02303979e-03 7.22567164e-03 3.20975316e-04 1.35031505e-03
 9.23430428e-03 1.20926261e-02 8.91353318e-03 2.17437550e-02
 2.70084459e-04 1.51132677e-03 1.99989363e-04 3.17434346e-03
 1.67818449e-02 3.15180230e-03 1.34556529e-02 3.87868577e-04
 2.88360841e-04 5.14640132e-03 2.63529663e-03 1.24270250e-02
 1.29656510e-02 3.31123176e-03 2.89445810e-03 1.24594269e-03
 7.34709400e-03 1.07889423e-03 9.87189364e-03 1.11891676e-03
 6.13325426e-04 2.90061655e-03 4.28139769e-03 3.57396492e-03
 3.31010868e-03 2.13343831e-03 4.96624764e-04 2.22598600e-03
 4.85691579e-04 8.58764894e-04 1.24116218e-03 1.35357604e-03
 1.39188659e-03 9.87648769e-05 2.54419469e-03 3.77171706e-03
 6.5261291

In [86]:
jnp.exp(2*l)-p

Array([[-2.77555756e-17, -1.73472348e-18, -1.73472348e-18,
        -1.38777878e-17, -2.77555756e-17, -6.93889390e-18,
        -5.42101086e-20, -1.38777878e-17, -6.93889390e-18,
        -6.93889390e-18, -3.46944695e-18, -6.93889390e-18,
        -2.77555756e-17, -6.93889390e-18, -4.33680869e-19,
        -1.73472348e-18, -6.93889390e-18, -6.93889390e-18,
        -3.46944695e-18, -1.73472348e-18, -8.67361738e-19,
        -1.73472348e-18, -5.42101086e-20, -2.16840434e-19,
        -1.73472348e-18, -3.46944695e-18, -1.73472348e-18,
        -3.46944695e-18, -5.42101086e-20, -4.33680869e-19,
        -5.42101086e-20, -8.67361738e-19, -3.46944695e-18,
        -8.67361738e-19, -3.46944695e-18, -1.08420217e-19,
        -5.42101086e-20, -8.67361738e-19, -4.33680869e-19,
        -3.46944695e-18, -3.46944695e-18, -8.67361738e-19,
        -4.33680869e-19, -2.16840434e-19, -1.73472348e-18,
        -2.16840434e-19, -1.73472348e-18, -2.16840434e-19,
        -1.08420217e-19, -4.33680869e-19, -8.67361738e-1

In [28]:
def jonasFunction(self,x,*args):
    return self(x,block_states=None,output_state=False)

In [None]:
psi_RWKV.net.apply(psi_RWKV.parameters,homFock,method=jonasFunction) 

In [25]:
import math
math.comb(N+L-1,L-1)

126

In [41]:
jnp.exp(-jnp.inf)

Array(0., dtype=float64, weak_type=True)