Let's try something a little more incremental. Let's get a faster initialization first.

## Model

In [1]:
# export
# hide

import numpy as np
import math
from numba import float64, int32, int64, boolean
from numba.experimental import jitclass

cmr_spec = [
    ('item_count', int32), 
    ('encoding_drift_rate', float64),
    ('start_drift_rate', float64),
    ('recall_drift_rate', float64),
    ('shared_support', float64),
    ('item_support', float64),
    ('learning_rate', float64),
    ('primacy_scale', float64),
    ('primacy_decay', float64),
    ('stop_probability_scale', float64),
    ('stop_probability_growth', float64),
    ('choice_sensitivity', float64),
    ('context', float64[:, ::1]),
    ('recall', int32[:, ::1]),
    ('retrieving', boolean),
    ('recall_total', int32),
    ('primacy_weighting', float64[::1]),
    ('probabilities', float64[:, ::1]),
    ('mfc', float64[:, :, ::1]),
    ('mcf', float64[:, :, ::1]),
    ('encoding_index', int32),
    ('items', float64[:,::1]),
    ('trial_count', int64),
    ('context_input', float64[:,::1])
]

In [2]:
# export

@jitclass(cmr_spec)
class CMR:

    def __init__(self, item_count, presentation_count, trial_count, 
                 encoding_drift_rate, start_drift_rate, recall_drift_rate, 
                 shared_support, item_support, learning_rate, primacy_scale, 
                 primacy_decay, stop_probability_scale, 
                 stop_probability_growth, choice_sensitivity):
        
        # store initial parameters
        self.item_count = item_count
        self.encoding_drift_rate = encoding_drift_rate
        self.start_drift_rate = start_drift_rate
        self.recall_drift_rate = recall_drift_rate
        self.shared_support = shared_support
        self.item_support = item_support
        self.learning_rate = learning_rate
        self.primacy_scale = primacy_scale
        self.primacy_decay = primacy_decay
        self.stop_probability_scale = stop_probability_scale
        self.stop_probability_growth = stop_probability_growth
        self.choice_sensitivity = choice_sensitivity
        
        # at the start of the list context is initialized with a state 
        # orthogonal to the pre-experimental context associated with the items
        self.context = np.zeros((trial_count, item_count+1))
        self.context[:, 0] = 1
        self.recall = np.zeros((trial_count, item_count), dtype='int32') # preallocation
        self.retrieving = False
        self.recall_total = 0
        
        # predefine primacy weighting vectors
        self.primacy_weighting = primacy_scale * np.exp(
            -primacy_decay * np.arange(presentation_count)) + 1

        # preallocate for outcome_probabilities
        self.probabilities = np.zeros((trial_count, item_count+1))
        
        # The two layers communicate with one another through two sets of 
        # associative connections represented by matrices Mfc and Mcf. Pre-
        # experimental Mfc is 1-learning_rate and pre-experimental Mcf is 
        # item_support for i=j. For i!=j, Mcf is shared_support.
        self.mfc = np.zeros((trial_count, item_count+1, item_count))
        self.mfc[:,] = np.eye(item_count+1, item_count, -1) * (1 - learning_rate)
        self.mcf = np.zeros((trial_count, item_count+1, item_count))
        self.mcf[:,] = np.ones((item_count+1, item_count)) * shared_support
        for i in range(item_count):
            self.mcf[:, i+1, i] = item_support
        self.mcf[:,1,:] = 0
        self.encoding_index = 0
        self.items = np.eye(item_count, item_count)
        self.trial_count = trial_count
        self.context_input = np.zeros((trial_count, self.item_count+1))
        
    def experience(self, experiences):
        for i in range(len(experiences[0])):
            self.update_context(self.encoding_drift_rate, experiences[i, :])
            
            for j in range(self.trial_count):
                self.mfc[j] += self.learning_rate * np.outer(
                    self.context[j], experiences[i, j])
                self.mcf[j] += self.primacy_weighting[
                    self.encoding_index] * np.outer(
                    self.context[j], experiences[i, j])
                
            self.encoding_index += 1
            
    def update_context(self, drift_rate, experience=None):

        # first pre-experimental or initial context is retrieved
        self.context_input[:] = 0
        if experience is not None:
            self.context_input = np.sum(self.mfc * experience.reshape((-1, *np.shape(experience))), axis=2)
            
            for i in range(self.trial_count):
                #self.context_input[i, :] = np.sum(experience[i] * self.mfc[i], axis=1)
                self.context_input[i, :] = self.context_input[i, :] / np.sqrt(
                    np.sum(np.square(self.context_input[i, :]))) # make len 1
        else:
            self.context_input[:, 0] = 1

        # updated context is sum of context and input, modulated by rho to 
        # have len 1 and some drift_rate
        rho = np.sqrt(1 + np.square(drift_rate) * (np.square(
            self.context * self.context_input) - 1)) - (drift_rate * (
            self.context * self.context_input))
        self.context = (rho * self.context) + (drift_rate * self.context_input)
        
    def activations(self, probe, use_mfc=False):
        
        if use_mfc:
            activation = np.zeros((self.trial_count, len(self.mfc[0, 0])))
            for i in range(self.trial_count):
                activation[i] = np.dot(probe[i], self.mfc[i]) + 10e-7
        else:
            activation = np.zeros((self.trial_count, len(self.mcf[0, 0])))
            for i in range(self.trial_count):
                activation[i] = np.dot(probe[i], self.mcf[i]) + 10e-7
        return activation
        
    def outcome_probabilities(self, activation_cue):

        activation = self.activations(activation_cue)
        activation = np.power(activation, self.choice_sensitivity)

        self.probabilities[:, 1:] = 0
        self.probabilities[:, 0] = min(self.stop_probability_scale * np.exp(
            self.recall_total * self.stop_probability_growth), 1.0  - (
            self.item_count * 10e-7))
        
        # also set stop probability to 1 where recall has terminated
        if self.recall_total > 0:
            self.probabilities[self.recall[:, self.recall_total-1] == 0, 0] = 1
        
        # track for each trial whether recall termination is guaranteed or not
        termination_not_guaranteed = self.probabilities[:, 0] < 1
        
        # suppress activation for already recalled items to 0
        for trial_index in range(self.trial_count):
            if termination_not_guaranteed[trial_index]:
                for each in self.recall[trial_index, :self.recall_total]:
                    activation[trial_index, each-1] = 0
                self.probabilities[trial_index, 1:] = (
                    1-self.probabilities[trial_index, 0]) * activation[trial_index] / np.sum(activation[trial_index])

        return self.probabilities
    
    def force_recall(self, choice=None):
        
        if not self.retrieving:
            self.update_context(self.start_drift_rate)
            self.retrieving = True

        if choice is None:
            pass
        else:
            self.recall[:, self.recall_total] = choice
            self.recall_total += 1
            self.update_context(self.recall_drift_rate, self.items[choice - 1])

        return self.recall[:, :self.recall_total]

