In [1]:
import larix as lx
import pandas as pd
from larch.roles import P,X

import numpy as np
import xarray as xr
from pytest import approx


### larch.numba is experimental, and not feature-complete ###
 the first time you import on a new system, this package will
 compile optimized binaries for your machine, which may take 
 a little while, please be patient 

OMP: Info #273: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.


In [2]:
raw = pd.read_csv(lx.example_file('swissmetro.csv.gz'))
raw['SM_COST'] = raw['SM_CO'] * (raw["GA"]==0) 
raw['TRAIN_COST'] = raw.eval("TRAIN_CO * (GA == 0)") 
raw['TRAIN_COST_SCALED'] = raw['TRAIN_COST'] / 100
raw['TRAIN_TT_SCALED'] = raw['TRAIN_TT'] / 100

raw['SM_COST_SCALED'] = raw.eval('SM_COST / 100')
raw['SM_TT_SCALED'] = raw['SM_TT'] / 100

raw['CAR_CO_SCALED'] = raw['CAR_CO'] / 100
raw['CAR_TT_SCALED'] = raw['CAR_TT'] / 100
raw['CAR_AV_SP'] = raw.eval("CAR_AV * (SP!=0)")
raw['TRAIN_AV_SP'] = raw.eval("TRAIN_AV * (SP!=0)")

raw['keep'] = raw.eval("PURPOSE in (1,3) and CHOICE != 0")

In [3]:
raw

Unnamed: 0,GROUP,SURVEY,SP,ID,PURPOSE,FIRST,TICKET,WHO,LUGGAGE,AGE,...,TRAIN_COST,TRAIN_COST_SCALED,TRAIN_TT_SCALED,SM_COST_SCALED,SM_TT_SCALED,CAR_CO_SCALED,CAR_TT_SCALED,CAR_AV_SP,TRAIN_AV_SP,keep
0,2,0,1,1,1,0,1,1,0,3,...,48,0.48,1.12,0.52,0.63,0.65,1.17,1,1,True
1,2,0,1,1,1,0,1,1,0,3,...,48,0.48,1.03,0.49,0.60,0.84,1.17,1,1,True
2,2,0,1,1,1,0,1,1,0,3,...,48,0.48,1.30,0.58,0.67,0.52,1.17,1,1,True
3,2,0,1,1,1,0,1,1,0,3,...,40,0.40,1.03,0.52,0.63,0.52,0.72,1,1,True
4,2,0,1,1,1,0,1,1,0,3,...,36,0.36,1.30,0.42,0.63,0.84,0.90,1,1,True
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
10723,3,1,1,1192,4,1,7,1,0,5,...,13,0.13,1.48,0.17,0.93,0.56,1.56,1,1,False
10724,3,1,1,1192,4,1,7,1,0,5,...,12,0.12,1.48,0.16,0.96,0.70,0.96,1,1,False
10725,3,1,1,1192,4,1,7,1,0,5,...,16,0.16,1.48,0.16,0.93,0.56,0.96,1,1,False
10726,3,1,1,1192,4,1,7,1,0,5,...,16,0.16,1.78,0.17,0.96,0.91,0.96,1,1,False


In [4]:
data = lx.Dataset.construct.from_idco(raw).dc.query_cases('keep')

In [5]:
m1 = lx.Model(
    data.dc.set_altids([1,2,3])
)
m1.availability_co_vars = {
    1: "TRAIN_AV_SP",
    2: "SM_AV",
    3: "CAR_AV_SP",
}
m1.choice_co_code = 'CHOICE'

m1.utility_co[1] = P("ASC_TRAIN") + X("TRAIN_COST_SCALED") * P("B_COST")
m1.utility_co[2] = X("SM_COST_SCALED") * P("B_COST")
m1.utility_co[3] = P("ASC_CAR") + X("CAR_CO_SCALED") * P("B_COST")

In [6]:
m1.parameters

In [7]:
m2 = lx.Model(
    data.dc.set_altids([1,2,3])
)
m2.availability_co_vars = {
    1: "TRAIN_AV_SP",
    2: "SM_AV",
    3: "CAR_AV_SP",
}
m2.choice_co_code = 'CHOICE'

m2.utility_co[1] = P("ASC_TRAIN") + X("TRAIN_TT_SCALED") * P("B_TIME") + X("TRAIN_COST_SCALED") * P("B_COST")
m2.utility_co[2] = X("SM_TT_SCALED") * P("B_TIME") + X("SM_COST_SCALED") * P("B_COST")
m2.utility_co[3] = P("ASC_CAR") + X("CAR_TT_SCALED") * P("B_TIME") + X("CAR_CO_SCALED") * P("B_COST")

In [8]:
m3 = lx.Model(
    data.dc.set_altids([1,2,3])
)
m3.availability_co_vars = {
    1: "TRAIN_AV_SP",
    2: "SM_AV",
    3: "CAR_AV_SP",
}
m3.choice_co_code = 'CHOICE'

m3.utility_co[1] = X("TRAIN_COST_SCALED") * P("Z_COST")
m3.utility_co[2] = X("SM_COST_SCALED") * P("Z_COST")
m3.utility_co[3] = X("CAR_CO_SCALED") * P("Z_COST")

m3.groupid = 'ID'

In [9]:
import jax.numpy as jnp

