In [1]:
from numba import config
config.DISABLE_JIT = True
from numba import int32, njit
from numba.experimental import jitclass
from typing import Type
import numpy as np
from psifr.stats import percentile_rank
from compmempy.helpers.transforming_arrays import njit_apply_along_axis
from compmempy.helpers.handling_data import item_to_study_positions, apply_by_subject
from compmempy.helpers.loading_data import to_numba_typed_dict
from jaxcmr_research.helpers.hdf5 import generate_trial_mask, load_data

In [2]:
class RepetitionLagCRP:
    def __init__(
        self, presentations: np.ndarray, max_repetitions: int = 2, min_lag: int = 4
    ):
        "Pre-allocate arrays for lag-CRP tabulations."
        list_length = np.max(np.sum(presentations != 0, axis=1))
        self.lag_range = list_length - 1
        self.min_lag = min_lag
        self.max_repetitions = max_repetitions
        self.actual_lag_transitions = np.zeros(
            (max_repetitions, self.lag_range * 2 + 1), dtype=np.int32
        )
        self.possible_lag_transitions = np.zeros(
            (max_repetitions, self.lag_range * 2 + 1), dtype=np.int32
        )

    def should_tabulate(
        self,
        prev_study_positions: np.ndarray,
    ) -> bool:
        """Only consider transitions from item with at least two spaced-out study positions"""
        return len(prev_study_positions) > 1 and prev_study_positions[
            -1
        ] - prev_study_positions[-2] >= (self.min_lag + 1)

    def tabulate_lags(
        self,
        previous_item: int,
        current_item: int,
        possible_items: np.ndarray,
        item_study_positions: list[np.ndarray],
    ):
        "Tabulate actual and possible serial lags of current from previous item."

        possible_lags = np.zeros(
            (self.max_repetitions, self.lag_range * 2 + 1), dtype=np.bool_
        )
        actual_lags = np.zeros(
            (self.max_repetitions, self.lag_range * 2 + 1), dtype=np.bool_
        )

        prev_study_positions = item_study_positions[previous_item - 1]
        current_study_positions = item_study_positions[current_item - 1]
        for repetition_index, prev_study_position in enumerate(prev_study_positions):
            for current_study_position in current_study_positions:
                lag = current_study_position - prev_study_position
                actual_lags[repetition_index, lag + self.lag_range] = True

            for item in possible_items:
                possible_study_positions = item_study_positions[item - 1]
                for possible_study_position in possible_study_positions:
                    lag = possible_study_position - prev_study_position
                    possible_lags[repetition_index, lag + self.lag_range] = True

        self.actual_lag_transitions += actual_lags
        self.possible_lag_transitions += possible_lags

    def tabulate_over_transitions(
        self, trials: np.ndarray, presentations: np.ndarray
    ) -> np.ndarray:
        "Tabulate actual and possible lag transitions over a set of trials"

        terminus = np.sum(trials != 0, axis=1)

        for trial_index in range(len(trials)):
            presentation = presentations[trial_index]
            item_count = np.max(presentation)
            possible_items = np.arange(item_count) + 1
            item_study_positions = njit_apply_along_axis(
                item_to_study_positions, possible_items, presentation
            )
            previous_item = 0

            for recall_index in range(terminus[trial_index]):
                current_item = presentation[trials[trial_index, recall_index] - 1]
                if recall_index > 0 and self.should_tabulate(
                    item_study_positions[previous_item - 1],
                ):
                    self.tabulate_lags(
                        previous_item,
                        current_item,
                        possible_items,
                        item_study_positions,
                    )

                previous_item = current_item
                possible_items = possible_items[possible_items != previous_item]

        return self.actual_lag_transitions / self.possible_lag_transitions


repetition_crp_spec = [
    ("lag_range", int32),
    ("max_repetitions", int32),
    ("min_lag", int32),
    ("actual_lag_transitions", int32[:, ::1]),
    ("possible_lag_transitions", int32[:, ::1]),
]

numba_RepetitionLagCRP: Type[RepetitionLagCRP] = jitclass(repetition_crp_spec)(
    RepetitionLagCRP
)  # type: ignore


@njit
def repcrp(
    trials: np.ndarray,
    presentations: np.ndarray,
    list_length: int,
    max_repetitions: int = 2,
    min_lag: int = 4,
) -> np.ndarray:
    "Apply the lag-CRP to a set of trials where each item has a single serial position."
    return numba_RepetitionLagCRP(
        presentations, max_repetitions, min_lag
    ).tabulate_over_transitions(trials, presentations)

## Lohnas Kahana 2014 Dataset

In [14]:
data_name = "LohnasKahana2014"
data_path = "data/LohnasKahana2014.h5"
data_query = "data['list_type'] >= 3"

data = to_numba_typed_dict(
    {key: np.array(value) for key, value in load_data(data_path).items()}
)
trial_mask = generate_trial_mask(data, data_query)
list_length = data["listLength"][0, 0]

