In [2]:
import torch
import torch.nn as nn
from spikingjelly.datasets.n_mnist import NMNIST
from spikingjelly.datasets import play_frame
from spikingjelly.activation_based import neuron, layer, learning
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from IPython.display import display, clear_output
from time import sleep
from PIL import Image
import mnist
import numpy as np

In [3]:
FILE_LIST = "MNIST/raw/train-images-idx3-ubyte MNIST/raw/train-labels-idx1-ubyte MNIST/raw/t10k-images-idx3-ubyte MNIST/raw/t10k-labels-idx1-ubyte".split()

In [4]:
arrays:list[np.ndarray] = []

for fname in FILE_LIST:
    with open(fname,'rb') as f:
        arrays.append(mnist.parse_idx(f))
assert len(arrays) == 4
train_imgs, train_labels, test_imgs, test_labels = arrays

In [5]:
train_set = NMNIST("data/",train=True, data_type="frame", frames_number=64, split_by="number")
# test_set = NMNIST("data/",train=False, data_type="frame", frames_number=64, split_by="number")

The directory [data/frames_number_64_split_by_number] already exists.


In [7]:
device = ["cpu", "cuda"][torch.cuda.is_available()]

def f_weight(x):
    return torch.clamp(x, -1, 1.)

torch.manual_seed(0)
# plt.style.use(['science'])

def f_pre(x, w_min, alpha=0.):
    return (x - w_min) ** alpha

def f_post(x, w_max, alpha=0.):
    return (w_max - x) ** alpha

w_min, w_max = -1., 1.
tau_pre, tau_post = 2., 2.
N_in, N_out = 2*34*34, 10
T = 64
batch_size = 2
lr = 0.01

def gen_block(in_features:int, out_features:int, **kwargs):
    return nn.Sequential(
        layer.Linear(in_features, out_features, **kwargs),
        neuron.LIFNode()
    )

net = nn.Sequential(
    gen_block(N_in, N_out)
).to(device)
# nn.init.trunc_normal_(net[0].weight.data, 0.4, 1, -1, 0.2)
### TODO: Implement STDP for multilayer.
nn.init.constant_(net[0][0].weight.data, 0.4)
optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.)

# in_spike = (torch.rand([T, batch_size, N_in]) > 0.7).float()
learner = learning.STDPLearner(step_mode='s', synapse=net[0], sn=net[1], 
                                tau_pre=tau_pre, tau_post=tau_post,
                                f_pre=f_weight, f_post=f_weight)

out_spike = []
trace_pre = []
trace_post = []
# weight = []

# Plot implementation
fig = plt.figure(figsize=(15, 3))
fig.suptitle(f"0-th iteration")
axes = []
canvas = torch.zeros((2*34, 34*10))
weight = net[0].weight.cpu().detach()
for i_neuron in range(10):
    canvas[:, 34*i_neuron:34*(i_neuron+1)] = weight[i_neuron].reshape(34*2, 34)
ax = fig.add_subplot(1,1,1)
pc = ax.pcolor(canvas)
div = make_axes_locatable(ax)
cax = div.append_axes("right", size="5%", pad=0.05)
fig.colorbar(pc, cax)
fig.tight_layout()
display(fig)
clear_output(wait=True)

with torch.no_grad():
    frame = torch.zeros((2*34, 34), requires_grad=False, device=device) # *2 for ON, OFF events
    for i, (frames, label) in enumerate(train_set, start=1):
        for t in range(T):
            optimizer.zero_grad()
            # out_spike.append(net(in_spike[t]))
            frame.fill_(0)
            frame[::2,:] = torch.tensor(frames[t,0])
            frame[1::2,:] = torch.tensor(frames[t,1])
            net(frame.flatten()[None,...]) # No out spike append
            # out_spike.append(net(frame.flatten()[None,...]))
            
            learner.step(on_grad=True)
            optimizer.step()
            net[0].weight.data.clamp_(w_min, w_max)
            # weight.append(net[0].weight.data.clone())
            trace_pre.append(learner.trace_pre)
            trace_post.append(learner.trace_post)
        
        if i%10 == 0:
            fig.suptitle(f"{i}-th iteration")
            weight = net[0].weight.cpu().detach()
            for i_neuron in range(10):
                canvas[:,34*i_neuron:34*(i_neuron+1)] = weight[i_neuron].reshape(34*2, 34)
            pc = ax.pcolor(canvas)
            
            # ax.clear()
            # ax.set_title(f"{i}-th iteration")
            # ax.pcolor(net[0].weight.cpu().detach())
            fig.colorbar(pc, cax)
            display(fig)
            clear_output(wait=True)
        

out_spike = torch.stack(out_spike)   # [T, batch_size, N_out]
trace_pre = torch.stack(trace_pre)   # [T, batch_size, N_in]
trace_post = torch.stack(trace_post) # [T, batch_size, N_out]
weight = torch.stack(weight)         # [T, N_out, N_in]

IndexError: index 1 is out of range

In [None]:
t = torch.arange(0, T).float()

# in_spike = in_spike[:, 0, 0] # No in_spike generation
out_spike = out_spike[:, 0, 0]
trace_pre = trace_pre[:, 0, 0]
trace_post = trace_post[:, 0, 0]
weight = weight[:, 0, 0]

cmap = plt.get_cmap('tab10')

# No in_spike generation
# plt.subplot(5, 1, 1)
# plt.eventplot((in_spike * t)[in_spike == 1], lineoffsets=0, colors=cmap(0))
# plt.xlim(-0.5, T + 0.5)
# plt.ylabel('$s[i]$', rotation=0, labelpad=10)
# plt.xticks([])
# plt.yticks([])

plt.subplot(5, 1, 2)
plt.plot(t, trace_pre, c=cmap(1))
plt.xlim(-0.5, T + 0.5)
plt.ylabel('$tr_{pre}$', rotation=0)
plt.yticks([trace_pre.min().item(), trace_pre.max().item()])
plt.xticks([])

plt.subplot(5, 1, 3)
plt.eventplot((out_spike * t)[out_spike == 1], lineoffsets=0, colors=cmap(2))
plt.xlim(-0.5, T + 0.5)
plt.ylabel('$s[j]$', rotation=0, labelpad=10)
plt.xticks([])
plt.yticks([])

plt.subplot(5, 1, 4)
plt.plot(t, trace_post, c=cmap(3))
plt.ylabel('$tr_{post}$', rotation=0)
plt.yticks([trace_post.min().item(), trace_post.max().item()])
plt.xlim(-0.5, T + 0.5)
plt.xticks([])

plt.subplot(5, 1, 5)
plt.plot(t, weight, c=cmap(4))
plt.xlim(-0.5, T + 0.5)
plt.ylabel('$w[i][j]$', rotation=0)
plt.yticks([weight.min().item(), weight.max().item()])
plt.xlabel('time-step')

plt.gcf().subplots_adjust(left=0.18)

plt.show()
plt.savefig('./stdp_trace.pdf')