In [1]:
import numpy as np
import h5py

import sys
from joblib import Parallel,delayed
import time

import re
import os

import cv2

from scipy.optimize import curve_fit

In [2]:
# Do not forget to change the parameters also in step 6

In [3]:
input_data_path = '../Step_2/'

In [4]:
fs          = 25   # kHz
gauss_blurr = 0.1  # ms
upper_thr   = 0.1  # Probability
lower_thr   = 0.05 # Probability

binning_y   = 1    # sec
binning_x   = 0.1  # ms
smooth_y    = 45   # Smoothing over time

first_x_ms  = 25   # Save the first 25 ms

thr1        = 15   # Hysteresis thrhold top
thr2        = 3    # Hysteresis thrhold bottom
sigma       = 2    # starting smoothness of fit

min_ridge_length = 200 # Minimal ridge length

In [5]:
filenames = []
for (_,_,files) in os.walk(input_data_path):
    for filename in files:
        if '.h5' in filename:
            filenames.append(filename)

In [6]:
def fit(x,y,mu,sigma):
    def f(x,a,b,c):
        return a * np.exp(-(x-b)**2/(2*c*c))
    
    try:
        param,_ = curve_fit(f,x,y,p0=[1,mu,sigma],maxfev=1000)
        return int(min(max(param[1],x[0]),x[-1])),min(abs(param[2]),3/binning_x)
    except:
        return mu,sigma

In [7]:
def argmax2d(array):
    x = array.shape[0]
    y = array.shape[1]
    m = np.argmax(np.reshape(array,x*y))
    return m//y,m%y

In [8]:
def delete_line(x,y,array,thr2):
    y_start = np.copy(y)
    while y < array.shape[1] and array[x,y] > thr2:
        array[x,y] = 0
        y += 1
        
    y = y_start - 1
    while y >= 0 and array[x,y] > thr2:
        array[x,y] = 0
        y -= 1
        
    return array

