In [1]:
from numba import config
config.DISABLE_JIT = True

import numpy as np
from psifr.stats import percentile_rank
from compmempy.helpers.transforming_arrays import njit_apply_along_axis
from compmempy.helpers.handling_data import item_to_study_positions, apply_by_subject
from compmempy.helpers.loading_data import to_numba_typed_dict
from jaxcmr_research.helpers.hdf5 import generate_trial_mask, load_data

In [2]:
#| code-summary: reference implementation of lag-rank analysis in simplest case

def lag_rank(trials: np.ndarray, list_length: int) -> float:
    """Summarize the tendency to group together nearby items by running a lag rank analysis.

    For each recall, this determines the absolute lag of all remaining items available for recall and then calculates their percentile rank. Then the rank of the actual transition made is taken, scaled to vary between 0 (furthest item chosen) and 1 (nearest item chosen). Chance clustering will be 0.5; clustering above that value is evidence of a temporal contiguity effect.

    Args:
        trials (np.ndarray): rows represent trials and columns represents a recall position.
            Nonzero identify study index of recalled items.
        list_length (int): The number of item presentations in each trial.
    """
    terminus = np.sum(trials != 0, axis=1)  # determine where each trial ends
    actual_ranks = []

    for trial_index in range(len(trials)):
        possible_items = np.arange(list_length) + 1
        previous_item = 0

        for recall_index in range(terminus[trial_index]):
            current_item = trials[trial_index, recall_index]
            if recall_index > 0:
                actual_lag = np.abs(current_item - previous_item)
                possible_lags = np.abs(possible_items - previous_item)
                actual_ranks.append(1 - percentile_rank(actual_lag, possible_lags))

            previous_item = current_item
            possible_items = possible_items[possible_items != previous_item]

    return float(np.nanmean(actual_ranks))


In [3]:
class RepetitionLagRank:
    """
    Generalize lag-rank analyses to account for multiple presentations of each item, but exclusively consider transitions from repeated items and separately measure rank percentiles relative to each presentation index of the repeated items.

    Assess lag-rank using the nearest study position to the reference study position used for tabulating lags. 
    """


    def __init__(
        self, presentations: np.ndarray, max_repetitions: int = 2, min_lag: int = 4
    ):
        "Pre-allocate arrays for lag-CRP tabulations."
        list_length = np.max(np.sum(presentations != 0, axis=1))
        self.lag_range = list_length - 1
        self.min_lag = min_lag
        self.max_repetitions = max_repetitions
        self.actual_ranks = [[] for _ in range(max_repetitions)]

    def should_tabulate(
        self,
        prev_study_positions: np.ndarray,
    ) -> bool:
        """Only consider transitions from item with at least two spaced-out study positions"""
        return len(prev_study_positions) > 1 and prev_study_positions[
            -1
        ] - prev_study_positions[-2] >= (self.min_lag + 1)

    def tabulate_lags(
        self,
        previous_item: int,
        current_item: int,
        possible_items: np.ndarray,
        item_study_positions: list[np.ndarray],
    ):
        "Tabulate actual and possible serial lags of current from previous item."

        prev_study_positions = item_study_positions[previous_item - 1]
        current_study_positions = item_study_positions[current_item - 1]
        for repetition_index, prev_study_position in enumerate(prev_study_positions):

            # first track the minimum lags of the actual transition made
            actual_lag = np.inf
            for current_study_position in current_study_positions:
                lag = np.abs(current_study_position - prev_study_position)
                actual_lag = np.minimum(lag, actual_lag)

            # then track minimum lags for each possible transition
            possible_lags = []
            for item in possible_items:
                possible_lag = np.inf
                possible_study_positions = item_study_positions[item - 1]
                for possible_study_position in possible_study_positions:
                    lag = np.abs(possible_study_position - prev_study_position)
                    possible_lag = np.minimum(lag, possible_lag)
                possible_lags.append(possible_lag)

            percent_rank = 1 - percentile_rank(actual_lag, possible_lags)
            # assert not np.isnan(percent_rank)
            self.actual_ranks[repetition_index].append(percent_rank)

    def tabulate_over_transitions(
        self, trials: np.ndarray, presentations: np.ndarray
    ) -> np.ndarray:
        "Tabulate actual and possible lag transitions over a set of trials"

        terminus = np.sum(trials != 0, axis=1)

        for trial_index in range(len(trials)):
            presentation = presentations[trial_index]
            item_count = np.max(presentation)
            possible_items = np.arange(item_count) + 1
            item_study_positions = njit_apply_along_axis(
                item_to_study_positions, possible_items, presentation
            )
            previous_item = 0

            for recall_index in range(terminus[trial_index]):
                current_item = presentation[trials[trial_index, recall_index] - 1]
                if recall_index > 0 and self.should_tabulate(
                    item_study_positions[previous_item - 1],
                ):
                    self.tabulate_lags(
                        previous_item,
                        current_item,
                        possible_items,
                        item_study_positions,
                    )

                previous_item = current_item
                possible_items = possible_items[possible_items != previous_item]

        return np.nanmean(self.actual_ranks, axis=1)
    
