# Mixed Logit

In [1]:
import pandas as pd
import numpy as np
import jax

from jaxlogit.mixed_logit import MixedLogit

In [2]:
#  64bit precision
jax.config.update("jax_enable_x64", True)

## Swissmetro Dataset


The swissmetro dataset contains stated-preferences for three alternative transportation modes that include car, train and a newly introduced mode: the swissmetro. This dataset is commonly used for estimation examples with the `Biogeme` and `PyLogit` packages. The dataset is available at http://transp-or.epfl.ch/data/swissmetro.dat and [Bierlaire et. al., (2001)](https://transp-or.epfl.ch/documents/proceedings/BierAxhaAbay01.pdf) provides a detailed discussion of the data as wells as its context and collection process. The explanatory variables in this example include the travel time (`TT`) and cost `CO` for each of the three alternative modes.

### Read data

The dataset is imported to the Python environment using `pandas`. Then, two types of samples, ones with a trip purpose different to commute or business and ones with an unknown choice, are filtered out. The original dataset contains 10,729 records, but after filtering, 6,768 records remain for following analysis. Finally, a new column that uniquely identifies each sample is added to the dataframe and the `CHOICE` column, which originally contains a numerical coding of the choices, is mapped to a description that is consistent with the alternatives in the column names. 

In [3]:
df_wide = pd.read_table("http://transp-or.epfl.ch/data/swissmetro.dat", sep='\t')

# Keep only observations for commute and business purposes that contain known choices
df_wide = df_wide[(df_wide['PURPOSE'].isin([1, 3]) & (df_wide['CHOICE'] != 0))]

df_wide['custom_id'] = np.arange(len(df_wide))  # Add unique identifier
df_wide['CHOICE'] = df_wide['CHOICE'].map({1: 'TRAIN', 2:'SM', 3: 'CAR'})
df_wide

Unnamed: 0,GROUP,SURVEY,SP,ID,PURPOSE,FIRST,TICKET,WHO,LUGGAGE,AGE,...,TRAIN_CO,TRAIN_HE,SM_TT,SM_CO,SM_HE,SM_SEATS,CAR_TT,CAR_CO,CHOICE,custom_id
0,2,0,1,1,1,0,1,1,0,3,...,48,120,63,52,20,0,117,65,SM,0
1,2,0,1,1,1,0,1,1,0,3,...,48,30,60,49,10,0,117,84,SM,1
2,2,0,1,1,1,0,1,1,0,3,...,48,60,67,58,30,0,117,52,SM,2
3,2,0,1,1,1,0,1,1,0,3,...,40,30,63,52,20,0,72,52,SM,3
4,2,0,1,1,1,0,1,1,0,3,...,36,60,63,42,20,0,90,84,SM,4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
8446,3,1,1,939,3,1,7,3,1,5,...,13,30,50,17,30,0,130,64,TRAIN,6763
8447,3,1,1,939,3,1,7,3,1,5,...,12,30,53,16,10,0,80,80,TRAIN,6764
8448,3,1,1,939,3,1,7,3,1,5,...,16,60,50,16,20,0,80,64,TRAIN,6765
8449,3,1,1,939,3,1,7,3,1,5,...,16,30,53,17,30,0,80,104,TRAIN,6766


### Reshape data

The imported dataframe is in wide format, and it needs to be reshaped to long format for processing by `xlogit`, which offers the convenient `wide_to_long` utility for this reshaping process. The user needs to specify the column that uniquely identifies each sample, the names of the alternatives, the columns that vary across alternatives, and whether the alternative names are a prefix or suffix of the column names. Additionally, the user can specify a value (`empty_val`) to be used by default when an alternative is not available for a certain variable. Additional usage examples for the `wide_to_long` function are available in xlogit's documentation at https://xlogit.readthedocs.io/en/latest/notebooks/convert_data_wide_to_long.html. Also, details about the function parameters are available at the [API reference ](https://xlogit.readthedocs.io/en/latest/api/utils.html#xlogit.utils.wide_to_long).

In [4]:
from jaxlogit.utils import wide_to_long

df = wide_to_long(df_wide, id_col='custom_id', alt_name='alt', sep='_',
                  alt_list=['TRAIN', 'SM', 'CAR'], empty_val=0,
                  varying=['TT', 'CO', 'HE', 'AV', 'SEATS'], alt_is_prefix=True)
df

