# Fitting CMR

## Model

In [1]:
import numpy as np
from numba import float64, int32, boolean
from numba.experimental import jitclass
from compmemlearn.familiarity import familiarity_weighting

cmr_spec = [
    ('item_count', int32), 
    ('encoding_drift_rate', float64),
    ('delay_drift_rate', float64),
    ('start_drift_rate', float64),
    ('recall_drift_rate', float64),
    ('shared_support', float64),
    ('item_support', float64),
    ('learning_rate', float64),
    ('primacy_scale', float64),
    ('primacy_decay', float64),
    ('stop_probability_scale', float64),
    ('stop_probability_growth', float64),
    ('choice_sensitivity', float64),
    ('context', float64[::1]),
    ('start_context_input', float64[::1]),
    ('delay_context_input', float64[::1]),
    ('preretrieval_context', float64[::1]),
    ('recall', int32[::1]),
    ('retrieving', boolean),
    ('recall_total', int32),
    ('primacy_weighting', float64[::1]),
    ('probabilities', float64[::1]),
    ('mfc', float64[:,::1]),
    ('mcf', float64[:,::1]),
    ('encoding_index', int32),
    ('items', float64[:,::1]),
    ('drift_familiarity_scale', float64),
    ('mfc_familiarity_scale', float64),
    ('mcf_familiarity_scale', float64),
    ('sampling_rule', int32)
]

In [2]:
import numpy as np
from numba import float64, int32, boolean
from numba.experimental import jitclass

icmr_spec = [
    ('item_count', int32), 
    ('encoding_drift_rate', float64),
    ('start_drift_rate', float64),
    ('recall_drift_rate', float64),
    ('shared_support', float64),
    ('item_support', float64),
    ('learning_rate', float64),
    ('primacy_scale', float64),
    ('primacy_decay', float64),
    ('stop_probability_scale', float64),
    ('stop_probability_growth', float64),
    ('choice_sensitivity', float64),
    ('context_sensitivity', float64),
    ('feature_sensitivity', float64),
    ('context', float64[::1]),
    ('preretrieval_context', float64[::1]),
    ('recall', float64[::1]),
    ('retrieving', boolean),
    ('recall_total', int32),
    ('item_weighting', float64[::1]),
    ('context_weighting', float64[::1]),
    ('all_weighting', float64[::1]),
    ('probabilities', float64[::1]),
    ('memory', float64[:,::1]),
    ('encoding_index', int32),
    ('items', float64[:,::1])
]