def replagrank(
    trials: np.ndarray,
    presentations: np.ndarray,
    list_length: int,
    max_repetitions: int = 2,
    min_lag: int = 4,
) -> np.ndarray:
    "Summarize the tendency to group together nearby items by running a lag rank analysis."
    lag_rank = RepetitionLagRank(presentations, max_repetitions, min_lag)
    scores = lag_rank.tabulate_over_transitions(trials, presentations)
    # return scores[0] - scores[1]
    return scores

## Lohnas Kahana 2014 Dataset

In [4]:
data_name = "LohnasKahana2014"
data_path = "data/LohnasKahana2014.h5"
data_query = "data['list_type'] >= 3"

data = to_numba_typed_dict({key: np.array(value) for key, value in load_data(data_path).items()})
trial_mask = generate_trial_mask(data, data_query)

max_repetitions = 2
min_lag = 4

In [5]:
subject_values = np.array(apply_by_subject(
    data,
    trial_mask,
    replagrank,
    max_repetitions,
    min_lag,
))

subject_values

array([[0.81361003, 0.7295366 ],
       [0.72933177, 0.62998377],
       [0.69780963, 0.68558622],
       [0.6556871 , 0.64951138],
       [0.64238469, 0.60224452],
       [0.72728644, 0.72827337],
       [0.57413907, 0.66251383],
       [0.59982331, 0.59598916],
       [0.74370313, 0.68040249],
       [0.69351671, 0.66671052],
       [0.85971342, 0.71705638],
       [0.63179574, 0.63085404],
       [0.61232839, 0.60385234],
       [0.78649566, 0.72635669],
       [0.66599945, 0.60788552],
       [0.76838349, 0.6924549 ],
       [0.82239387, 0.75002738],
       [0.79654307, 0.72084637],
       [0.60051983, 0.62020085],
       [0.68834796, 0.6614504 ],
       [0.60875966, 0.58214981],
       [0.59232637, 0.63117534],
       [0.60970931, 0.59833962],
       [0.70855475, 0.65347738],
       [0.69587753, 0.77728374],
       [0.6418735 , 0.61501097],
       [0.81724153, 0.7786852 ],
       [0.6340587 , 0.64829903],
       [0.60041945, 0.60581126],
       [0.75633806, 0.72031164],
       [0.

In [12]:
from jaxcmr_research.helpers.repetition import control_dataset

ctrl_data = control_dataset(to_numba_typed_dict({key: np.array(value) for key, value in data.items()}), "data['list_type'] == 4", "data['list_type'] == 1", 100)

ctrl_subject_values = np.array(apply_by_subject(
    ctrl_data,
    np.ones(ctrl_data['recalls'].shape[0], dtype=bool),
    replagrank,
    max_repetitions,
    min_lag,
))

ctrl_subject_values

35it [00:00, 514.28it/s]


IndexError: index 0 is out of bounds for axis 0 with size 0

In [6]:
import numpy as np
from scipy.stats import ttest_rel

# Example data in an Nx2 array
data = subject_values

# Ensure the data has exactly two columns
if data.shape[1] != 2:
    raise ValueError("Data must have exactly two columns")

# Split the data into two related samples
data1, data2 = data[:, 0], data[:, 1]

# Perform the paired t-test
t_statistic, p_value = ttest_rel(data1, data2, alternative="greater")

print(f"T-statistic: {t_statistic}")
print(f"One-tailed P-value: {p_value}")

# Interpretation
if p_value < 0.05:
    print("There is a statistically significant difference between the two paired samples at the 5% significance level.")
else:
    print("There is no statistically significant difference between the two paired samples at the 5% significance level.")

print(f"Mean First Sample: {np.mean(data1)}")
print(f"Mean Second Sample: {np.mean(data2)}")
print(f"Standard Error First Sample: {np.std(data1) / np.sqrt(len(data1))}")
print(f"Standard Error Second Sample: {np.std(data2) / np.sqrt(len(data2))}")
print(f"Mean Difference: {np.mean(data1 - data2)}")
print(f"Standard Error Difference: {np.std(data1 - data2) / np.sqrt(len(data1))}")

T-statistic: 3.4299583538057674
One-tailed P-value: 0.000800182195425057
There is a statistically significant difference between the two paired samples at the 5% significance level.
Mean First Sample: 0.6902984700510105
Mean Second Sample: 0.6627942211503018
Standard Error First Sample: 0.01371007703540283
Standard Error Second Sample: 0.00962267158588462
Mean Difference: 0.027504248900708624
Standard Error Difference: 0.00790344397299532


## Simulation of CMR Fitted to Lohnas Kahana 2014 Dataset

In [10]:
from jaxcmr_research.helpers.hdf5 import simulate_h5_from_h5
from jax import random
from jaxcmr_research.helpers.hdf5 import generate_trial_mask, load_data
from jaxcmr_research.helpers.misc import summarize_parameters, import_from_string
import numpy as np
from jaxcmr_research.helpers.array import compute_similarity_matrix
from jax import numpy as jnp
import json
from IPython.display import Markdown  # type: ignore

data_name = "Cond34LohnasKahana2014"
data_path = "data/LohnasKahana2014.h5"
data_query = "data['list_type'] >= 3"
connection_path = "data/peers-all-mpnet-base-v2.npy"
experiment_count = 1
seed = 0
fit_result_path = (
    "notebooks/Model_Fitting/Cond34LohnasKahana2014_BaseCMR_Model_Fitting.json"
)


data = load_data(data_path)
trial_mask = generate_trial_mask(data, data_query)
embeddings = np.load(connection_path)
connections = compute_similarity_matrix(embeddings) # unused here
model_factory_path = "jaxcmr_research.cmr.BaseCMRFactory"
model_factory = import_from_string(model_factory_path)
with open(fit_result_path, "r") as f:
    results = json.load(f)
    if "subject" not in results["fits"]:
        results["fits"]["subject"] = results["subject"]


Markdown(
    summarize_parameters([results], None, include_std=True, include_ci=True)
)



| | | Cond34LohnasKahana2014 BaseCMR Model Fitting |
|---|---|---|
| fitness | mean | 817.31 +/- 70.82 |
| | std | 203.20 |
| recall drift rate | mean | 0.91 +/- 0.05 |
| | std | 0.15 |
| encoding drift rate | mean | 0.71 +/- 0.05 |
| | std | 0.15 |
| stop probability growth | mean | 0.23 +/- 0.03 |
| | std | 0.10 |
| learning rate | mean | 0.47 +/- 0.08 |
| | std | 0.23 |
| item support | mean | 14.62 +/- 7.60 |
| | std | 21.82 |
| semantic scale | mean | 0.00 +/- 0.00 |
| | std | 0.00 |
| mfc trace sensitivity | mean | 1.00 +/- 0.00 |
| | std | 0.00 |
| stop probability scale | mean | 0.02 +/- 0.01 |
| | std | 0.02 |
| shared support | mean | 8.95 +/- 5.74 |
| | std | 16.47 |
| semantic choice sensitivity | mean | 0.00 +/- 0.00 |
| | std | 0.00 |
| primacy decay | mean | 20.74 +/- 10.79 |
| | std | 30.97 |
| primacy scale | mean | 11.76 +/- 8.89 |
| | std | 25.52 |
| start drift rate | mean | 0.61 +/- 0.11 |
| | std | 0.32 |
| choice sensitivity | mean | 29.08 +/- 12.52 |
| | std | 35.92 |
| mfc choice sensitivity | mean | 1.00 +/- 0.00 |
| | std | 0.00 |
| mcf trace sensitivity | mean | 1.00 +/- 0.00 |
| | std | 0.00 |


In [11]:
rng = random.PRNGKey(seed)
rng, rng_iter = random.split(rng)
sim = simulate_h5_from_h5(
    model_factory=model_factory,
    dataset=data,
    connections=connections,
    parameters={key: jnp.array(val) for key, val in results["fits"].items()},
    trial_mask=trial_mask,
    experiment_count=experiment_count,
    rng=rng_iter,
)

In [12]:
subject_values = np.array(apply_by_subject(
    sim,
    generate_trial_mask(sim, data_query),
    replagrank,
    max_repetitions,
    min_lag,
))

subject_values

array([[0.77692096, 0.74743774],
       [0.70638134, 0.67180049],
       [0.75945242, 0.73465871],
       [0.65707941, 0.64447968],
       [0.67295288, 0.64746986],
       [0.72448643, 0.74264282],
       [0.75290972, 0.71364864],
       [0.66671027, 0.69113424],
       [0.67982723, 0.70215156],
       [0.62979111, 0.60622805],
       [0.78261856, 0.79155743],
       [0.69979227, 0.6532797 ],
       [0.61556494, 0.59404468],
       [0.7843276 , 0.75754388],
       [0.66870814, 0.66289579],
       [0.72253204, 0.73960228],
       [0.76171796, 0.7447928 ],
       [0.75842882, 0.72868288],
       [0.59033125, 0.63275705],
       [0.65047042, 0.6351033 ],
       [0.61912167, 0.610592  ],
       [0.67443116, 0.68603807],
       [0.48866639, 0.44945337],
       [0.68564902, 0.69379116],
       [0.75800599, 0.69960205],
       [0.60612067, 0.60563539],
       [0.84268733, 0.75617773],
       [0.59443468, 0.69628965],
       [0.67282773, 0.65928771],
       [0.76414394, 0.79114156],
       [0.

In [13]:
import numpy as np
from scipy.stats import ttest_rel

# Example data in an Nx2 array
data = subject_values

# Ensure the data has exactly two columns
if data.shape[1] != 2:
    raise ValueError("Data must have exactly two columns")

# Split the data into two related samples
data1, data2 = data[:, 0], data[:, 1]

# Perform the paired t-test
t_statistic, p_value = ttest_rel(data1, data2, alternative="greater")

print(f"T-statistic: {t_statistic}")
print(f"Two-tailed P-value: {p_value}")

# Interpretation
if p_value < 0.05:
    print("There is a statistically significant difference between the two paired samples at the 5% significance level.")
else:
    print("There is no statistically significant difference between the two paired samples at the 5% significance level.")

print(f"Mean First Sample: {np.mean(data1)}")
print(f"Mean Second Sample: {np.mean(data2)}")
print(f"Standard Error First Sample: {np.std(data1) / np.sqrt(len(data1))}")
print(f"Standard Error Second Sample: {np.std(data2) / np.sqrt(len(data2))}")
print(f"Mean Difference: {np.mean(data1 - data2)}")
print(f"Standard Error Difference: {np.std(data1 - data2) / np.sqrt(len(data1))}")

T-statistic: 1.3167208876141099
Two-tailed P-value: 0.09836951011353126
There is no statistically significant difference between the two paired samples at the 5% significance level.
Mean First Sample: 0.6895850904892817
Mean Second Sample: 0.6814926971478833
Standard Error First Sample: 0.01219909593171219
Standard Error Second Sample: 0.011695876197063684
Mean Difference: 0.008092393341398618
Standard Error Difference: 0.006057434063202858
