In [1]:
# | default_exp rnc
%load_ext autoreload
%autoreload 2

# Repetition Neighbor Contiguity Analysis
> Rate of transition between items that follow presentation of the same item

For spaced items (lag >= 4) in mixed lists with study positions $i$ and $j$, we determined the proportion of times, given that a subject made a transition between an item from an item in $S_i = {i + 1, i + 2}$ or $S_j = {j + 1, j + 2}$, that they would then transition to an item in $S_j$ or $S_i$, respectively. 

To control for the proportion of transitions expected in the absence of repeated items, we assigned the same serial positions used in the mixed lists to corresponding items in the control lists. For each subject, we assigned these positions to 100 random shuffles of the control lists, and took the mean across lists to get a baseline expectation of remote transitions.

In [2]:
# | exports
import jax
import jax.numpy as jnp
from jax import lax

## Initial Approach

<!-- In our Repetition Lag-CRP analysis, we included a `should_tabulate` function that focused consideration of transitions from items with at least two spaced-out study positions. In this analysis, we instead identify transitions from +1 and +2 neighbors of any study positions of repeatedly studied items, $S_i = {i + 1, i + 2}$ or $S_j = {j + 1, j + 2}$ where $i$ and $j$ identify the study positions of the repeated items.

In our Repetition Lag-CRP analysis, in our `tabulate` function, we tracked actual versus possible transitions to different serial lags from the last recalled item. In this analysis, we instead focus on tracking actual and possible transitions from items in the $S_i$ or $S_j$ sets to their corresponding $S_j$ or $S_i$ sets, respectively.

Under this analysis, the key challenge of the analysis is consistently identifying applicable $S_i$ and $S_j$ sets across recall transitions in a tabulation.
To accomplish this, at initialization of the tabulation, we track for each studied item the study position(s) of the item studied immediately before it, as well as of the item studied two study positions before it. 

When either the last or second-to-last studied item has multiple study positions, the considered item is part of an $S_i$ or $S_j$ set.

No, I think we can be more efficient than this.
When we get to a recall event, that can be an item with multiple study positions.
When we consider the next recall event, it can also be an item with multiple study positions.
For all applicable study positions of both transitioned-from and transitioned-to items, we should identify the items studied one or two study positions before them.
If these sets intersect and the two items are not themselves neighbors, then that's a transition between $S_i$ and $S_j$. -->

In [3]:
# | exports