## Test

In [3]:
import numpy as np
from numba import njit
from numba.typed import List
from numba.typed import List

@njit(fastmath=True, nogil=True)
def cmr_murd_likelihood(
    data_to_fit, item_counts, encoding_drift_rate, start_drift_rate, 
    recall_drift_rate, shared_support, item_support, learning_rate, 
    primacy_scale, primacy_decay, stop_probability_scale, 
    stop_probability_growth, choice_sensitivity):
    
    result = 0.0
    for i in range(len(item_counts)):
        item_count = item_counts[i]
        trials = data_to_fit[i]
        
        model = CMR(item_count, item_count, len(trials), encoding_drift_rate, 
                    start_drift_rate, recall_drift_rate, shared_support,
                    item_support, learning_rate, primacy_scale, 
                    primacy_decay, stop_probability_scale, 
                    stop_probability_growth, choice_sensitivity)
        
        # same sequence of experiences across trials
        experiences = np.zeros((len(trials), item_count, item_count))
        experiences[:,] = np.eye(item_count, item_count)
        model.experience(experiences.T.copy())

In [4]:
from instance_cmr.datasets import *

murd_trials0, murd_events0, murd_length0 = prepare_murddata(
    '../data/MurdData_clean.mat', 0)
print(murd_length0, np.shape(murd_trials0))

murd_events0.head()

20 (1200, 15)


Unnamed: 0,subject,list,item,input,output,study,recall,repeat,intrusion
0,1,1,1,1,5.0,True,True,0,False
1,1,1,2,2,7.0,True,True,0,False
2,1,1,3,3,,True,False,0,False
3,1,1,4,4,,True,False,0,False
4,1,1,5,5,,True,False,0,False


In [5]:
lb = np.finfo(float).eps
hand_fit_parameters = {
    'item_counts': List([murd_length0]),
    'encoding_drift_rate': .8,
    'start_drift_rate': .7,
    'recall_drift_rate': .8,
    'shared_support': 0.01,
    'item_support': 1.0,
    'learning_rate': .3,
    'primacy_scale': 1,
    'primacy_decay': 1,
    'stop_probability_scale': 0.01,
    'stop_probability_growth': 0.3,
    'choice_sensitivity': 2
}
cmr_murd_likelihood(List([murd_trials0[:80]]), **hand_fit_parameters)

ValueError: unable to broadcast argument 1 to output array
File "<ipython-input-2-8bfbb419f176>", line 1, 

In [None]:
%%timeit
cmr_murd_likelihood(List([murd_trials0[:80]]), **hand_fit_parameters)

experience (80, 20)
mfc (80, 21, 20)
context_input (80, 21)

In [None]:
mfc = np.arange(24).reshape((2, 3, 4))
mfc

In [None]:
experience = np.arange(8).reshape((2, 4))
experience

In [None]:
context_input = np.zeros((2, 3))

In [None]:
%%timeit

for i in range(2):
    context_input[i, :] = np.sum(experience[i] * mfc[i], axis=1)

In [None]:
%%timeit
np.sum(mfc * experience[:, np.newaxis], axis=2)