In [None]:
import torch
import random
from torch.nn import PairwiseDistance

In [None]:
class NeuronEnsembles:
  
  def __init__(self, field_size=20000, num_ensembles=5000, sparsity=0.003):
    self.field_size = field_size
    self.num_ensembles = num_ensembles
    self.sparsity = sparsity
    self.generate_random_ensembles()
    
  def generate_random_ensembles(self):
    ensembles = torch.rand(self.field_size, self.num_ensembles)
    mask = ensembles <= self.sparsity
    ensembles[mask] = 1
    ensembles[~mask] = 0
    self.ensembles = ensembles
    
  def get_ensembles(self):
    return self.ensembles
  
  def get_random_ensemble(self):
    return self.ensembles[:,random.randint(0,self.num_ensembles-1)]
  
class Network:
  
  def __init__(self, field_1_size, field_2_size):
    self.field_1 = torch.zeros(field_1_size)
    self.field_2 = torch.zeros(field_2_size)
    
  def train(self, field_1_ensembles, field_2_ensembles):
    weights = torch.einsum('ib,jb->ij', (field_2_ensembles, field_1_ensembles))
    weights[weights>0] = 1
    self.weights = weights
    
  def set_field(self, activity, which):
    if which == 'field_1':
      self.field_1 = activity
    elif which == 'field_2':
      self.field_2 = activity
    else:
      raise Exception("'which' must be 'field_1' or 'field_2'")
      
  def read_field(self, which):
    if which == 'field_1':
      return self.field_1
    elif which == 'field_2':
      return self.field_2
    else:
      raise Exception("'which' must be 'field_1' or 'field_2'")

  def forward(self, threshold=3):
    self.field_2 = self.activate(torch.mv(self.weights, self.field_1), threshold)
  
  def backward(self, threshold=3):
    self.field_1 = self.activate(torch.mv(self.weights.t(), self.field_2), threshold)
  
  def cycle(self, threshold=3):
    '''Forward and backward pass of network'''
    self.forward(threshold)
    self.backward(threshold)
  
  def cycle_with(self, activity, threshold=3):
    self.set_field(activity, 'field_1')
    self.cycle()
    return self.read_field('field_1')
  
  def reset_activity(self):
    self.field_1 *= 0
    self.field_2 *= 0
  
  @staticmethod
  def activate(field, threshold):
    mask = field > (torch.max(field) - threshold)
    field[mask] = 1
    field[~mask] = 0
    return field
  
def dist(x1, x2, p=1):
  return torch.dist(x1,x2,p).item()

# Train and test the attractor network

In [None]:
# train the network
field_1_size = 20000
field_2_size = 20000
num_ensembles = 5000
field_1_ensembles = NeuronEnsembles(field_1_size, num_ensembles)
field_2_ensembles = NeuronEnsembles(field_2_size, num_ensembles)
net = Network(field_1_size, field_2_size)
net.train(field_1_ensembles.get_ensembles(), field_2_ensembles.get_ensembles())

In [None]:
# test the network
x = field_1_ensembles.get_random_ensemble()
net.set_field(x,'field_1')
net.cycle()
x_ = net.read_field('field_1')
print('distance between original and recovered ensemble: {}'.format(dist(x,x_)))

distance between original and recovered ensemble: 0.0


# Experiment #1: Pattern Corruption & Recovery
How significantly can we corrupt our initial vector and still reliably recover the correct one.

In [None]:
def corrupt_ensemble(ensemble, pct_disable=0.25, pct_activate=0.25):
  '''
  Corrupt a neural ensemble by disabling a percent of their active neurons and 
  activating a percent of inactive neurons. Percents are based on the number
  of active neurons in the original ensemble.
  '''
  ensemble = ensemble.clone()
  num_active = ensemble.sum().item()
  num_disable = int(num_active * pct_disable)
  num_activate = int(num_active * pct_activate)
  
  # randomly disable active neurons
  nz_idx = ensemble.nonzero()
  disable_idx = nz_idx[torch.randperm(nz_idx.size()[0])[:num_disable]]
  ensemble[disable_idx] = 0
  
  # randomly enable inactive neurons
  z_idx = (x == 0).nonzero()
  activate_idx = z_idx[torch.randperm(z_idx.size()[0])[:num_activate]]
  ensemble[activate_idx] = 1
  return ensemble

x = field_1_ensembles.get_random_ensemble()
print('Num active original: {}'.format(x.sum().item()))
x = corrupt_ensemble(x, pct_disable=0.5, pct_activate=0)
print('Num active after disable: {}'.format(x.sum().item()))
x = corrupt_ensemble(x, pct_disable=0, pct_activate=0.5)
print('Num active after enable: {}'.format(x.sum().item()))

Num active original: 68.0
Num active after disable: 34.0
Num active after enable: 51.0


In [None]:
x = field_1_ensembles.get_random_ensemble()
x_c = corrupt_ensemble(x, pct_disable=0.5, pct_activate=1.5)
x_ = net.cycle_with(x_c)
#x_ = net.cycle_with(x_)
dist(x,x_)

2.0

# Experiment #2: Capacity
What is the capacity of this network to store and recover patterns reliably

# Discussion
Let's discuss the intuition behind why this simple model works.

Suppose our network consists of only one pair $(\vec{x},\vec{y})$ of neuron ensembles. We construct a weight matrix $\vec{W} = \vec{y} \vec{x}^T$.

During the forward pass, we perform $\vec{W}\vec{x} = a(\vec{y^*}) \approx\vec{y}$ where $a()$ is our neural field activation function.

Then during the backward pass we perform $\vec{W}^T\vec{y^*} = a(\vec{x^*}) \approx\vec{x}$

It's easy to see how this works when there is just one pair of ensembles to learn, but why does it work with many? The magic is that we are taking advantage of high dimension spaces and sparse distributed representations.

>$\let\vec\mathbf$
>$\sum_{k} \vec{y_k} \vec{x_k}^T$



# References
https://redwood.berkeley.edu/wp-content/uploads/2018/01/knoblauch2010memory.pdf