In [1]:
import mne
import os
from mne.preprocessing import (create_eog_epochs, create_ecg_epochs,
                               compute_proj_ecg, compute_proj_eog)
import utils.variables as var
import scipy.io as sio
import numpy as np
from utils.valid_recordings import get_valid_recordings
import matplotlib.pyplot as plt
from math import log
%matplotlib qt

In [2]:
class Filtering:
    '''
    A Python class that handles all of the filtering of data, including saving the filtered data to new .mat-files 
    '''
    root         = '/Users/idaandreassen/Desktop/MASTER/stress_detection_EEG/'
    dir_raw      = root + var.DIR_RAW
    dir_ssp      = root + var.DIR_SSP
    dir_psd      = root + var.DIR_PSD
    dir_decomp   = root + var.DIR_DECOMP

    sfreq = var.SFREQ
    ch_type = 'eeg'
    n_channels = var.NUM_CHANNELS

    def __init__(self, sub_nr, ses_nr, run_nr):
        self.sub_nr = sub_nr
        self.ses_nr = ses_nr
        self.run_nr = run_nr

        self.load_data()

        #Create MNE RawArray
        info = mne.create_info(8, sfreq=self.sfreq, ch_types=self.ch_type, verbose=None)
        self.raw_arr = mne.io.RawArray(self.data, info)
        
        mne.rename_channels(self.raw_arr.info, var.MAPPING)

        montage = mne.channels.make_standard_montage('standard_1020')
        self.raw_arr.set_montage(montage)  

        self.filtered_arr = self.init_filter()

    #-----------------------------------------------FUNCTIONS-----------------------------------------

    def load_data(self):
        dir      = self.dir_raw
        data_key = 'raw_eeg_data'

        # Load one recording
        filename  = f"/sub-{self.sub_nr}_ses-{self.ses_nr}_run-{self.run_nr}.mat"
        f         = dir + filename
        self.data = sio.loadmat(f)[data_key]

    def save_ssp_data(self):
        title = f"sub-{self.sub_nr}_ses-{self.ses_nr}_run-{self.run_nr}"
        self.ssp.apply_proj(True)
        clean_data = self.ssp.to_data_frame(scalings=1)
        clean_data = clean_data.to_numpy()
        clean_data = np.transpose(clean_data)
        clean_dict = {
            "Clean_data" : clean_data[1:, :]  #First column of dataFrames is not data
        }
        sio.savemat(f'{self.dir_ssp}/{title}.mat', clean_dict)

    def save_decomp_data(self):
        title = f"sub-{self.sub_nr}_ses-{self.ses_nr}_run-{self.run_nr}"
        clean_data = self.decomp_data.to_data_frame(scalings=1)
        clean_data = clean_data.to_numpy()
        clean_data = np.transpose(clean_data)
        clean_dict = {
            "Decomp_data" : clean_data[1:, :]  #First column of dataFrames is not data
        }
        sio.savemat(f'{self.dir_decomp}/{title}.mat', clean_dict)

    def save_psd(self, data_type):
        title = f"sub-{self.sub_nr}_ses-{self.ses_nr}_run-{self.run_nr}"
        psd_data = self.psd.get_data()
        psd_dict = {
            "psd_data" : psd_data
        }
        if data_type == 'raw':
            sio.savemat(f'{self.dir_psd}/PSD/{title}.mat', psd_dict)
        elif data_type == 'filtered':
            sio.savemat(f'{self.dir_psd}/PSD_filtered/{title}.mat', psd_dict)

    def init_filter(self):
        band_pass = self.raw_arr.copy().filter(1, 50)
        #sav_gol = band_pass.copy().savgol_filter(h_freq=35, verbose=False)
        notch = band_pass.copy().notch_filter(freqs=[50,100], trans_bandwidth = 0.5)
        return notch

    def compute_and_save_filtered_psd(self, data_type):
        if data_type == 'raw':
            self.psd = self.raw_arr.copy().compute_psd()
            #self.psd.plot()
            #self.save_psd(data_type)
        elif data_type == 'filtered':
            self.psd = self.filtered_arr.copy().compute_psd()
            #self.psd.plot()
            self.save_psd(data_type)
        else:
            print('No data with data_type = {data_type} found')
    
    def decompose_data(self, freq_band, method):
        '''
        Decompose dataset

        Parameters
        ----------
        freq_band : str
            The frequency band in which to decompose the signal. Either gamma, beta, alpha, theta, or delta
        method : str
            Either fir or iir. Decides which filtering method to use.
        '''
        if freq_band == 'gamma':
            self.decomp_data = self.filtered_arr.copy().filter(30, 45, method=method)
        elif freq_band == 'beta':
            self.decomp_data = self.filtered_arr.copy().filter(12, 30, method=method)
        elif freq_band == 'alpha':
            self.decomp_data = self.filtered_arr.copy().filter(8, 12, method=method)
        elif freq_band == 'theta':
            self.decomp_data = self.filtered_arr.copy().filter(4, 8, method=method)
        elif freq_band == 'delta':
            self.decomp_data = self.filtered_arr.copy().filter(0, 4, method=method)
        else:
            print('No frequency band matching {freq_band} found')
    
    def vizualize_artifacts(self):
        channels = ['Fp2', 'FT9']
        for ch_name in channels:
            self.eog_evoked = create_eog_epochs(self.filtered_arr, ch_name=ch_name).average(picks='all')
            self.eog_evoked.apply_baseline((None, None))
            self.eog_evoked.plot_joint()

    def compute_SSP_projectors(self):
        channels = ['Fp2', 'FT9']
        for ch_name in channels:
            self.eog_projs, _ = compute_proj_eog(self.filtered_arr, n_grad=0, n_mag=0, n_eeg=1, reject=None,
                                no_proj=True, ch_name=ch_name) #check if grad and mag can be 0
            
        
    def plot_eog_projectors(self):
        fig = mne.viz.plot_projs_joint(self.eog_projs, self.eog_evoked, 'Fp2')
        fig.suptitle('EOG projectors')

        fig = mne.viz.plot_projs_joint(self.eog_projs, self.eog_evoked, 'FT9')
        fig.suptitle('EOG projectors')
        
        
    def plot_ssp(self):
        for title in ('Without', 'With'):
            if title == 'With':
                self.ssp = self.filtered_arr.add_proj(self.eog_projs)
            with mne.viz.use_browser_backend('matplotlib'):
                fig = self.filtered_arr.plot() # original contents of plot function: order=artifact_picks, n_channels=len(artifact_picks)
            fig.subplots_adjust(top=0.9)  # make room for title
            fig.suptitle('{} EOG projectors'.format(title), size='xx-large',
                        weight='bold')
    
    
    def plot_psd_comparison(self, filename1, filename2, title_psd, title_ch_comparison):
        psd1 = sio.loadmat(filename1)
        psd2 = sio.loadmat(filename2)

        fig = plt.figure(1)
        fig.set_figwidth(20)
        fig.set_figheight(10)

        length = len(psd1['psd_data'][0])

        psd1_arr = np.empty([var.NUM_CHANNELS, length])
        psd2_arr = np.empty([var.NUM_CHANNELS, length])

        for ch in range(var.NUM_CHANNELS): 
            for i in range(length):
                db1 = 10*log(psd1['psd_data'][ch][i])
                psd1_arr[ch][i] = db1

                db2 = 10*log(psd2['psd_data'][ch][i])
                psd2_arr[ch][i] = db2

            plt.plot(psd1_arr[ch], color='#3cbfbe')
            plt.plot(psd2_arr[ch], color='#b01b81')
        plt.legend(['Not stressed', 'Stressed'])
        plt.xlabel('Hz')
        plt.ylabel('dB')
        plt.title(title_psd)
        plt.show()

        plt.figure(2)
        plt.suptitle(title_ch_comparison)

        psd1_1 = np.empty([var.NUM_CHANNELS, length])
        psd2_1 = np.empty([var.NUM_CHANNELS, length])

        for i, ch in enumerate(var.CHANNELS):
            for j in range(length):
                db1 = 10*log(psd1['psd_data'][i][j])
                psd1_1[i][j] = db1

                db2 = 10*log(psd2['psd_data'][i][j])
                psd2_1[i][j] = db2

            ax = plt.subplot(2, 4, i+1)
            ax.plot(psd1_1[i], color='#3cbfbe')
            ax.plot(psd2_1[i], color='#b01b81')
            #psd1_1[i].plot(ax=ax, color='#3cbfbe')
            #psd2_1[i].plot(ax=ax, color='#b01b81')
            ax.legend(['Not stressed', 'Stressed'])
            ax.set_title('Channel: %s' %ch)


