<a href="https://colab.research.google.com/github/k1151msarandega/1st-order/blob/main/1st_order_temporal_encoding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Architecture:** *1st-order model*

**Encoding Scheme:** *Temporal encoding*

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from snntorch import spikegen
from snntorch import surrogate
import time

1. Define the SNN model

In [None]:
class SNNModel(nn.Module):
    def __init__(self):
        super(SNNModel, self).__init__()
        self.spikegen = spikegen.VoltageSpikegen(tau=10.0)
        self.fc = nn.Linear(784, 10)

    def forward(self, x):
        out_spikes = self.spikegen(x)
        out = self.fc(out_spikes)
        return out

2. Set device

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

3. Define training parameters

In [None]:
batch_size = 64
learning_rate = 0.001
num_epochs = 10

4. Load MNIST dataset

In [None]:
train_dataset = MNIST(root='./data', train=True, transform=ToTensor(), download=True)
test_dataset = MNIST(root='./data', train=False, transform=ToTensor())
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

5. Create the SNN model

In [None]:
model = SNNModel()
model.to(device)

6. Define loss function and optimiser

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

7. Function to compute time difference

In [None]:
def get_time_diff(start_time):
    end_time = time.time()
    time_diff = end_time - start_time
    return time_diff

8. Training loop

In [None]:
total_steps = len(train_loader)
start_time = time.time()
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        # Move data to device
        images = images.view(-1, 28 * 28).to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(images)

        # Compute loss
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i+1) % 100 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{total_steps}], Loss: {loss.item()}")

training_time = get_time_diff(start_time)
print(f"Training Time: {training_time} seconds")

9. Test the model

In [None]:
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    start_time = time.time()
    for images, labels in test_loader:
        images = images.view(-1, 28 * 28).to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

10. Evaluation metrics

In [None]:
training_time = get_time_diff(start_time)
print(f"Training Time: {training_time} seconds")

accuracy = 100 * correct / total
    testing_time = get_time_diff(start_time)
    print(f"Test Accuracy: {accuracy}%")
    print(f"Testing Time: {testing_time} seconds")