In [37]:
from numba import config

config.DISABLE_JIT = True

import numpy as np
from psifr.stats import percentile_rank
from compmempy.helpers.transforming_arrays import njit_apply_along_axis
from compmempy.helpers.handling_data import item_to_study_positions, apply_by_subject
from compmempy.helpers.loading_data import to_numba_typed_dict
from jaxcmr_research.helpers.hdf5 import generate_trial_mask, load_data

from compmempy.analyses.repcrp import repcrp


def score_rep_crp(
    data: dict[str, np.ndarray],
    trial_mask: np.ndarray,
    max_repetitions: int = 3,
    min_lag: int = 4,
):
    subject_values = apply_by_subject(
        data,
        trial_mask,
        repcrp,
        max_repetitions,
        min_lag,
    )

    list_length = np.max(data["listLength"])
    # return (np.array(subject_values)[:, :, list_length] + np.array(subject_values)[:, :, list_length+1] + np.array(subject_values)[:, :, list_length+2]) / 3
    return np.array(subject_values)[:, :, list_length]

## Base Dataset

In [48]:
data_name = "LohnasKahana2014"
data_path = "data/LohnasKahana2014.h5"
data_query = "data['list_type'] == 4"

data = to_numba_typed_dict(
    {key: np.array(value) for key, value in load_data(data_path).items()}
)
trial_mask = generate_trial_mask(data, data_query)

max_repetitions = 2
min_lag = 4

In [49]:
subject_values = score_rep_crp(data, trial_mask, max_repetitions, min_lag)

subject_values

