# Fast-Classifying, High-Accuracy Spiking Deep Networks

In [1]:
import torch
from torchvision.transforms import ToTensor, Compose, Normalize
from torchvision.datasets import MNIST

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [2]:
# data alread 0-1 normalised, simply convert to tensor
transform_data = ToTensor()

# Load the data
batch_size = 100
train_dataset = MNIST(root = './mnist/', train = True, download = True, transform=transform_data)
test_dataset = MNIST(root = './mnist/', train = False, download = True, transform=transform_data)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size = batch_size, shuffle = False)

In [3]:
# fully connected neural network
class FC_Net(torch.nn.Module):
    def __init__(self, n_x: int, n_h: list, n_y: int):
        super().__init__()

        self.in_layer = torch.nn.Linear(n_x, n_h[0], bias=False)
        self.h1_layer = torch.nn.Linear(n_h[0], n_h[1], bias=False)
        self.h2_layer = torch.nn.Linear(n_h[1], n_y, bias=False)
        self.dropout = torch.nn.Dropout()
        self.activator = torch.nn.ReLU()

    def forward(self, x):
        # Flatten images
        x = x.view(x.size(0), -1)
        
        inp = self.dropout(self.activator(self.in_layer(x)))
        h1 = self.dropout(self.activator(self.h1_layer(inp)))
        y = self.activator(self.h2_layer(h1))

        return y
    
    def save_parameters(self, path: str):
        torch.save(self.in_layer, path + "0.pt")
        torch.save(self.h1_layer, path + "1.pt")
        torch.save(self.h2_layer, path + "2.pt")

In [4]:
def init_weights(m):
    if isinstance(m, torch.nn.Linear):
        torch.nn.init.uniform_(m.weight, -0.1, 0.1)

fc_net = FC_Net(784, [1200, 1200], 10).to(device)
fc_net.apply(init_weights)
optimiser = torch.optim.Adam(fc_net.parameters())

# optimiser used in the original paper seems to kill the gradients, so we're just going to use adam
# optimiser = torch.optim.SGD(fc_net.parameters(), lr=.01, momentum=0.5)

fc_net.train()

FC_Net(
  (in_layer): Linear(in_features=784, out_features=1200, bias=False)
  (h1_layer): Linear(in_features=1200, out_features=1200, bias=False)
  (h2_layer): Linear(in_features=1200, out_features=10, bias=False)
  (dropout): Dropout(p=0.5, inplace=False)
  (activator): ReLU()
)

In [5]:
# Training model
num_epochs = 15
for epoch in range(num_epochs):
    # Go trough all samples in train dataset
    for i, (images, labels) in enumerate(train_loader):
        # Get from dataloader and send to device
        images = images.to(device)
        labels = labels.to(device)
        # Forward pass
        outputs = fc_net(images)
        # Compute loss
        loss = torch.nn.functional.cross_entropy(outputs, labels)
        # Backward and optimize
        optimiser.zero_grad()
        loss.backward()
        optimiser.step()
        # Display
        if (i+1) % 100 == 0:
            print (f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')

Epoch [1/15], Step [100/600], Loss: 0.6633
Epoch [1/15], Step [200/600], Loss: 0.5264
Epoch [1/15], Step [300/600], Loss: 0.4307
Epoch [1/15], Step [400/600], Loss: 0.3239
Epoch [1/15], Step [500/600], Loss: 0.3713
Epoch [1/15], Step [600/600], Loss: 0.1365
Epoch [2/15], Step [100/600], Loss: 0.2165
Epoch [2/15], Step [200/600], Loss: 0.1633
Epoch [2/15], Step [300/600], Loss: 0.0998
Epoch [2/15], Step [400/600], Loss: 0.1761
Epoch [2/15], Step [500/600], Loss: 0.1617
Epoch [2/15], Step [600/600], Loss: 0.0589
Epoch [3/15], Step [100/600], Loss: 0.1698
Epoch [3/15], Step [200/600], Loss: 0.0258
Epoch [3/15], Step [300/600], Loss: 0.1525
Epoch [3/15], Step [400/600], Loss: 0.0464
Epoch [3/15], Step [500/600], Loss: 0.1399
Epoch [3/15], Step [600/600], Loss: 0.1295
Epoch [4/15], Step [100/600], Loss: 0.0571
Epoch [4/15], Step [200/600], Loss: 0.0932
Epoch [4/15], Step [300/600], Loss: 0.1227
Epoch [4/15], Step [400/600], Loss: 0.0579
Epoch [4/15], Step [500/600], Loss: 0.1779
Epoch [4/15

In [8]:
# Evaluate model accuracy on test after training
# Set model in eval mode!
fc_net.eval()
# Evaluate
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        # Get images and labels from test loader
        images = images.to(device)
        labels = labels.to(device)
        # Forward pass and predict class using max
        outputs = fc_net(images)
        _, predicted = torch.max(outputs.data, 1)
        # Check if predicted class matches label
        # and count numbler of correct predictions
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
# Compute final accuracy and display
accuracy = correct/total
print(f'Evaluation after training, test accuracy: {accuracy:.4f}')

Evaluation after training, test accuracy: 0.9799


In [9]:
fc_net.save_parameters("linear")