In [9]:
def smoothing_data(filename,save_images=False):
    t0 = time.time()

    f_in  = h5py.File(input_data_path+filename,'r')
    f_out = h5py.File(filename,'w')

    for circuit in range(15):
        print(circuit)
        circuit_group = f_out.create_group('Circuit_'+str(circuit))
        patterns = np.array(list(f_in['Circuit_' + str(circuit)].keys()))
        for pattern in patterns:
            pattern_group = circuit_group.create_group(pattern)
            segments = np.array(list(f_in['Circuit_' + str(circuit) + '/' + pattern].keys()))
            for segment in segments:
                segment_group = pattern_group.create_group(segment)
                start_id  = np.array(f_in['Circuit_' + str(circuit) + '/' + pattern + '/' + segment + '/Start_ID'])
                stop_id   = np.array(f_in['Circuit_' + str(circuit) + '/' + pattern + '/' + segment + '/Stop_ID'])
                data      = np.array(f_in['Circuit_' + str(circuit) + '/' + pattern + '/' + segment + '/Spikes'])
                artifacts = np.array(f_in['Circuit_' + str(circuit) + '/' + pattern + '/' + segment + '/Artifacts'])

                # Create the binned dataset
                bins_x   = int(np.ceil(250/binning_x)+0.5)
                bins_y   = int(np.ceil((stop_id-start_id)/(binning_y*4))+0.5)
                binned   = np.zeros((bins_y,bins_x,4))

                for i in range(data.shape[0]):
                    bins_x = int(data[i,2]/(binning_x*fs))
                    bins_y = int(data[i,0]/(binning_y*4))
                    elec   = int(data[i,1]+0.5)
                    binned[bins_y,bins_x,elec] += 1

                binned   = binned/(binning_y*4)

                # Smooth in y direction
                def gauss(sigma,res):
                    steps = int(np.ceil(sigma/res)+0.5)*6
                    x = np.arange(-steps,steps+1,1)
                    return np.exp(-(x**2)/(2*(sigma/res)**2))

                gauss_kernel  = gauss(smooth_y,binning_y*4)
                if gauss_kernel.shape[0] > binned.shape[0]:
                    print('Warning: Smoothing more than there are data points.')
                    kernel_start = (gauss_kernel.shape[0]  - binned.shape[0])//2
                    gauss_kernel = gauss_kernel[kernel_start:kernel_start+binned.shape[0]]
                gauss_kernel /= np.sum(gauss_kernel)

                smoothed = np.copy(binned)
                for i in range(smoothed.shape[1]):
                    for j in range(4):
                        smoothed[:,i,j] = np.convolve(smoothed[:,i,j],gauss_kernel,'same') 

                # Correct for data loss in y direction
                factor = np.ones_like(smoothed[:,0,0])
                for i in range(gauss_kernel.shape[0]//2):
                    factor[i]  /= np.sum(gauss_kernel[:i+gauss_kernel.shape[0]//2])
                    factor[factor.shape[0]-1-i] /= np.sum(gauss_kernel[:i+gauss_kernel.shape[0]//2])
                smoothed = smoothed * factor[:,np.newaxis,np.newaxis]

                # Smooth in x direction
                gauss_kernel = gauss(gauss_blurr,binning_x)
                for i in range(smoothed.shape[0]):
                    for j in range(4):
                        smoothed[i,:,j] = np.convolve(smoothed[i,:,j],gauss_kernel,'same')

                smoothed = (np.minimum(smoothed*255,255))[:,:int(first_x_ms/binning_x)]

                ridges = []

                for i in range(4):
                    # Here:   | x   and   -- y
                    traced = np.copy(smoothed[:,:,i])
                    while True:
                        max_x,max_y = argmax2d(traced)
                        if traced[max_x,max_y] < thr1:
                            break

                        # This means, we have a new point that is relevant
                        ridge  = [[max_x,max_y,i]]
                        traced = delete_line(max_x,max_y,traced,thr2)

                        up_x   = max_x+1
                        up_y   = max_y
                        up_s   = sigma
                        while True:
                            if up_x >= (traced.shape[0]-1):
                                break

                            x = np.arange(max(0,up_y-int(up_s*2+0.5)),min(traced.shape[1],up_y+int(up_s*2+0.5)+1),1)
                            y = traced[up_x,x]
                            up_y,up_s = fit(x,y,up_y,up_s)
                            up_y = max(0,min(up_y,traced.shape[1]-1))
                            if traced[up_x,up_y] < thr2:
                                break
                            traced = delete_line(up_x,up_y,traced,thr2)
                            ridge.append([up_x,up_y,i])
                            up_x += 1

                        down_x   = max_x-1
                        down_y   = max_y
                        down_s   = sigma
                        while True:
                            if down_x < 0:
                                break

                            x = np.arange(max(0,down_y-int(down_s*2+0.5)),min(traced.shape[1],down_y+int(down_s*2+0.5)+1),1)
                            y = traced[down_x,x]
                            down_y,down_s = fit(x,y,down_y,down_s)
                            down_y = max(0,min(down_y,traced.shape[1]-1))
                            if traced[down_x,down_y] < thr2:
                                break
                            traced = delete_line(down_x,down_y,traced,thr2)
                            ridge.insert(0,[down_x,down_y,i])
                            down_x -= 1

                        ridge = np.array(ridge)
                        if ridge.shape[0] > min_ridge_length:
                            ridges.append(ridge)

                if save_images:
                    segment_group.create_dataset(name='Filtered',data=(smoothed+0.5).astype(np.uint8))
                for i in range(len(ridges)):
                    segment_group.create_dataset(name='Ridge_'+str(i),data=ridges[i].astype(np.uint16))
                segment_group.create_dataset(name='Artifacts',data=artifacts)
                segment_group.create_dataset(name='Spikes',data=data)
                segment_group.create_dataset(name='Start_ID',data=start_id)
                segment_group.create_dataset(name='Stop_ID',data=stop_id)
                    
    f_in.close()
    f_out.close()

    print(filename,time.time()-t0)
    sys.stdout.flush()

In [10]:
Parallel(n_jobs=10)(delayed(smoothing_data)(filename) for filename in filenames)

[None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None]