# SHD Dataset Processing: Creating Timing-Based Benchmarks

This notebook provides a complete pipeline for processing the Spiking Heidelberg Digits (SHD) dataset to create specialized benchmarks for evaluating temporal information processing in spiking neural networks.

## Research Motivation

The SHD dataset has become a widely used benchmark in neuromorphic computing. However, our research reveals that while these datasets contain substantial temporal information, surprisingly accurate classification can be achieved using spike counts alone or following various perturbations of spike timing. This suggests that standard SHD may not be ideal for probing the extent to which spike-based models can exploit temporal information.

## Generated Datasets

This notebook creates three dataset variants:

1. **`shd_whole.mat`**: Complete original dataset (training + test combined)
   - Contains all original spike data
   - Serves as the baseline for comparison

2. **`shd_part.mat`**: Partial dataset with class balancing
   - Preserves original spike patterns and timing
   - Balanced classes for fair evaluation
   - Suitable for standard temporal processing benchmarks

3. **`shd_norm.mat`**: Normalized dataset eliminating spike count information
   - **By construction, eliminates all spike count information**
   - Only temporal information remains (though some temporal info may be removed)
   - Specifically designed to evaluate timing-based computation capabilities
   - Enables fair assessment of temporal vs. rate-based processing

## Dataset Access

We provide a public release of these processed datasets to facilitate research on timing-based computation in spiking neural networks, as described in our paper.

## Download Original SHD Dataset

First, we need to download the original Spiking Heidelberg Digits dataset from the official source. The dataset consists of:
- **Training set**: `shd_train.h5.gz` 
- **Test set**: `shd_test.h5.gz`

All files are downloaded with MD5 hash verification to ensure data integrity.

In [2]:
import os
import urllib.request
import gzip
import shutil
import hashlib

from six.moves.urllib.error import HTTPError 
from six.moves.urllib.error import URLError
from six.moves.urllib.request import urlretrieve

print("  All libraries imported successfully!")

  All libraries imported successfully!


### Dataset Download Utilities

The following functions handle secure file downloading, hash validation, and decompression. These utilities ensure data integrity and provide robust download capabilities.

In [3]:
def _hash_file(fpath, algorithm='sha256', chunk_size=65535):
    """Calculate hash of a file."""
    if (algorithm == 'sha256') or (algorithm == 'auto' and len(fpath) == 64):
        hasher = hashlib.sha256()
    else:
        hasher = hashlib.md5()

    with open(fpath, 'rb') as fpath_file:
        for chunk in iter(lambda: fpath_file.read(chunk_size), b''):
            hasher.update(chunk)

    return hasher.hexdigest()

def validate_file(fpath, file_hash, algorithm='auto', chunk_size=65535):
    """Validate a file against its hash."""
    if (algorithm == 'sha256') or (algorithm == 'auto' and len(file_hash) == 64):
        hasher = 'sha256'
    else:
        hasher = 'md5'

    if str(_hash_file(fpath, hasher, chunk_size)) == str(file_hash):
        return True
    else:
        return False

print("  Hash validation functions defined.")

  Hash validation functions defined.


In [4]:
def get_file(fname,
             origin,
             md5_hash=None,
             file_hash=None,
             cache_subdir='datasets',
             hash_algorithm='auto',
             extract=False,
             archive_format='auto',
             cache_dir=None):
    """Download a file from a URL if it not already in the cache."""
    
    if cache_dir is None:
        cache_dir = os.path.join(os.path.expanduser('~'), '.data-cache')
    if md5_hash is not None and file_hash is None:
        file_hash = md5_hash
        hash_algorithm = 'md5'
    datadir_base = os.path.expanduser(cache_dir)
    if not os.access(datadir_base, os.W_OK):
        datadir_base = os.path.join('/tmp', '.data-cache')
    datadir = os.path.join(datadir_base, cache_subdir)

    # Create directories if they don't exist
    os.makedirs(cache_dir, exist_ok=True)
    os.makedirs(datadir, exist_ok=True)

    fpath = os.path.join(datadir, fname)

    download = False
    if os.path.exists(fpath):
        # File found; verify integrity if a hash was provided.
        if file_hash is not None:
            if not validate_file(fpath, file_hash, algorithm=hash_algorithm):
                print('A local file was found, but it seems to be '
                      'incomplete or outdated because the ' + hash_algorithm +
                      ' file hash does not match the original value of ' + file_hash +
                      ' so we will re-download the data.')
                download = True
    else:
        download = True

    if download:
        print('Downloading data from', origin)

        error_msg = 'URL fetch failure on {}: {} -- {}'
        try:
            try:
                urlretrieve(origin, fpath)
            except HTTPError as e:
                raise Exception(error_msg.format(origin, e.code, e.msg))
            except URLError as e:
                raise Exception(error_msg.format(origin, e.errno, e.reason))
        except (Exception, KeyboardInterrupt) as e:
            if os.path.exists(fpath):
                os.remove(fpath)
            raise

    return fpath

print("  File download function defined.")

  File download function defined.


In [5]:
def get_and_gunzip(origin, filename, md5hash=None, cache_dir=None, cache_subdir=None):
    """Download and decompress a gzipped file."""
    gz_file_path = get_file(filename, origin, md5_hash=md5hash, cache_dir=cache_dir, cache_subdir=cache_subdir)
    hdf5_file_path = gz_file_path[:-3]  # Remove .gz extension
    
    if not os.path.isfile(hdf5_file_path) or os.path.getctime(gz_file_path) > os.path.getctime(hdf5_file_path):
        print("Decompressing %s" % gz_file_path)
        with gzip.open(gz_file_path, 'rb') as f_in, open(hdf5_file_path, 'wb') as f_out:
            shutil.copyfileobj(f_in, f_out)
    
    return hdf5_file_path

print("  Decompression function defined.")

  Decompression function defined.