In [3]:
@jitclass(cmr_spec)
class Classic_CMR:

    def __init__(self, item_count, presentation_count, parameters):

        # store initial parameters
        self.item_count = item_count
        self.encoding_drift_rate = parameters['encoding_drift_rate']
        self.delay_drift_rate = parameters['delay_drift_rate']
        self.start_drift_rate = parameters['start_drift_rate']
        self.recall_drift_rate = parameters['recall_drift_rate']
        self.shared_support = parameters['shared_support']
        self.item_support = parameters['item_support']
        self.learning_rate = parameters['learning_rate']
        self.primacy_scale = parameters['primacy_scale']
        self.primacy_decay = parameters['primacy_decay']
        self.stop_probability_scale = parameters['stop_probability_scale']
        self.stop_probability_growth = parameters['stop_probability_growth']
        self.choice_sensitivity = parameters['choice_sensitivity']
        self.drift_familiarity_scale = parameters['drift_familiarity_scale']
        self.mfc_familiarity_scale = parameters['mfc_familiarity_scale']
        self.mcf_familiarity_scale = parameters['mcf_familiarity_scale']
        self.sampling_rule = parameters['sampling_rule']
        
        # at the start of the list context is initialized with a state 
        # orthogonal to the pre-experimental context
        # associated with the set of items
        self.context = np.zeros(item_count + 2)
        self.context[0] = 1
        self.preretrieval_context = self.context
        self.recall = np.zeros(item_count, int32) # recalls has at most `item_count` entries
        self.retrieving = False
        self.recall_total = 0

        # predefine primacy weighting vectors
        self.primacy_weighting = parameters['primacy_scale'] * np.exp(
            -parameters['primacy_decay'] * np.arange(presentation_count)) + 1

        # preallocate for outcome_probabilities
        self.probabilities = np.zeros((item_count + 1))

        # predefine contextual input vectors relevant for delay_drift_rate and start_drift_rate parameters
        self.start_context_input = np.zeros((self.item_count+2))
        self.start_context_input[0] = 1
        self.delay_context_input = np.zeros((self.item_count+2))
        self.delay_context_input[-1] = 1

        # The two layers communicate with one another through two sets of 
        # associative connections represented by matrices Mfc and Mcf. 
        # Pre-experimental Mfc is 1-learning_rate and pre-experimental Mcf is
        # item_support for i=j. For i!=j, Mcf is shared_support.
        self.mfc = np.eye(item_count, item_count+2, 1) * (1-self.learning_rate)
        self.mcf = np.ones((item_count+1, item_count)) * self.shared_support
        for i in range(item_count):
            self.mcf[i, i] = self.item_support
        self.mcf =  np.vstack((np.zeros((1, item_count)), self.mcf))
        self.encoding_index = 0
        self.items = np.eye(item_count, item_count)

    def experience(self, experiences):
        
        for i in range(len(experiences)):
            
            mfc_familiarity = 1
            mcf_familiarity = 1
            if np.any(self.context[1:-1] > 0) and ((self.mfc_familiarity_scale != 0) or (self.mcf_familiarity_scale != 0)):
                
                similarity = np.dot(self.context[1:-1], experiences[i]) / (
                    np.sqrt(np.dot(self.context[1:-1], self.context[1:-1])) * np.sqrt(
                        np.dot(experiences[i],experiences[i])))
            
                mfc_familiarity = familiarity_weighting(self.mfc_familiarity_scale, similarity)
                mcf_familiarity = familiarity_weighting(self.mcf_familiarity_scale, similarity)

            self.update_context(self.encoding_drift_rate, experiences[i])
            self.mfc += mfc_familiarity * self.learning_rate * np.outer(self.context, experiences[i]).T
            self.mcf += mcf_familiarity * self.primacy_weighting[self.encoding_index] * np.outer(
                self.context, experiences[i])
            self.encoding_index += 1

    def update_context(self, drift_rate, experience=None):

        # first pre-experimental or initial context is retrieved
        familiarity = 1
        if len(experience) == len(self.mfc):
            
            if np.any(self.context[1:-1] > 0) and self.drift_familiarity_scale != 0:
                 
                 similarity = np.dot(self.context[1:-1], experience) / (
                     np.sqrt(np.dot(self.context[1:-1], self.context[1:-1])) * np.sqrt(
                          np.dot(experience,experience)))
                 familiarity = familiarity_weighting(self.drift_familiarity_scale, similarity)

            context_input = np.dot(experience, self.mfc)
            context_input = np.power(context_input, familiarity)

            # if the context is not pre-experimental, the context is retrieved
            context_input = context_input / np.sqrt(np.sum(np.square(context_input))) # norm to length 1
            
        else:
            context_input = experience
  
        # updated context is sum of context and input, modulated by rho to have len 1 and some drift_rate
        rho = np.sqrt(1 + np.square(min(drift_rate, 1.0)) * (
            np.square(self.context * context_input) - 1)) - (
                min(drift_rate, 1.0) * (self.context * context_input))
        self.context = (rho * self.context) + (min(drift_rate, 1.0) * context_input)

    def activations(self, probe, use_mfc=False):

        if use_mfc:
            return np.dot(probe, self.mfc) + 10e-7
        else:
            return np.dot(probe, self.mcf) + 10e-7
        
    def outcome_probabilities(self):

        # measure activation for each item
        activation = self.activations(self.context)

        # already recalled items have zero activation
        #for already_recalled_item in self.recall[:self.recall_total]:
        #    activation[int(already_recalled_item)] = 0
        activation[self.recall[:self.recall_total]] = 0

        # power sampling rule vs modified exponential sampling rule
        if self.sampling_rule == 0:
            activation = np.power(activation, self.choice_sensitivity)
        else:
            pre_activation = (2 * activation)/ self.choice_sensitivity
            activation = np.exp(pre_activation - np.max(pre_activation))

        self.probabilities = np.zeros((self.item_count + 1))
        self.probabilities[0] = min(self.stop_probability_scale * np.exp(
            self.recall_total * self.stop_probability_growth), 1.0  - (
                (self.item_count-self.recall_total) * 10e-7))
        
        if np.sum(activation) > 0:
            self.probabilities[1:] = (1-self.probabilities[0]) * activation / np.sum(activation)
        
        #if np.isnan(np.sum(self.probabilities)):

        return self.probabilities

    def free_recall(self, steps=None):

        # some amount of the pre-list context is reinstated before initiating recall
        if not self.retrieving:
            self.recall = np.zeros(self.item_count, int32)
            self.recall_total = 0
            self.preretrieval_context = self.context
            self.update_context(self.delay_drift_rate, self.delay_context_input)
            self.update_context(self.start_drift_rate, self.start_context_input)
            self.retrieving = True

        # number of items to retrieve is # of items left to recall if steps is unspecified
        if steps is None:
            steps = self.item_count - self.recall_total
        steps = self.recall_total + steps
        
        # at each recall attempt
        while self.recall_total < steps:

            # the current state of context is used as a retrieval cue to attempt recall of a studied item
            # compute outcome probabilities and make choice based on distribution
            outcome_probabilities = self.outcome_probabilities()
            if np.any(outcome_probabilities[1:]):
                choice = np.sum(np.cumsum(outcome_probabilities) < np.random.rand(), dtype=int)
            else:
                choice = 0

            # resolve and maybe store outcome
            # we stop recall if no choice is made (0)
            if choice == 0:
                self.retrieving = False
                self.context = self.preretrieval_context
                break
            self.recall[self.recall_total] = choice - 1
            self.recall_total += 1
            self.update_context(self.recall_drift_rate, self.items[choice - 1])
        return self.recall[:self.recall_total]

    def force_recall(self, choice=None):

        if not self.retrieving:
            self.recall = np.zeros(self.item_count, int32)
            self.recall_total = 0
            self.preretrieval_context = self.context
            self.update_context(self.delay_drift_rate, self.delay_context_input)
            self.update_context(self.start_drift_rate, self.start_context_input)
            self.retrieving = True

        if choice is None:
            pass
        elif choice > 0:
            self.recall[self.recall_total] = choice - 1
            self.recall_total += 1
            self.update_context(self.recall_drift_rate, self.items[choice - 1])
        else:
            self.retrieving = False
            self.context = self.preretrieval_context
        return self.recall[:self.recall_total]

