In [None]:
from mne import io
from sklearn.utils import shuffle
import scipy.signal
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
import json

import pickle
import gzip

def filt_coef(order, cutoff, filt_type, sampleRate):
    Nq = sampleRate//2
    Wn = cutoff / Nq
    b, a = scipy.signal.butter(order, Wn, filt_type)
    return b, a

In [None]:
home = r"..\1230\data\data3\raw"

In [None]:
# Get all edf files under this directory

ls = os.listdir(home)
newls = []
for fname in ls:
    if fname.endswith('.edf'):
        newls.append(fname)
newls

In [None]:
# # Or uncomment this block to customize the files you want to include in your dataset

# newls = [
#     '0201_11.edf',
#     '0201_12.edf',
#     '0201_13.edf'
# ]

In [None]:
window_length = 1000 # length of the input array needed
sampleRate = 1000 # sample rate of raw data

datacuts = []
labels = []
threshold = 100 # This is to reduce noise in data: cut all the data points beyond [-threshold, threshold]

for fname in newls:
    fpath = os.path.join(r"..\data\data2\raw", fname)

    r = io.read_raw_edf(fpath)
    r.drop_channels(['ECG', 'VEOG', 'HEOG'])
    unfilter_datacut = r.load_data()[['Fp1','Fp2']][0]
    unfilter_datacut = np.array(unfilter_datacut)
    
    # filter the data
    b, a = filt_coef(2, 30, 'lowpass', sampleRate)
    d, c = filt_coef(5, 1, 'highpass', sampleRate)    
    unfilter_datacut[0] = scipy.signal.filtfilt(b, a, unfilter_datacut[0])
    unfilter_datacut[1] = scipy.signal.filtfilt(b, a, unfilter_datacut[1])
    unfilter_datacut[0] = scipy.signal.filtfilt(d, c, unfilter_datacut[0])
    unfilter_datacut[1] = scipy.signal.filtfilt(d, c, unfilter_datacut[1])
    
    
    # cut all the data points beyond [-threshold, threshold]
    unfilter_datacut = np.transpose(unfilter_datacut)
    print(unfilter_datacut.shape)
    datacut = []
    for i,j in unfilter_datacut:
        if abs(i) <= threshold and abs(j) <= threshold:
            datacut.append([i,j])
    datacut = np.transpose(np.array(datacut))
    print("datacut:", datacut.shape)

    
    # normalise to [0, 1]
    datacut[0] = datacut[0] - min(datacut[0])
    datacut[1] = datacut[1] - min(datacut[1])
    scale = 2 * threshold
    datacut[0] = datacut[0] / scale
    datacut[1] = datacut[1] / scale

    # resample the array to 500Hz
    datacut = datacut[:, ::2]

    
    # cut the long array to short ones, length: 2s.
    l = datacut.shape[1] # data points
    n = int(l / window_length) # number of cuts
    # Triple classification: L/M/H
    label = int(fname.split("_")[0][1]) - 1.0 # 0,1,2
    for start in [i * window_length for i in list(range(n))]:
#         label = int(fname.split('_')[-2][1]) - 1
        arr = datacut[:, start: start + window_length]
        datacuts.append(arr)
        labels.append(label)

    print(len(datacuts))


In [None]:
# shuffle the data
datacuts, labels = shuffle(datacuts, labels, random_state=0)

In [None]:
datacuts = np.array(datacuts)
datacuts.shape

In [None]:
# count number of samples in each class to make sure it's balanced
from collections import Counter
d = Counter(labels)
d

In [None]:
labels[0:10]

In [None]:
labels = np.array(labels)
labels.shape

In [None]:
for i in range(len(labels)):
    plt.plot(datacuts[i][0])

In [None]:
# save your dataset, remember to change the file name!

dataset = [datacuts, labels]

save_file = "..\\data\\data3_attention\\fatigue_sub1101_low_fp1fp2_2s_resampled500Hz.pkl"
with open(save_file, 'wb') as handle:
            # pickle.dump(dataset, handle, protocol=pickle.HIGHEST_PROTOCOL)
    pickle.dump(dataset, handle, protocol=2)
    print("saved:", save_file)