In [6]:
def get_shd_dataset(cache_dir, cache_subdir):
    """Download the Spiking Heidelberg Digits dataset."""
    
    # The remote directory with the data files
    base_url = "https://zenkelab.org/datasets"

    # Retrieve MD5 hashes from remote
    print("Fetching MD5 checksums...")
    response = urllib.request.urlopen("%s/md5sums.txt" % base_url)
    data = response.read() 
    lines = data.decode('utf-8').split("\n")
    file_hashes = {line.split()[1]: line.split()[0] for line in lines if len(line.split()) == 2}

    # Download the Spiking Heidelberg Digits (SHD) dataset
    files = ["shd_train.h5.gz", "shd_test.h5.gz"]
    
    downloaded_files = []
    for fn in files:
        print(f"\nProcessing {fn}...")
        origin = "%s/%s" % (base_url, fn)
        hdf5_file_path = get_and_gunzip(origin, fn, md5hash=file_hashes[fn], 
                                      cache_dir=cache_dir, cache_subdir=cache_subdir)
        print("Available at: %s" % hdf5_file_path)
        downloaded_files.append(hdf5_file_path)
    
    return downloaded_files

print("  Main dataset download function defined.")

  Main dataset download function defined.


## Configure Storage Location

Set up the directory structure where the original and processed datasets will be stored.

In [7]:
import os

# Set up cache directories relative to current working directory
cache_dir = "./data"
cache_subdir = "hdspikes"

# Create the full path for easier reference
full_cache_path = os.path.join(cache_dir, cache_subdir)

print(f"Cache directory: {cache_dir}")
print(f"Cache subdirectory: {cache_subdir}")
print(f"Full cache path: {full_cache_path}")

# Create directories if they don't exist
os.makedirs(full_cache_path, exist_ok=True)
print(f"  Directory created/verified: {full_cache_path}")


Cache directory: ./data
Cache subdirectory: hdspikes
Full cache path: ./data/hdspikes
  Directory created/verified: ./data/hdspikes


## Execute Dataset Download

Now we'll download the original SHD dataset. This process will:
1. Fetch MD5 checksums from the server for data integrity
2. Download the compressed training and test files  
3. Verify file integrity using MD5 hash validation
4. Automatically decompress the files to HDF5 format

**Note**: Initial download may take several minutes depending on your internet connection.

In [8]:
# Download the SHD dataset
print("Starting SHD dataset download...")
print("=" * 50)

try:
    downloaded_files = get_shd_dataset(cache_dir, cache_subdir)
    print("\n" + "=" * 50)
    print("Dataset download completed successfully!")
    print(f"Files downloaded to: {full_cache_path}")
    
except Exception as e:
    print(f"\n Error during download: {str(e)}")
    print("Please check your internet connection and try again.")

Starting SHD dataset download...
Fetching MD5 checksums...

Processing shd_train.h5.gz...
Downloading data from https://zenkelab.org/datasets/shd_train.h5.gz

Processing shd_train.h5.gz...
Downloading data from https://zenkelab.org/datasets/shd_train.h5.gz
Decompressing ./data/hdspikes/shd_train.h5.gz
Decompressing ./data/hdspikes/shd_train.h5.gz
Available at: ./data/hdspikes/shd_train.h5

Processing shd_test.h5.gz...
Downloading data from https://zenkelab.org/datasets/shd_test.h5.gz
Available at: ./data/hdspikes/shd_train.h5

Processing shd_test.h5.gz...
Downloading data from https://zenkelab.org/datasets/shd_test.h5.gz
Decompressing ./data/hdspikes/shd_test.h5.gz
Decompressing ./data/hdspikes/shd_test.h5.gz
Available at: ./data/hdspikes/shd_test.h5

Dataset download completed successfully!
Files downloaded to: ./data/hdspikes
Available at: ./data/hdspikes/shd_test.h5

Dataset download completed successfully!
Files downloaded to: ./data/hdspikes


## Verify Original Dataset

Let's verify that the original SHD files were downloaded correctly and examine their structure.

In [9]:
# Check for the expected files
expected_files = ["shd_train.h5", "shd_test.h5"]

print(" File Verification Report")
print("=" * 40)

for filename in expected_files:
    filepath = os.path.join(full_cache_path, filename)
    
    if os.path.exists(filepath):
        # Get file size
        file_size = os.path.getsize(filepath)
        file_size_mb = file_size / (1024 * 1024)
        
        print(f" {filename}")
        print(f"    Path: {filepath}")
        print(f"    Size: {file_size_mb:.2f} MB ({file_size:,} bytes)")
        print()
    else:
        print(f" {filename} - NOT FOUND")
        print(f"   Expected at: {filepath}")
        print()

# List all files in the cache directory
print(" All files in cache directory:")
print("-" * 40)
try:
    all_files = os.listdir(full_cache_path)
    for file in sorted(all_files):
        file_path = os.path.join(full_cache_path, file)
        if os.path.isfile(file_path):
            size_mb = os.path.getsize(file_path) / (1024 * 1024)
            print(f"   {file} ({size_mb:.2f} MB)")
        else:
            print(f"   {file} (directory)")
except FileNotFoundError:
    print(f"   Directory not found: {full_cache_path}")

print("\n Original SHD dataset is ready for processing!")

 File Verification Report
 shd_train.h5
    Path: ./data/hdspikes/shd_train.h5
    Size: 256.37 MB (268,826,085 bytes)

 shd_test.h5
    Path: ./data/hdspikes/shd_test.h5
    Size: 75.07 MB (78,719,235 bytes)

 All files in cache directory:
----------------------------------------
   shd_test.h5 (75.07 MB)
   shd_test.h5.gz (36.37 MB)
   shd_train.h5 (256.37 MB)
   shd_train.h5.gz (124.78 MB)

 Original SHD dataset is ready for processing!


## Explore Dataset Structure

Let's examine the structure and content of the downloaded HDF5 files to understand the data format.

In [10]:
# Example: How to load the SHD dataset
import h5py

# Open the dataset files
train_file_path = os.path.join(full_cache_path, 'shd_train.h5')
test_file_path = os.path.join(full_cache_path, 'shd_test.h5')

