In [None]:
#| default_exp repetition

# Helper Functions for Dealing with Item Repetitions in Study Lists

In [None]:
# | exports
import math
from functools import partial

import jax
import numpy as np
from jax import lax, random
from jax import numpy as jnp

from jaxcmr.helpers import generate_trial_mask
from jaxcmr.typing import Array, Int_, Integer, RecallDataset


## Identifying All Study Positions of an Item

We address two tasks: 
- having an item-unique index and a presentation vector and identifying applicable study positions for the item
- having a study position and a presentation vector and identifying the item-unique index of the item at that position

We decompose the latter problem into the identification of the item-unique index at that study position, and then solving the first problem for that item-unique index.

Provided functions are jit and vmap-compatible, making them efficient within JAX's functional programming paradigm.

In [None]:
# | exports


def item_to_study_positions(
    item: Int_,
    presentation: Integer[Array, " list_length"],
    size: int,
):
    """Returns the one-indexed study positions of an item in a 1D presentation sequence.

    Args:
        item: the item index.
        presentation: the 1D presentation sequence.
        size: max non-zero entries that could be returned
    """
    return lax.cond(
        item == 0,
        lambda: jnp.zeros(size, dtype=int),
        lambda: jnp.nonzero(presentation == item, size=size, fill_value=-1)[0] + 1,
    )


def all_study_positions(
    study_position: Int_,
    presentation: Integer[Array, " list_length"],
    size: int,
):
    """Return the one-indexed study positions associated with a given study position.

    Args:
        study_position: the study position.
        presentation: the 1D presentation sequence.
        size: max non-zero entries that could be returned
    """
    item = lax.cond(
        study_position > 0,
        lambda: presentation[study_position - 1],
        lambda: 0,
    )
    return item_to_study_positions(item, presentation, size)

## Setting Up Control Analyses

> To control for the proportion of transitions expected in the absence of repeated items, we assigned the same serial positions used in the mixed lists to corresponding items in the control lists. For each subject, we assigned these positions to 100 random shuffles of the control lists, and took the mean across lists to get a baseline expectation of remote transitions.

A control analysis is essential for measuring the effect of repeated items on transitions.
In the full control analysis, we apply the analysis on 100 random shuffles of these control trial vectors per subject in order to get a baseline expectation of remote transitions.
Here we specify an implementation that works reasonably for given pairings of control trial response vectors and mixed list presentation vectors.

An important consideration that strongly affects analysis results is whether for an item studied in a mixed list at two positions, recall of items in control lists at *either* position count for the purposes of tabulating transitions between $S_i$ and $S_j$, or just the item in the study positions specified by $S_i$ or $S_j$.
If we accept that both positions are applicable, the percentage of valid transitions in our control analysis will be higher than if we don't, potentially tempering conclusions about the effects of item repetitions on recall performance.
It's possible that differences in decision-making about these issues helps explain differences in reported outcomes across studies.

Here we take the stance that the proper implementation of the null hypothesis treats recall of *either* position in control list trials as a recall of the repeater item they were paired with from the mixed list. This helps examine the question of whether recall rates for the specified study positions really differ if the items in those study positions were identical to one another or distinct. To instead only treat one study position of a repeater item as valid for the control analysis necessarily produces recall rates that are lower than we would obtain under the main analysis, effectively always detecting an effect of item repetitions on recall performance even if participants treated the two study positions as two distinct items across both encoding and recall. 

Following our methodological commitment, a remaining problem is to determine how to handle recalls of distinct items from control lists that under the control analysis are treated as the same item due to matching study positions with mixed lists. Here, we take the same approach we would take when analyzing mixed lists: we filter out repeated recalls of the same "item" when analyzing recall performance across both our main and control analyses. By then conditioning tabulation of transition rates upon whether a given item has already been recalled or not, we can properly focus both versions of the analysis (main and control) on transitions where neither study position has yet been recalled in the current trial.

In [None]:
#| exports