Unnamed: 0,custom_id,alt,TT,CO,HE,AV,SEATS,GROUP,SURVEY,SP,...,TICKET,WHO,LUGGAGE,AGE,MALE,INCOME,GA,ORIGIN,DEST,CHOICE
0,0,TRAIN,112,48,120,1,0,2,0,1,...,1,1,0,3,0,2,0,2,1,SM
1,0,SM,63,52,20,1,0,2,0,1,...,1,1,0,3,0,2,0,2,1,SM
2,0,CAR,117,65,0,1,0,2,0,1,...,1,1,0,3,0,2,0,2,1,SM
3,1,TRAIN,103,48,30,1,0,2,0,1,...,1,1,0,3,0,2,0,2,1,SM
4,1,SM,60,49,10,1,0,2,0,1,...,1,1,0,3,0,2,0,2,1,SM
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
20299,6766,SM,53,17,30,1,0,3,1,1,...,7,3,1,5,1,2,0,1,2,TRAIN
20300,6766,CAR,80,104,0,1,0,3,1,1,...,7,3,1,5,1,2,0,1,2,TRAIN
20301,6767,TRAIN,108,13,60,1,0,3,1,1,...,7,3,1,5,1,2,0,1,2,TRAIN
20302,6767,SM,53,21,30,1,0,3,1,1,...,7,3,1,5,1,2,0,1,2,TRAIN


### Create model specification

Following the reshaping, users can create or update the dataset's columns in order to accommodate their model specification needs, if necessary. The code below shows how the columns `ASC_TRAIN` and `ASC_CAR` were created to incorporate alternative-specific constants in the model. In addition, the example illustrates an effective way of establishing variable interactions (e.g., trip costs for commuters with an annual pass) by updating existing columns conditional on values of other columns. Although apparently simple, column operations provide users with an intuitive and highly-flexible mechanism to incorporate model specification aspects, such as variable transformations, interactions, and alternative specific coefficients and constants. By operating the dataframe columns, any utility specification can be accommodated in `xlogit`. As shown in [this specification example](https://xlogit.readthedocs.io/en/latest/notebooks/multinomial_model.html#Create-model-specification), highly-flexible utility specifications can be modeled in `xlogit` by operating the dataframe columns.

In [5]:
df['ASC_TRAIN'] = np.ones(len(df))*(df['alt'] == 'TRAIN')
df['ASC_CAR'] = np.ones(len(df))*(df['alt'] == 'CAR')
df['TT'], df['CO'] = df['TT']/100, df['CO']/100  # Scale variables
annual_pass = (df['GA'] == 1) & (df['alt'].isin(['TRAIN', 'SM']))
df.loc[annual_pass, 'CO'] = 0  # Cost zero for pass holders

### Estimate model parameters

The `fit` method estimates the model by taking as input the data from the previous step along with additional specification criteria, such as the distribution of the random parameters (`randvars`), the number of random draws (`n_draws`), and the availability of alternatives for the choice situations (`avail`). We set the optimization method as `L-BFGS-B` as this is a robust routine that usually helps solve convergence issues.  Once the estimation routine is completed, the `summary` method can be used to display the estimation results.

In [6]:
varnames=['ASC_CAR', 'ASC_TRAIN', 'CO', 'TT']
model = MixedLogit()
res = model.fit(
    X=df[varnames],
    y=df['CHOICE'],
    varnames=varnames,
    alts=df['alt'],
    ids=df['custom_id'],
    avail=df['AV'],
    panels=df["ID"],
    randvars={'TT': 'n'},
    n_draws=1500,
)
model.summary()

2025-07-14 00:29:34,451 INFO jaxlogit.mixed_logit: Starting data preparation, including generation of 1500 random draws for each random variable and observation.
INFO:2025-07-14 00:29:34,528:jax._src.xla_bridge:752: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2025-07-14 00:29:34,528 INFO jax._src.xla_bridge: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory


2025-07-14 00:29:37,022 INFO jaxlogit.mixed_logit: Data contains 752 panels, using segment_sum for panel-wise log-likelihood.
2025-07-14 00:29:37,023 INFO jaxlogit.mixed_logit: Shape of draws: (6768, 1, 1500), number of draws: 1500
2025-07-14 00:29:37,024 INFO jaxlogit.mixed_logit: Shape of Xdf: (6768, 2, 3), shape of Xdr: (6768, 2, 1)
2025-07-14 00:29:37,026 INFO jaxlogit.mixed_logit: Compiling log-likelihood function.
2025-07-14 00:29:37,459 INFO jaxlogit.mixed_logit: Compilation finished, init neg_loglike = 6260.71, params= [(np.str_('ASC_CAR'), Array(0.1, dtype=float64)), (np.str_('ASC_TRAIN'), Array(0.1, dtype=float64)), (np.str_('CO'), Array(0.1, dtype=float64)), (np.str_('TT'), Array(0.1, dtype=float64)), (np.str_('sd.TT'), Array(0.1, dtype=float64))]
2025-07-14 00:29:37,461 INFO jaxlogit._optimize: Running minimization with method trust-region


