In [11]:
import sys

import matplotlib.pyplot as plt
import numpy as np
import snntorch
import snntorch.functional as SF
import torch
from matplotlib import cm
from sklearn.metrics import confusion_matrix
from snntorch import surrogate

sys.path.append('../')
from src import data
from src import plot_params
from src.spiking_neural_network import Snn

In [12]:
def save_checkpoint(checkpoint: dict, database: str, n: int, save=False):
    if save == True:
        torch.save(
            checkpoint, f'../data/trained_model/{n}_{n}_{database}_snn.pth')

In [13]:
database = 'mnist'
batch_size = 128

(train_set, test_set,
 train_loader, test_loader,
 device) = data.set_loader_device(database, batch_size)

In [14]:
n = 300
num_units_layers = [784, n, n, 10]
spk_neuron = snntorch.Leaky(beta=0.5,
                            threshold=1,
                            reset_mechanism='subtract',
                            spike_grad=surrogate.fast_sigmoid(slope=25))

model = Snn(spk_neuron, num_units_layers).to(device=device)
loss = SF.ce_count_loss()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
epochs = 1
train_loss, train_acc, test_loss, test_acc, w0, w1, w2 = model.train(
    loss, optimizer, train_loader, test_loader, epochs, batch_size, device)

# Evaluation.
eval = model.evaluation(test_loader, batch_size, device)
print('Evaluation:', eval)

# Confusion matrix.
true_label, pred_label = model.predictions(test_loader, batch_size, device)
cm = confusion_matrix(true_label, pred_label)

# Save the model.
checkpoint = {'model_state_dict': model.state_dict()}

Epoch: 1 batch_idx: 0
Epoch: 1 batch_idx: 20
Epoch: 1 batch_idx: 40
Epoch: 1 batch_idx: 60
Epoch: 1 batch_idx: 80
Epoch: 1 batch_idx: 100
Epoch: 1 batch_idx: 120
Epoch: 1 batch_idx: 140
Epoch: 1 batch_idx: 160
Epoch: 1 batch_idx: 180
Epoch: 1 batch_idx: 200
Epoch: 1 batch_idx: 220
Epoch: 1 batch_idx: 240
Epoch: 1 batch_idx: 260
Epoch: 1 batch_idx: 280
Epoch: 1 batch_idx: 300
Epoch: 1 batch_idx: 320
Epoch: 1 batch_idx: 340
Epoch: 1 batch_idx: 360
Epoch: 1 batch_idx: 380
Epoch: 1 batch_idx: 400
Epoch: 1 batch_idx: 420
Epoch: 1 batch_idx: 440
Epoch: 1 batch_idx: 460
Evaluation: 0.9402043269230769


In [5]:
import random

In [6]:
random.shuffle(checkpoint['model_state_dict']['layers.0.weight'])
random.shuffle(checkpoint['model_state_dict']['layers.1.weight'])
random.shuffle(checkpoint['model_state_dict']['layers.2.weight'])

In [7]:
eval = model.evaluation(test_loader, batch_size, device)
print('Evaluation:', eval)

Evaluation: 0.09815705128205128


In [8]:
true_label, pred_label = model.predictions(test_loader, batch_size, device)
cm = confusion_matrix(true_label, pred_label)

In [15]:
colors = ['navy', '#76b7b2', '#e15759']
#colors = ['skyblue', 'yellowgreen', 'lightsalmon']
colors_2 = ['navy', 'teal', 'crimson']
colors_2 = colors
training_time = np.arange(468)

w0 = w0.squeeze()
w1 = w1.squeeze()
w2 = w2.squeeze()

In [21]:
fig, ax = plt.subplots(ncols=2, nrows=1, figsize=(12,4), constrained_layout=False)

ax[0].plot(training_time[:31], train_loss[:31], ':', color=colors[0])
#ax[0].plot(training_time[:30], test_loss[:30], color=colors[0])

ax[0].plot(training_time[30:200], train_loss[30:200], ':', color=colors[1])
#ax[0].plot(training_time[30:200], test_loss[30:200], color=colors[1])

ax[0].plot(training_time[200:], train_loss[200:], ':',color=colors[2], label='Train set loss')
#ax[0].plot(training_time[200:], test_loss[200:], color=colors[2], label='Test Set')