In [4]:

@jitclass(icmr_spec)
class Instance_CMR:

    def __init__(self, item_count, presentation_count, parameters):

        # store initial parameters
        self.item_count = item_count
        self.encoding_drift_rate = parameters['encoding_drift_rate']
        self.start_drift_rate = parameters['start_drift_rate']
        self.recall_drift_rate = parameters['recall_drift_rate']
        self.shared_support = parameters['shared_support']
        self.item_support = parameters['item_support']
        self.learning_rate = parameters['learning_rate']
        self.primacy_scale = parameters['primacy_scale']
        self.primacy_decay = parameters['primacy_decay']
        self.stop_probability_scale = parameters['stop_probability_scale']
        self.stop_probability_growth = parameters['stop_probability_growth']
        self.choice_sensitivity = parameters['choice_sensitivity']
        self.context_sensitivity = parameters['context_sensitivity']
        self.feature_sensitivity = parameters['feature_sensitivity']
        
        # at the start of the list context is initialized with a state 
        # orthogonal to the pre-experimental context associated with the set of items
        self.context = np.zeros(item_count + 1)
        self.context[0] = 1
        self.preretrieval_context = self.context
        self.recall = np.zeros(item_count) # recalls has at most `item_count` entries
        self.retrieving = False
        self.recall_total = 0

        # predefine activation weighting vectors
        self.item_weighting = np.ones(item_count+presentation_count)
        self.context_weighting = np.ones(item_count+presentation_count)
        self.item_weighting[item_count:] = self.learning_rate
        self.context_weighting[item_count:] = \
            self.primacy_scale * np.exp(-self.primacy_decay * np.arange(presentation_count)) + 1
        self.all_weighting = self.item_weighting * self.context_weighting

        # preallocate for outcome_probabilities
        self.probabilities = np.zeros((item_count + 1))

        # initialize memory
        # we now conceptualize it as a pairing of two stores Mfc and Mcf respectively
        # representing feature-to-context and context-to-feature associations
        mfc = np.eye(item_count, item_count + 1, 1) * (1 - self.learning_rate)
        mcf = np.ones((item_count, item_count)) * self.shared_support
        for i in range(item_count):
            mcf[i, i] = self.item_support
        mcf = np.hstack((np.zeros((item_count, 1)), mcf))
        self.memory = np.zeros((item_count + presentation_count, item_count * 2 + 2))
        self.memory[:item_count,] = np.hstack((mfc, mcf))
        self.encoding_index = item_count
        self.items = np.eye(item_count, item_count + 1, 1)

    def experience(self, experiences):

        for i in range(len(experiences)):
            self.memory[self.encoding_index, :self.item_count+1] = experiences[i]
            self.update_context(self.encoding_drift_rate, self.memory[self.encoding_index])
            self.memory[self.encoding_index, self.item_count+1:] = self.context
            self.encoding_index += 1

    def update_context(self, drift_rate, experience=None):

        # first pre-experimental or initial context is retrieved
        if experience is not None:
            context_input = self.echo(experience)[self.item_count + 1:]
            context_input = context_input / np.sqrt(np.sum(np.square(context_input))) # norm to length 1
        else:
            context_input = np.zeros((self.item_count+1))
            context_input[0] = 1

        # updated context is sum of context and input, modulated by rho to have len 1 and some drift_rate
        rho = np.sqrt(1 + np.square(drift_rate) * (np.square(self.context * context_input) - 1)) - (
                drift_rate * (self.context * context_input))
        self.context = (rho * self.context) + (drift_rate * context_input)

    def echo(self, probe):

        return np.dot(self.activations(probe), self.memory[:self.encoding_index])

    def activations(self, probe):

        # computes and cubes similarity value to find activation for each trace in memory
        activation = np.dot(self.memory[:self.encoding_index], probe) / (
            np.sqrt(np.sum(np.square(self.memory[:self.encoding_index]), axis=1)) * np.sqrt(
                np.sum(np.square(probe))))

        # weight activations based on whether probe contains item or contextual features or both
        if np.any(probe[:self.item_count + 1]):
            if np.any(probe[self.item_count + 1:]):
                # both mfc and mcf weightings, see below
                activation *= self.all_weighting[:self.encoding_index]
            else:
                # mfc weightings - scale by gamma for each experimental trace
                activation *= self.item_weighting[:self.encoding_index]
            activation = np.power(activation, self.context_sensitivity)
        else:
            # mcf weightings - scale by primacy/attention function based on experience position
            activation *= self.context_weighting[:self.encoding_index]
            if self.feature_sensitivity != 1.0:
                activation = np.power(activation, self.feature_sensitivity)
            else:
                activation = np.power(activation, self.context_sensitivity)
            
        return activation + 10e-7

    def outcome_probabilities(self):

        activation_cue = np.hstack(
                    (np.zeros(self.item_count + 1), self.context))
        echo = self.echo(activation_cue)[1:self.item_count+1]
        echo = np.power(echo, self.choice_sensitivity)
        
        self.probabilities = np.zeros((self.item_count + 1))
        self.probabilities[0] = min(self.stop_probability_scale * np.exp(
            self.recall_total * self.stop_probability_growth), 1.0 - (self.item_count * 10e-7))

        if self.probabilities[0] < 1:
            for already_recalled_item in self.recall[:self.recall_total]:
                echo[int(already_recalled_item)] = 0
        self.probabilities[1:] = (1-self.probabilities[0]) * echo / np.sum(echo)
        
        return self.probabilities

    def free_recall(self, steps=None):

        # some pre-list context is reinstated before initiating recall
        if not self.retrieving:
            self.recall = np.zeros(self.item_count)
            self.recall_total = 0
            self.preretrieval_context = self.context
            self.update_context(self.start_drift_rate)
            self.retrieving = True
            
        # number of items to retrieve is infinite if steps is unspecified
        if steps is None:
            steps = self.item_count - self.recall_total
        steps = self.recall_total + steps

        # at each recall attempt
        while self.recall_total < steps:

            # the current state of context is used as a retrieval cue to 
            # attempt recall of a studied item compute outcome probabilities 
            # and make choice based on distribution
            outcome_probabilities = self.outcome_probabilities()
            if np.any(outcome_probabilities[1:]):
                choice = np.sum(
                    np.cumsum(outcome_probabilities) < np.random.rand())
            else:
                choice = 0

            # resolve and maybe store outcome
            # we stop recall if no choice is made (0)
            if choice == 0:
                self.retrieving = False
                self.context = self.preretrieval_context
                break
            self.recall[self.recall_total] = choice - 1
            self.recall_total += 1
            self.update_context(self.recall_drift_rate,
                                np.hstack((self.items[choice - 1], 
                                           np.zeros(self.item_count + 1))))
        return self.recall[:self.recall_total]
    
    def force_recall(self, choice=None):

        if not self.retrieving:
            self.recall = np.zeros(self.item_count)
            self.recall_total = 0
            self.preretrieval_context = self.context
            self.update_context(self.start_drift_rate)
            self.retrieving = True

        if choice is None:
            pass
        elif choice > 0:
            self.recall[self.recall_total] = choice - 1
            self.recall_total += 1
            self.update_context(
                self.recall_drift_rate, 
                np.hstack((self.items[choice - 1], 
                           np.zeros(self.item_count + 1))))
        else:
            self.retrieving = False
            self.context = self.preretrieval_context
        return self.recall[:self.recall_total]

