# CMR - Single List Length
Now that relevant dependencies are specified and testing, we'll jump right into fitting the model to larger portions of the dataset. This time, we'll do the entire 20-item list length subset of the Murdock (1962) dataset.

## Load Data

In [1]:
from instance_cmr.datasets import prepare_repdata
from instance_cmr.model_fitting import cmr_rep_likelihood
from instance_cmr.model_fitting import cmr_rep_objective_function
from instance_cmr.model_fitting import visualize_rep_fit
from instance_cmr.models import CMR

trials, events, list_length, presentations, list_types, rep_data, subjects = prepare_repdata(
    '../../data/repFR.mat')

events.head()

Unnamed: 0,subject,list,item,input,output,study,recall,repeat,intrusion,condition
0,1,1,0,1,1.0,True,True,0,False,4
1,1,1,1,2,2.0,True,True,0,False,4
2,1,1,2,3,3.0,True,True,0,False,4
3,1,1,3,4,4.0,True,True,0,False,4
4,1,1,4,5,5.0,True,True,0,False,4


## Fitting

In [7]:
from scipy.optimize import differential_evolution
from numba.typed import List
import numpy as np

free_parameters = [
    'encoding_drift_rate',
    'start_drift_rate',
    'recall_drift_rate',
    'shared_support',
    'item_support',
    'learning_rate',
    'primacy_scale',
    'primacy_decay',
    'stop_probability_scale',
    'stop_probability_growth',
    'choice_sensitivity']

lb = np.finfo(float).eps
ub = 1-np.finfo(float).eps

bounds = [
    (lb, ub),
    (lb, ub),
    (lb, ub),
    (lb, ub),
    (lb, ub),
    (lb, ub),
    (lb, 100),
    (lb, 100),
    (lb, ub),
    (lb, 10),
    (lb, 10)
]

# cost function to be minimized
# ours scales inversely with the probability that the data could have been 
# generated using the specified parameters and our model
selection = list_types > 1
cost_function = cmr_rep_objective_function(
    trials[selection], presentations[selection], list_types[selection], list_length,
    {}, free_parameters)

result = differential_evolution(cost_function, bounds, disp=True)
print(result)

differential_evolution step 1: f(x)= 92383.6
differential_evolution step 2: f(x)= 79110
differential_evolution step 3: f(x)= 49855.4
differential_evolution step 4: f(x)= 46285.7
differential_evolution step 5: f(x)= 46207
differential_evolution step 6: f(x)= 46207
differential_evolution step 7: f(x)= 44908.7
differential_evolution step 8: f(x)= 44908.7
differential_evolution step 9: f(x)= 44908.7
differential_evolution step 10: f(x)= 44908.7
differential_evolution step 11: f(x)= 44908.7
differential_evolution step 12: f(x)= 43481.9
differential_evolution step 13: f(x)= 43481.9
differential_evolution step 14: f(x)= 43481.9
differential_evolution step 15: f(x)= 43481.9
differential_evolution step 16: f(x)= 43421.7
differential_evolution step 17: f(x)= 42900.7
differential_evolution step 18: f(x)= 42900.7
differential_evolution step 19: f(x)= 42686
differential_evolution step 20: f(x)= 42686
differential_evolution step 21: f(x)= 42686
differential_evolution step 22: f(x)= 42686
differentia

## Results
condition = 1:
```
     fun: 18808.191701458352
     jac: array([-6.54836182e-03, -2.03726812e-02, -2.47382558e-01, -1.08411769e-01,
        1.20271579e+01, -4.45295880e+01, -5.45696818e-02,  0.00000000e+00,
       -6.14818418e-02,  4.27098712e-01,  5.82076613e-03])
 message: 'Optimization terminated successfully.'
    nfev: 6678
     nit: 33
 success: True
       x: array([7.34339302e-01, 5.77280522e-01, 9.32318711e-01, 3.95388799e-01,
       2.22044605e-16, 1.00000000e+00, 5.52186306e-01, 2.65079941e+01,
       2.36119108e-02, 9.12422168e-02, 5.76025967e+00])
```

condition = 2:
```
     fun: 11599.828829510156
     jac: array([-7.34871715e-02, -1.86082615e-01, -3.72710927e-01, -3.34663126e+01,
        2.91904144e+01, -2.33954778e+01, -7.63975554e-03,  0.00000000e+00,
        4.21205186e+00,  7.74525688e-01,  1.45519153e-03])
 message: 'Optimization terminated successfully.'
    nfev: 7929
     nit: 40
 success: True
       x: array([8.67736021e-01, 7.01910644e-01, 9.24538309e-01, 1.00000000e+00,
       2.22044605e-16, 1.00000000e+00, 1.57450284e+00, 6.48561407e+01,
       1.06204366e-02, 2.08793367e-01, 6.61945787e+00])
```

