In [3]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch import optim
from torch.autograd import Function
from torch.nn import functional as F
import math

In [4]:
def spike_fn(U, threshold):
    return (U >= threshold).float()


#input X spikes is determined by the probability,
#higher the pixel tensity, higher the probability there is a spike.
def poisson_encode(x, T):
    x = x.unsqueeze(0).repeat(T, 1, 1)  # (T, B, features)
    spikes = (torch.rand_like(x) < x).float()
    return spikes

In [3]:
class LIFLayer(nn.Module):
    def __init__(self, in_features, out_features, tau_trace = 20.0,
                 tau_mem=20.0, tau_syn=5.0, dt=1.0, threshold=1.0):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.threshold = threshold

        # use Xavier initialization, because otherwise some neurons were not firing at all, it was unstable.
        std = 1.0 / (in_features**0.5)
        self.W = nn.Parameter(torch.randn(out_features, in_features) * std)


        self.dt = dt #just a theoretical constant

        #Use register_buffer, because it must be linked to the nn.Module
        self.register_buffer('alpha', torch.exp(torch.tensor(-dt / tau_syn))) #Voltage decay
        self.register_buffer('beta', torch.exp(torch.tensor(-dt / tau_mem))) #Current decay
        self.register_buffer('gamma', torch.exp(torch.tensor(-dt / tau_trace)))  #Trace decay


    #input spikes -> (T, B, in_dim)
    #returns -> (T, B, out_dim) as spikes, and (T, B, out_dim) as voltage values
    def forward(self, input_spikes):

        T, B, _ = input_spikes.shape
        device = input_spikes.device

        #current, and voltage must start from 0 in the beginning
        I = torch.zeros(B, self.out_features, device=device)
        U = torch.zeros(B, self.out_features, device=device)

        #keep all spike and voltage values across the time steps.
        out_spikes = []
        U_hist = []

        for t in range(T):
            x_t = input_spikes[t]  # (B, in_features), X value of the current time step.

            #new current
            I = self.alpha * I + x_t @ self.W.t()  # (B, out_features)

            #new membrane voltage
            U = self.beta * U + I

            #if U exceeds threshold get one as a spike.
            S = spike_fn(U, self.threshold)  # (B, out_features)
            U_hist.append(U.clone())
            U = U - S * self.threshold #reset U

            out_spikes.append(S)
            

        out_spikes = torch.stack(out_spikes, dim=0)
        U_hist = torch.stack(U_hist, dim=0)
        return out_spikes, U_hist


In [4]:

class RandomBPSNN(nn.Module):
    def __init__(self, T=20, tau_mem=20.0, tau_syn=5.0, dt=1.0, threshold=1.0):
        super().__init__()
        self.T = T
        self.lif_layers = torch.nn.ModuleList()
        self.G_hiddens = torch.nn.ParameterList()


    def append_LIF(self, LIF, out_dim = 10):
        self.lif_layers.append(LIF)
        if (len(self.lif_layers) > 1):
            hidden_dim = self.lif_layers[-2].out_features #set the projection matrix
            #random projection matrix -> (out_dim(10), hidden_dim)
            G_hidden = nn.Parameter(
                torch.randn(out_dim, hidden_dim) * 0.1, requires_grad = False
            )
            self.G_hiddens.append(G_hidden)

    #here if X is a static no temporal data -> X values are encoded using rate-encoding or poisson encoding
    #generally rate-encoding is used because it is more stable however I implemented poisson encoding above as well.
    #if it has a stime steps, no encoding is used.

    #Static X -> (B, in_dim)
    #Time_step_X -> (T, B, in_dim)
    def forward(self, x):
        device = x.device
        static = len(x.shape) == 2
        if (static):
            B, in_dim = x.shape
            x = x.unsqueeze(0).repeat(self.T, 1, 1)  # (T, B, in_dim), rate encoding

        else:
            _, B, in_dim = x.shape
        o_spk = x
        spikes, voltages = [], []
        for lif in self.lif_layers:
            o_spk, o_U = lif(o_spk)
            spikes.append(o_spk)
            voltages.append(o_U)

        #to calculate accuracy, we generally sum spike numbers of all the time steps, and predict the 
        #class which has maximum number of summed spikies.
        out_rate = o_spk.mean(dim=0)  #(B, O)

        return {
            "Us": voltages,
            "spikes": spikes,
            "o_spk": spikes[-1],
            "o_U": voltages[-1],
            "out_rate": out_rate
        }


