In [None]:
import numba
import time
import numpy as np
import pandas as pd

import cupy as cp
from skcuda.fft import fft, ifft, Plan

from sklearn.cluster import KMeans
from sklearn import mixture
from mne.time_frequency import tfr_array_morlet as morlet

from ethotype import config
import tensorflow as tf

In [None]:
def median_filter(df, 
                  window=config.median_filter_window):
    '''
    TODO: add doc-string
    '''
    return df.rolling(center=True,window=window).mean()

def reject_cross_frame_noise(df, 
                             n_stds=config.n_stds):
    '''
    TODO: add doc-string
    '''
    #we need to investigate if this has any affect on down stream intuitions
    global_mask = []
    df2 = df.copy()
    for column in df2:
        feature = df2[column]
        df2[column+'_shift1'] = feature.shift(1)
        stdx = np.std(df2[column])*n_stds
        abs(df2[column]-df2[column+'_shift1'])<abs(stdx)
        mask = df2[abs(df2[column]-df2[column+'_shift1'])>abs(stdx)].index.values.tolist()
        for frame in mask:
            if frame not in global_mask:
                global_mask.append(frame)
    included_frames = df[~df.index.isin(global_mask)].index.tolist()
    df= df[~df.index.isin(global_mask)]
    return df, np.asarray(included_frames)

def get_low_var_cutoff(fv, 
                       frame_index):
    '''
    TODO: add doc-string
    '''
    frame_index= np.asarray(frame_index)
    fv = np.vstack(fv)
    fv_var = np.var(fv, axis=0)
    #gmm = mixture.GaussianMixture(n_components=2,  covariance_type='tied', random_state=42)
    #labels = gmm.fit(fv_var.reshape(-1,1)).predict(fv_var.reshape(-1,1))
    kmeans = KMeans(2, random_state=0)
    labels = kmeans.fit(fv_var.reshape(-1,1)).predict(fv_var.reshape(-1,1))
    if np.mean(fv_var[np.where(labels==1)]) > np.mean(fv_var[np.where(labels==0)]):
        pass
    else:
        labels = np.where(labels==0,1,0)
    fv = np.vstack(fv[:,np.where(labels==1)])
    frames = frame_index[np.where(labels==1)]
    return fv, labels, frames

def get_wavelet_features(data, 
                         sampling_rate = config.sampling_rate, 
                         n_octaves=config.n_octaves, 
                         range_cycles = config.range_cycles):
    '''
    TODO: add doc-string
    '''
    
    #metadata
    channels = [column for column in data] 
    
    #wavelet parameters
    range_frex = [1, sampling_rate/2] #nyquist limit is half of srate
    frex = np.geomspace(range_frex[0],
                                  range_frex[1],
                                  num=n_octaves)
    
    cycles = np.geomspace(range_cycles[0],
                                    range_cycles[1],
                                    num=n_octaves)

    time = np.linspace(-2,2,sampling_rate)
    half_wave = int((len(time))/2)
    
    #FFT parameters
    nkern = len(time)
    nconv = nkern+len(data)
    
    #init output
    tf = []
    for _, chan in enumerate(channels):
        cf = []
        dataX = np.fft.fft(np.asarray(data[chan]), nconv)
        for o, (f, c) in enumerate(zip(frex, cycles)):
            #create wavelet
            sine_wave = np.exp(1j*2*np.pi*f*time) 
            s = c / (2*np.pi*f) 
            gaus_win  = np.exp((-time**2)/(2*s**2)) 
            cmw  = sine_wave * gaus_win
            cmwX = np.fft.fft(cmw,nconv)
            cmwX = cmwX/np.max(cmwX)

            #convolve and trim
            m = np.fft.ifft(dataX * cmwX) 
            m2 = m[half_wave:-half_wave]

            cf.append(abs(m2)**2)
        tf.append(cf)
    return tf


