<a href="https://colab.research.google.com/github/k1151msarandega/1st-order/blob/main/1st_order_rate_encoding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Architecture:** *1st-order model*


**Encoding Scheme:** *Rate encoding*

In [None]:
pip install snntorch

import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import snntorch as snn
import time

1. Define the SNN model

In [None]:
class SNN(torch.nn.Module):
    def __init__(self):
        super(SNN, self).__init__()
        self.fc = torch.nn.Linear(784, 10)

    def forward(self, x):
        x = torch.sigmoid(self.fc(x))
        return x


2. Define the rate encoding function

In [None]:
def rate_encoding(x):
    return x * 255

3. Load the MNIST dataset

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(rate_encoding)
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)


4. Create data loaders

In [None]:
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


5. Define the SNN model and optimiser

In [None]:
snn_model = SNN()
snn_optimizer = torch.optim.Adam(snn_model.parameters(), lr=0.001)

6. Define the SNN simulator

In [None]:
sim = snn.Simulator()

7. Train the SNN model

In [None]:
num_epochs = 5

start_time = time.time()

for epoch in range(num_epochs):
    snn_model.train()
    for batch_idx, (data, targets) in enumerate(train_loader):
        # Reset the SNN simulator
        sim.reset()

        # Encode input spikes
        inputs = sim.input(data)

        # Forward pass through the SNN model
        outputs = snn_model(inputs)

        # Compute loss and backpropagation
        loss = sim.loss(outputs, targets)
        snn_optimizer.zero_grad()
        loss.backward()
        snn_optimizer.step()

        # Print progress
        if batch_idx % 100 == 0:
            print(f'Epoch {epoch+1}/{num_epochs}, Step {batch_idx}/{len(train_loader)}, Loss: {loss.item()}')

training_time = time.time() - start_time


8. Test the SNN model

In [None]:
snn_model.eval()
correct = 0
total = 0

start_time = time.time()

with torch.no_grad():
    for data, targets in test_loader:
        # Reset the SNN simulator
        sim.reset()

        # Encode input spikes
        inputs = sim.input(data)

        # Forward pass through the SNN model
        outputs = snn_model(inputs)

        # Get predicted labels
        _, predicted = torch.max(outputs.data, 1)

        # Compute accuracy
        total += targets.size(0)
        correct += (predicted == targets).sum().item()


9. Evaluation metrics

In [None]:
testing_time = time.time() - start_time

accuracy = 100 * correct / total
print(f'Accuracy: {accuracy}%')
print(f'Training Time: {training_time} seconds')
print(f'Testing Time: {testing_time} seconds')