Normalizes, trims, and filters NINFEADB file

In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import collections
from scipy.signal import butter, filtfilt, sosfiltfilt, decimate
import sys
sys.path.insert(1, '../src/')
from config import raw_data_path, univariate_data_path, processed_data_path

In [2]:
# Enter datasets you want to preprocess
datasets = ['tpehgt', 'tpehgdb', 'ehgdb1', 'ehgdb2', 'icehgds', 'ninfea', 'nifeadb'] # for when running process_all


In [3]:
def plot_info(dataset): 
    data_path = os.path.join(raw_data_path, dataset + '_data.npy')
    data = np.load(data_path, allow_pickle=True)

    # Print basic dataset info
    print(f"Total number of entries: {len(data)}")
    print(f"First entry keys: {list(data[0].keys())}")  # Check the dictionary structure

    # Extract sequence lengths
    sequence_lengths = np.array([entry['signal'].shape[0] for entry in data])

    # Compute statistics
    print(f"Max sequence length: {np.max(sequence_lengths)}")
    print(f"Min sequence length: {np.min(sequence_lengths)}")
    print(f"Mean sequence length: {np.mean(sequence_lengths):.2f}")
    print(f"Standard deviation of sequence lengths: {np.std(sequence_lengths):.2f}")

    # Check number of channels
    num_channels = set(entry['signal'].shape[1] for entry in data)
    print(f"Unique number of channels in dataset: {num_channels}")

    # Print a sample metadata entry
    print("Sample metadata:", {k: v for k, v in data[0].items() if k != 'signal'})
    print(data[0]['metadata']['fs'])
    print('Number of channels: ', data[0]['signal'].shape[1])


In [4]:
# Function to plot signals
def plot_signals(signal, title='Signals', filename=None):
    num_channels = signal.shape[1]
    fig, axes = plt.subplots(num_channels, 1, figsize=(12, 2 * num_channels), sharex=True)

    for i in range(num_channels):
        axes[i].plot(signal[:, i], label=f'Channel {i+1}')
        axes[i].legend(loc='upper right')
        axes[i].set_ylabel("Amplitude")

    axes[-1].set_xlabel("Time Steps")
    fig.suptitle(title)
    
    # Save the figure
    # plt.savefig(filename, dpi=300, bbox_inches='tight')
    # plt.close()
    plt.show()




In [5]:

def trim_data(data, dataset_name, sec_to_remove=60):
    """
    Trim the first and last `sec_to_remove` seconds from each sequence in the dataset.
    Returns a new dataset with trimmed signals.
    """
    trimmed_data = []
    if dataset_name == 'ninfea':
        sec_to_remove = 10
    
    for entry in data:
        sampling_frequency = int(entry['metadata']['fs'])
        ts_to_remove = sec_to_remove * sampling_frequency
        
        signal = entry['signal']
        trimmed_signal = signal[ts_to_remove:-ts_to_remove] if 2 * ts_to_remove < len(signal) else signal
        
        # Remove specific channels if dataset is 'ninfea'
        if dataset_name == 'ninfea':
            channels_to_remove = [27, 28, 29, 30, 32, 33] # Remove channels containing 0.0 values
            trimmed_signal = np.delete(trimmed_signal, channels_to_remove, axis=1)
        
        # Remove last two channels if dataset is 'tpehgt'
        if dataset_name == 'tpehgt':
            channels_to_remove = [1, 3, 5, 6, 7] # Remove filtered/TOCO channels (TOCO = channel 6 & 7)
            trimmed_signal = np.delete(trimmed_signal, channels_to_remove, axis=1)

        if dataset_name == 'tpehgdb':
            channels_to_remove = [1, 2, 3, 5, 6, 7, 9, 10, 11] # Remove filtered channels
            trimmed_signal = np.delete(trimmed_signal, channels_to_remove, axis=1)

        if dataset_name == 'icehgds':
            channels_to_remove = [1, 3, 5] # Remove filtered channels
            trimmed_signal = np.delete(trimmed_signal, channels_to_remove, axis=1)

        trimmed_data.append({
            'record_name': entry['record_name'],
            'signal': trimmed_signal,
            'metadata': entry['metadata']
        })
    
    return trimmed_data
    