def get_window_features(data, 
                        win_scales=config.win_scales):
    '''
    TODO: add doc-string
    '''
    channels = [column for column in data]
    TF = []
    for _, chan in enumerate(channels):
        cf = []
        for win_scale in win_scales:
            val = data[chan].rolling(center=True,window=win_scale).mean()
            cf.append(val)
        TF.append(cf)
    return(tf)

def sum_norm(fv):
    '''
    TODO: add doc-string
    '''
    return fv/fv.sum(axis=1)[:,None]

@numba.njit(fastmath=True)
def morletConjFT(w, 
                 omega0):
    '''
    TODO: add doc-string
    '''
    return np.pi**(-1/4)*np.exp(-.5*(w-omega0)**2) #eq. C2


def berman2014_wavelet(x, 
                       f, 
                       omega0, 
                       dt):
    '''
    TODO: add doc-string
    '''
    N = len(x)
    L = len(f)
            
    amp = np.zeros((L,N)) #make amplitude container
    if N % 2 == 1: #make the signal divisible by 2

        x = np.append(x, 0)
        N += 1
        test = True
    else:
        test = False
        
    s = x.shape

    if s[0]!=1:
        x = x.T
    x = np.concatenate([np.zeros(int(N/2)), x, np.zeros(int(N/2))]) #pad the signal for convolution
    M = N 
    N = len(x)
    scales = np.divide([omega0 + np.sqrt(2+omega0**2)],4*np.pi*f) #eq. C3
    Omegavals = np.divide([2*np.pi*np.arange(-N/2,N/2,1,dtype='int64')],N*dt)
    
    with tf.Graph().as_default():
        sess = tf.Session() #tf
        x_tensor = tf.cast(x, dtype='complex64') #tf
        xHat = tf.fft(x_tensor) #tf

        #x_gpu = cp.array(x)#cupy
        #print(x_gpu)
        #xHat = cp.fft.fft(x_gpu) #cupy
        #xHat = np.fft.fft(x) #numpy


        if test == True:
            idx = np.arange((M-1)/2+1, M/2+M-1, 1, dtype='uint64')
        else:
            idx = np.arange((M-1)/2+1, M/2+M, 1, dtype='uint64')


        for i in range(0,L): 
            m = morletConjFT(-Omegavals*scales[i], omega0)

            #
            m_tensor = tf.cast(m, dtype='complex64') #tf
            sig_tensor = tf.multiply(m_tensor,xHat) #tf
            q = tf.ifft(sig_tensor) #tf
            q = sess.run(q) #tf


            #m_gpu = cp.array(m*xHat) #cupy
            #q = cp.fft.ifft(cp.array(m_gpu))*cp.sqrt(scales[i]) #cupy
            #q = np.fft.ifft(m*xHat)*np.sqrt(scales[i]) #numpy

            q = q[:,idx]
            amp[i,:] = np.abs(q)*((np.pi**-.25)*np.exp(.25*(omega0-np.sqrt(omega0**2+2)**2)))/np.sqrt(2*scales[i]) #eq. C5 

        sess.close() #tf
    return amp

