# JAX Spiking Neural Network

## 1. Imports

In [1]:
import os
import numpy as np
import jax
import jax.numpy as jnp
import pandas as pd
from jax import jit, grad, value_and_grad
import optax # JAX optimizers
import functools
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler


key = jax.random.PRNGKey(0) # PRNGKey is a pseudo-random number generator key, 0 is the seed

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


## 2. Configuration

In [2]:
class Config:
    def __init__(self):

        # Model parameters
        self.input_size = 16
        self.hidden_size = 24
        self.output_size = 4

        # Training parameters
        self.learning_rate = 0.0001
        self.batch_size = 32
        self.num_epochs = 100

        # Data parameters
        self.weight_1_mean = 1.0
        self.weight_2_mean = 0.5
        self.weight_3_mean = 0.5
        self.weight_4_mean = 0.5
        self.weight_1_std = 0.8
        self.weight_2_std = 0.5
        self.weight_3_std = 0.5
        self.weight_4_std = 0.5

        self.neuron_params = {
            'tau_1' : jnp.array([2,2,2,2,2,2,2,2,2,2,2,2,4,4,4,4,4,4,4,4,4,4,4,4]),
            'tau_2' : jnp.array([2,2,2,2,2,2,4,4,4,4,4,4,8,8,8,8,8,8,16,16,16,16,16,16]),
            'tau_3' : jnp.array([2,2,2,4,4,4,8,8,8,16,16,16,32,32,32,64,64,64,128,128,128,256,256,256]),
            'tau_4' : jnp.array([2,2,2,2]),
            'V_reset' : 0.0,
            'V_th' : 1.0,
        }

config = Config()
optimizer = optax.adamax(learning_rate=config.learning_rate)

In [3]:
def initialize_weights(key, config):
    keys = jax.random.split(key, num=4)
    weight_params = {
        'weight_1': jax.random.normal(keys[0], (config.input_size, config.hidden_size)) * config.weight_1_std + config.weight_1_mean,
        'weight_2': jax.random.normal(keys[1], (config.hidden_size, config.hidden_size)) * config.weight_2_std + config.weight_2_mean,
        'weight_3': jax.random.normal(keys[2], (config.hidden_size, config.hidden_size)) * config.weight_3_std + config.weight_3_mean,
        'weight_4': jax.random.normal(keys[3], (config.hidden_size, config.output_size)) * config.weight_4_std + config.weight_4_mean,
    }
    return weight_params

## 3. Data preparation

 - load dataset
 - preprocess
 - split data into training, validation, test
 - create batches

### Load data from CSV file

In [None]:
def load_data_from_directory(data_dir):
    data_list = []
    labels_list = []
    labels = ['CAR', 'STREET', 'HOME', 'CAFE']
    label_to_index = {label: index for index, label in enumerate(labels)}
    
    print(f"Loading data from directory: {data_dir}")
    for filename in os.listdir(data_dir):
        if filename.endswith('.csv'):
            filepath = os.path.join(data_dir, filename)
            df = pd.read_csv(filepath)
            data = df.iloc[:, :-1].values  # First 16 columns
            label_values = df['Label'].unique()
            
            if len(label_values) != 1:
                raise ValueError(f"Inconsistent labels in file: {filepath}")
            
            label = label_values[0]
            label_index = label_to_index.get(label)
            if label_index is None:
                raise ValueError(f"Unknown label '{label}' in file: {filepath}")
            
            if data.shape != (1000, 16):
                raise ValueError(f"Unexpected data shape in file: {filepath}. Expected (1000, 16), got {data.shape}")
            
            data_list.append(data)
            labels_list.append(label_index)
    
    # Stack data and labels
    all_data = np.stack(data_list)  # Shape: (num_samples, 1000, 16)
    all_labels = np.array(labels_list)  # Shape: (num_samples,)
    
    print(f"Total samples loaded: {all_data.shape[0]}")
    return all_data, all_labels


dir_test = '/home/kumria/Documents/Offline_Datasets/TEST'
dir_complete = '/home/kumria/Documents/Offline_Datasets/4_one_second_samples'

data, labels = load_data_from_directory(dir_complete)

print(data.shape)
print(labels.shape)

Loading data from directory: /home/kumria/Documents/Offline_Datasets/4_one_second_samples


### Split Data

In [None]:
def split_data(all_data, all_labels, train_ratio=0.65, val_ratio=0.15, test_ratio=0.2, seed=42):
    assert train_ratio + val_ratio + test_ratio == 1.0, "Ratios must sum to 1"
    num_samples = all_data.shape[0]
    indices = np.arange(num_samples)
    np.random.seed(seed)
    np.random.shuffle(indices)
    
    train_end = int(train_ratio * num_samples)
    val_end = train_end + int(val_ratio * num_samples)
    
    train_indices = indices[:train_end]
    val_indices = indices[train_end:val_end]
    test_indices = indices[val_end:]
    
    train_data = all_data[train_indices]
    train_labels = all_labels[train_indices]
    val_data = all_data[val_indices]
    val_labels = all_labels[val_indices]
    test_data = all_data[test_indices]
    test_labels = all_labels[test_indices]
    
    print(f"Dataset split: Train {len(train_data)}, Validation {len(val_data)}, Test {len(test_data)}")
    return (train_data, train_labels), (val_data, val_labels), (test_data, test_labels)

