In [39]:
%matplotlib notebook
import numpy as np
import matplotlib.pyplot as plt
rng = np.random.default_rng()

%load_ext autoreload
%autoreload 2

from brain import k_cap, idx_to_vec, FFArea, RecurrentArea, RandomChoiceArea, ScaffoldNetwork, FSMNetwork, PFANetwork

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Forming an assembly

### Initialize a brain area

In [40]:
n_inputs = 1000
n_neurons = 1000
cap_size = 30
density = 0.25
plasticity = 1e-1

In [41]:
brain_area = RecurrentArea(n_inputs, n_neurons, cap_size, density, plasticity)

In [42]:
stimulus = np.arange(cap_size)

### Form an assembly by presenting the stimulus several times

In [140]:
n_rounds = 10

activations = np.zeros((n_rounds, n_neurons))

brain_area.inhibit()
for i in range(n_rounds):
    brain_area.forward(stimulus)
    activations[i] = brain_area.read(dense=True)

### Plot activations during formation

In [147]:
idx = activations.sum(axis=0).argsort()[::-1]

fig, axes = plt.subplots(n_rounds, figsize=(4, 6), sharex=True, sharey=True)

for i in range(n_rounds):
    axes[i].bar(np.arange(5*cap_size), activations[i, idx[:5 * cap_size]])
    axes[i].set_xticks([])
    axes[i].set_yticks([])
    axes[i].spines['top'].set_visible(False)
    axes[i].spines['right'].set_visible(False)
    axes[i].spines['left'].set_visible(False)

