In [None]:
import mne
import numpy as np
import numpy.typing as npt
import scipy
import pywt

mne.set_log_level('WARNING')
mne.set_config('MNE_BROWSE_RAW_SIZE','16,8')

# Types
from typing import Annotated, Literal, TypeVar

DType = TypeVar("DType", bound=np.generic)
ArrayN = Annotated[npt.NDArray[DType], Literal['N']]

### Get Bonn University data
Sets A and B are from healthy patients, while Sets C, D, and E are from epileptic patients.

Sets C and D are seizure-free segments while set E is during a seizure.

Each set contains 100 single-channel EEG segments of 23.6-sec duration.

In [None]:
def get_data(channels: int, time_points: int) -> dict[str, np.ndarray]:
    data = dict()

    # maps from set_letter to set_letter_alternate (for the filenames)
    sets = {
        'A': 'Z',
        'B': 'O',
        'C': 'N',
        'D': 'F',
        'E': 'S'
    }

    for set_letter in sets:
        set_letter_alternate = sets[set_letter]

        set_data = np.zeros((channels, time_points))
        for i in range(channels):
            filename = f'data/bonn/SET {set_letter}/{set_letter_alternate}{str(i+1).zfill(3)}.txt'
            z = np.loadtxt(filename)
            set_data[i] = z[:time_points]

        data[set_letter] = set_data

    return data

channels, time_points = 100, 4096
freq = 173.61

set_letter = 'D'
data = get_data(channels, time_points)

info = mne.create_info(
    ch_names=[f'c{i}' for i in range(channels)],
    sfreq=freq,
    ch_types='eeg',
)
raw = mne.io.RawArray(data[set_letter], info)

# Preprocessing
### Filtering

In [None]:
# raw.copy().compute_psd().plot()

# raw.copy().plot(duration=5, n_channels=15, scalings=500)

raw = raw.notch_filter(freqs=50)
raw = raw.filter(l_freq=0.1, h_freq=50)

# raw.copy().compute_psd().plot();
# raw.copy().plot(duration=5, n_channels=15, scalings=500);


### Segmentation into intervals

In [None]:
# TODO: see if the thing is actually split into time inveral??

total_seconds = time_points/freq
interval_length = 5  # in seconds

# create a list of numpy arrays, each containing the data of a single interval
intervals = []
t = 0
while t < time_points/freq:
    intervals.append(data[set_letter][:, int(t*freq):int((t+interval_length)*freq)])
    t += interval_length

# Feature Extraction
### Define 1D multilevel DWT function for each time interval (using Daubechies 4)

In [None]:
def discrete_wavelet_transform(interval_data):
    # 1D multilevel DWT
    pass