In [1]:
"""
=============================================================================
SPIKING HEIDELBERG DIGITS (SHD) CLASSIFICATION WITH SNNTORCH
=============================================================================
Step 1: Environment Setup & Reproducibility Configuration
Author: AI Research Engineer
Dataset: Spiking Heidelberg Digits (SHD)
Framework: snntorch + PyTorch + Tonic
=============================================================================
"""

# ============================================================================
# PART A: Install Required Libraries
# ============================================================================
print("=" * 80)
print("INSTALLING DEPENDENCIES")
print("=" * 80)

import sys
import subprocess

def install_package(package_name, import_name=None):
    """Install package if not already available."""
    if import_name is None:
        import_name = package_name
    
    try:
        __import__(import_name)
        print(f"✓ {package_name} already installed")
    except ImportError:
        print(f"Installing {package_name}...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", package_name])
        print(f"✓ {package_name} installed successfully")

# Install required packages
install_package("snntorch")
install_package("tonic")
install_package("h5py")
install_package("celluloid")  # For animation support

print("\n" + "=" * 80)
print("ALL DEPENDENCIES INSTALLED")
print("=" * 80 + "\n")

# ============================================================================
# PART B: Import All Required Modules
# ============================================================================
print("Importing libraries...")

# Standard Libraries
import os
import random
import warnings
from pathlib import Path
from typing import Tuple, Dict, List, Optional

# Numerical & Data Processing
import numpy as np
import pandas as pd
import h5py

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.gridspec import GridSpec
from IPython.display import clear_output

# PyTorch & Deep Learning
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split

# SNNTorch - Spiking Neural Networks
import snntorch as snn
from snntorch import surrogate
from snntorch import functional as SF
from snntorch import utils

# Tonic - Neuromorphic Data Loading
import tonic
from tonic import DiskCachedDataset

# Training utilities
from tqdm.auto import tqdm
import time

# Suppress warnings for cleaner output
warnings.filterwarnings('ignore')

print("✓ All libraries imported successfully\n")

# ============================================================================
# PART C: Set Global Random Seeds for Reproducibility
# ============================================================================
print("=" * 80)
print("CONFIGURING REPRODUCIBILITY (seed=42)")
print("=" * 80)

SEED = 42

