In [1]:
import larix as lx
import pandas as pd
from larch.roles import P,X

import numpy as np
import xarray as xr
from pytest import approx


### larch.numba is experimental, and not feature-complete ###
 the first time you import on a new system, this package will
 compile optimized binaries for your machine, which may take 
 a little while, please be patient 

OMP: Info #273: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.


In [2]:
raw = pd.read_csv(lx.example_file('swissmetro.csv.gz'))
raw['SM_COST'] = raw['SM_CO'] * (raw["GA"]==0) 
raw['TRAIN_COST'] = raw.eval("TRAIN_CO * (GA == 0)") 
raw['TRAIN_COST_SCALED'] = raw['TRAIN_COST'] / 100
raw['TRAIN_TT_SCALED'] = raw['TRAIN_TT'] / 100

raw['SM_COST_SCALED'] = raw.eval('SM_COST / 100')
raw['SM_TT_SCALED'] = raw['SM_TT'] / 100

raw['CAR_CO_SCALED'] = raw['CAR_CO'] / 100
raw['CAR_TT_SCALED'] = raw['CAR_TT'] / 100
raw['CAR_AV_SP'] = raw.eval("CAR_AV * (SP!=0)")
raw['TRAIN_AV_SP'] = raw.eval("TRAIN_AV * (SP!=0)")

raw['keep'] = raw.eval("PURPOSE in (1,3) and CHOICE != 0")

In [3]:
raw

Unnamed: 0,GROUP,SURVEY,SP,ID,PURPOSE,FIRST,TICKET,WHO,LUGGAGE,AGE,...,TRAIN_COST,TRAIN_COST_SCALED,TRAIN_TT_SCALED,SM_COST_SCALED,SM_TT_SCALED,CAR_CO_SCALED,CAR_TT_SCALED,CAR_AV_SP,TRAIN_AV_SP,keep
0,2,0,1,1,1,0,1,1,0,3,...,48,0.48,1.12,0.52,0.63,0.65,1.17,1,1,True
1,2,0,1,1,1,0,1,1,0,3,...,48,0.48,1.03,0.49,0.60,0.84,1.17,1,1,True
2,2,0,1,1,1,0,1,1,0,3,...,48,0.48,1.30,0.58,0.67,0.52,1.17,1,1,True
3,2,0,1,1,1,0,1,1,0,3,...,40,0.40,1.03,0.52,0.63,0.52,0.72,1,1,True
4,2,0,1,1,1,0,1,1,0,3,...,36,0.36,1.30,0.42,0.63,0.84,0.90,1,1,True
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
10723,3,1,1,1192,4,1,7,1,0,5,...,13,0.13,1.48,0.17,0.93,0.56,1.56,1,1,False
10724,3,1,1,1192,4,1,7,1,0,5,...,12,0.12,1.48,0.16,0.96,0.70,0.96,1,1,False
10725,3,1,1,1192,4,1,7,1,0,5,...,16,0.16,1.48,0.16,0.93,0.56,0.96,1,1,False
10726,3,1,1,1192,4,1,7,1,0,5,...,16,0.16,1.78,0.17,0.96,0.91,0.96,1,1,False


In [4]:
data = lx.Dataset.construct.from_idco(raw).dc.query_cases('keep')

In [5]:
m1 = lx.Model(
    data.dc.set_altids([1,2,3])
)
m1.availability_co_vars = {
    1: "TRAIN_AV_SP",
    2: "SM_AV",
    3: "CAR_AV_SP",
}
m1.choice_co_code = 'CHOICE'

m1.utility_co[1] = P("ASC_TRAIN") + X("TRAIN_COST_SCALED") * P("B_COST")
m1.utility_co[2] = X("SM_COST_SCALED") * P("B_COST")
m1.utility_co[3] = P("ASC_CAR") + X("CAR_CO_SCALED") * P("B_COST")

In [6]:
m2 = lx.Model(
    data.dc.set_altids([1,2,3])
)
m2.availability_co_vars = {
    1: "TRAIN_AV_SP",
    2: "SM_AV",
    3: "CAR_AV_SP",
}
m2.choice_co_code = 'CHOICE'

