In [1]:
# default_exp analyses

# Study Phase Retrieval Effect
They considered transitions between items following a shared repeated item. They calculated the proportion of those items recalled in $S_j = {j + 1, j + 2}$ of which CMR then recalled an item in the set $S_i = {i + 1, i + 2}$. They also calculated the proportion of recalls $S_i$ of which CMR then transitioned to an item in the set $S_j$. They calculated the proportion of transitions for each of lags $j - i >= 4$, and represented the mean percent of transitions across these lags. 

1. Find items recalled in S_i or S_j across each repeatedly presented item in each trial.
2. For each relevant recall, check for a successively recalled item. 
3. If there is a successively recalled, increment a counter. 
4. If the successively recalled item is in S_j when the previous item was in S_i or vice versa, increment a separate counter.
5. Return the proportion of S_i or S_j transitions over the total number of transitions from S_i or S_j. 

In [64]:
from repetition_cmr.analyses import recall_by_all_study_positions

In [2]:
# export

from numba import njit
import numpy as np

#@njit(nogil=True)
def neighbor_contiguity(trials, presentations, max_repeats=2):

    # 0) initialize variables
    list_length = len(presentations[0])
    lag_range = list_length - 1
    total_actual_lags = np.zeros(lag_range * 2 + 1)
    terminus = np.sum(trials != 0, axis=1) # number of recalls per trial
    recall_by_study_position = recall_by_all_study_positions(trials, presentations, max_repeats)
    trials_shape = np.shape(trials)
    mask = np.zeros((trials_shape[0], max_repeats, trials_shape[1]), dtype=np.bool_)
    reference_positions = np.zeros((trials_shape[0], max_repeats, trials_shape[1]), dtype=np.int32)

    # 1) # create mask identifying recalls in S_i or S_j
    for trial_index, presentation in enumerate(presentations):

        masked_items = [[] for _ in range(max_repeats)]

        for current_index, current_item in enumerate(presentation[:-1]):

            # identify each time current_item occurs in presentation and skip if count is 1
            positions = np.nonzero(presentation == current_item)[0]
            if len(positions) == 1:
                continue

            # also skip if lag between positions is less than 4
            if positions[1] - positions[0] < 4:
                continue

            # identify relative position of current_index in repetitions
            # if repetition_index is 0, then successively presented items are in S_i
            # if repetition_index is 1, then successively presented items are in S_j
            repetition_index = np.nonzero(positions==current_index)[0][0]

            # recall of item at succesive two serial positions should be in mask 
            # reference_positions position should be the repetition position 
            # that's distinct from current_index
            # TODO: figure out policy for when max_repeats > 2
            if current_index+1 < list_length:
                masked_items[repetition_index].append(presentation[current_index+1]+1)

                reference_positions[trial_index, repetition_index, np.nonzero(
                    trials[trial_index] == presentation[current_index+1]+1)[0]] = positions[positions != current_index][0] + 1

            if current_index+2 < list_length:
                masked_items[repetition_index].append(presentation[current_index+2]+1)

                reference_positions[trial_index, repetition_index, np.nonzero(
                    trials[trial_index] == presentation[current_index+2]+1)[0]] = positions[positions != current_index][0] + 1    

            return reference_positions[trial_index]

        for i in range(max_repeats):
            mask[trial_index, i] = np.isin(trials[trial_index], np.array(masked_items[i]))

    total_transitions = 0
    for trial_index in range(len(trials)):
        
        previous_item = 0
        item_count = np.max(presentations[trial_index]) + 1
        possible_items = np.arange(item_count) # initial pool of possible recalls, 1-indexed
        possible_positions = np.zeros((item_count, max_repeats), dtype=np.int32)
        
        # we track possible positions using presentations and alt_presentations
        for item in range(item_count):
            pos = np.nonzero(presentations[trial_index] == item)[0] + 1
            possible_positions[item, :len(pos)] = pos
            
        for recall_index in range(terminus[trial_index]):
            
            current_item = presentations[trial_index][trials[trial_index, recall_index]-1]
            
            # track possible and actual transition lags from selected recalls
            if recall_index > 0 and np.any(mask[trial_index, :, recall_index-1]): 
                
                total_transitions += 1
                repetition_index = np.nonzero(mask[trial_index, :, recall_index-1])[0][0]
                
                # item indices don't help track lags anymore
                # so more complex calculation needed to identify possible lags given previous item
                current_index = np.nonzero(possible_items==current_item)[0]
                possible_lags = np.zeros((max_repeats, len(possible_items)), dtype=np.int32)

                for x in range(len(recall_by_study_position)):
                    # for y in range(len(recall_by_study_position)):
                    #     if reference_positions[trial_index, y, recall_index-1] > 0:
                        
                    possible_lags[x] = possible_positions[
                        possible_items, x] - reference_positions[trial_index, repetition_index, recall_index-1]
                    
                    # if tracked position is 0, then we don't actually want to count it in our lags
                    possible_lags[x][possible_positions[possible_items, x] == 0] = 0

                possible_lags += lag_range
                total_actual_lags[possible_lags[:, current_index].flatten()] += 1

            # update pool to exclude recalled item (updated to still identify 1-indexed item)
            previous_item = current_item
            possible_items = possible_items[possible_items != previous_item]
                    
    # we only care about the +1 and +2 transitions here
    relevant_transitions = total_actual_lags[lag_range+1] + total_actual_lags[lag_range+2]
    
    return relevant_transitions/total_transitions