def set_seed(seed: int = 42):
    """
    Set random seeds for complete reproducibility across:
    - Python's random module
    - NumPy
    - PyTorch (CPU and CUDA)
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # For multi-GPU setups
    
    # Configure PyTorch for deterministic operations
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    # Set environment variable for Python hashing
    os.environ['PYTHONHASHSEED'] = str(seed)
    
    print(f"✓ Global seed set to {seed}")
    print("✓ PyTorch configured for deterministic operations")
    print("✓ CUDNN deterministic mode enabled")
    print("✓ CUDNN benchmark mode disabled")

set_seed(SEED)

# ============================================================================
# PART D: GPU Detection & Initialization
# ============================================================================
print("\n" + "=" * 80)
print("GPU DETECTION & INITIALIZATION")
print("=" * 80)

# Detect available device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

if torch.cuda.is_available():
    print(f"✓ GPU Detected: {torch.cuda.get_device_name(0)}")
    print(f"✓ CUDA Version: {torch.version.cuda}")
    print(f"✓ Number of GPUs: {torch.cuda.device_count()}")
    print(f"✓ GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    print(f"✓ Current GPU Memory Allocated: {torch.cuda.memory_allocated(0) / 1e9:.4f} GB")
    print(f"✓ Current GPU Memory Cached: {torch.cuda.memory_reserved(0) / 1e9:.4f} GB")
else:
    print("⚠ WARNING: GPU not available. Training will be slow on CPU.")
    print("Consider enabling GPU acceleration in Kaggle Notebook settings.")

# ============================================================================
# PART E: Configure Visualization Settings
# ============================================================================
print("\n" + "=" * 80)
print("VISUALIZATION CONFIGURATION")
print("=" * 80)

# Set plotting style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

# Configure matplotlib for better plots
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['figure.dpi'] = 100
plt.rcParams['font.size'] = 10
plt.rcParams['axes.labelsize'] = 12
plt.rcParams['axes.titlesize'] = 14
plt.rcParams['legend.fontsize'] = 10
plt.rcParams['xtick.labelsize'] = 10
plt.rcParams['ytick.labelsize'] = 10

print("✓ Visualization settings configured")

# ============================================================================
# PART F: Define Kaggle Directory Paths
# ============================================================================
print("\n" + "=" * 80)
print("DIRECTORY STRUCTURE")
print("=" * 80)

# Define paths for Kaggle environment
BASE_DIR = Path("/kaggle/working")
DATA_DIR = Path("/kaggle/input")
OUTPUT_DIR = BASE_DIR / "outputs"
MODEL_DIR = BASE_DIR / "models"
CACHE_DIR = BASE_DIR / "cache"

# Create necessary directories
for directory in [OUTPUT_DIR, MODEL_DIR, CACHE_DIR]:
    directory.mkdir(parents=True, exist_ok=True)
    print(f"✓ Created/Verified: {directory}")

# ============================================================================
# PART G: System Information Summary
# ============================================================================
print("\n" + "=" * 80)
print("SYSTEM INFORMATION SUMMARY")
print("=" * 80)

print(f"Python Version: {sys.version.split()[0]}")
print(f"PyTorch Version: {torch.__version__}")
print(f"snntorch Version: {snn.__version__}")
print(f"Tonic Version: {tonic.__version__}")
print(f"NumPy Version: {np.__version__}")
print(f"Device: {device}")
print(f"Working Directory: {BASE_DIR}")
print(f"Random Seed: {SEED}")

print("\n" + "=" * 80)
print("ENVIRONMENT SETUP COMPLETE ✓")
print("=" * 80)
print("\nReady to proceed to data acquisition and preprocessing.\n")

INSTALLING DEPENDENCIES
Installing snntorch...
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 125.6/125.6 kB 2.8 MB/s eta 0:00:00
✓ snntorch installed successfully
Installing tonic...
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 106.2/106.2 kB 3.0 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 50.4/50.4 kB 2.2 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 131.9/131.9 kB 5.2 MB/s eta 0:00:00
✓ tonic installed successfully
✓ h5py already installed
Installing celluloid...
✓ celluloid installed successfully

ALL DEPENDENCIES INSTALLED

Importing libraries...
✓ All libraries imported successfully

CONFIGURING REPRODUCIBILITY (seed=42)
✓ Global seed set to 42
✓ PyTorch configured for deterministic operations
✓ CUDNN deterministic mode enabled
✓ CUDNN benchmark mode disabled

GPU DETECTION & INITIALIZATION
Device: cuda
✓ GPU Detected: Tesla T4
✓ CUDA Version: 12.4
✓ Number of GPUs: 2
✓ GPU Memory: 15.83 GB
✓ Current GPU Memory Allocated: 0.0000 GB
✓ Current GPU Memory

In [3]:
"""
=============================================================================
Step 2: Data Acquisition & Preprocessing
=============================================================================
Download and preprocess the Spiking Heidelberg Digits (SHD) dataset.
The SHD dataset contains spike trains from 700 input channels representing
spoken digits (0-9) in German, recorded as spike events.
=============================================================================
"""

# ============================================================================
# PART A: Dataset Overview & Information
# ============================================================================
print("=" * 80)
print("SPIKING HEIDELBERG DIGITS (SHD) DATASET")
print("=" * 80)

dataset_info = """
Dataset: Spiking Heidelberg Digits (SHD)
Task: Audio digit classification (0-9 in German)
Classes: 20 classes (10 digits × 2 speakers)
Input: Spike trains from 700 audio channels
Encoding: Cochlear model converting audio to spike events
Train samples: ~8,156 samples
Test samples: ~2,264 samples
Duration: Variable (~1 second typical)
Format: HDF5 files with spike times and neuron indices

