import nengo import nengo_dl import numpy as np import tensorflow as tf def test_mse(y_true, y_pred): return tf.reduce_mean(tf.square(y_pred[:, -10:] - y_true[:, -10:])) # number of training and validation images (batch size) n_training = 1000 n_validation = 100 # network parameters seed = 0 n_neurons = 1000 minibatch_size = 10 n_steps = 1 # image parameters res = np.array([64, 64]) img_size = res[0] * res[1] * 3 # dummy data for simplification training_images = np.zeros((n_training, res[0]*res[1]*3)) training_targets = np.zeros((n_training, 3)) validation_images = np.zeros((n_validation, res[0]*res[1]*3)) validation_targets = np.zeros((n_validation, 3)) with nengo.Network(seed=seed) as net: image_input = nengo.Node(size_out=img_size, output=np.zeros(img_size)) dense_layer = nengo.Ensemble( n_neurons=n_neurons, dimensions=img_size) image_output = nengo.Node(size_in=3) nengo.Connection(image_input, dense_layer) connection_weights = np.zeros((3, dense_layer.n_neurons)) nengo.Connection(dense_layer.neurons, image_output, transform=connection_weights) output_probe = nengo.Probe(image_output, synapse=0.01) # turn off synapses for training to simplify for conn in net.all_connections: conn.synapse = None with net: output_probe_no_filter = nengo.Probe(image_output) # increase probe filter to account for removing interal filters output_probe.synapse = 0.04 nengo_dl.configure_settings(trainable=False) net.config[nengo.Ensemble].trainable = True # create data dicts training_images_dict = { image_input: training_images.reshape((n_training, n_steps, img_size)) } training_targets_dict = { output_probe_no_filter: training_targets.reshape((n_training, n_steps, 3)) } validation_images_dict = { image_input: validation_images.reshape((n_validation, n_steps, img_size)) } validation_targets_dict = { output_probe: validation_targets.reshape((n_validation, n_steps, 3)) } with nengo_dl.Simulator( net, minibatch_size=minibatch_size, seed=seed) as sim: print("Error before training:") sim.compile( loss={output_probe: test_mse}) sim.evaluate(validation_images_dict, validation_targets_dict) # run the training sim.compile(optimizer=tf.optimizers.RMSprop(0.01), loss={output_probe_no_filter: tf.losses.mse}) sim.fit(training_images_dict, training_targets_dict, epochs=25) print("Error after training:") sim.compile( loss={output_probe: test_mse}) sim.evaluate(validation_images_dict, validation_targets_dict)