In [1]:
import os
import mne
import random
import re
import shutil

import numpy as np
import pandas as pd
import scipy.io as io

from pywt import wavedec
from typing import List

In [2]:
filename = "resources/dane_eeg.zip"
extract_dir = "resources/dane_eeg"
seizure_occ_file = "resources/dane_eeg/czasy_napadow.m"
healthy_dir = "resources/dane_eeg/zdrowi"
unhealthy_dir = "resources/dane_eeg/chorzy"
img_dir = "resources/dane_eeg/images/"

# shutil.unpack_archive(filename, extract_dir)

In [3]:
def get_mat_files(base_path): 
    mat_files_list = []
    filenames = {}

    for file in os.listdir(base_path):
        if file.endswith(".mat"):
            mat_files_list.append(base_path + "/" + file)
            
    for filename in mat_files_list:        
        filenames[filename.split('/')[-1].split('_')[0]] = filename
        
    return mat_files_list, filenames


healthy_paths, healthy_patients_dict = get_mat_files(healthy_dir)
unhealthy_paths, unhealthy_patients_dict = get_mat_files(unhealthy_dir)

In [4]:
def get_all_seizure_occurrences(file_path):
    all_occurrences = {}
    with open (file_path , 'r') as file:
        for line in file :
            line_number = int(re.search(r'{([0-9]+)}', line).group(1))
            occ_list = list(map(int, re.search(r'\[(.+?)]', line).group(1).split()))
            all_occurrences[line_number] = occ_list
    return all_occurrences


seizure_occurrences = get_all_seizure_occurrences(seizure_occ_file)

In [5]:
class Seizure:
    def __init__ (self, id, filename, samples, seizure_occurrences=None, frequency=100, channels=16) -> None :
        super (). __init__ ()
        self.id = id
        self.filename = filename
        self.samples = samples
        self.seizure_occurrences = seizure_occurrences
        self.frequency = frequency
        self.channels = channels

In [6]:
def healthy_patients(healthy_people_dict):
    healthy_patients = []
    for id, path in healthy_people_dict.items():
        samples = io.loadmat(path)['x']
        filename = os.path.basename(path)
        healthy_patients.append(Seizure(id, filename, samples))
 
    return healthy_patients

def unhealthy_patients(unhealthy_people_dict, seizure_occurrences):
    unhealthy_patients = []
    for id, path in unhealthy_people_dict.items():
        seizure_occ = seizure_occurrences[int(id)]
        samples = io.loadmat(path)['x']
        filename = os.path.basename(path)
        unhealthy_patients.append(Seizure(id, filename, samples, seizure_occurrences=seizure_occ))
 
    return unhealthy_patients

In [7]:
healthy_list: List[Seizure] = healthy_patients(healthy_patients_dict)
unhealthy_list: List[Seizure] = unhealthy_patients(unhealthy_patients_dict, seizure_occurrences)

In [8]:
def saveToPng(figure, filename, img_dir, seizure_time):
    base_filename = os.path.splitext(filename)[0]
    dir_name = img_dir + base_filename
    if not os.path.exists(dir_name):
        os. makedirs(dir_name)
    figure.savefig(dir_name + '/' + base_filename + str(seizure_time) + '.png')


def plot_all_attacks(unhealthy_people, save_to_png=False, img_dir=""):
 # colored blocks for visualization of attack
    def create_annotations(occurrences):
        events = [x for x in occurrences ]
        return mne.Annotations(onset=events, duration=2, description=['seizure ' + str(x) for x in occurrences])

    def visualize(filename, samples, seizure_time, all_events, save_to_png):
        ch_names = list(map(str,range(1,17))) 
        info = mne.create_info(ch_names=ch_names, sfreq =100)
        raw = mne.io.RawArray(np.transpose(samples), info)
        raw.set_annotations(all_events)

        figure = raw.plot(n_channels=16, scalings='auto', title=filename + 'Time_of_seizure: ' + str(seizure_time), block=False, start=max(0, seizure_time - 5), duration=10)
 
        if save_to_png:
            saveToPng(figure, filename, img_dir, seizure_time)

    for person in unhealthy_people:
        all_events = create_annotations(person.seizure_occurrences)
        for seizure_time in person.seizure_occurences:
            visualize(person.filename, person.samples, seizure_time, all_events, save_to_png)
            
def plot_healthy(healthy_people, save_to_png=False, img_dir=""):
 # colored blocks for visualization of attack
    def visualize(filename, samples, seizure_time, all_events, save_to_png):
        ch_names = list(map(str,range(1,17))) 
        info = mne.create_info(ch_names=ch_names, sfreq =100)
        raw = mne.io.RawArray(np.transpose(samples), info)
        raw.set_annotations(all_events)

        figure = raw.plot(n_channels=16, scalings='auto', title=filename + 'Time_of_seizure: ' + str(seizure_time), block=False, start=max(0, seizure_time - 5), duration=10)
 
        if save_to_png:
            saveToPng(figure, filename, img_dir, seizure_time)

    for person in unhealthy_people:
        all_events = create_annotations(person.seizure_occurrences)
        for seizure_time in person.seizure_occurences:
            visualize(person.filename, person.samples, seizure_time, all_events, save_to_png)

