## Inspection

The following code starts off by inspecting the new dataset. We note two limitations with the current dataset:

1. Each file only contains one tremor event, so it is extremely unbalanced if we just slide a window over the data.

2. The amount of data is extremely small

In [None]:
import mne
raw_data = mne.io.read_raw_bdf(f"../new_data/sub-pd3/ses-off/eeg/sub-pd3_ses-off_task-rest_eeg.bdf", preload=True)

events, event_id = mne.events_from_annotations(raw_data)

events = mne.find_events(raw_data, stim_channel="Status", initial_event=True)
print(events)

fig = mne.viz.plot_events(events, sfreq=raw_data.info["sfreq"], first_samp=raw_data.first_samp, event_id=event_id)

raw_data.plot(events=events, n_channels=64, duration=5)

Extracting EDF parameters from /Users/supremegg/Documents/GitHub/parkinsons-tremor-detection/new_data/sub-pd3/ses-off/eeg/sub-pd3_ses-off_task-rest_eeg.bdf...
BDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 102399  =      0.000 ...   199.998 secs...
[]
2 events found on stim channel Status
Event IDs: [    1 65536]
[[    0     0 65536]
 [ 5571     0     1]]
Using pyopengl with version 3.1.6


<mne_qt_browser._pg_figure.MNEQtBrowser at 0x28a9367b0>

Channels marked as bad:
none


## Plan for Execution

### Preprocessing

1. Load these data, apply some filtering and ICA
2. Find the events and epoch them
3. Slide a window over the data to produce data segments
4. Generate labels 0, 1, 2
    - 0: No tremor
    - 1: Pre-tremor (2 seconds before)
    - 2: Tremor (5 seconds)
5. Crop out the sample after 5 seconds after the tremor ends
6. Apply normalization across channels for each segment
7. Save these data segments and their corresponding label

### Model training

1. Each data sample is one epoch / window with its corresponding label
2. Train a batch of these samples with CNN-LSTM with BCE loss
3. Apply k-fold cross-validation since the data is limited

### Inference

1. Take the continuous EEG data and slide a window over it to produce these epochs / windows
2. The model will be able to take in a window and output a prediction (0/1)
3. Keep sliding the window (increments of 1 second or more)

In [16]:
import torch
import numpy as np
import mne
from sklearn.preprocessing import StandardScaler

In [None]:
def load_and_preprocess_bdf(file_paths, 
                            event_id,
                            window_size=3.0, 
                            window_step=0.5, 
                            pre_tremor_window=2.0, 
                            save_path=None):
    """
    Load and preprocess BDF files with sliding windows and multi-class labeling.
    
    Args:
        file_paths: List of BDF file paths
        window_size: Window size in seconds (3-4s recommended for 5s tremor events)
        window_step: Step size for sliding window in seconds (smaller = more overlap)
        pre_tremor_window: Time before tremor onset to label as pre-tremor (seconds)
        save_path: Optional path to save processed data
    
    Returns:
        Tuple of (windows, labels) as PyTorch tensors
    """
    all_windows = []
    all_labels = []

    for file_path in file_paths:
        # 1. Load data and apply basic preprocessing
        raw = mne.io.read_raw_bdf(file_path, preload=True)
        raw.filter(0.5, 50, fir_design='firwin')
        # ica = mne.preprocessing.ICA(n_components=15, random_state=42)
        # ica.fit(raw)
        # eog_indices, eog_scores = ica.find_bads_eog(raw)
        # if eog_indices:
        #     ica.exclude = eog_indices
        # ica.apply(raw)
        
        # 2. Get sampling frequency and calculate window parameters
        sfreq = raw.info['sfreq']
        samples_per_window = int(window_size * sfreq)
        step_samples = int(window_step * sfreq)
        pre_tremor_samples = int(pre_tremor_window * sfreq)
        # NOTE: Somehow it cannot find built-in event_id
        events = mne.find_events(raw, stim_channel='Status')
        
        if 'tremor' not in event_id:
            print(f"Warning: 'tremor' event not found in {file_path}. Available events: {event_id}")
            continue
        
        tremor_events = [evt for evt in events if evt[2] == event_id['tremor']]
        if not tremor_events:
            print(f"Warning: No tremor events found in {file_path}")
            continue
        
        # NOTE: We are guaranteed each sample only contain ONE tremor event
        onset = tremor_events[0][0]
        
        # NOTE: Assuming each tremor lasts about 5 seconds
        tremor_duration_samples = int(5.0 * sfreq)
        tremor_end = onset + tremor_duration_samples
        # NOTE: Trim seach recording 10 seconds after tremor
        buffer_samples = int(10 * sfreq)
        analysis_end_sample = tremor_end + buffer_samples
        
        # 3. Get raw data and create sliding windows
        data = raw.get_data()
        n_samples = data.shape[1]
        analysis_end_sample = min(analysis_end_sample, n_samples)
        
        for start_sample in range(0, analysis_end_sample - samples_per_window, step_samples):
            end_sample = start_sample + samples_per_window
            
            # Extract window
            window = data[:, start_sample:end_sample]
            
            # 4. Determine the label for this window based on its relation to tremor events
            label = 0 # No tremor

            # Check if window ends right before tremor onset (pre-tremor)
            if onset - pre_tremor_samples <= end_sample <= onset:
                label = 1  # Pre-tremor
                
            # Check if window overlaps with tremor event
            elif start_sample <= onset + tremor_duration_samples and end_sample >= onset:
                label = 2  # Full tremor
            
            # 5. Normalize the window
            scaler = StandardScaler()
            window = scaler.fit_transform(window.T).T
            
            all_windows.append(window)
            all_labels.append(label)

    if not all_windows:
        raise ValueError("No valid data was processed. Check input files and event markers.")
    
    all_windows = np.array(all_windows)
    all_labels = np.array(all_labels)

    all_windows = torch.tensor(all_windows, dtype=torch.float32)
    all_labels = torch.tensor(all_labels, dtype=torch.long)
    
    class_counts = np.bincount(all_labels.numpy())
    print(f"Class distribution:")
    print(f"  Class 0 (No tremor): {class_counts[0] if len(class_counts) > 0 else 0}")
    print(f"  Class 1 (Pre-tremor): {class_counts[1] if len(class_counts) > 1 else 0}")
    print(f"  Class 2 (Full tremor): {class_counts[2] if len(class_counts) > 2 else 0}")
    
    if save_path:
        torch.save({'windows': all_windows, 'labels': all_labels}, save_path)
        
    print(f"Processed {len(all_windows)} windows with shape {all_windows.shape}")
    
    return all_windows, all_labels

