In [57]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import gridspec
import seaborn as sns
import os, re
import glob
import cPickle as pkl
import scipy.signal as signal
import math
import pandas as pd
%matplotlib inline
%load_ext autoreload
%autoreload 2 
sns.set_style("white")
import warnings
warnings.filterwarnings('ignore')

from braintv_ephys_dev.workstation.danield import generalephys as ephys
import braintv_ephys_dev.workstation.danield.continuous_traces as traces
import braintv_ephys_dev.workstation.danield.utils as utils

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Get Firing Rate Functions

In [58]:
def count_all_peaks(file_path, sigma_thresh=5):
    bad_channels = []
    traces.skip_channels = np.append(traces.npix_p3_reference_channels, bad_channels)
    
    mm = np.memmap(file_path, dtype=np.int16, mode='r')
    
    num_channels = traces.get_channel_count(file_path.rsplit(os.path.sep, 1)[0], from_channel_map=False)
    chunk = traces.get_chunk(mm, 3., 11., num_channels, sampling_rate = 30000)
    
    norm_chunk, mean_offset = get_norm_chunk(chunk)
    
    peaks = np.zeros((norm_chunk.shape[0] - len(traces.skip_channels), 1))
    dead_channels_passed = 0
    
    for channel in range(norm_chunk.shape[0]):
        if channel not in traces.skip_channels:
            peaks[channel - dead_channels_passed], _, _ = count_channel_peaks(norm_chunk[channel,:], sigma_thresh)
        else:
            dead_channels_passed += 1
        
    return peaks, mean_offset

def count_channel_peaks(signal, sigma_thresh=5):
    
    num_peaks = 0
    indices = []
    peak_values = []
    threshold = sigma_thresh * np.std(signal) + np.mean(signal)
    
    above_thresh = False
    max_point = 0
    
    for index, point in enumerate(signal):
        if above_thresh:
            if point < threshold * 0.5:
                indices.append(peak_index)
                peak_values.append(max_point)
                max_point = 0
                above_thresh = False
            else: 
                if point > max_point:
                    max_point = point
                    peak_index = index
                
        else:
            if point >= threshold:
                above_thresh = True
                num_peaks = num_peaks + 1
                peak_index = index
                max_point = point
                
    return num_peaks, indices, peak_values

def get_norm_chunk(chunk):
    chunk_detrended = np.zeros(chunk.shape)
    mean_offset = np.zeros((chunk.shape[0], 1))
    
    for channel in range(chunk.shape[0]):
        chunk_detrended[channel, :] = butter_highpass_filter(chunk[channel,:], 10, 30000, order=5)
        mean_offset[channel, :] = np.mean(chunk[channel,:])
    
    return chunk_detrended, mean_offset

def butter_highpass(cutoff, fs, order=5):
    nyq = 0.5 * fs
    normal_cutoff = cutoff / nyq
    b, a = signal.bessel(order, normal_cutoff, btype='highpass', analog=False)
    w, h = signal.freqs(b, a)
    return b, a

def butter_highpass_filter(data, cutoff, fs, order=5):
    b, a = butter_highpass(cutoff, fs, order=order)
    y = signal.filtfilt(b, a, data)
    return y

def running_mean(x, N):
    return np.expand_dims(np.convolve(x[:,0], np.ones((N,))/N)[(N-1) / 2:-(N-1) / 2], 1)

## Choose Files for Data

In [59]:
path = os.path.join(r'\\SD1', 'SD1', 'DanD', 'M310016', 'localization')

filenames_spikeband = glob.glob(os.path.join(path,'*','*0_0.dat'))
filenames_lfp = glob.glob(os.path.join(path,'*','*1_0.dat'))

phase = 3
bad_channels = []
traces.skip_channels = np.append(traces.npix_p3_reference_channels, bad_channels)

In [60]:
print traces.skip_channels

[  36.   75.  112.  151.  188.  227.  264.  303.  340.  379.]