condition = 3:
```
     fun: 12467.28845030914
     jac: array([-7.06495481e-01,  2.79451341e+00,  4.65661285e-01,  8.89849617e-01,
       -3.12502378e-01, -4.91127139e-03,  3.10319592e-01,  0.00000000e+00,
        1.41712007e+01, -2.56841304e-01, -6.36282497e-01])
 message: 'Optimization terminated successfully.'
    nfev: 7977
     nit: 36
 success: True
       x: array([8.04918058e-01, 9.40206223e-01, 9.09553576e-01, 4.28407588e-01,
       6.78013841e-01, 3.82057987e-01, 2.86219265e-01, 1.30715221e+01,
       8.03299670e-03, 2.19058557e-01, 3.58487757e+00])
```

condition = 4:
```
     fun: 17288.22393131492
     jac: array([ 6.68660501e-01, -1.22599885e-01, -5.09317030e-03,  1.13141141e-01,
        3.52516508e+01, -1.67347025e-01,  1.91539585e+00,  0.00000000e+00,
       -2.75831553e+00, -3.88172339e-01, -9.68066166e-01])
 message: 'Optimization terminated successfully.'
    nfev: 7344
     nit: 39
 success: True
       x: array([8.46529122e-01, 5.31498507e-01, 9.75056709e-01, 4.26670303e-02,
       2.22044605e-16, 4.83256935e-01, 2.71210420e+00, 4.16117716e+01,
       2.13758709e-02, 1.06864959e-01, 1.23391835e+00])
```

condition = all:
```
     fun: 60478.103550051164
     jac: array([-5.85423547,  0.99098543,  0.13242243,  2.06637196,  5.32672857,
        0.6868504 , -0.06111804, -0.45693014,  3.69254849,  4.03961167,
       -0.98516467])
 message: 'Optimization terminated successfully.'
    nfev: 6753
     nit: 28
 success: True
       x: array([8.44186613e-01, 4.82922372e-01, 9.64152301e-01, 4.66982063e-02,
       2.22044605e-16, 4.11208644e-01, 4.38262927e+00, 3.66252400e-01,
       2.51594034e-02, 1.01416573e-01, 1.14461246e+00])
```

condition = 2, 3, 4:
```
     fun: 41497.26322521102
     jac: array([ 1.52795109e-01,  1.16415321e-02, -6.57018969e-01, -8.65838957e-02,
        2.23917596e+01,  6.40284270e-02,  1.81898941e-02, -1.12049747e-01,
       -5.89352567e-02, -4.24188329e-01, -5.09317036e-03])
 message: 'Optimization terminated successfully.'
    nfev: 9627
     nit: 42
 success: True
       x: array([8.29021569e-01, 5.34374726e-01, 9.57772443e-01, 8.05158344e-02,
       2.22044605e-16, 3.56257191e-01, 3.86117273e+00, 3.13361270e-01,
       2.18375575e-02, 1.21196998e-01, 1.31711611e+00])

```

In [10]:
for subject in np.unique(subjects):
    for condition_index, condition in enumerate([list_types == 1, list_types > 1, list_types > 0]):
    
        print(subject, condition_index)
        selection = np.logical_and(condition, subjects == subject)
        cost_function = cmr_rep_objective_function(
            trials[selection], presentations[selection], list_types[selection], list_length,
            {}, free_parameters)

        result = differential_evolution(cost_function, bounds, disp=False)
        
        print(result)

1 0
     fun: 689.990524765121
     jac: array([ 8.98808135e-02,  1.01374553e+00,  3.81567132e-01, -1.23918653e-03,
        6.29133865e+00,  4.58155680e+00,  6.45661658e-01,  0.00000000e+00,
        1.36695576e+01,  1.30398803e-02, -7.98706883e-01])
 message: 'Optimization terminated successfully.'
    nfev: 11937
     nit: 56
 success: True
       x: array([7.11936485e-01, 7.80019478e-01, 9.69824091e-01, 1.59549295e-01,
       2.22044605e-16, 8.02175443e-01, 2.53588615e-01, 5.80197566e+01,
       9.51171097e-05, 3.64171832e-01, 3.21611156e+00])
1 1
     fun: 1153.8934664606459
     jac: array([ 0.44194621, -0.03992682,  0.06675691,  0.30099727, -0.16320882,
       -0.50226845, -0.00613909,  0.10427357, -3.73479452, -0.12232704,
        0.01032276])
 message: 'Optimization terminated successfully.'
    nfev: 15702
     nit: 45
 success: True
       x: array([7.52535469e-01, 6.91981291e-01, 9.48399975e-01, 2.01664114e-01,
       8.25163326e-01, 1.60789372e-01, 6.64969317e+00, 1.95623306