try:
    print(" Loading dataset files for exploration...")
    
    # Open training file
    train_file = h5py.File(train_file_path, 'r')
    print(f" Training file loaded: {train_file_path}")
    print(f"   Keys: {list(train_file.keys())}")
    
    # Open test file  
    test_file = h5py.File(test_file_path, 'r')
    print(f" Test file loaded: {test_file_path}")
    print(f"   Keys: {list(test_file.keys())}")
    
    # Access the data (without loading into memory)
    x_train = train_file['spikes']
    y_train = train_file['labels']
    x_test = test_file['spikes']
    y_test = test_file['labels']
    
    print(f"\n Dataset Statistics:")
    print(f"   Training samples: {len(y_train)}")
    print(f"   Test samples: {len(y_test)}")
    print(f"   Total samples: {len(y_train) + len(y_test)}")
    print(f"   Unique labels: {len(set(y_train[:]))}")
    print(f"   Label range: {min(y_train[:])} - {max(y_train[:])}")
    
    # Remember to close the files when done
    train_file.close()
    test_file.close()
    print("\n Files closed successfully")
    
except Exception as e:
    print(f" Error loading dataset: {str(e)}")

print("\n Ready to proceed with dataset processing pipeline!")

 Loading dataset files for exploration...
 Training file loaded: ./data/hdspikes/shd_train.h5
   Keys: ['extra', 'labels', 'spikes']
 Test file loaded: ./data/hdspikes/shd_test.h5
   Keys: ['extra', 'labels', 'spikes']

 Dataset Statistics:
   Training samples: 8156
   Test samples: 2264
   Total samples: 10420
   Unique labels: 20
   Label range: 0 - 19

 Files closed successfully

 Ready to proceed with dataset processing pipeline!


## Step 1: Generate Combined Dataset (shd_whole.mat)

We first convert the sparse HDF5 format to a dense MAT format and combine training and test data. This creates our baseline `shd_whole.mat` file containing all original spike information.

In [11]:
# Additional imports for MAT file generation
import numpy as np
import h5py
from scipy import io

print("Additional libraries for MAT file generation imported successfully!")

Additional libraries for MAT file generation imported successfully!


In [12]:
def sparse_data_generator_from_hdf5_spikes(X, y, batch_size, nb_steps, nb_units, max_time, shuffle=True):
    """
    Generator function to convert sparse spike data from HDF5 to dense format.
    
    Parameters:
    - X: HDF5 group containing 'times' and 'units' datasets
    - y: Labels array
    - batch_size: Number of samples per batch
    - nb_steps: Number of time steps
    - nb_units: Number of units/neurons
    - max_time: Maximum time duration
    - shuffle: Whether to shuffle the data
    
    Yields:
    - dense_batch: Dense spike tensor of shape (batch_size, nb_units, nb_steps)
    - y_batch: Labels for the batch
    """
    labels_ = np.array(y, dtype=np.int32)
    number_of_batches = len(labels_) // batch_size
    sample_index = np.arange(len(labels_))
    firing_times = X['times']
    units_fired = X['units']
    time_bins = np.linspace(0, max_time, num=nb_steps)

    if shuffle:
        np.random.shuffle(sample_index)

    counter = 0
    while counter < number_of_batches:
        batch_index = sample_index[batch_size * counter:batch_size * (counter + 1)]
        dense_batch = np.zeros((batch_size, nb_units, nb_steps), dtype=np.uint8)
        y_batch = []

        for bc, idx in enumerate(batch_index):
            times = np.digitize(firing_times[idx], time_bins)
            units = units_fired[idx]
            times[times >= nb_steps] = nb_steps - 1 
            dense_batch[bc, units, times] = 1
            y_batch.append(labels_[idx])

        yield dense_batch, np.array(y_batch, dtype=np.uint8)
        counter += 1

print("Sparse data generator function defined.")

Sparse data generator function defined.


In [13]:
def collect_all(X_h5, Y_h5, batch_size, nb_steps, nb_units, max_time):
    """
    Collect all data from HDF5 format and convert to dense numpy arrays.
    
    Parameters:
    - X_h5: HDF5 group containing spike data
    - Y_h5: Labels array
    - batch_size: Batch size for processing
    - nb_steps: Number of time steps
    - nb_units: Number of units/neurons
    - max_time: Maximum time duration
    
    Returns:
    - X_all: Dense spike data array
    - Y_all: Labels array
    """
    X_all = []
    Y_all = []
    for x_batch, y_batch in sparse_data_generator_from_hdf5_spikes(
            X_h5, Y_h5, batch_size, nb_steps, nb_units, max_time, shuffle=False):
        X_all.append(x_batch)
        Y_all.append(y_batch)
    X_all = np.concatenate(X_all, axis=0)
    Y_all = np.concatenate(Y_all, axis=0)
    return X_all, Y_all

print("Data collection helper function defined.")

Data collection helper function defined.


In [14]:
def generate_shd_whole_mat(cache_dir, cache_subdir, save_path=None, 
                          batch_size=256, nb_steps=100, nb_units=700, max_time=1.4):
    """
    Generate a combined MAT file containing both training and test SHD data in dense format.
    
    Parameters:
    - cache_dir: Directory where the dataset is cached
    - cache_subdir: Subdirectory containing the dataset files
    - save_path: Path where to save the MAT file (if None, saves in cache directory)
    - batch_size: Batch size for processing (default: 256)
    - nb_steps: Number of time steps (default: 100)
    - nb_units: Number of units/neurons (default: 700)
    - max_time: Maximum time duration in seconds (default: 1.4)
    
    Returns:
    - save_path: Path where the file was saved
    """
    
    print("Starting MAT file generation...")
    print(f"Parameters: batch_size={batch_size}, nb_steps={nb_steps}, nb_units={nb_units}, max_time={max_time}")
    
    # Construct file paths
    base_path = os.path.join(cache_dir, cache_subdir)
    train_path = os.path.join(base_path, "shd_train.h5")
    test_path = os.path.join(base_path, "shd_test.h5")
    
    # Set default save path if not provided
    if save_path is None:
        save_path = os.path.join(base_path, "shd_whole.mat")
    
    print(f"   Loading data from:")
    print(f"   Training: {train_path}")
    print(f"   Test: {test_path}")
    
    try:
        # Open HDF5 files
        train_file = h5py.File(train_path, "r")
        test_file = h5py.File(test_path, "r")
        
        # Extract data
        x_train = train_file['spikes']
        y_train = train_file['labels']
        x_test = test_file['spikes']
        y_test = test_file['labels']
        
        print("   Converting sparse data to dense format...")
        print("   Processing training data...")
        X_train_all, Y_train_all = collect_all(x_train, y_train, batch_size, nb_steps, nb_units, max_time)
        
        print("   Processing test data...")
        X_test_all, Y_test_all = collect_all(x_test, y_test, batch_size, nb_steps, nb_units, max_time)
        
        # Combine training and test data
        print(" Combining training and test data...")
        X_all = np.concatenate([X_train_all, X_test_all], axis=0)
        Y_all = np.concatenate([Y_train_all, Y_test_all], axis=0)
        
        print(f"   Final dataset shape:")
        print(f"   X shape: {X_all.shape}")
        print(f"   Y shape: {Y_all.shape}")
        print(f"   Total samples: {len(Y_all)}")
        print(f"   Unique labels: {len(np.unique(Y_all))}")
        
        # Save as MAT file
        print(f" Saving to: {save_path}")
        io.savemat(save_path, {'X': X_all, 'Y': Y_all})
        
        # Close files
        train_file.close()
        test_file.close()
        
        print(" MAT file generation completed successfully!")
        return save_path
        
    except Exception as e:
        print(f" Error during MAT file generation: {str(e)}")
        if 'train_file' in locals():
            train_file.close()
        if 'test_file' in locals():
            test_file.close()
        raise

