Welcome! This script should generate a brand new spike dataset (numpy arrays) from the .wav home sound dataset. This is to save on memory and time, as we won't have to preprocess the dataset multiple times. Trust me, this script takes a loooooong time to run.

You don't have to run this script, as the dataset is saved to my google drive and I will have zipped it up and shared it with you all.

I took some help from AI to write this code.

## Update: adding noise
I've added functionality for adding white gaussian noise within a random SNR (Signal-to-Noise Ratio) to each sample. Feel free to use this code to generate a dataset with any desired amount of white gaussian noise.

Using SNR means the noise will be proportional to the signal (original audio). Note that:

SNR = 0: noise and signal have equal dB.

SNR < 0: noise is *louder* than signal, dB of noise is higher.

SNR > 0: noise is *quieter* than signal, dB of noise is lower.

In [1]:
# run once, also runtime may ask to be restarted, that's fine
!pip install rockpool
!pip install "numpy<2" "jax==0.4.20" "jaxlib==0.4.20"
!pip install samna
!pip install bitstruct

Collecting samna
  Using cached samna-0.48.2-py3-none-any.whl.metadata (375 bytes)
Using cached samna-0.48.2-py3-none-any.whl (6.8 kB)
Installing collected packages: samna
Successfully installed samna-0.48.2
Collecting bitstruct
  Downloading bitstruct-8.21.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.3 kB)