In [5]:
def random_bp_step(model, x, target, optimizer, super_spike_B = 25, last_learning_window = 5):
    """
    model -> all_layers(RandomBPSNN)
    x -> (B, in_dim) or (T, B, in_dim) if it is (B, in_dim) timesteps are produced using rate-encoding
    target -> (B, out_dim)
    loss_fn -> Cross entropy derivative is used here -> |Y - y_predicted|
    last_learning_window -> last last_learnin_window number is used for determining loss and derivatives.
    increase it -> it pay attention to more time steps, since the loss signal is stronger, the latency is shorter
    since it might wait to learn
    decreasing it -> network only pays attention to last spikes, the latency is longer because it waits for the last spikes
    """    

    optimizer.zero_grad()
    
    out = model(x)
    out_spikes = out["o_spk"] # (T, B, 10)
    T, B, num_classes = out_spikes.shape
    
    # DEBUG: Spike rate kontrol
    #print(f"Spike rate: {out_spikes.mean().item():.4f}")
    #print(f"Output spike sum per class: {out_spikes.sum(dim=0).mean(dim=0)}")
    #apply one_hot encoding if it hasnt applied yet.
    if (len(target.shape) == 1):
        target = F.one_hot(target, num_classes=num_classes).float()  # (B, 10)


    #Surrogate derivative -> SuperSpike -> 1 / (1 + B*|U - threshold|)^2
    #calculate surrogate gradients
    o_U = out["o_U"]  #(T, B, O)
    T = o_U.shape[0]

    #dL/dy^(L) = (out_rate - target)
    learning_window = out_spikes[T-last_learning_window:].mean(dim = 0, keepdim = False) #(B, O)
    dL_dyL = learning_window - target  # (B, O)
    #probs = torch.softmax(learning_window) #Calculate softmax(probabilities)
    loss = F.cross_entropy(learning_window, target) #calculate cross_entropy loss function
    #print(f"Learning window min: {learning_window.min().item():.4f}")
    #print(f"Learning window max: {learning_window.max().item():.4f}")
    #print(f"Loss: {loss.item():.4f}")
    super_spike_B = super_spike_B
    du = (o_U - model.lif_layers[-1].threshold).abs() #(T, B, O)
    sigma_prime_per_t = 1 / ((1 + super_spike_B * du)**2) # (T, B, O)
    
    #(B, O) - average over time
    sigma_prime_out = sigma_prime_per_t[T-last_learning_window:, :, :].mean(dim=0, keepdim = False)
    #take derivative for the output layer:
    input_of_out_layer = out["spikes"][-2] # (T, B, hidden_units)
    learning_window_hiddens = input_of_out_layer[T-last_learning_window:].mean(dim = 0, keepdim = False)#(B, hidden_units)

    delta_out = dL_dyL * sigma_prime_out  # (B, O)

    dw_output = torch.permute(delta_out, (1, 0)) @ learning_window_hiddens #(O, hidden_units)
    dw_output = dw_output / B #(O, hidden_units)
    model.lif_layers[-1].W.grad = dw_output #set the derivative
 
    voltages = out["Us"]
    spikes = out["spikes"]
    for i in range(len(voltages) - 1):
        voltage = voltages[i] # (T, B, H)
        
        layer_threshold = model.lif_layers[i].threshold
        du = (voltage[T-last_learning_window:, :, :] - layer_threshold).abs()
        sigma_prime_per_t = 1 / ((1 + super_spike_B * du)**2) # (T, B, O)
        sigma_prime_out = sigma_prime_per_t[T-last_learning_window:, :, :].mean(dim=0, keepdim = False) #(B, hidden_dim)
        if (i == 0):
            input_spikes = x[T-last_learning_window:, :, :].mean(dim = 0, keepdim = False)#(B, hidden_dim)
        else:
            spike = spikes[i - 1] # (T, B, H)
            input_spikes = spike[T-last_learning_window:, :, :].mean(dim = 0, keepdim = False)#(B, hidden_dim)
        #calculate loss on the projected vector (for it being biologically plausible, it is more localized)
        random_matrix = model.G_hiddens[i] # (output_dim, hidden_dim)
        delta_hidden = sigma_prime_out*(delta_out @ random_matrix) # (B, O) @ (O, hidden_dim) -> (B, hidden_dim)
        dw = torch.zeros_like(model.lif_layers[i].W)

        dw += torch.permute(delta_hidden, (1, 0)) @ input_spikes #(O, B) @ (B, input_dim) -> (O, input_dim)
        dw = dw/B #batch mean
        #set gradients
        model.lif_layers[i].W.grad = dw


    optimizer.step()

    return loss.item()


In [5]:

# MNIST yükleme
transform = transforms.Compose([
transforms.ToTensor()                # (1, 28, 28), [0,1]
])

train_dataset = datasets.MNIST(root="../data", train=True,
                        download=True, transform=transform)
test_dataset = datasets.MNIST(root="../data", train=False,
                        download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader  = DataLoader(test_dataset, batch_size=256, shuffle=False)


In [7]:

def main(train_loader, test_loader, model, optimizer, num_epochs, verbose = True, super_spike_beta = 25):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)

    spike_means_outer = []

    for epoch in range(1, num_epochs + 1):
        model.train()
        total_loss = 0.0
        for batch_idx, (data, target) in enumerate(train_loader):
            data = data.to(device)   # (B,1,28,28)
            target = target.to(device)

            # Flatten: (B, 784)
            data = data.view(data.size(0), -1)
            data = data.unsqueeze(0).repeat(20, 1, 1)  # (T, B, in_dim), rate encoding

            total_loss += random_bp_step(model, data, target, optimizer, last_learning_window=15, super_spike_B=super_spike_beta)

            if (batch_idx + 1) % 100 == 0:
                if (verbose):
                    print(f"Epoch {epoch} | Batch {batch_idx+1}/{len(train_loader)} | "
                        f"Loss: {total_loss / (batch_idx+1):.4f}")
                data, target = next(iter(train_loader))
                data = data.to(device).view(data.size(0), -1)

                out = model(data)
                spike_means = []
                for i in out["spikes"]:
                    spike_means.append(i.mean().item())
                spike_means_outer.append(torch.tensor(spike_means))
                if (verbose):
                    for i in range(len(spike_means)):
                        print(f"layer {i+1} spike means:", spike_means[i])


        # ---- Test accuracy ----
        acc = evaluate(model, test_loader, device)
        if (verbose):
            print(f"Epoch {epoch} finished. Test accuracy: {acc:.2f}%")
    spike_means_outer = torch.stack(spike_means_outer, dim = 0)
    return model, acc, spike_means_outer
def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in loader:
            data = data.to(device)
            target = target.to(device)
            data = data.view(data.size(0), -1)

            out = model(data)
            out_rate = out["out_rate"]  # (B, 10)
            pred = out_rate.argmax(dim=1)
            correct += (pred == target).sum().item()
            total += target.size(0)
    return 100.0 * correct / total



In [8]:

#experiment hidden sizes
#hidden_sizes = [32, 64, 128, 256, 512]
# Model
device = torch.device("cuda")
def get_model(hidden_dims, thresholds = None, super_spike_beta = 25, lr = 1e-3):
    if (thresholds is None):
        thresholds = [1 for _ in range(len(hidden_dims) - 1)]

    model = RandomBPSNN()
    weights = []
    for i in range(1, len(hidden_dims)):
        model.append_LIF(LIFLayer(hidden_dims[i-1], hidden_dims[i], threshold=thresholds[i-1]))
        weights.append(model.lif_layers[-1].W)
    model = model.to(device)
    optimizer = torch.optim.Adam(
        weights, lr=lr
    )
    return model, optimizer

