# Example of running jaxlogit with batched draws

jaxlogit's default way of processing random draws for simulation is to generate them once at the beginning and then run calculate the loglikelihood at each step of the optimization routine with these draws. The size of the corresponding array is (number_of_observations x number_of_random_variables x  number_of_draws) which can get very large. In case tnis is too large for local memory, jaxlogit can dynamcially generate draws on each iteration. The advantage of this is that calculations can now be batched, i.e., processed on smaller subsets and then added up. This reduces memory load that the cost of runtime. Note that jax still calculates gradients so this method also has memory limits.

In [1]:
%load_ext memory_profiler

In [2]:
import pandas as pd
import numpy as np
import jax

from jaxlogit.mixed_logit import MixedLogit

In [3]:
#  64bit precision
jax.config.update("jax_enable_x64", True)

## Electricity Dataset

From xlogit's examples. Note we skip the calculation of std errors here to speed up test times.

In [4]:
df = pd.read_csv("https://raw.github.com/arteagac/xlogit/master/examples/data/electricity_long.csv")

In [5]:
n_obs = df['chid'].unique().shape[0]
n_vars = 6
n_draws = 5000

size_in_ram = (n_obs * n_vars * n_draws * 8) / (1024 ** 3)  # in GB

print(
    f"Data has {n_obs} observations, we use {n_vars} random variables in the model. We work in 64 bit precision, so each element is 8 bytes."
    + f" For {n_draws} draws, the array of draws is about {size_in_ram:.2f} GB."
)

varnames = ['pf', 'cl', 'loc', 'wk', 'tod', 'seas']

Data has 4308 observations, we use 6 random variables in the model. We work in 64 bit precision, so each element is 8 bytes. For 5000 draws, the array of draws is about 0.96 GB.


In [6]:
%%memit

model = MixedLogit()
res = model.fit(
    X=df[varnames],
    y=df['choice'],
    varnames=varnames,
    ids=df['chid'],
    panels=df['id'],
    alts=df['alt'],
    n_draws=n_draws,
    randvars={'pf': 'n', 'cl': 'n', 'loc': 'n', 'wk': 'n', 'tod': 'n', 'seas': 'n'},
    skip_std_errs=True,  # skip standard errors to speed up the example
    # force_positive_chol_diag=False  # do not use softplus transformation for sd. variables
)
display(model.summary())

INFO:2025-08-03 09:01:04,867:jax._src.xla_bridge:752: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2025-08-03 09:01:04,867 INFO jax._src.xla_bridge: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2025-08-03 09:01:08,019 INFO jaxlogit.mixed_logit: Number of draws: 5000.
2025-08-03 09:01:08,020 INFO jaxlogit.mixed_logit: Data contains 361 panels.
2025-08-03 09:01:08,020 INFO jaxlogit._optimize: Running minimization with method trust-region


Loss on this step: 5398.76049331062, Loss on the last accepted step: 0.0, Step size: 1.0
Loss on this step: 228458.86628940978, Loss on the last accepted step: 5398.76049331062, Step size: 0.25
Loss on this step: 160451.46329785584, Loss on the last accepted step: 5398.76049331062, Step size: 0.0625
Loss on this step: 50414.41638847614, Loss on the last accepted step: 5398.76049331062, Step size: 0.015625
Loss on this step: 12744.479542385374, Loss on the last accepted step: 5398.76049331062, Step size: 0.00390625
Loss on this step: 4876.005292732634, Loss on the last accepted step: 5398.76049331062, Step size: 0.00390625
Loss on this step: 12433.663557121343, Loss on the last accepted step: 4876.005292732634, Step size: 0.0009765625
Loss on this step: 4805.104927545437, Loss on the last accepted step: 4876.005292732634, Step size: 0.0009765625
Loss on this step: 4727.152504373395, Loss on the last accepted step: 4805.104927545437, Step size: 0.0009765625
Loss on this step: 4745.130706

2025-08-03 09:03:15,641 INFO jaxlogit.mixed_logit: Optimization finished, success = True, final loglike = -3880.18, final gradient max = 3.52e-05, norm = 3.71e-05.
2025-08-03 09:03:15,642 INFO jaxlogit.mixed_logit: Skipping H_inv and grad_n calculation due to skip_std_errs=True
2025-08-03 09:03:15,643 INFO jaxlogit._choice_model: Post fit processing
2025-08-03 09:03:15,872 INFO jaxlogit._choice_model: Optimization terminated successfully.


    Message: 
    Iterations: 122
    Function evaluations: 135
Estimation time= 131.0 seconds
---------------------------------------------------------------------------
Coefficient              Estimate      Std.Err.         z-val         P>|z|
---------------------------------------------------------------------------
pf                     -1.0166190     1.0000000    -1.0166190         0.309    
cl                     -0.2327734     1.0000000    -0.2327734         0.816    
loc                     2.3556396     1.0000000     2.3556396        0.0185 *  
wk                      1.6744863     1.0000000     1.6744863        0.0941 .  
tod                    -9.7531221     1.0000000    -9.7531221      3.03e-22 ***
seas                   -9.9133439     1.0000000    -9.9133439       6.4e-23 ***
sd.pf                  -1.3452418     1.0000000    -1.3452418         0.179    
sd.cl                  -0.6834247     1.0000000    -0.6834247         0.494    
sd.loc                  1.7530060    

