In [108]:
# Note that this code uses libraries from the TENNLab Neuromorphic Framework
import eons
import neuro
import risp
from torchvision import datasets, transforms
from random import randint
import json

# some values
NUM_INPUTS = 784
NUM_OUTPUTS = 10
NUM_HIDDEN = 50
NUM_SYNAPSES = 150
NUM_NEURONS = NUM_INPUTS+NUM_OUTPUTS+NUM_HIDDEN 
MOA = neuro.MOA()
MOA.seed(132312, '') # NEED THESE OTHERWISE DOESN'T RANDOMIZE

In [109]:
# Fashion MNIST dataset
transform = transforms.ToTensor() # scale down to 0-1

train_dataset = datasets.FashionMNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.FashionMNIST(root='./data', train=False, transform=transform, download=True)

In [110]:
print(test_dataset[0][0].shape)

torch.Size([1, 28, 28])


In [111]:
risp_config = { #https://bitbucket.org/neuromorphic-utk/framework/src/93a1bbba2c0a31ca8f036ff2a89a4d6ffae5cec3/processors/risp/README.md#markdown-header-params
  "min_weight": -1,
  "max_weight": 1,
  "min_threshold": -1,
  "max_threshold": 1,
  "min_potential": -1,
  "max_delay": 10,
  "discrete": False
}

proc = risp.Processor(risp_config)

# configure eons
eons_param = {
    "starting_nodes": 4,
    "starting_edges": 10,
    "merge_rate": 0,
    "population_size": 1000,
    "multi_edges": 0,
    "crossover_rate": 0.5,
    "mutation_rate": 0.9,
    "selection_type": "tournament",
    "tournament_size_factor": 0.1,
    "tournament_best_net_factor": 0.9,
    "random_factor": 0.05,
    "num_mutations": 5,
    "node_mutations": { "Threshold": 1.0 },
    "net_mutations": { },
    "edge_mutations": { "Weight": 0.65 , "Delay": 0.35 },
    "num_best" : 2
}
eons_inst = eons.EONS(eons_param)

In [112]:
# create the network from which eons will base the others
net = neuro.Network()
net.set_properties(proc.get_network_properties())

# create input, output and hidden neurons
for i in range(NUM_INPUTS):
    node = net.add_node(i)
    node.set("Threshold", 1)
    net.randomize_node_properties(MOA, node)
    net.add_input(i)

for i in range(NUM_INPUTS, NUM_INPUTS+NUM_OUTPUTS):
    node = net.add_node(i)
    node.set("Threshold", 1)
    net.randomize_node_properties(MOA, node)
    net.add_output(i)

for i in range(NUM_INPUTS+NUM_OUTPUTS, NUM_INPUTS+NUM_OUTPUTS+NUM_HIDDEN):
    node = net.add_node(i)
    net.randomize_node_properties(MOA, node)

# create synapses randomly
for i in range(NUM_SYNAPSES):
    synapse = net.add_or_get_edge(randint(0, NUM_NEURONS-1), randint(0, NUM_NEURONS-1)) # randomly get 2 ids
    net.randomize_edge_properties(MOA, synapse)

In [113]:
# IGNORE THIS # IGNORE THIS # IGNORE THIS # IGNORE THIS # IGNORE THIS # IGNORE THIS # IGNORE THIS # IGNORE THIS # IGNORE THIS 

#encoder_params = {
#    "encoder_specs" : {
#            "default" : {
#                "bins"          : n_inputs,
#                "max_spikes"    : 1, 
#                "ov_interval"   : 2,
#                "intrabin"      : "spike_count",
 #               "interbin"      : "simple",
#                "spike_min"     : 1,
#               "spike_max"     : 1
 #           }
 #       },
 #   "dmin" : [0],
 #   "dmax" : [1],
 #   "use_encoders" : ["default"]
#}


# Create the encoder with the appropriate encoder parameters.  
# The first two parameters for Encoder Array correspond to the minimum and maximum values that can be obtained. 
#encoder = neuro.EncoderArray(encoder_params)

## OTHER

# Run one sample through our network
#encoder = neuro.SpikeEncoder('rate')
#encoder.set_overall_interval(10, True)

#encoder.get_spikes(1, 0, 1)
with open('dump.json', 'w') as f:
    f.write(str(net))

In [114]:
# Set up encoder, we'll use rate encoding