In [None]:
def one_hot_encode(labels, num_classes):
    return jax.nn.one_hot(labels, num_classes=num_classes)

### load_data

In [None]:
def load_data():
    data_dir = '/home/kumria/Documents/Offline_Datasets/TEST'
    all_data, all_labels = load_data_from_directory(data_dir)
    
    # Split data
    (train_data, train_labels), (val_data, val_labels), (test_data, test_labels) = split_data(all_data, all_labels)
    
    # One-hot encode labels
    num_classes = 4  # As per my labels: 'CAR', 'STREET', 'HOME', 'CAFE'
    train_labels_onehot = one_hot_encode(train_labels, num_classes)
    val_labels_onehot = one_hot_encode(val_labels, num_classes)
    test_labels_onehot = one_hot_encode(test_labels, num_classes)
    
    # Convert to JAX arrays
    train_data = jnp.array(train_data)
    train_labels_onehot = jnp.array(train_labels_onehot)
    val_data = jnp.array(val_data)
    val_labels_onehot = jnp.array(val_labels_onehot)
    test_data = jnp.array(test_data)
    test_labels_onehot = jnp.array(test_labels_onehot)
    
    # Also keep integer labels for accuracy computation
    train_labels = jnp.array(train_labels)
    val_labels = jnp.array(val_labels)
    test_labels = jnp.array(test_labels)
    
    return (train_data, train_labels_onehot, train_labels), \
           (val_data, val_labels_onehot, val_labels), \
           (test_data, test_labels_onehot, test_labels)


JAX does not have DataLoader, so we have to create our own batches

## 4. LIF Spiking Neuron

In [None]:
def lif_neuron(V_input, input_current, tau, V_th, V_reset):

    # update the membrane potential
    dV = (input_current - V_input) / tau        # dV/dt = (R*I - V) / tau so we are considering R==1, and dV == dV/dt so we are considering dt==1
    V = V_input + dV

    # spike generation
    spike = (V > V_th).astype(jnp.float32)      # if V > V_th, spike = 1, otherwise spike = 0

    # reset the membrane potential after spike
    V = jnp.where(spike, V_reset, V)            # if spike = 1, V = V_reset (can reset to either zero or another mechanism), otherwise V_new = V_old (so nothing changes)

    return V, spike

## 5. SNN Model Definition

In [None]:
@jax.jit
def SNN_forward(neuron_params, inputs, weight_params):
    
    batch_size, time_steps, input_size = inputs.shape
    hidden_size = neuron_params['tau_1'].shape[0]
    output_size = neuron_params['tau_4'].shape[0]
    # we are using neuron_params instead of config because JAX jit does not allow to use config in the function


    # call the tau
    tau_1 = neuron_params['tau_1']
    tau_2 = neuron_params['tau_2']
    tau_3 = neuron_params['tau_3']
    tau_4 = neuron_params['tau_4']

    V_th = neuron_params['V_th']
    V_reset = neuron_params['V_reset']

    # initialize the membrane potentials
    V1 = jnp.zeros((batch_size, hidden_size))
    V2 = jnp.zeros((batch_size, hidden_size))
    V3 = jnp.zeros((batch_size, hidden_size))
    V4 = jnp.zeros((batch_size, output_size))

    # JAX-OPTIMIZED LOOP
    def step(carry, x_t):
        V1, V2, V3, V4 = carry # get the membrane potentials from the previous time step (carry-over)

        # layer 1
        I1 = jnp.dot(x_t, weight_params['weight_1'])
        V1_new, spike_1 = lif_neuron(V1, I1, tau_1, V_th, V_reset)

        # layer 2
        I2 = jnp.dot(spike_1, weight_params['weight_2'])
        V2_new, spike_2 = lif_neuron(V2, I2, tau_2, V_th, V_reset)

        # layer 3
        I3 = jnp.dot(spike_2, weight_params['weight_3'])
        V3_new, spike_3 = lif_neuron(V3, I3, tau_3, V_th, V_reset)

        # layer 4
        I4 = jnp.dot(spike_3, weight_params['weight_4'])
        V4_new, spike_4 = lif_neuron(V4, I4, tau_4, 1e7, V_reset)

        new_carry = (V1_new, V2_new, V3_new, V4_new)

        output = (spike_4, V4_new)

        return new_carry, output

   
    inputs = inputs.transpose(1, 0, 2)  # transpose to change the shape of inputs, to bring time_steps in the first dimension for the scan operation: (time_steps, batch_size, input_size)
    _, outputs = jax.lax.scan(step, (V1, V2, V3, V4), inputs)

    spikes, membrane_potentials = outputs

    return spikes, membrane_potentials

