In [1]:
import jVMC
import jax
#from jax.config import config
jax.config.update("jax_enable_x64", True)
import jax.random as jrnd

import jax.numpy as jnp
from functools import partial
import numpy as np


In [2]:
numSamples = 2**4
ldim = 4
L = 10

net = jVMC.nets.bosons.RWKV(L,L,LocalHilDim=ldim,embedding_size=5,num_layers=2)
#net = jVMC.nets.bosons.GPT(L,3,lDim=ldim,embeddingDim=4,depth=2,nHeads=2)

psi = jVMC.vqs.NQS(net)
homFock = jnp.ones((1,1,L),dtype=int)
psi(homFock)

Array([[-6.27901942]], dtype=float64)

In [3]:
shape_samples = (numSamples,ldim,L)
shape_logits = (numSamples,ldim)
shape_gumbel = (numSamples,ldim)
print(shape_samples,shape_logits)
working_space_samples = jnp.full(shape_samples,-2,dtype=jnp.int64)
working_space_logits = jnp.full(shape_logits,-jnp.inf,dtype=jnp.float64)
working_space_gumbel = jnp.full(shape_gumbel,-jnp.inf,dtype=jnp.float64)

working_space_samples = working_space_samples.at[0,0,0].set(-1)
working_space_logits = working_space_logits.at[0,0].set(0.)
working_space_gumbel = working_space_gumbel.at[0,0].set(0.)

key = jrnd.PRNGKey(1)
keys = jrnd.split(key,(numSamples,1))

(16, 4, 10) (16, 4)


