In [1]:
# default_exp analyses

# Study Phase Retrieval Effect
They considered transitions between items following a shared repeated item. They calculated the proportion of those items recalled in $S_j = {j + 1, j + 2}$ of which CMR then recalled an item in the set $S_i = {i + 1, i + 2}$. They also calculated the proportion of recalls $S_i$ of which CMR then transitioned to an item in the set $S_j$. They calculated the proportion of transitions for each of lags $j - i >= 4$, and represented the mean percent of transitions across these lags. 

We'll extend this analysis by performing a masked, reference-shifted lag-crp analysis. We'll only track transitions from neighbors of repeatedly-presented items. And we'll shift our lag-reference from the these items to the alternative position of the repeatedly-presented items they transitioned from.

To estimate the proportion of transitions that CMR would make at these lags in the absence of repeated items, they considered transitions in control lists matched to the same serial positions considered in the mixed lists. They matched these serial positions to 100 random shuffles of the control lists, and took the mean across the reshuffled datasets. We'll develop that functionality elsewhere.

## Functions
We start with a specification of `picky_lag_crp` that supports item repetitions. Then we need to write a mask that selects transitions from items

In [1]:
from repetition_cmr.analyses import recall_by_all_study_positions

In [18]:
# export

from numba import njit
import numpy as np

#@njit(nogil=True)
def neighbor_contiguity(trials, presentations, max_repeats=2):
    
    list_length = len(presentations[0])
    lag_range = list_length - 1
    total_actual_lags = np.zeros(lag_range * 2 + 1)
    total_possible_lags = np.zeros(lag_range * 2 + 1)
    terminus = np.sum(trials != 0, axis=1) # number of recalls per trial
    recall_by_study_position = recall_by_all_study_positions(trials, presentations, max_repeats)
    trials_shape = np.shape(trials)
    mask = np.zeros((trials_shape[0], max_repeats, trials_shape[1]), dtype=np.bool_)
    reference_positions = np.zeros((trials_shape[0], max_repeats, trials_shape[1]), dtype=np.int32)

    for trial_index, presentation in enumerate(presentations):

        masked_items = [[] for _ in range(max_repeats)]

        for current_index, current_item in enumerate(presentation[:-1]):

            # identify each time current_item occurs in presentation and skip if count is 1
            positions = np.nonzero(presentation == current_item)[0]
            if len(positions) == 1:
                continue

            # also skip if lag between positions is <4
            if positions[1] - positions[0] < 4:
                continue

            # identify relative position of current_index in repetitions
            repetition_index = np.nonzero(positions==current_index)[0][0]

            # recall of item at succesive two serial positions should be in mask 
            # reference_positions position should be the repetition position that's distinct from current_index
            # TODO: figure out policy for when max_repeats > 2
            if current_index+1 < list_length:
                masked_items[repetition_index].append(presentation[current_index+1]+1)

                reference_positions[trial_index, repetition_index, np.nonzero(
                    trials[trial_index] == presentation[current_index+1]+1)[0]] = positions[positions != current_index][0] + 1

            if current_index+2 < list_length:
                masked_items[repetition_index].append(presentation[current_index+2]+1)

                reference_positions[trial_index, repetition_index, np.nonzero(
                    trials[trial_index] == presentation[current_index+2]+1)[0]] = positions[positions != current_index][0] + 1    

        for i in range(max_repeats):
            mask[trial_index, i] = np.isin(trials[trial_index], np.array(masked_items[i]))

        # if trial_index == 0:
        #     print(presentation)
        #     print(trials[trial_index]-1)
        #     print(mask[trial_index])
        #     print(reference_positions[trial_index])
        #assert(False)
        
    for trial_index in range(len(trials)):
        
        previous_item = 0
        item_count = np.max(presentations[trial_index]) + 1
        possible_items = np.arange(item_count) # initial pool of possible recalls, 1-indexed
        possible_positions = np.zeros((item_count, max_repeats), dtype=np.int32)
        
        # we track possible positions using presentations and alt_presentations
        for item in range(item_count):
            pos = np.nonzero(presentations[trial_index] == item)[0] + 1
            possible_positions[item, :len(pos)] = pos
            
        for recall_index in range(terminus[trial_index]):
            
            current_item = presentations[trial_index][trials[trial_index, recall_index]-1]
            
            # track possible and actual lags
            # TODO: add mask condition to select only transitions from Si and Sj
            if recall_index > 0 and np.any(mask[trial_index, :, recall_index-1]): 

                repetition_index = np.nonzero(mask[trial_index, :, recall_index-1])[0][0]
                
                # item indices don't help track lags anymore
                # so more complex calculation needed to identify possible lags given previous item
                current_index = np.nonzero(possible_items==current_item)[0]
                possible_lags = np.zeros((max_repeats, len(possible_items)), dtype=np.int32)

                for x in range(len(recall_by_study_position)):
                    # for y in range(len(recall_by_study_position)):
                    #     if reference_positions[trial_index, y, recall_index-1] > 0:
                        
                    possible_lags[x] = possible_positions[
                        possible_items, x] - reference_positions[trial_index, repetition_index, recall_index-1]
                    
                    # if tracked position is 0, then we don't actually want to count it in our lags
                    possible_lags[x][possible_positions[possible_items, x] == 0] = 0

                possible_lags += lag_range
                total_actual_lags[possible_lags[:, current_index].flatten()] += 1
                total_possible_lags[possible_lags.flatten()] += 1

            # update pool to exclude recalled item (updated to still identify 1-indexed item)
            previous_item = current_item
            possible_items = possible_items[possible_items != previous_item]
                    
    # small correction to avoid nans and commit to excluding multiply-tracked single presentations 
    total_actual_lags[lag_range] = 0
    total_possible_lags[total_actual_lags==0] += 1
    
    return total_actual_lags/total_possible_lags

In [19]:
import pandas as pd
import numpy.matlib
from compmemlearn.datasets import events_metadata, generate_trial_mask

events = pd.read_csv('../../../compmemlearn/data/LohnasKahana2014.csv')
trials, list_lengths, presentations = events_metadata(events)
trials = trials
presentations = presentations
list_length = list_lengths


for condition in [4]:
    trial_mask = generate_trial_mask(events, f"condition == {condition}")[0]
    print(neighbor_contiguity(trials[0][trial_mask], presentations[0][trial_mask]))

169.0 2530.0
[0.         0.         0.         0.07142857 0.07317073 0.01538462
 0.03409091 0.02608696 0.01333333 0.03825137 0.02202643 0.01123596
 0.01286174 0.02432432 0.01690821 0.01046025 0.02702703 0.01584507
 0.01771337 0.01783061 0.03034483 0.0405954  0.03090235 0.01853998
 0.03179825 0.04115226 0.03519062 0.04919584 0.05028463 0.0480226
 0.07276119 0.06540284 0.09235075 0.08340573 0.08065915 0.05922166
 0.06193969 0.05942948 0.06942278 0.         0.07097289 0.06269592
 0.05414013 0.05823627 0.05602007 0.05681818 0.0460251  0.03169308
 0.03022453 0.04181185 0.0380868  0.02982293 0.02364532 0.02651113
 0.03097345 0.02719407 0.0129199  0.0210084  0.02391629 0.0192
 0.02166065 0.015625   0.01594533 0.0255102  0.01428571 0.00680272
 0.01321586 0.01685393 0.         0.01834862 0.02597403 0.
 0.025      0.         0.         0.         0.         0.
 0.        ]
