In [None]:
#hide
#default_exp model_analysis
from nbdev.showdoc import *

# Visualizing Model States

We often simulate a simple free recall experiment and visualize model states throughout to explore their capacity to
exhibit classical patterns of primacy, recency, and temporal contiguity. Any arbitrary configuration of parameters can
be specified for the model, including an `experiment_count`, determining the number of simulations with the given
parameters.

In each experiment:
1. A specified number of unique items are each experienced once,
2. Context is momentarily drifted toward its pre-experimental state, and
3. The model freely recalls items until it stops, with retrieval of previously experienced items disallowed.

To visualize model state, we add to our `model_analysis` submodule three basic categories of visualizations. To
visualize model state throughout encoding, we track the state of `context` and the amount of `support` for recall of
each item based on contextual state. We also prepare a visualization of the final state of `memory` once encoding is
finished. To visualize model state throughout retrieval, we similarly track `context` and `support` at each step of
recall. An additional visualization makes clearer the distribution of outcome probabilities at a particular index of
recall (e.g. after a second item has been recalled). While the previous sets of analyses focus on behavior of a
particular instantiation of the model, a final set of analysis focuses on model behavior across many simulations. We
track recall probability as a function of serial position, probability of starting recall with each serial position,
and conditional response probability as a function of lag.

## Encoding
First we create simulations and visualizations to track model state throughout encoding of new memories. To do this,
we produce two parallel functions, `encoding_states` and `plot_states` that collect and visualize encoding states,
respectively. An additional wrapper function called `encoding_visualizations` plots these states in addition to the
final overall state of model memory.

In [None]:
#hide 
#export
import numpy as np

def encoding_states(model):
    """
    Tracks state of context, and item supports across encoding. Model is also advanced to a state of fully encoded
    memories.

    **Required model attributes**:  
    - item_count: specifies number of items encoded into memory  
    - context: vector representing an internal contextual state  
    - experience: adding a new trace to the memory model  
    - activations: function returning item activations given a vector probe  
    - outcome_probabilities: function returning item supports given a set of activations

    **Returns** array representations of context and support for retrieval of each item at each increment of item
    encoding. Each has shape model.item_count by model.item_count + 1.
    """
    
    experiences = np.eye(model.item_count, model.item_count + 1, 1)
    cmr_experiences = np.eye(model.item_count, model.item_count)
    encoding_contexts, encoding_supports = model.context, []

    # track model state across experiences
    for i in range(len(experiences)):
        try:
            model.experience(experiences[i].reshape((1, -1)))
        except ValueError:
            # special case for CMR
            model.experience(cmr_experiences[i].reshape((1, -1)))

        # track model contexts and item supports
        encoding_contexts = np.vstack((encoding_contexts, model.context))

        if model.__class__.__name__ == 'CMR':
            activation_cue = lambda model: model.context
        else:
            activation_cue = lambda model: np.hstack((np.zeros(model.item_count + 1), model.context))

        if len(encoding_supports) > 0:
            encoding_supports = np.vstack((encoding_supports, model.outcome_probabilities(activation_cue(model))))
        else:
            encoding_supports = model.outcome_probabilities(activation_cue(model))
    
    return encoding_contexts, encoding_supports

In [None]:
show_doc(encoding_states, title_level=3)

In [None]:
# hide
# export 
# collapse_input
import seaborn as sns
import matplotlib.pyplot as plt

def plot_states(matrix, title, figsize=(15, 15), savefig=False):
    """
    Plots an array of model states as a value-annotated heatmap with an arbitrary title.

    **Arguments**:  
    - matrix: an array of model states, ideally with columns representing unique feature indices and rows
        representing unique update indices  
    - title: a title for the generated plot, ideally conveying what array values represent at each entry  
    - savefig: boolean deciding whether generated figure is saved (True if Yes)
    """
    plt.figure(figsize=figsize)
    sns.heatmap(matrix, annot=True, linewidths=.5)
    plt.title(title)
    plt.xlabel('Feature Index')
    plt.ylabel('Update Index')
    if savefig:
        plt.savefig('figures/{}.jpeg'.format(title).replace(' ', '_').lower(), bbox_inches='tight')
    plt.show()

In [None]:
show_doc(plot_states, title_level=3)

In [None]:
# hide
# export
def encoding_visualizations(model, savefig=True):
    """
    Plots encoding contexts, encoding supports as heatmaps.

    **Required model attributes**:  
    - item_count: specifies number of items encoded into memory  
    - context: vector representing an internal contextual state  
    - experience: adding a new trace to the memory model  
    - activations: function returning item activations given a vector probe  
    - outcome_probabilities: function returning item supports given a set of activations
    - memory: a unitary representation of the current state of memory

    **Also** requires savefig:  boolean deciding if generated figure is saved
    """
    
    encoding_contexts, encoding_supports = encoding_states(model)
    plot_states(encoding_contexts, 'Encoding Contexts', savefig=savefig)
    plot_states(encoding_supports, 'Supports For Each Item At Each Increment of Encoding', savefig=savefig)

In [None]:
try:
    show_doc(encoding_visualizations, title_level=3)
