# Deep Identification of Subject Age from EEG Data

This project aims to recreate the ES1D model with a new dataset.

## Preprocessing

In [91]:
import pandas as pd
import numpy as np
import random as rd
import os
from scipy import signal
import matplotlib.pyplot as plt

In [92]:
# Experiment specific parameters

FS = 128 # sampling frequency
F_LOW = 4 # lower cutoff frequency
F_HIGH = 48 # higher cutoff frequency
N_TAPS = 212 # number of taps for the filter
NPERSEG = 768 # length of the segment

In [93]:
# function to generate n number of filepaths from a given directory
def get_filepaths(dir_path, n):

    # n random filepaths
    file_paths = rd.sample(os.listdir(dir_path), n)

    return file_paths

In [94]:
# load training data from the various csv files
# stored in data/train with each subject in a separate file

# first row of each file contains the age in '# Age = 30' format
# so we need to extract that and store it in a separate column
# and then drop the first row and the starting '#' from the second row (remaining headers)

# load function for n number of files, uses random selections

def load_data(filepath):
    # load n random files from data/train
    # return a dataframe with all the data
    
    # append the './data/train/' to the filepath
    filepath = './data/train/' + filepath

    # get rid of the '#' and 'Age = ' from the first row
    # and store the age in a separate column
    with open(filepath, 'r') as f:
        age = f.readline().replace('#', '').replace('Age = ', '').strip()

    # load the data from the file
    df = pd.read_csv(filepath, skiprows=1)

    # add a column for age
    df['age'] = age

    # # reset the index
    # df = df.reset_index(drop=True)

    # return the dataframe 
    return df

In [95]:
# actually load the data
df = load_data('00000021_s003_t000.csv')
df

Unnamed: 0,# EEG FP1-REF,EEG FP2-REF,EEG F3-REF,EEG F4-REF,EEG C3-REF,EEG C4-REF,EEG P3-REF,EEG P4-REF,EEG O1-REF,EEG O2-REF,...,EEG ROC-REF,EEG LOC-REF,EEG EKG1-REF,EEG T1-REF,EEG T2-REF,PHOTIC-REF,IBI,BURSTS,SUPPR,age
0,4907.684,4722.900,4908.142,4623.871,2863.313,4999.847,4907.989,4959.869,4999.694,4907.837,...,4958.648,4702.911,4958.648,-2705.378,4818.268,0.0,0.0,0.0,0.0,28
1,4919.281,4715.118,4918.365,4554.443,2626.802,4999.847,4919.128,4952.087,4999.847,4919.433,...,4951.782,4687.958,4952.240,-2510.828,4830.780,0.0,0.0,0.0,0.0,28
2,4906.006,4743.347,4906.769,4631.653,2677.919,4999.847,4906.463,4946.136,4999.847,4906.311,...,4945.526,4724.274,4945.373,-2560.724,4829.864,0.0,0.0,0.0,0.0,28
3,4898.071,4690.704,4897.461,4548.797,2521.974,4999.847,4898.071,4939.422,4999.847,4898.224,...,4938.659,4664.917,4938.965,-2387.537,4804.992,0.0,0.0,0.0,0.0,28
4,4897.613,4727.630,4897.766,4598.846,2487.642,4999.847,4897.766,4932.708,4999.847,4897.766,...,4932.403,4706.268,4932.403,-2369.074,4833.526,0.0,0.0,0.0,0.0,28
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
316995,0.004,0.004,0.004,0.004,0.004,0.004,0.004,0.004,0.004,0.004,...,0.004,0.004,0.004,0.004,0.004,0.0,0.0,0.0,0.0,28
316996,0.004,0.004,0.004,0.004,0.004,0.004,0.004,0.004,0.004,0.004,...,0.004,0.004,0.004,0.004,0.004,0.0,0.0,0.0,0.0,28
316997,0.004,0.004,0.004,0.004,0.004,0.004,0.004,0.004,0.004,0.004,...,0.004,0.004,0.004,0.004,0.004,0.0,0.0,0.0,0.0,28
316998,0.004,0.004,0.004,0.004,0.004,0.004,0.004,0.004,0.004,0.004,...,0.004,0.004,0.004,0.004,0.004,0.0,0.0,0.0,0.0,28


In [96]:
# define Welch's PSD estimate function

def welch_psd(x, fs, nperseg):
    # calculate the PSD using Welch's method
    # x is the signal
    # fs is the sampling frequency
    # nperseg is the length of the segment
    # returns the PSD and the frequencies

    # calculate the PSD
    f, Pxx = signal.welch(x, fs, nperseg=nperseg)

    # return the PSD and the frequencies
    return Pxx, f

In [97]:
# Wavelet time frequency decomposition function

