In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation
import torch
from pathlib import Path
from model_spytorch import SNNModel

Matplotlib created a temporary cache directory at /tmp/matplotlib-e__unvz_ because the default path (/home/jupyter-ikharitonov/.cache/matplotlib) is not a writable directory; it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.


In [2]:
data = np.load(Path.home()/'RANCZLAB-NAS/iakov/data/mnist_sequences.npy')
data = data.reshape((data.shape[0], data.shape[1], data.shape[2]*data.shape[3]))
data = torch.Tensor(data)
# preprocessing
data = data[torch.randperm(data.shape[0])] # row shuffling
data /= 255 # 0-1 normalisation
data *= 0.5 # scaling down by 1/2

In [3]:
model = SNNModel(batch_size = 100, hidden_units = data.shape[2], num_timesteps = data.shape[1], step_length = 1/data.shape[1], device = torch.device('cuda'), dtype = torch.float)
model.init_parameters()

In [4]:
save_weights_path = Path.home() / 'RANCZLAB-NAS/iakov/v1_weights_22_jan'
epochs_so_far = len(os.listdir(save_weights_path))
np_weights = np.load(save_weights_path / f'epoch{epochs_so_far-1}_v1_matrix.npy')
# model.load_weights(np_weights)
model.v1 = torch.Tensor(np_weights).cuda().type(model.dtype)

In [1]:
plt.figure(figsize=(15,15))
plt.imshow(model.v1.detach().cpu().numpy(), cmap='seismic')
plt.colorbar()
plt.show()

NameError: name 'plt' is not defined

In [6]:
mem_rec, spk_rec = model.run_snn(data[:100].cuda())

In [7]:
# Animation takes long time to load
plt.rcParams["animation.html"] = "jshtml"
plt.rcParams['figure.dpi'] = 150  
plt.ioff()

fig, ax = plt.subplots(nrows=1,ncols=3)

def animate(t):
    plt.cla()
    ax[0].set_title(f'Frame {t}: input', fontsize=5)
    ax[0].imshow(data[:100,:,:].reshape((100,210,28,28)).detach().cpu().numpy()[0,t,:,:])
    ax[1].set_title(f'Frame {t}: membrane voltage', fontsize=5)
    ax[1].imshow(mem_rec.reshape((100,210,28,28)).detach().cpu().numpy()[0,t,:,:])
    ax[2].set_title(f'Frame {t}: unit spiking', fontsize=5)
    ax[2].imshow(spk_rec.reshape((100,210,28,28)).detach().cpu().numpy()[0,t,:,:])

matplotlib.animation.FuncAnimation(fig, animate, frames=210)