## Function
We optimize this with numba (and some creative Python) to speed up the calculation.

In [5]:
import numpy as np
from numba import njit, prange

@njit(fastmath=True, nogil=True, parallel=True)
def murdock_data_likelihood(data_to_fit, item_counts, model_class, parameters):

    result = 0.0
    for i in range(len(item_counts)):
        item_count = item_counts[i]
        trials = data_to_fit[i]
        likelihood = np.ones((len(trials), item_count))

        model = model_class(item_count, item_count, parameters)
        model.experience(model.items)

        for trial_index in range(len(trials)):
            trial = trials[trial_index]

            model.force_recall()
            for recall_index in range(len(trial) + 1):

                # identify index of item recalled; if zero then recall is over
                if recall_index == len(trial) and len(trial) < item_count:
                    recall = 0
                else:
                    recall = trial[recall_index]

                # store probability of and simulate recall of indexed item 
                likelihood[trial_index, recall_index] = \
                    model.outcome_probabilities()[recall] + 10e-7
                
                if recall == 0:
                    break
                model.force_recall(recall)

            # reset model to its pre-retrieval (but post-encoding) state
            model.force_recall(0)
        
        result -= np.sum(np.log(likelihood))

    return result

In [6]:
from numba.typed import Dict
from numba.core import types