def wavelet_tf(x, fs, freqs, n_cycles):

    # calculate the wavelet time frequency decomposition
    # x is the signal
    # fs is the sampling frequency
    # freqs is the list of frequencies to calculate the decomposition for
    # n_cycles is the number of cycles to use for each frequency
    # returns the time frequency decomposition

    # calculate the time frequency decomposition
    tf = signal.cwt(x, signal.ricker, freqs, n_cycles)

    # return the time frequency decomposition
    return tf

In [98]:
# Filtering function
# - bandpass between 4 and 48 Hz, FIR with a Hamming window of 212 samples

def filter_signal(x, fs, f_low, f_high, n_taps):

    # filter the signal using a FIR filter
    # x is the signal
    # fs is the sampling frequency
    # f_low is the lower cutoff frequency
    # f_high is the higher cutoff frequency
    # n_taps is the number of taps for the filter
    # returns the filtered signal

    # calculate the normalized frequencies
    f_low_norm = f_low / (fs / 2)
    f_high_norm = f_high / (fs / 2)

    # calculate the filter taps
    taps = signal.firwin(n_taps, [f_low_norm, f_high_norm], window='hamming', pass_zero=False)

    # filter the signal
    x_filt = signal.lfilter(taps, 1.0, x)

    # return the filtered signal
    return x_filt

In [99]:
# Artefact removal function

def remove_artefacts(x, fs, f_low, f_high, n_taps, nperseg):

    # remove artefacts from the signal
    # x is the signal
    # fs is the sampling frequency
    # f_low is the lower cutoff frequency
    # f_high is the higher cutoff frequency
    # n_taps is the number of taps for the filter
    # nperseg is the length of the segment
    # returns the filtered signal with artefacts removed

    # filter the signal
    x_filt = filter_signal(x, fs, f_low, f_high, n_taps)

    # calculate the PSD
    Pxx, f = welch_psd(x_filt, fs, nperseg)

    # calculate the mean and standard deviation of the PSD
    mean = np.mean(Pxx)
    std = np.std(Pxx)

    # calculate the threshold
    threshold = mean + 3 * std

    # find the indices of the PSD above the threshold
    indices = np.where(Pxx > threshold)[0]

    # find the frequencies of the PSD above the threshold
    freqs = f[indices]

    # find the maximum frequency
    f_max = np.max(freqs)

    # if the maximum frequency is above the cutoff frequency
    if f_max > f_high:

        # find the index of the maximum frequency
        index = np.where(f == f_max)[0][0]

        # set the PSD above the maximum frequency to the mean
        Pxx[index:] = mean

        # calculate the inverse FFT
        x_filt = np.fft.irfft(Pxx)

    # return the filtered signal with artefacts removed
    return x_filt

In [100]:
# Function to divide each recording into smaller segments

def segment_signal(x, fs, nperseg):

    # rectangular window with a size of 768 samples
    # no overlap between segments
    # returns a list of segments

    # calculate the number of segments
    n_segments = int(np.floor((len(x)) / (nperseg)))

    # create a list to store the segments
    segments = []

    # loop through the segments
    for i in range(n_segments):
    
        # calculate the start and end indices
        start = i * (nperseg)
        end = start + nperseg

        # append the segment to the list
        segments.append(x[start:end])

    # return the list of segments
    return segments

In [101]:
# Function to perform PSD estimation on each segment

def segment_psd(x, fs, nperseg):

    # calculate the PSD for each segment
    # nperseg is the length of the segment
    # fs is the sampling frequency
    # overlap between segments is 50%
    # returns a list of PSDs

    # calculate the number of segments
    n_segments = int(np.floor((len(x) - nperseg / 2) / (nperseg / 2))) # 50% overlap between segments shown by nperseg/2

    # create a list to store the PSDs
    psds = []

    # loop through the segments
    for i in range(n_segments):
        
            # calculate the start and end indices
            start = i * int(nperseg / 2)
            end = start + nperseg
    
            # calculate the PSD
            Pxx, f = welch_psd(x[start:end], fs, nperseg)
    
            # append the PSD to the list
            psds.append(Pxx)

    # return the list of PSDs
    return psds

In [102]:
# Function to perform all the preprocessing steps for a single recording

# start by loading the data and segmenting the signal
# then calculate the PSD for each segment
# return the psds and the cleaned signals for each segment

