In [1]:
import mne
import numpy as np
import numpy.typing as npt
import scipy
import pywt

mne.set_log_level('WARNING')
mne.set_config('MNE_BROWSE_RAW_SIZE','16,8')

# Types
from typing import Annotated, Literal, TypeVar

DType = TypeVar("DType", bound=np.generic)
ArrayN = Annotated[npt.NDArray[DType], Literal['N']]

### Get Bonn University data
Sets A and B are from healthy patients, while Sets C, D, and E are from epileptic patients.

Sets C and D are seizure-free segments while set E is during a seizure.

Each set contains 100 single-channel EEG segments of 23.6-sec duration.

In [2]:
def get_data(channels: int, time_points: int) -> dict[str, np.ndarray]:
    data = dict()

    # maps from set_letter to set_letter_alternate (for the filenames)
    sets = {
        'A': 'Z',
        'B': 'O',
        'C': 'N',
        'D': 'F',
        'E': 'S'
    }

    for set_letter in sets:
        set_letter_alternate = sets[set_letter]

        set_data = np.zeros((channels, time_points))
        for i in range(channels):
            filename = f'data/bonn/SET {set_letter}/{set_letter_alternate}{str(i+1).zfill(3)}.txt'
            z = np.loadtxt(filename)
            set_data[i] = z[:time_points]

        data[set_letter] = set_data

    return data

channels, time_points = 100, 4096
freq = 173.61

set_letter = 'D'
data = get_data(channels, time_points)

info = mne.create_info(
    ch_names=[f'c{i}' for i in range(channels)],
    sfreq=freq,
    ch_types='eeg',
)
raw = mne.io.RawArray(data[set_letter], info)

# Preprocessing
### Filtering

In [3]:
# raw.copy().compute_psd().plot()

# raw.copy().plot(duration=5, n_channels=15, scalings=500)

raw = raw.notch_filter(freqs=50)
raw = raw.filter(l_freq=0.1, h_freq=50)

# raw.copy().compute_psd().plot();
# raw.copy().plot(duration=5, n_channels=15, scalings=500);


  raw = raw.filter(l_freq=0.1, h_freq=50)


### Segmentation into intervals

In [4]:
# TODO: see if the thing is actually split into time inveral??

total_seconds = time_points/freq
interval_length = 5  # in seconds

# create a list of numpy arrays, each containing the data of a single interval
intervals = []
t = 0
while t < time_points/freq:
    intervals.append(data[set_letter][:, int(t*freq):int((t+interval_length)*freq)])
    t += interval_length

# Feature Extraction
### Define 1D multilevel DWT function for each time interval (using Daubechies 4)

In [5]:
def discrete_wavelet_transform(interval_data):
    # 1D multilevel DWT
    cA4, cD4, cD3, cD2, cD1 = pywt.wavedec(interval_data, wavelet='db4', level=4)
    # low frequencies => high time resolution, low freq resolution
    # high frequences => low time resolution, high freq resolution
    # print(cA4.shape)  # 0.1-4 Hz   60
    # print(cD4.shape)  # 4-8 Hz     60
    # print(cD3.shape)  # 8-15 Hz    114
    # print(cD2.shape)  # 15-30 Hz   222
    # print(cD1.shape)  # 30-60 Hz   437
    return cA4, cD4, cD3, cD2, cD1

### Define feature vector functions

In [6]:
def row_to_col_transpose(array):
    return np.atleast_2d(array).T

def variance(data) -> ArrayN[np.float64]:
    variance_data = np.var(data, axis=1)
    return variance_data  # shape (# channels,)

def standard_deviation(data) -> ArrayN[np.float64]:
    std_data = np.std(data, axis=1)
    return std_data

def kurtosis(data) -> ArrayN[np.float64]:
    # fisher = True is default, which subtracts 3 from the final value
    return scipy.stats.kurtosis(data, axis=1, fisher=False)

def nn_shannon_entropy(data) -> ArrayN[np.float64]:
    # non normalized shannon entropy - a measure of how uncertain the data is (or how surprising it is?)
    squared = data**2
    return np.sum(squared * np.log(squared), axis=1)

def logarithmic_band_power(data) -> ArrayN[np.float64]:
    n = data.shape[1]
    return np.log(np.sum(data**2, axis=1) / n)

def compute_features(data):
    features = np.array([
        logarithmic_band_power(data),
        standard_deviation(data),
        variance(data),
        kurtosis(data),
        nn_shannon_entropy(data)
    ], dtype=np.float64)
    return features.T

print(variance(intervals[0]))
print(standard_deviation(intervals[0]))
print(kurtosis(intervals[0]))
print(nn_shannon_entropy(intervals[0]))
print(logarithmic_band_power(intervals[0]))

[  1106.57125963  18057.13239426   5822.6473265    1465.21438454
   7886.70163098    726.69006524   1092.40755347  15418.81284097
 192363.22856957  24736.46406436   3945.20277155   1992.04540751
    675.73189666   1073.18770997   1089.72145668   2867.99609734
   1797.86219571   4869.08348823   1227.14138523   2564.35665525
  15969.24571751   5702.68892716   4798.46200809   6437.22644158
   8589.1930701     459.48913065   4782.77322368   1609.36989875
  11323.58775728  18451.3811191    9024.61437693   4675.99322293
   1964.66406687    664.17162782   2189.62519684   1350.1072371
   1999.19629242   2739.62317687   4290.07288157   1184.52778302
   3677.17976092    935.55260503   4377.23824592   3822.14173083
   2364.1111627    2691.42764855   3763.63615856   3284.52486574
   2604.50388031   4101.3807881    6204.24254429   1675.90763071
   2156.03616072   5071.5168869    3210.16551451   1406.68449357
   2825.60670507   1423.79595239   1748.82725509  72937.32968343
  12957.28611099   3050.12

  return np.sum(squared * np.log(squared), axis=1)
  return np.sum(squared * np.log(squared), axis=1)
