In [13]:
from numba import config

config.DISABLE_JIT = True

import numpy as np
from psifr.stats import percentile_rank
from compmempy.helpers.transforming_arrays import njit_apply_along_axis
from compmempy.helpers.handling_data import item_to_study_positions, apply_by_subject
from compmempy.helpers.loading_data import to_numba_typed_dict
from jaxcmr_research.helpers.hdf5 import generate_trial_mask, load_data

from compmempy.analyses.repcrp import repcrp


def score_rep_crp(
    data: dict[str, np.ndarray],
    trial_mask: np.ndarray,
    max_repetitions: int = 3,
    min_lag: int = 4,
):
    subject_values = apply_by_subject(
        data,
        trial_mask,
        repcrp,
        max_repetitions,
        min_lag,
    )

    list_length = np.max(data["listLength"])
    # return (np.array(subject_values)[:, :, list_length] + np.array(subject_values)[:, :, list_length+1] + np.array(subject_values)[:, :, list_length+2]) / 3
    return np.array(subject_values)[:, :, list_length]

## Base Dataset

In [14]:
data_name = "HowardKahana2005"
data_path = "data/HowardKahana2005.h5"
data_query = "data['condition'] > 0"

data = to_numba_typed_dict(
    {key: np.array(value) for key, value in load_data(data_path).items()}
)
trial_mask = generate_trial_mask(data, data_query)

max_repetitions = 3
min_lag = 4

In [15]:
subject_values = score_rep_crp(data, trial_mask, max_repetitions, min_lag)

subject_values