settings_encoder = {
    'dmin': [0 for i in range(784)],
    'dmax': [1 for i in range(784)],
    'default_interval': 10,
    "encoders": [ "rate" ]
}

encoder = neuro.EncoderArray(settings_encoder)

spikes = encoder.get_spikes(test_dataset[0][0].flatten().tolist())

# debug - looks like encoder array automatically creates our spikes at the given rate for each id
# I'm assuming it just guesses input ids based on element of our input list index?
for s in spikes:
    print(f'({s.id} at {s.time:.2f})', end=', ')

(215 at 0.00), (216 at 0.00), (219 at 0.00), (221 at 0.00), (221 at 6.89), (237 at 0.00), (238 at 0.00), (240 at 0.00), (240 at 9.44), (241 at 0.00), (241 at 3.04), (241 at 6.07), (241 at 9.11), (242 at 0.00), (249 at 0.00), (249 at 2.14), (249 at 4.29), (249 at 6.43), (249 at 8.57), (265 at 0.00), (268 at 0.00), (268 at 2.90), (268 at 5.80), (268 at 8.69), (269 at 0.00), (269 at 1.78), (269 at 3.57), (269 at 5.35), (269 at 7.13), (269 at 8.92), (270 at 0.00), (270 at 2.32), (270 at 4.64), (270 at 6.95), (270 at 9.27), (275 at 0.00), (276 at 0.00), (276 at 2.74), (276 at 5.48), (276 at 8.23), (277 at 0.00), (277 at 2.41), (277 at 4.81), (277 at 7.22), (277 at 9.62), (293 at 0.00), (295 at 0.00), (295 at 4.81), (295 at 9.62), (296 at 0.00), (296 at 1.98), (296 at 3.95), (296 at 5.93), (296 at 7.91), (296 at 9.88), (297 at 0.00), (297 at 2.12), (297 at 4.25), (297 at 6.37), (297 at 8.50), (298 at 0.00), (298 at 1.73), (298 at 3.47), (298 at 5.20), (298 at 6.94), (298 at 8.67), (299 at 0.

In [116]:
# Apply spikes 
proc.load_network(net)

for i in range(net.num_nodes()):
    proc.track_neuron_events(i)

proc.apply_spikes(spikes)

proc.run(1000)
v = proc.neuron_vectors()
for i in range(len(v)):
    if len(v[i]) != 0:
        print(f'{i} fired at {v[i]}')

#proc.neuron_last_fires()

19 fired at [8.0, 10.0, 12.0, 14.0, 16.0]
28 fired at [6.0, 7.0, 9.0, 10.0, 12.0, 14.0, 15.0]
51 fired at [5.0, 6.0, 8.0, 9.0, 11.0, 13.0, 14.0]
78 fired at [9.0, 11.0, 14.0, 17.0]
89 fired at [14.0]
102 fired at [4.0, 5.0, 7.0, 9.0, 10.0, 12.0]
179 fired at [6.0, 7.0, 9.0, 10.0, 12.0, 13.0, 15.0]
193 fired at [10.0, 14.0, 17.0]
215 fired at [0.0]
216 fired at [0.0]
219 fired at [0.0]
221 fired at [0.0, 6.0]
237 fired at [0.0]
238 fired at [0.0]
240 fired at [0.0, 9.0]
241 fired at [0.0, 3.0, 6.0, 9.0]
242 fired at [0.0]
249 fired at [0.0, 2.0, 4.0, 6.0, 8.0]
265 fired at [0.0]
268 fired at [0.0, 2.0, 5.0, 8.0]
269 fired at [0.0, 1.0, 3.0, 5.0, 7.0, 8.0]
270 fired at [0.0, 2.0, 4.0, 6.0, 9.0]
275 fired at [0.0]
276 fired at [0.0, 2.0, 5.0, 8.0]
277 fired at [0.0, 2.0, 4.0, 7.0, 9.0]
293 fired at [0.0]
295 fired at [0.0, 4.0, 9.0]
296 fired at [0.0, 1.0, 3.0, 5.0, 7.0, 9.0]
297 fired at [0.0, 2.0, 4.0, 6.0, 8.0]
298 fired at [0.0, 1.0, 3.0, 5.0, 6.0, 8.0]
299 fired at [0.0, 1.0, 2.0, 4.