In [1]:
!pip install torch --quiet
!pip install sinabs==0.3.2 --quiet
!pip install scikit-learn --quiet

In [2]:
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import torch.nn as nn
import torch.nn.functional as F
from sinabs.from_torch import from_model

In [None]:
# iris = load_iris()
# X = iris.data  # shape (150, 4)
# y = iris.target  # shape (150,)

# # We'll do a train/test split
# X_train, X_test, y_train, y_test = train_test_split(
#     X, y, test_size=0.2, random_state=42, stratify=y
# )

# # Standard scaling for better convergence
# scaler = StandardScaler()
# X_train = scaler.fit_transform(X_train)
# X_test = scaler.transform(X_test)

iris = load_iris()
X = iris.data  # shape (150, 4)
y = iris.target  # shape (150,)

# We'll do a train/test split
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

# Instead of standard scaling, let's do min-max scaling directly:
min_vals = X_train.min(axis=0)
max_vals = X_train.max(axis=0)

X_train = (X_train - min_vals) / (max_vals - min_vals)
X_test = (X_test - min_vals) / (max_vals - min_vals)


In [4]:
# class IrisDataset(Dataset):
#     def __init__(self, X, y, is_spiking=False, time_window=50):
#         """
#         :param X: Input features, shape (N, 4)
#         :param y: Targets, shape (N,)
#         :param is_spiking: Whether to convert inputs into spike trains.
#         :param time_window: Number of time steps if is_spiking=True.
#         """
#         self.X = torch.tensor(X, dtype=torch.float32)
#         self.y = torch.tensor(y, dtype=torch.long)
#         self.is_spiking = is_spiking
#         self.time_window = time_window

#     def __len__(self):
#         return len(self.X)

#     def __getitem__(self, index):
#         features = self.X[index]  # shape (4,)
#         target = self.y[index]

#         # If spiking, convert features to spike trains
#         if self.is_spiking:
#             # features in range: after scaling, they could be positive/negative.
#             # We'll ensure all values are positive by normalizing to [0,1] for Poisson sampling.
#             # Let's map features to [0, 1] using min-max normalization across the whole dataset
#             # For simplicity, let's assume features ~ N(0,1) after StandardScaler.
#             # We'll clip them between -3 and +3 and then map to [0,1].
#             features_clipped = torch.clamp(features, -3, 3)
#             features_norm = (features_clipped + 3) / 6.0  # now in [0,1]

#             # Poisson sampling: For each time step, draw a random number and compare:
#             # output shape: (time_window, 4)
#             spike_train = (torch.rand(self.time_window, *features_norm.shape) < features_norm).float()
#             return spike_train, target

#         return features, target

class IrisDataset(Dataset):
    def __init__(self, X, y, is_spiking=False, time_window=50):
        """
        :param X: Input features, shape (N, 4), already scaled to [0,1].
        :param y: Targets, shape (N,)
        :param is_spiking: Whether to convert inputs into spike trains.
        :param time_window: Number of time steps if is_spiking=True.
        """
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)
        self.is_spiking = is_spiking
        self.time_window = time_window

    def __len__(self):
        return len(self.X)

    def __getitem__(self, index):
        features = self.X[index]  # shape (4,)
        target = self.y[index]

        # If spiking, use features as probabilities for Poisson-like sampling
        if self.is_spiking:
            # features are already in [0,1], so we can use them directly
            # Poisson sampling: For each time step, draw a random number and compare:
            spike_train = (torch.rand(self.time_window, features.shape[0]) < features).float()
            return spike_train, target

        return features, target

In [5]:
train_dataset = IrisDataset(X_train, y_train, is_spiking=False)
test_dataset = IrisDataset(X_test, y_test, is_spiking=False)

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [6]:
ann = nn.Sequential(
    nn.Linear(4, 10, bias=False),
    nn.ReLU(),
    nn.Linear(10, 3, bias=False)
)

