In [1]:
import numpy as np
import pandas as pd
import h5py
import os
import time
from scipy.stats import skew, kurtosis
from scipy import signal, stats
from sklearn.preprocessing import StandardScaler, MinMaxScaler

import torch
from torch_geometric.data import InMemoryDataset, Data
from torch.nn import Linear
import torch.nn.functional as F 
from torch_geometric.nn import GCNConv, TopKPooling, global_mean_pool
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
from torch_geometric.data import DataLoader
import warnings
warnings.filterwarnings("ignore")

import matplotlib.pyplot as plt

In [2]:
def getCoherenceAdjacency(sample, channel):

    """
    Get average and stdv of pairwise MS-coherence between
    channels on 5 freq ranges.
    @param sample: Array, a sample signal of size (640X64)
    @param channel: Integer, current channel for pairwise coherence
    return: Array, average and stdv of coherence of size (10,)
            Array, adjacency vector of size (64,)
    """

    i_Cxy = 0
    Cxy_pairwise = np.empty((63, 5))
    adjacency_vector = np.zeros((64,))
    for other_channel in range(64):
        if other_channel == channel:
            continue
        adjacency_vector[other_channel] = abs(stats.spearmanr(sample[:, channel], sample[:, other_channel])[0])
        f, Cxy = signal.coherence(sample[:, channel], sample[:, other_channel], 160, nperseg=80)
        alpha_idx = np.where((f >= 8) & (f <= 12))[0]
        alpha_mean = Cxy[alpha_idx].mean()

        beta_idx = np.where((f >= 12) & (f <= 35))[0]
        beta_mean = Cxy[beta_idx].mean()

        gamma_idx = np.where(f >= 35)[0]
        gamma_mean = Cxy[gamma_idx].mean()

        theta_idx = np.where((f >= 4) & (f <= 8))[0]
        theta_mean = Cxy[theta_idx].mean()

        delta_idx = np.where((f >= 0.5) & (f <= 4))[0]
        delta_mean = Cxy[delta_idx].mean()
        Cxy_pairwise[i_Cxy] = [delta_mean, theta_mean, alpha_mean, beta_mean, gamma_mean]
        i_Cxy += 1

    return np.append(Cxy_pairwise.mean(axis=0), Cxy_pairwise.std(axis=0)), adjacency_vector

In [3]:
def getTimeFeatures(sample):

    """
    Extracts time-domain features from sample signals.
    @param sample: Array, a sample signal of size (640X64)
    return: Array, time-domain feature matrix of size (64X10)
    """

    # mean absolute value
    mav = abs(sample).mean(axis=0)[..., np.newaxis]
    # variance
    var = np.var(sample, axis=0)[..., np.newaxis]
    # mean square root
    msr = np.sqrt(abs(sample)).mean(axis=0)[..., np.newaxis]
    # root mean square
    rms = np.sqrt(np.mean(sample**2, axis=0))[..., np.newaxis]
    # log detector
    ld = np.exp(np.log(abs(sample)).mean(axis=0))[..., np.newaxis]
    # waveform length
    wav_len = np.sum(abs(sample[1:, :] - sample[:-1, :]), axis=0)[..., np.newaxis]
    # difference absolute standard deviation value
    dasdv = np.sqrt(np.mean((sample[1:, :] - sample[:-1, :]) ** 2, axis=0))[..., np.newaxis]
    # zero crossing
    sample_mul = np.sign(sample[1:, :] * sample[:-1, :])
    sample_mul_sgn = np.where(sample_mul==1, 0, sample_mul)
    sample_mul_sgn = np.where(sample_mul_sgn==-1, 1, sample_mul_sgn)
    sample_diff_sgn = abs(sample[1:, :] - sample[:-1, :]) >= 0.01
    nzc = np.logical_and(sample_mul_sgn, sample_diff_sgn).sum(axis=0)[..., np.newaxis]
    # skewness
    sample_skewness = skew(sample, axis=0)[..., np.newaxis]
    # kurtosis
    sample_kurtosis = kurtosis(sample, axis=0)[..., np.newaxis]

    time_features = np.hstack((mav, var, msr, rms, ld, wav_len, dasdv,
                               nzc, sample_skewness, sample_kurtosis))
    # normalize
    scaler = StandardScaler()
    time_features = scaler.fit_transform(time_features)
    return time_features