In [83]:
#experiment hidden sizes
#hidden_sizes = [32, 64, 128, 256, 512]
# Model

hidden_sizes = [64, 256]
threshold = 1
epoch = 3
T = 25
acc_results = []
for hidden_size in hidden_sizes:
    model, optimizer = get_model([784, hidden_size, 10], thresholds = [0.3, 0.3])
    model, acc = main(train_loader, test_loader, model, optimizer, num_epochs=epoch, verbose=True, super_spike_beta=1)
    print(f"For hidden size {hidden_size} accuracy is --> ", acc)
    for i, lif in enumerate(model.lif_layers):
        print(f"Layer {i} W mean: {lif.W.mean().item():.6f}")
        print(f"Layer {i} W std: {lif.W.std().item():.6f}")
        if lif.W.grad is not None:
            print(f"Layer {i} grad mean: {lif.W.grad.mean().item():.6f}")
            print(f"Layer {i} grad max: {lif.W.grad.abs().max().item():.6f}")



Using device: cuda
Epoch 1 | Batch 100/938 | Loss: 2.2950
layer 1 spike means: 0.488525390625
layer 2 spike means: 0.010859374888241291
Epoch 1 | Batch 200/938 | Loss: 2.2834
layer 1 spike means: 0.47679445147514343
layer 2 spike means: 0.011015624739229679
Epoch 1 | Batch 300/938 | Loss: 2.2640
layer 1 spike means: 0.4914794862270355
layer 2 spike means: 0.016640624031424522
Epoch 1 | Batch 400/938 | Loss: 2.2532
layer 1 spike means: 0.4742065370082855
layer 2 spike means: 0.014609375037252903
Epoch 1 | Batch 500/938 | Loss: 2.2392
layer 1 spike means: 0.48432618379592896
layer 2 spike means: 0.025390625
Epoch 1 | Batch 600/938 | Loss: 2.2278
layer 1 spike means: 0.508544921875
layer 2 spike means: 0.025156250223517418
Epoch 1 | Batch 700/938 | Loss: 2.2129
layer 1 spike means: 0.504626452922821
layer 2 spike means: 0.03890625014901161
Epoch 1 | Batch 800/938 | Loss: 2.1943
layer 1 spike means: 0.5222534537315369
layer 2 spike means: 0.03218749910593033
Epoch 1 | Batch 900/938 | Loss:

In [84]:
#experiment hidden sizes
#hidden_sizes = [32, 64, 128, 256, 512]
# Model

hidden_sizes = [64, 256]
threshold = 1
epoch = 3
T = 25
acc_results = []
for hidden_size in hidden_sizes:
    model, optimizer = get_model([784, hidden_size, 10], thresholds = [0.3, 0.3], lr = 0.005)
    model, acc = main(train_loader, test_loader, model, optimizer, num_epochs=epoch, verbose=True, super_spike_beta=1)
    print(f"For hidden size {hidden_size} accuracy is --> ", acc)
    for i, lif in enumerate(model.lif_layers):
        print(f"Layer {i} W mean: {lif.W.mean().item():.6f}")
        print(f"Layer {i} W std: {lif.W.std().item():.6f}")
        if lif.W.grad is not None:
            print(f"Layer {i} grad mean: {lif.W.grad.mean().item():.6f}")
            print(f"Layer {i} grad max: {lif.W.grad.abs().max().item():.6f}")



Using device: cuda
Epoch 1 | Batch 100/938 | Loss: 2.3101
layer 1 spike means: 0.4888061583042145
layer 2 spike means: 0.0
Epoch 1 | Batch 200/938 | Loss: 2.3061
layer 1 spike means: 0.4969848692417145
layer 2 spike means: 0.0
Epoch 1 | Batch 300/938 | Loss: 2.2947
layer 1 spike means: 0.49171143770217896
layer 2 spike means: 0.014609375037252903
Epoch 1 | Batch 400/938 | Loss: 2.2765
layer 1 spike means: 0.5290283560752869
layer 2 spike means: 0.010937499813735485
Epoch 1 | Batch 500/938 | Loss: 2.2647
layer 1 spike means: 0.537109375
layer 2 spike means: 0.01249999925494194
Epoch 1 | Batch 600/938 | Loss: 2.2571
layer 1 spike means: 0.560839831829071
layer 2 spike means: 0.008281249552965164
Epoch 1 | Batch 700/938 | Loss: 2.2514
layer 1 spike means: 0.566418468952179
layer 2 spike means: 0.009609375149011612
Epoch 1 | Batch 800/938 | Loss: 2.2466
layer 1 spike means: 0.5447143912315369
layer 2 spike means: 0.010312499478459358
Epoch 1 | Batch 900/938 | Loss: 2.2438
layer 1 spike mea

KeyboardInterrupt: 

In [89]:
#experiment hidden sizes
#hidden_sizes = [32, 64, 128, 256]
# Model

hidden_sizes = [32, 64, 128, 256]
epoch = 2
acc_results = []
for hidden_size in hidden_sizes:
    model, optimizer = get_model([784, hidden_size, 10], thresholds = [0.3, 0.3], lr = 0.005)
    model, acc, spikes_means = main(train_loader, test_loader, model, optimizer, num_epochs=epoch, verbose=False, super_spike_beta=0.1)
    print(f"For hidden size {hidden_size} accuracy is --> ", acc)
    spikes_means = torch.mean(spikes_means, dim = 0, keepdim=False)
    for i in range(len(spikes_means)):
        print(f"Layer {i} density -> ", spikes_means[i])
    print()


Using device: cuda
For hidden size 32 accuracy is -->  88.33
Layer 0 density ->  tensor(0.4982)
Layer 1 density ->  tensor(0.0909)

Using device: cuda
For hidden size 64 accuracy is -->  90.09
Layer 0 density ->  tensor(0.5068)
Layer 1 density ->  tensor(0.0897)

Using device: cuda
For hidden size 128 accuracy is -->  87.1
Layer 0 density ->  tensor(0.5224)
Layer 1 density ->  tensor(0.0915)

Using device: cuda
For hidden size 256 accuracy is -->  88.13
Layer 0 density ->  tensor(0.4924)
Layer 1 density ->  tensor(0.0936)