In [3]:
valid_recordings = get_valid_recordings('raw')
for recordings in valid_recordings:
    sub_nr, ses_nr, run_nr = recordings.split('_')
    test = Filtering(sub_nr, ses_nr, run_nr)
    test.vizualize_artifacts()
    test.compute_SSP_projectors()
    test.plot_eog_projectors()
    test.plot_ssp()
    test.save_ssp_data()
    

Creating RawArray with float64 data, n_channels=8, n_times=75000
    Range : 0 ... 74999 =      0.000 ...   299.996 secs
Ready.
Creating RawArray with float64 data, n_channels=8, n_times=75000
    Range : 0 ... 74999 =      0.000 ...   299.996 secs
Ready.
Creating RawArray with float64 data, n_channels=8, n_times=75000
    Range : 0 ... 74999 =      0.000 ...   299.996 secs
Ready.
Creating RawArray with float64 data, n_channels=8, n_times=75000
    Range : 0 ... 74999 =      0.000 ...   299.996 secs
Ready.
Creating RawArray with float64 data, n_channels=8, n_times=75000
    Range : 0 ... 74999 =      0.000 ...   299.996 secs
Ready.
Creating RawArray with float64 data, n_channels=8, n_times=75000
    Range : 0 ... 74999 =      0.000 ...   299.996 secs
