In [1]:
import jVMC
import jax
#from jax.config import config
jax.config.update("jax_enable_x64", True)
import jax.random as jrnd

import jax.numpy as jnp
from functools import partial
import numpy as np


In [2]:
numSamples = 2**4
ldim = 4
L = 10

net = jVMC.nets.bosons.RWKV(L,L,LocalHilDim=ldim,embedding_size=5,num_layers=2)
#net = jVMC.nets.bosons.GPT(L,3,lDim=ldim,embeddingDim=4,depth=2,nHeads=2)

psi = jVMC.vqs.NQS(net)
homFock = jnp.ones((1,1,L),dtype=int)
psi(homFock)

Array([[-6.27901942]], dtype=float64)

In [3]:
shape_samples = (numSamples,ldim,L)
shape_logits = (numSamples,ldim)
shape_gumbel = (numSamples,ldim)
print(shape_samples,shape_logits)
working_space_samples = jnp.full(shape_samples,-2,dtype=jnp.int64)
working_space_logits = jnp.full(shape_logits,-jnp.inf,dtype=jnp.float64)
working_space_gumbel = jnp.full(shape_gumbel,-jnp.inf,dtype=jnp.float64)

working_space_samples = working_space_samples.at[0,0,0].set(-1)
working_space_logits = working_space_logits.at[0,0].set(0.)
working_space_gumbel = working_space_gumbel.at[0,0].set(0.)

key = jrnd.PRNGKey(1)
keys = jrnd.split(key,(numSamples,1))

(16, 4, 10) (16, 4)


