# Mixed Logit with correlations

Using swissmetro data, comparing results to biogeme (Bierlaire, M. (2018). PandasBiogeme: a short introduction. EPFL (Transport and Mobility Laboratory, ENAC))


In [1]:
import pandas as pd
import numpy as np
import jax

from jaxlogit.mixed_logit import MixedLogit

In [2]:
#  64bit precision
jax.config.update("jax_enable_x64", True)

## Swissmetro Dataset


The swissmetro dataset contains stated-preferences for three alternative transportation modes that include car, train and a newly introduced mode: the swissmetro. This dataset is commonly used for estimation examples with the `Biogeme` and `PyLogit` packages. The dataset is available at http://transp-or.epfl.ch/data/swissmetro.dat and [Bierlaire et. al., (2001)](https://transp-or.epfl.ch/documents/proceedings/BierAxhaAbay01.pdf) provides a detailed discussion of the data as wells as its context and collection process. The explanatory variables in this example include the travel time (`TT`) and cost `CO` for each of the three alternative modes.

### Read data

The dataset is imported to the Python environment using `pandas`. Then, two types of samples, ones with a trip purpose different to commute or business and ones with an unknown choice, are filtered out. The original dataset contains 10,729 records, but after filtering, 6,768 records remain for following analysis. Finally, a new column that uniquely identifies each sample is added to the dataframe and the `CHOICE` column, which originally contains a numerical coding of the choices, is mapped to a description that is consistent with the alternatives in the column names. 

In [3]:
df_wide = pd.read_table("http://transp-or.epfl.ch/data/swissmetro.dat", sep='\t')

# Keep only observations for commute and business purposes that contain known choices
df_wide = df_wide[(df_wide['PURPOSE'].isin([1, 3]) & (df_wide['CHOICE'] != 0))]

df_wide['custom_id'] = np.arange(len(df_wide))  # Add unique identifier
df_wide['CHOICE'] = df_wide['CHOICE'].map({1: 'TRAIN', 2:'SM', 3: 'CAR'})
df_wide

Unnamed: 0,GROUP,SURVEY,SP,ID,PURPOSE,FIRST,TICKET,WHO,LUGGAGE,AGE,...,TRAIN_CO,TRAIN_HE,SM_TT,SM_CO,SM_HE,SM_SEATS,CAR_TT,CAR_CO,CHOICE,custom_id
0,2,0,1,1,1,0,1,1,0,3,...,48,120,63,52,20,0,117,65,SM,0
1,2,0,1,1,1,0,1,1,0,3,...,48,30,60,49,10,0,117,84,SM,1
2,2,0,1,1,1,0,1,1,0,3,...,48,60,67,58,30,0,117,52,SM,2
3,2,0,1,1,1,0,1,1,0,3,...,40,30,63,52,20,0,72,52,SM,3
4,2,0,1,1,1,0,1,1,0,3,...,36,60,63,42,20,0,90,84,SM,4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
8446,3,1,1,939,3,1,7,3,1,5,...,13,30,50,17,30,0,130,64,TRAIN,6763
8447,3,1,1,939,3,1,7,3,1,5,...,12,30,53,16,10,0,80,80,TRAIN,6764
8448,3,1,1,939,3,1,7,3,1,5,...,16,60,50,16,20,0,80,64,TRAIN,6765
8449,3,1,1,939,3,1,7,3,1,5,...,16,30,53,17,30,0,80,104,TRAIN,6766


### Reshape data

The imported dataframe is in wide format, and it needs to be reshaped to long format for processing by `xlogit`, which offers the convenient `wide_to_long` utility for this reshaping process. The user needs to specify the column that uniquely identifies each sample, the names of the alternatives, the columns that vary across alternatives, and whether the alternative names are a prefix or suffix of the column names. Additionally, the user can specify a value (`empty_val`) to be used by default when an alternative is not available for a certain variable. Additional usage examples for the `wide_to_long` function are available in xlogit's documentation at https://xlogit.readthedocs.io/en/latest/notebooks/convert_data_wide_to_long.html. Also, details about the function parameters are available at the [API reference ](https://xlogit.readthedocs.io/en/latest/api/utils.html#xlogit.utils.wide_to_long).

In [4]:
from jaxlogit.utils import wide_to_long