In [4]:
def getFeaturesAdjacency(sample):

    """
    Extracts temporal and spectral features from sample signals.
    Calculate Spearman's correlation as weighted adjacency matrix.
    @param sample: Array, a sample signal of size (640X64)
    return: Array, feature matrix of size (64X25)
            Array, adjacency matrix of size (64X64)
    """

    time_features = getTimeFeatures(sample)

    freq_features = np.zeros((64, 15))
    adjacency_matrix = np.zeros((64, 64))
    for channel in range(64):
#         f, pxx = signal.welch(sample[:, channel], fs=160, window='hann', nperseg=256, nfft=256)
        f, pxx = signal.periodogram(sample[:, channel], 160)
        alpha_idx = np.where((f >= 8) & (f <= 12))[0]
        alpha_mean = pxx[alpha_idx].mean()

        beta_idx = np.where((f >= 12) & (f <= 35))[0]
        beta_mean = pxx[beta_idx].mean()

        gamma_idx = np.where(f >= 35)[0]
        gamma_mean = pxx[gamma_idx].mean()

        theta_idx = np.where((f >= 4) & (f <= 8))[0]
        theta_mean = pxx[theta_idx].mean()

        delta_idx = np.where((f >= 0.5) & (f <= 4))[0]
        delta_mean = pxx[delta_idx].mean()

        coherence, adjacency_vector = getCoherenceAdjacency(sample, channel)

        freq_features_sample = np.array([delta_mean, theta_mean, alpha_mean,
                                         beta_mean, gamma_mean])
        freq_features[channel] = np.append(freq_features_sample, coherence)
        adjacency_matrix[channel, :] = adjacency_vector
    
    # normalize
    scaler = StandardScaler()
    freq_features = scaler.fit_transform(freq_features)

    return np.hstack((time_features, freq_features)), adjacency_matrix

In [5]:
def extract(data, threshold, feature_matrices, edge_index_lst, edge_attr_lst):
    
    """
    Extract feature_matrices, edge_indices, edge_attrs for a dataset.
    @param data: Array, train, validation, or test signals
    @param threshold: Float, threshold for correlation
    @param feature_matrices: Array, placeholder for feature matrix of each sample
    @param edge_index_lst: List, placeholder for edge index of each sample
    @param edge_attr_lst: List, placeholder for edge attributes of each sample
    """
    
    for sample_ind in range(data.shape[0]):
        if sample_ind % 50 == 0:
            print(f'sample_ind: {sample_ind}')
        sample = data[sample_ind]
        feature_matrix, adjacency_matrix = getFeaturesAdjacency(sample)
        feature_matrices[sample_ind] = feature_matrix

        # get edge index
        edge_index = np.argwhere(adjacency_matrix > threshold).T
        # get edge attributes
        valid_ind = np.where(adjacency_matrix.ravel() > threshold)[0]
        edge_attr = adjacency_matrix.ravel()[valid_ind][..., np.newaxis]
        edge_index_lst.append(edge_index)
        edge_attr_lst.append(edge_attr)

In [6]:
class EEGDataset(InMemoryDataset):
    def __init__(self, root, eeg, tasks, transform=None, pre_transform=None):
        self.eeg = eeg
        self.y = tasks
        super(EEGDataset, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])
        
    @property
    def raw_file_names(self):
        return []
    
    @property
    def processed_file_names(self):
        return ['data.pt']
    
    def download(self):
        pass
    
    def process(self):
        
        data_list = []
        for i in range(len(self.eeg.edge_index)):
            data = Data(x=torch.FloatTensor(self.eeg.feature_matrix[i]),
                    edge_index=torch.tensor(self.eeg.edge_index[i], dtype=torch.long),
                    edge_attr=torch.FloatTensor(self.eeg.edge_attr[i]),
                    y=torch.tensor(self.y[i]-1, dtype=torch.long))    # save tasks
            data_list.append(data)
        
        self.data, self.slices = self.collate(data_list)
        torch.save((self.data, self.slices), self.processed_paths[0])

