<a href="https://colab.research.google.com/github/k1151msarandega/Lapicque-s-RC/blob/main/Lapicque's_RC_rank_order_encoding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Architecture:** *Lapicque's RC*

**Encoding Scheme:** *Rank-order encoding*

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from snntorch import spikegen, surrogate, encoding, data
from snntorch import snn
from snntorch import utils
from snntorch import models
import time

1. Set the device

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

2. Define the rank-order encoder

In [None]:
encoder = encoding.RankOrderEncoder()

3. Define the Lapicque's RC model

In [None]:
class LapicqueRCModel(nn.Module):
    def __init__(self):
        super(LapicqueRCModel, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = x.view(-1, 784)
        x = self.fc1(x)
        x = torch.relu(x)
        x = self.fc2(x)
        return x

4. Load the MNIST dataset

In [None]:
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=64, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=64, shuffle=True)

5. Create model instance

In [None]:
model = LapicqueRCModel().to(device)

6. Define the spike generation function

In [None]:
spike_fn = spikegen.probability

7. Define the surrogate gradient function

In [None]:
surrogate_fn = surrogate.Sigmoid()

8. Define the optimizer and loss function

In [None]:
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
loss_fn = nn.CrossEntropyLoss()

9. Define the SNN optimizer

In [None]:
snn_optimizer = snn.Adam(model, optimizer)

10. Train the SNN

In [None]:
num_epochs = 10
start_time = time.time()
for epoch in range(num_epochs):
    # Training
    model.train()
    utils.reset_layerwise_stats(model)
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        # Reset the spike accumulator
        model.zero_spike_accumulator()

        # Encode the input spikes
        input_spikes = encoder(data, time=0.5)

        # Run the SNN
        output_spikes = model(input_spikes)

        # Compute the spike gradients and update the model parameters
        snn_optimizer.step(output_spikes, target)

        if batch_idx % 100 == 0:
            print('Epoch: {} [{}/{} ({:.0f}%)]'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader)))


11. Testing

In [None]:
model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            input_spikes = encoder(data, time=0.5)
            output_spikes = model(input_spikes)
            _, predicted = output_spikes.max(1)
            correct += predicted.eq(target).sum().item()

12. Evaluation metrics

In [None]:
accuracy = 100. * correct / len(test_loader.dataset)
print('Accuracy: {:.2f}%'.format(accuracy))

end_time = time.time()
execution_time = end_time - start_time
print('Execution Time: {:.2f} seconds'.format(execution_time))