print("MAT file generation function defined.")

MAT file generation function defined.


### Execute Baseline Dataset Generation

Let's create the combined `shd_whole.mat` file that serves as our baseline containing all original spike information.

In [15]:
# Configuration parameters
batch_size = 256
nb_steps = 100
nb_units = 700
max_time = 1.4

# You can customize the save path if needed
# custom_save_path = "/path/to/your/desired/location/shd_whole.mat"
custom_save_path = None  # Will save in the cache directory by default

print(" Generating baseline combined SHD dataset...")
print("=" * 60)

try:
    # Generate the MAT file
    saved_path = generate_shd_whole_mat(
        cache_dir=cache_dir,
        cache_subdir=cache_subdir,
        save_path=custom_save_path,
        batch_size=batch_size,
        nb_steps=nb_steps,
        nb_units=nb_units,
        max_time=max_time
    )
    
    print("=" * 60)
    print(" Baseline dataset generation completed!")
    print(f" File saved at: {saved_path}")
    
    # Verify the saved file
    if os.path.exists(saved_path):
        file_size_mb = os.path.getsize(saved_path) / (1024 * 1024)
        print(f" File size: {file_size_mb:.2f} MB")
        
        # Try to load and verify the MAT file
        try:
            mat_data = io.loadmat(saved_path)
            print(f"   MAT file verification:")
            print(f"   X shape: {mat_data['X'].shape}")
            print(f"   Y shape: {mat_data['Y'].shape}")
            print(f"   Data type X: {mat_data['X'].dtype}")
            print(f"   Data type Y: {mat_data['Y'].dtype}")
        except Exception as e:
            print(f" Could not verify MAT file contents: {str(e)}")
    else:
        print(" File was not created successfully")
        
except Exception as e:
    print(f" Error during baseline dataset generation: {str(e)}")
    print("Please check the dataset files and try again.")

 Generating baseline combined SHD dataset...
Starting MAT file generation...
Parameters: batch_size=256, nb_steps=100, nb_units=700, max_time=1.4
   Loading data from:
   Training: ./data/hdspikes/shd_train.h5
   Test: ./data/hdspikes/shd_test.h5
   Converting sparse data to dense format...
   Processing training data...
   Processing test data...
   Processing test data...
 Combining training and test data...
 Combining training and test data...
   Final dataset shape:
   X shape: (9984, 700, 100)
   Y shape: (9984,)
   Total samples: 9984
   Unique labels: 20
 Saving to: ./data/hdspikes/shd_whole.mat
   Final dataset shape:
   X shape: (9984, 700, 100)
   Y shape: (9984,)
   Total samples: 9984
   Unique labels: 20
 Saving to: ./data/hdspikes/shd_whole.mat
 MAT file generation completed successfully!
 Baseline dataset generation completed!
 File saved at: ./data/hdspikes/shd_whole.mat
 File size: 666.51 MB
 MAT file generation completed successfully!
 Baseline dataset generation comp

### Customization Options

You can customize the MAT file generation by modifying the parameters:

```python
# Example with custom parameters
custom_saved_path = generate_shd_whole_mat(
    cache_dir="./data",          # Your cache directory
    cache_subdir="hdspikes",     # Subdirectory with HDF5 files
    save_path="./my_shd_data.mat",  # Custom save location
    batch_size=128,              # Different batch size
    nb_steps=200,                # More time steps
    nb_units=700,                # Number of units
    max_time=1.4                 # Maximum time duration
)
```

**Parameters Explanation:**
- `batch_size`: Processing batch size (affects memory usage)
- `nb_steps`: Number of time bins for temporal discretization
- `nb_units`: Number of input units/neurons
- `max_time`: Maximum time duration in seconds
- `save_path`: Custom path for the output MAT file

The generated MAT file contains:
- `X`: Dense spike data array of shape `(total_samples, nb_units, nb_steps)`
- `Y`: Labels array of shape `(total_samples,)`

## Step 2: Create Specialized Timing-Based Benchmarks

Now we implement the core contribution of our research: creating specialized datasets that enable fair evaluation of temporal vs. rate-based information processing in spiking neural networks.

In [16]:
def do_min_count(X, Y):
    """
    Apply min-count processing to ensure each neuron has the same minimum number of spikes across all samples.
    
    Parameters:
    - X: Input spike data of shape (N, F, T)
    - Y: Labels array
    
    Returns:
    - X_min: Processed spike data with min-count applied
    - Y: Original labels (unchanged)
    """
    N, F, T = X.shape
    count_all = X.sum(axis=2)  # Count spikes per neuron per sample
    min_counts = count_all.min(axis=0)  # Min count for each neuron across all samples

    X_min = np.zeros_like(X)
    for f_idx in range(F):
        N_f = min_counts[f_idx]
        if N_f == 0:
            continue
        for i_idx in range(N):
            spike_times = np.where(X[i_idx, f_idx, :] == 1)[0]
            if len(spike_times) > N_f:
                # Randomly select N_f spike times
                chosen_times = np.random.choice(spike_times, size=N_f, replace=False)
                X_min[i_idx, f_idx, chosen_times] = 1
            else:
                # Keep all existing spikes
                X_min[i_idx, f_idx, spike_times] = 1
    return X_min, Y

