In [1]:
from numba import config
config.DISABLE_JIT = True

import numpy as np
from psifr.stats import percentile_rank
from compmempy.helpers.transforming_arrays import njit_apply_along_axis
from compmempy.helpers.handling_data import item_to_study_positions, apply_by_subject
from compmempy.helpers.loading_data import to_numba_typed_dict
from jaxcmr_research.helpers.hdf5 import generate_trial_mask, load_data

In [2]:
#| code-summary: reference implementation of lag-rank analysis in simplest case

def lag_rank(trials: np.ndarray, list_length: int) -> float:
    """Summarize the tendency to group together nearby items by running a lag rank analysis.

    For each recall, this determines the absolute lag of all remaining items available for recall and then calculates their percentile rank. Then the rank of the actual transition made is taken, scaled to vary between 0 (furthest item chosen) and 1 (nearest item chosen). Chance clustering will be 0.5; clustering above that value is evidence of a temporal contiguity effect.

    Args:
        trials (np.ndarray): rows represent trials and columns represents a recall position.
            Nonzero identify study index of recalled items.
        list_length (int): The number of item presentations in each trial.
    """
    terminus = np.sum(trials != 0, axis=1)  # determine where each trial ends
    actual_ranks = []

    for trial_index in range(len(trials)):
        possible_items = np.arange(list_length) + 1
        previous_item = 0

        for recall_index in range(terminus[trial_index]):
            current_item = trials[trial_index, recall_index]
            if recall_index > 0:
                actual_lag = np.abs(current_item - previous_item)
                possible_lags = np.abs(possible_items - previous_item)
                actual_ranks.append(1 - percentile_rank(actual_lag, possible_lags))

            previous_item = current_item
            possible_items = possible_items[possible_items != previous_item]

    return float(np.nanmean(actual_ranks))


In [3]:
class RepetitionLagRank:
    """
    Generalize lag-rank analyses to account for multiple presentations of each item, but exclusively consider transitions from repeated items and separately measure rank percentiles relative to each presentation index of the repeated items.

    Assess lag-rank using the nearest study position to the reference study position used for tabulating lags. 
    """


    def __init__(
        self, presentations: np.ndarray, max_repetitions: int = 2, min_lag: int = 4
    ):
        "Pre-allocate arrays for lag-CRP tabulations."
        list_length = np.max(np.sum(presentations != 0, axis=1))
        self.lag_range = list_length - 1
        self.min_lag = min_lag
        self.max_repetitions = max_repetitions
        self.actual_ranks = [[] for _ in range(max_repetitions)]

    def should_tabulate(
        self,
        prev_study_positions: np.ndarray,
    ) -> bool:
        """Only consider transitions from item with at least two spaced-out study positions"""
        return len(prev_study_positions) > 1 and prev_study_positions[
            -1
        ] - prev_study_positions[-2] >= (self.min_lag + 1)

    def tabulate_lags(
        self,
        previous_item: int,
        current_item: int,
        possible_items: np.ndarray,
        item_study_positions: list[np.ndarray],
    ):
        "Tabulate actual and possible serial lags of current from previous item."

        prev_study_positions = item_study_positions[previous_item - 1]
        current_study_positions = item_study_positions[current_item - 1]
        for repetition_index, prev_study_position in enumerate(prev_study_positions):

            # first track the minimum lags of the actual transition made
            actual_lag = np.inf
            for current_study_position in current_study_positions:
                lag = np.abs(current_study_position - prev_study_position)
                actual_lag = np.minimum(lag, actual_lag)

            # then track minimum lags for each possible transition
            possible_lags = []
            for item in possible_items:
                possible_lag = np.inf
                possible_study_positions = item_study_positions[item - 1]
                for possible_study_position in possible_study_positions:
                    lag = np.abs(possible_study_position - prev_study_position)
                    possible_lag = np.minimum(lag, possible_lag)
                possible_lags.append(possible_lag)

            percent_rank = 1 - percentile_rank(actual_lag, possible_lags)
            # assert not np.isnan(percent_rank)
            self.actual_ranks[repetition_index].append(percent_rank)

    def tabulate_over_transitions(
        self, trials: np.ndarray, presentations: np.ndarray
    ) -> np.ndarray:
        "Tabulate actual and possible lag transitions over a set of trials"

        terminus = np.sum(trials != 0, axis=1)

        for trial_index in range(len(trials)):
            presentation = presentations[trial_index]
            item_count = np.max(presentation)
            possible_items = np.arange(item_count) + 1
            item_study_positions = njit_apply_along_axis(
                item_to_study_positions, possible_items, presentation
            )
            previous_item = 0

            for recall_index in range(terminus[trial_index]):
                current_item = presentation[trials[trial_index, recall_index] - 1]
                if recall_index > 0 and self.should_tabulate(
                    item_study_positions[previous_item - 1],
                ):
                    self.tabulate_lags(
                        previous_item,
                        current_item,
                        possible_items,
                        item_study_positions,
                    )

                previous_item = current_item
                possible_items = possible_items[possible_items != previous_item]

        return np.nanmean(self.actual_ranks, axis=1)
    