m2.utility_co[1] = P("ASC_TRAIN") + X("TRAIN_TT_SCALED") * P("B_TIME") + X("TRAIN_COST_SCALED") * P("B_COST")
m2.utility_co[2] = X("SM_TT_SCALED") * P("B_TIME") + X("SM_COST_SCALED") * P("B_COST")
m2.utility_co[3] = P("ASC_CAR") + X("CAR_TT_SCALED") * P("B_TIME") + X("CAR_CO_SCALED") * P("B_COST")

In [7]:
mk = lx.Model(
    data.dc.set_altids([101,102])
)
mk.utility_co[102] = P("W_OTHER")

In [8]:
m1.pf

Unnamed: 0,value,initvalue,nullvalue,minimum,maximum,holdfast,note
ASC_CAR,0.0,0.0,0.0,-inf,inf,0,
ASC_TRAIN,0.0,0.0,0.0,-inf,inf,0,
B_COST,0.0,0.0,0.0,-inf,inf,0,


In [9]:
b = lx.LatentClass(
    mk, 
    {101:m1, 102:m2}, 
    datatree=data.dc.set_altids([1,2,3]),
)

In [10]:
b.reflow_data_arrays()

In [11]:
assert b.groupid is None

In [12]:
b.jax_loglike(b.pvals)



DeviceArray(-6964.699, dtype=float32)

In [13]:
b.jax_loglike(b.pvals)

DeviceArray(-6964.699, dtype=float32)

In [14]:
assert b.jax_loglike(b.pvals) == approx(-6964.6445, rel=1e-4)

In [15]:
b.jax_loglike(b.pvals)

DeviceArray(-6964.699, dtype=float32)

In [16]:
b.jax_d_loglike(b.pvals) # 5.25s

DeviceArray([  -98.999954, -1541.4995  ,  -224.60873 ,  -923.5082  ,     0.      ], dtype=float32)

In [17]:
b.jax_d_loglike(b.pvals) # 5.25s

DeviceArray([  -98.999954, -1541.4995  ,  -224.60873 ,  -923.5082  ,     0.      ], dtype=float32)

In [18]:
#%timeit b.jax_d_loglike(b.pvals) # 15ms

In [19]:
result = b.jax_maximize_loglike()

In [20]:
assert result.loglike == approx(-5208.49609375, rel=1e-5)

In [21]:
b.parameters.to_dataframe()

Unnamed: 0_level_0,value,initvalue,nullvalue,minimum,maximum,holdfast
param_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
ASC_CAR,0.12459,0.0,0.0,-inf,inf,0
ASC_TRAIN,-0.397589,0.0,0.0,-inf,inf,0
B_COST,-1.264065,0.0,0.0,-inf,inf,0
B_TIME,-2.797749,0.0,0.0,-inf,inf,0
W_OTHER,1.094482,0.0,0.0,-inf,inf,0


In [22]:
b.jax_param_cov(b.pvals);

2022-03-05 13:27:57.085315: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:55] 
********************************
Slow compile?  XLA was built without compiler optimizations, which can be slow.  Try rebuilding with -c opt.
Compiling module jit_func__16.7038
********************************


In [23]:
b.parameters.to_dataframe()

Unnamed: 0_level_0,value,initvalue,nullvalue,minimum,maximum,holdfast,std_err
param_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
ASC_CAR,0.12459,0.0,0.0,-inf,inf,0,0.050483
ASC_TRAIN,-0.397589,0.0,0.0,-inf,inf,0,0.060846
B_COST,-1.264065,0.0,0.0,-inf,inf,0,0.061178
B_TIME,-2.797749,0.0,0.0,-inf,inf,0,0.175583
W_OTHER,1.094482,0.0,0.0,-inf,inf,0,0.116456


In [24]:
assert b.pstderr == approx(np.array([ 
    0.050481,  0.060852,  0.061178,  0.17538 ,  0.116116
]), rel=5e-3)

In [25]:
b.pstderr

array([ 0.050483,  0.060846,  0.061178,  0.175583,  0.116456], dtype=float32)

In [26]:
assert b.pvals == approx(np.array([ 0.124529, -0.397936, -1.264057, -2.799115,  1.091897]), rel=5e-3)