In [90]:
#experiment hidden sizes
#hidden_sizes = [32, 64, 128, 256]
# Model

hidden_sizes = [32, 64, 128, 256]
epoch = 2
acc_results = []
for hidden_size in hidden_sizes:
    model, optimizer = get_model([784, hidden_size, 10], thresholds = [0.1, 0.3], lr = 0.005)
    model, acc, spikes_means = main(train_loader, test_loader, model, optimizer, num_epochs=epoch, verbose=False, super_spike_beta=0.5)
    print(f"For hidden size {hidden_size} accuracy is --> ", acc)
    spikes_means = torch.mean(spikes_means, dim = 0, keepdim=False)
    for i in range(len(spikes_means)):
        print(f"Layer {i} density -> ", spikes_means[i])
    print()


Using device: cuda
For hidden size 32 accuracy is -->  78.48
Layer 0 density ->  tensor(0.5881)
Layer 1 density ->  tensor(0.0547)

Using device: cuda
For hidden size 64 accuracy is -->  73.9
Layer 0 density ->  tensor(0.5333)
Layer 1 density ->  tensor(0.0719)

Using device: cuda
For hidden size 128 accuracy is -->  57.51
Layer 0 density ->  tensor(0.4903)
Layer 1 density ->  tensor(0.0437)

Using device: cuda
For hidden size 256 accuracy is -->  35.22
Layer 0 density ->  tensor(0.5136)
Layer 1 density ->  tensor(0.0217)



In [91]:
#experiment hidden sizes
#hidden_sizes = [32, 64, 128, 256]
# Model

hidden_sizes = [32, 64, 128, 256]
epoch = 2
acc_results = []
for hidden_size in hidden_sizes:
    model, optimizer = get_model([784, hidden_size, 10], thresholds = [0.05, 0.3], lr = 0.005)
    model, acc, spikes_means = main(train_loader, test_loader, model, optimizer, num_epochs=epoch, verbose=False, super_spike_beta=0.1)
    print(f"For hidden size {hidden_size} accuracy is --> ", acc)
    spikes_means = torch.mean(spikes_means, dim = 0, keepdim=False)
    for i in range(len(spikes_means)):
        print(f"Layer {i} density -> ", spikes_means[i])
    print()


Using device: cuda
For hidden size 32 accuracy is -->  88.15
Layer 0 density ->  tensor(0.5238)
Layer 1 density ->  tensor(0.0866)

Using device: cuda
For hidden size 64 accuracy is -->  90.25
Layer 0 density ->  tensor(0.4979)
Layer 1 density ->  tensor(0.0899)

Using device: cuda
For hidden size 128 accuracy is -->  91.07
Layer 0 density ->  tensor(0.4666)
Layer 1 density ->  tensor(0.0907)

Using device: cuda
For hidden size 256 accuracy is -->  90.19
Layer 0 density ->  tensor(0.4974)
Layer 1 density ->  tensor(0.0921)



In [92]:
#experiment hidden sizes
#hidden_sizes = [32, 64, 128, 256]
# Model

hidden_sizes = [32, 64, 128, 256]
epoch = 2
acc_results = []
for hidden_size in hidden_sizes:
    model, optimizer = get_model([784, hidden_size, 10], thresholds = [1, 0.3], lr = 0.005)
    model, acc, spikes_means = main(train_loader, test_loader, model, optimizer, num_epochs=epoch, verbose=False, super_spike_beta=0.7)
    print(f"For hidden size {hidden_size} accuracy is --> ", acc)
    spikes_means = torch.mean(spikes_means, dim = 0, keepdim=False)
    for i in range(len(spikes_means)):
        print(f"Layer {i} density -> ", spikes_means[i])
    print()


Using device: cuda
For hidden size 32 accuracy is -->  80.14
Layer 0 density ->  tensor(0.4716)
Layer 1 density ->  tensor(0.0677)

Using device: cuda
For hidden size 64 accuracy is -->  65.38
Layer 0 density ->  tensor(0.5717)
Layer 1 density ->  tensor(0.0484)

Using device: cuda
For hidden size 128 accuracy is -->  39.09
Layer 0 density ->  tensor(0.4675)
Layer 1 density ->  tensor(0.0220)

Using device: cuda
For hidden size 256 accuracy is -->  47.2
Layer 0 density ->  tensor(0.4533)
Layer 1 density ->  tensor(0.0394)



In [93]:
#experiment hidden sizes
#hidden_sizes = [32, 64, 128, 256]
# Model

hidden_sizes = [32, 64, 128, 256]
epoch = 2
acc_results = []
for hidden_size in hidden_sizes:
    model, optimizer = get_model([784, hidden_size, 10], thresholds = [2, 0.3], lr = 0.005)
    model, acc, spikes_means = main(train_loader, test_loader, model, optimizer, num_epochs=epoch, verbose=False, super_spike_beta=1.2)
    print(f"For hidden size {hidden_size} accuracy is --> ", acc)
    spikes_means = torch.mean(spikes_means, dim = 0, keepdim=False)
    for i in range(len(spikes_means)):
        print(f"Layer {i} density -> ", spikes_means[i])
    print()


Using device: cuda
For hidden size 32 accuracy is -->  55.56
Layer 0 density ->  tensor(0.5825)
Layer 1 density ->  tensor(0.0433)

Using device: cuda
For hidden size 64 accuracy is -->  46.33
Layer 0 density ->  tensor(0.4838)
Layer 1 density ->  tensor(0.0428)

Using device: cuda
For hidden size 128 accuracy is -->  18.49
Layer 0 density ->  tensor(0.4097)
Layer 1 density ->  tensor(0.0089)

Using device: cuda
For hidden size 256 accuracy is -->  12.57
Layer 0 density ->  tensor(0.4510)
Layer 1 density ->  tensor(0.0079)



In [94]:
#experiment hidden sizes
#hidden_sizes = [32, 64, 128, 256]
# Model