df = wide_to_long(df_wide, id_col='custom_id', alt_name='alt', sep='_',
                  alt_list=['TRAIN', 'SM', 'CAR'], empty_val=0,
                  varying=['TT', 'CO', 'HE', 'AV', 'SEATS'], alt_is_prefix=True)
df

Unnamed: 0,custom_id,alt,TT,CO,HE,AV,SEATS,GROUP,SURVEY,SP,...,TICKET,WHO,LUGGAGE,AGE,MALE,INCOME,GA,ORIGIN,DEST,CHOICE
0,0,TRAIN,112,48,120,1,0,2,0,1,...,1,1,0,3,0,2,0,2,1,SM
1,0,SM,63,52,20,1,0,2,0,1,...,1,1,0,3,0,2,0,2,1,SM
2,0,CAR,117,65,0,1,0,2,0,1,...,1,1,0,3,0,2,0,2,1,SM
3,1,TRAIN,103,48,30,1,0,2,0,1,...,1,1,0,3,0,2,0,2,1,SM
4,1,SM,60,49,10,1,0,2,0,1,...,1,1,0,3,0,2,0,2,1,SM
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
20299,6766,SM,53,17,30,1,0,3,1,1,...,7,3,1,5,1,2,0,1,2,TRAIN
20300,6766,CAR,80,104,0,1,0,3,1,1,...,7,3,1,5,1,2,0,1,2,TRAIN
20301,6767,TRAIN,108,13,60,1,0,3,1,1,...,7,3,1,5,1,2,0,1,2,TRAIN
20302,6767,SM,53,21,30,1,0,3,1,1,...,7,3,1,5,1,2,0,1,2,TRAIN


### Create model specification

Following the reshaping, users can create or update the dataset's columns in order to accommodate their model specification needs, if necessary. The code below shows how the columns `ASC_TRAIN` and `ASC_CAR` were created to incorporate alternative-specific constants in the model. In addition, the example illustrates an effective way of establishing variable interactions (e.g., trip costs for commuters with an annual pass) by updating existing columns conditional on values of other columns. Although apparently simple, column operations provide users with an intuitive and highly-flexible mechanism to incorporate model specification aspects, such as variable transformations, interactions, and alternative specific coefficients and constants. By operating the dataframe columns, any utility specification can be accommodated in `xlogit`. As shown in [this specification example](https://xlogit.readthedocs.io/en/latest/notebooks/multinomial_model.html#Create-model-specification), highly-flexible utility specifications can be modeled in `xlogit` by operating the dataframe columns.

In [5]:
df['ASC_TRAIN'] = np.ones(len(df))*(df['alt'] == 'TRAIN')
df['ASC_CAR'] = np.ones(len(df))*(df['alt'] == 'CAR')
df['TT'], df['CO'] = df['TT']/100, df['CO']/100  # Scale variables
annual_pass = (df['GA'] == 1) & (df['alt'].isin(['TRAIN', 'SM']))
df.loc[annual_pass, 'CO'] = 0  # Cost zero for pass holders

### Estimate model parameters

The `fit` method estimates the model by taking as input the data from the previous step along with additional specification criteria, such as the distribution of the random parameters (`randvars`), the number of random draws (`n_draws`), and the availability of alternatives for the choice situations (`avail`). We set the optimization method as `L-BFGS-B` as this is a robust routine that usually helps solve convergence issues.  Once the estimation routine is completed, the `summary` method can be used to display the estimation results.

In [6]:
varnames=['ASC_CAR', 'ASC_TRAIN', 'CO', 'TT']
model = MixedLogit()
res = model.fit(
    X=df[varnames],
    y=df['CHOICE'],
    varnames=varnames,
    alts=df['alt'],
    ids=df['custom_id'],
    avail=df['AV'],
    panels=df["ID"],
    randvars={'TT': 'n'},
    n_draws=1500,
)
model.summary()