Key Characteristics:
- Neuromorphic representation of audio
- Sparse spike events (not dense tensors)
- Temporal information encoded in spike timing
- Requires time-binning for SNN processing
"""

print(dataset_info)

# ============================================================================
# PART B: Download SHD Dataset Using Tonic
# ============================================================================
print("=" * 80)
print("DOWNLOADING SHD DATASET")
print("=" * 80)

# Define data directory
SHD_DATA_DIR = BASE_DIR / "shd_data"
SHD_DATA_DIR.mkdir(exist_ok=True)

print(f"Download location: {SHD_DATA_DIR}\n")

# Download training and test sets
print("Downloading training set...")
train_dataset_raw = tonic.datasets.SHD(
    save_to=str(SHD_DATA_DIR),
    train=True
)

print("\nDownloading test set...")
test_dataset_raw = tonic.datasets.SHD(
    save_to=str(SHD_DATA_DIR),
    train=False
)

print("\n✓ Dataset downloaded successfully")
print(f"✓ Train samples: {len(train_dataset_raw)}")
print(f"✓ Test samples: {len(test_dataset_raw)}")

# ============================================================================
# PART C: Inspect Raw Data Structure
# ============================================================================
print("\n" + "=" * 80)
print("INSPECTING RAW DATA STRUCTURE")
print("=" * 80)

# Get a sample from the dataset
sample_events, sample_label = train_dataset_raw[0]

print(f"\nSample Label (Class): {sample_label}")
print(f"\nEvent Data Structure:")
print(f"Type: {type(sample_events)}")
print(f"Dtype: {sample_events.dtype}")
print(f"Shape: {sample_events.shape}")
print(f"Number of spike events: {len(sample_events)}")

print(f"\nEvent Fields:")
for field_name in sample_events.dtype.names:
    print(f"  - {field_name}: {sample_events[field_name][:5]} ...")

# Extract key statistics
neuron_indices = sample_events['x']
spike_times = sample_events['t']

print(f"\nSample Statistics:")
print(f"  Neuron indices range: {neuron_indices.min()} to {neuron_indices.max()}")
print(f"  Number of input channels: {neuron_indices.max() + 1}")
print(f"  Spike times range: {spike_times.min():.2f} to {spike_times.max():.2f} μs")
print(f"  Duration: {(spike_times.max() - spike_times.min()) / 1e6:.4f} seconds")
print(f"  Total spikes: {len(spike_times)}")
print(f"  Average firing rate: {len(spike_times) / (neuron_indices.max() + 1):.2f} spikes/neuron")

# ============================================================================
# PART D: Analyze Dataset Statistics
# ============================================================================
print("\n" + "=" * 80)
print("DATASET STATISTICS ANALYSIS")
print("=" * 80)

def analyze_dataset_stats(dataset, name="Dataset", num_samples=500):
    """Analyze key statistics across multiple samples."""
    
    print(f"\nAnalyzing {name} ({num_samples} samples)...")
    
    stats = {
        'num_spikes': [],
        'duration': [],
        'active_neurons': [],
        'labels': []
    }
    
    # Sample random indices
    np.random.seed(SEED)
    indices = np.random.choice(len(dataset), min(num_samples, len(dataset)), replace=False)
    
    for idx in tqdm(indices, desc=f"Analyzing {name}"):
        events, label = dataset[idx]
        
        spike_times = events['t']
        neuron_ids = events['x']
        
        stats['num_spikes'].append(len(events))
        stats['duration'].append((spike_times.max() - spike_times.min()) / 1e6)  # Convert to seconds
        stats['active_neurons'].append(len(np.unique(neuron_ids)))
        stats['labels'].append(label)
    
    return stats

# Analyze both datasets
train_stats = analyze_dataset_stats(train_dataset_raw, "Train Set", num_samples=500)
test_stats = analyze_dataset_stats(test_dataset_raw, "Test Set", num_samples=200)

# Print statistics
def print_stats(stats, name):
    print(f"\n{name} Statistics:")
    print(f"  Spikes per sample: {np.mean(stats['num_spikes']):.1f} ± {np.std(stats['num_spikes']):.1f}")
    print(f"  Duration (seconds): {np.mean(stats['duration']):.3f} ± {np.std(stats['duration']):.3f}")
    print(f"  Active neurons: {np.mean(stats['active_neurons']):.1f} ± {np.std(stats['active_neurons']):.1f}")
    print(f"  Classes: {len(np.unique(stats['labels']))} unique labels")
    print(f"  Class distribution: {np.bincount(stats['labels'])}")

print_stats(train_stats, "Train Set")
print_stats(test_stats, "Test Set")

# ============================================================================
# PART E: Define Time-Binning Transform
# ============================================================================
print("\n" + "=" * 80)
print("CONFIGURING TIME-BINNING TRANSFORM")
print("=" * 80)

# Time-binning parameters
NUM_TIME_BINS = 100  # Discretize time into 100 bins
SENSOR_SIZE = (700, 1, 1)  # 700 input channels

print(f"""
Time-Binning Configuration:
  Input channels: {SENSOR_SIZE[0]}
  Time bins: {NUM_TIME_BINS}
  
