<a href="https://colab.research.google.com/github/k1151msarandega/Lapicque-s-RC/blob/main/Lapicque's_RC_direct_encoding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Architecture:** *Lapicque's RC*


**Encoding Scheme:** *Direct encoding*

1. Import necessary libraries

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from snntorch import simulators

2. Set device

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

3. Define the Lapicque's RC neuron model

In [None]:
class LapicqueRCNeuron(nn.Module):
    def __init__(self):
        super(LapicqueRCNeuron, self).__init__()
        self.threshold = torch.Tensor([1.0]).to(device)
        self.membrane_potential = torch.zeros(1, 1).to(device)

    def forward(self, x):
        self.membrane_potential += x
        spike = (self.membrane_potential >= self.threshold).float()
        self.membrane_potential *= (1 - spike)
        return spike

4. Load MNIST dataset

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

5. Define the SNN model

In [None]:
class SNNLapicqueRC(nn.Module):
    def __init__(self):
        super(SNNLapicqueRC, self).__init__()
        self.fc1 = nn.Linear(784, 512)
        self.fc2 = nn.Linear(512, 10)
        self.neuron = LapicqueRCNeuron()

    def forward(self, x):
        x = x.view(-1, 784)
        x = self.fc1(x)
        x = self.neuron(x)
        x = self.fc2(x)
        return x

6. Set up the SNN model

In [None]:
model = SNNLapicqueRC().to(device)

7. Define the loss function and optimiser

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train the SNN model
def train(model, train_loader, optimizer, criterion):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

8. Test the SNN model

In [None]:
def test(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()
    accuracy = 100. * correct / total
    return accuracy

9. Main training loop

In [None]:
num_epochs = 10

for epoch in range(num_epochs):
    train(model, train_loader, optimizer, criterion)
    accuracy = test(model, test_loader)
    print(f"Epoch {epoch+1}/{num_epochs} - Accuracy: {accuracy:.2f}%")

10. Calculate computational time

In [None]:
start_time = torch.cuda.Event(enable_timing=True)
end_time = torch.cuda.Event(enable_timing=True)
start_time.record()


11. Run the inference on test dataset

In [None]:
test_accuracy = test(model, test_loader)

end_time.record()
torch.cuda.synchronize()
elapsed_time = start_time.elapsed_time(end_time) / 1000  # Convert to seconds

print(f"Test Accuracy: {test_accuracy:.2f}%")
print(f"Computational Time: {elapsed_time:.4f} seconds")