Downloading bitstruct-8.21.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (83 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m83.7/83.7 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: bitstruct
Successfully installed bitstruct-8.21.0


In [2]:
from zipfile import ZipFile

from google.colab import drive
drive.mount('/content/drive')

zip_file_path = '/content/drive/MyDrive/D7046E_SNN_project_dataset/building_106_kitchen.zip'
extract_path = '/content/dataset'

with ZipFile(zip_file_path, 'r') as zip_ref:
    zip_ref.extractall(extract_path)

Mounted at /content/drive


In [3]:
audio_root_dir = '/content/dataset/building_106_kitchen/training_segments'
output_dir = '/content/drive/MyDrive/D7046E_SNN_project_dataset/kitchen_spike_dataset_noisy'

In [1]:
import os
import random
from IPython.display import Image, Audio, display
import numpy as np
import librosa
from rockpool.devices.xylo.syns61201 import AFESim
from rockpool.timeseries import TSContinuous
from tqdm import tqdm
import json
import gc


# add white noise to audio at a given SNR [dB] (Signal to Noise Ratio)
def add_white_noise(audio, snr_db):
    signal_power = np.mean(audio ** 2)
    snr_linear = 10 ** (snr_db / 10)
    noise_power = signal_power / snr_linear

    # from a normal distribution = gaussian noise
    noise = np.random.normal(0, np.sqrt(noise_power), audio.shape)
    return audio + noise


def convert_audio_to_spikes(
    audio_path,
    sample_T,
    raster_period,
    add_noise=False,
    snr_range=(0, 20)  # dB
):
    afe = AFESim(
        fs=110e3,
        raster_period=raster_period,
        max_spike_per_raster_period=15,
        add_noise=False,      # keep AFE clean
        add_offset=False,
        add_mismatch=False,
        seed=None,
    ).timed()

    audio, sr = librosa.load(audio_path, sr=None, mono=True)

    # add noise here
    if add_noise:
        snr_db = np.random.uniform(*snr_range)
        audio = add_white_noise(audio, snr_db)

    times = np.arange(len(audio)) / sr
    ts_audio = TSContinuous(times, audio[:, None])

    # we have a bunch of garbage collection like this to save memory
    del audio, times

    # convert audio to spikes
    spikes, _, _ = afe(ts_audio)

    del ts_audio, afe

    # convert to raster format (time_steps x channels)
    spike_raster = spikes.raster(dt=raster_period, add_events=True)
    del spikes

    # validate we got spikes
    if spike_raster.shape[0] == 0:
        # if no spikes generated, create empty raster
        spike_raster = np.zeros((sample_T, 16), dtype=np.float32)
        return spike_raster

    # pad or truncate to sample_T time steps
    # (300 10ms time steps = 3 seconds)
    if spike_raster.shape[0] < sample_T:
        padding = np.zeros((sample_T - spike_raster.shape[0], spike_raster.shape[1]), dtype=np.float32)
        spike_raster = np.vstack([spike_raster, padding])
        del padding
    else:
        spike_raster = spike_raster[:sample_T, :]

    spike_raster = spike_raster.astype(np.float32)
    return spike_raster


def preprocess_batch(
    audio_root_dir,
    output_dir,
    sample_T,
    raster_period,
    batch_start=0,
    batch_size=4,
    snr_range=(0, 20),
    add_noise=False
):
    os.makedirs(output_dir, exist_ok=True)

    # Get all classes
    all_classes = sorted([d for d in os.listdir(audio_root_dir)
                         if os.path.isdir(os.path.join(audio_root_dir, d))])

    # Select batch of classes to process
    classes_to_process = all_classes[batch_start:batch_start + batch_size]

    if not classes_to_process:
        print(f"No classes to process at batch_start={batch_start}")
        return

    print(f"Processing batch: classes {batch_start} to {batch_start + len(classes_to_process) - 1}")
    print(f"Classes in this batch: {classes_to_process}")

    # Create/load metadata
    metadata_path = os.path.join(output_dir, 'metadata.json')
    if os.path.exists(metadata_path):
        with open(metadata_path, 'r') as f:
            metadata = json.load(f)
        class_to_idx = metadata['class_to_idx']
    else:
        # first batch - create metadata
        class_to_idx = {cls: i for i, cls in enumerate(all_classes)}
        metadata = {
            'classes': all_classes,
            'class_to_idx': class_to_idx,
            'sample_T': sample_T,
            'raster_period': raster_period,
            'num_channels': 16,
        }
        with open(metadata_path, 'w') as f:
            json.dump(metadata, f, indent=2)
        print(f"Created metadata for all {len(all_classes)} classes")

    # load existing sample info if it exists
    sample_info_path = os.path.join(output_dir, 'sample_info.npy')
    if os.path.exists(sample_info_path):
        sample_info = np.load(sample_info_path, allow_pickle=True).tolist()
    else:
        sample_info = []

    # Process each class in this batch
    total_processed = 0
    total_errors = 0
    error_log = []

    for class_name in classes_to_process:
        class_path = os.path.join(audio_root_dir, class_name)
        class_idx = class_to_idx[class_name]

        # Create output directory for this class
        class_output_dir = os.path.join(output_dir, class_name)
        os.makedirs(class_output_dir, exist_ok=True)

        # Get all WAV files
        wav_files = [f for f in os.listdir(class_path) if f.endswith('.wav')]

        print(f"\nProcessing '{class_name}' ({len(wav_files)} samples)...")

        class_errors = 0
        for idx, wav_file in enumerate(tqdm(wav_files, desc=f"  {class_name}")):
            audio_path = os.path.join(class_path, wav_file)
            output_filename = wav_file.replace('.wav', '.npy')
            output_path = os.path.join(class_output_dir, output_filename)

            # Skip if already processed
            if os.path.exists(output_path):
                continue

            spike_raster = convert_audio_to_spikes(
                audio_path,
                sample_T,
                raster_period,
                add_noise,
                snr_range # CHOOSE SNR [dB] HERE
            )

            np.save(output_path, spike_raster)
            sample_info.append((output_path, class_idx))
            total_processed += 1
            del spike_raster

            if idx % 5 == 0:
                gc.collect()


        # Save sample info after each class
        np.save(sample_info_path, sample_info, allow_pickle=True)
        gc.collect()


    return len(all_classes), batch_start + batch_size

In [4]:
# Clear memory
gc.collect()

# AFESim parameters (not creating it here - created fresh for each file)
sample_T = 300
raster_period = 10e-3

# CHOOSE SNR HERE
snr_range = (0,15)
add_noise = True

# PROCESS IN BATCHES (of classes)
batch_start = 23  # start at class

batch_size = 1  # process n classes at a time

print(f"\nConfiguration:")
print(f"  Batch start: {batch_start}")
print(f"  Batch size: {batch_size}")

total_classes, next_batch = preprocess_batch(
    audio_root_dir=audio_root_dir,
    output_dir=output_dir,
    sample_T=sample_T,
    raster_period=raster_period,
    batch_start=batch_start,
    batch_size=batch_size,
    snr_range=snr_range,
    add_noise=add_noise
)



Configuration:
  Batch start: 23
  Batch size: 1
Processing batch: classes 23 to 23
Classes in this batch: ['water_tap']

Processing 'water_tap' (115 samples)...


  water_tap: 100%|██████████| 115/115 [07:00<00:00,  3.66s/it]


Becuase TimedModuleWrapper was giving me trouble I tried switching to non-timed simulation.

First 4 classes have 10 ms raster period, max 15 spikes: 88.9% empty\
First 4 classes have 20 ms raster period, max 100 spikes: 91.7% empty\
First 4 classes have 50 ms raster period, max 100 spikes: 91.7% empty\
First 4 classes have 5 ms raster period, max 100 spikes: 91.7% empty

conclusion: adjusting these parameters doesn't help make more spikes

Tried scaling voltage of input to expected RMS:\
First 4 classes have 10 ms raster period, max 15 spikes: 11.1% empty (didn't end up using this anyway, while fewer samples are empty, most are still super sparse, this is still worth considering though)

In the end, I switched back to TimedModuleWrapper (.timed()) and created a new afesim for every single audio sample, making sure to delete each simulation and variables after I've used them to save on memory. This is very slow but seems to work fine.


In [None]:
# for testing different noise
# path to each training class
train_segments_path = '/content/dataset/building_106_kitchen/training_segments'
class_directories = []
for item in os.listdir(train_segments_path):
    item_path = os.path.join(train_segments_path, item)
    if os.path.isdir(item_path):
        class_directories.append(item)

# one random audio sample from each class
# run this cell again to get different samples
for class_name in class_directories:
    class_path = os.path.join(train_segments_path, class_name)
    wav_files = [f for f in os.listdir(class_path) if f.endswith('.wav')]
    if wav_files:
        random_wav_file = random.choice(wav_files)
        random_wav_file_path = os.path.join(class_path, random_wav_file)
        print(f"\nClass: {class_name}")
        audio_data, sr = librosa.load(random_wav_file_path, sr=None, mono=True)
        noisy_audio_data = add_white_noise(audio_data, snr_db=15)
        display(Audio(data=noisy_audio_data, rate=sr))