In [7]:
class Xprep():
    
    """
    Placeholder for normalized feature_matrix, edge_index, and edge_attr
    in preparation for GCN Dataset generation.
    """
    def __init__(self, feature_matrix, edge_index, edge_attr):
        self.feature_matrix = feature_matrix
        self.edge_index = edge_index
        self.edge_attr = edge_attr

In [8]:
loadPath = '/scratch/qh503/deepLearningProject/data_h5/'
savePath = '/scratch/qh503/deepLearningProject/data_GCN/timeFreq25_tr1800val360ts_normExtractNorm_normTimeFreq/'

In [9]:
f_train = h5py.File(os.path.join(loadPath, "train1800_raw_EEG.h5"), "r")
tr_data = f_train['data'][:]
ytr = f_train['tasks'][:]
tr_subjects = f_train['subjects'][:]

f_valid = h5py.File(os.path.join(loadPath, "valid360_raw_EEG.h5"), "r")
val_data = f_valid['data'][:]
yval = f_valid['tasks'][:]
val_subjects = f_valid['subjects'][:]

f_test = h5py.File(os.path.join(loadPath, "test360_raw_EEG.h5"), "r")
ts_data = f_test['data'][:]
yts = f_test['tasks'][:]
ts_subjects = f_test['subjects'][:]

## Normalization

In [10]:
# flatten and reshape data
xtr_s_flattened = np.squeeze(tr_data).ravel().reshape((-1, 64))
xval_s_flattened = np.squeeze(val_data).ravel().reshape((-1, 64))
xts_s_flattened = np.squeeze(ts_data).ravel().reshape((-1, 64))
print(xtr_s_flattened.shape)
print(xval_s_flattened.shape)
print(xts_s_flattened.shape)

# normalize data
scaler = StandardScaler()
Ztr_temp = scaler.fit_transform(xtr_s_flattened)
Zval_temp = scaler.transform(xval_s_flattened)
Zts_temp = scaler.transform(xts_s_flattened)

# flatten and reshape data back
Ztr = np.squeeze(Ztr_temp).ravel().reshape((-1, 640, 64))
Zval = np.squeeze(Zval_temp).ravel().reshape((-1, 640, 64))
Zts = np.squeeze(Zts_temp).ravel().reshape((-1, 640, 64))
print(Ztr.shape)
print(Zval.shape)
print(Zts.shape)

(1152000, 64)
(230400, 64)
(230400, 64)
(1800, 640, 64)
(360, 640, 64)
(360, 640, 64)


## Extract features for each sample

In [11]:
# %%time
Xtr = np.empty((tr_data.shape[0], 64, 25))
Xval = np.empty((val_data.shape[0], 64, 25))
Xts = np.empty((ts_data.shape[0], 64, 25))
tr_edge_index, tr_edge_attr = [], []
val_edge_index, val_edge_attr = [], []
ts_edge_index, ts_edge_attr = [], []
print('trainset processing...')
extract(Ztr, 0.4, Xtr, tr_edge_index, tr_edge_attr)
print('valset processing...')
extract(Zval, 0.4, Xval, val_edge_index, val_edge_attr)
print('testset processing...')
extract(Zts, 0.4, Xts, ts_edge_index, ts_edge_attr)

trainset processing...
sample_ind: 0


## GCN dataset generation

In [12]:
tr_eeg = Xprep(Xtr, tr_edge_index, tr_edge_attr)
val_eeg = Xprep(Xval, val_edge_index, val_edge_attr)
ts_eeg = Xprep(Xts, ts_edge_index, ts_edge_attr)

In [13]:
tr_dataset = EEGDataset(root=os.path.join(savePath, 'train'), eeg=tr_eeg, tasks=ytr)
val_dataset = EEGDataset(root=os.path.join(savePath, 'valid'), eeg=val_eeg, tasks=yval)
ts_dataset = EEGDataset(root=os.path.join(savePath, 'test'), eeg=ts_eeg, tasks=yval)