In [6]:
def butter_bandpass_filter(data, lowcut, highcut, fs, order=4):
    """
    Apply a Butterworth bandpass filter to the signal.
    """
    b, a = butter(order, highcut, 'high', fs=fs)
    y = filtfilt(b, a, data, axis=0)
    b, a = butter(order, lowcut, 'low', fs=fs)
    return filtfilt(b, a, y, axis=0)

def filter_data(data, bandwidth=[0.3, 0.4]):
    """
    Filter each channel of the signal within the specified bandwidth.
    """
    filtered_data = []
    
    for entry in data:
        signal = entry['signal']
        fs = entry['metadata']['fs']
        
        # Apply bandpass filter
        filtered_signal = butter_bandpass_filter(signal, lowcut=bandwidth[0], highcut=bandwidth[1], fs=fs)
        
        filtered_data.append({
            'record_name': entry['record_name'],
            'signal': filtered_signal,
            'metadata': entry['metadata']
        })
    
    return filtered_data


In [7]:

def z_normalize_signals(data, epsilon=1e-8):
    """
    Apply z-normalization to each channel in the multivariate time series dataset.
    """
    normalized_entries = []
    
    for entry in data:
        signal = entry['signal']  # Shape: (sequence_length, num_channels)
        
        if signal.ndim == 1:
            signal = signal[:, np.newaxis]  # Ensure 2D array for consistency
        
        # Compute mean and std for each channel separately
        mu = np.mean(signal, axis=0)
        sigma = np.std(signal, axis=0)
        
        # Normalize each channel
        normalized_signal = (signal - mu) / (sigma + epsilon)
        
        # Store the normalized entry
        normalized_entries.append({
            'record_name': entry['record_name'],
            'signal': normalized_signal
        })
    
    return normalized_entries

In [8]:

def check_normalize(data, tol=1e-2):
    """
    Check if the normalized dataset has mean ~0 and variance ~1 for each channel.
    """
    all_correct = True
    incorrect_entries = []
    
    for entry in data:
        signal = entry['signal']
        mean_per_channel = np.mean(signal, axis=0)
        std_per_channel = np.std(signal, axis=0)
        
        incorrect = np.where((np.abs(mean_per_channel) >= tol) | (np.abs(std_per_channel - 1) >= tol))[0]
        if len(incorrect) > 0:
            all_correct = False
            incorrect_entries.append((entry['record_name'], incorrect, mean_per_channel[incorrect], std_per_channel[incorrect]))
    
    if all_correct:
        print("Normalization check passed: All channels have mean ≈ 0 and std ≈ 1.")
    else:
        print("Normalization check failed: Some channels deviate from expected mean and std.")
        for record_name, incorrect, means, stds in incorrect_entries:
            print(f"Record {record_name}: ")
            for ch, mean, std in zip(incorrect, means, stds):
                print(f"  Channel {ch}: mean = {mean:.4f}, std = {std:.4f}")
    

In [9]:
def process_all(datasets):
    for dataset_name in datasets: 
        print('Now preprocessing ', dataset_name)
        data_path = os.path.join(raw_data_path, dataset_name + '_data.npy')
        data = np.load(data_path, allow_pickle=True)
        print('signal before trimming/deleting channels:', data[0]['signal'].shape)
        data = trim_data(data, dataset_name)
        print('signal after trimming/deleting channels:', data[0]['signal'].shape)
        data = filter_data(data)
        data = z_normalize_signals(data)
        normalized_data_path = os.path.join(processed_data_path, dataset_name + "_preprocessed.npy")
        np.save(normalized_data_path, data) 
        print(dataset_name, 'saved!')

