# Comparison of biogeme and jaxlogit on swissmetro data

Comparing panel estimation among other things. Biogeme code mostly from examples at https://biogeme.epfl.ch/sphinx/auto_examples/swissmetro/index.html

In [1]:
import pandas as pd

import biogeme.biogeme_logging as blog
import biogeme.biogeme as bio
from biogeme import models
from biogeme.expressions import Beta, bioDraws, log, MonteCarlo, PanelLikelihoodTrajectory
import biogeme.database as db
from biogeme.expressions import Variable

In [2]:
logger = blog.get_screen_logger(level=blog.INFO)

In [3]:
df = pd.read_table("http://transp-or.epfl.ch/data/swissmetro.dat", sep='\t')

In [4]:
(((df.PURPOSE != 1) * (df.PURPOSE != 3) + (df.CHOICE == 0)) > 0).value_counts()

False    6768
True     3960
Name: count, dtype: int64

In [6]:
database = db.Database('swissmetro', df)

GROUP = Variable('GROUP')
SURVEY = Variable('SURVEY')
SP = Variable('SP')
ID = Variable('ID')
PURPOSE = Variable('PURPOSE')
FIRST = Variable('FIRST')
TICKET = Variable('TICKET')
WHO = Variable('WHO')
LUGGAGE = Variable('LUGGAGE')
AGE = Variable('AGE')
MALE = Variable('MALE')
INCOME = Variable('INCOME')
GA = Variable('GA')
ORIGIN = Variable('ORIGIN')
DEST = Variable('DEST')
TRAIN_AV = Variable('TRAIN_AV')
CAR_AV = Variable('CAR_AV')
SM_AV = Variable('SM_AV')
TRAIN_TT = Variable('TRAIN_TT')
TRAIN_CO = Variable('TRAIN_CO')
TRAIN_HE = Variable('TRAIN_HE')
SM_TT = Variable('SM_TT')
SM_CO = Variable('SM_CO')
SM_HE = Variable('SM_HE')
SM_SEATS = Variable('SM_SEATS')
CAR_TT = Variable('CAR_TT')
CAR_CO = Variable('CAR_CO')
CHOICE = Variable('CHOICE')

exclude = ((PURPOSE != 1) * (PURPOSE != 3) + (CHOICE == 0)) > 0
print(f"Removing {(((df.PURPOSE != 1) * (df.PURPOSE != 3) + (df.CHOICE == 0)) > 0).sum()} rows from the database based on the exclusion criteria.")
database.remove(exclude)

SM_COST = database.define_variable('SM_COST', SM_CO * (GA == 0))
TRAIN_COST = database.define_variable('TRAIN_COST', TRAIN_CO * (GA == 0))
CAR_AV_SP = database.define_variable('CAR_AV_SP', CAR_AV * (SP != 0))
TRAIN_AV_SP = database.define_variable('TRAIN_AV_SP', TRAIN_AV * (SP != 0))
TRAIN_TT_SCALED = database.define_variable('TRAIN_TT_SCALED', TRAIN_TT / 100)
TRAIN_COST_SCALED = database.define_variable('TRAIN_COST_SCALED', TRAIN_COST / 100)
SM_TT_SCALED = database.define_variable('SM_TT_SCALED', SM_TT / 100)
SM_COST_SCALED = database.define_variable('SM_COST_SCALED', SM_COST / 100)
CAR_TT_SCALED = database.define_variable('CAR_TT_SCALED', CAR_TT / 100)
CAR_CO_SCALED = database.define_variable('CAR_CO_SCALED', CAR_CO / 100)

Removing 3960 rows from the database based on the exclusion criteria.


In [9]:
# panel data
database.panel('ID')

In [10]:
B_COST = Beta('B_COST', 0.1, None, None, 0)
B_COST_S = Beta('B_COST_S', 0.75, None, None, 0)
B_COST_RND = B_COST + B_COST_S * bioDraws('b_cost_rnd', 'NORMAL_MLHS_ANTI')

B_TIME = Beta('B_TIME', 0.1, None, None, 0)
B_TIME_S = Beta('B_TIME_S', 0.75, None, None, 0)
B_TIME_RND = B_TIME + B_TIME_S * bioDraws('b_time_rnd', 'NORMAL_MLHS_ANTI')

ASC_CAR = Beta('ASC_CAR', 0.1, None, None, 0)
ASC_TRAIN = Beta('ASC_TRAIN', 0, None, None, 1) 
ASC_SM = Beta('ASC_SM', 0.1, None, None, 0)