Ready.
Creating RawArray with float64 data, n_channels=8, n_times=75000
    Range : 0 ... 74999 =      0.000 ...   299.996 secs
Ready.
Creating RawArray with float64 data, n_channels=8, n_times=75000
    Range : 0 ... 74999 =      0.000 ..

ERROR:root:Failed to read data for recording P006_S002_001
ERROR:root:Failed to read data for recording P006_S002_002


Creating RawArray with float64 data, n_channels=8, n_times=75000
    Range : 0 ... 74999 =      0.000 ...   299.996 secs
Ready.
Creating RawArray with float64 data, n_channels=8, n_times=75000
    Range : 0 ... 74999 =      0.000 ...   299.996 secs
Ready.
Creating RawArray with float64 data, n_channels=8, n_times=75000
    Range : 0 ... 74999 =      0.000 ...   299.996 secs
Ready.
Creating RawArray with float64 data, n_channels=8, n_times=75000
    Range : 0 ... 74999 =      0.000 ...   299.996 secs
Ready.
Creating RawArray with float64 data, n_channels=8, n_times=75000
    Range : 0 ... 74999 =      0.000 ...   299.996 secs
Ready.
Creating RawArray with float64 data, n_channels=8, n_times=75000
    Range : 0 ... 74999 =      0.000 ...   299.996 secs
Ready.
Creating RawArray with float64 data, n_channels=8, n_times=75000
    Range : 0 ... 74999 =      0.000 ...   299.996 secs
Ready.
Creating RawArray with float64 data, n_channels=8, n_times=75000
    Range : 0 ... 74999 =      0.000 ..

ERROR:root:Failed to read data for recording P028_S002_001
ERROR:root:Failed to read data for recording P028_S002_002


FileNotFoundError: [Errno 2] No such file or directory: '/Users/idaandreassen/Desktop/MASTER/stress_detection_EEG/Data/Raw_eeg/sub-P001_ses-S001_run-001.mat'

In [None]:
valid_recordings = get_valid_recordings('raw')
for recordings in valid_recordings:
    sub_nr, ses_nr, run_nr = recordings.split('_')
    test = Filtering(sub_nr, ses_nr, run_nr)
    test.compute_and_save_filtered_psd('filtered')
    

In [None]:
valid_recordings = get_valid_recordings('raw')
for recordings in valid_recordings:
    sub_nr, ses_nr, run_nr = recordings.split('_')
    test = Filtering(sub_nr, ses_nr, run_nr)
    test.decompose_data('delta', 'iir')
    test.save_decomp_data()

In [None]:
participant = 'P002'
session = 'S001'
run = '001'

