### 0. Import Libraries

In [1]:
import torch
import torchvision

### 1.1 Hyperparameters

In [2]:
inputs=784 + 1 # bias input
hidden=100
outputs=10
ticks = 3

batch_size=64
epochs = 100
lr = 0.001

### 1.2 Dataset

In [3]:
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST("/tmp/mnist_data", train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST("/tmp/mnist_data", train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size, shuffle=True)

### 2 MNN forward definition

In [4]:
def net(s, x):
    s[:,0:inputs-1] = x
    s[:,inputs] = 1
    t = torch.matmul(s, A)
    s = torch.relu(t)
    return s

### 3. Training Loop

In [5]:
neurons = inputs + hidden + outputs
A = torch.rand(neurons, neurons).requires_grad_(True)
optimizer = torch.optim.Adam([A], lr=lr)
loss_fn = torch.nn.CrossEntropyLoss()

for epoch in range(0, epochs):
    for i, (batch_X, batch_y) in enumerate(train_loader):
        state = torch.zeros(batch_X.shape[0], neurons)
        for t in range(0, ticks):
            state = net(state, batch_X.view(-1, 28*28))
            
        outs = state[:,inputs+hidden:neurons]
        loss = loss_fn(outs, batch_y)
        
        # Using pytorch backpropagation becouse our numpy implementaiton is inefficent
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        acc = (outs.max(1).indices == batch_y).float().mean()
        
        #mean_loss = 0.01*loss.item() + 0.99*mean_loss
        #mean_acc = 0.01*acc.item() + 0.99*mean_acc
        
        if i % 200 == 0:
            print("epoch: %d  loss: %f  acc: %f" % (epoch, loss.item(), acc.item()))


epoch: 0  loss: 3372.194092  acc: 0.062500
epoch: 0  loss: 26.474407  acc: 0.359375
epoch: 0  loss: 6.621606  acc: 0.421875
epoch: 0  loss: 8.578157  acc: 0.515625
epoch: 0  loss: 5.717260  acc: 0.390625
epoch: 1  loss: 27.072432  acc: 0.468750
epoch: 1  loss: 8.286741  acc: 0.375000
epoch: 1  loss: 2.454543  acc: 0.468750
epoch: 1  loss: 6.169957  acc: 0.500000
epoch: 1  loss: 1.970924  acc: 0.390625
epoch: 2  loss: 7.954168  acc: 0.421875
epoch: 2  loss: 1.510525  acc: 0.375000
epoch: 2  loss: 30.964554  acc: 0.281250
epoch: 2  loss: 6.005965  acc: 0.343750
epoch: 2  loss: 6.202227  acc: 0.312500
epoch: 3  loss: 25.473427  acc: 0.359375
epoch: 3  loss: 1.553840  acc: 0.390625
epoch: 3  loss: 3.897541  acc: 0.296875
epoch: 3  loss: 10.986524  acc: 0.328125
epoch: 3  loss: 1.484676  acc: 0.375000
epoch: 4  loss: 4.381802  acc: 0.359375
epoch: 4  loss: 12.233021  acc: 0.312500
epoch: 4  loss: 8.849322  acc: 0.203125
epoch: 4  loss: 1.799240  acc: 0.234375
epoch: 4  loss: 19.747335  acc:

epoch: 40  loss: 0.493053  acc: 0.796875
epoch: 40  loss: 0.513560  acc: 0.796875
epoch: 40  loss: 0.422060  acc: 0.812500
epoch: 40  loss: 0.373325  acc: 0.828125
epoch: 41  loss: 0.712073  acc: 0.734375
epoch: 41  loss: 0.823605  acc: 0.656250
epoch: 41  loss: 0.539759  acc: 0.765625
epoch: 41  loss: 0.522847  acc: 0.843750
epoch: 41  loss: 0.630685  acc: 0.718750
epoch: 42  loss: 0.578489  acc: 0.781250
epoch: 42  loss: 0.584081  acc: 0.750000
epoch: 42  loss: 0.532507  acc: 0.765625
epoch: 42  loss: 0.790715  acc: 0.671875
epoch: 42  loss: 0.468103  acc: 0.812500
epoch: 43  loss: 0.463063  acc: 0.812500
epoch: 43  loss: 0.681275  acc: 0.734375
epoch: 43  loss: 0.546149  acc: 0.781250
epoch: 43  loss: 0.657347  acc: 0.750000
epoch: 43  loss: 0.407549  acc: 0.828125
epoch: 44  loss: 0.672337  acc: 0.781250
epoch: 44  loss: 0.586630  acc: 0.812500
epoch: 44  loss: 0.462358  acc: 0.828125
epoch: 44  loss: 0.518188  acc: 0.781250
epoch: 44  loss: 0.469771  acc: 0.843750
epoch: 45  loss:

epoch: 80  loss: 0.094495  acc: 0.953125
epoch: 80  loss: 0.130607  acc: 0.968750
epoch: 80  loss: 0.174940  acc: 0.953125
epoch: 80  loss: 0.077542  acc: 0.968750
epoch: 81  loss: 0.128128  acc: 0.968750
epoch: 81  loss: 0.129849  acc: 0.953125
epoch: 81  loss: 0.142679  acc: 0.937500
epoch: 81  loss: 0.196485  acc: 0.953125
epoch: 81  loss: 0.110845  acc: 0.953125
epoch: 82  loss: 0.060586  acc: 0.968750
epoch: 82  loss: 0.079515  acc: 0.968750
epoch: 82  loss: 0.103322  acc: 0.968750
epoch: 82  loss: 0.130299  acc: 0.937500
epoch: 82  loss: 0.176548  acc: 0.953125
epoch: 83  loss: 0.080915  acc: 0.968750
epoch: 83  loss: 0.175514  acc: 0.937500
epoch: 83  loss: 0.170380  acc: 0.953125
epoch: 83  loss: 0.178971  acc: 0.968750
epoch: 83  loss: 0.094930  acc: 0.984375
epoch: 84  loss: 0.004027  acc: 1.000000
epoch: 84  loss: 0.045177  acc: 0.984375
epoch: 84  loss: 0.041374  acc: 0.984375
epoch: 84  loss: 0.182671  acc: 0.937500
epoch: 84  loss: 0.004714  acc: 1.000000
epoch: 85  loss:

### 4. Testing Loop

In [6]:
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score

y_true = []
y_pred = []

for i, (batch_X, batch_y) in enumerate(test_loader):
    state = torch.zeros(batch_X.shape[0], neurons)
    for t in range(0, ticks):
        state = net(state, batch_X.view(-1, 28*28))

    outs = state[:,inputs+hidden:neurons]
    preds = outs.max(1).indices
    
    y_true.extend(batch_y.detach().numpy().tolist())
    y_pred.extend(preds.detach().numpy().tolist())

print("Accuracy:")
print(accuracy_score(y_true, y_pred))
print("Confusion Matrix:")
print(confusion_matrix(y_true, y_pred))

Accuracy:
0.9414
Confusion Matrix:
[[ 959    1    3    0    1    3    8    3    2    0]
 [   2 1116    3    1    0    1    3    4    5    0]
 [  32    2  965    7    3    1    7    6    8    1]
 [  19    0   12  940    2   11    2   10   12    2]
 [  18    0    5    0  921    1    8    5    5   19]
 [  32    3    1   21    5  796   10    3   15    6]
 [  23    1    4    0    5   10  910    1    4    0]
 [  12    8   11    4    4    0    0  977    3    9]
 [  20    4    3   10    6    4    6   11  904    6]
 [  16    7    1    6   14    3    0   24   12  926]]