@jax.jit
def find_repeated_positions(presentation: jnp.ndarray) -> jnp.ndarray:
    """
    JAX-jittable function to find all (i, j) study positions for items that
    appear *exactly* twice in `presentation`.

    presentation: shape (L,), each element is a 1-based item ID
      - We assume item IDs do not exceed L (the length of the presentation).

    Returns:
      repeated_positions: shape (L//2, 2). Each row is (i, j) for a repeated item;
        the first rows are the valid (i,j), and unused rows remain (0,0).
    """

    L = presentation.shape[0]

    # We'll allocate for all item IDs up to L (i.e. item in [1..L]).
    # item_counts[item] tracks how many times 'item' appears (saturated at 3).
    # positions_buffer[item] = [pos1, pos2] in 1-based indexing, for up to 2 occurrences.
    item_counts = jnp.zeros((L + 1,), dtype=jnp.int32)  # index 0 unused
    positions_buffer = jnp.zeros((L + 1, 2), dtype=jnp.int32)

    # Pass 1: accumulate counts and positions of each item
    def first_pass_body(idx, carry):
        (counts, pos_buf) = carry
        item = presentation[idx]  # 1-based item ID
        old_count = counts[item]
        new_count = jnp.minimum(old_count + 1, 3)  # saturate at 3
        counts = counts.at[item].set(new_count)

        # Store this study position (idx+1 => 1-based) if it's the 1st or 2nd occurrence
        def store_pos(pb):
            return pb.at[item, new_count - 1].set(idx + 1)

        def do_nothing(pb):
            return pb

        pos_buf = lax.cond(new_count <= 2, store_pos, do_nothing, pos_buf)

        return (counts, pos_buf)

    (item_counts, positions_buffer) = lax.fori_loop(
        0, L, first_pass_body, (item_counts, positions_buffer)
    )

    # Pass 2: gather all items with exactly 2 occurrences into an array (L//2, 2).
    repeated_positions_init = jnp.zeros((L // 2, 2), dtype=jnp.int32)

    def second_pass_body(item_id, carry):
        (rp_count, rp_array) = carry
        c = item_counts[item_id]  # how many times item_id appeared (sat @ 3)
        is_repeated = c == 2
        pos1 = positions_buffer[item_id, 0]
        pos2 = positions_buffer[item_id, 1]
        i_min = jnp.minimum(pos1, pos2)
        j_max = jnp.maximum(pos1, pos2)

        # If repeated, store (i_min, j_max) at row rp_count
        rp_array = rp_array.at[rp_count, 0].set(jnp.where(is_repeated, i_min, 0))
        rp_array = rp_array.at[rp_count, 1].set(jnp.where(is_repeated, j_max, 0))

        rp_count = rp_count + jnp.where(is_repeated, 1, 0)
        return (rp_count, rp_array)

    (final_count, repeated_positions) = lax.fori_loop(
        1,
        L + 1,  # item_id in [1..L], ignoring index 0
        second_pass_body,
        (0, repeated_positions_init),
    )

    # repeated_positions is shape (L//2, 2). The first `final_count` rows are valid.
    return repeated_positions


def make_s_set(pos, list_len):
    """
    For a 1-based study position pos, define S_pos = {pos+1, pos+2},
    clipping so that we never exceed list_len.

    Returns a fixed-shape int32[2] array in 1-based indexing.
    """
    s = jnp.array([pos + 1, pos + 2], dtype=jnp.int32)
    # clip so we don't go beyond list_len
    s = jnp.clip(s, 1, list_len)
    return s


@jax.jit
def compute_cross_transition_proportion(
    repeated_positions,  # shape (R, 2), each row = (i, j) in 1-based indexing
    recall,  # shape (L,), each entry is the 1-based position of the item recalled
    list_len,  # an int (length of the study list, e.g. 40)
):
    """
    Given:
      - repeated_positions: shape (R,2), each row is (i,j) with i<j
      - recall: a length L vector of *study positions* (1-based).
                The k-th recalled item was originally studied in position recall[k].
                If there's a "null" recall, that might be 0, etc.
      - list_len: the number of study positions in the list (e.g. 40)

    We want to compute:

      numerator = total number of transitions from S_i->S_j or S_j->S_i
      denominator = total number of transitions that start in S_i or S_j
                    (i.e. from an item in S_i or S_j to anything)
      proportion = numerator / denominator   (0 if denominator=0)

    Then sum or average over all repeated pairs?

    In many analyses, you'd either:
       (A) sum across all repeated pairs and produce a single proportion, or
       (B) compute the proportion per pair (i, j) and average across pairs.

    Below, we do (A) for demonstration: we get the global proportion across *all* repeated pairs.
    Modify as needed to do (B).
    """

    # We'll accumulate counts in arrays of shape (R,).
    # Then at the end we sum across R and do a ratio.
    def body_fun(pair_index, carry):
        """
        Loop body over each repeated pair (i, j).
        carry = (numerator_sum, denominator_sum)
        """
        (numerator_sum, denominator_sum) = carry

        i = repeated_positions[pair_index, 0]
        j = repeated_positions[pair_index, 1]

        # Build S_i and S_j
        s_i = make_s_set(i, list_len)  # shape (2,)
        s_j = make_s_set(j, list_len)  # shape (2,)

        # We'll do a second pass: go over each adjacent pair in the recall sequence
        # to find how many transitions come from S_i or S_j, and how many are cross.

        # We accumulate local counts so that we can add them into the carry afterward.
        def transition_loop(idx, local_carry):
            """
            We'll examine the transition recall[idx] -> recall[idx+1].
            local_carry = (local_numer, local_denom).
            """
            (local_numer, local_denom) = local_carry

            val_from = recall[idx]
            val_to = recall[idx + 1]

            # Check if 'from' is in S_i or S_j
            from_in_i = jnp.any(val_from == s_i)
            from_in_j = jnp.any(val_from == s_j)
            from_in_either = from_in_i | from_in_j

            # If from_in_either, we will increment the denominator
            increment_denom = jnp.where(from_in_either, 1, 0)

            # For the numerator, we need from_in_i & to_in_j or from_in_j & to_in_i
            to_in_i = jnp.any(val_to == s_i)
            to_in_j = jnp.any(val_to == s_j)

            cross_transition = (from_in_i & to_in_j) | (from_in_j & to_in_i)
            increment_numer = jnp.where(cross_transition, 1, 0)

            return (local_numer + increment_numer, local_denom + increment_denom)

        # We'll run a lax.fori_loop over all transitions from 0 to L-2
        # (since we look at recall[idx], recall[idx+1])
        init_local = (0, 0)  # local_numer, local_denom
        local_numer, local_denom = lax.fori_loop(
            0, recall.shape[0] - 1, transition_loop, init_local
        )

        # Add the local counts for this pair into the global sums
        return (numerator_sum + local_numer, denominator_sum + local_denom)

    init_val = (0, 0)
    numerator_total, denominator_total = lax.fori_loop(
        0, repeated_positions.shape[0], body_fun, init_val
    )

    # Combine into a proportion (guarding against division by zero)
    return jnp.where(denominator_total > 0, numerator_total / denominator_total, 0.0)


@jax.jit
def main(presentation, recalls):
    rep_positions_array = find_repeated_positions(presentation)
    return compute_cross_transition_proportion(
        repeated_positions=rep_positions_array,
        recall=recalls,
        list_len=presentation.shape[0],
    )

In [4]:
import numpy as np

# Example presentation (length=40), 1-based item IDs:
presentation = jnp.array([
    1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 
    11, 12, 12, 13, 14, 15, 16, 17, 10, 18,
    19, 20, 19, 21, 22, 23, 20, 24, 25, 26,
    22, 27, 28, 24, 29, 30, 31, 32, 33, 34
], dtype=int)

# Example recall (length=40), each entry is the 1-based position of the item recalled
# or 0 for "no recall" in that slot:
recall = np.array([
    1,  2,  3,  4,  5,  6,  7,  9, 10, 11,
    17, 14, 12, 15, 25, 20, 28, 30, 39, 38,
    37, 18,  0,  0,  0,  0,  0,  0,  0,  0,
     0,  0,  0,  0,  0,  0,  0,  0,  0,  0
], dtype=int)


# 1) Find repeated positions (normal Python)
rep_positions_array = find_repeated_positions(presentation)
rep_positions_array = jnp.array(rep_positions_array)
print("Repeated positions (i,j):", rep_positions_array, len(rep_positions_array), len(presentation))

# 2) JAX-ify the inputs
recall_jax = jnp.array(recall)

# 3) Call the jitted function
proportion = compute_cross_transition_proportion(
    repeated_positions=rep_positions_array,
    recall=recall_jax,
    list_len=40
)

print("Cross-transition proportion = ", main(presentation, recall_jax))

Repeated positions (i,j): [[10 19]
 [12 13]
 [21 23]
 [22 27]
 [25 31]
 [28 34]
 [ 0  0]
 [ 0  0]
 [ 0  0]
 [ 0  0]
 [ 0  0]
 [ 0  0]
 [ 0  0]
 [ 0  0]
 [ 0  0]
 [ 0  0]
 [ 0  0]
 [ 0  0]
 [ 0  0]
 [ 0  0]] 20 40
Cross-transition proportion =  0.3888889


In [5]:
import os

from jaxcmr.helpers import find_project_root, generate_trial_mask, load_data
from jax import vmap

In [6]:
# parameters
run_tag = "CRP"
data_name = "LohnasKahana2014"
data_query = "data['list_type'] >= 3"
data_path = os.path.join(find_project_root(), "data/LohnasKahana2014.h5")

# set up data structures
data = load_data(data_path)
recalls = data["recalls"]
presentations = data["pres_itemnos"]
list_length = data["listLength"][0].item()
trial_mask = generate_trial_mask(data, data_query)

vmap(main, in_axes=(0, 0), out_axes=0)(
    presentations[trial_mask], recalls[trial_mask]
).mean()

Array(0.11464079, dtype=float32)