hidden_sizes = [32, 64, 128, 256]
epoch = 2
acc_results = []
for hidden_size in hidden_sizes:
    model, optimizer = get_model([784, hidden_size, 10], thresholds = [2, 0.3], lr = 0.005)
    model, acc, spikes_means = main(train_loader, test_loader, model, optimizer, num_epochs=epoch, verbose=False, super_spike_beta=0.02)
    print(f"For hidden size {hidden_size} accuracy is --> ", acc)
    spikes_means = torch.mean(spikes_means, dim = 0, keepdim=False)
    for i in range(len(spikes_means)):
        print(f"Layer {i} density -> ", spikes_means[i])
    print()


Using device: cuda
For hidden size 32 accuracy is -->  91.82
Layer 0 density ->  tensor(0.5716)
Layer 1 density ->  tensor(0.0924)

Using device: cuda
For hidden size 64 accuracy is -->  92.17
Layer 0 density ->  tensor(0.5054)
Layer 1 density ->  tensor(0.0894)

Using device: cuda
For hidden size 128 accuracy is -->  92.96
Layer 0 density ->  tensor(0.4625)
Layer 1 density ->  tensor(0.0936)

Using device: cuda
For hidden size 256 accuracy is -->  93.78
Layer 0 density ->  tensor(0.4687)
Layer 1 density ->  tensor(0.0959)



In [95]:
#experiment hidden sizes
#hidden_sizes = [32, 64, 128, 256]
# Model

hidden_sizes = [32, 64, 128, 256]
epoch = 2
acc_results = []
for hidden_size in hidden_sizes:
    model, optimizer = get_model([784, hidden_size, 10], thresholds = [5, 0.3], lr = 0.005)
    model, acc, spikes_means = main(train_loader, test_loader, model, optimizer, num_epochs=epoch, verbose=False, super_spike_beta=0.02)
    print(f"For hidden size {hidden_size} accuracy is --> ", acc)
    spikes_means = torch.mean(spikes_means, dim = 0, keepdim=False)
    for i in range(len(spikes_means)):
        print(f"Layer {i} density -> ", spikes_means[i])
    print()


Using device: cuda
For hidden size 32 accuracy is -->  91.47
Layer 0 density ->  tensor(0.4994)
Layer 1 density ->  tensor(0.0860)

Using device: cuda
For hidden size 64 accuracy is -->  90.5
Layer 0 density ->  tensor(0.5081)
Layer 1 density ->  tensor(0.0858)

Using device: cuda
For hidden size 128 accuracy is -->  91.76
Layer 0 density ->  tensor(0.3898)
Layer 1 density ->  tensor(0.0862)

Using device: cuda
For hidden size 256 accuracy is -->  92.93
Layer 0 density ->  tensor(0.3861)
Layer 1 density ->  tensor(0.0927)



In [11]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch import optim
from torch.autograd import Function
from torch.nn import functional as F
import math

def spike_fn(U, threshold):
    return (U >= threshold).float()


#input X spikes is determined by the probability,
#higher the pixel tensity, higher the probability there is a spike.
def poisson_encode(x, T):
    x = x.unsqueeze(0).repeat(T, 1, 1)  # (T, B, features)
    spikes = (torch.rand_like(x) < x).float()
    return spikes


#Now, I implement it with eligibility, because with the last learning window approach, the networks was enforced to
#produce spikes at the ends, however there was no temporal information.
#So, we either use Backpropagation through time, or eligibility trace.
#backpropagation through time is not biologically plausible, so we use eligibility trace.
#Also, a good thing is that eligibility trace and backpropagation through time is
#roughly equivalent to each other mathematically(roughly, not exactly, works in practice).
#eligibility trace basically -> you put a trace to the neurons at each time stop, basicaly it makes the neuron know the
#gradients of the previous time step. So it is why it is equivalent to backpropagation through time.
#each time step we multiply that trace from the previous time step with a number, so it may cause vanishing gradients if
#that factor(gamma) is < 1 , or exploding gradients if gamma > 1, 
#it is used for holding the temporal context.
#also one last thing to note it makes it online, so no need to hold the history, less memory need and biologicaly more plausible
class LIFLayerEligibility(nn.Module):
    def __init__(self, in_features, out_features, tau_trace = 20.0,
                 tau_mem=20.0, tau_syn=5.0, dt=1.0, threshold=1.0, super_spike_B = 0.03):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.threshold = threshold
        self.super_spike_B = super_spike_B
        # use Xavier initialization, because otherwise some neurons were not firing at all, it was unstable.
        std = 1 / (in_features**0.5)
        self.W = nn.Parameter(torch.randn(out_features, in_features) * std)
        

        self.dt = dt #just a theoretical constant

        #Use register_buffer, because it must be linked to the nn.Module
        self.register_buffer('alpha', torch.exp(torch.tensor(-dt / tau_syn))) #Voltage decay
        self.register_buffer('beta', torch.exp(torch.tensor(-dt / tau_mem))) #Current decay
        self.register_buffer('gamma', torch.exp(torch.tensor(-dt / tau_trace)))  #Trace decay


    #input spikes -> (T, B, in_dim)
    #returns -> (T, B, out_dim) as spikes, and (T, B, out_dim) as voltage values
    def forward(self, input_spikes):

        T, B, _ = input_spikes.shape
        device = self.W.device

        #current, and voltage must start from 0 in the beginning
        I = torch.zeros(B, self.out_features, device=device)
        U = torch.zeros(B, self.out_features, device=device)

        #keep all spike and voltage values across the time steps.
        out_spikes = []
        U_hist = []
        trace = torch.zeros(B, self.out_features, self.in_features, device=device)
        for t in range(T):
            x_t = input_spikes[t]  # (B, in_features), X value of the current time step.

            #new current
            I = self.alpha * I + x_t @ self.W.t()  # (B, out_features)

            #new membrane voltage
            U = self.beta * U + I

            #if U exceeds threshold get one as a spike.
            S = spike_fn(U, self.threshold)  # (B, out_features)
            U_hist.append(U.clone())
            U = U - S * self.threshold #reset U
            # Surrogate gradient
            du = (U - self.threshold).abs()
            sigma_prime = 1 / ((1 + self.super_spike_B * du)**2)  # (B, out_features)

            # Eligibility trace update
            # calculate trace here
            trace = self.gamma * trace + sigma_prime.unsqueeze(2) * x_t.unsqueeze(1)
            
            U_hist.append(U.clone())
            U = U - S * self.threshold
            out_spikes.append(S)

            

        out_spikes = torch.stack(out_spikes, dim=0)
        U_hist = torch.stack(U_hist, dim=0)
        return out_spikes, U_hist, trace