device = "cuda" if torch.cuda.is_available() else "cpu"
ann = ann.to(device)

In [7]:
optim = torch.optim.Adam(ann.parameters(), lr=1e-3)
n_epochs = 200

ann.train()
for epoch in range(n_epochs):
    for data, target in train_loader:
        data, target = data.to(device), target.to(device)
        optim.zero_grad()
        output = ann(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optim.step()


In [8]:
ann.eval()
correct_predictions = []
with torch.no_grad():
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        output = ann(data)
        pred = output.argmax(dim=1, keepdim=True)
        correct = pred.eq(target.view_as(pred))
        correct_predictions.append(correct)
correct_predictions = torch.cat(correct_predictions)
print(f"ANN accuracy: {correct_predictions.float().mean().item()*100:.2f}%")

ANN accuracy: 86.67%


In [9]:
import sinabs.layers as sl

In [None]:
# 5. Convert the trained ANN to an SNN using sinabs
# For SNNs, we need to know the input shape. For Iris, the input shape is (4,) but sinabs expects NCHW or a shape.
# We can treat the input as (1,4) to mimic a shape. sinabs expects a leading channel dimension for conv layers,
# but we have a linear model only. We'll specify input_shape as (4,) or (1,4) depending on sinabs version.

input_shape = (4, )  # just a single-dimensional input
test_batch_size = 1
sinabs_model = from_model(ann, input_shape=input_shape, add_spiking_output=True, synops=False, batch_size=test_batch_size, min_v_mem=0, spike_threshold=1, membrane_subtract=1.5)
# sinabs_model = from_model(ann, add_spiking_output=True, synops=False, batch_size=test_batch_size)
print(sinabs_model.spiking_model)


# Now, let's test the SNN version. We need spiking inputs:
spiking_test_dataset = IrisDataset(X_test, y_test, is_spiking=True, time_window=64)
spiking_test_loader = DataLoader(spiking_test_dataset, batch_size=test_batch_size, shuffle=False)

# Evaluate the SNN with spiking input
sinabs_model.spiking_model.eval()

outputs_vec = []
targets_vec = []

correct_predictions = []
with torch.no_grad():
    for data, target in spiking_test_loader:
        # data shape is now (time_window, batch_size, 4)
        # sinabs expects input in format (time, batch, features)
        data, target = data.to(device), target.to(device)
        print(data.shape)
        data = sl.FlattenTime()(data)
        with torch.no_grad():
          output = sinabs_model(data)
          outputs_vec.append(output)
          output = output.unflatten(
              0, (test_batch_size, output.shape[0] // test_batch_size)
          )
        # output is the integrated output over time. Let's take the final spiking layer and decode:
        # pred = output.sum(1).argmax(dim=1, keepdim=True)
        integrated_output = output.sum(dim=1)  # [1, 3]

        outputs_vec.append(output)
        targets_vec.append(target)

        # Argmax over classes
        pred = integrated_output.argmax(dim=1) # [1]

        # Compare with target
        correct = pred.eq(target)
        correct_predictions.append(correct)

correct_predictions = torch.cat(correct_predictions)
print(f"SNN accuracy (with spiking input): {correct_predictions.float().mean().item()*100:.2f}%")

Sequential(
  (0): Linear(in_features=4, out_features=10, bias=False)
  (1): IAFSqueeze(spike_threshold=1, min_v_mem=0)
  (2): Linear(in_features=10, out_features=3, bias=False)
  (Spiking output): IAFSqueeze(spike_threshold=1, min_v_mem=0)
)
torch.Size([1, 64, 4])
torch.Size([1, 64, 4])
torch.Size([1, 64, 4])
torch.Size([1, 64, 4])
torch.Size([1, 64, 4])
torch.Size([1, 64, 4])
torch.Size([1, 64, 4])
torch.Size([1, 64, 4])
torch.Size([1, 64, 4])
torch.Size([1, 64, 4])
torch.Size([1, 64, 4])
torch.Size([1, 64, 4])
torch.Size([1, 64, 4])
torch.Size([1, 64, 4])
torch.Size([1, 64, 4])
torch.Size([1, 64, 4])
torch.Size([1, 64, 4])
torch.Size([1, 64, 4])
torch.Size([1, 64, 4])
torch.Size([1, 64, 4])
torch.Size([1, 64, 4])
torch.Size([1, 64, 4])
torch.Size([1, 64, 4])
torch.Size([1, 64, 4])
torch.Size([1, 64, 4])
torch.Size([1, 64, 4])
torch.Size([1, 64, 4])
torch.Size([1, 64, 4])
torch.Size([1, 64, 4])
torch.Size([1, 64, 4])
SNN accuracy (with spiking input): 56.67%


In [11]:
sinabs_model.spiking_model

Sequential(
  (0): Linear(in_features=4, out_features=10, bias=False)
  (1): IAFSqueeze(spike_threshold=1, min_v_mem=0)
  (2): Linear(in_features=10, out_features=3, bias=False)
  (Spiking output): IAFSqueeze(spike_threshold=1, min_v_mem=0)
)

In [12]:
outputs_vec

[tensor([[16., -0., -0.],
         [ 0.,  0.,  0.],
         [17., -0., -0.],
         [ 0.,  0.,  0.],
         [ 0.,  0.,  0.],
         [16., -0., -0.],
         [ 0.,  0.,  0.],
         [17., -0., -0.],
         [ 0.,  0.,  0.],
         [ 0.,  0.,  0.],
         [ 0.,  0.,  0.],
         [16., -0., -0.],
         [ 0.,  0.,  0.],
         [17., -0., -0.],
         [17., -0., -0.],
         [ 0.,  0.,  0.],
         [ 0.,  0.,  0.],
         [21., -0., -0.],
         [ 0.,  0.,  0.],
         [16., -0., -0.],
         [ 0.,  0.,  0.],
         [ 0.,  0.,  0.],
         [ 0.,  0.,  0.],
         [ 0.,  0.,  0.],
         [-0., -0.,  0.],
         [ 0.,  0.,  0.],
         [ 0.,  0.,  0.],
         [ 0.,  0.,  0.],
         [ 0.,  0.,  0.],
         [16., -0., -0.],
         [ 0.,  0.,  0.],
         [ 0.,  0.,  0.],
         [17., -0., -0.],
         [16., -0., -0.],
         [17., -0., -0.],
         [-0.,  0.,  0.],
         [16., -0., -0.],
         [17., -0., -0.],
         [ 0

In [13]:
targets_vec

[tensor([0], device='cuda:0'),
 tensor([2], device='cuda:0'),
 tensor([1], device='cuda:0'),
 tensor([1], device='cuda:0'),
 tensor([0], device='cuda:0'),
 tensor([1], device='cuda:0'),
 tensor([0], device='cuda:0'),
 tensor([0], device='cuda:0'),
 tensor([2], device='cuda:0'),
 tensor([1], device='cuda:0'),
 tensor([2], device='cuda:0'),
 tensor([2], device='cuda:0'),
 tensor([2], device='cuda:0'),
 tensor([1], device='cuda:0'),
 tensor([0], device='cuda:0'),
 tensor([0], device='cuda:0'),
 tensor([0], device='cuda:0'),
 tensor([1], device='cuda:0'),
 tensor([1], device='cuda:0'),
 tensor([2], device='cuda:0'),
 tensor([0], device='cuda:0'),
 tensor([2], device='cuda:0'),
 tensor([1], device='cuda:0'),
 tensor([2], device='cuda:0'),
 tensor([2], device='cuda:0'),
 tensor([1], device='cuda:0'),
 tensor([1], device='cuda:0'),
 tensor([0], device='cuda:0'),
 tensor([2], device='cuda:0'),
 tensor([0], device='cuda:0')]