# Optimizing a memory for the storage of structured semantic pointers

In the previous examples, we have essentially ignored time by defining models that map inputs to outputs in a single forward pass (i.e., we configured the default synapse to be `None`). In this example, we'll introduce a simple process model of information retrieval based on [this](https://github.com/nengo/nengo/blob/master/examples/spa/question_memory.ipynb) standard Nengo tutorial. The idea is to sequentially present pairs of items that get bound together and added to a structured semantic pointer that is retained in memory over time. Once the bound pairs have been added to the semantic pointer, we can then query it with a cue as before to test retrieval accuracy. 

In [None]:
import nengo
import nengo.spa as spa
import nengo_dl
import numpy as np
import tensorflow as tf

from urllib.request import urlretrieve
import zipfile

import matplotlib.pyplot as plt
%matplotlib inline

## 1. Generate random semantic pointers

To start, we'll define a new function for generating training data that returns of arrays of shape `(n_inputs, n_steps, dims)`, where `n_steps` will be the number of time steps in the process we want to model. This data will allow us to learn model parameters that help to ensure a trajectory of inputs produces an appropriate trajectory of outputs. To start, we'll generate simple examples in which the input tractory presents a single semantic pointer to the network for some number of time steps, and the desired output trajectory involves maintaining a representation of this semantic pointer in a recurrently connected ensemble for some further number of time steps. 

In [None]:
def get_memory_data(n_inputs, dims, seed, t_int, t_mem=None): 
    t_total = t_int + t_mem if t_mem else t_int * 2
    
    state = np.random.RandomState(seed)
    vocab = spa.Vocabulary(dimensions=dims, rng=state, max_similarity=1)
    
    # intialize arrays for input and output trajectories
    inputs = np.zeros((n_inputs, t_total, dims))
    outputs = np.zeros((n_inputs, t_total, dims))
    
    # iterate through examples to be generated, fill arrays
    for n in range(n_inputs):
        vocab.add('SP' + str(n), vocab.create_pointer())
        
        # create inputs and target memory for first pair
        inputs[n, :t_int, :] = vocab['SP' + str(n)].v
        outputs[n, :, :] = vocab['SP' + str(n)].v

    # make scaling ramp for target output trajectories
    ramp = np.asarray([t / t_int for t in range(t_int)])
    ramp = np.concatenate((ramp, np.ones(t_total - t_int)))
    outputs = outputs * ramp[None, :, None]      
        
    return inputs, outputs, vocab 

Now we'll generate some example trajectories for visualization purposes. We'll also make some arrays to define an input process in our model, so that an example input trajectory is presented by default when a simulation is run. 

In [None]:
t_int = 100
n_inputs = 40
dims = 64
seed = 236
default = 0 # which test item to present by default

test_inputs, test_outputs, test_vocab = get_memory_data(n_inputs, dims, seed=seed, t_int=t_int)

# create arrays for PresentInput process in model definition
sp = np.vstack((test_vocab['SP' + str(default)].v, np.zeros((1, dims))))

## 2. Define the model

Initially, we'll build a very simple model containing a single input node and single memory ensemble. The input will by default present the arrays in `sp` at regular intervals specified by the variable `t_int`.  

In [None]:
n_neurons = 10 * dims  # number of neurons for memory ensemble
p_time = t_int / 1000 # presentation interval for input process in ms

with nengo.Network(seed=seed) as net:
    nengo_dl.configure_settings(trainable=False)
    net.config[nengo.Ensemble].neuron_type = nengo.RectifiedLinear()
    
    sp_input = nengo.Node(nengo.processes.PresentInput(sp, p_time))
    memory = nengo.Ensemble(n_neurons, dims)

    nengo.Connection(sp_input, memory)
    conn = nengo.Connection(memory, memory, transform=0.5, synapse=0.1)
    
    net.config[conn].trainable = True
    net.config[nengo.Probe].synapse = nengo.Lowpass(0.01)
    
    sp_probe = nengo.Probe(sp_input)   
    memory_probe = nengo.Probe(memory)

Next, we'll run the model for the specified number of steps, both to see how well the memory works, and to see how closely that trajectory of this memory's state matches the target trajectory. Because the initial transform on the memory's recurrent connection is less than one, we should expect the memory to be quite 'leaky' and not function particularly well. 

In [None]:
with nengo_dl.Simulator(net, seed=seed) as sim:
    sim.run(2 * p_time)