None

peak memory: 13964.31 MiB, increment: 13699.08 MiB


### Now with batching

In [7]:
%%memit

model = MixedLogit()
res = model.fit(
    X=df[varnames],
    y=df['choice'],
    varnames=varnames,
    ids=df['chid'],
    panels=df['id'],
    alts=df['alt'],
    n_draws=n_draws,
    randvars={'pf': 'n', 'cl': 'n', 'loc': 'n', 'wk': 'n', 'tod': 'n', 'seas': 'n'},
    skip_std_errs=True,  # skip standard errors to speed up the example
    batch_size=2500,
)
display(model.summary())

  self.pid = os.fork()
2025-08-03 09:03:16,494 INFO jaxlogit.mixed_logit: Batch size 2500 for 5000 draws, 2 batches, batch_shape=(361, 6, 2500).
2025-08-03 09:03:16,712 INFO jaxlogit.mixed_logit: Shape of halton_rand_idxs: (2, 902500), last row: [ 902600  902601  902602 ... 1805097 1805098 1805099].
2025-08-03 09:03:16,712 INFO jaxlogit.mixed_logit: Data contains 361 panels.
2025-08-03 09:03:16,713 INFO jaxlogit._optimize: Running minimization with method trust-region


Loss on this step: 5400.574956968836, Loss on the last accepted step: 0.0, Step size: 1.0
Loss on this step: 236205.1987293798, Loss on the last accepted step: 5400.574956968836, Step size: 0.25
Loss on this step: 179328.5462734287, Loss on the last accepted step: 5400.574956968836, Step size: 0.0625
Loss on this step: 60138.966668619876, Loss on the last accepted step: 5400.574956968836, Step size: 0.015625
Loss on this step: 15029.725330006007, Loss on the last accepted step: 5400.574956968836, Step size: 0.00390625
Loss on this step: 4902.1301395825285, Loss on the last accepted step: 5400.574956968836, Step size: 0.00390625
Loss on this step: 23906.99970969664, Loss on the last accepted step: 4902.1301395825285, Step size: 0.0009765625
Loss on this step: 5000.272332071836, Loss on the last accepted step: 4902.1301395825285, Step size: 0.000244140625
Loss on this step: 4818.005552007051, Loss on the last accepted step: 4902.1301395825285, Step size: 0.000244140625
Loss on this step:

2025-08-03 09:10:27,677 INFO jaxlogit.mixed_logit: Optimization finished, success = True, final loglike = -3879.92, final gradient max = 7.33e-06, norm = 3.69e-05.
2025-08-03 09:10:27,679 INFO jaxlogit.mixed_logit: Skipping H_inv and grad_n calculation due to skip_std_errs=True
2025-08-03 09:10:27,680 INFO jaxlogit._choice_model: Post fit processing
2025-08-03 09:10:27,689 INFO jaxlogit._choice_model: Optimization terminated successfully.


    Message: 
    Iterations: 147
    Function evaluations: 157
Estimation time= 431.2 seconds
---------------------------------------------------------------------------
Coefficient              Estimate      Std.Err.         z-val         P>|z|
---------------------------------------------------------------------------
pf                     -1.0109002     1.0000000    -1.0109002         0.312    
cl                     -0.2295223     1.0000000    -0.2295223         0.818    
loc                     2.3477364     1.0000000     2.3477364        0.0189 *  
wk                      1.6751214     1.0000000     1.6751214         0.094 .  
tod                    -9.7162856     1.0000000    -9.7162856      4.32e-22 ***
seas                   -9.8762655     1.0000000    -9.8762655      9.19e-23 ***
sd.pf                  -1.3990437     1.0000000    -1.3990437         0.162    
sd.cl                  -0.6755186     1.0000000    -0.6755186         0.499    
sd.loc                  1.7072542    

None

  self.pid = os.fork()


peak memory: 7562.27 MiB, increment: 6776.59 MiB


Peak memory usage came down 14GB to 7.6GB, but this is at the cost of runtime, which went from 130s to 430s.

### Now with low-memory BFGS

In [9]:
%%memit

model = MixedLogit()
res = model.fit(
    X=df[varnames],
    y=df['choice'],
    varnames=varnames,
    ids=df['chid'],
    panels=df['id'],
    alts=df['alt'],
    n_draws=n_draws,
    randvars={'pf': 'n', 'cl': 'n', 'loc': 'n', 'wk': 'n', 'tod': 'n', 'seas': 'n'},
    skip_std_errs=True,  # skip standard errors to speed up the example
    batch_size=None,
    optim_method="L-BFGS-B",  # "trust-region", "L-BFGS-B", "BFGS"lver
)
display(model.summary())

  self.pid = os.fork()