In [10]:
# m3.unmangle()
# m3.pvals = [-10000]

In [11]:
# from larix.folding import fold_dataset
# m3.dataset = fold_dataset(m3.dataset, data['ID'])
# m3.reflow_data_arrays

In [12]:
mk = lx.Model(
    data.dc.set_altids([101, 102, 103])
)
mk.utility_co[102] = P("W_OTHER")
mk.utility_co[103] = P("W_COST")

In [13]:
b = lx.LatentClass(
    mk, 
    {101:m1, 102:m2, 103:m3}, 
    datatree=data.dc.set_altids([1,2,3]),
    groupid="ID",
)

In [14]:
b

<larix.latent_class.LatentClass at 0x14317f760>

In [15]:
b.lock(Z_COST=-10000)

In [16]:
b.reflow_data_arrays()

In [17]:
b.groupid

'ID'

In [18]:
b.jax_loglike(b.pvals)



DeviceArray(-6867.245, dtype=float32)

In [19]:
b.jax_loglike(np.zeros_like(b.pvals))

DeviceArray(-6964.664, dtype=float32)

In [20]:
assert b.jax_loglike(b.pvals) == approx(-6867.245, rel=1e-4)

In [21]:
b.jax_loglike(b.pvals)

DeviceArray(-6867.245, dtype=float32)

In [22]:
b.jax_d_loglike(b.pvals) # 5.25s

2022-03-05 13:33:17.076944: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:55] 
********************************
Slow compile?  XLA was built without compiler optimizations, which can be slow.  Try rebuilding with -c opt.
Compiling module jit_func__20.7357
********************************


DeviceArray([-1.104774e+02, -1.545956e+03, -2.189165e+01, -9.183448e+02, -1.658510e+02,  8.292550e+01, -1.490116e-08],            dtype=float32)

In [23]:
b.jax_d_loglike(b.pvals) # 5.25s

DeviceArray([-1.104774e+02, -1.545956e+03, -2.189165e+01, -9.183448e+02, -1.658510e+02,  8.292550e+01, -1.490116e-08],            dtype=float32)

In [24]:
result = b.jax_maximize_loglike()

In [25]:
result

     jac: array([ 1.723166e-01, -3.695679e-02, -7.270813e-02,  7.333374e-02, -1.975441e-02,  2.137756e-02, -2.607703e-08])
 loglike: -4474.47900390625
 message: 'Optimization terminated successfully'
    nfev: 44
     nit: 15
    njev: 15
  status: 0
 success: True
       x: array([ 6.079767e-02, -9.362056e-01, -1.159657e+00, -3.095285e+00, -7.734761e-01,  1.155985e+00, -1.000000e+04])

In [26]:
assert result.loglike == approx(-4474.478515625, rel=1e-5)

In [27]:
result.loglike

-4474.47900390625

In [28]:
b.parameters.to_dataframe()

Unnamed: 0_level_0,value,initvalue,nullvalue,minimum,maximum,holdfast
param_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
ASC_CAR,0.060798,0.0,0.0,-inf,inf,0
ASC_TRAIN,-0.936206,0.0,0.0,-inf,inf,0
B_COST,-1.159657,0.0,0.0,-inf,inf,0
B_TIME,-3.095285,0.0,0.0,-inf,inf,0
W_COST,-0.773476,0.0,0.0,-inf,inf,0
W_OTHER,1.155985,0.0,0.0,-inf,inf,0
Z_COST,-10000.0,-10000.0,0.0,-inf,inf,1


In [29]:
se, hess, ihess = b.jax_param_cov(b.pvals)

2022-03-05 13:33:35.991804: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:55] 
********************************
Slow compile?  XLA was built without compiler optimizations, which can be slow.  Try rebuilding with -c opt.
Compiling module jit_func__21.11151
********************************


In [30]:
b.parameters.to_dataframe()

Unnamed: 0_level_0,value,initvalue,nullvalue,minimum,maximum,holdfast,std_err
param_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
ASC_CAR,0.060798,0.0,0.0,-inf,inf,0,0.048158
ASC_TRAIN,-0.936206,0.0,0.0,-inf,inf,0,0.069796
B_COST,-1.159657,0.0,0.0,-inf,inf,0,0.069555
B_TIME,-3.095285,0.0,0.0,-inf,inf,0,0.106282
W_COST,-0.773476,0.0,0.0,-inf,inf,0,0.161079
W_OTHER,1.155985,0.0,0.0,-inf,inf,0,0.11945
Z_COST,-10000.0,-10000.0,0.0,-inf,inf,1,0.0


In [31]:
assert b.pstderr == approx(np.array([ 
    0.048158,  0.069796,  0.069555,  0.106282,  0.161079,  0.11945 ,  0.
]), rel=5e-3)

In [32]:
b.pstderr

array([ 0.048158,  0.069796,  0.069555,  0.106282,  0.161079,  0.11945 ,  0.      ], dtype=float32)

In [33]:
assert b.pvals == approx(np.array([ 
    6.079781e-02, -9.362056e-01, -1.159657e+00, -3.095285e+00, -7.734768e-01,  1.155985e+00, -1.000000e+04
]), rel=5e-3)

In [34]:
b.pvals

array([ 6.079767e-02, -9.362056e-01, -1.159657e+00, -3.095285e+00, -7.734761e-01,  1.155985e+00, -1.000000e+04])