# Mixed Logit with correlations

Using swissmetro data, comparing results to biogeme (Bierlaire, M. (2018). PandasBiogeme: a short introduction. EPFL (Transport and Mobility Laboratory, ENAC))


In [1]:
import pandas as pd
import numpy as np
import jax

from jaxlogit.mixed_logit import MixedLogit
from jaxlogit.utils import wide_to_long

In [2]:
#  64bit precision
jax.config.update("jax_enable_x64", True)

## Swissmetro Dataset

In [3]:
df_wide = pd.read_table("http://transp-or.epfl.ch/data/swissmetro.dat", sep='\t')

# Keep only observations for commute and business purposes that contain known choices
df_wide = df_wide[(df_wide['PURPOSE'].isin([1, 3]) & (df_wide['CHOICE'] != 0))]

df_wide['custom_id'] = np.arange(len(df_wide))  # Add unique identifier
df_wide['CHOICE'] = df_wide['CHOICE'].map({1: 'TRAIN', 2:'SM', 3: 'CAR'})

# biogeme data prep
#exclude = ((PURPOSE != 1) * (PURPOSE != 3) + (CHOICE == 0)) > 0
exclude = ((df_wide['PURPOSE'] != 1) * (df_wide['PURPOSE'] != 3) + (df_wide['CHOICE'] == 0)) > 0
print(f"Excluding {exclude.sum()} observations from {len(df_wide)} total observations.")
df_wide = df_wide[~exclude]

Excluding 0 observations from 6768 total observations.


In [4]:
#TRAIN_AV_SP = database.define_variable('TRAIN_AV_SP', TRAIN_AV * (SP != 0))
df_wide['TRAIN_AV'] = df_wide['TRAIN_AV'] * (df_wide['SP'] != 0)
# CAR_AV_SP = database.define_variable('CAR_AV_SP', CAR_AV * (SP != 0))
df_wide['CAR_AV'] = df_wide['CAR_AV'] * (df_wide['SP'] != 0)

In [5]:
df = wide_to_long(df_wide, id_col='custom_id', alt_name='alt', sep='_',
                  alt_list=['TRAIN', 'SM', 'CAR'], empty_val=0,
                  varying=['TT', 'CO', 'HE', 'AV', 'SEATS'], alt_is_prefix=True)

In [6]:
df['ASC_TRAIN'] = np.where(df['alt'] == 'TRAIN', 1, 0)
df['ASC_CAR'] = np.where(df['alt'] == 'CAR', 1, 0)
df['ASC_SM'] = np.where(df['alt'] == 'SM', 1, 0)

df['TT'] = df['TT'] / 100.0
df['CO'] = df['CO'] / 100.0

annual_pass = (df['GA'] == 1) & (df['alt'].isin(['TRAIN', 'SM']))
df.loc[annual_pass, 'CO'] = 0  # Cost zero for pass holders

## Now above before long to wide
#CAR_AV_SP = database.define_variable('CAR_AV_SP', CAR_AV * (SP != 0))
#TRAIN_AV_SP = database.define_variable('TRAIN_AV_SP', TRAIN_AV * (SP != 0))
#df.loc[(df['SP'] == 0) & (df['alt'].isin(['CAR', 'TRAIN'])), 'AV'] = 0

In [7]:
df.custom_id.unique().max() + 1

np.int64(6768)

In [8]:
# model.coeff_names
# array(['ASC_CAR', 'ASC_TRAIN', 'ASC_SM', 'CO', 'TT', 'sd.ASC_CAR', 'sd.ASC_TRAIN', 'sd.ASC_SM'], dtype='<U12')
init_vals = [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0]

In [10]:
varnames=['ASC_CAR', 'ASC_TRAIN', 'ASC_SM', 'CO', 'TT']

randvars={'ASC_CAR': 'n', 'ASC_TRAIN': 'n', 'ASC_SM': 'n'}
fixedvars = {'ASC_SM': 0.0, 'sd.ASC_TRAIN': 0.0}  # Mean and variance of SM is fixed to 0

do_panel = False

model = MixedLogit()
res = model.fit(
    X=df[varnames],
    y=df['CHOICE'],
    varnames=varnames,
    alts=df['alt'],
    ids=df['custom_id'],
    avail=df['AV'],
    panels=None if do_panel is False else df["ID"],
    randvars=randvars,
    n_draws=2000,
    fixedvars=fixedvars,
    init_coeff=init_vals,
    include_correlations=False,  # Enable correlation between random parameters
)
model.summary()

2025-07-07 16:40:46,247 INFO jaxlogit.mixed_logit: Starting data preparation, including generation of 2000 random draws for each random variable and observation.