2025-07-04 17:58:03,363 INFO jaxlogit.mixed_logit: Starting data preparation, including generation of 1500 random draws for each random variable and observation.
INFO:2025-07-04 17:58:03,448:jax._src.xla_bridge:752: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2025-07-04 17:58:03,448 INFO jax._src.xla_bridge: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2025-07-04 17:58:05,537 INFO jaxlogit.mixed_logit: Compiling log-likelihood function.
2025-07-04 17:58:06,113 INFO jaxlogit.mixed_logit: Compilation finished, init neg_loglike = 7309.54
2025-07-04 17:58:06,115 INFO jaxlogit._optimize: Running minimization with method L-BFGS-B
2025-07-04 17:58:06,851 INFO jaxlogit._optimize: Iter 1, fun = 5509.545, |grad| = 1054.174
2025-07-04 17:58:07,320 INFO jaxlogit._optimize: Iter 2, fun = 4969.103, |grad| = 57

    Message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH
    Iterations: 16
    Function evaluations: 17
Estimation time= 18.5 seconds
---------------------------------------------------------------------------
Coefficient              Estimate      Std.Err.         z-val         P>|z|
---------------------------------------------------------------------------
ASC_CAR                 0.2831119     0.0560481     5.0512331       4.5e-07 ***
ASC_TRAIN              -0.5722801     0.0794782    -7.2004683      6.65e-13 ***
CO                     -1.6601727     0.0778872   -21.3151023      1.26e-97 ***
TT                     -3.2289922     0.1749822   -18.4532610      3.17e-74 ***
sd.TT                   3.6485578     0.1683478    21.6727368     9.28e-101 ***
---------------------------------------------------------------------------
Significance:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Log-Likelihood= -4359.218
AIC= 8728.436
BIC= 8762.536


## Example of fixing parameters

In [7]:
# we left this one out before, let's add it and assert parameters to 0
df['ASC_SM'] = np.ones(len(df))*(df['alt'] == 'SM')

In [8]:
varnames=['ASC_CAR', 'ASC_TRAIN', 'ASC_SM', 'CO', 'TT']
fixedvars = {'ASC_SM': 0.0}  # Fixing parameters
model = MixedLogit()
res = model.fit(
    X=df[varnames],
    y=df['CHOICE'],
    varnames=varnames,
    alts=df['alt'],
    ids=df['custom_id'],
    avail=df['AV'],
    panels=df["ID"],
    randvars={'TT': 'n'},
    n_draws=1500,
    fixedvars=fixedvars
)
model.summary()

2025-07-04 17:58:55,216 INFO jaxlogit.mixed_logit: Starting data preparation, including generation of 1500 random draws for each random variable and observation.
2025-07-04 17:58:56,375 INFO jaxlogit.mixed_logit: Compiling log-likelihood function.
2025-07-04 17:58:56,875 INFO jaxlogit.mixed_logit: Compilation finished, init neg_loglike = 7309.54
2025-07-04 17:58:56,878 INFO jaxlogit._optimize: Running minimization with method L-BFGS-B
2025-07-04 17:58:57,677 INFO jaxlogit._optimize: Iter 1, fun = 5509.545, |grad| = 1054.174
2025-07-04 17:58:58,140 INFO jaxlogit._optimize: Iter 2, fun = 4969.103, |grad| = 571.378
2025-07-04 17:58:58,609 INFO jaxlogit._optimize: Iter 3, fun = 4611.336, |grad| = 275.850
2025-07-04 17:58:59,065 INFO jaxlogit._optimize: Iter 4, fun = 4470.256, |grad| = 221.282
2025-07-04 17:58:59,537 INFO jaxlogit._optimize: Iter 5, fun = 4395.300, |grad| = 63.354
2025-07-04 17:59:00,010 INFO jaxlogit._optimize: Iter 6, fun = 4380.971, |grad| = 54.077
2025-07-04 17:59:00,45

    Message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH
    Iterations: 16
    Function evaluations: 17
Estimation time= 19.6 seconds
---------------------------------------------------------------------------
Coefficient              Estimate      Std.Err.         z-val         P>|z|
---------------------------------------------------------------------------
ASC_CAR                 0.2831119     0.0560481     5.0512331       4.5e-07 ***
ASC_TRAIN              -0.5722801     0.0794782    -7.2004683      6.65e-13 ***
ASC_SM                  0.0000000     0.0000000           nan           nan    
CO                     -1.6601727     0.0778872   -21.3151023      1.26e-97 ***
TT                     -3.2289922     0.1749822   -18.4532610      3.17e-74 ***
sd.TT                   3.6485578     0.1683478    21.6727368     9.28e-101 ***
---------------------------------------------------------------------------
Significance:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Log-Likel

## Now let's add correlations