def murdock_objective_function(data_to_fit, item_counts, model_class, fixed_parameters, free_parameters):
    """
    Configures cmr_likelihood for search over specified free/fixed parameters.
    """

    base_parameters = Dict.empty(key_type=types.unicode_type, value_type=types.float64)
    for name, value in fixed_parameters.items():
        base_parameters[name] = value
    
    @njit(fastmath=True, nogil=True)
    def objective_function(x):
        parameters = base_parameters.copy()
        for i in range(len(free_parameters)):
            parameters[free_parameters[i]] = x[i]
        return murdock_data_likelihood(data_to_fit, item_counts, model_class, parameters)

    return objective_function

## Performance Demo

In [7]:
from numba.typed import List
from compmemlearn.datasets import prepare_murdock1962_data

murd_trials0, murd_events0, murd_length0 = prepare_murdock1962_data('../../data/MurdData_clean.mat', 0)
murd_trials1, murd_events1, murd_length1 = prepare_murdock1962_data('../../data/MurdData_clean.mat', 1)
murd_trials2, murd_events2, murd_length2 = prepare_murdock1962_data('../../data/MurdData_clean.mat', 2)
murd_events0.head()


free_parameters = [
    'encoding_drift_rate',
    'start_drift_rate',
    'recall_drift_rate',
    'shared_support',
    'item_support',
    'learning_rate',
    'primacy_scale',
    'primacy_decay',
    'stop_probability_scale',
    'stop_probability_growth',
    'choice_sensitivity',
    'delay_drift_rate',
    'mfc_familiarity_scale',
    'mcf_familiarity_scale',
    'drift_familiarity_scale']