2025-07-07 16:40:55,088 INFO jaxlogit.mixed_logit: Compiling log-likelihood function.
2025-07-07 16:40:55,606 INFO jaxlogit.mixed_logit: Compilation finished, init neg_loglike = 6816.59
2025-07-07 16:40:55,608 INFO jaxlogit._optimize: Running minimization with method L-BFGS-B
2025-07-07 16:40:57,770 INFO jaxlogit._optimize: Iter 1, fun = 5656.014, |grad| = 575.180
2025-07-07 16:40:58,994 INFO jaxlogit._optimize: Iter 2, fun = 5464.520, |grad| = 303.995
2025-07-07 16:41:00,138 INFO jaxlogit._optimize: Iter 3, fun = 5346.907, |grad| = 171.981
2025-07-07 16:41:01,489 INFO jaxlogit._optimize: Iter 4, fun = 5301.962, |grad| = 58.851
2025-07-07 16:41:02,703 INFO jaxlogit._optimize: Iter 5, fun = 5291.043, |grad| = 48.407
2025-07-07 16:41:04,149 INFO jaxlogit._optimize: Iter 6, fun = 5279.227, |grad| = 65.026
2025-07-07 16:41:05,471 INFO jaxlogit._optimize: Iter 7, fun = 5267.343, |grad| = 92.726
2025-07-07 16:41:06,782 INFO jaxlogit._optimize: Iter 8, fun = 5260.180, |grad| = 46.869
2025-07-

    Message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH
    Iterations: 19
    Function evaluations: 22
Estimation time= 63.9 seconds
---------------------------------------------------------------------------
Coefficient              Estimate      Std.Err.         z-val         P>|z|
---------------------------------------------------------------------------
ASC_CAR                -0.4959902     0.0905947    -5.4748277      4.54e-08 ***
ASC_TRAIN              -1.2494991     0.1094049   -11.4208742      6.18e-30 ***
ASC_SM                  0.0000000     0.0000000           nan           nan    
CO                     -1.7616644     0.1021400   -17.2475416      2.86e-65 ***
TT                     -1.7162396     0.0972969   -17.6392047      4.05e-68 ***
sd.ASC_CAR              0.0034497     0.2152555     0.0160261         0.987    
sd.ASC_TRAIN            0.0000000     0.0000000           nan           nan    
sd.ASC_SM               3.1942063     0.2543524    12.5581909      8

In [None]:
# install xlogit, compare. something has gone wrong here I think, why are results so different?

In [11]:
model.coeff_names

array(['ASC_CAR', 'ASC_TRAIN', 'ASC_SM', 'CO', 'TT', 'sd.ASC_CAR',
       'sd.ASC_TRAIN', 'sd.ASC_SM'], dtype='<U12')

In [15]:
import jax.numpy as jnp

tril_rows, tril_cols = jnp.tril_indices(8)
diag_mask = tril_rows == tril_cols
off_diag_mask = ~diag_mask

In [18]:
diag_vals = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
off_diag_vals = jnp.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6])

In [19]:
tril_vals = jnp.where(
    diag_mask,
    diag_vals[tril_rows],  # diagonal values
    off_diag_vals[jnp.cumsum(off_diag_mask) - 1]  # off-diagonal values
)

In [20]:
tril_vals

Array([1. , 0.1, 2. , 0.2, 0.3, 3. , 0.4, 0.5, 0.6, 4. , 0.7, 0.8, 0.9,
       1. , 5. , 1.1, 1.2, 1.3, 1.4, 1.5, 6. , 1.6, 1.6, 1.6, 1.6, 1.6,
       1.6, 7. , 1.6, 1.6, 1.6, 1.6, 1.6, 1.6, 1.6, 8. ], dtype=float64)

In [21]:
L = jnp.zeros((8, 8))
L = L.at[tril_rows, tril_cols].set(tril_vals)


In [22]:
L

Array([[1. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
       [0.1, 2. , 0. , 0. , 0. , 0. , 0. , 0. ],
       [0.2, 0.3, 3. , 0. , 0. , 0. , 0. , 0. ],
       [0.4, 0.5, 0.6, 4. , 0. , 0. , 0. , 0. ],
       [0.7, 0.8, 0.9, 1. , 5. , 0. , 0. , 0. ],
       [1.1, 1.2, 1.3, 1.4, 1.5, 6. , 0. , 0. ],
       [1.6, 1.6, 1.6, 1.6, 1.6, 1.6, 7. , 0. ],
       [1.6, 1.6, 1.6, 1.6, 1.6, 1.6, 1.6, 8. ]], dtype=float64)