In [1]:
import larix as lx
import jax
import jax.numpy as jnp
from pytest import approx


### larch.numba is experimental, and not feature-complete ###
 the first time you import on a new system, this package will
 compile optimized binaries for your machine, which may take 
 a little while, please be patient 

OMP: Info #273: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.


In [2]:
d = lx.examples.MTC(format='dataset')

In [3]:
m = lx.Model(d)
from larch import P, X, PX
m.utility_co[2] = P("ASC_SR2")  + P("hhinc#2") * X("hhinc")
m.utility_co[3] = P("ASC_SR3P") + P("hhinc#3") * X("hhinc")
m.utility_co[4] = P("ASC_TRAN") + P("hhinc#4") * X("hhinc")
m.utility_co[5] = P("ASC_BIKE") + P("hhinc#5") * X("hhinc")
m.utility_co[6] = P("ASC_WALK") + P("hhinc#6") * X("hhinc")

m.utility_ca = PX("tottime") + PX("totcost")

m.availability_var = 'avail'
m.choice_ca_var = 'chose'

m.title = "MTC Example 1 (Simple MNL)"
r = m.maximize_loglike(quiet=True)

In [4]:
m.jax_loglike(m.pvals)



DeviceArray(-3626.186, dtype=float32)

In [5]:
assert m.jax_loglike(m.pvals) == approx(-3626.186)

In [6]:
m.jax_d_loglike(m.pvals)

DeviceArray([ 8.647442e-04, -4.024506e-04, -1.380444e-03, -1.451683e-02,  3.976822e-03, -4.386186e-02, -9.957599e-02,
             -7.297770e-01,  3.342795e-02,  1.698433e-01, -2.316932e+00, -2.219502e-02], dtype=float32)

In [7]:
assert m.d_loglike(r.x/2) == approx([  
    -100.668873,     -427.881921,   -346.047421,   
    -176.396743,      -69.950285, -25639.434015, 
    -20581.626877, -10752.088732,  -5810.403543,  
    -4360.946795,   54115.923707, -16552.567557,
])

In [8]:
assert m.jax_d_loglike(np.asarray(r.x/2)) == approx([  
    -100.668873,     -427.881921,   -346.047421,   
    -176.396743,      -69.950285, -25639.434015, 
    -20581.626877, -10752.088732,  -5810.403543,  
    -4360.946795,   54115.923707, -16552.567557,
], rel=1e-5)

In [9]:
m.pvals = 'null'

In [10]:
assert all(m.pf['value'] == 0)

In [11]:
rj = m.jax_maximize_loglike()

In [12]:
rj

     jac: array([ 2.272785e-02, -3.063679e-02, -8.840561e-03,  1.739502e-03, -6.566763e-03, -1.223307e+00, -1.308636e-01,
       -7.242671e-01,  1.559882e+00,  9.642785e-01, -5.636771e+00, -6.563418e-01])
 loglike: -3626.18603515625
 message: 'Optimization terminated successfully'
    nfev: 119
     nit: 39
    njev: 39
  status: 0
 success: True
       x: array([-2.376162e+00, -2.177859e+00, -3.724838e+00, -6.710735e-01, -2.063628e-01, -2.171617e-03,  3.547370e-04,
       -5.285348e-03, -1.282182e-02, -9.697886e-03, -4.919956e-03, -5.133474e-02])

In [13]:
assert rj.loglike == approx(-3626.18603515625)

In [14]:
h = m.jax_d2_loglike(m.pvals)

In [15]:
np.sqrt(np.linalg.pinv(h).diagonal())

array([ 3.045759e-01,  1.046350e-01,  1.776854e-01,  1.325887e-01,  1.941035e-01,  1.553271e-03,  2.537732e-03,
        1.828749e-03,  5.326436e-03,  3.033601e-03,  2.388836e-04,  3.099270e-03], dtype=float32)

In [16]:
ihess = m.jax_invhess_loglike(m.pvals)

In [17]:
m.jax_param_cov(m.pvals);

In [18]:
m.pstderr

array([ 3.045759e-01,  1.046351e-01,  1.776854e-01,  1.325887e-01,  1.941035e-01,  1.553271e-03,  2.537732e-03,
        1.828749e-03,  5.326436e-03,  3.033601e-03,  2.388836e-04,  3.099270e-03], dtype=float32)

In [19]:
m.pf

Unnamed: 0,value,initvalue,nullvalue,minimum,maximum,holdfast,note,best,std_err
ASC_BIKE,-2.376162,0.0,0.0,-inf,inf,0,,-2.376328,0.304576
ASC_SR2,-2.177859,0.0,0.0,-inf,inf,0,,-2.178014,0.104635
ASC_SR3P,-3.724838,0.0,0.0,-inf,inf,0,,-3.725078,0.177685
ASC_TRAN,-0.671074,0.0,0.0,-inf,inf,0,,-0.670861,0.132589
ASC_WALK,-0.206363,0.0,0.0,-inf,inf,0,,-0.206775,0.194104
hhinc#2,-0.002172,0.0,0.0,-inf,inf,0,,-0.00217,0.001553
hhinc#3,0.000355,0.0,0.0,-inf,inf,0,,0.000358,0.002538
hhinc#4,-0.005285,0.0,0.0,-inf,inf,0,,-0.005286,0.001829
hhinc#5,-0.012822,0.0,0.0,-inf,inf,0,,-0.012808,0.005326
hhinc#6,-0.009698,0.0,0.0,-inf,inf,0,,-0.009686,0.003034
