<a href="https://colab.research.google.com/github/k1151msarandega/Lapicque-s-RC/blob/main/Lapicque's_RC_temporal_encoding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Architecture:** *Lapicque's RC*


**Encoding Scheme:** *Temporal encoding*

In [None]:
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import snntorch as snn

1. Set the random seed for reproducibility

In [None]:
torch.manual_seed(0)

2. Load MNIST dataset

In [None]:
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor(), download=True)

3. Define the training and testing data loaders

In [None]:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=128, shuffle=False)


4. Define the Lapicque's RC model using temporal encoding

In [None]:
class LapicqueRC(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(LapicqueRC, self).__init__()

        # Input layer
        self.input = nn.Linear(input_size, hidden_size)

        # Hidden layer
        self.hidden = snn.Lapicque(hidden_size, tau=10, threshold=1.0)

        # Output layer
        self.output = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.input(x)
        x = self.hidden(x)
        x = self.output(x)
        return x

5. Create an instance of the Lapicque's RC model

In [None]:
model = LapicqueRC(input_size=784, hidden_size=256, output_size=10)

6. Define the loss function and optimiser

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

7. Training loop

In [None]:
start_time = time.time()
for epoch in range(10):  # Number of training epochs
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0

    for inputs, targets in train_loader:
        optimizer.zero_grad()

        # Flatten the input images
        inputs = inputs.view(-1, 784)

        # Apply temporal encoding to the input spike trains
        inputs = snn.io.spike_input(inputs, time=1.0, dt=1.0)

        # Forward pass
        outputs = model(inputs)

        # Compute loss
        loss = criterion(outputs, targets)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Update statistics
        total_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += (predicted == targets).sum().item()

    # Print training statistics
    print('Epoch:', epoch+1)
    print('Loss:', total_loss / len(train_loader))
    print('Accuracy:', correct / total)

8. Print training time

In [None]:
end_time = time.time()
elapsed_time = end_time - start_time
print('Training Time:', elapsed_time, 'seconds')

9. Testing Loop

In [None]:
start_time = time.time()
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for inputs, targets in test_loader:
        # Flatten the input images
        inputs = inputs.view(-1, 784)

        # Apply temporal encoding to the input spike trains
        inputs = snn.io.spike_input(inputs, time=1.0, dt=1.0)

        # Forward pass
        outputs = model(inputs)

        # Compute predictions
        _, predicted = torch.max(outputs.data, 1)

        # Update statistics
        total += targets.size(0)
        correct += (predicted == targets).sum().item()

10. Print testing statistics

In [None]:
print('Test Accuracy:', correct / total)

end_time = time.time()
elapsed_time = end_time - start_time
print('Testing Time:', elapsed_time, 'seconds')