print("Min-count processing function defined.")

Min-count processing function defined.


In [17]:
def create_min_count_dataset_avoid_widespread(X, Y, neuron_threshold=2, max_frac_for_neuron=0.01, max_samples_to_remove=1000):
    """
    Create a min-count dataset while avoiding widespread neuron issues.
    
    Parameters:
    - X: Input spike data of shape (N, F, T)
    - Y: Labels array
    - neuron_threshold: Minimum spike count threshold per neuron
    - max_frac_for_neuron: Maximum fraction of samples to remove for any neuron
    - max_samples_to_remove: Maximum total samples to remove
    
    Returns:
    - X_processed: Processed spike data
    - Y_processed: Corresponding labels
    """
    N, F, T = X.shape
    counts = X.sum(axis=2)  # Count spikes per neuron per sample
    min_counts_per_neuron = counts.min(axis=0)

    # Find problematic neurons
    bad_neurons = np.where(min_counts_per_neuron < neuron_threshold)[0]
    print(f"Found {len(bad_neurons)} neurons with min_count < {neuron_threshold}.")

    if len(bad_neurons) == 0:
        return do_min_count(X, Y)

    # Identify samples to potentially remove
    samples_to_remove = set()
    for f_idx in bad_neurons:
        neuron_counts = counts[:, f_idx]
        i_bad = np.where(neuron_counts < neuron_threshold)[0]
        frac = len(i_bad) / N
        if frac <= max_frac_for_neuron:
            samples_to_remove.update(i_bad)

    # Apply sample filtering if reasonable
    if 0 < len(samples_to_remove) < max_samples_to_remove:
        keep_idxs = np.setdiff1d(np.arange(N), list(samples_to_remove))
        X_filtered = X[keep_idxs]
        Y_filtered = Y[keep_idxs]
        print(f"Removing {len(samples_to_remove)} samples.")
    else:
        X_filtered = X
        Y_filtered = Y
        print("NOT removing any samples.")

    return do_min_count(X_filtered, Y_filtered)

print("Advanced min-count processing function defined.")

Advanced min-count processing function defined.


In [18]:
def balance_dataset(X, Y):
    """
    Balance the dataset by equalizing the number of samples per class.
    
    Parameters:
    - X: Input spike data of shape (N, F, T)
    - Y: Labels array
    
    Returns:
    - X_balanced: Balanced spike data
    - Y_balanced: Balanced labels
    """
    N, num_neurons, T = X.shape
    Y_flat = Y.ravel()

    print(f"Input dataset: X.shape=({N},{num_neurons},{T}), Y.shape=({len(Y_flat)})")

    # Count samples per class
    unique_labels = np.unique(Y_flat)
    counts_per_class = {c: np.sum(Y_flat == c) for c in unique_labels}
    min_count = min(counts_per_class.values())

    print("--- Sample counts per class ---")
    for c in sorted(unique_labels):
        print(f"Class {c}: {counts_per_class[c]} samples")
    print(f"Using min_count = {min_count}")

    # Balance classes
    X_list = []
    Y_list = []
    for c in sorted(unique_labels):
        idxs_c = np.where(Y_flat == c)[0]
        np.random.shuffle(idxs_c)
        selected = idxs_c[:min_count]
        X_list.append(X[selected])
        Y_list.append(Y_flat[selected])

    X_balanced = np.concatenate(X_list, axis=0)
    Y_balanced = np.concatenate(Y_list, axis=0)

    # Shuffle the balanced dataset
    perm = np.random.permutation(len(Y_balanced))
    X_balanced = X_balanced[perm]
    Y_balanced = Y_balanced[perm]

    print(f"Final balanced shape: X={X_balanced.shape}, Y={Y_balanced.shape}")
    return X_balanced, Y_balanced

print("Class balancing function defined.")

Class balancing function defined.