ax[0].spines[['right', 'top']].set_visible(False)
#ax[0].set_xticklabels([])
ax[0].set_ylabel('Loss')
ax[0].set_xlabel('Training Time')

ax[0].legend(bbox_to_anchor=(0.6, 0.25))
ax[0].set_title('(A)')

ax_twinx = ax[0].twinx()

alpha = 0.5
#ax_twinx.plot(training_time[:30], train_acc[:30], color=colors[0])
ax_twinx.plot(training_time[:31], test_acc[:31], color=colors_2[0])

#ax_twinx.plot(training_time[30:200], train_acc[30:200], color=colors[1])
ax_twinx.plot(training_time[30:200], test_acc[30:200], color=colors_2[1])

#ax_twinx.plot(training_time[200:], train_acc[200:],color=colors[2], label='Train Set')
ax_twinx.plot(training_time[200:], test_acc[200:], color=colors_2[2], label='Test set accuracy')


ax_twinx.set_ylabel('Accuracy')
ax_twinx.legend(bbox_to_anchor=(1, 0.6))


place = [-0.8, 0.6, 0.45]

axin1 = ax[0].inset_axes([0.1] +place)

axin1.set_ylabel('Count')
range_hist = (-0.1, 0.1)
axin1.hist(w0[0], bins=50, range= range_hist, histtype='stepfilled', color=colors_2[0], alpha=0.8, label=' ')
axin1.hist(w0[80], bins=50, range= range_hist, histtype='stepfilled', color=colors_2[1], alpha=0.65, label=' ')
axin1.hist(w0[467], bins=50, range= range_hist, histtype='stepfilled', color=colors_2[2], alpha=0.4, label=' ')

#axin1.legend(title=r'$\boldsymbol{w}^{(1)}$')
axin1.set_xlabel(r'$w_{ij}^{(1)}$')
axin1.legend(title='Learning phases',loc='upper center', ncols=3 , bbox_to_anchor=(0., 1.55))
axin1.set_title('(C)')


axin2 = ax[0].inset_axes([0.8] +place)

axin2.hist(w1[0], bins=50, range= range_hist, histtype='stepfilled', color=colors_2[0], alpha=0.8)
axin2.hist(w1[80], bins=50, range= range_hist, histtype='stepfilled', color=colors_2[1], alpha=0.65)
axin2.hist(w1[467], bins=50, range= range_hist, histtype='stepfilled', color=colors_2[2], alpha=0.4)

#axin2.legend(title=r'$\boldsymbol{w}^{(1)}$')
axin2.set_xlabel(r'$w_{ij}^{(2)}$')

axin2.set_title('(D)')


axin3 = ax[0].inset_axes([1.5]+ place)

axin3.hist(w2[0], bins=50, range= range_hist, histtype='stepfilled', color=colors_2[0], alpha=0.8)
axin3.hist(w2[80], bins=50, range= range_hist, histtype='stepfilled', color=colors_2[1], alpha=0.65)
axin3.hist(w2[467], bins=50, range= range_hist, histtype='stepfilled', color=colors_2[2], alpha=0.4)

#axin3.legend(title=r'$\boldsymbol{w}^{(1)}$')
axin3.set_xlabel(r'$w_{ij}^{(3)}$')

axin3.set_title('(E)')


ax3 = ax[1]
img = ax3.imshow(cm, cmap=plt.cm.Blues)

label = np.arange(10)
ax3.set_xticks(label)
ax3.set_yticks(label)

ax3.set_xlabel('Predicted Label')
ax3.set_ylabel('True Label')
ax3.grid(False)
ax3.set_title('(B)')

for i in range(label.size):
    for j in range(label.size):
        if cm[i,j] < 670:
            text = ax3.text(j, i, cm[i, j], ha="center", va="center", color="k", fontsize=12)
        else:
            text = ax3.text(j, i, cm[i, j], ha="center", va="center", color="w", fontsize=10)

#fig.tight_layout()
cb = fig.colorbar(img)

fig.savefig('../plots/fit_mnist.pdf')

Error in callback <function _draw_all_if_interactive at 0x10bff4360> (for post_execute), with arguments args (),kwargs {}:


ZeroDivisionError: float division by zero

ZeroDivisionError: float division by zero

<Figure size 1200x400 with 4 Axes>

In [44]:
save_checkpoint(checkpoint, database, n, save=True)