def replagrank(
    trials: np.ndarray, 
    presentations: np.ndarray,
    list_length: int,
    max_repetitions: int = 2,
    min_lag: int = 4,
) -> np.ndarray:
    "Summarize the tendency to group together nearby items by running a lag rank analysis."
    lag_rank = RepetitionLagRank(presentations, max_repetitions, min_lag)
    scores = lag_rank.tabulate_over_transitions(trials, presentations)
    # return scores[0] - scores[1]
    return scores

## Lohnas Kahana 2014 Dataset

In [4]:
data_name = "HowardKahana2005"
data_path = "data/HowardKahana2005.h5"
data_query = "data['condition'] == 1"

data = to_numba_typed_dict({key: np.array(value) for key, value in load_data(data_path).items()})
trial_mask = generate_trial_mask(data, data_query)

max_repetitions = 3
min_lag = 4

In [5]:
subject_values = np.array(apply_by_subject(
    data,
    trial_mask,
    replagrank,
    max_repetitions,
    min_lag,
))

subject_values

array([[0.55458256, 0.55935317, 0.57284621],
       [0.54180173, 0.59890606, 0.57060684],
       [0.53393008, 0.55153121, 0.53496213],
       [0.57366132, 0.63184187, 0.60833569],
       [0.81606475, 0.7471299 , 0.65207017],
       [0.59370536, 0.63382783, 0.62046825],
       [0.56615354, 0.69242874, 0.67287036],
       [0.58799843, 0.65033669, 0.70159359],
       [0.65934723, 0.63793931, 0.63931203],
       [0.5338701 , 0.59965742, 0.62899263],
       [0.57808731, 0.62113122, 0.61312485],
       [0.58809324, 0.61893685, 0.63010304],
       [0.66350915, 0.64017319, 0.61181625],
       [0.54788143, 0.58375505, 0.5761218 ],
       [0.58420092, 0.59557518, 0.61315352],
       [0.54054236, 0.58148119, 0.64388459],
       [0.65721444, 0.64702842, 0.68813806],
       [0.5530729 , 0.64742351, 0.60206152],
       [0.53660104, 0.59874394, 0.57247216],
       [0.52473799, 0.54955018, 0.61529558],
       [0.70289568, 0.71807535, 0.68180693],
       [0.59832981, 0.51317758, 0.55687888],
       [0.

In [7]:
import numpy as np
from scipy.stats import ttest_rel

# Example data in an Nx2 array
data = subject_values

# Split the data into two related samples
data1, data2 = data[:, 0], data[:, 2]

# Perform the paired t-test
t_statistic, p_value = ttest_rel(data1, data2, alternative="greater")

print(f"T-statistic: {t_statistic}")
print(f"Two-tailed P-value: {p_value}")

# Interpretation
if p_value < 0.05:
    print("There is a statistically significant difference between the two paired samples at the 5% significance level.")
else:
    print("There is no statistically significant difference between the two paired samples at the 5% significance level.")

print(f"Mean First Sample: {np.mean(data1)}")
print(f"Mean Second Sample: {np.mean(data2)}")
print(f"Standard Error First Sample: {np.std(data1) / np.sqrt(len(data1))}")
print(f"Standard Error Second Sample: {np.std(data2) / np.sqrt(len(data2))}")
print(f"Mean Difference: {np.mean(data1 - data2)}")
print(f"Standard Error Difference: {np.std(data1 - data2) / np.sqrt(len(data1))}")

T-statistic: -1.0248325864793413
Two-tailed P-value: 0.8453794324837075
There is no statistically significant difference between the two paired samples at the 5% significance level.
Mean First Sample: 0.5869558009074269
Mean Second Sample: 0.5939309807352077
Standard Error First Sample: 0.0071141390595325935
Standard Error Second Sample: 0.00646714770370032
Mean Difference: -0.00697517982778077
Standard Error Difference: 0.006754406482273385


## Simulation of CMR Fitted to Lohnas Kahana 2014 Dataset

In [73]:
from jaxcmr_research.helpers.hdf5 import simulate_h5_from_h5
from jax import random
from jaxcmr_research.helpers.hdf5 import generate_trial_mask, load_data
from jaxcmr_research.helpers.misc import summarize_parameters, import_from_string
import numpy as np
from jaxcmr_research.helpers.array import compute_similarity_matrix
from jax import numpy as jnp
import json
from IPython.display import Markdown  # type: ignore

data_name = "Cond34LohnasKahana2014"
data_path = "data/LohnasKahana2014.h5"
data_query = "data['list_type'] >= 3"
connection_path = "data/peers-all-mpnet-base-v2.npy"
experiment_count = 100
seed = 0
fit_result_path = (
    "notebooks/Model_Fitting/Cond34LohnasKahana2014_BaseCMR_Model_Fitting.json"
)


data = load_data(data_path)
trial_mask = generate_trial_mask(data, data_query)
embeddings = np.load(connection_path)
connections = compute_similarity_matrix(embeddings) # unused here
model_factory_path = "jaxcmr_research.cmr.BaseCMRFactory"
model_factory = import_from_string(model_factory_path)
with open(fit_result_path, "r") as f:
    results = json.load(f)
    if "subject" not in results["fits"]:
        results["fits"]["subject"] = results["subject"]


Markdown(
    summarize_parameters([results], None, include_std=True, include_ci=True)
)



| | | Cond34LohnasKahana2014 BaseCMR Model Fitting |
|---|---|---|
| fitness | mean | 817.31 +/- 70.82 |
| | std | 203.20 |
| item support | mean | 14.62 +/- 7.60 |
| | std | 21.82 |
| recall drift rate | mean | 0.91 +/- 0.05 |
| | std | 0.15 |
| learning rate | mean | 0.47 +/- 0.08 |
| | std | 0.23 |
| stop probability scale | mean | 0.02 +/- 0.01 |
| | std | 0.02 |
| semantic scale | mean | 0.00 +/- 0.00 |
| | std | 0.00 |
| stop probability growth | mean | 0.23 +/- 0.03 |
| | std | 0.10 |
| choice sensitivity | mean | 29.08 +/- 12.52 |
| | std | 35.92 |
| mcf trace sensitivity | mean | 1.00 +/- 0.00 |
| | std | 0.00 |
| shared support | mean | 8.95 +/- 5.74 |
| | std | 16.47 |
| primacy decay | mean | 20.74 +/- 10.79 |
| | std | 30.97 |
| primacy scale | mean | 11.76 +/- 8.89 |
| | std | 25.52 |
| start drift rate | mean | 0.61 +/- 0.11 |
| | std | 0.32 |
| mfc trace sensitivity | mean | 1.00 +/- 0.00 |
| | std | 0.00 |
| mfc choice sensitivity | mean | 1.00 +/- 0.00 |
| | std | 0.00 |
| encoding drift rate | mean | 0.71 +/- 0.05 |
| | std | 0.15 |
| semantic choice sensitivity | mean | 0.00 +/- 0.00 |
| | std | 0.00 |


In [74]:
rng = random.PRNGKey(seed)
rng, rng_iter = random.split(rng)
sim = simulate_h5_from_h5(
    model_factory=model_factory,
    dataset=data,
    connections=connections,
    parameters={key: jnp.array(val) for key, val in results["fits"].items()},
    trial_mask=trial_mask,
    experiment_count=experiment_count,
    rng=rng_iter,
)

In [85]:
subject_values = np.array(apply_by_subject(
    sim,
    generate_trial_mask(sim, data_query),
    replagrank,
    max_repetitions,
    min_lag,
))

subject_values

array([[0.75887673, 0.7302167 ],
       [0.69681988, 0.67291635],
       [0.70955402, 0.7050469 ],
       [0.66869269, 0.65605559],
       [0.66450965, 0.64945109],
       [0.69580354, 0.71301576],
       [0.64725494, 0.65425818],
       [0.62592041, 0.63509403],
       [0.72164041, 0.71636556],
       [0.64914463, 0.66624445],
       [0.80162998, 0.74868433],
       [0.65573286, 0.65381038],
       [0.62425521, 0.61936696],
       [0.77323767, 0.72573166],
       [0.64427101, 0.63551851],
       [0.7450803 , 0.72672831],
       [0.78015708, 0.75350908],
       [0.76767412, 0.74310528],
       [0.64175197, 0.64298237],
       [0.6487697 , 0.64219973],
       [0.63852844, 0.64632898],
       [0.63352686, 0.66216156],
       [0.49074815, 0.50246911],
       [0.67416181, 0.67416415],
       [0.72535477, 0.74980616],
       [0.65746941, 0.65224424],
       [0.78740716, 0.74231676],
       [0.62743774, 0.66551034],
       [0.64185827, 0.65641555],
       [0.73360079, 0.73135758],
       [0.

In [86]:
import numpy as np
from scipy.stats import ttest_rel

# Example data in an Nx2 array
data = subject_values

# Ensure the data has exactly two columns
if data.shape[1] != 2:
    raise ValueError("Data must have exactly two columns")

# Split the data into two related samples
data1, data2 = data[:, 0], data[:, 1]

# Perform the paired t-test
t_statistic, p_value = ttest_rel(data1, data2)

print(f"T-statistic: {t_statistic}")
print(f"Two-tailed P-value: {p_value}")

# Interpretation
if p_value < 0.05:
    print("There is a statistically significant difference between the two paired samples at the 5% significance level.")
else:
    print("There is no statistically significant difference between the two paired samples at the 5% significance level.")

print(f"Mean First Sample: {np.mean(data1)}")
print(f"Mean Second Sample: {np.mean(data2)}")
print(f"Standard Error First Sample: {np.std(data1) / np.sqrt(len(data1))}")
print(f"Standard Error Second Sample: {np.std(data2) / np.sqrt(len(data2))}")
print(f"Mean Difference: {np.mean(data1 - data2)}")
print(f"Standard Error Difference: {np.std(data1 - data2) / np.sqrt(len(data1))}")

T-statistic: 1.7770109995486576
Two-tailed P-value: 0.08451454111037006
There is no statistically significant difference between the two paired samples at the 5% significance level.
Mean First Sample: 0.6856308203240852
Mean Second Sample: 0.6792693290738305
Standard Error First Sample: 0.010549645458655145
Standard Error Second Sample: 0.008753099626656921
Mean Difference: 0.006361491250254665
Standard Error Difference: 0.0035283710856964506
