In [7]:
%matplotlib notebook
import mne
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from scipy.signal import butter, cheby1, filtfilt

from scipy.signal import spectrogram, stft, istft, check_NOLA
import pickle

# IMPORT DATA AND PREPROCESS

In [2]:
# 31, 35, 38 and test folders for good data
# odd scalp on left, even on right

# SPECIFY PATIENT AND SCALP TARGET ELECTRODE
patient = 'UFSEEG031'
targetScalpElectrode = 'F7'

# SPECIFY ARTIFACT ELECTRODES FROM WORD FILES
artifactElectrodes = {}
artifactElectrodes['UFSEEG031'] = ['LTP7', 'LTP8', 'LAH11', 'LAH12', 'LPH10', 'LPH11', 'LPH12','LOF15', 'LOF16']

#filepath = '//ahcdfs.ahc.ufl.edu/files/NLGY/Groups/Epilepsy/KalamangalamLab/SEEG/%s/SEEG/EDF/TestClipSleep/TestClipSleep.edf' % patient
#filepath = 'C:/Users/the_m/Data/uFlorida/predictScalp/TestClipSleep.edf'
filepath = '/blue/gkalamangalam/jmark.ettinger/predictScalp/TestClipSleep.edf'

raw = mne.io.read_raw_edf(filepath,preload=True)
sfreq = int(raw.info['sfreq'])

scalpElectrodes = {}
scalpElectrodes[patient] = [i for i in raw.ch_names if len(i) == 2]
print()
print(scalpElectrodes)
print(raw)
print(raw.info)

Extracting EDF parameters from /blue/gkalamangalam/jmark.ettinger/predictScalp/TestClipSleep.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 1766015  =      0.000 ...  1724.624 secs...

{'UFSEEG031': ['F7', 'F8', 'F3', 'F4', 'C3', 'C4', 'P7', 'P8', 'P3', 'P4', 'PR']}
<RawEDF | TestClipSleep.edf, 148 x 1766016 (1724.6 s), ~1.95 GB, data loaded>
<Info | 7 non-empty values
 bads: []
 ch_names: LTP1, LTP2, LTP3, LTP4, LTP5, LTP6, LTP7, LTP8, LAM1, LAM2, ...
 chs: 148 EEG
 custom_ref_applied: False
 highpass: 0.0 Hz
 lowpass: 512.0 Hz
 meas_date: 2001-01-01 02:05:22 UTC
 nchan: 148
 projs: []
 sfreq: 1024.0 Hz
>


In [3]:
# DISCARD ALL CHANNELS EXCEPT GOOD SEEG CHANNELS AND THE SINGLE SCALP TARGET

channels = [i for i in raw.ch_names if i not in artifactElectrodes[patient] and i[0] in {'L', 'R'}] + [targetScalpElectrode]
raw.pick_channels(channels)#.plot(duration=5.0, n_channels=20);

<RawEDF | TestClipSleep.edf, 88 x 1766016 (1724.6 s), ~1.16 GB, data loaded>

In [4]:
# LOWPASS FILTER THE DATA, SUBSAMPLE THE DATA, SCALE ALL CHANNELS, AND EXTRACT TO NUMPY ARRAY

filterWindow = 64
subsampleFreq = filterWindow * 2   # FINAL FREQUENCY IN HERTZ AFTER SUBSAMPLING
filterOrder = 5

df = raw.to_data_frame().drop(labels=['time'], axis=1)
data = df.to_numpy()

# SHOULD WE SCALE HERE OR AFTER FILTER AND SUBSAMPLE?
scaler = StandardScaler()
data = scaler.fit_transform(data)

b, a = butter(filterOrder, filterWindow, btype='lowpass', fs = sfreq)
data = filtfilt(b, a, data, axis=0)
    
