In [None]:
!pip install nengo nengo-dl

In [2]:
import nengo
import nengo_dl
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from sklearn.metrics import classification_report

In [3]:
class Dataset:
  NUM_CLASSES = 10
  DIM = 28 * 28

  def __init__(self, time_per_data: float):
    (train_x, train_y), (test_x, test_y) = tf.keras.datasets.mnist.load_data()

    self.train_x = train_x.reshape([train_x.shape[0], self.DIM])
    self.train_y = np.eye(self.NUM_CLASSES)[train_y]
    self.test_x = test_x.reshape([test_x.shape[0], self.DIM])
    self.test_y = np.eye(self.NUM_CLASSES)[test_y]

    self.time_per_data = time_per_data
    self._is_test = False
    self._start_from = 0.0

  def __len__(self):
    return len(self.test_y) if self._is_test else len(self.train_y)

  def train(self, start_from: float=0.0):
    self._is_test = False
    self._start_from = start_from
  
  def test(self, start_from: float=0.0):
    self._is_test = True
    self._start_from = start_from

  def get_x(self, t: float) -> np.ndarray:
    t += self._start_from
    x = self.test_x if self._is_test else self.train_x
    idx = int(t // self.time_per_data) % len(x)
    return x[idx]

  def get_y(self, t: float) -> np.ndarray:
    t += self._start_from
    y = self.test_y if self._is_test else self.test_y
    idx = int(t // self.time_per_data) % len(y)
    return y[idx]

  @property
  def duration(self) -> float:
    return len(self) * self.time_per_data

In [4]:
time_per_data = 0.1
dataset = Dataset(time_per_data)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [5]:
model = nengo.Network()
with model:
  input_node = nengo.Node(dataset.get_x, size_out=dataset.DIM)
  target_node = nengo.Node(dataset.get_y, size_out=dataset.NUM_CLASSES)
  input = nengo.Ensemble(n_neurons=768, dimensions=dataset.DIM)
  target = nengo.Ensemble(n_neurons=64, dimensions=dataset.NUM_CLASSES)

  nengo.Connection(input_node, input)
  nengo.Connection(target_node, target)

  pred = nengo.Ensemble(n_neurons=64, dimensions=dataset.NUM_CLASSES)
  pred_conn = nengo.Connection(input, pred, transform=np.ones([dataset.NUM_CLASSES, dataset.DIM]))

  pred_p = nengo.Probe(pred, synapse=0.01)
  target_p = nengo.Probe(target, synapse=0.01)

In [6]:
def evaluate(sim: nengo_dl.Simulator):
  """Evaluate simulate results, calculate and print metrics"""
  timestep_of_data = int(time_per_data / sim.dt)
  timestep = int(sim.time / sim.dt)
  timestep -= timestep % timestep_of_data

  preds = sim.data[pred_p].reshape([timestep // timestep_of_data, timestep_of_data, -1]).mean(axis=1).argmax(-1)
  labels = sim.data[target_p].reshape([timestep // timestep_of_data, timestep_of_data, -1]).mean(axis=1).argmax(-1)

  return labels, preds

# Test Before Training

In [7]:
with nengo_dl.Simulator(model, device='/cpu:0') as sim:
  dataset.test()
  # Use Only 100 examples for speed
  sim.run(dataset.time_per_data * 100)
  print(classification_report(*evaluate(sim)))

Build finished in 0:00:02                                                      
Optimization finished in 0:00:00                                               
Construction finished in 0:00:00                                               
Simulation finished in 0:00:07                                                 
              precision    recall  f1-score   support

           0       0.00      0.00      0.00         8
           1       0.00      0.00      0.00        14
           2       0.00      0.00      0.00         8
           3       0.11      1.00      0.20        11
           4       0.00      0.00      0.00        14
           5       0.00      0.00      0.00         7
           6       0.00      0.00      0.00        10
           7       0.00      0.00      0.00        15
           8       0.00      0.00      0.00         2
           9       0.00      0.00      0.00        11

    accuracy                           0.11       100
   macro avg       0.01      0

  _warn_prf(average, modifier, msg_start, len(result))


# Train & Test

In [8]:
learning_rate = 2e-4
epochs = 5
batch_size = 1000
num_train_data_divider = 12

model_backup_path = "./model.bak"

In [9]:
with model:
  err = nengo.Ensemble(64, dimensions=dataset.NUM_CLASSES)
  nengo.Connection(pred, err)
  nengo.Connection(target, err, transform=-1)

  pred_conn.learning_rule_type = nengo.learning_rules.PES(learning_rate=learning_rate)
  err_conn = nengo.Connection(err, pred_conn.learning_rule)  

In [10]:
with nengo_dl.Simulator(model, device='/cpu:0') as sim:
  sim.save_params(model_backup_path)

for i in range(1, epochs+1):
  train_labels = []
  train_preds = []

  print(f"[+] Start training {i} epoch")
  dataset.train()
  for j in range(len(dataset) // batch_size // num_train_data_divider):
    with nengo_dl.Simulator(model, device="/cpu:0") as sim:
      # Load to ignore learning with test dataset
      sim.load_params(model_backup_path)

      dataset.train(start_from=batch_size * j * dataset.time_per_data)
      sim.run(batch_size * dataset.time_per_data)
      labels, preds = evaluate(sim)
      train_labels.extend(labels)
      train_preds.extend(preds)

      # Save model params
      sim.save_params(model_backup_path)
  print("[+] Training Metrics")
  print(classification_report(train_labels, train_preds))
  
  test_labels = []
  test_preds = []
  dataset.test()
  print(f"[+] Start Test {i} epoch")
  for j in range(len(dataset) // batch_size):
    with nengo_dl.Simulator(model, device="/cpu:0") as sim:
      dataset.test(start_from = batch_size * j * dataset.time_per_data)
      sim.run(batch_size * dataset.time_per_data)
      labels, preds = evaluate(sim)
      test_labels.extend(labels)
      test_preds.extend(preds)
  
  print("[+] Test Metrics")
  print(classification_report(test_labels, test_preds))

Build finished in 0:00:02                                                      
Optimization finished in 0:00:00                                               
Construction finished in 0:00:00                                               
[+] Start training 1 epoch
Build finished in 0:00:02                                                      
Optimization finished in 0:00:00                                               
Construction finished in 0:00:00                                               
Simulation finished in 0:01:10                                                 
Build finished in 0:00:02                                                      
Optimization finished in 0:00:00                                               
Construction finished in 0:00:00                                               
Simulation finished in 0:01:27                                                 
Build finished in 0:00:02                                                      
Optimization 



Simulation finished in 0:01:10                                                 
Build finished in 0:00:02                                                      
Optimization finished in 0:00:00                                               
Construction finished in 0:00:00                                               




Simulation finished in 0:01:11                                                 
[+] Training Metrics
              precision    recall  f1-score   support

           0       0.03      0.10      0.04        21
           1       0.30      0.06      0.11       500
           2       0.00      0.00      0.00         0
           3       0.00      0.00      0.00         0
           4       0.09      0.27      0.14       474
           5       0.23      0.47      0.31       419
           6       0.05      0.01      0.02       654
           7       0.07      0.05      0.06       532
           8       0.00      0.00      0.00         0
           9       0.41      0.14      0.21      2400

    accuracy                           0.15      5000
   macro avg       0.12      0.11      0.09      5000
weighted avg       0.27      0.15      0.16      5000

[+] Start Test 1 epoch
|                     Building network (0%)                    | ETA:  --:--:--

  _warn_prf(average, modifier, msg_start, len(result))


Build finished in 0:00:02                                                      
Optimization finished in 0:00:00                                               
Construction finished in 0:00:00                                               
Simulation finished in 0:01:11                                                 
Build finished in 0:00:02                                                      
Optimization finished in 0:00:00                                               
Construction finished in 0:00:00                                               
Simulation finished in 0:01:09                                                 
Build finished in 0:00:18                                                      
Optimization finished in 0:00:00                                               
Construction finished in 0:00:00                                               
Simulation finished in 0:01:11                                                 
Build finished in 0:00:02               

  _warn_prf(average, modifier, msg_start, len(result))


Build finished in 0:00:02                                                      
Optimization finished in 0:00:00                                               
Construction finished in 0:00:00                                               
Simulation finished in 0:01:12                                                 
Build finished in 0:00:15                                                      
Optimization finished in 0:00:00                                               
Construction finished in 0:00:00                                               
Simulation finished in 0:01:12                                                 
Build finished in 0:00:12                                                      
Optimization finished in 0:00:00                                               
Construction finished in 0:00:00                                               
Simulation finished in 0:01:11                                                 
Build finished in 0:00:02               