def filter_repeated_recalls(recalls: jnp.ndarray) -> jnp.ndarray:
    """
    Drop any repeated recalls in each trial beyond the first occurrence.
    """
    n_positions = recalls.shape[1]
    matches = recalls[:, None, :] == recalls[:, :, None]
    lower_tri = jnp.tril(jnp.ones((n_positions, n_positions), bool), k=-1)
    seen_before = jnp.any(matches & lower_tri[None], axis=2)
    keep_mask = (recalls != 0) & ~seen_before
    return recalls * keep_mask


@partial(jax.jit, static_argnums=(2,))
def _shuffle_and_tile_controls(
    control_recalls: jnp.ndarray,  # [n_pure_trials, n_recalls]
    mixed_presentations: jnp.ndarray,  # [n_mixed_trials, n_pres]
    n_shuffles: int,  # static
    prng_key: jnp.ndarray,  # single PRNGKey
) -> tuple[jnp.ndarray, jnp.ndarray]:
    """
    Shuffle-filter-tile helper for one subject.
    """
    n_pure, _ = control_recalls.shape
    n_mixed, _ = mixed_presentations.shape
    repeat_factor = math.ceil(n_mixed / n_pure)
    n_permutations = n_shuffles * repeat_factor

    keys = random.split(prng_key, n_permutations)
    batched = jax.vmap(lambda k: random.permutation(k, control_recalls, axis=0))(keys)
    flat_shuffled = batched.reshape((-1, control_recalls.shape[1]))

    tiled_pres = jnp.repeat(mixed_presentations, repeats=n_shuffles, axis=0)
    return flat_shuffled, tiled_pres


def make_control_dataset(
    data: RecallDataset,
    mixed_query: str,
    control_query: str,
    n_shuffles: int,
    remove_repeats: bool = True,
    seed: int = 0,
) -> RecallDataset:
    """
    For each subject:
      - pick their pure-list recalls & mixed presentations,
      - call the jitted shuffle helper,
      - accumulate the new blocks plus any other fields.
    """

    # 1) find which subjects actually have mixed trials
    all_subject_ids = jnp.array(data["subject"]).flatten()
    mixed_mask = generate_trial_mask(data, mixed_query)
    pure_mask = generate_trial_mask(data, control_query)
    subjects = np.unique(all_subject_ids[mixed_mask])
    prng_keys = random.split(random.PRNGKey(seed), subjects.size)

    recalls_blocks = []
    pres_blocks = []
    subject_id_blocks = []
    other_fields_acc = {
        key: [] for key in data if key not in ("recalls", "pres_itemnos", "subject")
    }

    for i, subj in enumerate(subjects):
        sel_pure = (all_subject_ids == subj) & pure_mask
        sel_mixed = (all_subject_ids == subj) & mixed_mask

        pure_recalls = jnp.array(data["recalls"][sel_pure])
        mixed_pres = jnp.array(data["pres_itemnos"][sel_mixed])
        if pure_recalls.shape[0] == 0 or mixed_pres.shape[0] == 0:
            continue

        new_recalls, new_pres = _shuffle_and_tile_controls(
            pure_recalls, mixed_pres, n_shuffles, prng_keys[i]
        )
        if remove_repeats:
            new_recalls = filter_repeated_recalls(new_recalls)

        recalls_blocks.append(new_recalls)
        pres_blocks.append(new_pres)
        subject_id_blocks.append(jnp.full((new_recalls.shape[0], 1), subj, dtype=int))

        # carry along all other fields, repeated to match new_recalls
        n_rows = new_recalls.shape[0]
        for field, acc in other_fields_acc.items():
            arr = jnp.array(data[field])[sel_mixed]
            acc.append(jnp.repeat(arr, repeats=n_rows // arr.shape[0], axis=0))

    return {
        "subject": jnp.vstack(subject_id_blocks),
        "recalls": jnp.vstack(recalls_blocks),
        "pres_itemnos": jnp.vstack(pres_blocks),
        **{f: jnp.vstack(lst) for f, lst in other_fields_acc.items()},
    }  # type: ignore