Loss on this step: 6260.709404860119, Loss on the last accepted step: 0.0, Step size: 1.0
Loss on this step: 205471.97895501574, Loss on the last accepted step: 6260.709404860119, Step size: 0.25
Loss on this step: 123910.46308477379, Loss on the last accepted step: 6260.709404860119, Step size: 0.0625
Loss on this step: 45412.46942728786, Loss on the last accepted step: 6260.709404860119, Step size: 0.015625
Loss on this step: 13605.640755086313, Loss on the last accepted step: 6260.709404860119, Step size: 0.00390625
Loss on this step: 5957.889047455692, Loss on the last accepted step: 6260.709404860119, Step size: 0.00390625
Loss on this step: 4662.3115816798545, Loss on the last accepted step: 5957.889047455692, Step size: 0.00390625
Loss on this step: 4515.065009751052, Loss on the last accepted step: 4662.3115816798545, Step size: 0.00390625
Loss on this step: 4418.413802576902, Loss on the last accepted step: 4515.065009751052, Step size: 0.00390625
Loss on this step: 4389.38411

2025-07-14 00:29:54,785 INFO jaxlogit.mixed_logit: Optimization finished, success = True, final loglike = -4359.22, final gradient max = 7.47e-05, norm = 6.75e-04.
2025-07-14 00:29:54,786 INFO jaxlogit.mixed_logit: Calculating gradient of individual log-likelihood contributions


Loss on this step: 4359.218229239384, Loss on the last accepted step: 4359.2182292405305, Step size: 0.09817397594451904


2025-07-14 00:29:58,379 INFO jaxlogit.mixed_logit: Calculating H_inv
2025-07-14 00:30:02,320 INFO jaxlogit._choice_model: Post fit processing
2025-07-14 00:30:03,519 INFO jaxlogit._choice_model: Optimization terminated successfully.


    Message: 
    Iterations: 60
    Function evaluations: 68
Estimation time= 29.0 seconds
---------------------------------------------------------------------------
Coefficient              Estimate      Std.Err.         z-val         P>|z|
---------------------------------------------------------------------------
ASC_CAR                 0.2831110     0.0560480     5.0512276       4.5e-07 ***
ASC_TRAIN              -0.5722733     0.0794778    -7.2004140      6.65e-13 ***
CO                     -1.6601666     0.0778870   -21.3150710      1.26e-97 ***
TT                     -3.2289976     0.1749805   -18.4534737      3.16e-74 ***
sd.TT                   3.6221628     0.1728452    20.9561137      1.58e-94 ***
---------------------------------------------------------------------------
Significance:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Log-Likelihood= -4359.218
AIC= 8728.436
BIC= 8762.536