In [None]:
def plot_memory_example(sim, vocab):
    plt.figure(figsize=(8, 8))
    
    plot_vocab = vocab.create_subset([k for k in vocab.keys[:10]])
    plt.subplot(3, 1, 1)
    plt.plot(sim.trange(), nengo.spa.similarity(sim.data[sp_probe], plot_vocab))
    plt.legend(plot_vocab.keys, fontsize='x-small', loc='right')
    plt.ylabel("Input")

    plt.subplot(3, 1, 2)
    plt.plot(sim.trange(), nengo.spa.similarity(sim.data[memory_probe], plot_vocab))
    plt.legend(plot_vocab.keys, fontsize='x-small', loc='right')
    plt.ylabel("Memory")

    out = test_outputs[default,:,:].reshape(t_int*2, dims)
    plt.subplot(3, 1, 3)
    plt.plot(sim.trange(), nengo.spa.similarity(out, plot_vocab))
    plt.legend(plot_vocab.keys, fontsize='x-small', loc='right')
    plt.ylabel("Target Memory")
    plt.xlabel("time [s]")
    
plot_memory_example(sim, test_vocab)  

## 3. Train the model

As is apparent in the above plot, the memory ensemble does not maintain a particularly stable representation of the input semantic pointer over time. This is because the transform weight on the memory's recurrent connection is 0.5, which results in the memory state decaying quickly over time. To improve rentention, we'll optimize the network parameters to correctly align the trajectories of the memory ensemble with the target trajectories corresponding to a large number of example input semantic pointers. Training on temporally extended trajectories can be slow, so we'll download pretrained parameters by default. You can train your own parameters under varying learning conditions by setting `do_training=True`.

In [None]:
do_training = False

sim = nengo_dl.Simulator(net, minibatch_size=50, seed=seed)

if do_training:
    optimizer = tf.train.MomentumOptimizer(learning_rate=0.00001, momentum=0.8, use_nesterov=True)
    train_inputs, train_outputs, _ = get_memory_data(4000, dims, seed, t_int)
    inputs = {sp_input: train_inputs}
    outputs = {memory_probe: train_outputs}

    print('Training loss before: ', sim.loss(inputs, outputs, 'mse'))
    sim.train(input_feed, output_feed, optimizer, n_epochs=40, objective='mse')
    print('Training loss after: ', sim.loss(inputs, outputs, 'mse'))

    sim.save_params('./mem_params')

else:
    # download pretrained parameters
    urlretrieve(
        "https://drive.google.com/uc?export=download&id=0BxRAh6Eg1us4STdmenUxV2VPcTg",
        "mem_params.zip")
    with zipfile.ZipFile("mem_params.zip") as f:
        f.extractall()

sim.close()

In [None]:
with nengo_dl.Simulator(net, seed=seed) as sim:
    sim.load_params('./mem_params')
    sim.run(2 * p_time)

plot_memory_example(sim, test_vocab)

As is apparent, the training procedure significantly improves the stability of the memory, with the representation of the input semantic pointer being retained indefinitely over time. 

## 4. Introduce input binding and retrieval cues

To make things a bit more complicated, we'll introduce a binding procedure such that the memory ensemble encodes a structured semantic pointer that is generated from a single pair of input semantic pointers. We'll also incorporate input cues that can be used to retrieve information from the memory ensemble. In order to do this, we'll have to make some simple changes to our data generation function and our model definition.

In [None]:
def get_binding_data(n_inputs, dims, seed, t_int, t_mem=None): 
    t_total = t_int + t_mem if t_mem else t_int * 2
    
    state = np.random.RandomState(seed)
    vocab = spa.Vocabulary(dimensions=dims, rng=state, max_similarity=1)
    
    # intialize arrays for input and output trajectories
    roles = np.zeros((n_inputs, t_total, dims))
    fills = np.zeros((n_inputs, t_total, dims))
    pairs = np.zeros((n_inputs, t_total, dims))

    # iterate through examples to be generated, fill arrays
    for n in range(n_inputs):        
        roles[n, :t_int, :] = vocab.parse('ROLE_'+str(n)).v
        fills[n, :t_int, :] = vocab.parse('FILL_'+str(n)).v
        
        pair_key = 'ROLE_' + str(n) + '*' + 'FILL_' + str(n)
        pair_ptr = vocab.parse(pair_key)
        vocab.add(pair_key, pair_ptr.v)
        
        pairs[n, :, :] = pair_ptr.v

    # make scaling ramp for target output trajectories
    ramp = np.asarray([t / t_int for t in range(t_int)])
    ramp = np.concatenate((ramp, np.ones(t_total - t_int)))
    pairs = pairs * ramp[None, :, None]

    return roles, fills, pairs, vocab 