In [4]:
print("space required:",(working_space_samples.nbytes+working_space_logits.nbytes+working_space_gumbel.nbytes)//2**10,"kB")

space required: 6 kB


In [5]:
def _workN(sample,logits,gumbel,states,key,position):
    newkey = jrnd.split(key[0],2)

    sample = jnp.array([sample[0].at[position].set(l) for l in jnp.arange(ldim)])
    #logitnew = logits[0] + jnp.log(jrnd.uniform(newkey[0],shape=(ldim,))) # here network output
    inputt = jnp.array([jnp.pad(sample[0,:-1],(1,0))])
    #jax.debug.print("{x}",x=states)
    para = psi.parameters
    #jax.debug.print("{x}",x=inputt[:,position])
    #logitnew, next_states = psi.net.apply(para,inputt,returnLogAmp=True)
    logitnew, next_states = psi.net.apply(para,inputt[:,position],block_states = states, output_state=True)
    
    #jax.debug.print("hello{x}\n",x=(next_states))
    #jax.debug.print("hello2{x}\n\n",x=next_states[0])
    

    gumbelnew = logitnew + jrnd.gumbel(newkey[1],shape=(ldim,))
    Z = jnp.max(gumbelnew)
    gumbelnew = jnp.nan_to_num(-jnp.log(jnp.exp(-gumbel[0])-jnp.exp(-Z)+jnp.exp(gumbelnew)),copy=True,nan=-jnp.inf)
    
    return sample,logitnew,gumbelnew,next_states

__workN = partial(_workN,position=0)
states = None
#print(states)

In [6]:
test,test_logit,test_gumbel,next_states = jax.vmap(__workN)(working_space_samples,working_space_logits,working_space_gumbel,states,keys)

In [7]:
print(test.shape)


(16, 4, 10)


In [8]:
print(len(next_states),type(next_states))
print(len(next_states[0]),type(next_states[0]))
print(len(next_states[0][0]),type(next_states[0][0]))
print(len(next_states[0][0][0]),type(next_states[0][0][0]))

print((next_states[0][0][0].shape))

a = zip(*next_states)
list(a)

2 <class 'list'>
2 <class 'tuple'>
4 <class 'tuple'>
16 <class 'jaxlib.xla_extension.ArrayImpl'>
(16, 5)


[((Array([[-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
          [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
          [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
          [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
          [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
          [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
          [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
          [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
          [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
          [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
          [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
          [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
          [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],

In [9]:
next_states

[((Array([[-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
          [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
          [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
          [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
          [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
          [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
          [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
          [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
          [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
          [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
          [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
          [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
          [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],

In [10]:
aa = np.array(next_states,dtype=object)
aa[0,0]

(Array([[-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
        [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
        [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
        [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
        [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
        [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
        [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
        [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
        [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
        [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
        [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
        [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
        [-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],
        [-0.23618801, -0.

In [11]:
len(next_states)

2

In [12]:
def _work(sample,logits,gumbel,key,position):
    newkey = jrnd.split(key[0],2)
    
    sample = jnp.array([sample[0].at[position].set(l) for l in jnp.arange(ldim)])
    logitnew = logits[0] + jnp.log(jrnd.uniform(newkey[0],shape=(ldim,))) # here network output
    gumbelnew = logitnew + jrnd.gumbel(newkey[1],shape=(ldim,))
    Z = jnp.max(gumbelnew)
    gumbelnew = jnp.nan_to_num(-jnp.log(jnp.exp(-gumbel[0])-jnp.exp(-Z)+jnp.exp(gumbelnew)),copy=True,nan=-jnp.inf)
    
    return sample,logitnew,gumbelnew
__work = partial(_work,position=0)
test,test_logit,test_gumbel = jax.vmap(__work)(working_space_samples,working_space_logits,working_space_gumbel,keys)

In [13]:
test.shape,test_logit.shape,test_gumbel.shape

((16, 4, 10), (16, 4), (16, 4))

In [14]:
indexes = jnp.argsort((-test_gumbel),axis=None)#.reshape(shape_gumbel)
indexes_states = indexes % numSamples
print(indexes.shape,indexes.reshape(shape_gumbel).shape)

(64,) (16, 4)


In [15]:
indexes

Array([ 2,  3,  1,  0,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
       34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
       51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63], dtype=int64)

In [16]:
test2 = test.reshape(-1,L)[indexes]
test2 = test2.reshape(ldim,numSamples,L)
test2 = jnp.swapaxes(test2,0,1)
test2

Array([[[ 2, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 0, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 0, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 0, -2, -2, -2, -2, -2, -2, -2, -2, -2]],

       [[ 3, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 1, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 1, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 1, -2, -2, -2, -2, -2, -2, -2, -2, -2]],

       [[ 1, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 2, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 2, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 2, -2, -2, -2, -2, -2, -2, -2, -2, -2]],

       [[ 0, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 3, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 3, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 3, -2, -2, -2, -2, -2, -2, -2, -2, -2]],

       [[ 0, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 0, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 0, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 0, -2, -2, -2, -2, -2, -2, -2, -

In [17]:
test3 = test.reshape(-1,L)[indexes]
test3 = test.reshape(numSamples,ldim,L)

test3

Array([[[ 0, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 1, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 2, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 3, -2, -2, -2, -2, -2, -2, -2, -2, -2]],

       [[ 0, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 1, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 2, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 3, -2, -2, -2, -2, -2, -2, -2, -2, -2]],

       [[ 0, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 1, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 2, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 3, -2, -2, -2, -2, -2, -2, -2, -2, -2]],

       [[ 0, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 1, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 2, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 3, -2, -2, -2, -2, -2, -2, -2, -2, -2]],

       [[ 0, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 1, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 2, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 3, -2, -2, -2, -2, -2, -2, -2, -

In [18]:
test = test.reshape(-1,L)[indexes]
test=test.reshape(ldim,numSamples,L)
test = jnp.swapaxes(test,0,1)
#test =jnp.take_along_axis(test.reshape(-1,L),indexes,axis=1)
test_logit =test_logit.ravel()[indexes]

test_logit =test_logit.reshape(ldim,numSamples).T

test_gumbel =test_gumbel.ravel()[indexes]
test_gumbel =test_gumbel.reshape(ldim,numSamples).T

In [19]:
vals, treedef  = jax.tree_util.tree_flatten(next_states)
vals_ordered = [v[indexes_states] for v in vals]
next_states = jax.tree_util.tree_unflatten(treedef,vals_ordered)


In [20]:
for v in vals:
    print(v.shape)

(16, 5)
(16, 5)
(16, 5)
(16, 5)
(16, 5)
(16, 5)
(16, 5)
(16, 5)
(16, 5)
(16, 5)


In [21]:
test.shape,test_logit.shape,test_gumbel.shape

((16, 4, 10), (16, 4), (16, 4))

In [22]:
test_gumbel

Array([[ 4.82767751e+000, -1.79769313e+308, -1.79769313e+308,
        -1.79769313e+308],
       [ 1.18292555e+000, -1.79769313e+308, -1.79769313e+308,
        -1.79769313e+308],
       [ 5.85091973e-001, -1.79769313e+308, -1.79769313e+308,
        -1.79769313e+308],
       [ 4.75331219e-001, -1.79769313e+308, -1.79769313e+308,
        -1.79769313e+308],
       [-1.79769313e+308, -1.79769313e+308, -1.79769313e+308,
        -1.79769313e+308],
       [-1.79769313e+308, -1.79769313e+308, -1.79769313e+308,
        -1.79769313e+308],
       [-1.79769313e+308, -1.79769313e+308, -1.79769313e+308,
        -1.79769313e+308],
       [-1.79769313e+308, -1.79769313e+308, -1.79769313e+308,
        -1.79769313e+308],
       [-1.79769313e+308, -1.79769313e+308, -1.79769313e+308,
        -1.79769313e+308],
       [-1.79769313e+308, -1.79769313e+308, -1.79769313e+308,
        -1.79769313e+308],
       [-1.79769313e+308, -1.79769313e+308, -1.79769313e+308,
        -1.79769313e+308],
       [-1.79769313e+

In [23]:
__work = partial(_work,position=1)
test,test_logit,test_gumbel = jax.vmap(__work)(test,test_logit,test_gumbel,keys)

In [24]:
test[0,:]

Array([[ 2,  0, -2, -2, -2, -2, -2, -2, -2, -2],
       [ 2,  1, -2, -2, -2, -2, -2, -2, -2, -2],
       [ 2,  2, -2, -2, -2, -2, -2, -2, -2, -2],
       [ 2,  3, -2, -2, -2, -2, -2, -2, -2, -2]], dtype=int64)

In [25]:
test_gumbel[10]

Array([-1.79769313e+308, -1.79769313e+308, -1.79769313e+308,
       -1.79769313e+308], dtype=float64)

In [26]:
para = psi.parameters

logitnew, next_states = psi.net.apply(para,jnp.array([0]),block_states = None, output_state=True)


In [27]:
len(next_states)
len(next_states[0])
len(next_states[0][0])

4

In [28]:
np.array(next_states,dtype=object)

array([[(Array([-0.23618801, -0.26505651, -0.21270127, -1.234682  ,  0.02869712],      dtype=float64), Array([-0.05376298, -0.09234596, -0.33510117, -0.00319198,  0.32256267],      dtype=float64), Array([1., 1., 1., 1., 1.], dtype=float64), Array([-0.6544959 , -0.06872965,  0.3672761 , -0.24597523, -0.037108  ],      dtype=float64)),
        Array([-0.3263443 , -0.27118838, -0.23936346, -1.17183447, -0.00444612],      dtype=float64)],
       [(Array([-0.40923051, -0.37972206, -0.23819815, -1.33588533, -0.07775159],      dtype=float64), Array([ 0.38243725,  0.41029255,  0.26731984, -0.32824794,  0.67079424],      dtype=float64), Array([1., 1., 1., 1., 1.], dtype=float64), Array([-0.35876541, -0.31570683,  0.12292459,  0.28810201, -0.08455723],      dtype=float64)),
        Array([-0.5166411 , -0.48439983,  0.01450164, -1.70214781,  0.15744088],      dtype=float64)]],
      dtype=object)

In [29]:
logitnew, next_states2 = psi.net.apply(para,jnp.array([0]),block_states = next_states, output_state=True)


In [30]:
b = np.array(next_states2,dtype=object)
b.shape

(2, 2)

In [31]:
max_particles = jnp.pad((ldim-1)*jnp.arange(1,L+1)[::-1],(0,1))[1::]
must_mask = 2 * jnp.tril(jnp.ones((ldim,ldim)),k=-1)
can_mask = jnp.flip(must_mask)
max_particles,ldim

(Array([27, 24, 21, 18, 15, 12,  9,  6,  3,  0], dtype=int64), 4)

In [32]:
@jax.jit
def _workN(sample,logits,gumbel,key,states,position):
    newkey = key #jrnd.split(key[0],2)
    particles_left = psi.net.Q - jnp.sum(sample[0]+jnp.abs(sample[0]))//2
    sample = jnp.array([sample[0].at[position].set(l) for l in jnp.arange(ldim)])
    inputt = jnp.array([jnp.pad(sample[0,:-1],(1,0))])
    #jax.debug.print("{x}",x=inputt.shape)
    para = psi.parameters
    logitnew, next_states = psi.net.apply(para,inputt[:,position],block_states = states, output_state=True)

    
    #must_give = jax.nn.relu(particles_left-psi.net.max_particles[position])
    must_give = jax.nn.relu(particles_left-max_particles[position])
    
    # number of particles that can be assigned
    can_give = jnp.minimum(particles_left, psi.net.LocalHilDim-1)
    # compute mask
    #mask = (psi.net.must_mask[must_give] + 
    #            psi.net.can_mask[can_give.astype(int)]) 
    mask = (must_mask[must_give] + 
                can_mask[can_give.astype(int)]) 
    #logit_renorm =  logitnew - mask ** jnp.inf
    
    logitnew =  jax.nn.log_softmax(logitnew - mask ** jnp.inf)
    logitnew = logits[0] + logitnew #- mask ** jnp.inf

    gumbelnew = logitnew + jrnd.gumbel(newkey[0],shape=(ldim,))
    #gumbelnew = logits[0] + jrnd.gumbel(newkey[0],shape=(ldim,))
    
    Z = jnp.max(gumbelnew)
    gumbelnew = jnp.nan_to_num(-jnp.log(jnp.exp(-gumbel[0])-jnp.exp(-Z)+jnp.exp(gumbelnew)),copy=True,nan=-jnp.inf)
    gumbelnew = gumbelnew - mask ** jnp.inf
    return sample, logitnew, gumbelnew, next_states
@jax.jit
def sorting_gumble(sample,logits,gumbel,states):
    indexes = jnp.argsort((-gumbel),axis=None)#.reshape(shape_gumbel)
    indexes_states = (indexes // ldim)[:numSamples]
    sample = sample.reshape(-1,L)[indexes]
    sample = sample.reshape(ldim,numSamples,L)
    sample = jnp.swapaxes(sample,0,1)
    
    logits = logits.ravel()[indexes]
    logits = logits.reshape(ldim,numSamples).T
    
    gumbel = gumbel.ravel()[indexes]
    gumbel = gumbel.reshape(ldim,numSamples).T

    vals, treedef  = jax.tree_util.tree_flatten(states)
    vals_ordered = [v[indexes_states] for v in vals]
    states = jax.tree_util.tree_unflatten(treedef,vals_ordered)
    return sample,logits,gumbel,states
@jax.jit
def scan_fn(carry, key):
    position = key[1]
    jax.debug.print("{x} position",x=position)
    keys = jrnd.split(key[0],(numSamples,1))
    sample = carry[0]
    logits = carry[1]
    gumbel = carry[2]
    states = carry[3]
    p_workN = partial(_workN,position=position)
    sample,logits,gumbel,states = jax.vmap(p_workN)(sample,logits,gumbel,keys,states)

    #### sorting gumble values
    

    
    return sorting_gumble(sample,logits,gumbel,states),None
    

In [33]:
shape_samples = (numSamples,ldim,L)
shape_logits = (numSamples,ldim)
shape_gumbel = (numSamples,ldim)
#print(shape_samples,shape_logits)
working_space_samples = jnp.full(shape_samples,-2,dtype=jnp.int64)
working_space_logits = jnp.full(shape_logits,-jnp.inf,dtype=jnp.float64)
working_space_gumbel = jnp.full(shape_gumbel,-jnp.inf,dtype=jnp.float64)

working_space_samples = working_space_samples.at[0,0,0].set(0)
working_space_logits = working_space_logits.at[0,0].set(0.)
working_space_gumbel = working_space_gumbel.at[0,0].set(0.)

key = jrnd.PRNGKey(12)
keys = jrnd.split(key,(L))
states = None
init_work = partial(_workN, position=0,states=states)

In [34]:
key0=jrnd.split(keys[0],(numSamples,1))
sample,logits,gumbel,states  = sorting_gumble(*jax.vmap(init_work)(working_space_samples,working_space_logits,working_space_gumbel,key0))
init_carry = sample,logits,gumbel,states

In [35]:
#np.exp(logits[:ldim,0]).sum()

In [36]:
#carry,_ = scan_fn((sample,logits,gumbel,states), (keys[1],1))
#carry[0]

In [38]:
#np.exp(carry[1][:,0]).sum()

In [40]:
#carry,_ = scan_fn(carry, (keys[3],2))

In [42]:
#carry[0][:,0,:],carry[1][:,0],carry[2][:,0],

In [43]:
res, _ = jax.lax.scan(scan_fn,init_carry,(keys[1:],jnp.arange(1,L)))

1 position
2 position
3 position
4 position
5 position
6 position
7 position
8 position
9 position


In [44]:
res[0][:,0,:]

Array([[0, 0, 0, 0, 0, 0, 1, 3, 3, 3],
       [0, 0, 0, 0, 0, 0, 2, 2, 3, 3],
       [0, 0, 0, 0, 0, 0, 2, 3, 2, 3],
       [0, 0, 0, 0, 0, 0, 2, 3, 3, 2],
       [0, 0, 0, 0, 0, 0, 3, 1, 3, 3],
       [0, 0, 0, 0, 0, 0, 3, 2, 2, 3],
       [0, 0, 0, 0, 0, 0, 3, 2, 3, 2],
       [0, 0, 0, 0, 0, 0, 3, 3, 1, 3],
       [0, 0, 0, 0, 0, 0, 3, 3, 2, 2],
       [0, 0, 0, 0, 0, 0, 3, 3, 3, 1],
       [0, 0, 0, 0, 0, 1, 0, 3, 3, 3],
       [0, 0, 0, 0, 0, 1, 1, 2, 3, 3],
       [0, 0, 0, 0, 0, 1, 1, 3, 2, 3],
       [0, 0, 0, 0, 0, 1, 1, 3, 3, 2],
       [0, 0, 0, 0, 0, 1, 2, 1, 3, 3],
       [0, 0, 0, 0, 0, 1, 2, 2, 2, 3]], dtype=int64)

In [45]:
for x,y in zip(res[1][:,0],res[0][:,0,:]):
    #print(y.shape)
    print(psi(jnp.expand_dims(y,(0,1)))[0,0], x/2)
    print(np.isclose(psi(jnp.expand_dims(y,(0,1)))[0,0], x/2))

-4.414934162882797 -4.414934162882797
True
-4.593647360239782 -4.593647360239781
True
-5.650431444156107 -5.650431444156107
True
-5.6805829302411235 -5.6805829302411235
True
-5.059114990879731 -5.059114990879731
True
-5.082922784586143 -5.082922784586143
True
-5.4604820466908155 -5.460482046690815
True
-5.5904007429047935 -5.5904007429047935
True
-5.583372101536961 -5.583372101536961
True
-5.58716605795933 -5.58716605795933
True
-4.642590220279781 -4.642590220279781
True
-4.9359754514687335 -4.9359754514687335
True
-5.453900325310379 -5.453900325310379
True
-5.344224327766906 -5.344224327766905
True
-5.3962779755061225 -5.3962779755061225
True
-4.953925927463026 -4.953925927463025
True


In [None]:
np.exp(res[1][:,0]).sum(),np.exp(2*psi(jnp.expand_dims(res[0][:,0,:],0))).sum()

In [None]:
(0.5*res[1][9:10,0]),psi(jnp.expand_dims(res[0][9:10,0,:],0))

In [None]:
res[0][:,0,:]

In [None]:
psi.net.init(key,homFock[0,0])


In [None]:
psi.net

In [None]:
a = np.arange(24).reshape(4,6)
a[0]

In [None]:
psi.sample(numSamples=12,key=key0[0])