array([[0.15384615, 0.1       , 0.16216216],
       [0.09615385, 0.05660377, 0.08510638],
       [0.09375   , 0.12307692, 0.05172414],
       [0.17460317, 0.17647059, 0.14925373],
       [0.45833333, 0.29464286, 0.13114754],
       [0.16483516, 0.14285714, 0.09677419],
       [0.07692308, 0.14634146, 0.05882353],
       [0.06521739, 0.09090909, 0.20512821],
       [0.06153846, 0.12698413, 0.18333333],
       [0.07142857, 0.16      , 0.08      ],
       [0.12      , 0.15686275, 0.02777778],
       [0.11290323, 0.12121212, 0.10294118],
       [0.11764706, 0.09433962, 0.0625    ],
       [0.08333333, 0.13043478, 0.13636364],
       [0.15730337, 0.11827957, 0.13580247],
       [0.05      , 0.01754386, 0.16981132],
       [0.10714286, 0.03846154, 0.10344828],
       [0.06521739, 0.16326531, 0.06666667],
       [0.10169492, 0.12068966, 0.09259259],
       [0.        , 0.0952381 , 0.10526316],
       [0.27272727, 0.15      , 0.13636364],
       [0.15625   , 0.07407407, 0.03703704],
       [0.

In [16]:
difference = subject_values
data1, data2 = difference[:, 0], difference[:, 1]

In [17]:
from scipy.stats import ttest_rel

t_statistic, p_value = ttest_rel(data1, data2, alternative="greater")
print(f"T-statistic: {t_statistic}")
print(f"One-tailed P-value: {p_value}")

if p_value < 0.05:
    print(
        "There is a statistically significant difference between the two paired samples at the 5% significance level."
    )
else:
    print(
        "There is no statistically significant difference between the two paired samples at the 5% significance level."
    )


print(f"Mean First Sample: {np.mean(data1)}")
print(f"Mean Second Sample: {np.mean(data2)}")
print(f"Standard Error First Sample: {np.std(data1) / np.sqrt(len(data1))}")
print(f"Standard Error Second Sample: {np.std(data2) / np.sqrt(len(data2))}")
print(f"Mean Difference: {np.mean(data1 - data2)}")
print(f"Standard Error Difference: {np.std(data1 - data2) / np.sqrt(len(data1))}")

T-statistic: 2.9086583323134523
One-tailed P-value: 0.0024814113527138544
There is a statistically significant difference between the two paired samples at the 5% significance level.
Mean First Sample: 0.11812624891416693
Mean Second Sample: 0.09773469670491197
Standard Error First Sample: 0.007758963650018588
Standard Error Second Sample: 0.005670229294661995
Mean Difference: 0.02039155220925494
Standard Error Difference: 0.006957324927386909


## Base CMR

In [19]:
from jaxcmr_research.helpers.hdf5 import simulate_h5_from_h5
from jax import random
from jaxcmr_research.helpers.hdf5 import generate_trial_mask, load_data
from jaxcmr_research.helpers.misc import summarize_parameters, import_from_string
import numpy as np
from jaxcmr_research.helpers.array import compute_similarity_matrix
from jax import numpy as jnp
import json
from IPython.display import Markdown  # type: ignore

data_name = "HowardKahana2005"
data_path = "data/HowardKahana2005.h5"
data_query = "data['condition'] > 0"
connection_path = "data/peers-all-mpnet-base-v2.npy"
experiment_count = 1
seed = 0
fit_result_path = (
    "notebooks/Model_Fitting//HowardKahana2005_BaseCMR_Model_Fitting.json"
)


data = load_data(data_path)
trial_mask = generate_trial_mask(data, data_query)
embeddings = np.load(connection_path)
connections = compute_similarity_matrix(embeddings)  # unused here
model_factory_path = "jaxcmr_research.cmr.BaseCMRFactory"
model_factory = import_from_string(model_factory_path)
with open(fit_result_path, "r") as f:
    results = json.load(f)
    if "subject" not in results["fits"]:
        results["fits"]["subject"] = results["subject"]


Markdown(summarize_parameters([results], None, include_std=True, include_ci=True))

| | | HowardKahana2005 BaseCMR Model Fitting |
|---|---|---|
| fitness | mean | 323.03 +/- 23.84 |
| | std | 96.23 |
| mfc choice sensitivity | mean | 1.00 +/- 0.00 |
| | std | 0.00 |
| start drift rate | mean | 0.47 +/- 0.08 |
| | std | 0.33 |
| semantic scale | mean | 0.00 +/- 0.00 |
| | std | 0.00 |
| primacy scale | mean | 12.77 +/- 5.16 |
| | std | 20.83 |
| choice sensitivity | mean | 51.03 +/- 7.74 |
| | std | 31.26 |
| semantic choice sensitivity | mean | 0.00 +/- 0.00 |
| | std | 0.00 |
| stop probability growth | mean | 0.32 +/- 0.03 |
| | std | 0.11 |
| item support | mean | 30.95 +/- 6.28 |
| | std | 25.34 |
| mcf trace sensitivity | mean | 1.00 +/- 0.00 |
| | std | 0.00 |
| shared support | mean | 23.77 +/- 4.86 |
| | std | 19.63 |
| stop probability scale | mean | 0.02 +/- 0.01 |
| | std | 0.02 |
| mfc trace sensitivity | mean | 1.00 +/- 0.00 |
| | std | 0.00 |
| learning rate | mean | 0.23 +/- 0.08 |
| | std | 0.31 |
| encoding drift rate | mean | 0.61 +/- 0.06 |
| | std | 0.26 |
| primacy decay | mean | 31.36 +/- 7.68 |
| | std | 30.99 |
| recall drift rate | mean | 0.84 +/- 0.05 |
| | std | 0.19 |


In [20]:
rng = random.PRNGKey(seed)
rng, rng_iter = random.split(rng)
sim = simulate_h5_from_h5(
    model_factory=model_factory,
    dataset=data,
    connections=connections,
    parameters={key: jnp.array(val) for key, val in results["fits"].items()},
    trial_mask=generate_trial_mask(data, 'data["subject"] != -1'),
    experiment_count=experiment_count,
    rng=rng_iter,
)

In [21]:
subject_values = score_rep_crp(
    sim, generate_trial_mask(sim, "data['condition'] > 0"), max_repetitions, min_lag
)

  return self.actual_lag_transitions / self.possible_lag_transitions


In [22]:
difference = subject_values
data1, data2 = difference[:, 0], difference[:, 1]

In [23]:
from scipy.stats import ttest_rel

t_statistic, p_value = ttest_rel(data1, data2, alternative="greater")
print(f"T-statistic: {t_statistic}")
print(f"One-tailed P-value: {p_value}")

if p_value < 0.05:
    print(
        "There is a statistically significant difference between the two paired samples at the 5% significance level."
    )
else:
    print(
        "There is no statistically significant difference between the two paired samples at the 5% significance level."
    )


print(f"Mean First Sample: {np.mean(data1)}")
print(f"Mean Second Sample: {np.mean(data2)}")
print(f"Standard Error First Sample: {np.std(data1) / np.sqrt(len(data1))}")
print(f"Standard Error Second Sample: {np.std(data2) / np.sqrt(len(data2))}")
print(f"Mean Difference: {np.mean(data1 - data2)}")
print(f"Standard Error Difference: {np.std(data1 - data2) / np.sqrt(len(data1))}")

T-statistic: 0.9331725226465978
One-tailed P-value: 0.17709222753926546
There is no statistically significant difference between the two paired samples at the 5% significance level.
Mean First Sample: 0.09806671341852317
Mean Second Sample: 0.09234026106294912
Standard Error First Sample: 0.006925938336803495
Standard Error Second Sample: 0.00690530373562172
Mean Difference: 0.00572645235557405
Standard Error Difference: 0.006089875579766138