## Get Labels

In [61]:
labels_path = glob.glob(os.path.join(path, 'labels.pkl'))[0]
labels = pkl.load(open(labels_path))

## This just takes the first dictionary keys and shortens them down a little to just something like 
## M310016_2017-06-15_08-10-38_1 instead of the whole filename
for key in labels.iterkeys():
    labels[key.split('/')[-1]] = labels.pop(key)

In [62]:
## Display all the annotations for a quick reference when I need to look at something
for i, dirname in enumerate(filenames_lfp):
    dirkey = dirname.split(os.path.sep)[-2]
    print dirkey
    for structure in labels[dirkey].iterkeys():
        print '             ', structure, labels[dirkey][structure][0], labels[dirkey][structure][1]

M310016_2017-06-15_08-10-38_1
              nucleus accumbens 0 20
              caudate putamen 21 171
              primary motor cortex 202 350
              above 351 383
              white matter 172 201
M310016_2017-06-15_08-43-17_2
              caudate putamen 72 185
              primary motor cortex 215 345
              above 346 383
              white matter 186 214
              globus pallidus 0 71
M310016_2017-06-15_09-11-15_3
              secondary motor cortex 184 331
              lateral ventricle 44 160
              above 332 383
              white matter 161 183
              fimbria 0 43
M310016_2017-06-15_09-44-15_4
              white matter 244 267
              LDVL 69 104
              fimbria 105 146
              primary somatosensory cortex 268 339
              lateral ventricle 147 243
              above 340 383
              VPM 0 37
              Po 38 68
M310016_2017-06-15_09-55-55_4b
              white matter 244 267
              LDVL 69 104


## GETTING DATA (NOT CHUNKED)

Meaning for each channel individually. Chunking it comes later on.

In [43]:
target_all_ = np.empty((0,1))
rms_spike_all_ = np.empty((0,1))
rms_lfp_all_ = np.empty((0,1))
gamma_all_ = np.empty((0,1))
alpha_all_ = np.empty((0,1))
beta_all_ = np.empty((0,1))
delta_all_ = np.empty((0,1))
theta_all_ = np.empty((0,1))
peaks_all_ = np.empty((0,1))
mean_offset_all_ = np.empty((0,1))
row_all_ = np.empty((0,1))