In [9]:
# plot_all_attacks(unhealthy_list, save_to_png = True, img_dir=img_dir)
# list(map(str,range(1,17))) 

In [10]:
# def plot_seizures_of_patient(unhealthy_people, patient_nr, channels_no =1):
#     patient = next((x for x in unhealthy_people if x.id == str(patient_nr)), 0)
#     info = mne.create_info(ch_names = list(map(str,range(1,17))), sfreq = 100, ch_types ='eeg')
#     raw = mne.io.RawArray(np.transpose(patient.samples), info)
#     annot = mne.Annotations(onset = patient.seizure_matches, duration = 2, description = patient.seizure_matches) 
#     raw.set_annotations(annot)
#     plot_kwargs = {
#         'scalings': dict(eeg=30),
#         'highpass': 1,
#         'lowpass': 40,
#         'n_channels': channels_no,
#         'duration': 10,
#     }

#     for seizure in patient.seizure_occurences:
#         raw.plot(**plot_kwargs, start = seizure - 2)

In [10]:
for patient in healthy_list:
    patient.samples = np.delete(patient.samples, -1, axis =1)

In [11]:
healthy_epochs = []
for patient in healthy_list:
  samples = patient.samples
  for i in range(0, samples.shape[0] - 100, 200):
    chunk = samples[i:i + 200]
    healthy_epochs.append(('healthy', chunk))


unhealthy_epochs = []
for patient in unhealthy_list:
    samples = patient.samples
    for seizure_time in patient.seizure_occurrences:
        seizure_start = seizure_time * 100
        chunk = samples[seizure_time:seizure_time + 200]
        unhealthy_epochs.append(('unhealthy', chunk))


chunks_healthy_from_sick = []
for patient in unhealthy_list:
    samples = patient.samples
    occurrences = patient.seizure_occurrences
    indexes = [i * 100 for i in occurrences]

    # divide by seizure matches
    chunks = np.split(samples, indexes)
    chunks_healthy = [chunks[0][: -400]]

     # remove 4 seconds of potential seizure from each chunk
    for element in chunks [1:]:
        chunks_healthy.append(element[400:])

     # split each chunk for healthy waves
        for e in chunks_healthy:
            for i in range(0, e.shape [0] - 100, 200):
                chunk = e[i:i + 200]
                chunks_healthy_from_sick.append(('healthy', chunk))

In [12]:
true_healthy = random.sample(healthy_epochs, 1000)
all_epochs = true_healthy + unhealthy_epochs

In [16]:
def wavelet_decompose_channels(data, level):
    data = data[0::2]
    data.columns.name ='channel'
    data_t = data.transpose()

    coeffs_list = wavedec(data_t.values, wavelet ='db4', level = level)

    nums = list(range(1, level + 1))
    names = []
    for num in nums:
        names.append('D' + str(num))
    names.append('A' + str(nums[ -1]))

    # reverse the names
    names = names[:: -1]
    wavelets = pd.DataFrame()

    for i, array in enumerate(coeffs_list):
        lvl_df = pd.DataFrame(array)
        lvl_df.index = data.columns
        lvl_df['level'] = names[i]
        lvl_df = lvl_df.set_index('level', append = True)
        lvl_df = lvl_df.T
        wavelets = pd.concat([ wavelets, lvl_df ], axis=1, sort = True)

    wavelets = wavelets.sort_values(['channel','level'], axis=1)

    to_be_dropped = [x for x in list(wavelets.columns.levels[1]) if not re.compile('D').match(x)]
    decom_wavelets = wavelets.drop(to_be_dropped, axis=1, level ='level')

    decom_wavelets.index.name ='sample'

    return decom_wavelets


class Feature:
    def __init__(self, max, min, mean, std, mean_abs) -> None:
        super().__init__()
        self.max = max
        self.min = min
        self.mean = mean
        self.std = std
        self.mean_abs = mean_abs

    def getall(self):
        import itertools
        return list(itertools.chain(self.max, self.min, self.mean, self.std, self.mean_abs))

all_label_features =[]


def get_features(data):
    max_data = data.max().to_numpy()
    min_data = data.min().to_numpy()
    mean_data = data.mean().to_numpy()
    std_data = data.std().to_numpy()
    mean_abs_data = data.abs().mean().to_numpy()

    return Feature(max_data, min_data, mean_data, std_data, mean_abs_data)


for epoch in all_epochs:
    wavelet = wavelet_decompose_channels(pd.DataFrame(epoch[1]), level=3)
    features = get_features(wavelet)
    all_label_features.append((epoch[0], features))

In [17]:
from sklearn.model_selection import train_test_split

def convert_y_to_number(str):
    return 1.0 if str == 'unhealthy' else 0.0

X = list(x[1].getall() for x in all_label_features ) # data
y = list(convert_y_to_number(x[0]) for x in all_label_features ) # labels

X_train , X_test , y_train , y_test = train_test_split(X, y, test_size=0.2, random_state=np.random.RandomState() )

In [21]:
import tensorflow as tf
from tensorflow import keras

SyntaxError: invalid syntax (pywrap_tensorflow_internal.py, line 114)