fit_values = [8.36412380e-01, 1.37438014e-01, 9.14231622e-01, 1.03548867e-01,
 1.00000000e+00, 2.70091435e-01, 3.89549514e+00, 6.04259249e-01,
 2.04970025e-02, 1.20035161e-01, 1.78302334e+00, 9.96231007e-01,
 0.0, 0.0, 0.0]

cmr_parameters = Dict.empty(key_type=types.unicode_type, value_type=types.float64)
for i in range(len(fit_values)):
    cmr_parameters[free_parameters[i]] = fit_values[i]
cmr_parameters['sampling_rule'] = 0


In [8]:
@njit(fastmath=True, nogil=True)
def init_cmr(item_count, presentation_count, parameters):
    return Classic_CMR(item_count, presentation_count, parameters)

murdock_data_likelihood(List([murd_trials0, murd_trials1, murd_trials2]), List([murd_length0,murd_length1, murd_length2]), init_cmr, cmr_parameters)

88662.15139967934

> 88662.15139967931

In [9]:
%%timeit
murdock_data_likelihood(List([murd_trials0, murd_trials1, murd_trials2]), List([murd_length0,murd_length1, murd_length2]),init_cmr, cmr_parameters)

238 ms ± 13.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


> 63.9 ms ± 1.02 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

## ICMR Test

In [10]:
free_parameters = [
    'encoding_drift_rate',
    'start_drift_rate',
    'recall_drift_rate',
    'shared_support',
    'item_support',
    'learning_rate',
    'primacy_scale',
    'primacy_decay',
    'stop_probability_scale',
    'stop_probability_growth',
#    'choice_sensitivity',
    'context_sensitivity',
#    'feature_sensitivity'
]

fit_values = np.array([7.25392222e-01, 2.22044605e-16, 9.04899967e-01, 1.16637532e-05,
       1.35117327e-03, 7.14751007e-03, 6.26753336e+00, 1.73084858e+00,
       3.05537134e-02, 3.38091912e-01, 1.54964510e+00])

icmr_parameters = Dict.empty(key_type=types.unicode_type, value_type=types.float64)
for i in range(len(fit_values)):
    icmr_parameters[free_parameters[i]] = fit_values[i]
icmr_parameters['choice_sensitivity'] = 1.0
icmr_parameters['feature_sensitivity'] = 1.0

In [11]:
@njit(fastmath=True, nogil=True)
def init_icmr(item_count, presentation_count, parameters):
    return Instance_CMR(item_count, presentation_count, parameters)

murdock_data_likelihood(List([murd_trials0, murd_trials1, murd_trials2]), List([murd_length0,murd_length1, murd_length2]), init_icmr, icmr_parameters)

87081.77648337423

## Time Profiler
Bottlenecks in the main likelihood function are calls to outcome_probability and to force_recall functions.

Main bottleneck in the outcome_probabilities function is the call to self.activations, taking 74.4% of time. Maybe I can simplify by only computing the activation for a single item? Alternatively, would be nice to sidestep the dot product.

In [12]:
%load_ext line_profiler

In [13]:
%lprun -f cmr_murd_likelihood cmr_murd_likelihood(List([murd_trials0, murd_trials1, murd_trials2]), List([murd_length0,murd_length1, murd_length2]), **parameters)

UsageError: Could not find function 'cmr_murd_likelihood'.
NameError: name 'cmr_murd_likelihood' is not defined
