In [3]:
%matplotlib notebook
import mne
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from scipy.signal import butter, cheby1, filtfilt

from scipy.signal import spectrogram, stft, istft, check_NOLA
import pickle

# IMPORT DATA AND PREPROCESS

In [5]:
# 31, 35, 38 and test folders for good data
# odd scalp on left, even on right

# SPECIFY PATIENT AND SCALP TARGET ELECTRODE
patient = 'UFSEEG031'
mode = 'Wake'
targetScalpElectrodes = [] # assigned below

# SPECIFY ARTIFACT ELECTRODES FROM WORD FILES
artifactElectrodes = {}
artifactElectrodes['UFSEEG031'] = ['LTP7', 'LTP8', 'LAH11', 'LAH12', 'LPH10', 'LPH11', 'LPH12','LOF15', 'LOF16']

#filepath = '/blue/gkalamangalam/ALLDATA/SEEG/%s/SEEG/EDF/TestClipSleep.edf' % patient
filepath = '/blue/gkalamangalam/ALLDATA/SEEG/%s/SEEG/EDF/TestClipAwake/TestClip%sME1.edf' % (patient, mode)

raw = mne.io.read_raw_edf(filepath,preload=True)
sfreq = int(raw.info['sfreq'])

scalpElectrodes = {}
scalpElectrodes[patient] = [i for i in raw.ch_names if len(i) == 2 and i !='PR']
targetScalpElectrodes = scalpElectrodes[patient]
print()
print(scalpElectrodes)
print(raw)
print(raw.info)

Extracting EDF parameters from /blue/gkalamangalam/ALLDATA/SEEG/UFSEEG031/SEEG/EDF/TestClipAwake/TestClipWakeME1.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 241023  =      0.000 ...   235.374 secs...

{'UFSEEG031': ['F7', 'F8', 'F3', 'F4', 'C3', 'C4', 'P7', 'P8', 'P3', 'P4']}
<RawEDF | TestClipWakeME1.edf, 148 x 241024 (235.4 s), ~272.3 MB, data loaded>
<Info | 7 non-empty values
 bads: []
 ch_names: LTP1, LTP2, LTP3, LTP4, LTP5, LTP6, LTP7, LTP8, LAM1, LAM2, ...
 chs: 148 EEG
 custom_ref_applied: False
 highpass: 0.0 Hz
 lowpass: 512.0 Hz
 meas_date: 2001-01-01 10:26:35 UTC
 nchan: 148
 projs: []
 sfreq: 1024.0 Hz
>


In [6]:
# DISCARD ALL CHANNELS EXCEPT GOOD SEEG CHANNELS AND THE SCALP TARGETS

channels = [i for i in raw.ch_names if i not in artifactElectrodes[patient] and i[0] in {'L', 'R'}] + targetScalpElectrodes
raw.pick_channels(channels)#.plot(duration=5.0, n_channels=20);

nScalp = len(targetScalpElectrodes)
nDepth = len(channels) - nScalp

In [7]:
# LOWPASS FILTER THE DATA, SUBSAMPLE THE DATA, SCALE ALL CHANNELS, AND EXTRACT TO NUMPY ARRAY

filterWindow = 64
subsampleFreq = filterWindow * 2   # FINAL FREQUENCY IN HERTZ AFTER SUBSAMPLING
filterOrder = 5
goodTimes = [] # set this below

df = raw.to_data_frame().drop(labels=['time'], axis=1)
data = df.to_numpy()

# SHOULD WE SCALE HERE OR AFTER FILTER AND SUBSAMPLE?
scaler = StandardScaler()
data = scaler.fit_transform(data)
goodTimes = range(int(data.shape[0] * .75))
data = data[goodTimes,:]

b, a = butter(filterOrder, filterWindow, btype='lowpass', fs = sfreq)
data = filtfilt(b, a, data, axis=0)
    