In [25]:
def process_shd_whole_to_datasets(whole_mat_path, output_dir=None, 
                                  neuron_threshold=2, max_frac_for_neuron=0.01, 
                                  max_samples_to_remove=2000):
    """
    Process shd_whole.mat to generate shd_part.mat and shd_norm.mat files.
    
    Parameters:
    - whole_mat_path: Path to the shd_whole.mat file
    - output_dir: Output directory (if None, uses same directory as input)
    - neuron_threshold: Minimum spike count threshold per neuron
    - max_frac_for_neuron: Maximum fraction of samples to remove for any neuron
    - max_samples_to_remove: Maximum total samples to remove
    
    Returns:
    - dict: Paths to generated files
    """
    
    print(" Processing shd_whole.mat to generate part and norm datasets...")
    print("=" * 70)
    
    # Set output directory
    if output_dir is None:
        output_dir = os.path.dirname(whole_mat_path)
    
    # Define output paths
    part_path = os.path.join(output_dir, "shd_part.mat")
    norm_path = os.path.join(output_dir, "shd_norm.mat")
    
    try:
        # Load the whole dataset
        print(f" Loading data from: {whole_mat_path}")
        if not os.path.exists(whole_mat_path):
            raise FileNotFoundError(f"File not found: {whole_mat_path}")
            
        data = io.loadmat(whole_mat_path)
        X_all = data['X']
        Y_all = data['Y'].ravel()
        
        print(f"   Loaded dataset:")
        print(f"   X shape: {X_all.shape}")
        print(f"   Y shape: {Y_all.shape}")
        print(f"   Unique labels: {len(np.unique(Y_all))}")
        
        # Step 1: Apply min-count processing and remove problematic samples
        print(f"\n  Step 1: Applying min-count processing...")
        X_min, Y_min = create_min_count_dataset_avoid_widespread(
            X_all, Y_all,
            neuron_threshold=neuron_threshold,
            max_frac_for_neuron=max_frac_for_neuron,
            max_samples_to_remove=max_samples_to_remove
        )
        
        # Step 2: Remove neurons that have zero activity across all samples and time
        print(f"\n Step 2: Removing inactive neurons...")
        sum_over_samples_time = X_min.sum(axis=(0, 2))
        zero_mask = (sum_over_samples_time == 0)
        non_zero_mask = ~zero_mask
        
        print(f"   Found {zero_mask.sum()} inactive neurons out of {len(zero_mask)}")
        
        # Create normalized dataset (with min-count + removed inactive neurons)
        X_norm = X_min[:, non_zero_mask, :]
        print(f"   After removing inactive neurons: X_norm.shape = {X_norm.shape}")
        
        # Create part dataset (original spikes but same active neurons and same removed samples as norm)
        X_part_unbalanced = X_all[:, non_zero_mask, :]
        
        # Important: Apply the same sample filtering to part dataset as was applied to norm dataset
        # We need to use the same samples that were kept for the norm dataset
        if X_min.shape[0] != X_all.shape[0]:
            # Some samples were removed during min-count processing
            print(f"   Applying same sample filtering to part dataset...")
            print(f"   Original samples: {X_all.shape[0]}, After min-count filtering: {X_min.shape[0]}")
            print(f"   Samples removed: {X_all.shape[0] - X_min.shape[0]}")
            
            # We need to track which samples were removed during the min-count processing
            # The best approach is to re-run the sample identification logic
            N, F, T = X_all.shape
            counts = X_all.sum(axis=2)
            min_counts_per_neuron = counts.min(axis=0)
            bad_neurons = np.where(min_counts_per_neuron < neuron_threshold)[0]
            
            samples_to_remove = set()
            if len(bad_neurons) > 0:
                for f_idx in bad_neurons:
                    neuron_counts = counts[:, f_idx]
                    i_bad = np.where(neuron_counts < neuron_threshold)[0]
                    frac = len(i_bad) / N
                    if frac <= max_frac_for_neuron:
                        samples_to_remove.update(i_bad)
            
            if 0 < len(samples_to_remove) < max_samples_to_remove:
                kept_indices = np.setdiff1d(np.arange(N), list(samples_to_remove))
                X_part_filtered = X_part_unbalanced[kept_indices]
                Y_part_filtered = Y_all[kept_indices]
                print(f"   Applied same sample filtering: removed {len(samples_to_remove)} samples")
                print(f"   Part dataset after filtering: {X_part_filtered.shape}")
            else:
                print(f"   No samples were removed in min-count processing")
                X_part_filtered = X_part_unbalanced
                Y_part_filtered = Y_all
        else:
            # No samples were removed
            X_part_filtered = X_part_unbalanced
            Y_part_filtered = Y_all
        
        # Step 3: Apply class balancing
        print(f"\n Step 3: Applying class balancing...")
        
        print(f"   Balancing part dataset...")
        X_part, Y_part = balance_dataset(X_part_filtered, Y_part_filtered)
        
        print(f"   Balancing norm dataset...")
        X_norm_balanced, Y_norm = balance_dataset(X_norm, Y_min)
        
        # Step 4: Save the processed datasets
        print(f"\n Step 4: Saving specialized benchmark datasets...")
        
        # Save part dataset
        print(f"   Saving part dataset to: {part_path}")
        io.savemat(part_path, {"X": X_part, "Y": Y_part})
        part_size_mb = os.path.getsize(part_path) / (1024 * 1024)
        print(f"    shd_part.mat saved ({part_size_mb:.2f} MB)")
        
        # Save norm dataset  
        print(f"   Saving norm dataset to: {norm_path}")
        io.savemat(norm_path, {"X": X_norm_balanced, "Y": Y_norm})
        norm_size_mb = os.path.getsize(norm_path) / (1024 * 1024)
        print(f"    shd_norm.mat saved ({norm_size_mb:.2f} MB)")
        
        # Summary
        print(f"\n Processing Summary:")
        print(f"   Original dataset: {X_all.shape}")
        print(f"   After sample filtering: {X_part_filtered.shape}")
        print(f"   Part dataset (balanced): {X_part.shape}")
        print(f"   Norm dataset (balanced): {X_norm_balanced.shape}")
        print(f"   Active neurons: {non_zero_mask.sum()}/{len(non_zero_mask)}")
        print(f"   Samples removed: {X_all.shape[0] - X_part_filtered.shape[0]}")
        
        # Verify that both datasets have the same number of samples after balancing
        if X_part.shape[0] == X_norm_balanced.shape[0]:
            print(f"    Both datasets have same sample count after balancing: {X_part.shape[0]}")
        else:
            print(f"     Different sample counts - Part: {X_part.shape[0]}, Norm: {X_norm_balanced.shape[0]}")
        
        return {
            "part_path": part_path,
            "norm_path": norm_path,
            "part_shape": X_part.shape,
            "norm_shape": X_norm_balanced.shape,
            "active_neurons": non_zero_mask.sum()
        }
        
    except Exception as e:
        print(f" Error during processing: {str(e)}")
        raise

print(" Specialized dataset processing function defined.")

 Specialized dataset processing function defined.


### Execute Specialized Dataset Generation

Now let's create the two specialized benchmark datasets from our baseline `shd_whole.mat`:

In [26]:
# Set random seed for reproducibility
np.random.seed(42)

# Configuration parameters
neuron_threshold = 2
max_frac_for_neuron = 0.01
max_samples_to_remove = 2000

# Path to the shd_whole.mat file (generated in previous step)
whole_mat_path = os.path.join(full_cache_path, "shd_whole.mat")

print(" Starting specialized dataset generation...")
print("=" * 70)
print(f"   Input file: {whole_mat_path}")
print(f"   Processing parameters:")
print(f"   neuron_threshold: {neuron_threshold}")
print(f"   max_frac_for_neuron: {max_frac_for_neuron}")
print(f"   max_samples_to_remove: {max_samples_to_remove}")