def process_one(dataset_name):
    print('Now preprocessing ', dataset_name)
    data_path = os.path.join(raw_data_path, dataset_name + '_data.npy')
    data = np.load(data_path, allow_pickle=True)
    print('signal before trimming/deleting channels:', data[0]['signal'].shape)
    data = trim_data(data, dataset_name)
    print('signal after trimming/deleting channels:', data[0]['signal'].shape)
    data = filter_data(data)
    data = z_normalize_signals(data)
    normalized_data_path = os.path.join(processed_data_path, dataset_name + "_preprocessed.npy")
    np.save(normalized_data_path, data) 
    print(dataset_name, 'saved!')

process_all(datasets)
# process_one('nifeadb')

Now preprocessing  tpehgt
signal before trimming/deleting channels: (35300, 8)
signal after trimming/deleting channels: (32900, 3)


tpehgt saved!
Now preprocessing  tpehgdb
signal before trimming/deleting channels: (35180, 12)
signal after trimming/deleting channels: (32780, 3)
tpehgdb saved!
Now preprocessing  ehgdb1
signal before trimming/deleting channels: (100000, 16)
signal after trimming/deleting channels: (76000, 16)


KeyboardInterrupt: 

In [30]:
print_data = ['icehgds', 'ninfea', 'nifeadb'] # for when running process_all

for dataset_name in print_data: 
    print('Statistics of', dataset_name)
    data_path_original = os.path.join(raw_data_path, dataset_name + '_data.npy')
    data_original = np.load(data_path_original, allow_pickle=True)
    data_path_preprocessed = os.path.join(processed_data_path, dataset_name + "_preprocessed.npy")
    data_preprocessed = np.load(data_path_preprocessed, allow_pickle=True)

    print(f"Total number of entries before: {len(data_original)}")
    print(f"Total number of entries after: {len(data_preprocessed)}")
    # Extract sequence lengths
    sequence_lengths_original = np.array([entry['signal'].shape[0] for entry in data_original])
    sequence_lengths_processed = np.array([entry['signal'].shape[0] for entry in data_preprocessed])

    # Compute statistics
    print(f"Max sequence length: {np.max(sequence_lengths_original)}")
    print(f"Min sequence length: {np.min(sequence_lengths_original)}")
    print(f"Mean sequence length: {np.mean(sequence_lengths_original):.2f}")
    # print(f"Standard deviation of sequence lengths: {np.std(sequence_lengths):.2f}")

    print(f"Max sequence length: {np.max(sequence_lengths_processed)}")
    print(f"Min sequence length: {np.min(sequence_lengths_processed)}")
    print(f"Mean sequence length: {np.mean(sequence_lengths_processed):.2f}")
    # print(f"Standard deviation of sequence lengths: {np.std(sequence_lengths):.2f}")
    # Check number of channels
    # num_channels = set(entry['signal'].shape[1] for entry in data)
    # print(f"Unique number of channels in dataset: {num_channels}")

    # Print a sample metadata entry
    # print("Sample metadata:", {k: v for k, v in data[0].items() if k != 'signal'})
    # print(data[0]['metadata']['fs'])
    # print('Number of channels: ', data[0]['signal'].shape[1])



Statistics of icehgds
Total number of entries before: 126
Total number of entries after: 126
Max sequence length: 38220
Min sequence length: 35040
Mean sequence length: 35505.58
Max sequence length: 35820
Min sequence length: 32640
Mean sequence length: 33105.58
Statistics of ninfea
Total number of entries before: 60
Total number of entries after: 60
Max sequence length: 245306
Min sequence length: 15351
Mean sequence length: 62560.15
Max sequence length: 204346
Min sequence length: 1793
Mean sequence length: 37984.15
Statistics of nifeadb
Total number of entries before: 26
Total number of entries after: 26
Max sequence length: 961521
Min sequence length: 309423
Mean sequence length: 600957.08
Max sequence length: 901521
Min sequence length: 249423
Mean sequence length: 490187.85