dataSubsampled = data[::sfreq // subsampleFreq,:]

# SEE QUESTION ABOVE...
#scaler = StandardScaler()
#dataSubsampled = scaler.fit_transform(dataSubsampled)

pd.DataFrame(dataSubsampled, columns=df.columns)

Unnamed: 0,LTP1,LTP2,LTP3,LTP4,LTP5,LTP6,LAM1,LAM2,LAM3,LAM4,...,F7,F8,F3,F4,C3,C4,P7,P8,P3,P4
0,1.269132,0.611309,0.332447,0.513127,0.050132,-0.921070,-0.487324,-0.081578,-0.525729,0.334276,...,0.010991,0.344059,-0.011087,0.166701,0.107314,0.244493,0.181124,0.244384,0.146235,0.207659
1,1.132279,0.446085,0.220560,0.418795,-0.012608,-1.192597,-0.458433,0.084668,-0.519432,0.192957,...,-0.226039,0.071321,-0.286986,-0.133555,-0.177902,-0.070756,-0.120731,0.035725,-0.166100,-0.086214
2,1.519371,0.780325,0.556298,0.783363,0.417908,-0.903523,-0.043603,0.561195,-0.101200,0.637068,...,0.324991,0.702319,0.293988,0.468108,0.383928,0.487604,0.412625,0.569097,0.366216,0.430346
3,1.045590,0.179590,-0.021626,0.243580,-0.178253,-1.434499,-0.418117,0.010637,-0.648773,-0.043092,...,-0.423745,-0.019962,-0.475247,-0.294202,-0.405888,-0.271123,-0.295942,-0.119615,-0.384137,-0.304387
4,1.524571,0.708372,0.451041,0.814309,0.455070,-0.785436,-0.112308,0.524834,-0.020276,0.738102,...,0.487048,0.855929,0.441763,0.629191,0.526852,0.635298,0.623851,0.732582,0.531818,0.571096
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
22591,0.357568,0.311642,0.254523,0.269244,-0.112806,-1.129674,-0.668119,-0.831806,-0.896787,-0.550804,...,-0.690259,-0.616377,-0.617147,-0.623802,-0.627646,-0.555057,-0.595257,-0.547246,-0.616649,-0.532412
22592,1.206071,1.275532,1.193510,1.190172,0.913044,-0.246881,0.140903,0.086183,0.065144,0.618205,...,0.592805,0.733179,0.705584,0.729668,0.707240,0.810024,0.685200,0.734267,0.695917,0.775567
22593,0.226729,0.246565,0.220536,0.218664,-0.202363,-1.391676,-0.762677,-0.872524,-0.948308,-0.616703,...,-0.954188,-0.804301,-0.867815,-0.830987,-0.873775,-0.755873,-0.875619,-0.825106,-0.875512,-0.787598
22594,1.276653,1.440350,1.394966,1.407198,1.077029,-0.262484,0.180379,0.213563,0.239611,0.844164,...,0.757063,0.958811,0.874841,0.958370,0.862809,0.978906,0.824342,0.891225,0.849699,0.927649


In [8]:
# PARTITION TIME SERIES INTO CONTIGUOUS TRAIN AND VALIDATION BLOCKS
# OTHERWISE WHEN VALID SET IS RANDOMLY DISTRIBUTED, IMPLICIT OVERFITTING OCCURS DUE TO TEMPORAL DEPENDENCY
# EACH BLOCK (SPECIFIED IN SECONDS) IS DIVIDED INTO TRAIN (SPECIFIED BY FRACTION, COMES FIRST) AND VALIDIDATION (COMES LAST)

def timeseriesTrainValidSplit(secondsInBlock, totalSeconds, trainFraction, subsampleFreq):
    nBlock = int(totalSeconds / secondsInBlock)
    samplesPerBlock = subsampleFreq * secondsInBlock
    trainIndexProto = np.arange(0, samplesPerBlock * trainFraction, dtype=int)
    validIndexProto = np.arange(samplesPerBlock * trainFraction, samplesPerBlock, dtype=int)

    trainIndexBlocks = [(trainIndexProto + (i * samplesPerBlock)).astype(int) for i in range(nBlock)]
    validIndexBlocks = [(validIndexProto + (i * samplesPerBlock)).astype(int) for i in range(nBlock)]

    trainIndices = np.concatenate(trainIndexBlocks).astype(int)
    validationIndices = np.concatenate(validIndexBlocks).astype(int)
    #return trainIndices, validationIndices
    return trainIndexBlocks, validIndexBlocks, trainIndices, validationIndices

def timeDomainDataMake(indexBlocks, halfWindow, dataSubsampled):
    xList = []
    yList = []
    _, nChannels = dataSubsampled.shape
    for thisBlock in indexBlocks:
        thisData = np.vstack([np.zeros((halfWindow, nChannels)), 
                             dataSubsampled[thisBlock,:], 
                             np.zeros((halfWindow, nChannels))])
        
        for t in range(0, len(thisBlock)):
            thisX = thisData[t:t + (2 * halfWindow) + 1,0:nDepth].flatten()
            thisY = thisData[t + halfWindow, nDepth:]
            xList.append(thisX)
            yList.append(thisY)

    xTimeDomain = np.stack(xList, axis = 0)
    #yTimeDomain = np.expand_dims(np.array(yList), axis=1)
    yTimeDomain = np.stack(yList, axis = 0)
    return xTimeDomain, yTimeDomain

# INDICES FOR TRAIN/VALIDIDATION SPLIT

In [9]:
secondsInBlock = 5
trainFraction = 0.0
totalSeconds = int(dataSubsampled.shape[0]/subsampleFreq)

if trainFraction > 0.0:
    trainIndexBlocks, validIndexBlocks, trainIndices, validationIndices = timeseriesTrainValidSplit(secondsInBlock, 
                                                                                                    totalSeconds, 
                                                                                                    trainFraction, 
                                                                                                    subsampleFreq)

# TIME DOMAIN DATA

In [10]:
halfWindowSeconds = .25

halfWindowSamples = int(halfWindowSeconds * subsampleFreq)

if trainFraction > 0.0:
    xTrainTimeDomain, yTrainTimeDomain = timeDomainDataMake(trainIndexBlocks, halfWindowSamples, dataSubsampled)
    xValidTimeDomain, yValidTimeDomain = timeDomainDataMake(validIndexBlocks, halfWindowSamples, dataSubsampled)

else:
    validIndexBlocks = [np.array(range(0, dataSubsampled.shape[0]))]
    xValidTimeDomain, yValidTimeDomain = timeDomainDataMake(validIndexBlocks, halfWindowSamples, dataSubsampled)
    xTrainTimeDomain, yTrainTimeDomain = np.array([]), np.array([])
    
print(xTrainTimeDomain.shape, yTrainTimeDomain.shape, xValidTimeDomain.shape, yValidTimeDomain.shape)

(0,) (0,) (22596, 5655) (22596, 10)


In [12]:
arraySavePath = '/blue/gkalamangalam/jmark.ettinger/predictScalp/timeDomain_%s_%s_%s.npz' % (patient, targetScalpElectrodes, mode)
np.savez(arraySavePath, 
         xTrainTimeDomain=xTrainTimeDomain, 
         xValidTimeDomain=xValidTimeDomain,
         yTrainTimeDomain=yTrainTimeDomain,
         yValidTimeDomain=yValidTimeDomain)
print(arraySavePath)

/blue/gkalamangalam/jmark.ettinger/predictScalp/timeDomain_UFSEEG031_['F7', 'F8', 'F3', 'F4', 'C3', 'C4', 'P7', 'P8', 'P3', 'P4']_Wake.npz


# STFT DATA

In [None]:
# APPLY SHORT TERM FOURIER TRANSFORM TO THE DATA AND CHECK PARAMETERS FOR INVERTABILITY

secondsInSTFTWindow = .5
nperseg = subsampleFreq * secondsInSTFTWindow
noverlap = nperseg - 1
windowType = ('tukey', .25)

f, t, S = stft(dataSubsampled, fs=subsampleFreq, window=windowType, nperseg=nperseg, noverlap=noverlap, axis=0)

print('freq, ', 'time, ', 'stft shape')
print(f.shape, t.shape, S.shape)
print('inverse ok? ',check_NOLA(windowType, nperseg, noverlap))

In [None]:
x_trainComplex = S[:, 0:-1, trainIndices].transpose([2,0,1])
y_trainComplex = S[:, -1, trainIndices].transpose()

x_validComplex = S[:, 0:-1, validationIndices].transpose([2,0,1])
y_validComplex = S[:, -1, validationIndices].transpose()

# MAKE REAL-VALUED TRAINING DATA BY CONVERTING STFT COMPLEX NUMBERS TO R,THETA
_,_,numCol = x_trainComplex.shape
x_trainRTheta = np.hstack([np.hstack([np.abs(x_trainComplex[:,:,i]), 
                                      np.angle(x_trainComplex[:,:,i])]) for i in range(numCol)])
x_validRTheta = np.hstack([np.hstack([np.abs(x_validComplex[:,:,i]), 
                                      np.angle(x_validComplex[:,:,i])]) for i in range(numCol)])

y_trainRTheta = np.hstack([np.abs(y_trainComplex), np.angle(y_trainComplex)])
y_validRTheta = np.hstack([np.abs(y_validComplex), np.angle(y_validComplex)])

_, nY = y_trainRTheta.shape
x_trainRTheta.shape, x_validRTheta.shape, y_trainRTheta.shape, y_validRTheta.shape

In [None]:
# PLOT THE STFT OF A TIME SERIES (MAGNITUDE ONLY)

index = -1 # -1 is the target
vmax = .2

plt.figure()
plt.pcolormesh(t, f, np.abs(S[:,index,:]), shading='auto', cmap='hot', vmin=0, vmax=vmax)
plt.ylabel('Frequency [Hz]')
plt.xlabel('Time [sec]')
plt.title('Index: %s' % str(index))
plt.show()

In [None]:
arraySavePath = '/blue/gkalamangalam/jmark.ettinger/predictScalp/freqRTheta_%s_%s_%s.npz' % (patient, targetScalpElectrode, mode)
np.savez(arraySavePath, 
         x_trainRTheta=x_trainRTheta, 
         x_validRTheta=x_validRTheta, 
         y_trainRTheta=y_trainRTheta, 
         y_validRTheta=y_validRTheta)

# COMBINE TIME AND FREQUENCY DOMAIN DATA

In [None]:
x_trainTimeFreq = np.hstack([xTrainTimeDomain, x_trainRTheta])
y_trainTimeFreq = np.hstack([yTrainTimeDomain, y_trainRTheta])
x_validTimeFreq = np.hstack([xValidTimeDomain, x_validRTheta])
y_validTimeFreq = np.hstack([yValidTimeDomain, y_validRTheta])

In [None]:
x_trainTimeFreq.shape, y_trainTimeFreq.shape, x_validTimeFreq.shape, y_validTimeFreq.shape

In [None]:
arraySavePath = '/blue/gkalamangalam/jmark.ettinger/predictScalp/timeFreqRTheta_%s_%s_%s.npz' % (patient, targetScalpElectrode, mode)
np.savez(arraySavePath, 
         x_trainTimeFreq=x_trainTimeFreq, 
         x_validTimeFreq=x_validTimeFreq, 
         y_trainTimeFreq=y_trainTimeFreq, 
         y_validTimeFreq=y_validTimeFreq)

# SCRATCH

In [None]:
# stft parameter tests

windowType = ('tukey', .25)

fakeData = np.random.rand(10000, 5)

f, t, S = stft(fakeData, fs=1000, window=windowType, nperseg=100, noverlap=0, axis=0, boundary=None)

f.shape, t.shape, S.shape, f, t

In [None]:
# OLD VERSION FOR TIME DOMAIN

xTrainTimeDomain = dataSubsampled[trainIndices, 0:-1]
yTrainTimeDomain = np.expand_dims(dataSubsampled[trainIndices, -1], axis=1)

xValidTimeDomain = dataSubsampled[validationIndices, 0:-1]
yValidTimeDomain = np.expand_dims(dataSubsampled[validationIndices, -1], axis=1)

xTrainTimeDomain.shape, yTrainTimeDomain.shape, xValidTimeDomain.shape, yValidTimeDomain.shape