except:
    pass

### Demo

#### ICMR

In [None]:
from instance_cmr.models import InstanceCMR

parameters = {
    'item_count': 20,
    'encoding_drift_rate': .8,
    'start_drift_rate': .7,
    'recall_drift_rate': .8,
    'shared_support': 0.01,
    'item_support': 1.0,
    'learning_rate': .3,
    'primacy_scale': 1,
    'primacy_decay': 1,
    'stop_probability_scale': 0.01,
    'stop_probability_growth': 0.3,
    'choice_sensitivity': 2
}

model = InstanceCMR(**parameters)
encoding_visualizations(model)

![](figures/icmr_encoding_contexts.jpeg)
![](figures/icmr_supports_for_each_item_at_each_increment_of_encoding.jpeg)


#### CMR

In [None]:
from instance_cmr.models import CMR

parameters = {
    'item_count': 20,
    'encoding_drift_rate': .8,
    'start_drift_rate': .7,
    'recall_drift_rate': .8,
    'shared_support': 0.01,
    'item_support': 1.0,
    'learning_rate': .3,
    'primacy_scale': 1,
    'primacy_decay': 1,
    'stop_probability_scale': 0.01,
    'stop_probability_growth': 0.3,
    'choice_sensitivity': 2
}

model = CMR(**parameters)
encoding_visualizations(model)

![](figures/cmr_encoding_contexts.jpeg)
![](figures/cmr_supports_for_each_item_at_each_increment_of_encoding.jpeg)

## Retrieval
Tracking model state across each step of retrieval. Since it's stochastic, these values change with each
random seed. An additional optional parameter `first_recall_item` can control which item is recalled first by
the model (`0` denotes termination of recall while actual items are 1-indexed); it is useful for testing
hypotheses about model dynamics during recall. We leave the parameter set at `None`, for now, indicating no
controlled first recall.

In [None]:
# export
# hide
#collapse_input
import numpy as np

def retrieval_states(model, first_recall_item=None):
    """
    Tracks state of context, and item supports across retrieval. Model is also advanced into a state of
    completed free recall.

    **Required model attributes**:
    - item_count: specifies number of items encoded into memory
    - context: vector representing an internal contextual state
    - experience: adding a new trace to the memory model
    - activations: function returning item activations given a vector probe
    - outcome_probabilities: function returning item supports given a set of activations
    - free_recall: function that freely recalls a given number of items or until recall stops
    - state: indicates whether model is encoding or engaged in recall with a string

    **Also** optionally uses first_recall_item: can specify an item for first recall

    **Returns** array representations of context and support for retrieval of each item at each increment of item
    retrieval. Also returns recall train associated with simulation.
    """

    if model.__class__.__name__ == 'CMR':
        activation_cue = lambda model: model.context
    else:
        activation_cue = lambda model: np.hstack((np.zeros(model.item_count + 1), model.context))

    # encoding items, presuming model is freshly initialized
    encoding_states(model)
    retrieval_contexts, retrieval_supports = model.context, model.outcome_probabilities(activation_cue(model))

    # pre-retrieval distraction
    model.free_recall(0)
    retrieval_contexts = np.vstack((retrieval_contexts, model.context))
    retrieval_supports = np.vstack((retrieval_supports, model.outcome_probabilities(activation_cue(model))))

    # optional forced first item recall
    if first_recall_item is not None:
        model.force_recall(first_recall_item)
        retrieval_contexts = np.vstack((retrieval_contexts, model.context))
        retrieval_supports = np.vstack((retrieval_supports, model.outcome_probabilities(activation_cue(model))))

    # actual recall
    while model.retrieving:
        model.free_recall(1)
        retrieval_contexts = np.vstack((retrieval_contexts, model.context))
        retrieval_supports = np.vstack((retrieval_supports, model.outcome_probabilities(activation_cue(model))))

    return retrieval_contexts, retrieval_supports, model.recall[:model.recall_total]

In [None]:
try:
    show_doc(retrieval_states, title_level=3)
except:
    pass

In [None]:
#export
#collapse_input
def outcome_probs_at_index(model, support_index_to_plot=1, savefig=True):
    """
    Plots outcome probability distribution at a specific index of free recall.

    **Required model attributes**:
    - item_count: specifies number of items encoded into memory  
    - context: vector representing an internal contextual state  
    - experience: adding a new trace to the memory model  
    - activations: function returning item activations given a vector probe  
    - outcome_probabilities: function returning item supports given a set of activations  
    - free_recall: function that freely recalls a given number of items or until recall stops  
    - state: indicates whether model is encoding or engaged in recall with a string

    **Other arguments**:  
    - support_index_to_plot: index of retrieval to plot  
    - savefig: whether to save or display the figure of interest

    **Generates** a plot of outcome probabilities as a line graph. Also returns vector representation of the
    generated probabilities.
    """

    retrieval_supports = retrieval_states(model)[1]
    plt.plot(np.arange(model.item_count + 1), retrieval_supports[support_index_to_plot])
    plt.xlabel('Choice Index')
    plt.ylabel('Outcome Probability')
    plt.title('Outcome Probabilities At Recall Index {}'.format(support_index_to_plot))
    plt.show()
    return retrieval_supports[support_index_to_plot]

