# Setting up

In [117]:
# All the needed imports
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import snntorch as snn
import snntorch.spikegen as spikegen
import spikingjelly.activation_based.encoding as encoding
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import matplotlib.pyplot as plt
import pathlib

**prepocess images**

In [118]:
# Finding the data in the dictionary
data_dir = pathlib.Path("..") / "brain-tumor-mri-dataset"
train_dir = data_dir / "Training"
test_dir = data_dir / "Testing"


transform_train = transforms.Compose([
    transforms.Grayscale(),  # Convert to grayscale
    transforms.Resize((64, 64)),  # Resize for SNN, 64 for saving memory
    transforms.RandomRotation(30),  # Rotation augmentation (randomly rotates the image up to ±30 degrees)
    transforms.RandomHorizontalFlip(), # Flips images horizontally (left to right)
    transforms.ToTensor(),  # Convert to Tensor
    transforms.Normalize((0.5,), (0.5,))  # Normalize, by reducing sharp spikes in activation
])

# Evaluate the model on real, unaltered images, so we dont do image augmantation.
transform_test = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load images to train and test sets
train_dataset = datasets.ImageFolder(root=train_dir, transform=transform_train)
test_dataset = datasets.ImageFolder(root=test_dir, transform=transform_test)

# Group multiple images into batches 
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Test correctness
print("Class Labels:", train_dataset.classes)


Class Labels: ['glioma', 'meningioma', 'notumor', 'pituitary']


**encoding images**

In [119]:
# Encoding transforms each pixel into a series of spikes:

# Ensures your code can run on any machine
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define timesteps (how many times an image generates spikes)
# Instead of feeding the network one static image,
# we feed 10 spike versions of it, to simulate passage of time
timesteps = 10

# rate-based encoding, it interprets the pixel intensity as a spike probability
def encode_batch(batch, timesteps):
    return spikegen.rate(batch, num_steps=timesteps).to(device)


# SNN MODEL

In [120]:
# This class uses spiking neurons, specifically Leaky Integrate-and-Fire(LIF),
# to classify brain tumor MRI images into one of four classes
class TumorSNN(nn.Module):
    def __init__(self):
        super().__init__()

        # 1st convolution layer:
        # Input: 1 grayscale channel
        # Output: 16 feature maps, using a 5x5 kernel
        self.conv1 = nn.Conv2d(1, 16, kernel_size=5, padding=2)
        # Leaky Integrate-and-Fire neuron layer after conv1
        self.lif1 = snn.Leaky(beta=0.9)

        # 2nd convolution layer:
        # It takes 16 input channels, outputs 32 feature maps
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5, padding=2)
        # Leaky Integrate-and-Fire neuron layer after conv2
        self.lif2 = snn.Leaky(beta=0.9)

        # Max pooling layer to reduce spatial dimensions by half
        self.pool = nn.MaxPool2d(2) 

        # Fully connected layer:
        # After 2 pools, input image size is 16x16
        # With 32 feature maps, total flattened size = 32 * 16 * 16 = 8192
        # Output: 4 classes (tumor types)
        self.fc = nn.Linear(32 * 16 * 16, 4)

    def forward(self, x):
        # Creates an initial state of that membrane potential, which is required to:
        # - Keep track of voltage inside the neuron
        # - Allow the model to "remember" signals from previous time steps
        # for 2 LIF layers
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()

        # To sum up the outputs from each timestep, because:
        # x has shape [timesteps, batch, channels, height, width]
        spk_sum = 0
        for step in range(x.size(0)):  # timesteps
            cur = x[step] # Get image spike frame at current timestep

            # First conv. + pool layer
            cur = self.pool(self.conv1(cur))
            spk1, mem1 = self.lif1(cur, mem1) # LIF neuron returns (spike, membrane)

            # Second conv + pool layer
            cur = self.pool(self.conv2(spk1))
            spk2, mem2 = self.lif2(cur, mem2) # LIF neuron returns (spike, membrane)

            # Flatten the output for the fully connected layer, 
            # turning a 3D tensor (channels, height, width) into a 1D vector
            # so it can be passed into a fully connected (dense) layer ????
            flat = spk2.view(spk2.size(0), -1)

            # Final linear classification layer
            out = self.fc(flat)

            spk_sum += out

        # Return the average output across all timesteps
        return spk_sum / x.size(0)


**initialize the model**

In [121]:
model = TumorSNN().to(device)

**Training function**

In [None]:
# Define Loss function
# It compares predicted class scores to true class labels
criterion = nn.CrossEntropyLoss()

# Define the optimizer
# We use Adam for efficient optimizing that adjusts learning rates dynamically
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training function
def train_snn(model, loader, epochs=5):
    model.train() # set the model into training mode

    for epoch in range(epochs):
        total_loss = 0 # Keep track of total loss in this epoch

        # Loop over all batches in the dataset
        for imgs, labels in loader:

            # Move data to GPU (or CPU), depending on device
            imgs, labels = imgs.to(device), labels.to(device)

            # Encode static images into spike trains over time
            spikes = encode_batch(imgs, timesteps)

            # Reset gradients before backprop
            optimizer.zero_grad()

            # Get predictions from the SNN
            outputs = model(spikes)

            # Compute loss between predictions and true labels
            loss = criterion(outputs, labels)

            # Compute gradients, backpropagation
            loss.backward()
            
            # Optimizer step, update model weights
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs} | Loss: {total_loss/len(loader):.4f}")

**Testing Function (Full Evaluation)**

In [None]:
def test_snn(model, loader):
    model.eval() # set the model to evaluation mode
    y_true, y_pred = [], []

    # Disables gradient calculations (saves memory and speeds up inference)
    with torch.no_grad():
        for imgs, labels in loader:
            # Move data to the same device as the model
            imgs, labels = imgs.to(device), labels.to(device)
            # Encode static images into spiking format
            spikes = encode_batch(imgs, timesteps)
            # Run the forward pass to get predictions
            outputs = model(spikes)
            # Get predicted class index (0-3) with highest score per sample
            preds = torch.argmax(outputs, dim=1)

            # Saves results for evaluation, move from GPU to CPU.
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())

    print(f"Accuracy: {accuracy_score(y_true, y_pred) * 100:.2f}%")
    print(f"Precision: {precision_score(y_true, y_pred, average='weighted') * 100:.2f}%")
    print(f"Recall: {recall_score(y_true, y_pred, average='weighted') * 100:.2f}%")
    print(f"F1 Score: {f1_score(y_true, y_pred, average='weighted') * 100:.2f}%")
    print("Confusion Matrix:")
    print(confusion_matrix(y_true, y_pred))


**Run the model**

In [124]:
train_snn(model, train_loader, epochs=5)
test_snn(model, test_loader)

Epoch 1/5 | Loss: 1.0248
Epoch 2/5 | Loss: 0.8582
Epoch 3/5 | Loss: 0.7977
Epoch 4/5 | Loss: 0.7667
Epoch 5/5 | Loss: 0.7471
Accuracy: 69.26%
Precision: 69.46%
Recall: 69.26%
F1 Score: 69.22%
Confusion Matrix:
[[216  59   1  24]
 [ 51 162  66  27]
 [ 38  40 301  26]
 [ 43  17  11 229]]