def preprocess(filepath, fs, f_low, f_high, n_taps, nperseg):
    # fs is the sampling frequency
    # f_low is the lower cutoff frequency
    # f_high is the higher cutoff frequency
    # n_taps is the number of taps for the filter
    # nperseg is the length of the segment

    # load the data
    df = load_data(filepath)

    # get the signal for EEG P3 and EEG P4 channels
    x_P3 = df['EEG P3-REF'].values
    x_P4 = df['EEG P4-REF'].values

    # remove artefacts from the signal for EEG P3 and EEG P4 channels
    x_P3 = remove_artefacts(x_P3, fs, f_low, f_high, n_taps, nperseg)
    x_P4 = remove_artefacts(x_P4, fs, f_low, f_high, n_taps, nperseg)

    # segment the signal for EEG P3 and EEG P4 channels
    segments_P3 = segment_signal(x_P3, fs, nperseg)
    segments_P4 = segment_signal(x_P4, fs, nperseg)

    # calculate the PSD for each segment for EEG P3 and EEG P4 channels
    psds_P3 = segment_psd(x_P3, fs, nperseg)
    psds_P4 = segment_psd(x_P4, fs, nperseg)

    # return the PSDs and the cleaned signals for each segment
    return psds_P3, psds_P4, segments_P3, segments_P4

In [103]:
# Function to plot the PSDs for each segment

def plot_psd(psds, filename):
    
    # plot the PSDs for each segment
    # psds is a list of PSDs
    # filename is the name of the file

    # create a figure
    fig, ax = plt.subplots()

    # loop through the PSDs
    for i in range(len(psds)):
        
        # calculate the frequency vector
        f = np.fft.rfftfreq(len(psds[i]), 1 / FS)

        # plot the PSD
        ax.plot(f, psds[i])

    # set the title
    ax.set_title(filename)

    # set the x and y labels
    ax.set_xlabel("Frequency (Hz)")
    ax.set_ylabel("Power (dB)")

    # set the x and y limits
    ax.set_xlim([0, 50])
    ax.set_ylim([0, 0.0005])

    # save the figure in the ./figs directory
    fig.savefig("./figs/" + filename + ".png")

In [112]:
# PSD Report function

def psd_report(P3_scores, P4_scores):
    # Generates a table with the mean and standard deviation of the PSDs for each recording for each channel
    # fp3_scores and fp4_scores are the dicts with the PSDs and the cleaned signals for each segment

    # create a list to store the mean and standard deviation of the PSDs for each recording for each channel
    psd_report = []

    # Output report columns are: recording P channel(P3/P4) mean_PSD std_PSD number_of_segments
    # loop through the recordings
    for filename in P3_scores.keys():

        # get the PSDs for each channel
        psds_P3 = P3_scores[filename][0]
        psds_P4 = P4_scores[filename][0]

        # calculate the mean and standard deviation of the PSDs for each channel
        mean_PSD_P3 = np.mean(psds_P3)
        std_PSD_P3 = np.std(psds_P3)
        mean_PSD_P4 = np.mean(psds_P4)
        std_PSD_P4 = np.std(psds_P4)

        # calculate the number of segments
        n_segments_3 = len(psds_P3)
        n_segments_4 = len(psds_P4)

        # append the values to the list
        psd_report.append([filename, "P3", mean_PSD_P3, std_PSD_P3, n_segments_3])
        psd_report.append([filename, "P4", mean_PSD_P4, std_PSD_P4, n_segments_4])

    # create a dataframe from the list
    df = pd.DataFrame(psd_report, columns=["recording", "channel", "mean_PSD", "std_PSD", "number_of_segments"])

    # save the dataframe as a csv file in the ./reports directory
    df.to_csv("./reports/psd_report.csv", index=False)

    # print the dataframe
    print(df)



In [113]:
# Full preprocessing function

def full_pre():

    # get the filepaths for the recordings
    filepaths = get_filepaths("data/train/", 1)

    # create dicts to store the PSDs and the cleaned signals for each segment
    # (key: file name, value: tuple of PSDs and cleaned signals)
    P3_scores = {}
    P4_scores = {}

    # loop through the filepaths
    for filepath in filepaths:
            
        # get the file name
        filename = os.path.basename(filepath)

        # preprocess the data
        psds_P3, psds_P4, segments_P3, segments_P4 = preprocess(filepath, FS, F_LOW, F_HIGH, N_TAPS, NPERSEG)

        # store the PSDs and the cleaned signals for each segment
        P3_scores[filename] = (psds_P3, segments_P3)
        P4_scores[filename] = (psds_P4, segments_P4)

        # plot the PSDs for each segment
        # plot_psd(psds_P3, filename)

    # return the dicts
    return P3_scores, P4_scores

# Call the full preprocessing function
P3_SCORES, P4_SCORES = full_pre()

# Generate the PSD report
psd_report(P3_SCORES, P4_SCORES)

                recording channel  mean_PSD     std_PSD  number_of_segments
0  00000021_s003_t000.csv      P3  4.535239  176.379801                 824
1  00000021_s003_t000.csv      P4  2.499747   91.023724                 824


## The ES1D Model