max_repetitions = 2
min_lag = 4

In [15]:
subject_values = np.array(
    apply_by_subject(
        data,
        trial_mask,
        repcrp,
        max_repetitions,
        min_lag,
    )
)#[:, :, list_length]

subject_values

  return self.actual_lag_transitions / self.possible_lag_transitions


array([[[       nan,        nan,        nan, ..., 0.        ,
         0.        , 0.        ],
        [       nan,        nan,        nan, ...,        nan,
                nan,        nan]],

       [[       nan,        nan,        nan, ..., 0.        ,
         0.        , 0.        ],
        [0.        , 0.25      , 0.        , ...,        nan,
                nan,        nan]],

       [[       nan,        nan,        nan, ..., 0.        ,
                nan,        nan],
        [0.        , 0.        , 0.        , ...,        nan,
                nan,        nan]],

       ...,

       [[       nan,        nan,        nan, ..., 0.        ,
         0.        , 0.        ],
        [0.        , 0.16666667, 0.        , ...,        nan,
                nan,        nan]],

       [[       nan,        nan,        nan, ..., 0.        ,
         0.        , 0.        ],
        [0.        , 0.        , 0.        , ...,        nan,
                nan,        nan]],

       [[       n

In [5]:
import numpy as np
from scipy.stats import ttest_rel

# Example data in an Nx2 array
data = subject_values

# Ensure the data has exactly two columns
if data.shape[1] != 2:
    raise ValueError("Data must have exactly two columns")

# Split the data into two related samples
data1, data2 = data[:, 0], data[:, 1]

# Perform the paired t-test
t_statistic, p_value = ttest_rel(data1, data2)

print(f"T-statistic: {t_statistic}")
print(f"Two-tailed P-value: {p_value}")

# Interpretation
if p_value < 0.05:
    print(
        "There is a statistically significant difference between the two paired samples at the 5% significance level."
    )
else:
    print(
        "There is no statistically significant difference between the two paired samples at the 5% significance level."
    )

print(f"Mean First Sample: {np.mean(data1)}")
print(f"Mean Second Sample: {np.mean(data2)}")
print(f"Standard Error First Sample: {np.std(data1) / np.sqrt(len(data1))}")
print(f"Standard Error Second Sample: {np.std(data2) / np.sqrt(len(data2))}")
print(f"Mean Difference: {np.mean(data1 - data2)}")
print(f"Standard Error Difference: {np.std(data1 - data2) / np.sqrt(len(data1))}")

T-statistic: 5.15935420116019
Two-tailed P-value: 1.0671446175294915e-05
There is a statistically significant difference between the two paired samples at the 5% significance level.
Mean First Sample: 0.24652416043918732
Mean Second Sample: 0.14889906025083952
Standard Error First Sample: 0.020744339470079483
Standard Error Second Sample: 0.007618757463565077
Mean Difference: 0.09762510018834776
Standard Error Difference: 0.01864968860435012


## Simulation of CMR Fitted to Lohnas Kahana 2014 Dataset

In [10]:
from jaxcmr_research.helpers.hdf5 import simulate_h5_from_h5
from jax import random
from jaxcmr_research.helpers.hdf5 import generate_trial_mask, load_data
from jaxcmr_research.helpers.misc import summarize_parameters, import_from_string
import numpy as np
from jaxcmr_research.helpers.array import compute_similarity_matrix
from jax import numpy as jnp
import json
from IPython.display import Markdown  # type: ignore

data_name = "Cond34LohnasKahana2014"
data_path = "data/LohnasKahana2014.h5"
data_query = "data['list_type'] >= 3"
connection_path = "data/peers-all-mpnet-base-v2.npy"
experiment_count = 100
seed = 0
fit_result_path = (
    "notebooks/Model_Fitting/Cond34LohnasKahana2014_BaseCMR_Model_Fitting.json"
)


data = load_data(data_path)
trial_mask = generate_trial_mask(data, data_query)
embeddings = np.load(connection_path)
connections = compute_similarity_matrix(embeddings)  # unused here
model_factory_path = "jaxcmr_research.cmr.BaseCMRFactory"
model_factory = import_from_string(model_factory_path)
with open(fit_result_path, "r") as f:
    results = json.load(f)
    if "subject" not in results["fits"]:
        results["fits"]["subject"] = results["subject"]


Markdown(summarize_parameters([results], None, include_std=True, include_ci=True))

| | | Cond34LohnasKahana2014 BaseCMR Model Fitting |
|---|---|---|
| fitness | mean | 817.31 +/- 70.82 |
| | std | 203.20 |
| primacy scale | mean | 11.76 +/- 8.89 |
| | std | 25.52 |
| mfc choice sensitivity | mean | 1.00 +/- 0.00 |
| | std | 0.00 |
| semantic scale | mean | 0.00 +/- 0.00 |
| | std | 0.00 |
| shared support | mean | 8.95 +/- 5.74 |
| | std | 16.47 |
| semantic choice sensitivity | mean | 0.00 +/- 0.00 |
| | std | 0.00 |
| choice sensitivity | mean | 29.08 +/- 12.52 |
| | std | 35.92 |
| primacy decay | mean | 20.74 +/- 10.79 |
| | std | 30.97 |
| stop probability scale | mean | 0.02 +/- 0.01 |
| | std | 0.02 |
| stop probability growth | mean | 0.23 +/- 0.03 |
| | std | 0.10 |
| mfc trace sensitivity | mean | 1.00 +/- 0.00 |
| | std | 0.00 |
| learning rate | mean | 0.47 +/- 0.08 |
| | std | 0.23 |
| mcf trace sensitivity | mean | 1.00 +/- 0.00 |
| | std | 0.00 |
| item support | mean | 14.62 +/- 7.60 |
| | std | 21.82 |
| encoding drift rate | mean | 0.71 +/- 0.05 |
| | std | 0.15 |
| recall drift rate | mean | 0.91 +/- 0.05 |
| | std | 0.15 |
| start drift rate | mean | 0.61 +/- 0.11 |
| | std | 0.32 |


In [11]:
rng = random.PRNGKey(seed)
rng, rng_iter = random.split(rng)
sim = simulate_h5_from_h5(
    model_factory=model_factory,
    dataset=data,
    connections=connections,
    parameters={key: jnp.array(val) for key, val in results["fits"].items()},
    trial_mask=trial_mask,
    experiment_count=experiment_count,
    rng=rng_iter,
)

In [12]:
subject_values = np.array(
    apply_by_subject(
        to_numba_typed_dict({key: np.array(value) for key, value in sim.items()}),
        generate_trial_mask(sim, data_query),
        repcrp,
        max_repetitions,
        min_lag,
    )
)[:, :, list_length]

subject_values

  return self.actual_lag_transitions / self.possible_lag_transitions


array([[0.37691602, 0.24585006],
       [0.20507914, 0.15685939],
       [0.2324383 , 0.2078476 ],
       [0.14615893, 0.12893448],
       [0.1418526 , 0.12987667],
       [0.26139264, 0.23840021],
       [0.13619954, 0.12609457],
       [0.19366251, 0.18171823],
       [0.25613003, 0.21843415],
       [0.1783601 , 0.18394196],
       [0.38746607, 0.25781398],
       [0.1863354 , 0.16498757],
       [0.19351464, 0.16460851],
       [0.30263329, 0.22005194],
       [0.17774365, 0.14188103],
       [0.29895324, 0.2251898 ],
       [0.34814036, 0.2730383 ],
       [0.28106852, 0.23035675],
       [0.13874346, 0.131178  ],
       [0.16557792, 0.15006693],
       [0.23768206, 0.20293017],
       [0.16209428, 0.17712692],
       [0.05846103, 0.06214786],
       [0.23163217, 0.19284526],
       [0.14168378, 0.12849003],
       [0.11863967, 0.10735954],
       [0.37306226, 0.26053499],
       [0.13778931, 0.14711447],
       [0.13638824, 0.13370119],
       [0.25534163, 0.23407895],
       [0.

In [13]:
import numpy as np
from scipy.stats import ttest_rel

# Example data in an Nx2 array
data = subject_values

# Ensure the data has exactly two columns
if data.shape[1] != 2:
    raise ValueError("Data must have exactly two columns")

# Split the data into two related samples
data1, data2 = data[:, 0], data[:, 1]

# Perform the paired t-test
t_statistic, p_value = ttest_rel(data1, data2)

print(f"T-statistic: {t_statistic}")
print(f"Two-tailed P-value: {p_value}")

# Interpretation
if p_value < 0.05:
    print(
        "There is a statistically significant difference between the two paired samples at the 5% significance level."
    )
else:
    print(
        "There is no statistically significant difference between the two paired samples at the 5% significance level."
    )

print(f"Mean First Sample: {np.mean(data1)}")
print(f"Mean Second Sample: {np.mean(data2)}")
print(f"Standard Error First Sample: {np.std(data1) / np.sqrt(len(data1))}")
print(f"Standard Error Second Sample: {np.std(data2) / np.sqrt(len(data2))}")
print(f"Mean Difference: {np.mean(data1 - data2)}")
print(f"Standard Error Difference: {np.std(data1 - data2) / np.sqrt(len(data1))}")

T-statistic: 5.147517335829471
Two-tailed P-value: 1.1055947734682667e-05
There is a statistically significant difference between the two paired samples at the 5% significance level.
Mean First Sample: 0.2091877481178868
Mean Second Sample: 0.17731821360867364
Standard Error First Sample: 0.0136460584415815
Standard Error Second Sample: 0.008723339539187561
Mean Difference: 0.03186953450921316
Standard Error Difference: 0.006102156456909528
