In [None]:
import larch as lx
import pandas as pd
from larch import P,X

import numpy as np
import xarray as xr
from pytest import approx

In [None]:
raw = pd.read_csv(lx.example_file('swissmetro.csv.gz'))
raw['SM_COST'] = raw['SM_CO'] * (raw["GA"]==0) 
raw['TRAIN_COST'] = raw.eval("TRAIN_CO * (GA == 0)") 
raw['TRAIN_COST_SCALED'] = raw['TRAIN_COST'] / 100
raw['TRAIN_TT_SCALED'] = raw['TRAIN_TT'] / 100

raw['SM_COST_SCALED'] = raw.eval('SM_COST / 100')
raw['SM_TT_SCALED'] = raw['SM_TT'] / 100

raw['CAR_CO_SCALED'] = raw['CAR_CO'] / 100
raw['CAR_TT_SCALED'] = raw['CAR_TT'] / 100
raw['CAR_AV_SP'] = raw.eval("CAR_AV * (SP!=0)")
raw['TRAIN_AV_SP'] = raw.eval("TRAIN_AV * (SP!=0)")

raw['keep'] = raw.eval("PURPOSE in (1,3) and CHOICE != 0")

In [None]:
raw

In [None]:
data = lx.Dataset.construct.from_idco(raw).dc.query_cases('keep')

In [None]:
m1 = lx.Model(
#     data.dc.set_altids([1,2,3])
)
m1.availability_co_vars = {
    1: "TRAIN_AV_SP",
    2: "SM_AV",
    3: "CAR_AV_SP",
}
m1.choice_co_code = 'CHOICE'

m1.utility_co[1] = P("ASC_TRAIN") + X("TRAIN_COST_SCALED") * P("B_COST")
m1.utility_co[2] = X("SM_COST_SCALED") * P("B_COST")
m1.utility_co[3] = P("ASC_CAR") + X("CAR_CO_SCALED") * P("B_COST")

In [None]:
m1.pnames

In [None]:
m2 = lx.Model(
#     data.dc.set_altids([1,2,3])
)
m2.availability_co_vars = {
    1: "TRAIN_AV_SP",
    2: "SM_AV",
    3: "CAR_AV_SP",
}
m2.choice_co_code = 'CHOICE'

m2.utility_co[1] = P("ASC_TRAIN") + X("TRAIN_TT_SCALED") * P("B_TIME") + X("TRAIN_COST_SCALED") * P("B_COST")
m2.utility_co[2] = X("SM_TT_SCALED") * P("B_TIME") + X("SM_COST_SCALED") * P("B_COST")
m2.utility_co[3] = P("ASC_CAR") + X("CAR_TT_SCALED") * P("B_TIME") + X("CAR_CO_SCALED") * P("B_COST")

In [None]:
m3 = lx.Model(
#     data.dc.set_altids([1,2,3])
)
m3.availability_co_vars = {
    1: "TRAIN_AV_SP",
    2: "SM_AV",
    3: "CAR_AV_SP",
}
m3.choice_co_code = 'CHOICE'

m3.utility_co[1] = X("TRAIN_COST_SCALED") * P("Z_COST")
m3.utility_co[2] = X("SM_COST_SCALED") * P("Z_COST")
m3.utility_co[3] = X("CAR_CO_SCALED") * P("Z_COST")

# m3.groupid = 'ID'

In [None]:
import jax.numpy as jnp

In [None]:
from larch.folding import fold_dataset, _group_breaks, dissolve_zero_variance

In [None]:
df = data.to_dataframe()
df['ingroup'] = df.groupby("ID").cumcount()+1
qq = df.set_index(['ID', 'ingroup'], drop=True).to_xarray()
qq = dissolve_zero_variance(qq, 'ingroup')
classdata = qq.dc.set_altids([101, 102, 103]).drop_dims("ingroup")

In [None]:
classdata.dc.CASEID = "ID"

In [None]:
mk = lx.Model(
    classdata
)
mk.utility_co[102] = P("W_OTHER")
mk.utility_co[103] = P("W_COST")
mk.groupid="ID"

In [None]:
b = lx.LatentClass(
    mk, 
    {101:m1, 102:m2, 103:m3}, 
    datatree=data.dc.set_altids([1,2,3]),
    groupid="ID",
)

In [None]:
b.lock(Z_COST=-10000)

In [None]:
b.loglike()

In [None]:
b.d_loglike()

In [None]:
b.loglike()

In [None]:
b.d_loglike()

In [None]:
b.maximize_loglike(method='slsqp')

In [None]:
b.total_weight()

In [None]:
b.calculate_parameter_covariance()

In [None]:
b.parameter_summary()

In [None]:
b._models["classmodel"]._data_arrays.wt

In [None]:
b._models[101].dataset

In [None]:
assert b.jax_loglike(b.pvals) == approx(-6867.245, rel=1e-4)

In [None]:
b.jax_loglike(b.pvals)

In [None]:
b.jax_d_loglike(b.pvals) # 5.25s

In [None]:
b.jax_d_loglike(b.pvals) # 5.25s

In [None]:
result = b.jax_maximize_loglike()

In [None]:
result

In [None]:
assert result.loglike == approx(-4474.478515625, rel=1e-5)

In [None]:
result.loglike

In [None]:
b.parameters.to_dataframe()

In [None]:
se, hess, ihess = b.jax_param_cov(b.pvals)

In [None]:
b.jax_d2_loglike(b.pvals)

In [None]:
hess

In [None]:
b.parameters.to_dataframe()

In [None]:
assert b.pstderr == approx(np.array([ 
    0.048158,  0.069796,  0.069555,  0.106282,  0.161079,  0.11945 ,  0.
]), rel=5e-3)

In [None]:
b.pstderr

In [None]:
assert b.pvals == approx(np.array([ 
    6.079781e-02, -9.362056e-01, -1.159657e+00, -3.095285e+00, -7.734768e-01,  1.155985e+00, -1.000000e+04
]), rel=5e-3)

In [None]:
b.pvals