for i in range(len(filenames_lfp)):
    spike_file = filenames_spikeband[i]
    lfp_file = filenames_lfp[i]
    dirkey = spike_file.split(os.path.sep)[-2]
    print dirkey, i
    
    rms_spike_ = np.expand_dims(traces.get_probe_freq(spike_file, frequency_range=[300, 30000]), 1)
    rms_lfp_ = np.expand_dims(traces.get_probe_freq(lfp_file, frequency_range=[0, 300]), 1)
    gamma_ = np.expand_dims(traces.get_probe_freq(lfp_file, frequency_range=[30, 40]), 1)
    alpha_ = np.expand_dims(traces.get_probe_freq(lfp_file, frequency_range=[8, 12]), 1)
    beta_ = np.expand_dims(traces.get_probe_freq(lfp_file, frequency_range=[12, 30]), 1)
    delta_ = np.expand_dims(traces.get_probe_freq(lfp_file, frequency_range=[0, 4]), 1)
    theta_ = np.expand_dims(traces.get_probe_freq(lfp_file, frequency_range=[4, 8]), 1)
    peaks_, mean_offset_ = count_all_peaks(spike_file, 4)
    peaks_ = running_mean(peaks_, 10)
    row_ = np.expand_dims([pix // 2 for pix in range(384)], 1)
    
    target_ = ["" for j in range(384)]
    
    for structure in labels[dirkey].iterkeys():
        for index in np.arange(labels[dirkey][structure][0], labels[dirkey][structure][1] + 1):
            target_[index] = structure
    
    for skip in sorted(traces.skip_channels, reverse=True):
        target_.pop(int(skip))
        row_ = np.delete(row_, int(skip), 0)
        mean_offset_ = np.delete(mean_offset_, int(skip), 0)
    
    ## 82nd channel is always weird and can't add it to bad channels because then a different channel becomes weird, and so on
    rms_spike_[82,:] = (rms_spike_[81,:] + rms_spike_[83,:]) / 2.0
    rms_lfp_[82,:] = (rms_lfp_[81,:] + rms_lfp_[83,:]) / 2.0
    gamma_[82,:] = (gamma_[81,:] + gamma_[83,:]) / 2.0
    alpha_[82,:] = (alpha_[81,:] + alpha_[83,:]) / 2.0
    beta_[82,:] = (beta_[81,:] + beta_[83,:]) / 2.0
    delta_[82,:] = (delta_[81,:] + delta_[83,:]) / 2.0
    theta_[82,:] = (theta_[81,:] + theta_[83,:]) / 2.0
    
    rms_spike_all_ = np.append(rms_spike_all_, rms_spike_, axis=0)
    rms_lfp_all_ = np.append(rms_lfp_all_, rms_lfp_, axis=0)
    gamma_all_ = np.append(gamma_all_, gamma_, axis=0)
    alpha_all_ = np.append(alpha_all_, alpha_, axis=0)
    beta_all_ = np.append(beta_all_, beta_, axis=0)
    delta_all_ = np.append(delta_all_, delta_, axis=0)
    theta_all_ = np.append(theta_all_, theta_, axis=0)
    peaks_all_ = np.append(peaks_all_, peaks_, axis=0)
    mean_offset_all_ = np.append(mean_offset_all_, mean_offset_, axis=0)
    row_all_ = np.append(row_all_, row_, axis=0)
    target_all_ = np.append(target_all_, np.expand_dims(target_, 1), axis=0)

all_bands_data = np.concatenate((rms_spike_all_, rms_lfp_all_, gamma_all_, alpha_all_, 
                                 beta_all_, delta_all_, theta_all_, peaks_all_, 
                                 mean_offset_all_, row_all_), axis=1)

## fixes the double labeling of white matter... kinda hacky though (where there were two dict entries for white matter)
for index in np.where(target_all_ == '')[0]:
    target_all_[index] = 'white matter'

M310016_2017-06-15_08-10-38_1 0
M310016_2017-06-15_08-43-17_2 1
M310016_2017-06-15_09-11-15_3 2
M310016_2017-06-15_09-44-15_4 3
M310016_2017-06-15_09-55-55_4b 4
M310016_2017-06-15_10-25-04_5 5
M310016_2017-06-15_10-52-16_6 6
M310016_2017-06-15_11-21-18_7 7
M310016_2017-06-15_11-51-37_8 8


In [44]:
print all_bands_data.shape

(3366L, 10L)


## CHUNKING IT

Averaging channels based on structure

In [47]:
all_bands_chunked_data = np.empty((0,10))
targets_chunked_data = np.empty((0,1))

prevStruct = target_all_[0,0]
first_index = 0

for i, structure in enumerate(target_all_[:,0]):
    if prevStruct != structure:
        struct_avg = np.expand_dims(np.mean(all_bands_data[first_index:i], axis=0), 0)
        all_bands_chunked_data = np.append(all_bands_chunked_data, struct_avg, axis=0)
        targets_chunked_data = np.append(targets_chunked_data, [[target_all_[first_index,0]]], axis=0)
        
        first_index = i
    prevStruct = structure
    
struct_avg = np.expand_dims(np.mean(all_bands_data[first_index:], axis=0), 0)
all_bands_chunked_data = np.append(all_bands_chunked_data, struct_avg, axis=0)
targets_chunked_data = np.append(targets_chunked_data, [[target_all_[first_index,0]]], axis=0)

## SAVING AS CSV

In [49]:
columns = ['rms_spike', 'rms_lfp', 'gamma', 'alpha', 'beta', 'delta', 'theta', 'peaks', 'mean_offset', 'row']

chunked_banded_data = pd.DataFrame(all_bands_chunked_data, columns=columns)
chunked_banded_data['targets'] = pd.Series(np.ravel(targets_chunked_data), index=chunked_banded_data.index)

banded_data = pd.DataFrame(all_bands_data, columns=columns)
banded_data['targets'] = pd.Series(np.ravel(target_all_), index=banded_data.index)

In [53]:
chunked_banded_data.to_csv('chunked_banded_data.csv', index=False)
banded_data.to_csv('not_chunked_banded_data.csv', index=False)

In [54]:
chunked_banded_data

Unnamed: 0,rms_spike,rms_lfp,gamma,alpha,beta,delta,theta,peaks,mean_offset,row,targets
0,0.101283,9.261163,21.87188,112.888103,50.774044,50.729492,88.434784,23.671429,-222.140674,4.761905,nucleus accumbens
1,0.09513,6.771591,11.278967,96.148983,38.25924,42.515525,75.862852,26.408844,-141.787697,47.782313,caudate putamen
2,0.089123,5.321826,5.639884,86.472363,31.429134,37.815071,68.72194,25.234483,-112.707936,92.965517,white matter
3,0.07413,5.291393,4.128761,95.439975,28.738385,42.346282,78.121547,27.545517,-92.783981,137.648276,primary motor cortex
4,0.038933,0.206789,0.198206,1.43869,0.650116,0.663794,1.15093,9.678125,-57.989478,183.0625,above
5,0.0,9.517997,11.340382,156.73127,58.010387,65.771295,121.746655,22.749296,-89.943596,17.492958,globus pallidus
6,0.0,10.066504,9.401824,169.777814,59.28108,77.930793,135.888588,17.487387,-87.073609,64.216216,caudate putamen
7,0.0,9.858323,8.1203,170.975109,58.007301,84.156709,141.054339,16.989286,-61.933477,99.964286,white matter
8,0.0,8.316975,7.218368,152.577432,45.779511,59.673571,122.182341,31.574803,-9.500173,139.692913,primary motor cortex
9,0.0,0.374178,0.51253,3.012869,1.458361,1.60014,2.504281,8.513514,34.913635,181.810811,above


In [55]:
banded_data

Unnamed: 0,rms_spike,rms_lfp,gamma,alpha,beta,delta,theta,peaks,mean_offset,row,targets
0,0.106462,9.859172,23.575648,119.878515,53.506656,53.313478,93.857327,12.6,-641.597256,0.0,nucleus accumbens
1,0.105721,8.984972,23.073113,101.585668,48.968343,44.683833,78.609154,14.3,-547.935525,0.0,nucleus accumbens
2,0.106754,10.190824,24.035342,125.632371,55.169640,56.694359,98.693777,16.9,-231.857207,1.0,nucleus accumbens
3,0.095615,9.390558,23.133328,110.594206,51.042836,49.411583,86.354161,19.1,-255.021822,1.0,nucleus accumbens
4,0.103743,8.453562,21.612603,95.285914,45.893032,42.715032,73.900220,21.0,-198.052282,2.0,nucleus accumbens
5,0.116564,9.199566,22.766959,106.761756,50.059377,47.904082,83.312473,24.9,-211.575669,2.0,nucleus accumbens
6,0.104068,9.347468,22.192652,111.788921,51.457651,50.819164,87.486795,24.5,-200.024766,3.0,nucleus accumbens
7,0.105680,8.892688,21.637862,104.135391,48.218347,46.429322,81.148240,25.1,-210.064767,3.0,nucleus accumbens
8,0.096895,9.192317,21.342597,112.093293,50.612174,50.408794,87.734038,25.0,-157.479307,4.0,nucleus accumbens
9,0.105884,9.312022,22.058888,112.850188,51.132626,50.623345,88.413182,25.4,-192.606698,4.0,nucleus accumbens
