In [None]:
# test pipeline for the bruyere dataset
from fn_cfg import *
import params as cfg

In [None]:
#localPath = '/Volumes/Backup Plus/EEG_Datasets/laurel_place/dataset'
#filename = '0002_2_12122019_1225'
#version = 1.0

version = 1.1
filename = '0_1_12072018_1206'
localPath = '/Users/joshuaighalo/Downloads/raw'

device = importFile.neurocatch()
fileObjects = device.init(version,filename,localPath)
rawEEG = fileObjects[0]
rawEOG = fileObjects[1]
rawEEGEOG = fileObjects[2]
time = fileObjects[3]
trigOutput = fileObjects[4]
plots(time,rawEEG,titles=cfg.channelNames,figsize=cfg.figure_size,pltclr=cfg.plot_color)
spectogramPlot(rawEEG,fs,nfft=cfg.nfft,nOverlap=cfg.noverlap,figsize=(16,6),subTitles=cfg.channelNames,title='Music Therapy Group 11')
filtering = filters()
notchFilterOutput = filtering.notch(rawEEG,line,fs)
plots(time,notchFilterOutput,titles=cfg.channelNames,figsize=cfg.figure_size,pltclr=cfg.plot_color)
#spectogramPlot(notchFilterOutput,fs,nfft=cfg.nfft,nOverlap=cfg.noverlap,figsize=(16,6),subTitles=cfg.channelNames,title='Music Therapy Group 11')


In [None]:
#   Comparative Study of Wavelet-Based Unsupervised Ocular Artifact Removal Techniques for Single-Channel EEG Data
#   Signal to Artifact Ratio (SAR) is a quantification method to measure the amount of artifact removal 
#   in a specific signal after processing with an algorithm [40].
#   SAR is a measure of the amount of artifact removal in a signal.
#   x = EEG signal containing artifact
#   y = EEG signal obtained after running an artifact free algorithm

def sar(x,y):
    return 10*np.log10((np.std(x))/(np.std(x-y)))
def mse(x,y):
    return mean_squared_error(x,y)

In [None]:
"""
DWT Only
"""

#   Probability Mapping Based Artifact Detection and Wavelet Denoising based 
#   Artifact Removal from Scalp EEG for BCI Applications


def dwt_only(data,wavelet):
    #  Perform DWT on the data
    #   Input: data - EEG data: 1D array (samples x channel)
    #   Output: signal_global - new signal extracted after global threshold 
    #           signal_std - new signal extracted after std threshold

    def dwt(data,wavelet):
        coeffs = wavedec(data,wavelet,level=10)
        return np.array(coeffs,dtype=object).T

    def global_threshold(data,coeffs):
        def coeffs_approx(data,coeffs):
            return (np.median(abs(coeffs[0]))/0.6745)*(np.sqrt(2*np.log(len(data))))
        def coeffs_detail(data,coeffs):
            return (np.median(abs(coeffs[1]))/0.6745)*(np.sqrt(2*np.log(len(data))))
        arr_approx = coeffs_approx(data,coeffs)
        arr_detail = coeffs_detail(data,coeffs)
        return np.vstack((arr_approx,arr_detail))

    def std_threshold(coeffs):
        def std_approx(coeffs):
            return 1.5*np.std(coeffs[0])
        def std_detail(coeffs):
            return 1.5*np.std(coeffs[1])
        arr_approx = std_approx(coeffs)
        arr_detail = std_detail(coeffs)
        return np.vstack((arr_approx,arr_detail))

    def apply_threshold(coeffs,threshold):
        def apply_threshold_approx(coeffs,threshold):
            coeffs[0][abs(coeffs[0])>threshold[0]] = 0
            coeffs_approx = coeffs[0]
            return coeffs_approx
        def apply_threshold_detail(coeffs,threshold):
            coeffs = coeffs[1:len(coeffs)]
            coeffs[0][abs(coeffs[0])>threshold[1]] = 0
            coeffs[1][abs(coeffs[1])>threshold[1]] = 0
            coeffs[2][abs(coeffs[2])>threshold[1]] = 0
            coeffs[3][abs(coeffs[3])>threshold[1]] = 0
            coeffs[4][abs(coeffs[4])>threshold[1]] = 0
            coeffs[5][abs(coeffs[5])>threshold[1]] = 0
            coeffs[6][abs(coeffs[6])>threshold[1]] = 0
            coeffs[7][abs(coeffs[7])>threshold[1]] = 0
            coeffs[8][abs(coeffs[8])>threshold[1]] = 0
            coeffs[9][abs(coeffs[9])>threshold[1]] = 0
            return coeffs
        arr_approx = apply_threshold_approx(coeffs,threshold)
        arr_detail = apply_threshold_detail(coeffs,threshold)
        arr_detail = list(np.array(arr_detail).T)
        arr_approx = arr_approx
        coefs = arr_detail
        (coefs).insert(0,arr_approx)
        return coefs

    def inv_dwt(coeffs,wavelet):
        def inverse_dwt(coeffs,wavelet):
            return waverec(coeffs,wavelet)
        arr = (inverse_dwt(list(np.array(coeffs,dtype=object)),wavelet))
        return  (np.array(arr).T)[:-1,:]

    coeffs = dwt(data,wavelet)
    threshold_global = global_threshold(data,coeffs)
    threshold_std = std_threshold(coeffs)
    coeffs_global = apply_threshold(coeffs,threshold_global)
    coeffs_std = apply_threshold(coeffs,threshold_std)
    signal_global = inv_dwt(coeffs_global,wavelet)
    signal_std = inv_dwt(coeffs_std,wavelet)
    return signal_global,signal_std


wavelet = 'coif5'
dwt_eeg_gt = dwt_only(notchFilterOutput[:,0],wavelet)[0]
dwt_eeg_st = dwt_only(notchFilterOutput[:,0],wavelet)[1]


plt.plot(time,dwt_eeg_gt[:,1],color='r',label='Global Threshold')
plt.legend()
plt.show()
plt.plot(time,dwt_eeg_st[:,1],color='b',label='STD Threshold')
plt.legend()
plt.show()

In [None]:
# Signal to Noise Ratio
sar_global_1 = sar(notchFilterOutput[:,0],new_signal_global[:,0])
sar_std_1 = sar(notchFilterOutput[:,0],new_signal_std[:,0])
print("SAR for global threshold: ",sar_global_1)
print("SAR for standard deviation threshold: ",sar_std_1)

mse_global_1 = mse(notchFilterOutput[:,0],new_signal_global[:,0])
mse_std_1 = mse(notchFilterOutput[:,0],new_signal_std[:,0])
print("MSE for global threshold: ",mse_global_1)
print("MSE for standard deviation threshold: ",mse_std_1)