# Example of running jaxlogit with batched draws

jaxlogit's default way of processing random draws for simulation is to generate them once at the beginning and then run calculate the loglikelihood at each step of the optimization routine with these draws. The size of the corresponding array is (number_of_observations x number_of_random_variables x  number_of_draws) which can get very large. In case tnis is too large for local memory, jaxlogit can dynamcially generate draws on each iteration. The advantage of this is that calculations can now be batched, i.e., processed on smaller subsets and then added up. This reduces memory load that the cost of runtime. Note that jax still calculates gradients so this method also has memory limits.

In [1]:
%load_ext memory_profiler

In [2]:
import pandas as pd
import numpy as np
import jax

from jaxlogit.mixed_logit import MixedLogit

In [3]:
#  64bit precision
jax.config.update("jax_enable_x64", True)

## Electricity Dataset

From xlogit's examples. Note we skip the calculation of std errors here to speed up test times.

In [4]:
df = pd.read_csv("https://raw.github.com/arteagac/xlogit/master/examples/data/electricity_long.csv")

In [5]:
n_obs = df['chid'].unique().shape[0]
n_vars = 6
n_draws = 5000

size_in_ram = (n_obs * n_vars * n_draws * 8) / (1024 ** 3)  # in GB

print(
    f"Data has {n_obs} observations, we use {n_vars} random variables in the model. We work in 64 bit precision, so each element is 8 bytes."
    + f" For {n_draws} draws, the array of draws is about {size_in_ram:.2f} GB."
)

varnames = ['pf', 'cl', 'loc', 'wk', 'tod', 'seas']

Data has 4308 observations, we use 6 random variables in the model. We work in 64 bit precision, so each element is 8 bytes. For 5000 draws, the array of draws is about 0.96 GB.


In [7]:
%%memit

model = MixedLogit()
res = model.fit(
    X=df[varnames],
    y=df['choice'],
    varnames=varnames,
    ids=df['chid'],
    panels=df['id'],
    alts=df['alt'],
    n_draws=n_draws,
    randvars={'pf': 'n', 'cl': 'n', 'loc': 'n', 'wk': 'n', 'tod': 'n', 'seas': 'n'},
    skip_std_errs=True,  # skip standard errors to speed up the example
    batch_size=None,
    optim_method="L-BFGS-B",
)
display(model.summary())

INFO:2025-08-04 21:42:03,524:jax._src.xla_bridge:752: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2025-08-04 21:42:03,524 INFO jax._src.xla_bridge: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2025-08-04 21:42:06,092 INFO jaxlogit.mixed_logit: Data contains 361 panels.
2025-08-04 21:42:06,093 INFO jaxlogit._optimize: Running minimization with method L-BFGS-B
2025-08-04 21:46:07,598 INFO jaxlogit.mixed_logit: Optimization finished, success = True, final loglike = -3880.18, final gradient max = 1.90e-03, norm = 1.13e-02.
2025-08-04 21:46:07,608 INFO jaxlogit.mixed_logit: Skipping H_inv and grad_n calculation due to skip_std_errs=True
2025-08-04 21:46:07,609 INFO jaxlogit._choice_model: Post fit processing
2025-08-04 21:46:07,903 INFO jaxlogit._choice_model: Optimization terminated successfully.


    Message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH
    Iterations: 95
    Function evaluations: 117
Estimation time= 244.3 seconds
---------------------------------------------------------------------------
Coefficient              Estimate      Std.Err.         z-val         P>|z|
---------------------------------------------------------------------------
pf                     -1.0166227     1.0000000    -1.0166227         0.309    
cl                     -0.2327797     1.0000000    -0.2327797         0.816    
loc                     2.3556730     1.0000000     2.3556730        0.0185 *  
wk                      1.6744819     1.0000000     1.6744819        0.0941 .  
tod                    -9.7531665     1.0000000    -9.7531665      3.03e-22 ***
seas                   -9.9133674     1.0000000    -9.9133674       6.4e-23 ***
sd.pf                  -1.3452424     1.0000000    -1.3452424         0.179    
sd.cl                  -0.6834139     1.0000000    -0.6834139     

None

peak memory: 9810.14 MiB, increment: 9543.38 MiB


In [8]:
%%memit

