## batch_size tests

In [1]:
import snntorch as snn
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import itertools
import matplotlib.pyplot as plt

In [2]:
# Network Architecture
num_inputs = 28*28
num_hidden = 1000
num_outputs = 10

# Training Parameters
batch_size=128
data_path='/data/mnist'

# Temporal Dynamics
num_steps = 25
alpha = 0.7
beta = 0.8

dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [3]:
# Define a transform
transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)

In [4]:
# Create DataLoaders
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=False)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=False)

In [5]:
# Define Network
class Net(nn.Module):
    def __init__(self):
        super().__init__()

        # Initialize layers
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Synaptic(alpha=alpha, beta=beta)
        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.lif2 = snn.Synaptic(alpha=alpha, beta=beta)

    def forward(self, x):

        # Initialize hidden states and outputs at t=0
        syn1, mem1 = self.lif1.init_synaptic(num_hidden)  #  test: remove spk
        syn2, mem2 = self.lif2.init_synaptic(num_outputs)
        
        # Record the final layer
        spk2_rec = []
        mem2_rec = []

        for step in range(num_steps):
            cur1 = self.fc1(x)
            spk1, syn1, mem1 = self.lif1(cur1, syn1, mem1)
            cur2 = self.fc2(spk1)
            spk2, syn2, mem2 = self.lif2(cur2, syn2, mem2)

            spk2_rec.append(spk2)
            mem2_rec.append(mem2)

        return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)

In [6]:
net = Net().to(device)

`Stein` has been deprecated and will be removed in a future version. Use `Synaptic` instead.
`Stein` has been deprecated and will be removed in a future version. Use `Synaptic` instead.


In [7]:
def print_batch_accuracy(data, targets, train=False):
    output, _ = net(data.view(batch_size, -1))
    _, idx = output.sum(dim=0).max(1)
    acc = np.mean((targets == idx).detach().cpu().numpy())

    if train:
        print(f"Train Set Accuracy: {acc}")
    else:
        print(f"Test Set Accuracy: {acc}")

def train_printer():
    print(f"Epoch {epoch}, Minibatch {minibatch_counter}")
    print(f"Train Set Loss: {loss_hist[counter]}")
    print(f"Test Set Loss: {test_loss_hist[counter]}")
    print_batch_accuracy(data_it, targets_it, train=True)
    print_batch_accuracy(testdata_it, testtargets_it, train=False)
    print("\n")

In [8]:
optimizer = torch.optim.Adam(net.parameters(), lr=2e-4, betas=(0.9, 0.999))
log_softmax_fn = nn.LogSoftmax(dim=-1)
loss_fn = nn.NLLLoss()

In [9]:
loss_hist = []
test_loss_hist = []
counter = 0

# Outer training loop
for epoch in range(1):
    minibatch_counter = 0
    train_batch = iter(train_loader)

    # Minibatch training loop
    for data_it, targets_it in train_batch:
        data_it = data_it.to(device)
        targets_it = targets_it.to(device)

        batch_size = data_it.size()[0]  #  test: remove spk
        
        spk_rec, mem_rec = net(data_it.view(batch_size, -1))
        log_p_y = log_softmax_fn(mem_rec)
        loss_val = torch.zeros((1), dtype=dtype, device=device)

        # Sum loss over time steps: BPTT
        for step in range(num_steps):
          loss_val += loss_fn(log_p_y[step], targets_it)

        # Gradient calculation
        optimizer.zero_grad()
        loss_val.backward()

        # Weight Update
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        # Test set
        test_data = itertools.cycle(test_loader)
        testdata_it, testtargets_it = next(test_data)
        testdata_it = testdata_it.to(device)
        testtargets_it = testtargets_it.to(device)

        batch_size = testdata_it.size()[0]  #  test: remove spk

        # Test set forward pass
        test_spk, test_mem = net(testdata_it.view(batch_size, -1))

        # Test set loss
        log_p_ytest = log_softmax_fn(test_mem)
        log_p_ytest = log_p_ytest.sum(dim=0)
        loss_val_test = loss_fn(log_p_ytest, testtargets_it)
        test_loss_hist.append(loss_val_test.item())

        # Print test/train loss/accuracy
        if counter % 50 == 0:
            train_printer()
        minibatch_counter += 1
        counter += 1

loss_hist_true_grad = loss_hist
test_loss_hist_true_grad = test_loss_hist

Epoch 0, Minibatch 0
Train Set Loss: 76.86495971679688
Test Set Loss: 56.226741790771484
Train Set Accuracy: 0.1875
Test Set Accuracy: 0.1796875


Epoch 0, Minibatch 50
Train Set Loss: 17.487354278564453
Test Set Loss: 14.834465980529785
Train Set Accuracy: 0.890625
Test Set Accuracy: 0.9140625


Epoch 0, Minibatch 100
Train Set Loss: 13.294114112854004
Test Set Loss: 15.876989364624023
Train Set Accuracy: 0.9375
Test Set Accuracy: 0.890625


Epoch 0, Minibatch 150
Train Set Loss: 14.225533485412598
Test Set Loss: 12.264030456542969
Train Set Accuracy: 0.875
Test Set Accuracy: 0.8984375


Epoch 0, Minibatch 200
Train Set Loss: 16.187170028686523
Test Set Loss: 14.389913558959961
Train Set Accuracy: 0.875
Test Set Accuracy: 0.890625


Epoch 0, Minibatch 250
Train Set Loss: 12.327922821044922
Test Set Loss: 12.903144836425781
Train Set Accuracy: 0.9140625
Test Set Accuracy: 0.921875


Epoch 0, Minibatch 300
Train Set Loss: 9.990253448486328
Test Set Loss: 10.313959121704102
Train Set Acc

In [10]:
spk_rec.size()  # the final batch is 96!

torch.Size([25, 96, 10])

In [11]:
test_spk.size()

torch.Size([25, 128, 10])