class RandomBPSNNEligibility(nn.Module):
    def __init__(self, T=20):
        super().__init__()
        self.T = T
        self.lif_layers = torch.nn.ModuleList()
        self.G_hiddens = torch.nn.ParameterList()


    def append_LIF(self, LIF, out_dim = 10):
        self.lif_layers.append(LIF)
        if (len(self.lif_layers) > 1):
            hidden_dim = self.lif_layers[-2].out_features #set the projection matrix
            #random projection matrix -> (out_dim(10), hidden_dim)
            G_hidden = nn.Parameter(
                torch.randn(out_dim, hidden_dim) * 0.1, requires_grad = False
            )
            self.G_hiddens.append(G_hidden)

    #here if X is a static no temporal data -> X values are encoded using rate-encoding or poisson encoding
    #generally rate-encoding is used because it is more stable however I implemented poisson encoding above as well.
    #if it has a stime steps, no encoding is used.

    #Static X -> (B, in_dim)
    #Time_step_X -> (T, B, in_dim)
    def forward(self, x):
        device = self.lif_layers[0].W.device
        static = len(x.shape) == 2
        x = x.to(device)
        if (static):
            B, in_dim = x.shape
            x = x.unsqueeze(0).repeat(self.T, 1, 1)  # (T, B, in_dim), rate encoding

        else:
            _, B, in_dim = x.shape
        o_spk = x
        spikes, voltages, traces = [], [], []
        for lif in self.lif_layers:
            o_spk, o_U, trace = lif(o_spk)
            spikes.append(o_spk)
            voltages.append(o_U)
            traces.append(trace)

        #to calculate accuracy, we generally sum spike numbers of all the time steps, and predict the 
        #class which has maximum number of summed spikies.
        out_rate = o_spk.mean(dim=0)  #(B, O)

        return {
            "Us": voltages,
            "spikes": spikes,
            "o_spk": spikes[-1],
            "o_U": voltages[-1],
            "out_rate": out_rate,
            "traces": traces
        }
the_device = torch.device("cuda")
def random_bp_step_eligibility(model, x, target, optimizer):
    """
    model -> all_layers(RandomBPSNN)
    x -> (B, in_dim) or (T, B, in_dim) if it is (B, in_dim) timesteps are produced using rate-encoding
    target -> (B, out_dim)
    loss_fn -> Cross entropy derivative is used here -> |Y - y_predicted|
    """    

    optimizer.zero_grad()
    x = x.to(the_device)
    out = model(x)
    spikes = out["spikes"] # (layer_num, T, B, O)
    traces = out["traces"] # (layer_num, ) -> trace values of the last time step for each layers, 
    out_spikes = out["o_spk"] # (T, B, 10)
    T, B, num_classes = out_spikes.shape
    
    #apply one_hot encoding if it hasnt applied yet.
    if (len(target.shape) == 1):
        target = F.one_hot(target, num_classes=num_classes).float()  # (B, 10)

    #calculate errors
    out_rate = spikes[-1].mean(dim=0)  # (B, num_classes)
    error = out_rate - target           # (B, num_classes)
    loss = F.cross_entropy(out_rate, target)

    #calculate gradients from the traces
    for i, lif in enumerate(model.lif_layers):
        trace = traces[i]  #(B, out, in)
        
        #if it is output layer just use normal error value, if it is a hidden layer use projection of the error derivative
        if i == len(model.lif_layers) - 1:
            #Output layer -> dW = error * trace
            dW = torch.einsum('bo,boi->oi', error, trace) / B
        else:
            #Hidden layer: project error with random matrix
            G = model.G_hiddens[i]
            proj_error = error @ G  # (B, hidden)
            dW = torch.einsum('bh,bhi->hi', proj_error, trace) / B
        
        lif.W.grad = dW
    
    optimizer.step()
    return loss.item()


#train the model, with eligibility applied
def train_eligibility(train_loader, test_loader, model, optimizer, num_epochs, verbose = True, apply_poisson = False, T = 20, dont_touch = False):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)

    spike_means_outer = []

    for epoch in range(1, num_epochs + 1):
        model.train()
        total_loss = 0.0
        for batch_idx, (data, target) in enumerate(train_loader):
            data = data.to(device)   # (B,1,28,28)
            target = target.to(device)
            if not dont_touch:
                # Flatten: (B, 784)
                data = data.view(data.size(0), -1)

                if (not apply_poisson):
                # Flatten: (B, 784)
                    data = data.unsqueeze(0).repeat(T, 1, 1)  # (T, B, in_dim), rate encoding
                else:
                    data = poisson_encode(data, T = T)
                    

            total_loss += random_bp_step_eligibility(model, data, target, optimizer)

            if (batch_idx + 1) % 8 == 0:
                if (verbose):
                    print(f"Epoch {epoch} | Batch {batch_idx+1}/{len(train_loader)} | "
                        f"Loss: {total_loss / (batch_idx+1):.4f}")
                data, target = next(iter(train_loader))
                if (not dont_touch):
                    data = data.to(device).view(data.size(0), -1)

                out = model(data)
                spike_means = []
                for i in out["spikes"]:
                    spike_means.append(i.mean().item())
                spike_means_outer.append(torch.tensor(spike_means))
                if (verbose):
                    for i in range(len(spike_means)):
                        print(f"layer {i+1} spike means:", spike_means[i])


        #calculate accuracy
        acc = evaluate(model, test_loader, device)
        if (verbose):
            print(f"Epoch {epoch} finished. Test accuracy: {acc:.2f}%")
    spike_means_outer = torch.stack(spike_means_outer, dim = 0)
    return model, acc, spike_means_outer
