# How Does Using Lax.Map instead of Jax.VMap Affect Performance?

I'm curious whether swapping from map to vmap will make my code more readable and/or performant.

I currently use map in...

- `predict_trials`, where I map `preidct_and_simulate_recalls` over each trial in `trials`. This is a major loss function for fitting my model.
- `predict_transitions` calls `predict_trials`.
- `present_and_predict_transitions` maps over present_lists and trials to call `present_and_predict_trial`.
- `simulate_trials` maps over `rngs` to call `simulate_free_recall` for each trial.
- `present_and_simulate_trials` adds `present_lists` to the mapping and calls `present_and_simulate_free_recall`.
- `simulate_transitions` maps `simulate_free_recall_after_first` over `rngs` and `first_recalls`.
- `present_and_simulate_transitions` does same as `simulate_transitions` but adds `present_lists` as a mapped argument.

I'll focus my research on `predict_trials`. If I don't find a performance effect or an improvement to readability, I won't bother with the other functions.

In [3]:
import jax

# jax.config.update("jax_disable_jit", True)
from jax import numpy as jnp, lax
from jaxcmr.helpers import load_data, generate_trial_mask
from jaxcmr.typing import MemorySearch, MemorySearchCreateFn
from jaxcmr.typing import Integer, Float, Array, Float_, Int_
from typing import Mapping
import json


In [4]:
# load and mask trials
data_name = "LohnasKahana2014"
data_path = f"data/{data_name}.h5"
data = load_data(data_path)
trial_query = 'data["list_type"] == 1'
trial_mask = generate_trial_mask(data, trial_query)  # type: ignore
trials = data["recalls"][trial_mask]
max_list_length = trials.shape[1]

In [5]:
with open("data/base_cmr_parameters.json") as f:
    fit_config = json.load(f)

base_params = fit_config["fixed"].copy()
base_params['choice_sensitivity'] += 0.001
bounds = fit_config["free"].copy()

FileNotFoundError: [Errno 2] No such file or directory: 'data/base_cmr_parameters.json'

In [6]:
def predict_trials(
    model_create_fn: MemorySearchCreateFn,
    list_length: int,
    trials: Integer[Array, " trials recall_events"],
    parameters: Mapping[str, Float_],
) -> Float[Array, " trials recall_events"]:
    """Return the simulation and outcome probabilities of multiple chains of retrieval events.

    Args:
        model_create_fn: constructor for a memory search model
        list_length: the length of the study and recall sequences.
        trials: the indices of the items to retrieve (1-indexed) or 0 to stop.
        parameters: the model parameters.
    """
    model = model_create_fn(list_length, parameters)
    model = lax.fori_loop(1, list_length + 1, lambda i, m: m.experience(i), model)
    model = model.start_retrieving()
    return lax.map(lambda trial: predict_and_simulate_recalls(model, trial)[1], trials)

_predict_trials = jax.jit(predict_trials, static_argnums=(0, 1))
_predict_trials(CMR.init, max_list_length, trials, base_params)


NameError: name 'CMR' is not defined

In [None]:
%timeit _predict_trials(CMR.init, max_list_length, trials, base_params).block_until_ready()

3.96 ms ± 93.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
def predict_trials(
    model_create_fn: MemorySearchCreateFn,
    list_length: int,
    trials: Integer[Array, " trials recall_events"],
    parameters: Mapping[str, Float_],
) -> Float[Array, " trials recall_events"]:
    """Return the simulation and outcome probabilities of multiple chains of retrieval events.

    Args:
        model_create_fn: constructor for a memory search model
        list_length: the length of the study and recall sequences.
        trials: the indices of the items to retrieve (1-indexed) or 0 to stop.
        parameters: the model parameters.
    """
    model = model_create_fn(list_length, parameters)
    model = lax.fori_loop(1, list_length + 1, lambda i, m: m.experience(i), model)
    model = model.start_retrieving()
    # return lax.map(lambda trial: predict_and_simulate_recalls(model, trial)[1], trials)
    return jax.vmap(lambda trial: predict_and_simulate_recalls(model, trial)[1])(trials)

_predict_trials = jax.jit(predict_trials, static_argnums=(0, 1))
_predict_trials(CMR.init, max_list_length, trials, base_params)

Array([[0.03062369, 0.20781308, 0.01182406, ..., 1.        , 1.        ,
        1.        ],
       [0.08700962, 0.21383993, 0.0047741 , ..., 1.        , 1.        ,
        1.        ],
       [0.62229365, 0.6549627 , 0.42818728, ..., 1.        , 1.        ,
        1.        ],
       ...,
       [0.62229365, 0.6549627 , 0.42818728, ..., 1.        , 1.        ,
        1.        ],
       [0.0830027 , 0.0185994 , 0.07063522, ..., 1.        , 1.        ,
        1.        ],
       [0.08700962, 0.48712012, 0.06968244, ..., 1.        , 1.        ,
        1.        ]], dtype=float32)

In [None]:
%timeit _predict_trials(CMR.init, max_list_length, trials, base_params).block_until_ready()