Rationale:
  - SNNs process data in discrete time steps
  - Time-binning converts continuous spike times into discrete bins
  - Each bin represents a simulation time step
  - Output shape: (Time_bins, Channels) = ({NUM_TIME_BINS}, {SENSOR_SIZE[0]})
""")

# Define transform using Tonic
transform = tonic.transforms.Compose([
    tonic.transforms.ToFrame(
        sensor_size=SENSOR_SIZE,
        n_time_bins=NUM_TIME_BINS
    )
])

print("✓ Time-binning transform configured")

# ============================================================================
# PART F: Create Cached Datasets
# ============================================================================
print("\n" + "=" * 80)
print("CREATING DISK-CACHED DATASETS")
print("=" * 80)

print("""
Why Disk Caching?
  - Transform operations are expensive (time-binning for each sample)
  - Caching pre-processes data once and saves to disk
  - Dramatically speeds up training (no repeated transformations)
  - Essential for iterative experimentation
""")

# Create cache directories
TRAIN_CACHE = CACHE_DIR / "shd_train_cache"
TEST_CACHE = CACHE_DIR / "shd_test_cache"

print(f"\nTrain cache: {TRAIN_CACHE}")
print(f"Test cache: {TEST_CACHE}")

# Create cached datasets
print("\nCreating cached training dataset...")
train_dataset = DiskCachedDataset(
    train_dataset_raw,
    transform=transform,
    cache_path=str(TRAIN_CACHE)
)

print("Creating cached test dataset...")
test_dataset = DiskCachedDataset(
    test_dataset_raw,
    transform=transform,
    cache_path=str(TEST_CACHE)
)

print("\n✓ Cached datasets created")
print(f"✓ Train samples: {len(train_dataset)}")
print(f"✓ Test samples: {len(test_dataset)}")

# ============================================================================
# PART G: Verify Preprocessed Data Shape
# ============================================================================
print("\n" + "=" * 80)
print("VERIFYING PREPROCESSED DATA")
print("=" * 80)

# Load a preprocessed sample
sample_frames, sample_label = train_dataset[0]

print(f"\nPreprocessed Sample:")
print(f"  Label: {sample_label}")
print(f"  Frames shape: {sample_frames.shape}")
print(f"  Frames dtype: {sample_frames.dtype}")
print(f"  Memory size: {sample_frames.nbytes / 1024:.2f} KB")

# Interpret shape dynamically based on actual dimensions
print(f"\nShape interpretation:")
if len(sample_frames.shape) == 4:
    # Shape: (time_bins, channels, height, width)
    print(f"  Time bins: {sample_frames.shape[0]}")
    print(f"  Channels: {sample_frames.shape[1]}")
    print(f"  Height: {sample_frames.shape[2]}")
    print(f"  Width: {sample_frames.shape[3]}")
elif len(sample_frames.shape) == 3:
    # Shape: (time_bins, channels, 1) or (time_bins, height, width)
    print(f"  Time bins: {sample_frames.shape[0]}")
    print(f"  Channels: {sample_frames.shape[1]}")
    if sample_frames.shape[2] == 1:
        print(f"  Extra dimension: {sample_frames.shape[2]}")
    else:
        print(f"  Spatial dims: {sample_frames.shape[2]}")
elif len(sample_frames.shape) == 2:
    # Shape: (time_bins, channels)
    print(f"  Time bins: {sample_frames.shape[0]}")
    print(f"  Channels: {sample_frames.shape[1]}")
else:
    print(f"  Unexpected shape with {len(sample_frames.shape)} dimensions")

# Reshape if needed to (time_bins, channels) for simplicity
if len(sample_frames.shape) > 2:
    print(f"\nReshaping from {sample_frames.shape} to (time_bins, channels)...")
    sample_frames = sample_frames.reshape(sample_frames.shape[0], -1)
    print(f"  New shape: {sample_frames.shape}")

# Check spike statistics in binned data
print(f"\nBinned Data Statistics:")
print(f"  Total spikes in bins: {sample_frames.sum():.0f}")
print(f"  Non-zero bins: {(sample_frames > 0).sum():.0f}")
print(f"  Sparsity: {100 * (1 - (sample_frames > 0).sum() / sample_frames.size):.2f}%")
print(f"  Max spikes per bin: {sample_frames.max():.0f}")

# ============================================================================
# PART H: Dataset Summary
# ============================================================================
print("\n" + "=" * 80)
print("DATA PREPROCESSING COMPLETE ✓")
print("=" * 80)

summary = f"""
Dataset Summary:
  ✓ Train samples: {len(train_dataset)}
  ✓ Test samples: {len(test_dataset)}
  ✓ Input channels: {SENSOR_SIZE[0]}
  ✓ Time bins: {NUM_TIME_BINS}
  ✓ Number of classes: 20
  ✓ Data format: Spike frames ({NUM_TIME_BINS}, {SENSOR_SIZE[0]}, 1, 1)
  ✓ Caching: Enabled (disk-cached for fast loading)
  