2025-08-03 09:13:03,135 INFO jaxlogit.mixed_logit: Number of draws: 5000.
2025-08-03 09:13:03,136 INFO jaxlogit.mixed_logit: Data contains 361 panels.
2025-08-03 09:13:03,136 INFO jaxlogit._optimize: Running minimization with method L-BFGS-B
2025-08-03 09:13:08,855 INFO jaxlogit._optimize: Iter 1, fun = 5026.871, |grad| = 337.015
2025-08-03 09:13:12,607 INFO jaxlogit._optimize: Iter 2, fun = 4595.819, |grad| = 1000.464
2025-08-03 09:13:17,632 INFO jaxlogit._optimize: Iter 3, fun = 4570.965, |grad| = 389.479
2025-08-03 09:13:20,893 INFO jaxlogit._optimize: Iter 4, fun = 4554.501, |grad| = 215.658
2025-08-03 09:13:25,762 INFO jaxlogit._optimize: Iter 5, fun = 4543.517, |grad| = 181.063
2025-08-03 09:13:28,943 INFO jaxlogit._optimize: Iter 6, fun = 4532.249, |grad| = 156.799
2025-08-03 09:13:32,236 INFO jaxlogit._optimize: Iter 7, fun = 4437.262, |grad| = 242.752
2025-08-03 09:13:35,420 INFO jaxlogit._optimize: Iter 8, fun = 4251.585, |grad| = 403.253
2025-08-03 09:

    Message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH
    Iterations: 95
    Function evaluations: 117
Estimation time= 417.3 seconds
---------------------------------------------------------------------------
Coefficient              Estimate      Std.Err.         z-val         P>|z|
---------------------------------------------------------------------------
pf                     -1.0166227     1.0000000    -1.0166227         0.309    
cl                     -0.2327797     1.0000000    -0.2327797         0.816    
loc                     2.3556730     1.0000000     2.3556730        0.0185 *  
wk                      1.6744819     1.0000000     1.6744819        0.0941 .  
tod                    -9.7531665     1.0000000    -9.7531665      3.03e-22 ***
seas                   -9.9133674     1.0000000    -9.9133674       6.4e-23 ***
sd.pf                  -1.3452424     1.0000000    -1.3452424         0.179    
sd.cl                  -0.6834139     1.0000000    -0.6834139     

None

  self.pid = os.fork()


peak memory: 10124.79 MiB, increment: 9254.01 MiB


In [6]:
%%memit

model = MixedLogit()
res = model.fit(
    X=df[varnames],
    y=df['choice'],
    varnames=varnames,
    ids=df['chid'],
    panels=df['id'],
    alts=df['alt'],
    n_draws=n_draws,
    randvars={'pf': 'n', 'cl': 'n', 'loc': 'n', 'wk': 'n', 'tod': 'n', 'seas': 'n'},
    skip_std_errs=True,  # skip standard errors to speed up the example
    batch_size=2500,
    optim_method="L-BFGS-B",  # "trust-region", "L-BFGS-B", "BFGS"lver
)
display(model.summary())

INFO:2025-08-03 09:43:59,274:jax._src.xla_bridge:752: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2025-08-03 09:43:59,274 INFO jax._src.xla_bridge: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2025-08-03 09:43:59,968 INFO jaxlogit.mixed_logit: Batch size 2500 for 5000 draws, 2 batches, batch_shape=(361, 6, 2500).
2025-08-03 09:44:00,120 INFO jaxlogit.mixed_logit: Shape of halton_rand_idxs: (2, 902500), last row: [ 902600  902601  902602 ... 1805097 1805098 1805099].
2025-08-03 09:44:00,121 INFO jaxlogit.mixed_logit: Data contains 361 panels.
2025-08-03 09:44:00,121 INFO jaxlogit._optimize: Running minimization with method L-BFGS-B
2025-08-03 09:47:41,623 INFO jaxlogit.mixed_logit: Optimization finished, success = True, final loglike = -3879.92, final gradient max = 8.02e-03, norm = 1.49e-02.
2025

    Message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH
    Iterations: 59
    Function evaluations: 67
Estimation time= 222.5 seconds
---------------------------------------------------------------------------
Coefficient              Estimate      Std.Err.         z-val         P>|z|
---------------------------------------------------------------------------
pf                     -1.0109200     1.0000000    -1.0109200         0.312    
cl                     -0.2295175     1.0000000    -0.2295175         0.818    
loc                     2.3476811     1.0000000     2.3476811        0.0189 *  
wk                      1.6751275     1.0000000     1.6751275         0.094 .  
tod                    -9.7164058     1.0000000    -9.7164058      4.32e-22 ***
seas                   -9.8764168     1.0000000    -9.8764168      9.18e-23 ***
sd.pf                  -1.3990568     1.0000000    -1.3990568         0.162    
sd.cl                  -0.6755195     1.0000000    -0.6755195      

None

peak memory: 7210.20 MiB, increment: 6944.94 MiB