`NOTE: after jax.lax.scan the spikes and membrane_potentials will have shapes (time_steps, batch_size, output_size)`

## 6. Loss Function

In [None]:
def MSE_loss(y_true, y_pred):
    return jnp.mean(jnp.square(y_true - y_pred))

## 7. Train, Validate, Test Functions

In [None]:
@jax.jit
def train(weights, opt_state, inputs, targets_onehot, neuron_params):
    def loss_fn(weights):
        _, membrane_potentials = SNN_forward(neuron_params, inputs, weights)
        outputs = jnp.sum(membrane_potentials, axis=0)  # Sum over time steps
        loss = MSE_loss(targets_onehot, outputs)
        return loss, outputs

    (loss, outputs), grads = value_and_grad(loss_fn, has_aux=True)(weights) # by using hax_aux=True, we can return the outputs from the loss function (we can return auxiliaty data like outputs, from the loss function)
    updates, opt_state = optimizer.update(grads, opt_state, weights)
    weights = optax.apply_updates(weights, updates)
    predictions = jnp.argmax(outputs, axis=-1) # axis=-1 means the last axis

    return weights, opt_state, loss, predictions

In [None]:
def validate(weights, inputs, targets_onehot, targets_int, neuron_params):

    # print("Targets One-Hot:", targets_onehot)
    # print("Targets Integer:", targets_int)
    
    _, membrane_potentials = SNN_forward(neuron_params, inputs, weights)
    outputs = jnp.sum(membrane_potentials, axis=0)  # Sum over time steps
    loss = MSE_loss(targets_onehot, outputs)
    predictions = jnp.argmax(outputs, axis=-1)
    accuracy = jnp.mean(predictions == targets_int)
    return loss, accuracy

In [None]:
def test(weights, inputs, targets_onehot, targets_int, neuron_params):
    loss, accuracy = validate(weights, inputs, targets_onehot, targets_int, neuron_params)
    return loss, accuracy

## - main -

In [None]:
def main():
    config = Config()
    key = jax.random.PRNGKey(0)
    optimizer = optax.adamax(learning_rate=config.learning_rate)
    weights = initialize_weights(key, config)
    opt_state = optimizer.init(weights)

    # Load data
    (train_data, train_labels_onehot, train_labels), \
    (validation_data, validation_labels_onehot, validation_labels), \
    (test_data, test_labels_onehot, test_labels) = load_data()

    num_epochs = config.num_epochs
    batch_size = config.batch_size
    num_batches = train_data.shape[0] // batch_size


    # Training loop
    for epoch in range(num_epochs):
        # shuffle data
        key, subkey = jax.random.split(key)
        perm = jax.random.permutation(subkey, train_data.shape[0]) # generate a random permutation of the numbers from 0 to num_samples of the training data
        train_data_shuffled = train_data[perm]
        train_labels_shuffled = train_labels[perm]
        train_labels_onehot_shuffled = train_labels_onehot[perm]

        epoch_loss = 0.0
        epoch_accuracy = 0.0

        for i in range(num_batches):
            # get the batch data
            batch_data = train_data_shuffled[i * batch_size : (i+1) * batch_size] # the i*batch_size:(i+1)*batch_size does the batching by taking the samples FROM "i*batch_size" TO "(i+1)*batch_size"
            batch_labels = train_labels_shuffled[i * batch_size : (i+1) * batch_size]
            batch_labels_onehot = train_labels_onehot_shuffled[i * batch_size : (i+1) * batch_size]

            # update the parameters
            weights, opt_state, loss, predictions = train(weights, opt_state, batch_data, batch_labels_onehot, config.neuron_params)
            epoch_loss += loss
            accuracy = jnp.mean(predictions == batch_labels)
            epoch_accuracy += accuracy

        # average metrics over batches
        epoch_loss /= num_batches
        epoch_accuracy /= num_batches

        # validate the model
        validation_loss, validation_accuracy = validate(weights, validation_data, validation_labels_onehot, validation_labels, config.neuron_params)

        print(f"Epoch {epoch+1}/{num_epochs}   ||   Loss: {epoch_loss:.4f}   ||   Accuracy: {100*epoch_accuracy:.4f}%   ||   Validation Loss: {validation_loss:.4f}   ||   Validation Accuracy: {100*validation_accuracy:.4f}%")

    
    # test the model
    test_loss, test_accuracy = test(weights, test_data, test_labels_onehot, test_labels, config.neuron_params) # test the model on the test data
    print(f"\nTest Loss: {test_loss:.4f}   ||   Test Accuracy: {100*test_accuracy:.4f}%")


if __name__ == "__main__":
    main()