try:
    # Check if the whole mat file exists
    if not os.path.exists(whole_mat_path):
        print(f" File not found: {whole_mat_path}")
        print("  Please run the baseline MAT file generation step first!")
    else:
        # Process the dataset
        results = process_shd_whole_to_datasets(
            whole_mat_path=whole_mat_path,
            output_dir=None,  # Will save in the same directory
            neuron_threshold=neuron_threshold,
            max_frac_for_neuron=max_frac_for_neuron,
            max_samples_to_remove=max_samples_to_remove
        )
        
        print("=" * 70)
        print("  Specialized dataset generation completed!")
        print(f"  Generated timing-based benchmark files:")
        print(f"     shd_part.mat: {results['part_path']}")
        print(f"     shd_norm.mat: {results['norm_path']}")
        
        # Verify the generated files
        for name, path in [("Part", results['part_path']), ("Norm", results['norm_path'])]:
            if os.path.exists(path):
                file_size_mb = os.path.getsize(path) / (1024 * 1024)
                print(f"    {name} dataset: {file_size_mb:.2f} MB")
                
                # Load and verify
                try:
                    verify_data = io.loadmat(path)
                    print(f"      X shape: {verify_data['X'].shape}")
                    print(f"      Y shape: {verify_data['Y'].shape}")
                    print(f"      Unique labels: {len(np.unique(verify_data['Y']))}")
                except Exception as e:
                    print(f"        Could not verify file: {str(e)}")
            else:
                print(f"     {name} dataset: File not created")
        
        print(f"\n Final Statistics:")
        print(f"   Active neurons used: {results['active_neurons']}")
        print(f"   Part dataset shape: {results['part_shape']}")
        print(f"   Norm dataset shape: {results['norm_shape']}")
        
except Exception as e:
    print(f" Error during specialized dataset generation: {str(e)}")
    print("Please check the input file and parameters.")

 Starting specialized dataset generation...
   Input file: ./data/hdspikes/shd_whole.mat
   Processing parameters:
   neuron_threshold: 2
   max_frac_for_neuron: 0.01
   max_samples_to_remove: 2000
 Processing shd_whole.mat to generate part and norm datasets...
 Loading data from: ./data/hdspikes/shd_whole.mat
   Loaded dataset:
   X shape: (9984, 700, 100)
   Y shape: (9984,)
   Unique labels: 20

  Step 1: Applying min-count processing...
   Loaded dataset:
   X shape: (9984, 700, 100)
   Y shape: (9984,)
   Unique labels: 20

  Step 1: Applying min-count processing...
Found 691 neurons with min_count < 2.
Found 691 neurons with min_count < 2.
Removing 1037 samples.
Removing 1037 samples.

 Step 2: Removing inactive neurons...

 Step 2: Removing inactive neurons...
   Found 476 inactive neurons out of 700
   Found 476 inactive neurons out of 700
   After removing inactive neurons: X_norm.shape = (8947, 224, 100)
   Applying same sample filtering to part dataset...
   Original samples

### Dataset Processing Summary

The advanced processing pipeline generates three types of datasets:

1. **`shd_whole.mat`**: Complete combined dataset (training + test)
   - Contains all original spike data
   - No preprocessing applied
   - Used as the base for further processing

2. **`shd_part.mat`**: Partial dataset with class balancing
   - Original spike patterns preserved
   - **Same sample filtering as norm dataset applied**
   - Inactive neurons removed
   - Classes balanced (equal samples per class)
   - **Should have same sample count as norm dataset**
   - Suitable for standard training

3. **`shd_norm.mat`**: Normalized dataset with min-count processing
   - Min-count spike normalization applied
   - Inactive neurons removed
   - Classes balanced
   - Problematic samples filtered out
   - Suitable for specialized training scenarios

**Processing Steps:**
1. **Min-count processing**: Ensures consistent spike counts across neurons (applies to norm only)
2. **Sample filtering**: Removes problematic samples (applies to both datasets equally)
3. **Neuron filtering**: Removes neurons with zero activity (applies to both datasets)
4. **Class balancing**: Equalizes the number of samples per class (applies to both datasets)
5. **Shuffling**: Randomizes sample order (applies to both datasets)

**Key Improvement:**
- Both `part` and `norm` datasets now use the same filtered samples before class balancing
- This ensures consistent comparison between the two dataset variants
- Both datasets will have identical sample counts after balancing (e.g., 5460 samples each)

**Customization Options:**
- `neuron_threshold`: Minimum spike count per neuron (default: 2)
- `max_frac_for_neuron`: Maximum fraction of samples to remove (default: 0.01)  
- `max_samples_to_remove`: Maximum total samples to remove (default: 2000)

## Step 3: Verify Generated Timing-Based Benchmarks

Let's load and verify our specialized timing-based benchmark datasets to ensure they meet our research requirements.

In [23]:
# Load and verify the specialized datasets
data_dir = full_cache_path
part_path = os.path.join(data_dir, 'shd_part.mat')
norm_path = os.path.join(data_dir, 'shd_norm.mat')

print(" Loading specialized timing-based benchmark datasets...")
print("=" * 70)

# Verify files exist
files_to_check = [
    ("Part (spike pattern)", part_path),
    ("Norm (normalized)", norm_path)
]

for name, filepath in files_to_check:
    if os.path.exists(filepath):
        file_size_mb = os.path.getsize(filepath) / (1024 * 1024)
        print(f" {name}: {file_size_mb:.2f} MB")
        
        # Load the data
        data = io.loadmat(filepath)
        
        print(f"    Dataset shape: X={data['X'].shape}, Y={data['Y'].shape}")
        print(f"    Number of classes: {len(np.unique(data['Y']))}")
        print(f"    Total samples: {data['X'].shape[0]}")
        
        # Check spike statistics for this dataset
        total_spikes = np.sum(data['X'])
        avg_spikes_per_sample = total_spikes / data['X'].shape[0]
        print(f"    Total spikes: {total_spikes:,}")
        print(f"    Average spikes per sample: {avg_spikes_per_sample:.2f}")
        
        # Check class distribution
        unique_labels, counts = np.unique(data['Y'], return_counts=True)
        print(f"    Class distribution: {dict(zip(unique_labels.flatten(), counts))}")
        
        print()
    else:
        print(f" {name}: File not found at {filepath}")