def evaluate(model, loader, device, dont_touch = False):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in loader:
            data = data.to(device)
            target = target.to(device)
            if (not dont_touch):
                data = data.view(data.size(0), -1) #it is not already flattened, so flatten it.

            out = model(data)
            out_rate = out["out_rate"]  # (B, out_dim)
            pred = out_rate.argmax(dim=1)
            correct += (pred == target).sum().item()
            total += target.size(0)
    return 100.0 * correct / total


device = torch.device("cuda")
#returns a model and optimizer,
#you can determine the hidden layer number with hidden_dims
#hidden dims -> [X, Y, Z, T, U] it means that input layer is X; Y, Z, T are neuron numbers of the hidden layers, and U is out_dim
#you can set different thresholds for each of the hidden_layer, default is 1 threshold
def get_model_eligibility(hidden_dims, thresholds = None, super_spike_beta = 25, lr = 1e-3, out_dims = 10):
    if (thresholds is None):
        thresholds = [1 for _ in range(len(hidden_dims) - 1)]

    model = RandomBPSNNEligibility()
    weights = []
    for i in range(1, len(hidden_dims)):
        model.append_LIF(LIFLayerEligibility(hidden_dims[i-1], hidden_dims[i], threshold=thresholds[i-1], super_spike_B=super_spike_beta), out_dim=out_dims)
        weights.append(model.lif_layers[-1].W)
    model = model.to(device)
    optimizer = torch.optim.Adam(
        weights, lr=lr
    )
    return model, optimizer

In [12]:
#experiment hidden sizes
#hidden_sizes = [32, 64, 128, 256]
# Model

hidden_sizes = [32, 64, 128, 256]
epoch = 2
acc_results = []
for hidden_size in hidden_sizes:
    model, optimizer = get_model_eligibility([784, hidden_size, 10], thresholds = [5, 0.3], lr = 0.005, out_dims=10)
    model, acc, spikes_means = train_eligibility(train_loader, test_loader, model, optimizer, num_epochs=epoch, verbose=False, dont_touch=False)
    print(f"For hidden size {hidden_size} accuracy is --> ", acc)
    spikes_means = torch.mean(spikes_means, dim = 0, keepdim=False)
    for i in range(len(spikes_means)):
        print(f"Layer {i} density -> ", spikes_means[i])
    print()


Using device: cuda
For hidden size 32 accuracy is -->  89.99
Layer 0 density ->  tensor(0.4973)
Layer 1 density ->  tensor(0.0713)

Using device: cuda
For hidden size 64 accuracy is -->  90.94
Layer 0 density ->  tensor(0.4743)
Layer 1 density ->  tensor(0.0826)

Using device: cuda
For hidden size 128 accuracy is -->  89.81
Layer 0 density ->  tensor(0.4545)
Layer 1 density ->  tensor(0.0892)

Using device: cuda
For hidden size 256 accuracy is -->  91.5
Layer 0 density ->  tensor(0.4784)
Layer 1 density ->  tensor(0.0861)



In [None]:

epoch = 2
model, optimizer = get_model_eligibility([784, 256, 64, 10], thresholds = [5, 4, 0.3], lr = 0.005)
model, acc, spikes_means = train_eligibility(train_loader, test_loader, model, optimizer, num_epochs=epoch, verbose=False, apply_poisson=True)
print(f"accuracy is --> ", acc)
spikes_means = torch.mean(spikes_means, dim = 0, keepdim=False)
for i in range(len(spikes_means)):
    print(f"Layer {i} density -> ", spikes_means[i])
print()


Using device: cuda
accuracy is -->  89.19
Layer 0 density ->  tensor(0.5298)
Layer 1 density ->  tensor(0.5233)
Layer 2 density ->  tensor(0.0674)



In [39]:

epoch = 2
model, optimizer = get_model_eligibility([784, 64, 10], thresholds = [3, 0.09], lr = 0.005)

model, acc, spikes_means = train_eligibility(train_loader, test_loader, model, optimizer, num_epochs=epoch, verbose=True, apply_poisson=False)

print(f"accuracy is --> ", acc)
spikes_means = torch.mean(spikes_means, dim = 0, keepdim=False)
for i in range(len(spikes_means)):
    print(f"Layer {i} density -> ", spikes_means[i])
print()


Using device: cuda
Epoch 1 | Batch 400/938 | Loss: 1.8441
layer 1 spike means: 0.44755861163139343
layer 2 spike means: 0.1049218699336052
Epoch 1 | Batch 800/938 | Loss: 1.7337
layer 1 spike means: 0.46051025390625
layer 2 spike means: 0.09765625
Epoch 1 finished. Test accuracy: 88.97%
Epoch 2 | Batch 400/938 | Loss: 1.5974
layer 1 spike means: 0.48680421710014343
layer 2 spike means: 0.09023437649011612
Epoch 2 | Batch 800/938 | Loss: 1.5966
layer 1 spike means: 0.5238281488418579
layer 2 spike means: 0.09242187440395355
Epoch 2 finished. Test accuracy: 90.02%
accuracy is -->  90.02
Layer 0 density ->  tensor(0.4797)
Layer 1 density ->  tensor(0.0963)



In [105]:
#experiment hidden sizes
#hidden_sizes = [32, 64, 128, 256]
# Model

hidden_sizes = [32, 64, 128, 256]
epoch = 2
acc_results = []
for hidden_size in hidden_sizes:
    model, optimizer = get_model_eligibility([784, hidden_size, 10], thresholds = [50, 0.3], lr = 0.005)
    model, acc, spikes_means = train_eligibility(train_loader, test_loader, model, optimizer, num_epochs=epoch, verbose=False)
    print(f"For hidden size {hidden_size} accuracy is --> ", acc)
    spikes_means = torch.mean(spikes_means, dim = 0, keepdim=False)
    for i in range(len(spikes_means)):
        print(f"Layer {i} density -> ", spikes_means[i])
    print()


Using device: cuda
For hidden size 32 accuracy is -->  79.7
Layer 0 density ->  tensor(0.5751)
Layer 1 density ->  tensor(0.0783)

Using device: cuda
For hidden size 64 accuracy is -->  88.27
Layer 0 density ->  tensor(0.4855)
Layer 1 density ->  tensor(0.0893)