11.9 ms ± 72.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


Kind of mysterious. The vmap version is several times slower than the map version. I'm not sure why. I'll try to figure it out.

I currently use vmap in my fitting code and in wrappers around my simulation functions. It's possible that I'm losing out on performance by using vmap in these places. I'll try to figure it out.

## Integration Test: What about inside our fitting code?

In our fitting code, we currently dynamically parametrize and compress the output of `predict_trials` using `vmap` in the following way:

In [None]:
def predict_trials(
    model_create_fn: MemorySearchCreateFn,
    list_length: int,
    trials: Integer[Array, " trials recall_events"],
    parameters: Mapping[str, Float_],
) -> Float[Array, " trials recall_events"]:
    """Return the simulation and outcome probabilities of multiple chains of retrieval events.

    Args:
        model_create_fn: constructor for a memory search model
        list_length: the length of the study and recall sequences.
        trials: the indices of the items to retrieve (1-indexed) or 0 to stop.
        parameters: the model parameters.
    """
    model = model_create_fn(list_length, parameters)
    model = lax.fori_loop(1, list_length + 1, lambda i, m: m.experience(i), model)
    model = model.start_retrieving()
    return lax.map(lambda trial: predict_and_simulate_recalls(model, trial)[1], trials)

@jax.jit
def loss_fn(x):
    params = {
        key: x[key_index] for key_index, key in enumerate(bounds)
    }
    return log_likelihood(
        predict_trials(
            CMR.init,
            max_list_length,
            trials,
            {**base_params, **params},
        )
    )

mapped_loss_fn = jax.jit(jax.vmap(loss_fn, in_axes=(-1,)))

I repeat the jit compilation here because I sometimes call loss_fn directly in the notebook. This doesn't seem to affect performance even marginally.

In [None]:
x = jnp.array([base_params[key] for key in bounds])
x = jnp.repeat(x[None, :], 10, axis=0).T
mapped_loss_fn(x).block_until_ready()

Array([32965.54, 32965.54, 32965.54, 32965.54, 32965.54, 32965.54,
       32965.54, 32965.54, 32965.54, 32965.54], dtype=float32)

In [None]:
%timeit mapped_loss_fn(x).block_until_ready()

73 ms ± 509 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


Now let's target a version that avoids vmap...

In [None]:
# @jax.jit
def loss_fn(x):
    params = {
        key: x[key_index] for key_index, key in enumerate(bounds)
    }
    return log_likelihood(
        predict_trials(
            CMR.init,
            max_list_length,
            trials,
            {**base_params, **params},
        )
    )

@jax.jit
def mapped_loss_fn(x):
    # use lax.map instead of vmap
    return lax.map(loss_fn, x)

x = jnp.array([base_params[key] for key in bounds])
x = jnp.repeat(x[None, :], 10, axis=0)
mapped_loss_fn(x).block_until_ready()

Array([32965.54, 32965.54, 32965.54, 32965.54, 32965.54, 32965.54,
       32965.54, 32965.54, 32965.54, 32965.54], dtype=float32)

In [None]:
%timeit mapped_loss_fn(x).block_until_ready()

37.8 ms ± 163 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


As predicted, the mappped version is faster than the vmap version. Why didn't I catch this sooner?
But I need it to work with the transposed x. For now, I'll just transpose outside the mapped function.

In [None]:
@jax.jit
def mapped_loss_fn(x):
    # use lax.map instead of vmap
    return lax.map(loss_fn, x.T)

x = jnp.array([base_params[key] for key in bounds])
x = jnp.repeat(x[None, :], 10, axis=0).T
mapped_loss_fn(x).block_until_ready()

Array([32965.54, 32965.54, 32965.54, 32965.54, 32965.54, 32965.54,
       32965.54, 32965.54, 32965.54, 32965.54], dtype=float32)

In [None]:
%timeit mapped_loss_fn(x).block_until_ready()

40.8 ms ± 113 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


This adds a couple milliseconds to the runtime. I should try to avoid this. 

## Optimize by Extracting Parametrization

Can I optimize further? 
One possibility is that I could configure the parameters outside the mapped function.

In [None]:
bounds

{'encoding_drift_rate': [2.220446049250313e-16, 0.9999999999999998],
 'delay_drift_rate': [2.220446049250313e-16, 0.9999999999999998],
 'start_drift_rate': [2.220446049250313e-16, 0.9999999999999998],
 'recall_drift_rate': [2.220446049250313e-16, 0.9999999999999998],
 'shared_support': [2.220446049250313e-16, 100.0],
 'item_support': [2.220446049250313e-16, 100.0],
 'learning_rate': [2.220446049250313e-16, 0.9999999999999998],
 'primacy_scale': [2.220446049250313e-16, 100.0],
 'primacy_decay': [2.220446049250313e-16, 100.0],
 'stop_probability_scale': [2.220446049250313e-16, 0.9999999999999998],
 'stop_probability_growth': [2.220446049250313e-16, 10.0]}

