In [None]:
# | default_exp simple_crp
%load_ext autoreload
%autoreload 2

# Lag-CRP Analysis


The **Lag-Conditional Response Probability (Lag-CRP)** analysis is a technique commonly used in memory research, particularly in free recall tasks, to measure the lag-contiguity effect: the tendency to recall items studied close together in time, rather than randomly or far apart. 

A key concept in Lag-CRP analysis is the "serial lag," which refers to positional distance between two items during the **study phase**. The Lag-CRP analysis considers serial lags between successively recalled items during the **recall phase**. At each recall position beyond the first, serial lag = (position of current recalled item) − (position of previously recalled item). 

For example, recalling the item studied in position 5 immediately after recalling the item studied in position 4 is a lag of +1. Negative lag indicates recalling an earlier-studied item, positive lag indicates recalling a later-studied item.

The lag-CRP measures the probability of recalling an item at a particular lag **given the current recall position**, conditional on items that remain available for recall. By this definition, the lag-CRP is calculated by tabulating for each recall transition:

- **Actual transitions**: Number of times participants transitioned from recalling item at position X to recalling item at position Y.
- **Available transitions**: Number of times each possible transition lag (Y - X) could have occurred, given the items that had not yet been recalled.
- **Lag-CRP**: Actual transitions divided by available transitions for each lag.

Thus the formula for calculating Lag-CRP across a set of trials is:

$$\text{CRP(Lag)} = \frac{\text{Actual Transitions at Lag}}{\text{Available Transitions at Lag}}$$

High CRP at small lags (especially ±1, ±2) indicates strong temporal contiguity effects: participants recall items studied near each other.

Lag-CRP curves usually show a peak around lag = +1 or −1 and decrease gradually as lags increase, indicating a strong association between temporally adjacent items.

In [None]:
#| exports
from jax import lax, jit
from jax import numpy as jnp
from simple_pytree import Pytree

from jaxcmr.typing import Array, Float, Int_, Integer

## Simple Case: When Study Lists are Uniform
To introduce how to calculate Lag-CRP, we will first consider a simple case where the study lists are uniform (no repeated items). In this case, we initialize for each trial a `Tabulation` object that updates at each recall attempt to track the following:

- The last item recalled 
- The number of available items for recall 
- the number of actual transitions made for each lag at each recall attempt
- the number available transitions for each lag at each recall attempt

In [None]:
#| exports

class Tabulation(Pytree):
    "A tabulation of transitions between items during recall of a study list."

    def __init__(self, list_length: int, first_recall: Int_):
        self.lag_range = list_length - 1
        self.list_length = list_length
        self.all_items = jnp.arange(1, list_length + 1, dtype=int)
        self.actual_transitions = jnp.zeros(self.lag_range * 2 + 1, dtype=int)
        self.avail_transitions = jnp.zeros(self.lag_range * 2 + 1, dtype=int)
        self.avail_items = jnp.ones(list_length, dtype=bool)
        self.avail_items = self.avail_items.at[first_recall - 1].set(False)
        self.previous_item = first_recall

    def _update(self, current_item: Int_) -> "Tabulation":
        "Tabulate actual and possible serial lags of current from previous item."
        actual_lag = current_item - self.previous_item + self.lag_range
        all_lags = self.all_items - self.previous_item + self.lag_range

        return self.replace(
            previous_item=current_item,
            avail_items=self.avail_items.at[current_item - 1].set(False),
            avail_transitions=self.avail_transitions.at[all_lags].add(self.avail_items),
            actual_transitions=self.actual_transitions.at[actual_lag].add(1),
        )

    def update(self, choice: Int_) -> "Tabulation":
        "Tabulate a transition if the choice is non-zero (i.e., a valid item)."
        return lax.cond(choice > 0, lambda: self._update(choice), lambda: self)

To calculate the Lag-CRP, we tabulate available and actual transitions for each lag at each recall attempt in applicable trials:

In [None]:
#| exports

def tabulate_trial(
    trial: Integer[Array, " recall_events"], list_length: int
) -> Tabulation:
    "Tabulate transitions across a single trial."
    return lax.scan(
        lambda tabulation, recall: (tabulation.update(recall), None),
        Tabulation(list_length, trial[0]),
        trial[1:],
    )[0]

Finally, we aggregate the counts of actual and available transitions across all trials, and divide the actual transitions by the available transitions to get the Lag-CRP for each lag.

In [None]:
#| exports

def simple_crp(
    trials: Integer[Array, "trials recall_events"], list_length: int
) -> Float[Array, " lags"]:
    "Tabulate transitions for multiple trials."
    tabulated_trials = lax.map(lambda trial: tabulate_trial(trial, list_length), trials)
    total_actual_transitions = jnp.sum(tabulated_trials.actual_transitions, axis=0)
    total_possible_transitions = jnp.sum(tabulated_trials.avail_transitions, axis=0)
    return total_actual_transitions / total_possible_transitions


This approach forms the basis for more complex transition analyses that are used in practice, such as when study lists are not uniform (e.g., when items are repeated) or when only particular transitions are of interest (e.g., from just the first recall position) or when transition counts are binned by distance measures instead of serial lag. 

## Examples

In [None]:
from jaxcmr.helpers import generate_trial_mask, load_data, find_project_root
import os

Uniform study lists case:

In [None]:
# parameters
run_tag = "CRP"
data_name = "LohnasKahana2014"
data_query = "data['list_type'] == 1"
data_path =  os.path.join(find_project_root(), "data/LohnasKahana2014.h5")

# set up data structures
data = load_data(data_path)
recalls = data["recalls"]
presentations = data["pres_itemnos"]
list_length = data["listLength"][0].item()
trial_mask = generate_trial_mask(data, data_query)

# plot
# plot_crp(data, generate_trial_mask(data, data_query))
jit(simple_crp, static_argnames=("list_length"))(
    recalls[trial_mask], list_length
)

Array([0.10416666, 0.04166667, 0.02147239, 0.0228833 , 0.01654412,
       0.01880878, 0.01775956, 0.02140309, 0.02038627, 0.01629914,
       0.02      , 0.01670644, 0.01305684, 0.01468532, 0.01878238,
       0.017042  , 0.01747234, 0.01930502, 0.01536831, 0.01510574,
       0.01016949, 0.01390176, 0.01393885, 0.01690507, 0.01760131,
       0.02655569, 0.01942117, 0.0177712 , 0.01956599, 0.02229081,
       0.02253148, 0.02098128, 0.02182163, 0.02749459, 0.02875974,
       0.03524229, 0.05237688, 0.06995769, 0.1551922 ,        nan,
       0.25058603, 0.09483766, 0.06154943, 0.04214095, 0.02825979,
       0.02849873, 0.02509804, 0.02086677, 0.02140078, 0.02180149,
       0.01564345, 0.01611922, 0.01162425, 0.01868557, 0.01537433,
       0.01518288, 0.01601424, 0.01365818, 0.01031716, 0.01069731,
       0.01161344, 0.01133391, 0.01545455, 0.01106833, 0.00858152,
       0.01419558, 0.01462317, 0.013261  , 0.014295  , 0.01597222,
       0.00988593, 0.01417848, 0.01458523, 0.014     , 0.00467