The negative signs for the cost and time coefficients suggest that decision makers experience a general disutility with alternatives that have higher waiting times and costs, which conforms to the underlying decision making theory. Note that these estimates are highly consistent with those returned by Biogeme (https://biogeme.epfl.ch/examples/swissmetro/05normalMixtureIntegral.html)

## Electricity Dataset

The electricity dataset contains 4,308 choices among four electricity suppliers based on the attributes of the offered plans, which include prices(pf), contract lengths(cl), time of day rates (tod), seasonal rates(seas), as well as attributes of the suppliers, which include whether the supplier is local (loc) and well-known (wk). The data was collected through a survey where 12 different choice situations were presented to each participant. The multiple responses per participants were organized into panels. Given that some participants answered less than 12 of the choice situations, some panels are unbalanced, which `xlogit` is able to handle. [Revelt and Train (1999)](https://escholarship.org/content/qt1900p96t/qt1900p96t.pdf) provide a detailed description of this dataset. 

### Read data

The dataset is already in long format so no reshaping is necessary, it can be used directly in xlogit.

In [16]:
df = pd.read_csv("https://raw.github.com/arteagac/xlogit/master/examples/data/electricity_long.csv")
df

Unnamed: 0,choice,id,alt,pf,cl,loc,wk,tod,seas,chid
0,0,1,1,7,5,0,1,0,0,1
1,0,1,2,9,1,1,0,0,0,1
2,0,1,3,0,0,0,0,0,1,1
3,1,1,4,0,5,0,1,1,0,1
4,0,1,1,7,0,0,1,0,0,2
...,...,...,...,...,...,...,...,...,...,...
17227,0,361,4,0,1,1,0,0,1,4307
17228,1,361,1,9,0,0,1,0,0,4308
17229,0,361,2,7,0,0,0,0,0,4308
17230,0,361,3,0,1,0,1,0,1,4308


### Fit the model

Note that the parameter `panels` was included in the `fit` function in order to take into account panel structure of this dataset during estimation.

In [17]:
varnames = ['pf', 'cl', 'loc', 'wk', 'tod', 'seas']
model = MixedLogit()
res = model.fit(
    X=df[varnames],
    y=df['choice'],
    varnames=varnames,
    ids=df['chid'],
    panels=df['id'],
    alts=df['alt'],
    n_draws=600,
    randvars={'pf': 'n', 'cl': 'n', 'loc': 'n', 'wk': 'n', 'tod': 'n', 'seas': 'n'}
)
model.summary()

2025-07-14 00:40:23,467 INFO jaxlogit.mixed_logit: Starting data preparation, including generation of 600 random draws for each random variable and observation.
2025-07-14 00:40:23,817 INFO jaxlogit.mixed_logit: Data contains 361 panels, using segment_sum for panel-wise log-likelihood.
2025-07-14 00:40:23,817 INFO jaxlogit.mixed_logit: Shape of draws: (4308, 6, 600), number of draws: 600
2025-07-14 00:40:23,818 INFO jaxlogit.mixed_logit: Shape of Xdf: (4308, 3, 0), shape of Xdr: (4308, 3, 6)
2025-07-14 00:40:23,819 INFO jaxlogit.mixed_logit: Compiling log-likelihood function.
2025-07-14 00:40:23,888 INFO jaxlogit.mixed_logit: Compilation finished, init neg_loglike = 5413.77, params= [(np.str_('pf'), Array(0.1, dtype=float64)), (np.str_('cl'), Array(0.1, dtype=float64)), (np.str_('loc'), Array(0.1, dtype=float64)), (np.str_('wk'), Array(0.1, dtype=float64)), (np.str_('tod'), Array(0.1, dtype=float64)), (np.str_('seas'), Array(0.1, dtype=float64)), (np.str_('sd.pf'), Array(0.1, dtype=flo

Loss on this step: 5413.773124872253, Loss on the last accepted step: 0.0, Step size: 1.0
Loss on this step: 246611.46044373326, Loss on the last accepted step: 5413.773124872253, Step size: 0.25
Loss on this step: 237649.62180272705, Loss on the last accepted step: 5413.773124872253, Step size: 0.0625
Loss on this step: 134257.0137458888, Loss on the last accepted step: 5413.773124872253, Step size: 0.015625
Loss on this step: 34162.8291987072, Loss on the last accepted step: 5413.773124872253, Step size: 0.00390625
Loss on this step: 5359.06864908849, Loss on the last accepted step: 5413.773124872253, Step size: 0.00390625
Loss on this step: 15502.842318962963, Loss on the last accepted step: 5359.06864908849, Step size: 0.0009765625
Loss on this step: 4871.405959053192, Loss on the last accepted step: 5359.06864908849, Step size: 0.0009765625
Loss on this step: 4832.977340051752, Loss on the last accepted step: 4871.405959053192, Step size: 0.0009765625
Loss on this step: 4703.34861

2025-07-14 00:40:43,541 INFO jaxlogit.mixed_logit: Optimization finished, success = True, final loglike = -3888.47, final gradient max = 1.22e-05, norm = 2.12e-05.
2025-07-14 00:40:43,543 INFO jaxlogit.mixed_logit: Calculating gradient of individual log-likelihood contributions
2025-07-14 00:40:45,909 INFO jaxlogit.mixed_logit: Calculating H_inv
2025-07-14 00:40:49,367 INFO jaxlogit._choice_model: Post fit processing
2025-07-14 00:40:50,018 INFO jaxlogit._choice_model: Optimization terminated successfully.


    Message: 
    Iterations: 134
    Function evaluations: 143
Estimation time= 26.5 seconds
---------------------------------------------------------------------------
Coefficient              Estimate      Std.Err.         z-val         P>|z|
---------------------------------------------------------------------------
pf                     -0.9972118     0.0378219   -26.3660068     3.71e-142 ***
cl                     -0.2196750     0.0255168    -8.6090393      1.02e-17 ***
loc                     2.2902246     0.1263057    18.1323913      7.19e-71 ***
wk                      1.6943092     0.0961498    17.6215603      3.63e-67 ***
tod                    -9.6752577     0.3350901   -28.8736009     9.73e-168 ***
seas                   -9.6962299     0.3246517   -29.8665593     2.65e-178 ***
sd.pf                  -1.3984357     0.0964657   -14.4967194      1.56e-46 ***
sd.cl                  -0.6749717     0.0754671    -8.9439189      5.47e-19 ***
sd.loc                  1.6001547     

In [18]:
# Note the sd. variables in jaxlogit are softplus transformed by default such that they are always positive. To compare these to xlogits results at https://github.com/arteagac/xlogit/blob/master/examples/mixed_logit_model.ipynb
# use jax.nn.softplus(params) for non-asserted sd. params. Or run w/o softplus:
model = MixedLogit()
res = model.fit(
    X=df[varnames],
    y=df['choice'],
    varnames=varnames,
    ids=df['chid'],
    panels=df['id'],
    alts=df['alt'],
    n_draws=600,
    randvars={'pf': 'n', 'cl': 'n', 'loc': 'n', 'wk': 'n', 'tod': 'n', 'seas': 'n'},
    force_positive_chol_diag=False,
)
model.summary()

2025-07-14 00:40:50,037 INFO jaxlogit.mixed_logit: Starting data preparation, including generation of 600 random draws for each random variable and observation.
2025-07-14 00:40:50,413 INFO jaxlogit.mixed_logit: Data contains 361 panels, using segment_sum for panel-wise log-likelihood.
2025-07-14 00:40:50,413 INFO jaxlogit.mixed_logit: Shape of draws: (4308, 6, 600), number of draws: 600
2025-07-14 00:40:50,414 INFO jaxlogit.mixed_logit: Shape of Xdf: (4308, 3, 0), shape of Xdr: (4308, 3, 6)
2025-07-14 00:40:50,415 INFO jaxlogit.mixed_logit: Compiling log-likelihood function.
2025-07-14 00:40:50,624 INFO jaxlogit.mixed_logit: Compilation finished, init neg_loglike = 5620.08, params= [(np.str_('pf'), Array(0.1, dtype=float64)), (np.str_('cl'), Array(0.1, dtype=float64)), (np.str_('loc'), Array(0.1, dtype=float64)), (np.str_('wk'), Array(0.1, dtype=float64)), (np.str_('tod'), Array(0.1, dtype=float64)), (np.str_('seas'), Array(0.1, dtype=float64)), (np.str_('sd.pf'), Array(0.1, dtype=flo

Loss on this step: 5620.078112126817, Loss on the last accepted step: 0.0, Step size: 1.0
Loss on this step: 237583.03657902256, Loss on the last accepted step: 5620.078112126817, Step size: 0.25
Loss on this step: 186550.21436822342, Loss on the last accepted step: 5620.078112126817, Step size: 0.0625
Loss on this step: 59946.26708452902, Loss on the last accepted step: 5620.078112126817, Step size: 0.015625
Loss on this step: 17214.20567921159, Loss on the last accepted step: 5620.078112126817, Step size: 0.00390625
Loss on this step: 7648.936515317229, Loss on the last accepted step: 5620.078112126817, Step size: 0.0009765625
Loss on this step: 6172.239268016281, Loss on the last accepted step: 5620.078112126817, Step size: 0.000244140625
Loss on this step: 5539.260682475645, Loss on the last accepted step: 5620.078112126817, Step size: 0.000244140625
Loss on this step: 5397.039549904901, Loss on the last accepted step: 5539.260682475645, Step size: 0.0008544921875
Loss on this step

2025-07-14 00:41:09,014 INFO jaxlogit.mixed_logit: Optimization finished, success = True, final loglike = -3888.47, final gradient max = 3.10e-05, norm = 8.86e-05.
2025-07-14 00:41:09,015 INFO jaxlogit.mixed_logit: Calculating gradient of individual log-likelihood contributions


Loss on this step: 3888.4650698021405, Loss on the last accepted step: 3888.465069802795, Step size: 0.6168402913763202


2025-07-14 00:41:11,288 INFO jaxlogit.mixed_logit: Calculating H_inv
2025-07-14 00:41:14,863 INFO jaxlogit._choice_model: Post fit processing
2025-07-14 00:41:15,476 INFO jaxlogit._choice_model: Optimization terminated successfully.


    Message: 
    Iterations: 124
    Function evaluations: 137
Estimation time= 25.4 seconds
---------------------------------------------------------------------------
Coefficient              Estimate      Std.Err.         z-val         P>|z|
---------------------------------------------------------------------------
pf                     -0.9972113     0.0378219   -26.3660027     3.71e-142 ***
cl                     -0.2196750     0.0255168    -8.6090387      1.02e-17 ***
loc                     2.2902239     0.1263057    18.1323892      7.19e-71 ***
wk                      1.6943088     0.0961498    17.6215600      3.63e-67 ***
tod                    -9.6752534     0.3350900   -28.8735984     9.73e-168 ***
seas                   -9.6962259     0.3246516   -29.8665562     2.65e-178 ***
sd.pf                   0.2207270     0.0191064    11.5525033         2e-30 ***
sd.cl                   0.4115603     0.0254614    16.1640699       4.2e-57 ***
sd.loc                  1.7840292     

## Fishing Dataset

This example illustrates the estimation of a Mixed Logit model for choices of 1,182 individuals for sport fishing modes using `xlogit`. The goal is to analyze the market shares of four alternatives (i.e., beach, pier, boat, and charter) based on their cost and fish catch rate. [Cameron (2005)](http://cameron.econ.ucdavis.edu/mmabook/mma.html) provides additional details about this dataset. The following code illustrates how to use `xlogit` to estimate the model parameters.

### Read data

The data to be analyzed can be imported to Python using any preferred method. In this example, the data in CSV format was imported using the popular `pandas` Python package. However, it is worth highlighting that `xlogit` does not depend on the `pandas` package, as `xlogit` can take any array-like structure as input. This represents an additional advantage because `xlogit` can be used with any preferred dataframe library, and not only with `pandas`.

In [19]:
import pandas as pd
df = pd.read_csv("https://raw.github.com/arteagac/xlogit/master/examples/data/fishing_long.csv")
df

Unnamed: 0,id,alt,choice,income,price,catch
0,1,beach,0,7083.33170,157.930,0.0678
1,1,boat,0,7083.33170,157.930,0.2601
2,1,charter,1,7083.33170,182.930,0.5391
3,1,pier,0,7083.33170,157.930,0.0503
4,2,beach,0,1249.99980,15.114,0.1049
...,...,...,...,...,...,...
4723,1181,pier,0,416.66668,36.636,0.4522
4724,1182,beach,0,6250.00130,339.890,0.2537
4725,1182,boat,1,6250.00130,235.436,0.6817
4726,1182,charter,0,6250.00130,260.436,2.3014


### Fit model

Once the data is in the `Python` environment, `xlogit` can be used to fit the model, as shown below. The `MultinomialLogit` class is imported from `xlogit`, and its constructor is used to initialize a new model. The `fit` method estimates the model using the input data and estimation criteria provided as arguments to the method's call. The arguments of the `fit` methods are described in [`xlogit`'s documentation](https://https://xlogit.readthedocs.io/en/latest/api/).


In [None]:
varnames = ['price', 'catch']
model = MixedLogit()
model.fit(
    X=df[varnames],
    y=df['choice'],
    varnames=varnames,
    alts=df['alt'],
    ids=df['id'],
    n_draws=2000,  # Note using 1000 draws here leads to sd.catch going to zero, need more draws to find minimum at positive stddev
    randvars={'price': 'n', 'catch': 'n'},
)
model.summary()

2025-07-14 00:42:22,862 INFO jaxlogit.mixed_logit: Starting data preparation, including generation of 2000 random draws for each random variable and observation.
2025-07-14 00:42:24,039 INFO jaxlogit.mixed_logit: Shape of draws: (1182, 2, 2000), number of draws: 2000
2025-07-14 00:42:24,040 INFO jaxlogit.mixed_logit: Shape of Xdf: (1182, 3, 0), shape of Xdr: (1182, 3, 2)
2025-07-14 00:42:24,040 INFO jaxlogit.mixed_logit: Compiling log-likelihood function.
2025-07-14 00:42:24,193 INFO jaxlogit.mixed_logit: Compilation finished, init neg_loglike = 2342.17, params= [(np.str_('price'), Array(0.1, dtype=float64)), (np.str_('catch'), Array(0.1, dtype=float64)), (np.str_('sd.price'), Array(0.1, dtype=float64)), (np.str_('sd.catch'), Array(0.1, dtype=float64))]
2025-07-14 00:42:24,194 INFO jaxlogit._optimize: Running minimization with method trust-region


Loss on this step: 2342.172165320885, Loss on the last accepted step: 0.0, Step size: 1.0
Loss on this step: 346690.4235844468, Loss on the last accepted step: 2342.172165320885, Step size: 0.25
Loss on this step: 343586.74230861163, Loss on the last accepted step: 2342.172165320885, Step size: 0.0625
Loss on this step: 340345.4710080018, Loss on the last accepted step: 2342.172165320885, Step size: 0.015625
Loss on this step: 186655.5449715904, Loss on the last accepted step: 2342.172165320885, Step size: 0.00390625
Loss on this step: 26460.878950393257, Loss on the last accepted step: 2342.172165320885, Step size: 0.0009765625
Loss on this step: 2188.526926627541, Loss on the last accepted step: 2342.172165320885, Step size: 0.0009765625
Loss on this step: 2132.948497166779, Loss on the last accepted step: 2188.526926627541, Step size: 0.00341796875
Loss on this step: 3066.1216490150573, Loss on the last accepted step: 2132.948497166779, Step size: 0.0008544921875
Loss on this step: 

2025-07-14 00:42:32,323 INFO jaxlogit.mixed_logit: Optimization finished, success = True, final loglike = -1300.58, final gradient max = -4.50e-07, norm = 1.55e-04.
2025-07-14 00:42:32,324 INFO jaxlogit.mixed_logit: Calculating gradient of individual log-likelihood contributions


Loss on this step: 1300.5817132549464, Loss on the last accepted step: 1300.5817132549464, Step size: 8.620749807419167e-08
Loss on this step: 1300.5817132549462, Loss on the last accepted step: 1300.5817132549464, Step size: 3.017262432596709e-07


2025-07-14 00:42:32,994 INFO jaxlogit.mixed_logit: Calculating H_inv
2025-07-14 00:42:33,779 INFO jaxlogit._choice_model: Post fit processing
2025-07-14 00:42:34,077 INFO jaxlogit._choice_model: Optimization terminated successfully.


    Message: 
    Iterations: 111
    Function evaluations: 139
Estimation time= 11.2 seconds
---------------------------------------------------------------------------
Coefficient              Estimate      Std.Err.         z-val         P>|z|
---------------------------------------------------------------------------
price                  -0.0272479     0.0022848   -11.9259324      4.83e-31 ***
catch                   1.3258108     0.1738880     7.6245081      5.01e-14 ***
sd.price               -4.5741720     0.2085573   -21.9324510      9.97e-90 ***
sd.catch                1.3316625     0.4742604     2.8078723       0.00507 ** 
---------------------------------------------------------------------------
Significance:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Log-Likelihood= -1300.582
AIC= 2609.163
BIC= 2629.463


In [None]:
# sd. vals agree with xlogit results except for sign of sd.catch, which is due to xlogit not restricting the sd devs to positive parameters and the log-likelihood being symmetric wrt to sign of normal std dev for non-correlated parameters. 
jax.nn.softplus(model.coeff_[len(model._rvidx):])

Array([0.010262  , 1.56597382], dtype=float64)

## Car Dataset

The fourth example uses a stated preference panel dataset for choice of car. Three alternatives are considered, with upto 6 choice situations per individual. This again is an unbalanced panel with responses of some individuals less than 6 situations. The dataset contains 8 explanaotry variables: price, operating cost, range, and binary indicators to indicate whether the car is electric, hybrid, and if performance is high or medium respectively. This dataset was taken from Kenneth Train's MATLAB codes for estimation of Mixed Logit models as shown in this link: https://eml.berkeley.edu/Software/abstracts/train1006mxlmsl.html

### Read data

In [30]:
import pandas as pd
import numpy as np

df = pd.read_csv("https://raw.github.com/arteagac/xlogit/master/examples/data/car100_long.csv")

Since price and operating cost need to be estimated with negative coefficients, we reverse the variable signs in the dataframe. 

In [31]:
df['price'] = -df['price']/10000
df['opcost'] = -df['opcost']
df

Unnamed: 0,person_id,choice_id,alt,choice,price,opcost,range,ev,gas,hybrid,hiperf,medhiperf
0,1,1,1,0,-4.6763,-47.43,0.0,0,0,1,0,0
1,1,1,2,1,-5.7209,-27.43,1.3,1,0,0,1,1
2,1,1,3,0,-8.7960,-32.41,1.2,1,0,0,0,1
3,1,2,1,1,-3.3768,-4.89,1.3,1,0,0,1,1
4,1,2,2,0,-9.0336,-30.19,0.0,0,0,1,0,1
...,...,...,...,...,...,...,...,...,...,...,...,...
4447,100,1483,2,0,-2.8036,-14.45,1.6,1,0,0,0,0
4448,100,1483,3,0,-1.9360,-54.76,0.0,0,1,0,1,1
4449,100,1484,1,1,-2.4054,-50.57,0.0,0,1,0,0,0
4450,100,1484,2,0,-5.2795,-21.25,0.0,0,0,1,0,1


### Fit the model

In [32]:
varnames = ['hiperf', 'medhiperf', 'price', 'opcost', 'range', 'ev', 'hybrid'] 
model = MixedLogit()
model.fit(
    X=df[varnames],
    y=df['choice'],
    varnames=varnames,
    alts=df['alt'],
    ids=df['choice_id'],
    panels=df['person_id'],
    randvars = {'price': 'ln', 'opcost': 'n',  'range': 'ln', 'ev':'n', 'hybrid': 'n'}, 
    n_draws = 1000,
)
model.summary()

2025-07-14 00:47:49,283 INFO jaxlogit.mixed_logit: Starting data preparation, including generation of 1000 random draws for each random variable and observation.
2025-07-14 00:47:49,455 INFO jaxlogit.mixed_logit: Lognormal distributions found for 2 random variables, applying transformation.
2025-07-14 00:47:49,456 INFO jaxlogit.mixed_logit: Data contains 100 panels, using segment_sum for panel-wise log-likelihood.
2025-07-14 00:47:49,457 INFO jaxlogit.mixed_logit: Shape of draws: (1484, 5, 1000), number of draws: 1000
2025-07-14 00:47:49,458 INFO jaxlogit.mixed_logit: Shape of Xdf: (1484, 2, 2), shape of Xdr: (1484, 2, 5)
2025-07-14 00:47:49,458 INFO jaxlogit.mixed_logit: Compiling log-likelihood function.
2025-07-14 00:47:49,536 INFO jaxlogit.mixed_logit: Compilation finished, init neg_loglike = 1737.11, params= [(np.str_('hiperf'), Array(0.1, dtype=float64)), (np.str_('medhiperf'), Array(0.1, dtype=float64)), (np.str_('price'), Array(0.1, dtype=float64)), (np.str_('opcost'), Array(0.

Loss on this step: 1737.1071909985076, Loss on the last accepted step: 0.0, Step size: 1.0
Loss on this step: 68618.60002635523, Loss on the last accepted step: 1737.1071909985076, Step size: 0.25
Loss on this step: 68461.72816998586, Loss on the last accepted step: 1737.1071909985076, Step size: 0.0625
Loss on this step: 66708.17052456817, Loss on the last accepted step: 1737.1071909985076, Step size: 0.015625
Loss on this step: 22075.294280744205, Loss on the last accepted step: 1737.1071909985076, Step size: 0.00390625
Loss on this step: 1588.862771609761, Loss on the last accepted step: 1737.1071909985076, Step size: 0.00390625
Loss on this step: 1546.9474055835506, Loss on the last accepted step: 1588.862771609761, Step size: 0.00390625
Loss on this step: 1514.787075048447, Loss on the last accepted step: 1546.9474055835506, Step size: 0.013671875
Loss on this step: 1414.7004832575594, Loss on the last accepted step: 1514.787075048447, Step size: 0.0478515625
Loss on this step: 68

2025-07-14 00:48:02,912 INFO jaxlogit.mixed_logit: Optimization finished, success = True, final loglike = -1298.18, final gradient max = 1.13e-05, norm = 1.13e-05.
2025-07-14 00:48:02,914 INFO jaxlogit.mixed_logit: Calculating gradient of individual log-likelihood contributions


Loss on this step: 1298.1755539946694, Loss on the last accepted step: 1298.175553994684, Step size: 0.8056689520017244
Loss on this step: 1298.1755539946685, Loss on the last accepted step: 1298.1755539946694, Step size: 0.8056689520017244


2025-07-14 00:48:04,350 INFO jaxlogit.mixed_logit: Calculating H_inv
2025-07-14 00:48:06,617 INFO jaxlogit._choice_model: Post fit processing
2025-07-14 00:48:07,128 INFO jaxlogit._choice_model: Optimization terminated successfully.


    Message: 
    Iterations: 126
    Function evaluations: 137
Estimation time= 17.8 seconds
---------------------------------------------------------------------------
Coefficient              Estimate      Std.Err.         z-val         P>|z|
---------------------------------------------------------------------------
hiperf                  0.1057373     0.0967950     1.0923846         0.275    
medhiperf               0.5713820     0.1004442     5.6885539      1.54e-08 ***
price                  -0.7405426     0.1406671    -5.2645058      1.61e-07 ***
opcost                  0.0119861     0.0057168     2.0966349        0.0362 *  
range                  -0.6709001     0.4032763    -1.6636240        0.0964 .  
ev                     -1.5936962     0.3357173    -4.7471372      2.26e-06 ***
hybrid                  0.7057744     0.1624134     4.3455432      1.48e-05 ***
sd.price                0.4547723     0.1965636     2.3136136        0.0208 *  
sd.opcost              -3.2385656     

In [33]:
jax.nn.softplus(model.coeff_[len(model._rvidx):])

Array([0.94616582, 0.03847054, 0.5772908 , 0.99715057, 0.74591881],      dtype=float64)

## References

- Bierlaire, M. (2018). PandasBiogeme: a short introduction. EPFL (Transport and Mobility Laboratory, ENAC).

- Brathwaite, T., & Walker, J. L. (2018). Asymmetric, closed-form, finite-parameter models of multinomial choice. Journal of Choice Modelling, 29, 78–112. 

- Cameron, A. C., & Trivedi, P. K. (2005). Microeconometrics: methods and applications. Cambridge university press.

- Croissant, Y. (2020). Estimation of Random Utility Models in R: The mlogit Package. Journal of Statistical Software, 95(1), 1-41.

- Revelt, D., & Train, K. (1999). Customer-specific taste parameters and mixed logit. University of California, Berkeley.