dataSubsampled = data[::sfreq // subsampleFreq,:]

# SEE QUESTION ABOVE...
#scaler = StandardScaler()
#dataSubsampled = scaler.fit_transform(dataSubsampled)

pd.DataFrame(dataSubsampled, columns=df.columns)

Unnamed: 0,LTP1,LTP2,LTP3,LTP4,LTP5,LTP6,LAM1,LAM2,LAM3,LAM4,...,LOF6,LOF7,LOF8,LOF9,LOF10,LOF11,LOF12,LOF13,LOF14,F7
0,-0.667612,-0.549802,-0.330502,-0.630513,-0.410391,0.427998,-1.039937,-1.478466,-1.065475,-0.822086,...,-0.472660,-0.251789,-0.305125,-0.427591,-0.342536,-0.520219,-0.676221,-1.557479,-1.436998,1.470390
1,-0.738084,-0.652242,-0.419590,-0.709839,-0.516951,0.324686,-1.225429,-1.600105,-1.230353,-0.989760,...,-0.602268,-0.367111,-0.381405,-0.524810,-0.499453,-0.721132,-0.893702,-1.772723,-1.607507,1.058312
2,-0.699389,-0.571008,-0.339067,-0.611417,-0.399328,0.375402,-1.236690,-1.564062,-1.189561,-0.916646,...,-0.504535,-0.298947,-0.299513,-0.397127,-0.261086,-0.453411,-0.636030,-1.620408,-1.480281,1.353298
3,-0.780354,-0.592758,-0.324217,-0.600832,-0.394643,0.179410,-1.170003,-1.549847,-1.217005,-0.944660,...,-0.631047,-0.428185,-0.367389,-0.457706,-0.333179,-0.513631,-0.702537,-1.727769,-1.552393,0.822709
4,-0.795788,-0.529620,-0.255734,-0.530419,-0.319918,0.073355,-1.070847,-1.472654,-1.120555,-0.801174,...,-0.614531,-0.417956,-0.368458,-0.462024,-0.315052,-0.458764,-0.581764,-1.622157,-1.376446,1.034713
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
220747,0.599706,0.880566,0.329989,0.170022,-0.593336,1.171366,-0.045117,0.042424,0.944977,-0.167098,...,-0.590951,0.042153,0.886732,-0.173214,0.592052,0.421529,0.223303,0.964764,-0.054071,1.284230
220748,0.796347,1.025520,0.587031,0.394007,-0.302905,1.216613,0.026838,0.057508,1.047731,-0.033600,...,-0.504816,0.082375,0.843882,-0.232989,0.397827,0.121425,-0.030988,0.860772,-0.106194,0.975292
220749,1.036370,1.355208,1.042317,0.785103,0.030696,1.566432,0.065474,0.110856,1.129801,0.020733,...,-0.418505,0.173858,0.908300,-0.371659,0.114995,-0.123640,-0.186584,0.758177,-0.257866,0.655768
220750,0.531788,0.671990,0.579819,0.479165,0.179328,0.754062,0.143671,0.138600,0.557351,0.096973,...,-0.076799,0.150046,0.452227,-0.124155,0.072867,0.000889,-0.010492,0.358378,-0.057188,0.514930


In [41]:
# PARTITION TIME SERIES INTO CONTIGUOUS TRAIN AND VALIDATION BLOCKS
# OTHERWISE WHEN VALID SET IS RANDOMLY DISTRIBUTED, IMPLICIT OVERFITTING OCCURS DUE TO TEMPORAL DEPENDENCY
# EACH BLOCK (SPECIFIED IN SECONDS) IS DIVIDED INTO TRAIN (SPECIFIED BY FRACTION, COMES FIRST) AND VALIDIDATION (COMES LAST)

def timeseriesTrainValidSplit(secondsInBlock, totalSeconds, trainFraction, subsampleFreq):
    nBlock = int(totalSeconds / secondsInBlock)
    samplesPerBlock = subsampleFreq * secondsInBlock
    trainIndexProto = np.arange(0, samplesPerBlock * trainFraction, dtype=int)
    validIndexProto = np.arange(samplesPerBlock * trainFraction, samplesPerBlock, dtype=int)

    trainIndexBlocks = [(trainIndexProto + (i * samplesPerBlock)).astype(int) for i in range(nBlock)]
    validIndexBlocks = [(validIndexProto + (i * samplesPerBlock)).astype(int) for i in range(nBlock)]

    trainIndices = np.concatenate(trainIndexBlocks).astype(int)
    validationIndices = np.concatenate(validIndexBlocks).astype(int)
    #return trainIndices, validationIndices
    return trainIndexBlocks, validIndexBlocks, trainIndices, validationIndices

'''
def timeDomainDataMake(indexBlocks, windowSamples, dataSubsampled):
    xList = []
    yList = []
    for thisBlock in indexBlocks:
        thisData = dataSubsampled[thisBlock,:]
        for t in range(0, len(thisBlock) - windowSamples):
            thisX = thisData[t:t + windowSamples,0:-1].flatten()
            thisY = thisData[t + (windowSamples // 2),-1]
            xList.append(thisX)
            yList.append(thisY)

    xTimeDomain = np.stack(xList, axis = 0)
    yTimeDomain = np.expand_dims(np.array(yList), axis=1)
    return xTimeDomain, yTimeDomain
'''

def timeDomainDataMake(indexBlocks, halfWindow, dataSubsampled):
    xList = []
    yList = []
    _, nChannels = dataSubsampled.shape
    for thisBlock in indexBlocks:
        thisData = np.vstack([np.zeros((halfWindow, nChannels)), 
                             dataSubsampled[thisBlock,:], 
                             np.zeros((halfWindow, nChannels))])
        
        for t in range(0, len(thisBlock)):
            thisX = thisData[t:t + (2 * halfWindow) + 1,0:-1].flatten()
            thisY = thisData[t + halfWindow, -1]
            xList.append(thisX)
            yList.append(thisY)

    xTimeDomain = np.stack(xList, axis = 0)
    yTimeDomain = np.expand_dims(np.array(yList), axis=1)
    return xTimeDomain, yTimeDomain

# INDICES FOR TRAIN/VALIDIDATION SPLIT

In [6]:
secondsInBlock = 5
totalSeconds = 1720
trainFraction = .8

trainIndexBlocks, validIndexBlocks, trainIndices, validationIndices = timeseriesTrainValidSplit(secondsInBlock, 
                                                                                                totalSeconds, 
                                                                                                trainFraction, 
                                                                                                subsampleFreq)

# TIME DOMAIN DATA

In [42]:
halfWindowSeconds = .25

halfWindowSamples = int(halfWindowSeconds * subsampleFreq)

xTrainTimeDomain, yTrainTimeDomain = timeDomainDataMake(trainIndexBlocks, halfWindowSamples, dataSubsampled)
xValidTimeDomain, yValidTimeDomain = timeDomainDataMake(validIndexBlocks, halfWindowSamples, dataSubsampled)

xTrainTimeDomain.shape, yTrainTimeDomain.shape, xValidTimeDomain.shape, yValidTimeDomain.shape

((176128, 5655), (176128, 1), (44032, 5655), (44032, 1))

In [44]:
arraySavePath = '/blue/gkalamangalam/jmark.ettinger/predictScalp/timeDomain.npz'
np.savez(arraySavePath, 
         xTrainTimeDomain=xTrainTimeDomain, 
         xValidTimeDomain=xValidTimeDomain,
         yTrainTimeDomain=yTrainTimeDomain,
         yValidTimeDomain=yValidTimeDomain)

# STFT DATA

In [35]:
# APPLY SHORT TERM FOURIER TRANSFORM TO THE DATA AND CHECK PARAMETERS FOR INVERTABILITY

secondsInSTFTWindow = .5
nperseg = subsampleFreq * secondsInSTFTWindow
noverlap = nperseg - 1
windowType = ('tukey', .25)

f, t, S = stft(dataSubsampled, fs=subsampleFreq, window=windowType, nperseg=nperseg, noverlap=noverlap, axis=0)

print('freq, ', 'time, ', 'stft shape')
print(f.shape, t.shape, S.shape)
print('inverse ok? ',check_NOLA(windowType, nperseg, noverlap))

freq,  time,  stft shape
(33,) (220753,) (33, 88, 220753)
inverse ok?  True


In [43]:
x_trainComplex = S[:, 0:-1, trainIndices].transpose([2,0,1])
y_trainComplex = S[:, -1, trainIndices].transpose()

x_validComplex = S[:, 0:-1, validationIndices].transpose([2,0,1])
y_validComplex = S[:, -1, validationIndices].transpose()

# MAKE REAL-VALUED TRAINING DATA BY CONVERTING STFT COMPLEX NUMBERS TO R,THETA
_,_,numCol = x_trainComplex.shape
x_trainRTheta = np.hstack([np.hstack([np.abs(x_trainComplex[:,:,i]), 
                                      np.angle(x_trainComplex[:,:,i])]) for i in range(numCol)])
x_validRTheta = np.hstack([np.hstack([np.abs(x_validComplex[:,:,i]), 
                                      np.angle(x_validComplex[:,:,i])]) for i in range(numCol)])

y_trainRTheta = np.hstack([np.abs(y_trainComplex), np.angle(y_trainComplex)])
y_validRTheta = np.hstack([np.abs(y_validComplex), np.angle(y_validComplex)])

_, nY = y_trainRTheta.shape
x_trainRTheta.shape, x_validRTheta.shape, y_trainRTheta.shape, y_validRTheta.shape

((176128, 5742), (44032, 5742), (176128, 66), (44032, 66))

In [None]:
# PLOT THE STFT OF A TIME SERIES (MAGNITUDE ONLY)

index = -1 # -1 is the target
vmax = .2

plt.figure()
plt.pcolormesh(t, f, np.abs(S[:,index,:]), shading='auto', cmap='hot', vmin=0, vmax=vmax)
plt.ylabel('Frequency [Hz]')
plt.xlabel('Time [sec]')
plt.title('Index: %s' % str(index))
plt.show()

In [45]:
arraySavePath = '/blue/gkalamangalam/jmark.ettinger/predictScalp/freqRTheta.npz'
np.savez(arraySavePath, 
         x_trainRTheta=x_trainRTheta, 
         x_validRTheta=x_validRTheta, 
         y_trainRTheta=y_trainRTheta, 
         y_validRTheta=y_validRTheta)

# COMBINE TIME AND FREQUENCY DOMAIN DATA

In [48]:
x_trainTimeFreq = np.hstack([xTrainTimeDomain, x_trainRTheta])
y_trainTimeFreq = np.hstack([yTrainTimeDomain, y_trainRTheta])
x_validTimeFreq = np.hstack([xValidTimeDomain, x_validRTheta])
y_validTimeFreq = np.hstack([yValidTimeDomain, y_validRTheta])

In [50]:
x_trainTimeFreq.shape, y_trainTimeFreq.shape, x_validTimeFreq.shape, y_validTimeFreq.shape

((176128, 11397), (176128, 67), (44032, 11397), (44032, 67))

In [49]:
arraySavePath = '/blue/gkalamangalam/jmark.ettinger/predictScalp/timeFreqRTheta.npz'
np.savez(arraySavePath, 
         x_trainTimeFreq=x_trainTimeFreq, 
         x_validTimeFreq=x_validTimeFreq, 
         y_trainTimeFreq=y_trainTimeFreq, 
         y_validTimeFreq=y_validTimeFreq)

# SCRATCH

In [None]:
# stft parameter tests

windowType = ('tukey', .25)

fakeData = np.random.rand(10000, 5)

f, t, S = stft(fakeData, fs=1000, window=windowType, nperseg=100, noverlap=0, axis=0, boundary=None)

f.shape, t.shape, S.shape, f, t

In [None]:
# OLD VERSION FOR TIME DOMAIN

xTrainTimeDomain = dataSubsampled[trainIndices, 0:-1]
yTrainTimeDomain = np.expand_dims(dataSubsampled[trainIndices, -1], axis=1)

xValidTimeDomain = dataSubsampled[validationIndices, 0:-1]
yValidTimeDomain = np.expand_dims(dataSubsampled[validationIndices, -1], axis=1)

xTrainTimeDomain.shape, yTrainTimeDomain.shape, xValidTimeDomain.shape, yValidTimeDomain.shape