array([[0.26315789, 0.15      ],
       [0.16666667, 0.11111111],
       [0.15      , 0.        ],
       [0.18181818, 0.        ],
       [0.11111111, 0.05263158],
       [0.28571429, 0.16666667],
       [0.        , 0.22222222],
       [0.16666667, 0.13043478],
       [0.25      , 0.        ],
       [0.19047619, 0.15384615],
       [0.46875   , 0.03125   ],
       [0.22222222, 0.11111111],
       [0.05263158, 0.04347826],
       [0.23809524, 0.13043478],
       [0.14285714, 0.05      ],
       [0.41176471, 0.        ],
       [0.17391304, 0.04545455],
       [0.16666667, 0.14285714],
       [0.        , 0.        ],
       [0.2       , 0.04761905],
       [0.        , 0.        ],
       [0.19047619, 0.11764706],
       [0.        , 0.05263158],
       [0.33333333, 0.        ],
       [0.        , 0.08333333],
       [0.11111111, 0.        ],
       [0.35      , 0.09090909],
       [0.        , 0.09090909],
       [0.125     , 0.06666667],
       [0.17647059, 0.05      ],
       [0.

In [50]:
difference = subject_values
data1, data2 = difference[:, 0], difference[:, 1]

In [51]:
from scipy.stats import ttest_rel

t_statistic, p_value = ttest_rel(data1, data2, alternative="greater")
print(f"T-statistic: {t_statistic}")
print(f"One-tailed P-value: {p_value}")

if p_value < 0.05:
    print(
        "There is a statistically significant difference between the two paired samples at the 5% significance level."
    )
else:
    print(
        "There is no statistically significant difference between the two paired samples at the 5% significance level."
    )


print(f"Mean First Sample: {np.mean(data1)}")
print(f"Mean Second Sample: {np.mean(data2)}")
print(f"Standard Error First Sample: {np.std(data1) / np.sqrt(len(data1))}")
print(f"Standard Error Second Sample: {np.std(data2) / np.sqrt(len(data2))}")
print(f"Mean Difference: {np.mean(data1 - data2)}")
print(f"Standard Error Difference: {np.std(data1 - data2) / np.sqrt(len(data1))}")

T-statistic: 3.7039144020051795
One-tailed P-value: 0.00037491219571755445
There is a statistically significant difference between the two paired samples at the 5% significance level.
Mean First Sample: 0.17126714961640352
Mean Second Sample: 0.08479366241223729
Standard Error First Sample: 0.01979637485332103
Standard Error Second Sample: 0.01289293099460334
Mean Difference: 0.08647348720416619
Standard Error Difference: 0.023010574825834136


In [40]:
from jaxcmr_research.helpers.repetition import control_dataset

ctrl_data = control_dataset(
    to_numba_typed_dict({key: np.array(value) for key, value in data.items()}),
    "data['list_type'] == 4",
    "data['list_type'] == 1",
    100,
)

ctrl_subject_values = score_rep_crp(
    ctrl_data,
    generate_trial_mask(ctrl_data, "data['subject'] != -1"),
    max_repetitions,
    min_lag,
)

ctrl_subject_values

35it [00:00, 1373.29it/s]


array([[0.21573698, 0.19689462],
       [0.09202013, 0.16529492],
       [0.1202211 , 0.15475649],
       [0.18223235, 0.06800446],
       [0.08936826, 0.13654189],
       [0.15721531, 0.15931907],
       [0.17507418, 0.17907445],
       [0.09120699, 0.08237644],
       [0.15384615, 0.31143399],
       [0.1085297 , 0.07787934],
       [0.41376812, 0.23415265],
       [0.09529652, 0.11083229],
       [0.10440395, 0.12709832],
       [0.32299084, 0.22008253],
       [0.05882353, 0.03554377],
       [0.21194503, 0.25513196],
       [0.20602767, 0.10147601],
       [0.2166157 , 0.14165103],
       [0.06846673, 0.06994329],
       [0.10475651, 0.08972353],
       [0.09467456, 0.17736185],
       [0.13864307, 0.1046832 ],
       [0.06886228, 0.05563798],
       [0.1671415 , 0.18703108],
       [0.10068493, 0.09970015],
       [0.16198126, 0.02203182],
       [0.22889007, 0.18959811],
       [0.13362069, 0.06505421],
       [0.11565696, 0.15073815],
       [0.1191446 , 0.11994003],
       [0.

In [41]:
difference = subject_values - ctrl_subject_values
data1, data2 = difference[:, 0], difference[:, 1]

In [42]:
from scipy.stats import ttest_rel

t_statistic, p_value = ttest_rel(data1, data2, alternative="greater")
print(f"T-statistic: {t_statistic}")
print(f"One-tailed P-value: {p_value}")

if p_value < 0.05:
    print(
        "There is a statistically significant difference between the two paired samples at the 5% significance level."
    )
else:
    print(
        "There is no statistically significant difference between the two paired samples at the 5% significance level."
    )


print(f"Mean First Sample: {np.mean(data1)}")
print(f"Mean Second Sample: {np.mean(data2)}")
print(f"Standard Error First Sample: {np.std(data1) / np.sqrt(len(data1))}")
print(f"Standard Error Second Sample: {np.std(data2) / np.sqrt(len(data2))}")
print(f"Mean Difference: {np.mean(data1 - data2)}")
print(f"Standard Error Difference: {np.std(data1 - data2) / np.sqrt(len(data1))}")

T-statistic: 2.8377861275499554
One-tailed P-value: 0.0038033671943909953
There is a statistically significant difference between the two paired samples at the 5% significance level.
Mean First Sample: 0.019395124297992707
Mean Second Sample: -0.05218501489244075
Standard Error First Sample: 0.014968251854968036
Standard Error Second Sample: 0.015492476579109554
Mean Difference: 0.07158013919043346
Standard Error Difference: 0.024860983971650943


## Base CMR

In [43]:
from jaxcmr_research.helpers.hdf5 import simulate_h5_from_h5
from jax import random
from jaxcmr_research.helpers.hdf5 import generate_trial_mask, load_data
from jaxcmr_research.helpers.misc import summarize_parameters, import_from_string
import numpy as np
from jaxcmr_research.helpers.array import compute_similarity_matrix
from jax import numpy as jnp
import json
from IPython.display import Markdown  # type: ignore

data_name = "Cond34LohnasKahana2014"
data_path = "data/LohnasKahana2014.h5"
data_query = "data['list_type'] >= 3"
connection_path = "data/peers-all-mpnet-base-v2.npy"
experiment_count = 1
seed = 0
fit_result_path = (
    "notebooks/Model_Fitting/Cond4LohnasKahana2014_BaseCMR_Model_Fitting.json"
)


data = load_data(data_path)
trial_mask = generate_trial_mask(data, data_query)
embeddings = np.load(connection_path)
connections = compute_similarity_matrix(embeddings)  # unused here
model_factory_path = "jaxcmr_research.cmr.BaseCMRFactory"
model_factory = import_from_string(model_factory_path)
with open(fit_result_path, "r") as f:
    results = json.load(f)
    if "subject" not in results["fits"]:
        results["fits"]["subject"] = results["subject"]


Markdown(summarize_parameters([results], None, include_std=True, include_ci=True))

| | | Cond4LohnasKahana2014 BaseCMR Model Fitting |
|---|---|---|
| fitness | mean | 470.79 +/- 50.44 |
| | std | 144.72 |
| mfc trace sensitivity | mean | 1.00 +/- 0.00 |
| | std | 0.00 |
| start drift rate | mean | 0.45 +/- 0.12 |
| | std | 0.35 |
| mcf trace sensitivity | mean | 1.00 +/- 0.00 |
| | std | 0.00 |
| recall drift rate | mean | 0.89 +/- 0.06 |
| | std | 0.17 |
| primacy decay | mean | 13.34 +/- 9.57 |
| | std | 27.46 |
| semantic choice sensitivity | mean | 0.00 +/- 0.00 |
| | std | 0.00 |
| encoding drift rate | mean | 0.73 +/- 0.06 |
| | std | 0.16 |
| learning rate | mean | 0.39 +/- 0.07 |
| | std | 0.21 |
| item support | mean | 8.18 +/- 2.96 |
| | std | 8.49 |
| stop probability scale | mean | 0.02 +/- 0.01 |
| | std | 0.03 |
| mfc choice sensitivity | mean | 1.00 +/- 0.00 |
| | std | 0.00 |
| primacy scale | mean | 18.93 +/- 9.81 |
| | std | 28.15 |
| semantic scale | mean | 0.00 +/- 0.00 |
| | std | 0.00 |
| choice sensitivity | mean | 35.07 +/- 12.97 |
| | std | 37.22 |
| stop probability growth | mean | 0.24 +/- 0.04 |
| | std | 0.11 |
| shared support | mean | 7.86 +/- 5.05 |
| | std | 14.49 |


In [44]:
rng = random.PRNGKey(seed)
rng, rng_iter = random.split(rng)
sim = simulate_h5_from_h5(
    model_factory=model_factory,
    dataset=data,
    connections=connections,
    parameters={key: jnp.array(val) for key, val in results["fits"].items()},
    trial_mask=generate_trial_mask(data, 'data["subject"] != -1'),
    experiment_count=experiment_count,
    rng=rng_iter,
)

In [45]:
subject_values = score_rep_crp(
    sim, generate_trial_mask(sim, "data['list_type'] == 4"), max_repetitions, min_lag
)

ctrl_sim = control_dataset(
    to_numba_typed_dict({key: np.array(val) for key, val in sim.items()}),
    "data['list_type'] == 4",
    "data['list_type'] == 1",
    100,
)

ctrl_subject_values = score_rep_crp(
    ctrl_sim,
    generate_trial_mask(ctrl_sim, "data['subject'] != -1"),
    max_repetitions,
    min_lag,
)

  return self.actual_lag_transitions / self.possible_lag_transitions
35it [00:00, 1379.58it/s]
  return self.actual_lag_transitions / self.possible_lag_transitions


In [46]:
difference = subject_values - ctrl_subject_values
data1, data2 = difference[:, 0], difference[:, 1]

In [47]:
from scipy.stats import ttest_rel

t_statistic, p_value = ttest_rel(data1, data2, alternative="greater")
print(f"T-statistic: {t_statistic}")
print(f"One-tailed P-value: {p_value}")

if p_value < 0.05:
    print(
        "There is a statistically significant difference between the two paired samples at the 5% significance level."
    )
else:
    print(
        "There is no statistically significant difference between the two paired samples at the 5% significance level."
    )


print(f"Mean First Sample: {np.mean(data1)}")
print(f"Mean Second Sample: {np.mean(data2)}")
print(f"Standard Error First Sample: {np.std(data1) / np.sqrt(len(data1))}")
print(f"Standard Error Second Sample: {np.std(data2) / np.sqrt(len(data2))}")
print(f"Mean Difference: {np.mean(data1 - data2)}")
print(f"Standard Error Difference: {np.std(data1 - data2) / np.sqrt(len(data1))}")

T-statistic: 1.0720050865921524
One-tailed P-value: 0.14563413572570685
There is no statistically significant difference between the two paired samples at the 5% significance level.
Mean First Sample: 0.04147011354341045
Mean Second Sample: 0.014878364318236576
Standard Error First Sample: 0.01500577943596517
Standard Error Second Sample: 0.015600549327887013
Mean Difference: 0.02659174922517387
Standard Error Difference: 0.024448684532896305
