In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import norse.torch as norse
import numpy as np

# Define constants
NUM_EPOCHS = 100
BATCH_SIZE = 32
TIME_STEPS = 100  # Number of time steps for spike simulation
ENCODING_SCALING = 10  # Scale to increase firing rates
SPIKE_THRESHOLD = 0.1  # Threshold for converting spikes to output value

def log_scale(values):
    return torch.sign(values) * torch.log1p(torch.abs(values))

# Generate dataset for the normal curve
def generate_dataset(num_samples=1000):
    mu_values = np.random.uniform(-1, 1, num_samples)
    sigma_values = np.random.uniform(0.1, 2, num_samples)  # Avoid zero
    x_values = np.random.uniform(-3, 3, num_samples)
    targets = (1 / (sigma_values * np.sqrt(2 * np.pi))) * np.exp(
        -((x_values - mu_values) ** 2) / (2 * sigma_values ** 2)
    )
    inputs = np.stack([mu_values, sigma_values, x_values], axis=1)
    return torch.tensor(inputs, dtype=torch.float32), torch.tensor(targets, dtype=torch.float32)

# Encode input values as spike trains
def encode_as_spike_trains(values):
    # Scale values and encode as rate-coded spikes
    values = log_scale(values)

    spikes = torch.zeros((values.size(0), values.size(1), TIME_STEPS))
    for t in range(TIME_STEPS):
        spikes[:, :, t] = torch.bernoulli(values * ENCODING_SCALING)
    return spikes

# Decode output spike trains to a scalar value
def decode_spike_train(spikes):
    # Count spikes and normalize
    spike_counts = torch.sum(spikes, dim=-1)  # Sum over time steps
    return spike_counts / TIME_STEPS  # Normalize by total time steps

# SNN definition
class SpikingNormalCurve(nn.Module):
    def __init__(self):
        super().__init__()
        self.input_layer = norse.LIFRecurrentCell(input_size=3, hidden_size=16)
        self.hidden_layer = norse.LIFRecurrentCell(input_size=16, hidden_size=16)
        self.output_layer = norse.LIFRecurrentCell(input_size=16, hidden_size=1)

    def forward(self, spike_trains):
        batch_size, _, time_steps = spike_trains.size()
        state_input = state_hidden = state_output = None
        
        output_spikes = []
        for t in range(time_steps):
            input_spike = spike_trains[:, :, t]
            _, state_input = self.input_layer(input_spike, state_input)
            _, state_hidden = self.hidden_layer(state_input, state_hidden)
            output, state_output = self.output_layer(state_hidden, state_output)
            output_spikes.append(output)

        return torch.stack(output_spikes, dim=-1)  # Stack spikes over time

# Loss function (MSE on decoded values)
def loss_function(predicted, target):
    return torch.mean((predicted - target) ** 2)

# Training procedure
def train_snn():
    # Generate dataset
    inputs, targets = generate_dataset()
    dataset = torch.utils.data.TensorDataset(inputs, targets)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

    # Initialize model, optimizer, and loss function
    model = SpikingNormalCurve()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    # Training loop
    for epoch in range(NUM_EPOCHS):
        epoch_loss = 0.0
        for batch_inputs, batch_targets in dataloader:
            # Encode inputs as spike trains
            spike_trains = encode_as_spike_trains(batch_inputs)

            # Forward pass
            output_spike_trains = model(spike_trains)
            decoded_outputs = decode_spike_train(output_spike_trains)

            # Compute loss
            loss = loss_function(decoded_outputs, batch_targets)
            epoch_loss += loss.item()

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch + 1}/{NUM_EPOCHS}, Loss: {epoch_loss / len(dataloader)}")

    return model

# Main execution
if __name__ == "__main__":
    trained_model = train_snn()


RuntimeError: Expected p_in >= 0 && p_in <= 1 to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)