model = MixedLogit()
res = model.fit(
    X=df[varnames],
    y=df['choice'],
    varnames=varnames,
    ids=df['chid'],
    panels=df['id'],
    alts=df['alt'],
    n_draws=n_draws,
    randvars={'pf': 'n', 'cl': 'n', 'loc': 'n', 'wk': 'n', 'tod': 'n', 'seas': 'n'},
    skip_std_errs=True,  # skip standard errors to speed up the example
    batch_size=1077,  # should result in 4 batches
    optim_method="L-BFGS-B",  # "trust-region", "L-BFGS-B", "BFGS"lver
)
display(model.summary())

  self.pid = os.fork()
2025-08-04 21:46:19,366 INFO jaxlogit.mixed_logit: Data contains 361 panels.
2025-08-04 21:46:19,367 INFO jaxlogit._optimize: Running minimization with method L-BFGS-B
2025-08-04 21:51:32,972 INFO jaxlogit.mixed_logit: Optimization finished, success = True, final loglike = -3880.18, final gradient max = 1.90e-03, norm = 1.13e-02.
2025-08-04 21:51:32,974 INFO jaxlogit.mixed_logit: Skipping H_inv and grad_n calculation due to skip_std_errs=True
2025-08-04 21:51:32,975 INFO jaxlogit._choice_model: Post fit processing
2025-08-04 21:51:33,020 INFO jaxlogit._choice_model: Optimization terminated successfully.


    Message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH
    Iterations: 95
    Function evaluations: 117
Estimation time= 315.9 seconds
---------------------------------------------------------------------------
Coefficient              Estimate      Std.Err.         z-val         P>|z|
---------------------------------------------------------------------------
pf                     -1.0166227     1.0000000    -1.0166227         0.309    
cl                     -0.2327797     1.0000000    -0.2327797         0.816    
loc                     2.3556730     1.0000000     2.3556730        0.0185 *  
wk                      1.6744819     1.0000000     1.6744819        0.0941 .  
tod                    -9.7531665     1.0000000    -9.7531665      3.03e-22 ***
seas                   -9.9133674     1.0000000    -9.9133674       6.4e-23 ***
sd.pf                  -1.3452424     1.0000000    -1.3452424         0.179    
sd.cl                  -0.6834139     1.0000000    -0.6834139     

None

  self.pid = os.fork()


peak memory: 4335.09 MiB, increment: 3563.82 MiB


In [9]:
%%memit

model = MixedLogit()
res = model.fit(
    X=df[varnames],
    y=df['choice'],
    varnames=varnames,
    ids=df['chid'],
    panels=df['id'],
    alts=df['alt'],
    n_draws=n_draws,
    randvars={'pf': 'n', 'cl': 'n', 'loc': 'n', 'wk': 'n', 'tod': 'n', 'seas': 'n'},
    skip_std_errs=True,  # skip standard errors to speed up the example
    batch_size=539,
    optim_method="L-BFGS-B",
)
display(model.summary())

  self.pid = os.fork()
2025-08-04 21:53:53,729 INFO jaxlogit.mixed_logit: Data contains 361 panels.
2025-08-04 21:53:53,730 INFO jaxlogit._optimize: Running minimization with method L-BFGS-B
2025-08-04 22:01:20,661 INFO jaxlogit.mixed_logit: Optimization finished, success = True, final loglike = -3880.18, final gradient max = 1.90e-03, norm = 1.13e-02.
2025-08-04 22:01:20,662 INFO jaxlogit.mixed_logit: Skipping H_inv and grad_n calculation due to skip_std_errs=True
2025-08-04 22:01:20,663 INFO jaxlogit._choice_model: Post fit processing
2025-08-04 22:01:20,704 INFO jaxlogit._choice_model: Optimization terminated successfully.


    Message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH
    Iterations: 95
    Function evaluations: 117
Estimation time= 448.7 seconds
---------------------------------------------------------------------------
Coefficient              Estimate      Std.Err.         z-val         P>|z|
---------------------------------------------------------------------------
pf                     -1.0166227     1.0000000    -1.0166227         0.309    
cl                     -0.2327797     1.0000000    -0.2327797         0.816    
loc                     2.3556730     1.0000000     2.3556730        0.0185 *  
wk                      1.6744819     1.0000000     1.6744819        0.0941 .  
tod                    -9.7531665     1.0000000    -9.7531665      3.03e-22 ***
seas                   -9.9133674     1.0000000    -9.9133674       6.4e-23 ***
sd.pf                  -1.3452424     1.0000000    -1.3452424         0.179    
sd.cl                  -0.6834139     1.0000000    -0.6834139     

None

  self.pid = os.fork()


peak memory: 3402.84 MiB, increment: 2489.12 MiB