In [None]:
try:
    show_doc(outcome_probs_at_index, title_level=3)
except:
    pass

In [None]:
#export
#collapse_input
def retrieval_visualizations(model, savefig=True):
    """
    Plots incremental retrieval contexts and supports, as heatmaps, and prints recalled items.

    **Required model attributes**:
    - item_count: specifies number of items encoded into memory
    - context: vector representing an internal contextual state
    - experience: adding a new trace to the memory model
    - activations: function returning item activations given a vector probe
    - outcome_probabilities: function returning item supports given a set of activations

    **Also** uses savefig: boolean deciding whether figures are saved (True) or displayed
    """
    
    retrieval_contexts, retrieval_supports, recall = retrieval_states(model)
    plot_states(retrieval_contexts, 'Retrieval Contexts', savefig=savefig)
    plot_states(retrieval_supports, 'Supports For Each Item At Each Increment of Retrieval', 
                savefig=savefig)
    return recall

In [None]:
try:
    show_doc(retrieval_visualizations, title_level=3)
except:
    pass

### Demo

#### ICMR

In [None]:
model = InstanceCMR(**parameters)
retrieval_visualizations(model)

Outputs can look like...

![](figures/retrieval_contexts.jpeg)
![](figures/supports_for_each_item_at_each_increment_of_retrieval.jpeg)

#### CMR

In [None]:
model = CMR(**parameters)
retrieval_visualizations(model)

![](figures/retrieval_contexts.jpeg)
![](figures/supports_for_each_item_at_each_increment_of_retrieval.jpeg)

## Organizational Analyses
Upon completion,  the `psifr` toolbox is used to generate three plots corresponding to the contents of Figure
4 in Morton & Polyn, 2016:
1. Recall probability as a function of serial position
2. Probability of starting recall with each serial position
3. Conditional response probability as a function of lag

Whereas previous visualizations were based on an arbitrary model simulation, the current figures are based on
averages over a simulation of the model some specified amount of times.

In [None]:
#export
#collapse_input
import pandas as pd
from psifr import fr

def temporal_organization_analyses(model, experiment_count, savefig=False, figsize=(15, 15), first_recall_item=None):
    """
    Visualization of the outcomes of a trio of organizational analyses of model performance on a free recall
    task.

    **Required model attributes**:
    - item_count: specifies number of items encoded into memory  
    - context: vector representing an internal contextual state  
    - experience: adding a new trace to the memory model  
    - free_recall: function that freely recalls a given number of items or until recall stops  

    **Other arguments**:  
    - experiment_count: number of simulations to compute curves over  
    - savefig: whether to save or display the figure of interest

    **Returns** three plots corresponding to the contents of Figure 4 in Morton & Polyn, 2016:  
    1. Recall probability as a function of serial position  
    2. Probability of starting recall with each serial position  
    3. Conditional response probability as a function of lag  
    """
    
    # encode items
    try:
        model.experience(np.eye(model.item_count, model.item_count + 1, 1))
    except ValueError:
        # so we can apply to CMR
        model.experience(np.eye(model.item_count, model.item_count))
    
    # simulate retrieval for the specified number of times, tracking results in df
    data = []
    for experiment in range(experiment_count):
        data += [[experiment, 0, 'study', i + 1, i] for i in range(model.item_count)]
    for experiment in range(experiment_count):
        if first_recall_item is not None:
            model.force_recall(first_recall_item)
        data += [[experiment, 0, 'recall', i + 1, o] for i, o in enumerate(model.free_recall())]
    data = pd.DataFrame(data, columns=['subject', 'list', 'trial_type', 'position', 'item'])
    merged = fr.merge_free_recall(data)
    
    # visualizations
    # spc
    recall = fr.spc(merged)
    g = fr.plot_spc(recall)
    plt.title('Serial Position Curve')
    if savefig:
        plt.savefig('figures/spc.jpeg', bbox_inches='tight')
    else:
        plt.show()

    # P(Start Recall) For Each Serial Position
    prob = fr.pnr(merged)
    pfr = prob.query('output <= 1')
    g = fr.plot_spc(pfr).add_legend()
    plt.title('Probability of Starting Recall With Each Serial Position')
    if savefig:
        plt.savefig('figures/pfr.jpeg', bbox_inches='tight')
    else:
        plt.show()

    # Conditional response probability as a function of lag
    crp = fr.lag_crp(merged)
    g = fr.plot_lag_crp(crp)
    plt.title('Conditional Response Probability')
    if savefig:
        plt.savefig('figures/crp.jpeg', bbox_inches='tight')
    else:
        plt.show()

In [None]:
try:
    show_doc(temporal_organization_analyses, title_level=3)
except:
    pass

### Demo

In [None]:
from instance_cmr.models import InstanceCMR

model = InstanceCMR(**parameters)
temporal_organization_analyses(model, 100, True)

![](figures/spc.jpeg)

![](figures/pfr.jpeg)

![](figures/crp.jpeg)