# Survival analysis with LFP spectral features

### Stops are events, pops are censored

In [1]:
import numpy as np
import pandas as pd
import physutils
import dbio
import os
from __future__ import division
import matplotlib.pyplot as plt
%matplotlib inline

plt.style.use('ggplot')
np.random.seed(123456)

## Load Data

In [2]:
# which patient, dataset to plot
dtup = 12, 1

In [3]:
# load data
dbname = os.path.expanduser('data/bart.hdf5')
lfpraw = dbio.fetch_all_such_LFP(dbname, *dtup)

## Preprocess Data

In [4]:
# remove global mean across all channels at each time then set each channel to mean 0
if lfpraw.shape[1] > 1:
    lfpraw = lfpraw.demean_global()

lfp = lfpraw.demean()    

In [5]:
filters = ['delta', 'theta', 'alpha', 'beta', 'gamma']
lfp = lfp.bandlimit(filters)

In [6]:
# decimate to 40 Hz, get instantaneous power, censor, and z-score each channel
# lfp = lfp.decimate(5).instpwr().censor().zscore()

In [7]:
# decimate to 10 Hz
lfp = lfp.decimate(5).decimate(4).instpwr()

In [8]:
# get events
evt = dbio.fetch(dbname, 'events', *dtup)
cols = ['banked', 'popped', 'start inflating', 'trial_type']

if 'is_control' in evt.columns:
    evt_tmp = evt.query('is_control == False')[cols]
else:
    evt_tmp = evt.loc[:, cols]

# add a binary column (1 = voluntary stop)    
evt_tmp['event'] = np.isnan(evt_tmp['popped']).astype('int')

# add a column for stop time (regardless of cause)
evt_tmp['stop'] = evt.loc[:, ['banked', 'popped']].mean(axis=1)

# drop unneeded columns
evt_tmp = evt_tmp.drop(['banked', 'popped'], axis=1)
evt_tmp = evt_tmp.rename(columns={'start inflating': 'start'})

evt_tmp.head()

Unnamed: 0_level_0,start,trial_type,event,stop
trial,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0,17.999,1,1,26.455
1,32.879,1,0,40.688
2,43.654,2,1,47.34
3,50.531,1,1,59.901
4,64.754,1,0,72.825


## Remove unneeded data

- take only non-control trials
- get only time points between trial start and event (pop or stop)

In [18]:
chunks = []
for trial, row in evt_tmp.iterrows():
    start, stop = row['start'], row['stop']
    this_chunk = lfp.loc[start:stop].copy()
    if not this_chunk.empty:
        this_chunk['event'] = 0  # no event until the last bin
        event_idx = list(this_chunk.columns).index('event')
        this_chunk.iloc[-1, this_chunk.columns.get_loc('event')] = int(row['event'])  # set last bin correctly
        this_chunk['ttype'] = int(row['trial_type'])
        this_chunk['rel_time'] = this_chunk.index - this_chunk.index[0]

        chunks.append(this_chunk)

# concatenate chunks, make non-power events their own series    
meanpwr = pd.concat(chunks)
event = meanpwr['event']
time_in_trial = meanpwr['rel_time']
ttype = pd.get_dummies(meanpwr['ttype'])
ttype.columns = ['ttype' + str(idx) for idx in ttype.columns]
meanpwr = meanpwr.drop(['event', 'ttype', 'rel_time'], axis=1)

In [19]:
# standardize
meanpwr = meanpwr.apply(lambda x: (x - x.mean())/x.std())

In [20]:
# make interaction terms and squares
int_terms = []
for i in range(len(meanpwr.columns)):
    for j in range(i + 1):
        if i == j:
            col = meanpwr.iloc[:, i] ** 2
            band, chan = col.name.split('.')
            col.name = "{}.{}.{}.{}".format(band, chan, band, chan)
        else:
            icol = meanpwr.iloc[:, i]
            jcol = meanpwr.iloc[:, j]
            col = icol * jcol
            iband, ichan = icol.name.split('.')
            jband, jchan = jcol.name.split('.')
            col.name = "{}.{}.{}.{}".format(iband, ichan, jband, jchan)
        
        col = (col - col.mean())/col.std()
        int_terms.append(col)

In [21]:
trainset = pd.concat([event, time_in_trial, ttype, meanpwr] + int_terms, axis=1, join='inner')
# trainset = trainset.dropna()  # can't send glmnet any row with a NaN
trainset.head()

Unnamed: 0_level_0,event,rel_time,ttype1,ttype2,ttype3,delta.17,theta.17,alpha.17,beta.17,gamma.17,...,alpha.17.alpha.17,beta.17.delta.17,beta.17.theta.17,beta.17.alpha.17,beta.17.beta.17,gamma.17.delta.17,gamma.17.theta.17,gamma.17.alpha.17,gamma.17.beta.17,gamma.17.gamma.17
time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
18.0,0,0.0,1,0,0,0.269038,-0.040771,-0.462809,-0.347926,0.236346,...,-0.38021,-0.405307,-0.464791,-0.403717,-0.423757,-0.235965,-0.465955,-0.546282,-0.53051,-0.473759
18.1,0,0.1,1,0,0,-0.095846,0.354851,-0.332959,-0.16946,0.031133,...,-0.430222,-0.331344,-0.502264,-0.454406,-0.468286,-0.283916,-0.453913,-0.495039,-0.490109,-0.501309
18.2,0,0.2,1,0,0,-0.488489,0.905684,-0.042367,0.14897,-0.505327,...,-0.483009,-0.391278,-0.403913,-0.484808,-0.471433,-0.103962,-0.726813,-0.478596,-0.52686,-0.373628
18.3,0,0.3,1,0,0,0.689783,1.380112,0.252423,0.438889,-0.674309,...,-0.453039,-0.138441,-0.166524,-0.428062,-0.389242,-0.6168,-1.002182,-0.577761,-0.642709,-0.273577
18.4,0,0.4,1,0,0,3.05334,1.535812,0.331905,0.572697,0.078537,...,-0.430561,0.835113,-0.028447,-0.389635,-0.323965,-0.109037,-0.390117,-0.476185,-0.463726,-0.498699


In [22]:
print trainset.shape

(4110, 25)


In [23]:
np.sum(trainset.event)

54