In [65]:
from numba import njit
import numpy as np

#@njit(nogil=True)
def neighbor_contiguity(trials, presentations, max_repeats=2):

    total_transitions = 0
    total_relevant_transitions = 0
    for trial_index, presentation in enumerate(presentations):

        # track for each item which transitions we're interested in
        relevant_transitions = [[] for _ in range(len(presentation))]
        for current_index, current_item in enumerate(presentation[:-1]):
                
            # identify each time current_item occurs in presentation and skip if count is 1
            positions = np.nonzero(presentation == current_item)[0]
            if len(positions) == 1:
                continue

            # also skip if lag between positions is less than 4
            if np.abs(positions[1] - positions[0]) < 4:
                continue

            # identify what would be relevant si and s_j
            s_i = []
            if current_index+1 < len(presentation):
                s_i.append(presentation[current_index+1])
            if current_index+2 < len(presentation):
                s_i.append(presentation[current_index+2])

            alt_index = positions[positions != current_index][0]
            s_j = []
            if alt_index+1 < len(presentation):
                s_j.append(presentation[alt_index+1])
            if alt_index+2 < len(presentation):
                s_j.append(presentation[alt_index+2])
    
            # track pairing in relevant_transitions
            for s_i_item in s_i:
                for s_j_item in s_j:
                    relevant_transitions[s_i_item].append(s_j_item)
                    relevant_transitions[s_j_item].append(s_i_item)

        # then for each recall, 
        print()
        print(presentation)
        # for i in range(len(presentation)):
        #     print(i, relevant_transitions[i])
        print(trials[trial_index])
        for recall_index, recall_item in enumerate(trials[trial_index]):
            print('recall_item', recall_item-1)

            # 0 means null so we skip
            if recall_item == 0:
                break

            # skip if item has no relevant transitions
            print([each for each in relevant_transitions[recall_item-1] if each+1 not in trials[trial_index, :recall_index]])
            if not [each for each in relevant_transitions[recall_item-1] if each+1 not in trials[trial_index, :recall_index]]:
                continue

            # skip if successively recalled item doesn't exist
            if recall_index+1 < len(trials[trial_index]):
                print('next_item', trials[trial_index][recall_index+1]-1)
                if trials[trial_index][recall_index+1] == 0:
                    continue
            else:
                continue
            total_transitions += 1

            # check if successively recalled item is among relevant transitions
            if trials[trial_index][recall_index+1]-1 in relevant_transitions[recall_item-1]:
                print('MATCH')
                total_relevant_transitions += 1
            else:
                print('NO MATCH')

        #print(total_relevant_transitions/total_transitions)

    return total_relevant_transitions/total_transitions

In [67]:
import pandas as pd
from compmemlearn.datasets import events_metadata, generate_trial_mask

events = pd.read_csv('../../../compmemlearn/data/LohnasKahana2014.csv')
trials, list_lengths, presentations = events_metadata(events)
trials = trials
presentations = presentations
list_length = list_lengths

for condition in [3]:
    trial_mask = generate_trial_mask(events, f"condition == {condition}")[0]
    result = neighbor_contiguity(trials[0][trial_mask], presentations[0][trial_mask])
    print(result)


[ 0  1  2  3  4  5  4  6  1  0  7  2  3  7  5  6  8  9 10 11 10 12 13 14
  9  8 15 12 11 13 15 16 14 17 18 19 16 19 18 17]
