In [16]:
import nengo
import numpy as np

from nengo_extras.data import load_mnist, one_hot_from_labels
from nengo_extras.matplotlib import tile
from nengo_extras.vision import Gabor, Mask

from torchvision.datasets import MNIST
import matplotlib.pyplot as plt

rng = np.random.RandomState(9)

# Playing with [image encoding](https://www.nengo.ai/nengo-extras/examples/mnist_single_layer.html) tutorial

In [81]:
from torchvision.transforms import v2
import torch

# download MNIST data using torch instead of nengo_extras.data.load_mnist, whose url seems down
dataset = MNIST(root = 'data3/', download = True, transform=v2.ToDtype(torch.float32) )


65.8%

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data3/MNIST/raw/train-images-idx3-ubyte.gz


100.0%


Extracting data3/MNIST/raw/train-images-idx3-ubyte.gz to data3/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data3/MNIST/raw/train-labels-idx1-ubyte.gz


100.0%
100.0%
100.0%


Extracting data3/MNIST/raw/train-labels-idx1-ubyte.gz to data3/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data3/MNIST/raw/t10k-images-idx3-ubyte.gz
Extracting data3/MNIST/raw/t10k-images-idx3-ubyte.gz to data3/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data3/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting data3/MNIST/raw/t10k-labels-idx1-ubyte.gz to data3/MNIST/raw



In [116]:
# training data X is nparray of shape (60000 images, 28 pixels, 28 pixels)
X = dataset.data.detach().numpy()

# normalize pixel values from [0,255] to [-1,1]
X = X / 255 * 2 - 1

# labels Y of shape (60000 images, 10 categories)
Y = one_hot_from_labels(dataset.targets.detach().numpy(), classes=10)

In [126]:
X = X.reshape((60000,784))


In [127]:
X.shape

(60000, 784)

In [128]:
# --- set up network parameters
n_vis = X.shape[1]
n_out = Y.shape[1]
print(n_vis)
print(n_out)

784
10


In [129]:


with nengo.Network(seed=0) as model:
    a = nengo.Ensemble(
        n_neurons=1000,             # the more the better
        dimensions=n_vis,                  # number of dimensions that can be input,
        # also number of dimensions representable by output
        eval_points=X,        # FIXME X_train is not 10D?
        neuron_type=nengo.LIFRate(),   # Leaky-Integrate-and-Fire but nonspiking neurons
        intercepts=nengo.dists.Choice([0.1]),  # choose turn-on = 0.1 for all neurons
        max_rates=nengo.dists.Choice([100]),   # give all neurons same firing power
    )

    v = nengo.Node(size_in=n_out)

    conn = nengo.Connection(
        a, v, synapse=None, eval_points=X,
        function=Y,                           # target function: return label when evaluated at each image
        solver=nengo.solvers.LstsqL2(reg=0.01)      # optimize by solving L2-regularized least squares
    )

def get_outs(simulator, images):
    # encode the images to get the ensemble activations
    _, acts = nengo.utils.ensemble.tuning_curves(a, simulator, inputs=images)

    # decode the ensemble activities using the connection's decoders
    return np.dot(acts, simulator.data[conn].weights.T)

In [130]:
with nengo.Simulator(model) as sim:
    get_outs(sim, X)

In [132]:
get_outs(sim,X).shape

(60000, 10)

In [133]:
a.encoders

ScatteredHypersphere(surface=True)