axes[n_rounds // 2].set_ylabel('Round')
axes[-1].set_xlabel('Firing neurons')
fig.tight_layout()

<IPython.core.display.Javascript object>

## Classifying stimulus classes

### Generate some samples from each stimulus class

In [77]:
n_classes = 3
n_samples_train = 10
n_samples_test = 200

class_vecs = np.full((n_classes, n_neurons), 1.8 * cap_size / n_neurons)
class_vecs[np.arange(n_classes)[:, np.newaxis], np.arange(n_classes * cap_size).reshape(n_classes, -1)] = 0.9

samples_train = rng.random((n_classes, n_samples_train, n_neurons)) < class_vecs[:, np.newaxis, :]
samples_test = rng.random((n_classes, n_samples_test, n_neurons)) < class_vecs[:, np.newaxis, :]

brain_area.reset()

### Visualize sample means

In [78]:
fig, ax = plt.subplots(figsize=(8, 2), sharex=True, sharey=True)
for i in range(n_classes):
    ax.bar(np.arange(5 * cap_size), samples_test[i].mean(axis=0)[:5*cap_size], label='Class {}'.format(i))
ax.legend()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_xticks([])
ax.set_xlabel('Input neuron')
ax.set_ylabel('Fraction of firing')

<IPython.core.display.Javascript object>

Text(0, 0.5, 'Fraction of firing')

In [79]:
assembly_support = np.zeros((n_classes, n_neurons))

for i in range(n_classes):
    brain_area.inhibit()
    for j in range(n_samples_train):
        brain_area.forward(np.nonzero(samples_train[i, j]))
    assembly_support[i] = brain_area.read(dense=True)

### Visualize assemblies

In [80]:
idx = (assembly_support.T @ np.arange(n_classes, 0, -1)).argsort()[::-1] 

fig, axes = plt.subplots(n_classes, figsize=(8, 2 * n_classes), sharex=True, sharey=True)
for i in range(n_classes):
    axes[i].bar(np.arange(5 * cap_size), assembly_support[i, idx[:5 * cap_size]], label='Class {}'.format(i), color='C{}'.format(i))
for ax in axes:
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_xticks([])
    
axes[-1].set_xlabel('Brain area neuron')

<IPython.core.display.Javascript object>

Text(0.5, 0, 'Brain area neuron')

In [81]:
test_overlaps = np.zeros((n_classes, n_samples_test, n_classes))

for i in range(n_classes):
    for j in range(n_samples_test):
        brain_area.inhibit()
        brain_area.forward(np.nonzero(samples_test[i, j]), update=False)
        test_overlaps[i, j] = brain_area.read(dense=True) @ assembly_support.T

In [82]:
accuracy = np.mean(test_overlaps.argmax(axis=-1) == np.arange(n_classes)[:, np.newaxis], axis=-1)
for i in range(n_classes):
    print('Class {:d} accuracy: {:%}'.format(i, accuracy[i]))

Class 0 accuracy: 99.000000%
Class 1 accuracy: 100.000000%
Class 2 accuracy: 100.000000%


In [83]:
fig, ax = plt.subplots()
ax.bar(np.arange(n_classes)-0.25, test_overlaps[0].mean(axis=0) / cap_size, width=0.25, label='Class 0')
ax.bar(np.arange(n_classes), test_overlaps[1].mean(axis=0) / cap_size, width=0.25, label='Class 1')
ax.bar(np.arange(n_classes)+0.25, test_overlaps[2].mean(axis=0) / cap_size, width=0.25, label='Class 2')
ax.set_xticks(np.arange(n_classes))
ax.legend(loc=(1., 0.05))
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_xlabel('Class')
ax.set_ylabel('Overlap with assembly')
fig.tight_layout()

<IPython.core.display.Javascript object>

## Memorizing sequences of inputs

### Initialize simple and scaffolded networks

In [2]:
n_inputs = 1000
n_neurons = 1000
cap_size = 30
density = 0.4
plasticity = 1e-1

simple_seq_area = RecurrentArea(n_inputs, n_neurons, cap_size, density, plasticity)
scaff_seq_net = ScaffoldNetwork(n_inputs, n_neurons, cap_size, density, plasticity)

### Define a sequence of inputs

In [3]:
seq_len = 25
sequence = np.arange(seq_len * cap_size).reshape(seq_len, cap_size)

### Train the models by repeatedly presenting the sequence, testing recall after each presentation

In [4]:
n_presentations = 10

simple_seq_assemblies = np.zeros((seq_len, n_neurons))
scaff_seq_assemblies = np.zeros((seq_len, n_neurons))
simple_seq_recall = np.zeros((n_presentations, seq_len))
scaff_seq_recall = np.zeros((n_presentations, seq_len))

for j in range(n_presentations):
    simple_seq_area.inhibit()
    scaff_seq_net.inhibit()
    for i in range(seq_len):
        simple_seq_area.forward(sequence[i])
        scaff_seq_net.forward(sequence[i])
        if j == 0:
            simple_seq_assemblies[i] = simple_seq_area.read(dense=True)
            scaff_seq_assemblies[i] = scaff_seq_net.read(dense=True)
    
    simple_seq_area.inhibit()
    scaff_seq_net.inhibit()
#     simple_seq_area.normalize()
#     scaff_seq_net.normalize()
    
    
    simple_seq_area.set_input(sequence[0])
    scaff_seq_net.set_input(sequence[0])
    for i in range(seq_len):
        simple_seq_area.step(update=False)
        scaff_seq_net.step(update=False)
        
        simple_seq_recall[j, i] = simple_seq_area.read(dense=True) @ simple_seq_assemblies[i]
        scaff_seq_recall[j, i] = scaff_seq_net.read(dense=True) @ scaff_seq_assemblies[i]

### Plot the results

In [5]:
fig, axes = plt.subplots(1, 3, figsize=(10, 4), sharey=True)
axes[0].plot(np.arange(seq_len), simple_seq_recall[2] / cap_size)
axes[0].plot(np.arange(seq_len), scaff_seq_recall[2] / cap_size)
axes[0].set_title('Recall after 3 presentations')
axes[0].set_ylabel('Recall fraction')
axes[0].set_xlabel('Sequence item')

axes[1].plot(np.arange(seq_len), simple_seq_recall[5] / cap_size)
axes[1].plot(np.arange(seq_len), scaff_seq_recall[5] / cap_size)
axes[1].set_title('Recall after 6 presentations')
axes[1].set_xlabel('Sequence item')

axes[2].plot(np.arange(n_presentations)+1, simple_seq_recall[:, -1] / cap_size, label='Simple')
axes[2].plot(np.arange(n_presentations)+1, scaff_seq_recall[:, -1] / cap_size, label='Scaffold')
axes[2].set_title('Recall of last element during training')
axes[2].set_xlabel('Presentation')
axes[2].legend()

fig.tight_layout()

<IPython.core.display.Javascript object>

## Simulate a FSM (DFA) to recognize numbers divisible by 3

We will simulate the following FSM, which recognizes numbers divisible by 3. It does this by tracking the sum of the digits mod 3 and accepting if the result is 0.

![fsm_0modthree.png](attachment:fsm_0modthree.png)

### Define a network to simulate the FSM

In [43]:
n_symbol_neurons = 1000
n_state_neurons = 500
n_arc_neurons = 5000
cap_size = 70
density = 0.2
plasticity = 1e-1

fsm_net = FSMNetwork(n_symbol_neurons, n_state_neurons, n_arc_neurons, cap_size, density, plasticity)

n_symbols = 10 + 1
n_states = 3 + 2

symbols = np.arange(n_symbols * cap_size).reshape(n_symbols, cap_size)
states = np.arange(n_states * cap_size).reshape(n_states, cap_size)

n_arcs = 11 * 3

### Define the FSM (via its transitions)

In [44]:
transition_list = []

for mod in range(3):
    for digit in range(10):
        transition_list += [[mod, digit, (mod + digit) % 3]]

transition_list += [[0, 10, 3], [1, 10, 4], [2, 10, 4]]

### Train the model by repeatedly presenting each transition

In [45]:
n_presentations = 15

arcs = np.zeros((len(transition_list), cap_size), dtype=int)

for i in range(n_presentations):
    for j, transition in enumerate(transition_list):
        fsm_net.train(symbols[transition[1]], states[transition[0]], states[transition[2]])
        arcs[j] = fsm_net.arc_area.read()

state_overlaps = np.zeros((len(transition_list), n_states))
for i, transition in enumerate(transition_list):
    fsm_net.inhibit()
    fsm_net.arc_area.forward([symbols[transition[1]], states[transition[0]]], update=False)
    fsm_net.state_area.forward(fsm_net.arc_area.read(), update=False)
    state_overlaps[i] = idx_to_vec(states, n_state_neurons) @ fsm_net.read(dense=True)

### Test the model by presenting a string of digits

Enter a string of digits separated by ', ', like so: 3, 0, 4, 7, 1

In [None]:
raw = input('Enter a string of digits: ')
sequence = [int(x) for x in raw.split(', ')] + [10]

In [9]:
outputs = np.zeros((len(sequence)+1, n_state_neurons))
fsm_net.state_area.fire(states[0], update=False)
outputs[0] = fsm_net.read(dense=True)
for i in range(len(sequence)):
    fsm_net.forward(symbols[sequence[i]], update=False)
    outputs[i+1] = fsm_net.read(dense=True)

### Plot the result

In [10]:
symbol_overlaps = np.zeros((len(sequence)+1, n_symbols))
symbol_overlaps[np.arange(len(sequence)), sequence] = 1.

state_overlaps = outputs @ idx_to_vec(states, n_state_neurons).T / cap_size

fig, axes = plt.subplots(len(sequence)+1, 2, figsize=(10, 4), sharey=True)
for i in range(len(sequence) + 1):
    axes[i, 0].bar(np.arange(n_symbols), symbol_overlaps[i])
    axes[i, 1].bar(np.arange(n_states), state_overlaps[i])
    axes[i, 0].set_xticks(np.arange(n_symbols))
    axes[i, 1].set_xticks(np.arange(n_states))
    
axes[-1, 0].set_xticklabels([i for i in range(n_symbols-1)] + ['□'])
axes[-1, 1].set_xticklabels(['mod 0', 'mod 1', 'mod 2', 'Accept', 'Reject'])

axes[len(sequence) // 2, 0].set_ylabel('Round')
axes[0, 0].set_title('Symbol Area')
axes[0, 1].set_title('State Area')
    
for ax in axes.flatten():
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.set_yticks([])

<IPython.core.display.Javascript object>

## Simulate a PFA to generate simple sentences

We will train a NEMO network to simulate the following PFA, which generates sentences like "the boy throws a ball", "a dog catches the ball", ....

![simple_sentence_pfa.png](attachment:simple_sentence_pfa.png)

One transition out of each state is sampled uniformly at random.

### Define a network of brain areas

In [102]:
n_symbol_neurons = 1000
n_state_neurons = 500
n_arc_neurons = 5000
n_random_neurons = 1000
cap_size = 70
density = 0.25
plasticity = 0.1

pfa_net = PFANetwork(n_symbol_neurons, n_state_neurons, n_arc_neurons, n_random_neurons, cap_size, density, plasticity)

### Define the PFA (via its transitions)

In [103]:
lexicon = np.array(['the', 'a', 'boy', 'dog', 'throws', 'chases', 'ball', 'stick'])

transition_list = [[0, 0, 1, 0], [0, 1, 1, 1], # subject article
                   [1, 0, 2, 2], [1, 1, 3, 3], # subject
                   [2, 0, 4, 4], [2, 1, 4, 4], # verb throws
                   [3, 0, 4, 5], [3, 1, 4, 5], # verb chases
                   [4, 0, 5, 0], [4, 1, 5, 1], # object article
                   [5, 0, 6, 6], [5, 1, 6, 7]] # object

In [104]:
n_states = 7
n_symbols = len(lexicon)
n_arcs = len(transition_list)

In [105]:
states = np.arange(n_states * cap_size).reshape(n_states, cap_size)
symbols = np.arange(n_symbols * cap_size).reshape(n_symbols, cap_size)

### Train the model by repeatedly presenting each transition

In [106]:
n_presentations = 15

for i in range(n_presentations):
    for j, transition in enumerate(transition_list):
        pfa_net.train(states[transition[0]], transition[1], states[transition[2]], symbols[transition[3]])

### Sample from the model

In [125]:
n_trials = 10

symbol_outputs = np.zeros((n_trials, 5, n_symbol_neurons))
state_outputs = np.zeros((n_trials, 5, n_state_neurons))

for i in range(n_trials):
    pfa_net.state_area.fire(states[0])
    for j in range(5):
        pfa_net.step()
        symbol_outputs[i, j] = pfa_net.read(dense=True)
        state_outputs[i, j] = pfa_net.state_area.read(dense=True)

In [126]:
symbol_overlaps = symbol_outputs[:, :, :cap_size*n_symbols].reshape(n_trials, 5, n_symbols, cap_size).sum(axis=-1)
state_overlaps = state_outputs[:, :, :cap_size*n_states].reshape(n_trials, 5, n_states, cap_size).sum(axis=-1)

In [127]:
output_symbols = lexicon[symbol_overlaps.argmax(axis=-1)]

In [128]:
output_symbols

array([['a', 'boy', 'throws', 'the', 'ball'],
       ['a', 'boy', 'throws', 'the', 'stick'],
       ['the', 'boy', 'throws', 'a', 'ball'],
       ['a', 'dog', 'chases', 'a', 'stick'],
       ['the', 'boy', 'throws', 'a', 'stick'],
       ['the', 'boy', 'throws', 'the', 'ball'],
       ['the', 'dog', 'chases', 'a', 'ball'],
       ['a', 'boy', 'throws', 'a', 'ball'],
       ['a', 'boy', 'throws', 'the', 'stick'],
       ['the', 'boy', 'throws', 'a', 'ball']], dtype='<U6')

### Plot the activations for one sample

In [136]:
trial = 0

fig, axes = plt.subplots(6, 2, figsize=(10, 4), sharey=True)
axes[0, 0].bar(np.arange(n_symbols), np.zeros(n_symbols))
axes[0, 1].bar(np.arange(n_states), [cap_size, 0, 0, 0, 0, 0, 0])
axes[0, 0].set_xticks(np.arange(n_symbols))
axes[0, 1].set_xticks(np.arange(n_states))

for i in range(5):
    axes[i+1, 0].bar(np.arange(n_symbols), symbol_overlaps[trial, i])
    axes[i+1, 1].bar(np.arange(n_states), state_overlaps[trial, i])
    axes[i+1, 0].set_xticks(np.arange(n_symbols))
    axes[i+1, 1].set_xticks(np.arange(n_states))
    

for ax in axes:
    ax[0].set_xticklabels([])
    ax[1].set_xticklabels([])
    
axes[-1, 0].set_xticklabels(lexicon)
axes[-1, 1].set_xticklabels(['subj\nart', 'subj', 'verb\nchase', 'verb\nthrow', 'obj\nart', 'obj', 'end'])

axes[len(sequence) // 2, 0].set_ylabel('Round')
axes[0, 0].set_title('Symbol Area')
axes[0, 1].set_title('State Area')
    
for ax in axes.flatten():
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.set_yticks([])
    
fig.tight_layout()

<IPython.core.display.Javascript object>