In [37]:
from jax import numpy as jnp
import numpy as np
from jax.config import config; config.update("jax_enable_x64", True)
from jax import lax, ops, vmap, jit, grad, random, lax
from melange_lite.magnets.utils import *


# Tests
some consistency tests for magnets and MCMC samplers for magnets

## Test the equivalence of MCMC SMC Sampler weights and manual MCMC-imposed SMC Sampler weights

In [38]:
from melange_lite.magnets.ising_modell import IsingsModellSMCMCMCSampler, TrainableIsingsModellSMCSampler

define some parameters for SMCSampler

In [39]:
T=10
N=248
L=32

In [28]:
IW_parameters = jnp.hstack([jnp.ones(T)[..., jnp.newaxis], #J
                                                                         jnp.zeros(T)[..., jnp.newaxis], #h
                                                                         jnp.linspace(0,1,T)[..., jnp.newaxis] #beta
                                                                          ])

build the smc factory

In [40]:
smc_factory = TrainableIsingsModellSMCSampler(T=T, 
                                              N=N, 
                                              IW_parameters= IW_parameters,
                                              L=L, 
                                              full_scan=True,
                                              MCMC=False
                                              )

make a parameter dictionary for propagation

In [41]:
param_dict = {'seed': random.PRNGKey(2342), 'kernel_parameters': IW_parameters[1:]}

generate the positions at time `t=0`

In [42]:
X0s = smc_factory.M0(param_dict)

then generate the positions at time `t=1`

In [43]:
X1s = smc_factory.M(X0s, param_dict, 0)

compute the weights at time `t=1`

In [44]:
logG1s = smc_factory.logG(X0s, X1s, param_dict, 0)

define the importance weight function (vmapped) so we can manually compute the log weights.

In [45]:
vIW_energy_fn = vmap(smc_factory._IW_energy_fn, in_axes=(0,None))

manually compute the log weights from the $\pi_t$-invariant $k_t$ forward kernel.

In [46]:
MCMC_lws = -(vIW_energy_fn(X0s['x'], smc_factory.IW_parameters[1]) - vIW_energy_fn(X0s['x'], smc_factory.IW_parameters[0]))

assert that these are all close, and then the test is complete.

In [47]:
assert jnp.allclose(logG1s, MCMC_lws)

all done