In [2]:
import numpyro
import numpyro.distributions as dist
from numpyro.handlers import seed, trace
import jax
import jax.numpy as jnp
from jax import random
from jax.scipy.integrate import trapezoid

from sim_utils import *

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
key, nworms, box_size = random.PRNGKey(0), 4, 64

2024-07-25 16:33:27.690680: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.5.82). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [4]:
dist.Normal(0, 1).sample(key, (nworms, 2))

Array([[ 0.08086783, -0.38624713],
       [-0.37565565,  1.6689739 ],
       [-1.2758198 ,  2.1192005 ],
       [-0.85821223,  1.1305932 ]], dtype=float32)

In [5]:
@partial(jax.jit, static_argnums=(1, 2))
def normal(key, loc=0.0, scale=1.0):
    key, sample_key = jax.random.split(key)
    samples = loc + jax.random.normal(sample_key, shape=(nworms,)) * scale
    return key, samples

@partial(jax.jit, static_argnums=(1, 2))
def uniform(key, low=0.0, high=1.0):
    key, sample_key = jax.random.split(key)
    samples = jax.random.uniform(sample_key, shape=(nworms,), minval=low, maxval=high)
    return key, samples

In [9]:

def sample(nworms):
    with numpyro.plate('nworms', nworms):
        L = numpyro.sample('L', dist.Uniform(30, 45))
        A = numpyro.sample('A', dist.Normal(1, 0.1))
        T = numpyro.sample('T', dist.Normal(0.8, 0.1))
        kw = numpyro.sample('kw', dist.Uniform(0, 2 * jnp.pi))
        ku = numpyro.sample('ku', dist.Normal(jnp.pi, 1))
        
        inc = numpyro.sample('inc', dist.Uniform(0, 2 * jnp.pi))
        dr = numpyro.sample('dr', dist.Uniform(0.2, 0.8))
        phase_1 = numpyro.sample('phase_1', dist.Uniform(0, 2 * jnp.pi))
        phase_2 = numpyro.sample('phase_2', dist.Uniform(0, 2 * jnp.pi))
        phase_3 = numpyro.sample('phase_3', dist.Normal(0, 0.1))
        alpha = numpyro.sample('alpha', dist.Normal(4, 4))

        alpha = jnp.abs(alpha + 1.0)
        half_box = box_size // 2
        x0 = numpyro.sample('x0', dist.Uniform(-half_box, half_box))
        y0 = numpyro.sample('y0', dist.Uniform(-half_box, half_box))

        params = {'L': L, 'A': A, 'T': T, 'kw': kw, 'ku': ku, 'inc': inc, 'dr': dr, 'phase_1': phase_1, 'phase_2': phase_2, 'phase_3': phase_3, 'alpha': alpha, 'x0': x0, 'y0': y0}
        duration = 0.55
        snapshots = 10
        kpoints = 12

        sim_fn = partial(
            worm_simulation,
            duration=duration,
            snapshots=snapshots,
            kpoints=kpoints,
        )
        worm = jax.vmap(sim_fn, out_axes=1)(params)
        worm = worm + box_size // 2
        numpyro.deterministic('worm', worm)


In [10]:
# seed and trace

key = random.PRNGKey(0)
seeded_model = seed(sample, key)
tr = trace(seeded_model).get_trace(nworms)

In [12]:
tr['worm']['value']

Array([[[[ 3.66111603e+01, -2.42812347e+00],
         [ 3.81319466e+01,  8.24419022e-01],
         [ 4.00192223e+01,  3.87892532e+00],
         [ 4.23425522e+01,  6.61643410e+00],
         [ 4.50739288e+01,  8.94697380e+00],
         [ 4.81155739e+01,  1.08549080e+01],
         [ 5.13465157e+01,  1.24210644e+01],
         [ 5.46573181e+01,  1.38104591e+01],
         [ 5.79527969e+01,  1.52358055e+01],
         [ 6.11287079e+01,  1.69107380e+01],
         [ 6.40488129e+01,  1.89999580e+01],
         [ 6.65506516e+01,  2.15753422e+01]],

        [[ 6.35205841e+01,  5.09313583e+01],
         [ 6.05324860e+01,  5.15613098e+01],
         [ 5.76012192e+01,  5.24175873e+01],
         [ 5.50444870e+01,  5.40875168e+01],
         [ 5.34668465e+01,  5.67022095e+01],
         [ 5.34034424e+01,  5.97553253e+01],
         [ 5.48348236e+01,  6.24528580e+01],
         [ 5.72421341e+01,  6.43318024e+01],
         [ 6.00757599e+01,  6.54702759e+01],
         [ 6.30258636e+01,  6.62592163e+01],
        