In [4]:
print("space required:",(working_space_samples.nbytes+working_space_logits.nbytes+working_space_gumbel.nbytes)//2**10,"kB")

space required: 6 kB


In [5]:
def _workN(sample,logits,gumbel,states,key,position):
    newkey = jrnd.split(key[0],2)

    sample = jnp.array([sample[0].at[position].set(l) for l in jnp.arange(ldim)])
    #logitnew = logits[0] + jnp.log(jrnd.uniform(newkey[0],shape=(ldim,))) # here network output
    inputt = jnp.array([jnp.pad(sample[0,:-1],(1,0))])
    #jax.debug.print("{x}",x=states)
    para = psi.parameters
    #jax.debug.print("{x}",x=inputt[:,position])
    #logitnew, next_states = psi.net.apply(para,inputt,returnLogAmp=True)
    logitnew, next_states = psi.net.apply(para,inputt[:,position],block_states = states, output_state=True)
    
    #jax.debug.print("hello{x}\n",x=(next_states))
    #jax.debug.print("hello2{x}\n\n",x=next_states[0])
    

    gumbelnew = logitnew + jrnd.gumbel(newkey[1],shape=(ldim,))
    Z = jnp.max(gumbelnew)
    gumbelnew = jnp.nan_to_num(-jnp.log(jnp.exp(-gumbel[0])-jnp.exp(-Z)+jnp.exp(gumbelnew)),copy=True,nan=-jnp.inf)
    
    return sample,logitnew,gumbelnew,next_states

__workN = partial(_workN,position=0)
states = None
#print(states)

In [6]:
test,test_logit,test_gumbel,next_states = jax.vmap(__workN)(working_space_samples,working_space_logits,working_space_gumbel,states,keys)

In [7]:
print(test.shape)


(16, 4, 10)


In [8]:
print(len(next_states),type(next_states))
print(len(next_states[0]),type(next_states[0]))
print(len(next_states[0][0]),type(next_states[0][0]))
print(len(next_states[0][0][0]),type(next_states[0][0][0]))

print((next_states[0][0][0].shape))

a = zip(*next_states)
list(a)

2 <class 'list'>
2 <class 'tuple'>
4 <class 'tuple'>
16 <class 'jaxlib.xla_extension.ArrayImpl'>
(16, 5)


[((Array([[-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
          [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
          [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
          [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
          [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
          [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
          [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
          [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
          [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
          [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
          [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
          [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
          [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],

In [9]:
next_states

[((Array([[-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
          [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
          [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
          [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
          [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
          [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
          [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
          [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
          [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
          [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
          [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
          [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
          [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],

In [10]:
aa = np.array(next_states,dtype=object)
aa[0,0]

(Array([[-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
        [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
        [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
        [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
        [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
        [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
        [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
        [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
        [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
        [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
        [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
        [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
        [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
        [-0.23618801, -0.

In [11]:
len(next_states)

2

In [12]:
def _work(sample,logits,gumbel,key,position):
    newkey = jrnd.split(key[0],2)
    
    sample = jnp.array([sample[0].at[position].set(l) for l in jnp.arange(ldim)])
    logitnew = logits[0] + jnp.log(jrnd.uniform(newkey[0],shape=(ldim,))) # here network output
    gumbelnew = logitnew + jrnd.gumbel(newkey[1],shape=(ldim,))
    Z = jnp.max(gumbelnew)
    gumbelnew = jnp.nan_to_num(-jnp.log(jnp.exp(-gumbel[0])-jnp.exp(-Z)+jnp.exp(gumbelnew)),copy=True,nan=-jnp.inf)
    
    return sample,logitnew,gumbelnew
__work = partial(_work,position=0)
test,test_logit,test_gumbel = jax.vmap(__work)(working_space_samples,working_space_logits,working_space_gumbel,keys)

In [13]:
test.shape,test_logit.shape,test_gumbel.shape

((16, 4, 10), (16, 4), (16, 4))

In [14]:
indexes = jnp.argsort((-test_gumbel),axis=None)#.reshape(shape_gumbel)
indexes_states = indexes % numSamples
print(indexes.shape,indexes.reshape(shape_gumbel).shape)

(64,) (16, 4)


In [15]:
indexes

Array([ 2,  3,  1,  0,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
       34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
       51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63], dtype=int64)

In [16]:
test2 = test.reshape(-1,L)[indexes]
test2 = test2.reshape(ldim,numSamples,L)
test2 = jnp.swapaxes(test2,0,1)
test2

Array([[[ 2, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 0, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 0, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 0, -2, -2, -2, -2, -2, -2, -2, -2, -2]],

       [[ 3, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 1, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 1, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 1, -2, -2, -2, -2, -2, -2, -2, -2, -2]],

       [[ 1, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 2, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 2, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 2, -2, -2, -2, -2, -2, -2, -2, -2, -2]],

       [[ 0, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 3, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 3, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 3, -2, -2, -2, -2, -2, -2, -2, -2, -2]],

       [[ 0, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 0, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 0, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 0, -2, -2, -2, -2, -2, -2, -2, -

In [17]:
test3 = test.reshape(-1,L)[indexes]
test3 = test.reshape(numSamples,ldim,L)

test3

Array([[[ 0, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 1, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 2, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 3, -2, -2, -2, -2, -2, -2, -2, -2, -2]],

       [[ 0, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 1, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 2, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 3, -2, -2, -2, -2, -2, -2, -2, -2, -2]],

       [[ 0, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 1, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 2, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 3, -2, -2, -2, -2, -2, -2, -2, -2, -2]],

       [[ 0, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 1, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 2, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 3, -2, -2, -2, -2, -2, -2, -2, -2, -2]],

       [[ 0, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 1, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 2, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 3, -2, -2, -2, -2, -2, -2, -2, -

In [18]:
test = test.reshape(-1,L)[indexes]
test=test.reshape(ldim,numSamples,L)
test = jnp.swapaxes(test,0,1)
#test =jnp.take_along_axis(test.reshape(-1,L),indexes,axis=1)
test_logit =test_logit.ravel()[indexes]

test_logit =test_logit.reshape(ldim,numSamples).T

test_gumbel =test_gumbel.ravel()[indexes]
test_gumbel =test_gumbel.reshape(ldim,numSamples).T

In [19]:
vals, treedef  = jax.tree_util.tree_flatten(next_states)
vals_ordered = [v[indexes_states] for v in vals]
next_states = jax.tree_util.tree_unflatten(treedef,vals_ordered)


In [20]:
for v in vals:
    print(v.shape)

(16, 5)
(16, 5)
(16, 5)
(16, 5)
(16, 5)
(16, 5)
(16, 5)
(16, 5)
(16, 5)
(16, 5)


In [21]:
test.shape,test_logit.shape,test_gumbel.shape

((16, 4, 10), (16, 4), (16, 4))

In [22]:
test_gumbel

Array([[ 4.82767751e+000, -1.79769313e+308, -1.79769313e+308,
        -1.79769313e+308],
       [ 1.18292555e+000, -1.79769313e+308, -1.79769313e+308,
        -1.79769313e+308],
       [ 5.85091973e-001, -1.79769313e+308, -1.79769313e+308,
        -1.79769313e+308],
       [ 4.75331219e-001, -1.79769313e+308, -1.79769313e+308,
        -1.79769313e+308],
       [-1.79769313e+308, -1.79769313e+308, -1.79769313e+308,
        -1.79769313e+308],
       [-1.79769313e+308, -1.79769313e+308, -1.79769313e+308,
        -1.79769313e+308],
       [-1.79769313e+308, -1.79769313e+308, -1.79769313e+308,
        -1.79769313e+308],
       [-1.79769313e+308, -1.79769313e+308, -1.79769313e+308,
        -1.79769313e+308],
       [-1.79769313e+308, -1.79769313e+308, -1.79769313e+308,
        -1.79769313e+308],
       [-1.79769313e+308, -1.79769313e+308, -1.79769313e+308,
        -1.79769313e+308],
       [-1.79769313e+308, -1.79769313e+308, -1.79769313e+308,
        -1.79769313e+308],
       [-1.79769313e+

In [23]:
__work = partial(_work,position=1)
test,test_logit,test_gumbel = jax.vmap(__work)(test,test_logit,test_gumbel,keys)

In [24]:
test[0,:]

Array([[ 2,  0, -2, -2, -2, -2, -2, -2, -2, -2],
       [ 2,  1, -2, -2, -2, -2, -2, -2, -2, -2],
       [ 2,  2, -2, -2, -2, -2, -2, -2, -2, -2],
       [ 2,  3, -2, -2, -2, -2, -2, -2, -2, -2]], dtype=int64)

In [25]:
test_gumbel[10]

Array([-1.79769313e+308, -1.79769313e+308, -1.79769313e+308,
       -1.79769313e+308], dtype=float64)

In [26]:
para = psi.parameters

logitnew, next_states = psi.net.apply(para,jnp.array([0]),block_states = None, output_state=True)


In [27]:
len(next_states)
len(next_states[0])
len(next_states[0][0])

4

In [28]:
np.array(next_states,dtype=object)

array([[(Array([-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],      dtype=float64), Array([-0.05376298, -0.09234596, -0.33510117, -0.00319198,  0.32256267],      dtype=float64), Array([1., 1., 1., 1., 1.], dtype=float64), Array([-0.6544959 , -0.06872965,  0.3672761 , -0.24597523, -0.037108  ],      dtype=float64)),
        Array([-0.3263443 , -0.27118838, -0.23936346, -1.17183447, -0.00444612],      dtype=float64)],
       [(Array([-0.40923051, -0.37972206, -0.23819815, -1.33588533, -0.07775159],      dtype=float64), Array([ 0.38243725,  0.41029255,  0.26731984, -0.32824794,  0.67079424],      dtype=float64), Array([1., 1., 1., 1., 1.], dtype=float64), Array([-0.35876541, -0.31570683,  0.12292459,  0.28810201, -0.08455723],      dtype=float64)),
        Array([-0.5166411 , -0.48439983,  0.01450164, -1.70214781,  0.15744088],      dtype=float64)]],
      dtype=object)

In [29]:
logitnew, next_states2 = psi.net.apply(para,jnp.array([0]),block_states = next_states, output_state=True)


In [30]:
b = np.array(next_states2,dtype=object)
b.shape

(2, 2)

In [31]:
max_particles = jnp.pad((ldim-1)*jnp.arange(1,L+1)[::-1],(0,1))[1::]
must_mask = 2 * jnp.tril(jnp.ones((ldim,ldim)),k=-1)
can_mask = jnp.flip(must_mask)
max_particles,ldim

(Array([27, 24, 21, 18, 15, 12,  9,  6,  3,  0], dtype=int64), 4)

In [124]:
@jax.jit
def _workN(sample,logits,gumbel,key,states,position):
    newkey = key #jrnd.split(key[0],2)
    particles_left = psi.net.Q - jnp.sum(sample[0]+jnp.abs(sample[0]))//2
    sample = jnp.array([sample[0].at[position].set(l) for l in jnp.arange(ldim)])
    inputt = jnp.array([jnp.pad(sample[0,:-1],(1,0))])
    #jax.debug.print("{x}",x=inputt.shape)
    para = psi.parameters
    logitnew, next_states = psi.net.apply(para,inputt[:,position],block_states = states, output_state=True)

    
    #must_give = jax.nn.relu(particles_left-psi.net.max_particles[position])
    must_give = jax.nn.relu(particles_left-max_particles[position])
    
    # number of particles that can be assigned
    can_give = jnp.minimum(particles_left, psi.net.LocalHilDim-1)
    # compute mask
    #mask = (psi.net.must_mask[must_give] + 
    #            psi.net.can_mask[can_give.astype(int)]) 
    mask = (must_mask[must_give] + 
                can_mask[can_give.astype(int)]) 
    #logit_renorm =  logitnew - mask ** jnp.inf
    
    logitnew = jax.nn.log_softmax(logitnew - mask ** jnp.inf)
    logitnew = logits[0] + logitnew #- mask ** jnp.inf
    jax.debug.print("logitnew: {x}",x=logitnew)

    #gumbelnew = logitnew + jrnd.gumbel(newkey[0],shape=(ldim,))
    gumbelnew = logits[0] + jrnd.gumbel(newkey[0],shape=(ldim,))
    jax.debug.print("gumbelnew: {x}",x=gumbelnew)

    Z = jnp.nanmax(gumbelnew,)
    jax.debug.print("Z: {x}",x=Z)
    
    jax.debug.print("gumbel[0]: {x}",x=gumbel[0])
    gumbelnew = jnp.nan_to_num(-jnp.log(jnp.exp(-gumbel[0])-jnp.exp(-Z)+jnp.exp(-gumbelnew)),copy=True,nan=-jnp.inf)
    #gumbelnew = (jnp.exp(-gumbel[0])-jnp.exp(-Z)+jnp.exp(-gumbelnew))
    jax.debug.print("gumbelnew: {x}",x=gumbelnew)

    gumbelnew = gumbelnew - mask ** jnp.inf
    return sample, logitnew, gumbelnew, next_states
@jax.jit
def sorting_gumble(sample,logits,gumbel,states):
    indexes = jnp.argsort((-gumbel),axis=None)#.reshape(shape_gumbel)
    indexes_states = (indexes // ldim)[:numSamples]
    sample = sample.reshape(-1,L)[indexes]
    sample = sample.reshape(ldim,numSamples,L)
    sample = jnp.swapaxes(sample,0,1)
    
    logits = logits.ravel()[indexes]
    logits = logits.reshape(ldim,numSamples).T
    
    gumbel = gumbel.ravel()[indexes]
    gumbel = gumbel.reshape(ldim,numSamples).T

    vals, treedef  = jax.tree_util.tree_flatten(states)
    vals_ordered = [v[indexes_states] for v in vals]
    states = jax.tree_util.tree_unflatten(treedef,vals_ordered)
    return sample,logits,gumbel,states
@jax.jit
def scan_fn(carry, key):
    position = key[1]
    #jax.debug.print("{x} position",x=position)
    keys = jrnd.split(key[0],(numSamples,1))
    sample = carry[0]
    logits = carry[1]
    gumbel = carry[2]
    states = carry[3]
    p_workN = partial(_workN,position=position)
    sample,logits,gumbel,states = jax.vmap(p_workN)(sample,logits,gumbel,keys,states)

    #### sorting gumble values
    

    
    return sorting_gumble(sample,logits,gumbel,states),None
    

In [125]:
shape_samples = (numSamples,ldim,L)
shape_logits = (numSamples,ldim)
shape_gumbel = (numSamples,ldim)
#print(shape_samples,shape_logits)
working_space_samples = jnp.full(shape_samples,-2,dtype=jnp.int64)
working_space_logits = jnp.full(shape_logits,-jnp.inf,dtype=jnp.float64)
working_space_gumbel = jnp.full(shape_gumbel,-jnp.inf,dtype=jnp.float64)

working_space_samples = working_space_samples.at[0,0,0].set(0)
working_space_logits = working_space_logits.at[0,0].set(0.)
working_space_gumbel = working_space_gumbel.at[0,0].set(0.)

key = jrnd.PRNGKey(12)
keys = jrnd.split(key,(L))
states = None
init_work = partial(_workN, position=0,states=states)
working_space_gumbel,working_space_logits

(Array([[  0., -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf]], dtype=float64),
 Array([[  0., -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf

In [126]:
key0=jrnd.split(keys[0],(numSamples,1))
sample,logits,gumbel,states  = sorting_gumble(*jax.vmap(init_work)(working_space_samples,working_space_logits,working_space_gumbel,key0))
init_carry = sample,logits,gumbel,states
gumbel,logits

logitnew: [[-1.2955238  -1.33387066 -1.52334739 -1.40727683]]
logitnew: [[-inf -inf -inf -inf]]
logitnew: [[-inf -inf -inf -inf]]
logitnew: [[-inf -inf -inf -inf]]
logitnew: [[-inf -inf -inf -inf]]
logitnew: [[-inf -inf -inf -inf]]
logitnew: [[-inf -inf -inf -inf]]
logitnew: [[-inf -inf -inf -inf]]
logitnew: [[-inf -inf -inf -inf]]
logitnew: [[-inf -inf -inf -inf]]
logitnew: [[-inf -inf -inf -inf]]
logitnew: [[-inf -inf -inf -inf]]
logitnew: [[-inf -inf -inf -inf]]
logitnew: [[-inf -inf -inf -inf]]
logitnew: [[-inf -inf -inf -inf]]
logitnew: [[-inf -inf -inf -inf]]
gumbelnew: [ 0.77154807  0.51677635 -0.94433252  0.89978207]
gumbelnew: [-inf -inf -inf -inf]
gumbelnew: [-inf -inf -inf -inf]
gumbelnew: [-inf -inf -inf -inf]
gumbelnew: [-inf -inf -inf -inf]
gumbelnew: [-inf -inf -inf -inf]
gumbelnew: [-inf -inf -inf -inf]
gumbelnew: [-inf -inf -inf -inf]
gumbelnew: [-inf -inf -inf -inf]
gumbelnew: [-inf -inf -inf -inf]
gumbelnew: [-inf -inf -inf -inf]
gumbelnew: [-inf -inf -inf -inf]
gumb

(Array([[-0.00000000e+000, -1.79769313e+308, -1.79769313e+308,
         -1.79769313e+308],
        [-5.41458654e-002, -1.79769313e+308, -1.79769313e+308,
         -1.79769313e+308],
        [-1.73770010e-001, -1.79769313e+308, -1.79769313e+308,
         -1.79769313e+308],
        [-1.15197559e+000, -1.79769313e+308, -1.79769313e+308,
         -1.79769313e+308],
        [-1.79769313e+308, -1.79769313e+308, -1.79769313e+308,
         -1.79769313e+308],
        [-1.79769313e+308, -1.79769313e+308, -1.79769313e+308,
         -1.79769313e+308],
        [-1.79769313e+308, -1.79769313e+308, -1.79769313e+308,
         -1.79769313e+308],
        [-1.79769313e+308, -1.79769313e+308, -1.79769313e+308,
         -1.79769313e+308],
        [-1.79769313e+308, -1.79769313e+308, -1.79769313e+308,
         -1.79769313e+308],
        [-1.79769313e+308, -1.79769313e+308, -1.79769313e+308,
         -1.79769313e+308],
        [-1.79769313e+308, -1.79769313e+308, -1.79769313e+308,
         -1.79769313e+308],

In [127]:
#np.exp(logits[:ldim,0]).sum()

In [128]:
#carry,_ = scan_fn((sample,logits,gumbel,states), (keys[1],1))
#carry[0]

In [129]:
#np.exp(carry[1][:,0]).sum()

In [130]:
working_space_samples

Array([[[ 0, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [-2, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [-2, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [-2, -2, -2, -2, -2, -2, -2, -2, -2, -2]],

       [[-2, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [-2, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [-2, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [-2, -2, -2, -2, -2, -2, -2, -2, -2, -2]],

       [[-2, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [-2, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [-2, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [-2, -2, -2, -2, -2, -2, -2, -2, -2, -2]],

       [[-2, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [-2, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [-2, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [-2, -2, -2, -2, -2, -2, -2, -2, -2, -2]],

       [[-2, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [-2, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [-2, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [-2, -2, -2, -2, -2, -2, -2, -2, -

In [131]:
%%timeit
key0=jrnd.split(keys[0],(numSamples,1))
sample,logits,gumbel,states  = sorting_gumble(*jax.vmap(init_work)(working_space_samples,working_space_logits,working_space_gumbel,key0))
init_carry = sample,logits,gumbel,states
res, _ = jax.lax.scan(scan_fn,init_carry,(keys[1:],jnp.arange(1,L)))

logitnew: [[-1.2955238  -1.33387066 -1.52334739 -1.40727683]]
logitnew: [[-inf -inf -inf -inf]]
logitnew: [[-inf -inf -inf -inf]]
logitnew: [[-inf -inf -inf -inf]]
logitnew: [[-inf -inf -inf -inf]]
logitnew: [[-inf -inf -inf -inf]]
logitnew: [[-inf -inf -inf -inf]]
logitnew: [[-inf -inf -inf -inf]]
logitnew: [[-inf -inf -inf -inf]]
logitnew: [[-inf -inf -inf -inf]]
logitnew: [[-inf -inf -inf -inf]]
logitnew: [[-inf -inf -inf -inf]]
logitnew: [[-inf -inf -inf -inf]]
logitnew: [[-inf -inf -inf -inf]]
logitnew: [[-inf -inf -inf -inf]]
logitnew: [[-inf -inf -inf -inf]]
gumbelnew: [ 0.77154807  0.51677635 -0.94433252  0.89978207]
gumbelnew: [-inf -inf -inf -inf]
gumbelnew: [-inf -inf -inf -inf]
gumbelnew: [-inf -inf -inf -inf]
gumbelnew: [-inf -inf -inf -inf]
gumbelnew: [-inf -inf -inf -inf]
gumbelnew: [-inf -inf -inf -inf]
gumbelnew: [-inf -inf -inf -inf]
gumbelnew: [-inf -inf -inf -inf]
gumbelnew: [-inf -inf -inf -inf]
gumbelnew: [-inf -inf -inf -inf]
gumbelnew: [-inf -inf -inf -inf]
gumb

In [132]:
res, _ = jax.lax.scan(scan_fn,init_carry,(keys[1:],jnp.arange(1,L)))

logitnew: [[-2.83664688 -2.81561158 -2.71724009 -2.80904987]]
logitnew: [[-2.60132734 -2.62324458 -2.82939931 -2.68877664]]
logitnew: [[-2.7702741  -2.77533414 -2.56185522 -2.79163938]]
logitnew: [[-3.0922986  -3.24943712 -2.37143365 -3.20791935]]
logitnew: [[-inf -inf -inf -inf]]
logitnew: [[-inf -inf -inf -inf]]
logitnew: [[-inf -inf -inf -inf]]
logitnew: [[-inf -inf -inf -inf]]
logitnew: [[-inf -inf -inf -inf]]
logitnew: [[-inf -inf -inf -inf]]
logitnew: [[-inf -inf -inf -inf]]
logitnew: [[-inf -inf -inf -inf]]
logitnew: [[-inf -inf -inf -inf]]
logitnew: [[-inf -inf -inf -inf]]
logitnew: [[-inf -inf -inf -inf]]
logitnew: [[-inf -inf -inf -inf]]
gumbelnew: [ 1.38755903  0.12912913  0.19384445 -0.02045952]
gumbelnew: [-1.02155533 -0.59969352 -1.38098139 -0.46135827]
gumbelnew: [-0.71541178 -1.34375645 -1.80284841 -1.74750028]
gumbelnew: [-2.21155242 -0.28802837 -1.31341298  1.12105521]
gumbelnew: [-inf -inf -inf -inf]
gumbelnew: [-inf -inf -inf -inf]
gumbelnew: [-inf -inf -inf -inf]
g

In [133]:
res[0][:,0,:]

Array([[1, 0, 3, 1, 2, 0, 1, 0, 0, 2],
       [3, 0, 3, 0, 0, 0, 0, 3, 1, 0],
       [0, 0, 2, 2, 2, 1, 1, 1, 0, 1],
       [0, 0, 2, 2, 2, 2, 0, 0, 0, 2],
       [0, 0, 2, 2, 2, 2, 0, 1, 1, 0],
       [0, 0, 2, 2, 2, 2, 0, 2, 0, 0],
       [0, 2, 0, 1, 0, 1, 2, 1, 2, 1],
       [3, 0, 0, 0, 1, 2, 2, 1, 1, 0],
       [0, 0, 2, 2, 2, 2, 1, 0, 0, 1],
       [0, 0, 2, 2, 2, 2, 0, 0, 2, 0],
       [0, 1, 0, 3, 1, 2, 2, 0, 1, 0],
       [0, 0, 2, 2, 2, 2, 0, 0, 1, 1],
       [0, 0, 2, 2, 2, 2, 0, 1, 0, 1],
       [1, 0, 3, 0, 1, 1, 2, 1, 1, 0],
       [0, 1, 0, 3, 1, 2, 0, 1, 2, 0],
       [1, 0, 3, 1, 2, 1, 0, 0, 1, 1]], dtype=int64)

In [134]:
res[2][:,0]

Array([ -1.57623229,  -4.5895654 ,  -8.15912877,  -8.68501153,
        -8.83944618,  -9.10722138,  -9.17772833,  -9.49407268,
        -9.60433575,  -9.73250806, -10.04772755, -10.66734484,
       -11.04830209, -11.64177463, -11.99340332, -12.87251494],      dtype=float64)

In [135]:
for x,y in zip(res[1][:,0],res[0][:,0,:]):
    #print(y.shape)
    print(psi(jnp.expand_dims(y,(0,1)))[0,0], x/2)
    print(np.isclose(psi(jnp.expand_dims(y,(0,1)))[0,0], x/2))

-5.782795408065023 -5.782795408065023
True
-5.817834188637212 -5.817834188637212
True
-5.297448756850904 -5.297448756850904
True
-4.948966578532409 -4.948966578532408
True
-4.779617326816258 -4.779617326816258
True
-4.4150578838405 -4.4150578838405
True
-6.440330790989108 -6.440330790989107
True
-5.518484590079414 -5.518484590079414
True
-4.630338078959005 -4.630338078959004
True
-4.992284390979331 -4.99228439097933
True
-5.127484974392141 -5.127484974392141
True
-4.9571962429599195 -4.9571962429599195
True
-4.7786484596843835 -4.7786484596843835
True
-5.865803890622967 -5.865803890622968
True
-5.847992876685945 -5.847992876685945
True
-5.7962596487921765 -5.796259648792176
True


In [43]:
np.exp(res[1][:,0]).sum(),np.exp(2*psi(jnp.expand_dims(res[0][:,0,:],0))).sum()

(0.0006674570094251267, 0.0006674570094251265)

In [44]:
(0.5*res[1][9:10,0]),psi(jnp.expand_dims(res[0][9:10,0,:],0))

(Array([-5.58716606], dtype=float64), Array([[-5.58716606]], dtype=float64))

In [45]:
res[0][:,0,:]

Array([[0, 0, 0, 0, 0, 0, 1, 3, 3, 3],
       [0, 0, 0, 0, 0, 0, 2, 2, 3, 3],
       [0, 0, 0, 0, 0, 0, 2, 3, 2, 3],
       [0, 0, 0, 0, 0, 0, 2, 3, 3, 2],
       [0, 0, 0, 0, 0, 0, 3, 1, 3, 3],
       [0, 0, 0, 0, 0, 0, 3, 2, 2, 3],
       [0, 0, 0, 0, 0, 0, 3, 2, 3, 2],
       [0, 0, 0, 0, 0, 0, 3, 3, 1, 3],
       [0, 0, 0, 0, 0, 0, 3, 3, 2, 2],
       [0, 0, 0, 0, 0, 0, 3, 3, 3, 1],
       [0, 0, 0, 0, 0, 1, 0, 3, 3, 3],
       [0, 0, 0, 0, 0, 1, 1, 2, 3, 3],
       [0, 0, 0, 0, 0, 1, 1, 3, 2, 3],
       [0, 0, 0, 0, 0, 1, 1, 3, 3, 2],
       [0, 0, 0, 0, 0, 1, 2, 1, 3, 3],
       [0, 0, 0, 0, 0, 1, 2, 2, 2, 3]], dtype=int64)

In [46]:
psi.net.init(key,homFock[0,0])


FrozenDict({
    params: {
        Head: {
            kernel: Array([[-0.73109729,  0.80943144, -0.53429929, -0.07685795],
                   [-0.06833654, -0.24768302,  0.7819772 , -0.64982771],
                   [-0.01642196,  0.1261088 , -0.2181712 ,  0.71729475],
                   [-0.19691397, -0.28328338, -0.64528551,  0.43487701],
                   [-0.19063595,  0.08516858,  0.11654353, -0.07250183]],      dtype=float64),
        },
        Neck: {
            bias: Array([0., 0., 0., 0., 0.], dtype=float64),
            kernel: Array([[-0.2402684 , -0.57522207,  0.74441163, -0.22276035, -0.26046138],
                   [ 0.54618611,  0.90540955, -0.38362381,  0.13328261, -0.04650405],
                   [-0.15561791,  0.56385177, -0.75360583, -0.76249225, -0.64436729],
                   [-0.30468884, -0.36444137,  0.39849115,  0.60606212,  0.72913632],
                   [-0.10827321, -0.23400804,  0.50961735, -0.01590819,  0.83923035]],      dtype=float64),
        },
  

In [47]:
psi.net

RWKV(
    # attributes
    L = 10
    Q = 10
    M = None
    LocalHilDim = 4
    dtype = float64
    order = 1
    num_layers = 2
    embedding_size = 5
    logProbFactor = 0.5
)

In [48]:
a = np.arange(24).reshape(4,6)
a[0]

array([0, 1, 2, 3, 4, 5])

In [49]:
psi.sample(numSamples=12,key=key0[0])

Array([[[2, 0, 2, 1, 2, 2, 0, 1, 0, 0],
        [0, 1, 0, 0, 0, 1, 2, 3, 0, 3],
        [0, 2, 3, 1, 2, 2, 0, 0, 0, 0],
        [1, 1, 3, 1, 0, 3, 1, 0, 0, 0],
        [1, 2, 2, 0, 2, 1, 0, 2, 0, 0],
        [0, 0, 0, 0, 2, 1, 2, 2, 0, 3],
        [1, 3, 3, 3, 0, 0, 0, 0, 0, 0],
        [2, 2, 3, 0, 0, 0, 2, 0, 0, 1],
        [1, 3, 2, 2, 2, 0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 1, 2, 3, 1, 2],
        [2, 1, 3, 2, 0, 2, 0, 0, 0, 0],
        [2, 2, 1, 0, 1, 1, 0, 0, 2, 1]]], dtype=int64)