In [18]:
import os
import glob

window_size = 3.0
window_step = 0.5
pre_tremor_window = 2.0
event_id = {
    "no-tremor": 65536,
    "tremor": 1
}

dirname = "../new_data"
pattern = os.path.join(dirname, "sub-pd*/ses-off/eeg/sub-pd*_ses-off_task-rest_eeg.bdf")
file_paths = glob.glob(pattern)
print(file_paths)

['../new_data/sub-pd3/ses-off/eeg/sub-pd3_ses-off_task-rest_eeg.bdf', '../new_data/sub-pd5/ses-off/eeg/sub-pd5_ses-off_task-rest_eeg.bdf', '../new_data/sub-pd23/ses-off/eeg/sub-pd23_ses-off_task-rest_eeg.bdf', '../new_data/sub-pd12/ses-off/eeg/sub-pd12_ses-off_task-rest_eeg.bdf', '../new_data/sub-pd13/ses-off/eeg/sub-pd13_ses-off_task-rest_eeg.bdf', '../new_data/sub-pd14/ses-off/eeg/sub-pd14_ses-off_task-rest_eeg.bdf', '../new_data/sub-pd22/ses-off/eeg/sub-pd22_ses-off_task-rest_eeg.bdf', '../new_data/sub-pd9/ses-off/eeg/sub-pd9_ses-off_task-rest_eeg.bdf', '../new_data/sub-pd6/ses-off/eeg/sub-pd6_ses-off_task-rest_eeg.bdf', '../new_data/sub-pd11/ses-off/eeg/sub-pd11_ses-off_task-rest_eeg.bdf', '../new_data/sub-pd16/ses-off/eeg/sub-pd16_ses-off_task-rest_eeg.bdf', '../new_data/sub-pd28/ses-off/eeg/sub-pd28_ses-off_task-rest_eeg.bdf', '../new_data/sub-pd17/ses-off/eeg/sub-pd17_ses-off_task-rest_eeg.bdf', '../new_data/sub-pd19/ses-off/eeg/sub-pd19_ses-off_task-rest_eeg.bdf', '../new_data/

In [29]:
save_path = 'processed_data.pt'
mne.set_log_level('WARNING')
preprocessed_data, labels = load_and_preprocess_bdf(file_paths, 
                                                   event_id=event_id,
                                                   window_size=window_size, 
                                                   window_step=window_step,
                                                   pre_tremor_window=pre_tremor_window,
                                                   save_path=save_path)
print(f"Data shape: {preprocessed_data.shape}, Labels shape: {labels.shape}")

5571
2086
6575
3857
2291
1398
2064
1493
2555
1595
2067
7377
1219
1103
3231
Class distribution:
  Class 0 (No tremor): 266
  Class 1 (Pre-tremor): 37
  Class 2 (Full tremor): 238
Processed 541 windows with shape torch.Size([541, 41, 1536])
Data shape: torch.Size([541, 41, 1536]), Labels shape: torch.Size([541])