In [None]:
# create testing data
test_roles, test_fills, test_pairs, test_vocab = get_binding_data(n_inputs, dims, seed=seed, t_int=t_int)

# create arrays for PresentInput processes in model definition
role = np.vstack((test_vocab['ROLE_'+str(default)].v, np.zeros((1, dims))))
fill = np.vstack((test_vocab['FILL_'+str(default)].v, np.zeros((1, dims))))
cue = np.vstack((np.zeros((1, dims)), test_vocab['ROLE_'+str(default)].v))

In this slightly more complicated model, we'll add a circular convolution network and train both the parameters of this network and the recurrent connection on the memory ensemble. We'll also add second circular convolution network for performing retrieval with an input cue.

In [None]:
with nengo.Network(seed=seed) as net:
    nengo_dl.configure_settings(trainable=False)
    net.config[nengo.Ensemble].neuron_type = nengo.RectifiedLinear()
    net.config[nengo.Connection].synapse = None
    
    role_inp = nengo.Node(nengo.processes.PresentInput(role, p_time))
    fill_inp = nengo.Node(nengo.processes.PresentInput(fill, p_time))
    cue_inp = nengo.Node(nengo.processes.PresentInput(cue, p_time))
    
    cconv = nengo.networks.CircularConvolution(50, dims)
    ccorr = nengo.networks.CircularConvolution(50, dims, invert_b=True)
    memory = nengo.Ensemble(n_neurons, dims)

    nengo.Connection(role_inp, cconv.input_a)
    nengo.Connection(fill_inp, cconv.input_b)
    nengo.Connection(cconv.output, memory)
    
    nengo.Connection(memory, ccorr.input_a)
    nengo.Connection(cue_inp, ccorr.input_b)
    
    mem_conn = nengo.Connection(memory, memory, transform=0.5, synapse=0.1)
    
    net.config[mem_conn].trainable = True
    
    net.config[nengo.Probe].synapse = nengo.Lowpass(0.01)
    role_probe = nengo.Probe(role_inp) 
    fill_probe = nengo.Probe(fill_inp)
    cue_probe = nengo.Probe(cue_inp)
    conv_probe = nengo.Probe(cconv.output)
    memory_probe = nengo.Probe(memory)
    output_probe = nengo.Probe(ccorr.output)

## 5. Test baseline retrieval accuracy

We can use each role item in the collection of generated data as a cue, and each filler as a target for retrieval with this cue. Given as much, we can test retrieval accuracy using the same accuracy function that was used in the previous example: 

In [None]:
def accuracy(sim, probe, vocab, targets, t_step=-1):
    # provide a simulator instance, the probe being evaluated, the vocab,
    # the target vectors, and the time step at which to evaluate

    # get output at the given time step
    output = sim.data[probe][:, t_step, :]

    # compute similarity between each output and vocab item
    sims = np.dot(vocab.vectors, output.T)
    idxs = np.argmax(sims, axis=0)

    # check that the output is most similar to the target
    acc = np.mean(np.all(vocab.vectors[idxs] == targets[:, 0], axis=1))
    return acc

In [None]:
 # use fillers at midpoint of input process as targets
targets = test_fills[:, int(t_int / 2), :][:, None, :]
test_cues = np.flip(test_roles, 1) # use roles as cues

# define an input feed for running the model on the test data
test_inputs = {role_inp: test_roles, fill_inp: test_fills, cue_inp: test_cues}

with nengo_dl.Simulator(net, seed=seed, minibatch_size=n_inputs) as sim:
    sim.run(2 * p_time, input_feeds=test_inputs)

print('Retrieval accuracy: ', accuracy(sim, output_probe, test_vocab, targets))

These results indicate that the model performs accurate retrieval only some of the time. Keep in mind too that the retrieval is assumed to be successful if an output vector is more similar to the correct filler vector than it is to any of the other vectors in the vocabulary, even if the overall similarity is quite low. We can visualize the model's output trajectories to explore this point in more detail.

In [None]:
with nengo_dl.Simulator(net, seed=seed) as sim:
    sim.run(2 * p_time)

