In [None]:
import larix as lx
import pandas as pd
from larch.roles import P,X

import numpy as np
import xarray as xr
from pytest import approx

In [None]:
raw = pd.read_csv(lx.example_file('swissmetro.csv.gz'))
raw['SM_COST'] = raw['SM_CO'] * (raw["GA"]==0) 
raw['TRAIN_COST'] = raw.eval("TRAIN_CO * (GA == 0)") 
raw['TRAIN_COST_SCALED'] = raw['TRAIN_COST'] / 100
raw['TRAIN_TT_SCALED'] = raw['TRAIN_TT'] / 100

raw['SM_COST_SCALED'] = raw.eval('SM_COST / 100')
raw['SM_TT_SCALED'] = raw['SM_TT'] / 100

raw['CAR_CO_SCALED'] = raw['CAR_CO'] / 100
raw['CAR_TT_SCALED'] = raw['CAR_TT'] / 100
raw['CAR_AV_SP'] = raw.eval("CAR_AV * (SP!=0)")
raw['TRAIN_AV_SP'] = raw.eval("TRAIN_AV * (SP!=0)")

raw['keep'] = raw.eval("PURPOSE in (1,3) and CHOICE != 0")

In [None]:
raw

In [None]:
data = lx.Dataset.from_idco(raw).query_cases('keep')

In [None]:
m1 = lx.Model(
    data.set_altids([1,2,3])
)
m1.availability_co_vars = {
    1: "TRAIN_AV_SP",
    2: "SM_AV",
    3: "CAR_AV_SP",
}
m1.choice_co_code = 'CHOICE'

m1.utility_co[1] = P("ASC_TRAIN") + X("TRAIN_COST_SCALED") * P("B_COST")
m1.utility_co[2] = X("SM_COST_SCALED") * P("B_COST")
m1.utility_co[3] = P("ASC_CAR") + X("CAR_CO_SCALED") * P("B_COST")

In [None]:
m2 = lx.Model(
    data.set_altids([1,2,3])
)
m2.availability_co_vars = {
    1: "TRAIN_AV_SP",
    2: "SM_AV",
    3: "CAR_AV_SP",
}
m2.choice_co_code = 'CHOICE'

m2.utility_co[1] = P("ASC_TRAIN") + X("TRAIN_TT_SCALED") * P("B_TIME") + X("TRAIN_COST_SCALED") * P("B_COST")
m2.utility_co[2] = X("SM_TT_SCALED") * P("B_TIME") + X("SM_COST_SCALED") * P("B_COST")
m2.utility_co[3] = P("ASC_CAR") + X("CAR_TT_SCALED") * P("B_TIME") + X("CAR_CO_SCALED") * P("B_COST")

In [None]:
mk = lx.Model(
    data.set_altids([101,102])
)
mk.utility_co[102] = P("W_OTHER")

In [None]:
m1.pf

In [None]:
b = lx.LatentClass(
    mk, 
    {101:m1, 102:m2}, 
    datatree=data.set_altids([1,2,3]),
)

In [None]:
b.reflow_data_arrays()

In [None]:
assert b.groupid is None

In [None]:
b.jax_loglike(b.pvals)

In [None]:
b.jax_loglike(b.pvals)

In [None]:
assert b.jax_loglike(b.pvals) == approx(-6964.6445, rel=1e-4)

In [None]:
b.jax_loglike(b.pvals)

In [None]:
b.jax_d_loglike(b.pvals) # 5.25s

In [None]:
b.jax_d_loglike(b.pvals) # 5.25s

In [None]:
#%timeit b.jax_d_loglike(b.pvals) # 15ms

In [None]:
result = b.jax_maximize_loglike()

In [None]:
assert result.loglike == approx(-5208.49609375, rel=1e-5)

In [None]:
b.parameters.to_dataframe()

In [None]:
b.jax_param_cov(b.pvals);

In [None]:
b.parameters.to_dataframe()

In [None]:
assert b.pstderr == approx(np.array([ 
    0.050481,  0.060852,  0.061178,  0.17538 ,  0.116116
]), rel=5e-3)

In [None]:
b.pstderr

In [None]:
assert b.pvals == approx(np.array([ 0.124529, -0.397936, -1.264057, -2.799115,  1.091897]), rel=5e-3)