def compute(df, rej_cross_frame_noise=config.rej_cross_frame_noise, 
            low_var_cutoff=config.low_var_cutoff, 
            features=config.feature_type,window=config.median_filter_window):
    '''
    TODO: add doc-string
    '''
    frames = np.asarray(df.index.tolist())
    
    #standard preprocessing
    df = df.rolling(center=True,window=window).mean()
    df = df.bfill()
    df = df.ffill()
    print('done with preporcessing')
    #optional preprocessing
    if rej_cross_frame_noise==True:
        print('rejecting crossframe noise')
        df, p_frames = reject_cross_frame_noise(df)
    else:
        p_frames = frames
    #wavelet computation
    if features=='sliding_window':
        print('calculating sliding_window')
        fv = get_window_features(df)
    elif features=='morlet_custom':
        fv = get_wavelet_features(df)
        print('calculating custom morlet wavelet')
    elif features=='morlet_standard':
        fv = get_wavelet_features_mne(df)
        print('calculating default morlet wavelet')
    elif features == 'berman2014_wavelet':
        print('calculating berman2014 wavelet')
        channels = [column for column in df]
        f = np.linspace(config.min_freq,config.sampling_rate/2,num=config.n_octaves)
        amps = []
        
        for chan in channels:
            print(chan)
            x = df[chan].as_matrix(columns=None)
            amp = berman2014_wavelet(x, f, omega0=5, dt=1/config.sampling_rate)
            amps.append(amp)
            
        fv = np.vstack(amps)
    else:
        print('unkown feature \'{}\''.format(features))
    #optional postprocessing
    if low_var_cutoff == True:
        fv, hv_labels, hv_frames = get_low_var_cutoff(fv, p_frames)
    else:
        fv = np.vstack(fv)
        hv_frames = p_frames
        hv_labels = np.zeros(len(fv[1]))
    #standard postprocessing
    fv = sum_norm(fv)
    return frames, p_frames, hv_frames, hv_labels, fv

In [None]:
import os

#general
home = os.getcwd()
proj='fromNadja' ###################################################################################### IMPORTANT
input_dir = ''.join([home, '/data/', proj, '/'])
input_ext = '.npz' ##################################################################################### IMPORTANT
output_dir = ''.join([home,'/tmp/',proj,'/'])
output_ext = '.npz'
if not os.path.exists(output_dir):
    os.makedirs(output_dir)
fv_str = '_feature_vector_' #feature vector of processed frames
lb_str = '_labels_' #frame labels of processed frames
fr_str = '_frames_' #all frames
ix_str = '_idx_' #processed frames
ix2_str = '_sub_idx_' #frames which survive pre-processing (low noise) but not post-processing (low variance)[if used]


#wavelet transformation
feature_type = 'berman2014_wavelet' #['berman2014_wavelet', 'morlet_standard','morelt_custom', 'sliding_window'] # suggested: berman2014_wavelet
input_file_list = ''.join([input_dir, '*', input_ext])
posture_channels = ['bone1', 'bone2',
       'bone3', 'bone4', 'bone5', 'bone6', 'bone7', 'bone8', 'bone9', 'bone10', 'bone11']
movement_channels = ['']

min_freq = 1. #Hz we set the minimum frequency to 1/s you can investigate if a longer time span is needed with a fourier transform 
sampling_rate = 140 #Hz ~2 * max. freq (max. freq == Nyquist limit)
n_octaves=25
n_features = len(posture_channels)*n_octaves
range_cycles = [3, 15]
n_stds = 5
rej_cross_frame_noise=False
low_var_cutoff=False
median_filter_window = 3
#rolling window
win_scales = [1,11,31,61,101,151,211,281,361,451,551]

#state-space embedding
n_jobs = -1 #all available threads = -1
tsne_alg = 'parametric_tsne' #['bh_tsne', 'fit_sne', 'parametric_tsne']
overwrite_model=False
cosine_dist_batch_size = 50
resample_size = 3000
perplexity = 200 
tsne_dims = 2
tsne_batch_size = 30
tsne_training_epochs = 100
model_name = ''.join([output_dir,proj,'_',str(perplexity),'_',feature_type,'_',tsne_alg,'_model.h5'])
feature_vectors_file_list = ''.join([output_dir, '*', fv_str, feature_type, output_ext])


#segmentation
max_k=40
ws_thresh=0.0005
ws_bins=100 
ws_sigma=2.0 
ws_kernel=13 
ws_min_dist=0
if low_var_cutoff == True:
    label_base = 1
else:
    label_base = 0
labels_file_list = ''.join([output_dir, '*', lb_str, feature_type, output_ext])
frames_file_list = ''.join([output_dir, '*', fr_str, feature_type, output_ext])
index_file_list = ''.join([output_dir, '*', ix_str, feature_type, output_ext])
sub_index_file_list = ''.join([output_dir, '*', ix2_str, feature_type, output_ext])