V1 = ASC_TRAIN + B_TIME_RND * TRAIN_TT_SCALED + B_COST_RND * TRAIN_COST_SCALED
V2 = ASC_SM + B_TIME_RND * SM_TT_SCALED + B_COST_RND * SM_COST_SCALED
V3 = ASC_CAR + B_TIME_RND * CAR_TT_SCALED + B_COST_RND * CAR_CO_SCALED

V = {1: V1, 2: V2, 3: V3}
av = {1: TRAIN_AV_SP, 2: SM_AV, 3: CAR_AV_SP}

prob = models.logit(V, av, CHOICE)
logprob = log(MonteCarlo(PanelLikelihoodTrajectory(prob)))

the_biogeme = bio.BIOGEME(
    database, logprob, number_of_draws=1000, seed=999
)
the_biogeme.modelName = 'test'
the_biogeme.generate_pickle = False
the_biogeme.generate_html = False

Default values of the Biogeme parameters are used. 
File biogeme.toml has been created 


In [11]:
the_biogeme.calculate_init_likelihood()

-5858.285806420752

In [12]:
results = the_biogeme.estimate()
pandas_results = results.get_estimated_parameters()

print(results.short_summary())

As the model is rather complex, we cancel the calculation of second derivatives. If you want to control the parameters, change the name of the algorithm in the TOML file from "automatic" to "simple_bounds" 
*** Initial values of the parameters are obtained from the file __test.iter 
Cannot read file __test.iter. Statement is ignored. 
The number of draws (1000) is low. The results may not be meaningful. 
As the model is rather complex, we cancel the calculation of second derivatives. If you want to control the parameters, change the name of the algorithm in the TOML file from "automatic" to "simple_bounds" 
Optimization algorithm: hybrid Newton/BFGS with simple bounds [simple_bounds] 
** Optimization: BFGS with trust region for simple bounds 
Iter.         ASC_CAR          ASC_SM          B_COST        B_COST_S          B_TIME        B_TIME_S     Function    Relgrad   Radius      Rho      
    0             1.1             1.1            -0.9             1.8            -0.9            

Results for model test
Nbr of parameters:		6
Sample size:			752
Observations:			6768
Excluded data:			3960
Final log likelihood:		-3925.737
Akaike Information Criterion:	7863.473
Bayesian Information Criterion:	7891.21



In [13]:
pandas_results

Unnamed: 0,Value,Rob. Std err,Rob. t-test,Rob. p-value
ASC_CAR,0.727231,0.142007,5.121091,3.037727e-07
ASC_SM,0.346918,0.15398,2.253004,0.0242589
B_COST,-3.951661,0.269035,-14.688254,0.0
B_COST_S,4.778901,0.30768,15.532025,0.0
B_TIME,-4.708162,0.280651,-16.775846,0.0
B_TIME_S,4.40448,0.265559,16.58567,0.0


## jaxlogit

In [19]:
import numpy as np

import jax

from jaxlogit.mixed_logit import MixedLogit
from jaxlogit.utils import wide_to_long

#  64bit precision
jax.config.update("jax_enable_x64", True)

In [20]:
df_wide = database.data.copy()

df_wide['custom_id'] = np.arange(len(df_wide))  # Add unique identifier
df_wide['CHOICE'] = df_wide['CHOICE'].map({1: 'TRAIN', 2:'SM', 3: 'CAR'})
df_wide['TRAIN_AV'] = df_wide['TRAIN_AV'] * (df_wide['SP'] != 0)
df_wide['CAR_AV'] = df_wide['CAR_AV'] * (df_wide['SP'] != 0)

df_jxl = wide_to_long(
    df_wide, id_col='custom_id', alt_name='alt', sep='_',
    alt_list=['TRAIN', 'SM', 'CAR'], empty_val=0,
    varying=['TT', 'CO', 'HE', 'AV', 'SEATS'], alt_is_prefix=True
)

df_jxl['ASC_TRAIN'] = np.where(df_jxl['alt'] == 'TRAIN', 1, 0)
df_jxl['ASC_CAR'] = np.where(df_jxl['alt'] == 'CAR', 1, 0)
df_jxl['ASC_SM'] = np.where(df_jxl['alt'] == 'SM', 1, 0)

df_jxl['TT'] = df_jxl['TT'] / 100.0
df_jxl['CO'] = df_jxl['CO'] / 100.0

df_jxl.loc[(df_jxl['GA'] == 1) & (df_jxl['alt'].isin(['TRAIN', 'SM'])), 'CO'] = 0  # Cost zero for pass holders

In [22]:
varnames = ['ASC_SM', 'ASC_CAR', 'ASC_TRAIN', 'TT', 'CO']