def plot_retrieval_example(sim, vocab):
    plot_vocab = vocab.create_subset([k for k in vocab.keys[:8]])
    plt.figure(figsize=(10, 14))
    plt.subplot(7, 1, 1)
    plt.plot(sim.trange(), nengo.spa.similarity(sim.data[role_probe], plot_vocab))
    plt.legend(plot_vocab.keys, fontsize='x-small', loc='right')
    plt.ylabel("Role Input")

    plt.subplot(7, 1, 2)
    plt.plot(sim.trange(), nengo.spa.similarity(sim.data[fill_probe], plot_vocab))
    plt.legend(plot_vocab.keys, fontsize='x-small', loc='right')
    plt.ylabel("Filler Input")

    plt.subplot(7, 1, 3)
    plt.plot(sim.trange(), nengo.spa.similarity(sim.data[conv_probe], plot_vocab))
    plt.legend(plot_vocab.keys, fontsize='x-small', loc='right')
    plt.ylabel("Binding")

    plt.subplot(7, 1, 4)
    plt.plot(sim.trange(), nengo.spa.similarity(sim.data[memory_probe], plot_vocab))
    plt.legend(plot_vocab.keys, fontsize='x-small', loc='right')
    plt.ylabel("Memory")

    out = test_pairs[default,:,:].reshape(t_int*2, dims)
    plt.subplot(7, 1, 5)
    plt.plot(sim.trange(), nengo.spa.similarity(out, plot_vocab))
    plt.legend(plot_vocab.keys, fontsize='x-small', loc='right')
    plt.ylabel("Target Memory")

    plt.subplot(7, 1, 6)
    plt.plot(sim.trange(), nengo.spa.similarity(sim.data[cue_probe], plot_vocab))
    plt.legend(plot_vocab.keys, fontsize='x-small', loc='right')
    plt.ylabel("Cue")

    plt.subplot(7, 1, 7)
    plt.plot(sim.trange(), nengo.spa.similarity(sim.data[output_probe], plot_vocab))
    plt.legend(plot_vocab.keys, fontsize='x-small', loc='right')
    plt.ylabel("Retrieval")
    plt.xlabel("time [s]")
    
plot_retrieval_example(sim, test_vocab)

## 6. Train the new model

Since our memory ensemble is doing such a poor job of retaining the bound input items over time, it is unsurprising that the retrieval accuracy is mediocre. Notice too that even if the rerieved output vector is most similar to the correct target vector, it is still not very distinct from most of the other vectors in the vocabulary. We can fix this by optimizing the model parameters as before.

In [None]:
do_training = False

sim = nengo_dl.Simulator(net, minibatch_size=50, seed=seed)

if do_training:
    optimizer = tf.train.MomentumOptimizer(learning_rate=0.00001, momentum=0.8, use_nesterov=True)

    train_roles, train_fills, train_pairs, _ = get_binding_data(4000, dims, seed, t_int)
    train_inputs = {role_inp: train_roles, fill_inp: train_fills}
    train_outputs = {memory_probe: train_pairs}

    print('Training loss before: ', sim.loss(train_inputs, train_outputs, 'mse'))
    sim.train(input_feed, output_feed, optimizer, n_epochs=40, objective='mse')
    print('Training loss after: ', sim.loss(train_inputs, train_outputs, 'mse'))
    
    sim.save_params('./mem_binding_params')

else:
    # download pretrained parameters
    urlretrieve(
        "https://drive.google.com/uc?export=download&id=0BxRAh6Eg1us4UlVnSTR2NEtMbkk",
        "mem_binding_params.zip")
    with zipfile.ZipFile("mem_binding_params.zip") as f:
        f.extractall()

sim.close()

## 7. Test improved retrieval accuracy

Recomputing our accuracy measure on the test inputs indicates that our optimization procedure has been very successful - retrieval is now nearly flawless.

In [None]:
with nengo_dl.Simulator(net, seed=seed, minibatch_size=n_inputs) as sim:
    sim.load_params('./mem_binding_params')
    sim.run(2 * p_time, input_feeds=test_inputs)

print('Retrieval accuracy: ', accuracy(sim, output_probe, test_vocab, targets))

We can visualize the change that results from training by plotting the similarities between the states of various network components and the vocabulary of known semantic pointers over time, as we did before.

In [None]:
with nengo_dl.Simulator(net, seed=seed) as sim:
    sim.load_params('./mem_binding_params')
    sim.run(2 * p_time)

plot_retrieval_example(sim, test_vocab)