In [1]:
%matplotlib notebook
import mne
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from scipy.signal import butter, cheby1, filtfilt

from scipy.signal import spectrogram, stft, istft, check_NOLA
import pickle

# IMPORT DATA AND PREPROCESS

In [2]:
# 31, 35, 38 and test folders for good data
# odd scalp on left, even on right

# SPECIFY PATIENT AND SCALP TARGET ELECTRODE
patient = 'UFSEEG031'
mode = 'Sleep'
targetScalpElectrodes = [] # assigned below

# SPECIFY ARTIFACT ELECTRODES FROM WORD FILES
artifactElectrodes = {}
artifactElectrodes['UFSEEG031'] = ['LTP7', 'LTP8', 'LAH11', 'LAH12', 'LPH10', 'LPH11', 'LPH12','LOF15', 'LOF16']

#filepath = '/blue/gkalamangalam/ALLDATA/SEEG/%s/SEEG/EDF/TestClipSleep.edf' % patient
filepath = '/blue/gkalamangalam/ALLDATA/SEEG/%s/SEEG/EDF/TestClipSleep/TestClip%s.edf' % (patient, mode)

raw = mne.io.read_raw_edf(filepath,preload=True)
sfreq = int(raw.info['sfreq'])

scalpElectrodes = {}
scalpElectrodes[patient] = [i for i in raw.ch_names if len(i) == 2 and i !='PR']
targetScalpElectrodes = scalpElectrodes[patient]
print()
print(scalpElectrodes)
print(raw)
print(raw.info)

Extracting EDF parameters from /blue/gkalamangalam/ALLDATA/SEEG/UFSEEG031/SEEG/EDF/TestClipSleep/TestClipSleep.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 1766015  =      0.000 ...  1724.624 secs...

{'UFSEEG031': ['F7', 'F8', 'F3', 'F4', 'C3', 'C4', 'P7', 'P8', 'P3', 'P4']}
<RawEDF | TestClipSleep.edf, 148 x 1766016 (1724.6 s), ~1.95 GB, data loaded>
<Info | 7 non-empty values
 bads: []
 ch_names: LTP1, LTP2, LTP3, LTP4, LTP5, LTP6, LTP7, LTP8, LAM1, LAM2, ...
 chs: 148 EEG
 custom_ref_applied: False
 highpass: 0.0 Hz
 lowpass: 512.0 Hz
 meas_date: 2001-01-01 02:05:22 UTC
 nchan: 148
 projs: []
 sfreq: 1024.0 Hz
>


In [3]:
# DISCARD ALL CHANNELS EXCEPT GOOD SEEG CHANNELS AND THE SCALP TARGETS

channels = [i for i in raw.ch_names if i not in artifactElectrodes[patient] and i[0] in {'L', 'R'}] + targetScalpElectrodes
raw.pick_channels(channels)#.plot(duration=5.0, n_channels=20);

nScalp = len(targetScalpElectrodes)
nDepth = len(channels) - nScalp

In [4]:
# LOWPASS FILTER THE DATA, SUBSAMPLE THE DATA, SCALE ALL CHANNELS, AND EXTRACT TO NUMPY ARRAY

filterWindow = 64
subsampleFreq = filterWindow * 2   # FINAL FREQUENCY IN HERTZ AFTER SUBSAMPLING
filterOrder = 5
goodTimes = [] # set this below

df = raw.to_data_frame().drop(labels=['time'], axis=1)
data = df.to_numpy()

# SHOULD WE SCALE HERE OR AFTER FILTER AND SUBSAMPLE?
scaler = StandardScaler()
data = scaler.fit_transform(data)
goodTimes = range(int(data.shape[0] * .75))
data = data[goodTimes,:]

b, a = butter(filterOrder, filterWindow, btype='lowpass', fs = sfreq)
data = filtfilt(b, a, data, axis=0)
    
