# Example of running jaxlogit with batched draws

jaxlogit's default way of processing random draws for simulation is to generate them once at the beginning and then run calculate the loglikelihood at each step of the optimization routine with these draws. The size of the corresponding array is (number_of_observations x number_of_random_variables x  number_of_draws) which can get very large. In case tnis is too large for local memory, jaxlogit can dynamcially generate draws on each iteration. The advantage of this is that calculations can now be batched, i.e., processed on smaller subsets and then added up. This reduces memory load that the cost of runtime. Note that jax still calculates gradients so this method also has memory limits.

In [6]:
%load_ext memory_profiler

In [1]:
import pandas as pd
import numpy as np
import jax

from jaxlogit.mixed_logit import MixedLogit

In [2]:
#  64bit precision
jax.config.update("jax_enable_x64", True)

## Electricity Dataset

In [3]:
df = pd.read_csv("https://raw.github.com/arteagac/xlogit/master/examples/data/electricity_long.csv")

In [12]:
n_obs = df['chid'].unique().shape[0]
n_vars = 6
n_draws = 5000

size_in_ram = (n_obs * n_vars * n_draws * 8) / (1024 ** 3)  # in GB

print(
    f"Data has {n_obs} observations, we use {n_vars} random variables in the model. We work in 64 bit precision, so each element is 8 bytes."
    + f" For {n_draws} draws, the array of draws is about {size_in_ram:.2f} GB."
)

Data has 4308 observations, we use 6 random variables in the model. We work in 64 bit precision, so each element is 8 bytes. For 5000 draws, the array of draws is about 0.96 GB.


In [8]:
%%memit

model = MixedLogit()
res = model.fit(
    X=df[varnames],
    y=df['choice'],
    varnames=varnames,
    ids=df['chid'],
    panels=df['id'],
    alts=df['alt'],
    n_draws=n_draws,
    randvars={'pf': 'n', 'cl': 'n', 'loc': 'n', 'wk': 'n', 'tod': 'n', 'seas': 'n'},
    skip_std_errs=True,  # skip standard errors to speed up the example
)
display(model.summary())

  self.pid = os.fork()
2025-08-02 06:50:58,584 INFO jaxlogit.mixed_logit: Number of draws: 5000.
2025-08-02 06:50:58,585 INFO jaxlogit.mixed_logit: Data contains 361 panels.
2025-08-02 06:50:58,585 INFO jaxlogit._optimize: Running minimization with method trust-region


Loss on this step: 5398.76049331062, Loss on the last accepted step: 0.0, Step size: 1.0
Loss on this step: 228458.86628940978, Loss on the last accepted step: 5398.76049331062, Step size: 0.25
Loss on this step: 160451.46329785584, Loss on the last accepted step: 5398.76049331062, Step size: 0.0625
Loss on this step: 50414.41638847614, Loss on the last accepted step: 5398.76049331062, Step size: 0.015625
Loss on this step: 12744.479542385374, Loss on the last accepted step: 5398.76049331062, Step size: 0.00390625
Loss on this step: 4876.005292732634, Loss on the last accepted step: 5398.76049331062, Step size: 0.00390625
Loss on this step: 12433.663557121343, Loss on the last accepted step: 4876.005292732634, Step size: 0.0009765625
Loss on this step: 4805.104927545437, Loss on the last accepted step: 4876.005292732634, Step size: 0.0009765625
Loss on this step: 4727.152504373395, Loss on the last accepted step: 4805.104927545437, Step size: 0.0009765625
Loss on this step: 4745.130706

2025-08-02 06:53:08,366 INFO jaxlogit.mixed_logit: Optimization finished, success = True, final loglike = -3880.18, final gradient max = 3.52e-05, norm = 3.71e-05.
2025-08-02 06:53:08,367 INFO jaxlogit.mixed_logit: Skipping H_inv and grad_n calculation due to skip_std_errs=True
2025-08-02 06:53:08,367 INFO jaxlogit._choice_model: Post fit processing
2025-08-02 06:53:08,480 INFO jaxlogit._choice_model: Optimization terminated successfully.


    Message: 
    Iterations: 122
    Function evaluations: 135