In [None]:
filt_arr = Filtering(sub_nr=participant, ses_nr=session, run_nr=run)
#filt_arr.vizualize_artifacts()
#filt_arr.compute_SSP_projectors()
#filt_arr.plot_eog_projectors()
#filt_arr.plot_ssp()
#filt_arr.save_ssp_data()

In [None]:
filt_arr.decompose_data('delta', 'iir')
filt_arr.compute_and_save_filtered_psd('filtered')


In [None]:
filt_arr.filtered_arr.plot()

In [None]:
filt_arr.save_decomp_data()

In [None]:
fname_not_stressed = 'Data/PSD_data/PSD_filtered/sub-P002_ses-S001_run-001.mat'
fname_stressed = 'Data/PSD_data/PSD_filtered/sub-P002_ses-S001_run-002.mat'

title_psd = 'Comparing filtered data PSD between P002_S001_001 and P002_S001_002'
title_ch_comparison = 'Comparing channel by channel for stressed and not-stressed between P002_S001_002 and P002_S001_001'

test.plot_psd_comparison(fname_not_stressed, fname_stressed, title_psd=title_psd, title_ch_comparison=title_ch_comparison)

In [None]:
filt_arr.compute_and_save_filtered_psd('filtered')
filt_arr.psd.plot_topomap()


In [None]:
# valid_recordings = get_valid_recordings('raw')
# for recordings in valid_recordings:
#     sub_nr, ses_nr, run_nr = recordings.split('_')
#     test = Filtering(sub_nr, ses_nr, run_nr)
#     test.compute_and_save_filtered_psd()

# test.plot_psd_comparison('Data/PSD_data/sub-P002_ses-S001_run-001.mat', 'Data/PSD_data/sub-P002_ses-S001_run-002.mat')


In [None]:
# title = f"sub-{self.sub_nr}_ses-{self.ses_nr}_run-{self.run_nr}"
# clean_data = self.ssp.to_data_frame(scalings=1)
# clean_data = clean_data.to_numpy()
# clean_data = np.transpose(clean_data)
# clean_dict = {
#     "Clean_data" : clean_data[1:, :]  #First column of dataFrames is not data
# }
# sio.savemat(f'{self.dir_ssp}/{title}.mat', clean_dict)
# filt_arr.ssp.apply_proj(True)
# title = f"sub-{participant}_ses-{session}_run-{run}"
# clean_data = filt_arr.ssp.to_data_frame(scalings=1)
# clean_data = clean_data.to_numpy()
# clean_data = np.transpose(clean_data)
# clean_dict = {"Clean_data" : clean_data[1:, :]}

# sio.savemat(f'{var.DIR_SSP}/test2.mat', clean_dict)

# from utils.data_processing import read_eeg_data, read_mat_data
#testing = read_mat_data('Data/SSP_eeg/sub-P001_ses-S001_run-001.mat')['Clean_data']
#info = mne.create_info(8, sfreq=var.SFREQ, ch_types= 'eeg', verbose=None)
#raw_array = mne.io.RawArray(testing, info)
#mapping = {'0':'F4','1':'Fp2','2':'C3','3':'FC6','4':'O1','5':'Oz','6':'FT9','7':'T8'}
#mne.rename_channels(raw_array.info, mapping)
#raw_array.add_proj(filt_arr.eog_projs).plot()
#filt_arr.ssp.plot()
# test_dict = {'Clean_data' : raw_array.to_data_frame(scalings=1)}
# sio.savemat(f'{var.DIR_SSP}/test.mat', test_dict)
# test2 = read_mat_data('Data/SSP_eeg/sub-P001_ses-S001_run-001.mat')['Clean_data']
# info = mne.create_info(8, sfreq=var.SFREQ, ch_types= 'eeg', verbose=None)
# raw_array2 = mne.io.RawArray(test2, info)
# mapping2 = {'0':'F4','1':'Fp2','2':'C3','3':'FC6','4':'O1','5':'Oz','6':'FT9','7':'T8'}
# mne.rename_channels(raw_array2.info, mapping2)
# raw_array2.plot()
#filt_arr.ssp.save()

# testy = read_eeg_data('ssp', 'Data/SSP_eeg/test2.mat')
# testy.plot()