dataSubsampled = data[::sfreq // subsampleFreq,:]

# SEE QUESTION ABOVE...
#scaler = StandardScaler()
#dataSubsampled = scaler.fit_transform(dataSubsampled)

pd.DataFrame(dataSubsampled, columns=df.columns)

Unnamed: 0,LTP1,LTP2,LTP3,LTP4,LTP5,LTP6,LAM1,LAM2,LAM3,LAM4,...,F7,F8,F3,F4,C3,C4,P7,P8,P3,P4
0,-0.667612,-0.549802,-0.330502,-0.630513,-0.410391,0.427998,-1.039937,-1.478466,-1.065475,-0.822086,...,1.470390,-0.025125,1.390247,0.125900,1.174324,0.110210,0.801922,0.483272,0.688382,0.014404
1,-0.738084,-0.652242,-0.419590,-0.709839,-0.516951,0.324686,-1.225429,-1.600105,-1.230353,-0.989760,...,1.058312,-0.334411,0.700480,-0.204952,0.502855,-0.421465,0.383678,0.059099,0.300875,-0.530017
2,-0.699389,-0.571008,-0.339067,-0.611417,-0.399328,0.375402,-1.236690,-1.564062,-1.189561,-0.916646,...,1.353298,0.152961,1.309031,0.351976,1.052420,0.129003,0.763821,0.433957,0.837160,-0.051506
3,-0.780354,-0.592758,-0.324217,-0.600832,-0.394643,0.179410,-1.170003,-1.549847,-1.217005,-0.944660,...,0.822709,-0.328529,0.589093,-0.180698,0.267894,-0.448504,0.238265,-0.164788,0.363592,-0.537998
4,-0.795788,-0.529620,-0.255734,-0.530419,-0.319918,0.073355,-1.070847,-1.472654,-1.120555,-0.801174,...,1.034713,0.029360,0.954408,0.305535,0.779454,0.147909,0.522795,0.216301,0.770463,-0.077564
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
165559,-0.583109,-0.815102,-0.634595,-0.452057,0.254602,-1.260952,-0.286024,-0.420071,0.938973,0.252986,...,-0.991949,-1.031649,-1.388580,-0.616400,-0.453551,0.229730,-0.204140,-0.331099,0.076937,0.057757
165560,-0.510411,-0.754459,-0.584003,-0.383257,0.308688,-1.234871,-0.314457,-0.457116,0.909562,0.243240,...,-1.039219,-1.042088,-1.346330,-0.495239,-0.467845,0.417203,-0.261942,-0.282041,0.007809,0.203710
165561,-0.529892,-0.761683,-0.553680,-0.359697,0.306900,-1.256808,-0.417568,-0.501714,0.852580,0.183756,...,-1.259204,-1.335702,-1.515425,-0.817458,-0.860090,0.035316,-0.623140,-0.494389,-0.331645,-0.073944
165562,-0.468182,-0.684009,-0.440543,-0.250115,0.376794,-1.165449,-0.385552,-0.467224,0.916943,0.233544,...,-1.048659,-1.152250,-1.102170,-0.431696,-0.540641,0.242423,-0.460304,-0.371930,-0.145180,0.040611


In [5]:
# PARTITION TIME SERIES INTO CONTIGUOUS TRAIN AND VALIDATION BLOCKS
# OTHERWISE WHEN VALID SET IS RANDOMLY DISTRIBUTED, IMPLICIT OVERFITTING OCCURS DUE TO TEMPORAL DEPENDENCY
# EACH BLOCK (SPECIFIED IN SECONDS) IS DIVIDED INTO TRAIN (SPECIFIED BY FRACTION, COMES FIRST) AND VALIDIDATION (COMES LAST)

def timeseriesTrainValidSplit(secondsInBlock, totalSeconds, trainFraction, subsampleFreq):
    nBlock = int(totalSeconds / secondsInBlock)
    samplesPerBlock = subsampleFreq * secondsInBlock
    trainIndexProto = np.arange(0, samplesPerBlock * trainFraction, dtype=int)
    validIndexProto = np.arange(samplesPerBlock * trainFraction, samplesPerBlock, dtype=int)

    trainIndexBlocks = [(trainIndexProto + (i * samplesPerBlock)).astype(int) for i in range(nBlock)]
    validIndexBlocks = [(validIndexProto + (i * samplesPerBlock)).astype(int) for i in range(nBlock)]

    trainIndices = np.concatenate(trainIndexBlocks).astype(int)
    validationIndices = np.concatenate(validIndexBlocks).astype(int)
    #return trainIndices, validationIndices
    return trainIndexBlocks, validIndexBlocks, trainIndices, validationIndices

def timeDomainDataMake(indexBlocks, halfWindow, dataSubsampled):
    xList = []
    yList = []
    _, nChannels = dataSubsampled.shape
    for thisBlock in indexBlocks:
        thisData = np.vstack([np.zeros((halfWindow, nChannels)), 
                             dataSubsampled[thisBlock,:], 
                             np.zeros((halfWindow, nChannels))])
        
        for t in range(0, len(thisBlock)):
            thisX = thisData[t:t + (2 * halfWindow) + 1,0:nDepth].flatten()
            thisY = thisData[t + halfWindow, nDepth:]
            xList.append(thisX)
            yList.append(thisY)

    xTimeDomain = np.stack(xList, axis = 0)
    #yTimeDomain = np.expand_dims(np.array(yList), axis=1)
    yTimeDomain = np.stack(yList, axis = 0)
    return xTimeDomain, yTimeDomain

# INDICES FOR TRAIN/VALIDIDATION SPLIT

In [6]:
secondsInBlock = 5
trainFraction = 0.8
totalSeconds = int(dataSubsampled.shape[0]/subsampleFreq)

if trainFraction < 1.0:
    trainIndexBlocks, validIndexBlocks, trainIndices, validationIndices = timeseriesTrainValidSplit(secondsInBlock, 
                                                                                                    totalSeconds, 
                                                                                                    trainFraction, 
                                                                                                    subsampleFreq)

# TIME DOMAIN DATA

In [7]:
halfWindowSeconds = .25

halfWindowSamples = int(halfWindowSeconds * subsampleFreq)

if trainFraction < 1.0:
    xTrainTimeDomain, yTrainTimeDomain = timeDomainDataMake(trainIndexBlocks, halfWindowSamples, dataSubsampled)
    xValidTimeDomain, yValidTimeDomain = timeDomainDataMake(validIndexBlocks, halfWindowSamples, dataSubsampled)

else:
    trainIndexBlocks = [np.array(range(0, dataSubsampled.shape[0]))]
    xTrainTimeDomain, yTrainTimeDomain = timeDomainDataMake(trainIndexBlocks, halfWindowSamples, dataSubsampled)
    xValidTimeDomain, yValidTimeDomain = np.array([]), np.array([])
    
print(xTrainTimeDomain.shape, yTrainTimeDomain.shape, xValidTimeDomain.shape, yValidTimeDomain.shape)

(132096, 5655) (132096, 10) (33024, 5655) (33024, 10)


In [8]:
arraySavePath = '/blue/gkalamangalam/jmark.ettinger/predictScalp/timeDomain_%s_%s_%s.npz' % (patient, targetScalpElectrodes, mode)
np.savez(arraySavePath, 
         xTrainTimeDomain=xTrainTimeDomain, 
         xValidTimeDomain=xValidTimeDomain,
         yTrainTimeDomain=yTrainTimeDomain,
         yValidTimeDomain=yValidTimeDomain)

# STFT DATA

In [None]:
# APPLY SHORT TERM FOURIER TRANSFORM TO THE DATA AND CHECK PARAMETERS FOR INVERTABILITY

secondsInSTFTWindow = .5
nperseg = subsampleFreq * secondsInSTFTWindow
noverlap = nperseg - 1
windowType = ('tukey', .25)

f, t, S = stft(dataSubsampled, fs=subsampleFreq, window=windowType, nperseg=nperseg, noverlap=noverlap, axis=0)

print('freq, ', 'time, ', 'stft shape')
print(f.shape, t.shape, S.shape)
print('inverse ok? ',check_NOLA(windowType, nperseg, noverlap))

In [None]:
x_trainComplex = S[:, 0:-1, trainIndices].transpose([2,0,1])
y_trainComplex = S[:, -1, trainIndices].transpose()

x_validComplex = S[:, 0:-1, validationIndices].transpose([2,0,1])
y_validComplex = S[:, -1, validationIndices].transpose()

# MAKE REAL-VALUED TRAINING DATA BY CONVERTING STFT COMPLEX NUMBERS TO R,THETA
_,_,numCol = x_trainComplex.shape
x_trainRTheta = np.hstack([np.hstack([np.abs(x_trainComplex[:,:,i]), 
                                      np.angle(x_trainComplex[:,:,i])]) for i in range(numCol)])
x_validRTheta = np.hstack([np.hstack([np.abs(x_validComplex[:,:,i]), 
                                      np.angle(x_validComplex[:,:,i])]) for i in range(numCol)])

y_trainRTheta = np.hstack([np.abs(y_trainComplex), np.angle(y_trainComplex)])
y_validRTheta = np.hstack([np.abs(y_validComplex), np.angle(y_validComplex)])

_, nY = y_trainRTheta.shape
x_trainRTheta.shape, x_validRTheta.shape, y_trainRTheta.shape, y_validRTheta.shape

In [None]:
# PLOT THE STFT OF A TIME SERIES (MAGNITUDE ONLY)

index = -1 # -1 is the target
vmax = .2

plt.figure()
plt.pcolormesh(t, f, np.abs(S[:,index,:]), shading='auto', cmap='hot', vmin=0, vmax=vmax)
plt.ylabel('Frequency [Hz]')
plt.xlabel('Time [sec]')
plt.title('Index: %s' % str(index))
plt.show()

In [None]:
arraySavePath = '/blue/gkalamangalam/jmark.ettinger/predictScalp/freqRTheta_%s_%s_%s.npz' % (patient, targetScalpElectrode, mode)
np.savez(arraySavePath, 
         x_trainRTheta=x_trainRTheta, 
         x_validRTheta=x_validRTheta, 
         y_trainRTheta=y_trainRTheta, 
         y_validRTheta=y_validRTheta)

# COMBINE TIME AND FREQUENCY DOMAIN DATA

In [None]:
x_trainTimeFreq = np.hstack([xTrainTimeDomain, x_trainRTheta])
y_trainTimeFreq = np.hstack([yTrainTimeDomain, y_trainRTheta])
x_validTimeFreq = np.hstack([xValidTimeDomain, x_validRTheta])
y_validTimeFreq = np.hstack([yValidTimeDomain, y_validRTheta])

In [None]:
x_trainTimeFreq.shape, y_trainTimeFreq.shape, x_validTimeFreq.shape, y_validTimeFreq.shape

In [None]:
arraySavePath = '/blue/gkalamangalam/jmark.ettinger/predictScalp/timeFreqRTheta_%s_%s_%s.npz' % (patient, targetScalpElectrode, mode)
np.savez(arraySavePath, 
         x_trainTimeFreq=x_trainTimeFreq, 
         x_validTimeFreq=x_validTimeFreq, 
         y_trainTimeFreq=y_trainTimeFreq, 
         y_validTimeFreq=y_validTimeFreq)

# SCRATCH

In [None]:
# stft parameter tests

windowType = ('tukey', .25)

fakeData = np.random.rand(10000, 5)

f, t, S = stft(fakeData, fs=1000, window=windowType, nperseg=100, noverlap=0, axis=0, boundary=None)

f.shape, t.shape, S.shape, f, t

In [None]:
# OLD VERSION FOR TIME DOMAIN

xTrainTimeDomain = dataSubsampled[trainIndices, 0:-1]
yTrainTimeDomain = np.expand_dims(dataSubsampled[trainIndices, -1], axis=1)

xValidTimeDomain = dataSubsampled[validationIndices, 0:-1]
yValidTimeDomain = np.expand_dims(dataSubsampled[validationIndices, -1], axis=1)

xTrainTimeDomain.shape, yTrainTimeDomain.shape, xValidTimeDomain.shape, yValidTimeDomain.shape