Ready for:
  → Exploratory Data Analysis (EDA)
  → DataLoader creation
  → SNN model development
"""

print(summary)

SPIKING HEIDELBERG DIGITS (SHD) DATASET

Dataset: Spiking Heidelberg Digits (SHD)
Task: Audio digit classification (0-9 in German)
Classes: 20 classes (10 digits × 2 speakers)
Input: Spike trains from 700 audio channels
Encoding: Cochlear model converting audio to spike events
Train samples: ~8,156 samples
Test samples: ~2,264 samples
Duration: Variable (~1 second typical)
Format: HDF5 files with spike times and neuron indices

Key Characteristics:
- Neuromorphic representation of audio
- Sparse spike events (not dense tensors)
- Temporal information encoded in spike timing
- Requires time-binning for SNN processing

DOWNLOADING SHD DATASET
Download location: /kaggle/working/shd_data

Downloading training set...

Downloading test set...

✓ Dataset downloaded successfully
✓ Train samples: 8156
✓ Test samples: 2264

INSPECTING RAW DATA STRUCTURE

Sample Label (Class): 11

Event Data Structure:
Type: <class 'numpy.ndarray'>
Dtype: [('t', '<i8'), ('x', '<i8'), ('p', '<i8')]
Shape: (4278,)


Analyzing Train Set:   0%|          | 0/500 [00:00<?, ?it/s]


Analyzing Test Set (200 samples)...


Analyzing Test Set:   0%|          | 0/200 [00:00<?, ?it/s]


Train Set Statistics:
  Spikes per sample: 7934.4 ± 2423.0
  Duration (seconds): 0.711 ± 0.134
  Active neurons: 564.9 ± 41.7
  Classes: 20 unique labels
  Class distribution: [27 29 25 32 29 20 23 27 29 22 22 26 18 24 19 23 29 29 27 20]

Test Set Statistics:
  Spikes per sample: 8313.1 ± 2291.3
  Duration (seconds): 0.718 ± 0.113
  Active neurons: 571.4 ± 39.7
  Classes: 20 unique labels
  Class distribution: [ 9  9 10  5 11  6  8  8  6 13 14 13  9  8 10 12 14 15 12  8]

CONFIGURING TIME-BINNING TRANSFORM

Time-Binning Configuration:
  Input channels: 700
  Time bins: 100
  
Rationale:
  - SNNs process data in discrete time steps
  - Time-binning converts continuous spike times into discrete bins
  - Each bin represents a simulation time step
  - Output shape: (Time_bins, Channels) = (100, 700)

✓ Time-binning transform configured

CREATING DISK-CACHED DATASETS

Why Disk Caching?
  - Transform operations are expensive (time-binning for each sample)
  - Caching pre-processes data once