randvars = {'CO': 'n', 'TT': 'n'}  

fixedvars = {'ASC_TRAIN': 0.0}

do_panel = True

model = MixedLogit()
res = model.fit(
    X=df_jxl[varnames],
    y=df_jxl['CHOICE'],
    varnames=varnames,
    alts=df_jxl['alt'],
    ids=df_jxl['custom_id'],
    avail=df_jxl['AV'],
    panels=None if do_panel is False else df_jxl["ID"],
    randvars=randvars,
    n_draws=1000,
    fixedvars=fixedvars,
    init_coeff=None,
    include_correlations=False,
    optim_method='trust-region',
    skip_std_errs=False,
    force_positive_chol_diag=False,  # not using softplus for std devs here for comparability with biogeme
)
model.summary()

2025-07-16 18:21:31,407 INFO jaxlogit.mixed_logit: Starting data preparation, including generation of 1000 random draws for each random variable and observation.


INFO:2025-07-16 18:21:31,478:jax._src.xla_bridge:752: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2025-07-16 18:21:31,478 INFO jax._src.xla_bridge: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2025-07-16 18:21:33,634 INFO jaxlogit.mixed_logit: Data contains 752 panels, using segment_sum for panel-wise log-likelihood.
2025-07-16 18:21:33,635 INFO jaxlogit.mixed_logit: Shape of draws: (6768, 2, 1000), number of draws: 1000
2025-07-16 18:21:33,636 INFO jaxlogit.mixed_logit: Shape of Xdf: (6768, 2, 3), shape of Xdr: (6768, 2, 2)
2025-07-16 18:21:33,638 INFO jaxlogit.mixed_logit: Compiling log-likelihood function.
2025-07-16 18:21:33,955 INFO jaxlogit.mixed_logit: Compilation finished, init neg_loglike = 6971.97, params= [(np.str_('ASC_SM'), Array(0.1, dtype=float64)), (np.str_('ASC_CAR'), Array(0.1, 

Loss on this step: 6971.967859038894, Loss on the last accepted step: 0.0, Step size: 1.0
Loss on this step: 395243.9234205277, Loss on the last accepted step: 6971.967859038894, Step size: 0.25
Loss on this step: 344432.87536128884, Loss on the last accepted step: 6971.967859038894, Step size: 0.0625
Loss on this step: 165141.12496215352, Loss on the last accepted step: 6971.967859038894, Step size: 0.015625
Loss on this step: 44265.42190838311, Loss on the last accepted step: 6971.967859038894, Step size: 0.00390625
Loss on this step: 13687.1931113031, Loss on the last accepted step: 6971.967859038894, Step size: 0.0009765625
Loss on this step: 6175.561013144365, Loss on the last accepted step: 6971.967859038894, Step size: 0.0009765625
Loss on this step: 4458.54568314532, Loss on the last accepted step: 6175.561013144365, Step size: 0.0009765625
Loss on this step: 4358.030094798618, Loss on the last accepted step: 4458.54568314532, Step size: 0.0009765625
Loss on this step: 4325.022

2025-07-16 18:21:45,690 INFO jaxlogit.mixed_logit: Optimization finished, success = True, final loglike = -3920.75, final gradient max = 3.65e-06, norm = 7.47e-06.
2025-07-16 18:21:45,691 INFO jaxlogit.mixed_logit: Calculating gradient of individual log-likelihood contributions
2025-07-16 18:21:48,755 INFO jaxlogit.mixed_logit: Calculating H_inv
2025-07-16 18:21:53,240 INFO jaxlogit._choice_model: Post fit processing
2025-07-16 18:21:53,777 INFO jaxlogit._choice_model: Optimization terminated successfully.


    Message: 
    Iterations: 50
    Function evaluations: 63
Estimation time= 22.3 seconds
---------------------------------------------------------------------------
Coefficient              Estimate      Std.Err.         z-val         P>|z|
---------------------------------------------------------------------------
ASC_SM                  0.3956112     0.1505933     2.6270168       0.00863 ** 
ASC_CAR                 0.7560334     0.1404524     5.3828439      7.58e-08 ***
ASC_TRAIN               0.0000000     0.0000000           nan           nan    
TT                     -4.6348978     0.2555659   -18.1358243      8.17e-72 ***
CO                     -4.1309352     0.5118050    -8.0713066      8.17e-16 ***
sd.TT                   4.3913653     0.2538132    17.3015621      1.17e-65 ***
sd.CO                   4.6962315     0.6416456     7.3190432      2.79e-13 ***
---------------------------------------------------------------------------
Significance:  0 '***' 0.001 '**' 0.01 '*' 0