<a href="https://colab.research.google.com/github/k1151msarandega/Lapicque-s-RC/blob/main/Lapicque's_RC_phase_encoding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Architecture:** *Lapicque's RC*

**Encoding Scheme:** *Phase encoding*

1. Install the required libraries: SNNtorch, PyTorch and import the necessary modules

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from snntorch import spikegen
from snntorch import data
from snntorch import conversion
from snntorch import surrogate

2. Define the SNN Lapicque's RC model

In [None]:
class SNNModel(nn.Module):
    def __init__(self):
        super(SNNModel, self).__init__()
        self.fc1 = spikegen.SpikeLinear(784, 256)
        self.fc2 = spikegen.SpikeLinear(256, 128)
        self.fc3 = spikegen.SpikeLinear(128, 10)
        self.surrogate = surrogate.ATan()
        self.surr_alpha = nn.Parameter(torch.tensor(1.5))

    def forward(self, x):
        x = self.fc1(x)
        x = self.surrogate(x, self.surr_alpha)
        x = self.fc2(x)
        x = self.surrogate(x, self.surr_alpha)
        x = self.fc3(x)
        x = self.surrogate(x, self.surr_alpha)
        return x

3. Prepare the dataset

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)

testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)

4. Initialise the SNN model and define the loss function and optimiser

In [None]:
snn_model = SNNModel()
loss_function = nn.CrossEntropyLoss()
optimizer = optim.SGD(snn_model.parameters(), lr=0.01, momentum=0.9)

5. Train the SNN model

In [None]:
num_epochs = 10

for epoch in range(num_epochs):
    running_loss = 0.0
    train_accuracy = 0.0
    for i, (inputs, labels) in enumerate(trainloader):
        inputs = spikegen.phase_encode(inputs)  # Perform phase encoding on inputs

        snn_model.zero_grad()  # Reset the membrane potentials and spikes of the SNN model

        outputs = snn_model(inputs)  # Perform forward pass through the SNN model

        loss = loss_function(outputs, labels)  # Compute the loss
        loss.backward()  # Perform backward pass
        optimizer.step()  # Update the weights

        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        train_accuracy += (predicted == labels).sum().item()

    train_accuracy = 100.0 * train_accuracy / len(trainset)
    print(f"Epoch: {epoch + 1}, Loss: {running_loss / len(trainloader)}, Accuracy: {train_accuracy}%")

print("Training complete!")

6. Evaluate the SNN model on the test set

In [None]:
correct = 0
total = 0

with torch.no_grad():
    for inputs, labels in testloader:
        inputs = spikegen.phase_encode(inputs)  # Perform phase encoding on inputs

        snn_model.zero_grad()  # Reset the membrane potentials and spikes of the SNN model

        outputs = snn_model(inputs)  # Perform forward pass through the SNN model

        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

test_accuracy = 100.0 * correct / total
print(f"Test Accuracy: {test_accuracy}%")


7. Evaluation metrics (time)

In [None]:
import time

start_time = time.time()

# Training and evaluation code

end_time = time.time()
computational_time = end_time - start_time
print(f"Computational time: {computational_time} seconds")