In [None]:
def loss_fn(params):
    return log_likelihood(
        predict_trials(
            CMR.init,
            max_list_length,
            trials,
            {**base_params, **params},
        )
    )

@jax.jit
def mapped_loss_fn(x):
    params = {
        key: x[key_index] for key_index, key in enumerate(bounds)
    }
    return lax.map(loss_fn, params)

x = jnp.array([base_params[key] for key in bounds])
x = jnp.array([base_params[key] for key in bounds])
x = jnp.repeat(x[None, :], 10, axis=0).T

mapped_loss_fn(x).block_until_ready()

Array([32965.54, 32965.54, 32965.54, 32965.54, 32965.54, 32965.54,
       32965.54, 32965.54, 32965.54, 32965.54], dtype=float32)

In [None]:
x.shape

(11, 10)

In [None]:
%timeit mapped_loss_fn(x).block_until_ready()

39.5 ms ± 608 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


This beats my previous best time somewhat without neglecting the transpose. Can we go further? I could try to avoid creating the dictionary in the mapped function.

I would need base_params to be an array with repeated values across the batch dimension. I can do this during initialization.

In [None]:
batch_base_params = {key: jnp.array([base_params[key]] * 10) for key in base_params}
batch_base_params

{'encoding_drift_rate': Array([0.80163276, 0.80163276, 0.80163276, 0.80163276, 0.80163276,
        0.80163276, 0.80163276, 0.80163276, 0.80163276, 0.80163276],      dtype=float32),
 'delay_drift_rate': Array([0.99664116, 0.99664116, 0.99664116, 0.99664116, 0.99664116,
        0.99664116, 0.99664116, 0.99664116, 0.99664116, 0.99664116],      dtype=float32),
 'start_drift_rate': Array([0.05112313, 0.05112313, 0.05112313, 0.05112313, 0.05112313,
        0.05112313, 0.05112313, 0.05112313, 0.05112313, 0.05112313],      dtype=float32),
 'recall_drift_rate': Array([0.8666706, 0.8666706, 0.8666706, 0.8666706, 0.8666706, 0.8666706,
        0.8666706, 0.8666706, 0.8666706, 0.8666706], dtype=float32),
 'shared_support': Array([0.01612209, 0.01612209, 0.01612209, 0.01612209, 0.01612209,
        0.01612209, 0.01612209, 0.01612209, 0.01612209, 0.01612209],      dtype=float32),
 'item_support': Array([0.8877853, 0.8877853, 0.8877853, 0.8877853, 0.8877853, 0.8877853,
        0.8877853, 0.8877853, 0.8

In [None]:
def loss_fn(params):
    return log_likelihood(
        predict_trials(
            CMR.init,
            max_list_length,
            trials,
            params,
        )
    )



@jax.jit
def mapped_loss_fn(x):
    params = {
        key: x[key_index] if key in bounds else batch_base_params[key]
        for key_index, key in enumerate(batch_base_params)
    }
    return lax.map(loss_fn, params)

x = jnp.array([base_params[key] for key in bounds])
x = jnp.array([base_params[key] for key in bounds])
x = jnp.repeat(x[None, :], 10, axis=0).T

mapped_loss_fn(x).block_until_ready()

Array([32965.54, 32965.54, 32965.54, 32965.54, 32965.54, 32965.54,
       32965.54, 32965.54, 32965.54, 32965.54], dtype=float32)

In [None]:
%timeit mapped_loss_fn(x).block_until_ready()

39.4 ms ± 654 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


This doesn't seem to help. Maybe because the `batch_base_params` object is bigger? Anyway.

## Last thing I want to check is whether log-likelihood really needs to apply log before sum.

In [None]:
likelihoods = _predict_trials(CMR.init, max_list_length, trials, base_params)

likelihoods

Array([[0.03062369, 0.20781308, 0.01182406, ..., 1.        , 1.        ,
        1.        ],
       [0.08700962, 0.21383993, 0.0047741 , ..., 1.        , 1.        ,
        1.        ],
       [0.62229365, 0.6549627 , 0.42818728, ..., 1.        , 1.        ,
        1.        ],
       ...,
       [0.62229365, 0.6549627 , 0.42818728, ..., 1.        , 1.        ,
        1.        ],
       [0.0830027 , 0.0185994 , 0.07063522, ..., 1.        , 1.        ,
        1.        ],
       [0.08700962, 0.48712012, 0.06968244, ..., 1.        , 1.        ,
        1.        ]], dtype=float32)

In [None]:
jnp.sum(jnp.log(likelihoods))

Array(-32965.54, dtype=float32)

In [None]:
jnp.log(jnp.sum(likelihoods))

Array(9.315411, dtype=float32)

Kay, it does.

## Final Results
Let's confirm I can improve fits by using map instead of vmap by actually updating my fitting code and re-running the fitting process.

Current fits return an average fitness of 590.41 with a runtime of 7:27 under .001 tolerance.

New fits return an average fitness of 590.41 with a runtime of 4:50 under .001 tolerance.

So I saved 2:37 by switching to map. Not bad.