Estimation time= 131.9 seconds
---------------------------------------------------------------------------
Coefficient              Estimate      Std.Err.         z-val         P>|z|
---------------------------------------------------------------------------
pf                     -1.0166190     1.0000000    -1.0166190         0.309    
cl                     -0.2327734     1.0000000    -0.2327734         0.816    
loc                     2.3556396     1.0000000     2.3556396        0.0185 *  
wk                      1.6744863     1.0000000     1.6744863        0.0941 .  
tod                    -9.7531221     1.0000000    -9.7531221      3.03e-22 ***
seas                   -9.9133439     1.0000000    -9.9133439       6.4e-23 ***
sd.pf                  -1.3452418     1.0000000    -1.3452418         0.179    
sd.cl                  -0.6834247     1.0000000    -0.6834247         0.494    
sd.loc                  1.7530060    

None

  self.pid = os.fork()


peak memory: 16300.37 MiB, increment: 13227.84 MiB


### Now with batching

In [9]:
%%memit

model = MixedLogit()
res = model.fit(
    X=df[varnames],
    y=df['choice'],
    varnames=varnames,
    ids=df['chid'],
    panels=df['id'],
    alts=df['alt'],
    n_draws=n_draws,
    randvars={'pf': 'n', 'cl': 'n', 'loc': 'n', 'wk': 'n', 'tod': 'n', 'seas': 'n'},
    skip_std_errs=True,  # skip standard errors to speed up the example
    batch_size=2500,
)
display(model.summary())

  self.pid = os.fork()
2025-08-02 06:53:37,059 INFO jaxlogit.mixed_logit: Batch size 2500 for 5000 draws, 2 batches, batch_shape=(361, 6, 2500).
2025-08-02 06:53:37,253 INFO jaxlogit.mixed_logit: Shape of halton_rand_idxs: (2, 902500), last row: [ 902600  902601  902602 ... 1805097 1805098 1805099].
2025-08-02 06:53:37,254 INFO jaxlogit.mixed_logit: Data contains 361 panels.
2025-08-02 06:53:37,255 INFO jaxlogit._optimize: Running minimization with method trust-region


Loss on this step: 5400.574956968836, Loss on the last accepted step: 0.0, Step size: 1.0
Loss on this step: 236205.1987293798, Loss on the last accepted step: 5400.574956968836, Step size: 0.25
Loss on this step: 179328.5462734287, Loss on the last accepted step: 5400.574956968836, Step size: 0.0625
Loss on this step: 60138.966668619876, Loss on the last accepted step: 5400.574956968836, Step size: 0.015625
Loss on this step: 15029.725330006007, Loss on the last accepted step: 5400.574956968836, Step size: 0.00390625
Loss on this step: 4902.1301395825285, Loss on the last accepted step: 5400.574956968836, Step size: 0.00390625
Loss on this step: 23906.99970969664, Loss on the last accepted step: 4902.1301395825285, Step size: 0.0009765625
Loss on this step: 5000.272332071836, Loss on the last accepted step: 4902.1301395825285, Step size: 0.000244140625
Loss on this step: 4818.005552007051, Loss on the last accepted step: 4902.1301395825285, Step size: 0.000244140625
Loss on this step:

2025-08-02 07:01:03,434 INFO jaxlogit.mixed_logit: Optimization finished, success = True, final loglike = -3879.92, final gradient max = 7.33e-06, norm = 3.69e-05.
2025-08-02 07:01:03,435 INFO jaxlogit.mixed_logit: Skipping H_inv and grad_n calculation due to skip_std_errs=True
2025-08-02 07:01:03,435 INFO jaxlogit._choice_model: Post fit processing
2025-08-02 07:01:03,438 INFO jaxlogit._choice_model: Optimization terminated successfully.


    Message: 
    Iterations: 147
    Function evaluations: 157
Estimation time= 446.4 seconds
---------------------------------------------------------------------------
Coefficient              Estimate      Std.Err.         z-val         P>|z|
---------------------------------------------------------------------------
pf                     -1.0109002     1.0000000    -1.0109002         0.312    
cl                     -0.2295223     1.0000000    -0.2295223         0.818    
loc                     2.3477364     1.0000000     2.3477364        0.0189 *  
wk                      1.6751214     1.0000000     1.6751214         0.094 .  
tod                    -9.7162856     1.0000000    -9.7162856      4.32e-22 ***
seas                   -9.8762655     1.0000000    -9.8762655      9.19e-23 ***
sd.pf                  -1.3990437     1.0000000    -1.3990437         0.162    
sd.cl                  -0.6755186     1.0000000    -0.6755186         0.499    
sd.loc                  1.7072542    

None

  self.pid = os.fork()


peak memory: 9674.15 MiB, increment: 6727.03 MiB


Peak memory usage came down 16.3GB to 9.7GB, but this is at the cost of runtime, which went from 132s to 446s.