In [None]:
import jax
import jax.numpy as jnp
import numpy as np
import xarray as xr

import larch as lx


def mtc7():
    d = lx.examples.MTC()
    y = d.icase[:7]
    y['groupid'] = xr.DataArray([1,1,1,2,2,2,2], dims=d.dc.CASEID)
    return y


def model7():
    d = mtc7()
    m = lx.Model(d)
    P, X, PX = lx.P, lx.X, lx.PX
    m.utility_co[2] = P("ASC_SR2") + P("hhinc#2") * X("hhinc")
    m.utility_co[3] = P("ASC_SR3P") + P("hhinc#3") * X("hhinc")
    m.utility_co[4] = P("ASC_TRAN") + P("hhinc#4") * X("hhinc")
    m.utility_co[5] = P("ASC_BIKE") + P("hhinc#5") * X("hhinc")
    m.utility_co[6] = P("ASC_WALK") + P("hhinc#6") * X("hhinc")
    m.utility_ca = PX("tottime") + PX("totcost")
    m.availability_var = 'avail'
    m.choice_ca_var = 'chose'
    m.title = "MTC Example 1 (Simple MNL)"
    m.mix_parameter("tottime", "tottime_s")
    m.groupid = 'groupid'
    m.n_draws = 5000
    return m

m = model7()

In [None]:
m._make_random_draws(n_draws=m.n_draws)

In [None]:
def _get_jnp_array(dataset, name):
    if name not in dataset:
        return None
    return jnp.asarray(dataset[name])


co = _get_jnp_array(m.dataset, 'co')
ca = _get_jnp_array(m.dataset, 'ca')
av = _get_jnp_array(m.dataset, 'av')
ch = _get_jnp_array(m.dataset, 'ch')


bund = dict(ca=ca, co=co, ch=ch, av=av)
bund0 = dict(ca=ca[0], co=co[0], ch=ch[0], av=av[0])

gbund = dict(rk=jax.random.PRNGKey(123))

In [None]:
m._jax_loglike(m.pvals, bund0, dict(rk=jax.random.PRNGKey(122)), 1059)

In [None]:
m.jax_loglike(m.pvals)

In [None]:
m.jax_d_loglike_total(m.pvals)

In [None]:
%timeit m.jax_d_loglike_total(m.pvals)

In [None]:
%timeit m.jax_loglike(m.pvals)

In [None]:
m.n_draws = 5000

In [None]:
stop

In [None]:
jax.random.split(jax.random.PRNGKey(123), (5,5))

In [None]:
xr.DataArray(jax.random.split(jax.random.PRNGKey(123), 5), dims=('c5','two'))

In [None]:
jax.vmap(jax.random.split, in_axes=(0, None))(jax.random.PRNGKey(123), 5)

In [None]:
jax.vmap(jax.random.split, in_axes=(0, None))(
    jax.random.split(jax.random.PRNGKey(123), 3),
    5
).shape

In [None]:
def onekey(key, n):
    return key

jax.random.PRNGKey(123).ndim

In [None]:
rk = jax.random.split(jax.random.PRNGKey(123), 1)


In [None]:
vsplit = jax.vmap(jax.random.split, in_axes=(0, None))

In [None]:
def ssplit(key, shapes):
    if key.ndim == 1:
        new_keys = jax.random.split(key, shapes[0])
    else:
        new_keys = jax.vmap(jax.random.split, in_axes=(0, None))(key, shapes[0])
    return new_keys, shapes[1:]

k, s = jax.random.PRNGKey(123), [5,3]

k, s = ssplit(k, s)
k, s = ssplit(k, s)
k, s

In [None]:
def vsplit(key_array, shapes):
    new_keys = jax.vmap(
        jax.random.split, in_axes=(0, None)
    )(key_array, shapes[0])
    return new_keys, shapes[1:]

In [None]:
vsplit(jax.random.PRNGKey(123), [5,3])

In [None]:
vsplit = jax.vmap(vsplit, in_axes=(0, None))

In [None]:
rk1 = vsplit(rk, 5)


In [None]:
jax.vmap(jax.random.split, in_axes=(0, None))(
    jax.random.split(jax.random.PRNGKey(123), 1),
    5
).shape

In [None]:
jax.vmap(m._jax_loglike, in_axes=(None, 0))(m.pvals, bund)

In [None]:
m.pvals = 'null'
m.pvals = {'tottime': -1}

m.jax_loglike(m.pvals)

In [None]:
m.jax_loglike_casewise(m.pvals)

In [None]:
lx.examples.MTC().icase[:7]['chose']

In [None]:
m.pvals = 'null'
m.pvals = {"hhinc#5": -1}

In [None]:
np.asarray(m.jax_utility(m.pvals))

In [None]:
m.datatree.root_dataset['tottime'].to_numpy()

In [None]:
inf = np.inf
np.array([
    [  0.,   0.,   0.,   0.,   0., -inf,   0.],
    [  0.,   0.,   0.,   0.,   0., -inf,   0.],
    [  0.,   0.,   0.,   0., -inf, -inf,   0.],
    [  0.,   0.,   0.,   0., -inf, -inf,   0.],
    [-inf,   0.,   0.,   0.,   0., -inf,   0.],
], dtype=np.float32)