[ 1  2  3  4  5  6 11  8 17 18 22 20 23 32 36 24  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0  0]
recall_item 0
[2, 3, 8, 9, 2, 3, 8, 9]
next_item 1
NO MATCH
recall_item 1
[7, 2, 8, 9, 7, 2, 8, 9]
next_item 2
MATCH
recall_item 2
[7, 2, 2, 7, 7, 7, 2, 2]
next_item 3
NO MATCH
recall_item 3
[7, 3, 3, 7, 4, 7, 3, 3, 4, 7]
next_item 4
MATCH
recall_item 4
[7, 7, 5, 6, 8, 7, 7, 5, 6, 8]
next_item 5
MATCH
recall_item 5
[7, 5, 5, 7, 5, 5]
next_item 10
NO MATCH
recall_item 10
[15, 12, 8, 15, 13, 15, 8, 15, 15, 12, 13, 15]
next_item 7
NO MATCH
recall_item 7
[]
recall_item 16
[14, 9, 12, 11, 14, 9, 12, 11, 17, 17]
next_item 17
MATCH
recall_item 17
[9, 8, 19, 18, 9, 8, 19, 19, 18, 19]
next_item 21
NO MATCH
recall_item 21
[]
recall_item 19
[14, 14]
next_item 22
NO MATCH
recall_item 22
[]
recall_item 31
[]
recall_item 35
[]
recall_item 23
[]
recall_item -1

[ 0  

In [63]:
from compmemlearn.datasets import prepare_lohnas2014_data

trials, events, list_length, presentations, list_types, rep_data, subjects = prepare_lohnas2014_data('../../../compmemlearn/data/repFR.mat')

results = []
for subject in np.unique(subjects):
    selection = np.logical_and(subjects == subject, list_types == 4)
    results.append(neighbor_contiguity(trials[selection], presentations[selection]))
np.mean(results)


[ 0  1  2  3  4  5  6  7  8  9 10 11 11 12 13 14 15 16  9 17 18 19 18 20
 21 22 19 23 24 25 21 26 27 23 28 29 30 31 32 33]
[ 1  2  3  4  5  6  7  9 10 11 17 14 12 15 25 20 28 30 39 38 37 18  0  0
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
recall_item 0
[]
recall_item 1
[]
recall_item 2
[]
recall_item 3
[]
recall_item 4
[]
recall_item 5
[]
recall_item 6
[]
recall_item 8
[]
recall_item 9
[]
recall_item 10
[17, 18, 17, 18]
next_item 16
NO MATCH
recall_item 16
[]
recall_item 13
[]
recall_item 11
[17, 18, 17, 18]
next_item 14
NO MATCH
recall_item 14
[]
recall_item 24
[18, 20, 18, 20, 28, 29, 28, 29]
next_item 19
NO MATCH
recall_item 19
[26, 27, 26, 27]
next_item 27
MATCH
recall_item 27
[22, 22]
next_item 29
NO MATCH
recall_item 29
[25, 25]
next_item 38
NO MATCH
recall_item 38
[]
recall_item 37
[]
recall_item 36
[]
recall_item 17
[]
recall_item -1

[ 0  1  2  3  4  5  6  7  8  9 10 11 12  9 13 14 15 16 17 18 19 20 21 22
 19 17 23 24 25 26 24 27 28 27 23 29 30 31 32 33]
[ 1  3  4  5 1

0.09525665473191955

In [27]:
presentations[0][0]

array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 11, 12, 13, 14, 15,
       16,  9, 17, 18, 19, 18, 20, 21, 22, 19, 23, 24, 25, 21, 26, 27, 23,
       28, 29, 30, 31, 32, 33])

In [28]:
trials[0][0]

array([ 1,  2,  3,  4,  5,  6,  7,  9, 10, 11, 17, 14, 12, 15, 25, 20, 28,
       30, 39, 38, 37, 18,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0])

In [31]:
trials[0][0]-1

array([ 0,  1,  2,  3,  4,  5,  6,  8,  9, 10, 16, 13, 11, 14, 24, 19, 27,
       29, 38, 37, 36, 17, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
       -1, -1, -1])

In [29]:
trials[0][0][result[0]]-1

array([10, 11, 24, 19])

In [30]:
trials[0][0][result[1]]-1

array([24, 27, 29, 17])

array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 11, 12, 13, 14, 15,
       16,  9, 17, 18, 19, 18, 20, 21, 22, 19, 23, 24, 25, 21, 26, 27, 23,
       28, 29, 30, 31, 32, 33])