print("  Research Insight:")
print("   • Part dataset: Contains original spike patterns (temporal information)")
print("   • Norm dataset: Min-count normalized (removes rate information)")
print("   • Both datasets are class-balanced for fair comparison")
print("   • Same problematic samples removed from both variants")

 Loading specialized timing-based benchmark datasets...
 Part (spike pattern): 116.64 MB
    Dataset shape: X=(5460, 224, 100), Y=(1, 5460)
    Number of classes: 20
    Total samples: 5460
    Total spikes: 17,584,343
    Average spikes per sample: 3220.58
    Class distribution: {0: 273, 1: 273, 2: 273, 3: 273, 4: 273, 5: 273, 6: 273, 7: 273, 8: 273, 9: 273, 10: 273, 11: 273, 12: 273, 13: 273, 14: 273, 15: 273, 16: 273, 17: 273, 18: 273, 19: 273}

 Norm (normalized): 116.64 MB
    Dataset shape: X=(5460, 224, 100), Y=(1, 5460)
    Number of classes: 20
    Total samples: 5460
    Total spikes: 2,571,660
    Average spikes per sample: 471.00
    Class distribution: {0: 273, 1: 273, 2: 273, 3: 273, 4: 273, 5: 273, 6: 273, 7: 273, 8: 273, 9: 273, 10: 273, 11: 273, 12: 273, 13: 273, 14: 273, 15: 273, 16: 273, 17: 273, 18: 273, 19: 273}

  Research Insight:
   • Part dataset: Contains original spike patterns (temporal information)
   • Norm dataset: Min-count normalized (removes rate info

In [24]:
# Comparative analysis between the two benchmark datasets
print("\n Comparative Analysis of Timing-Based Benchmarks")
print("=" * 70)

if os.path.exists(part_path) and os.path.exists(norm_path):
    # Load both datasets
    part_data = io.loadmat(part_path)
    norm_data = io.loadmat(norm_path)
    
    X_part, Y_part = part_data['X'], part_data['Y']
    X_norm, Y_norm = norm_data['X'], norm_data['Y']
    
    # Verify identical shapes and sample counts
    print(f"   Dataset Shape Comparison:")
    print(f"   Part dataset: {X_part.shape}")
    print(f"   Norm dataset: {X_norm.shape}")
    
    if X_part.shape == X_norm.shape:
        print(f"    Both datasets have identical shapes")
    else:
        print(f"     Different shapes detected")
    
    # Spike count comparison
    part_total_spikes = np.sum(X_part)
    norm_total_spikes = np.sum(X_norm)
    
    print(f"\n Spike Count Analysis:")
    print(f"   Part dataset total spikes: {part_total_spikes:,}")
    print(f"   Norm dataset total spikes: {norm_total_spikes:,}")
    print(f"   Spike reduction ratio: {norm_total_spikes/part_total_spikes:.3f}")
    
    # Per-sample spike distribution
    part_spikes_per_sample = np.sum(X_part, axis=(1,2))
    norm_spikes_per_sample = np.sum(X_norm, axis=(1,2))
    
    print(f"\n Spike Distribution Per Sample:")
    print(f"   Part dataset - Mean: {np.mean(part_spikes_per_sample):.2f}, Std: {np.std(part_spikes_per_sample):.2f}")
    print(f"   Norm dataset - Mean: {np.mean(norm_spikes_per_sample):.2f}, Std: {np.std(norm_spikes_per_sample):.2f}")
    
    # Research implications
    print(f"\n Research Implications:")
    print(f"   • Part dataset preserves temporal spike patterns and rate information")
    print(f"   • Norm dataset removes rate information while preserving timing patterns")
    print(f"   • Performance differences indicate reliance on rate vs. temporal coding")
    print(f"   • Both datasets enable fair comparison with identical preprocessing")
    
else:
    print("  Cannot perform comparative analysis - datasets not found")
    print("   Please ensure both datasets have been generated successfully")


 Comparative Analysis of Timing-Based Benchmarks
   Dataset Shape Comparison:
   Part dataset: (5460, 224, 100)
   Norm dataset: (5460, 224, 100)
    Both datasets have identical shapes

 Spike Count Analysis:
   Part dataset total spikes: 17,584,343
   Norm dataset total spikes: 2,571,660
   Spike reduction ratio: 0.146

 Spike Distribution Per Sample:
   Part dataset - Mean: 3220.58, Std: 897.56
   Norm dataset - Mean: 471.00, Std: 0.00

 Research Implications:
   • Part dataset preserves temporal spike patterns and rate information
   • Norm dataset removes rate information while preserving timing patterns
   • Performance differences indicate reliance on rate vs. temporal coding
   • Both datasets enable fair comparison with identical preprocessing

 Spike Count Analysis:
   Part dataset total spikes: 17,584,343
   Norm dataset total spikes: 2,571,660
   Spike reduction ratio: 0.146

 Spike Distribution Per Sample:
   Part dataset - Mean: 3220.58, Std: 897.56
   Norm dataset - Mea

## Conclusion: Enabling Fair Assessment of Temporal vs. Rate-Based Processing

This notebook provides a complete pipeline for generating specialized timing-based benchmarks from the SHD dataset, enabling researchers to conduct fair assessments of temporal versus rate-based information processing in spiking neural networks.

### 🎯 Key Contributions

1. **Baseline Dataset Generation (`shd_whole.mat`)**
   - Combines training and test data into a unified format
   - Preserves original temporal structure
   - Provides foundation for specialized processing

2. **Timing-Based Benchmark Datasets**
   - **`shd_part.mat`**: Preserves both temporal patterns and rate information
   - **`shd_norm.mat`**: Eliminates rate information via min-count normalization
   - Both datasets undergo identical preprocessing for fair comparison

3. **Reproducible Research Pipeline**
   - Standardized processing parameters
   - Consistent sample filtering across dataset variants
   - Class balancing for unbiased evaluation
   - Comprehensive validation and verification

### 🔬 Research Applications

- **Temporal Coding Research**: Compare model performance between `part` and `norm` datasets
- **Rate vs. Timing Analysis**: Performance differences indicate information encoding preferences
- **Neuromorphic Benchmarking**: Standardized datasets for fair model comparison
- **Algorithm Development**: Controlled environment for testing temporal processing algorithms

### 📈 Usage Recommendations

1. **Baseline Evaluation**: Train models on `shd_part.mat` for standard performance assessment
2. **Temporal Sensitivity Testing**: Compare performance between `part` and `norm` datasets
3. **Algorithm Comparison**: Use identical preprocessing to ensure fair model comparisons
4. **Research Publication**: Reference this pipeline for reproducible results

### 🚀 Future Extensions

- Support for additional neuromorphic datasets
- Extended normalization techniques
- Advanced temporal analysis metrics
- Integration with popular neuromorphic frameworks

---

**Citation**: If you use these timing-based benchmarks in your research, please cite the associated publication and acknowledge this processing pipeline.

**Reproducibility**: All processing steps are deterministic (with fixed random seed) to ensure reproducible results across different research groups.