Using device: cuda
For hidden size 128 accuracy is -->  89.65
Layer 0 density ->  tensor(0.4396)
Layer 1 density ->  tensor(0.0884)

Using device: cuda
For hidden size 256 accuracy is -->  81.75
Layer 0 density ->  tensor(0.4485)
Layer 1 density ->  tensor(0.0795)



In [46]:

epoch = 15
model, optimizer = get_model_eligibility([784, 64, 32, 10], thresholds = [3, 2, 0.09], lr = 5e-4)

model, acc, spikes_means = train_eligibility(train_loader, test_loader, model, optimizer, num_epochs=epoch, verbose=True, apply_poisson=False)

print(f"Num layers: {len(model.lif_layers)}")
print(f"Num G matrices: {len(model.G_hiddens)}")

for i, lif in enumerate(model.lif_layers):
    print(f"Layer {i}: {lif.in_features} → {lif.out_features}")

for i, G in enumerate(model.G_hiddens):
    print(f"G[{i}]: {G.shape}")

for i, lif in enumerate(model.lif_layers):
    if lif.W.grad is not None:
        print(f"Layer {i} grad mean: {lif.W.grad.abs().mean().item():.6f}")
        print(f"Layer {i} grad max: {lif.W.grad.abs().max().item():.6f}")


print(f"accuracy is --> ", acc)
spikes_means = torch.mean(spikes_means, dim = 0, keepdim=False)
for i in range(len(spikes_means)):
    print(f"Layer {i} density -> ", spikes_means[i])

print()


Using device: cuda
Epoch 1 | Batch 400/938 | Loss: 2.3384
layer 1 spike means: 0.5126953125
layer 2 spike means: 0.4975341856479645
layer 3 spike means: 0.014843749813735485
Epoch 1 | Batch 800/938 | Loss: 2.2484
layer 1 spike means: 0.591748058795929
layer 2 spike means: 0.5405029654502869
layer 3 spike means: 0.055937498807907104
Epoch 1 finished. Test accuracy: 54.11%
Epoch 2 | Batch 400/938 | Loss: 1.9369
layer 1 spike means: 0.5823730826377869
layer 2 spike means: 0.5014892816543579
layer 3 spike means: 0.10312499850988388
Epoch 2 | Batch 800/938 | Loss: 1.8860
layer 1 spike means: 0.617846667766571
layer 2 spike means: 0.4841064512729645
layer 3 spike means: 0.09921874850988388
Epoch 2 finished. Test accuracy: 82.28%
Epoch 3 | Batch 400/938 | Loss: 1.7000
layer 1 spike means: 0.609375
layer 2 spike means: 0.47395020723342896
layer 3 spike means: 0.10468749701976776
Epoch 3 | Batch 800/938 | Loss: 1.6821
layer 1 spike means: 0.6128906607627869
layer 2 spike means: 0.46958008408546

In [43]:

epoch = 5
model, optimizer = get_model_eligibility([784, 64, 32, 10], thresholds = [3, 2, 0.09], lr = 0.05)#increase lrto increase convergence speed

model, acc, spikes_means = train_eligibility(train_loader, test_loader, model, optimizer, num_epochs=epoch, verbose=True, apply_poisson=False)

print(f"accuracy is --> ", acc)
spikes_means = torch.mean(spikes_means, dim = 0, keepdim=False)
for i in range(len(spikes_means)):
    print(f"Layer {i} density -> ", spikes_means[i])
print()


Using device: cuda
Epoch 1 | Batch 400/938 | Loss: 2.3035
layer 1 spike means: 0.48603516817092896
layer 2 spike means: 0.44023439288139343
layer 3 spike means: 0.0
Epoch 1 | Batch 800/938 | Loss: 2.2954
layer 1 spike means: 0.5793213248252869
layer 2 spike means: 0.4284912049770355
layer 3 spike means: 0.019453125074505806
Epoch 1 finished. Test accuracy: 9.90%
Epoch 2 | Batch 400/938 | Loss: 2.2277
layer 1 spike means: 0.5710693597793579
layer 2 spike means: 0.43400880694389343
layer 3 spike means: 0.009531250223517418
Epoch 2 | Batch 800/938 | Loss: 2.2366
layer 1 spike means: 0.5257934927940369
layer 2 spike means: 0.46015626192092896
layer 3 spike means: 0.012578125111758709
Epoch 2 finished. Test accuracy: 9.80%
Epoch 3 | Batch 400/938 | Loss: 2.2305
layer 1 spike means: 0.53448486328125
layer 2 spike means: 0.4595947265625
layer 3 spike means: 0.0015624999068677425
Epoch 3 | Batch 800/938 | Loss: 2.2329
layer 1 spike means: 0.522143542766571
layer 2 spike means: 0.469970703125
l

In [None]:

epoch = 5
model, optimizer = get_model_eligibility([784, 256, 128, 10], thresholds = [3, 2, 0.09], lr = 0.05)#increase lrto increase convergence speed

model, acc, spikes_means = train_eligibility(train_loader, test_loader, model, optimizer, num_epochs=epoch, verbose=True, apply_poisson=False)

print(f"accuracy is --> ", acc)
spikes_means = torch.mean(spikes_means, dim = 0, keepdim=False)
for i in range(len(spikes_means)):
    print(f"Layer {i} density -> ", spikes_means[i])
print()


Overall, SNNs are a great way to reduce the energy needs of models in both training and test phases. However, there are lots of variety on how we construct our layers, for example which cell do we choose, or which surrogate gradients we should use etc etc.. and parameter tuning is an important thing in SNNs.

For example in the https://www.pnas.org/doi/10.1073/pnas.2109194119 this article you can see the accuracy results of different implementations, ranging from
71.2% to 98.1%. 
According to the article, they used augmentation, also they used BPTT for this results trained it for 100 epochs. I trained a model above up to 15 epochs, which was achieved 93% rate, if I tune the values even a little bit I might have approached that level too. In addition, that article used sparsity regularization, increases the sparsity finding the tradeoff between performance and sparsity.
