In [1]:
import numpy as np
import h5py

import sys
from joblib import Parallel,delayed
import time

import re
import os

import cv2

from scipy.optimize import curve_fit

In [2]:
# Do not forget to change the parameters also in step 3.

In [3]:
circuit_mask_path = '../Step_4/'
input_data_path   = '../Step_5/'

In [4]:
fs          = 25   # kHz
gauss_blurr = 0.1  # ms
upper_thr   = 0.1  # Probability
lower_thr   = 0.05 # Probability

binning_y   = 1    # sec
binning_x   = 0.1  # ms
smooth_y    = 45   # Smoothing over time

first_x_ms  = 25   # Save the first 25 ms

thr1        = 15   # Hysteresis thrhold top
thr2        = 3    # Hysteresis thrhold bottom
sigma       = 2    # starting smoothness of fit

min_ridge_length = 200 # Minimal ridge length

In [5]:
f = h5py.File(circuit_mask_path + 'step_4_percentages.h5','r')
total_circuit_count = np.sum([np.sum(f['MEA_'+str(i)+'/total']) for i in [2,3,4,6,7,8,9,10]])
f.close()

In [6]:
filenames = []
for (_,_,files) in os.walk(input_data_path):
    for filename in files:
        if '.h5' in filename:
            filenames.append(filename)

In [7]:
def fit(x,y,mu,sigma):
    def f(x,a,b,c):
        return a * np.exp(-(x-b)**2/(2*c*c))
    
    try:
        param,_ = curve_fit(f,x,y,p0=[1,mu,sigma],maxfev=1000)
        return int(min(max(param[1],x[0]),x[-1])),min(abs(param[2]),3/binning_x)
    except:
        return mu,sigma

In [8]:
def argmax2d(array):
    x = array.shape[0]
    y = array.shape[1]
    m = np.argmax(np.reshape(array,x*y))
    return m//y,m%y

In [9]:
def delete_line(x,y,array,thr2):
    y_start = np.copy(y)
    while y < array.shape[1] and array[x,y] > thr2:
        array[x,y] = 0
        y += 1
        
    y = y_start - 1
    while y >= 0 and array[x,y] > thr2:
        array[x,y] = 0
        y -= 1
        
    return array

In [10]:
filename = filenames[0]

In [11]:
def smoothing_data(filename,save_images=False):
    t0 = time.time()

    f_in  = h5py.File(input_data_path+filename,'r')
    f_out = h5py.File(filename,'w')
    for DIV in list(f_in.keys()):
        f_in_div = f_in[DIV]
        f_out_div = f_out.create_group(DIV)
        for circuit in range(total_circuit_count):
            try:
                f_in_div['Circuit_'+str(circuit)].keys()
            except:
                continue # This circuit has no data for the experiment in question
            print(circuit)
            circuit_group = f_out_div.create_group('Circuit_'+str(circuit))
            patterns = list(f_in_div['Circuit_' + str(circuit)].keys())
            patterns.remove('MEA')
            patterns.remove('circuit')
            pattern  = np.array(patterns)
            for pattern in patterns:
                pattern_group = circuit_group.create_group(pattern)
                segments = np.array(list(f_in_div['Circuit_' + str(circuit) + '/' + pattern].keys()))
                
                for segment in segments:
                    segment_group = pattern_group.create_group(segment)
                    start_id  = np.array(f_in_div['Circuit_' + str(circuit) + '/' + pattern + '/' + segment + '/Start_ID'])
                    stop_id   = np.array(f_in_div['Circuit_' + str(circuit) + '/' + pattern + '/' + segment + '/Stop_ID'])
                    data      = np.array(f_in_div['Circuit_' + str(circuit) + '/' + pattern + '/' + segment + '/Spikes'])
                    artifacts = np.array(f_in_div['Circuit_' + str(circuit) + '/' + pattern + '/' + segment + '/Artifacts'])

                    # Create the binned dataset
                    bins_x   = int(np.ceil(250/binning_x)+0.5)
                    bins_y   = int(np.ceil((stop_id-start_id)/(binning_y*4))+0.5)
                    binned   = np.zeros((bins_y,bins_x,4))

                    for i in range(data.shape[0]):
                        bins_x = int(data[i,2]/(binning_x*fs))
                        bins_y = int(data[i,0]/(binning_y*4))
                        elec   = int(data[i,1]+0.5)
                        binned[bins_y,bins_x,elec] += 1

                    binned   = binned/(binning_y*4)

                    # Smooth in y direction
                    def gauss(sigma,res):
                        steps = int(np.ceil(sigma/res)+0.5)*6
                        x = np.arange(-steps,steps+1,1)
                        return np.exp(-(x**2)/(2*(sigma/res)**2))

                    gauss_kernel  = gauss(smooth_y,binning_y*4)
                    if gauss_kernel.shape[0] > binned.shape[0]:
                        print('Warning: Smoothing more than there are data points.')
                        kernel_start = (gauss_kernel.shape[0]  - binned.shape[0])//2
                        gauss_kernel = gauss_kernel[kernel_start:kernel_start+binned.shape[0]]
                    gauss_kernel /= np.sum(gauss_kernel)

                    smoothed = np.copy(binned)
                    for i in range(smoothed.shape[1]):
                        for j in range(4):
                            smoothed[:,i,j] = np.convolve(smoothed[:,i,j],gauss_kernel,'same') 

                    # Correct for data loss in y direction
                    factor = np.ones_like(smoothed[:,0,0])
                    for i in range(gauss_kernel.shape[0]//2):
                        factor[i]  /= np.sum(gauss_kernel[:i+gauss_kernel.shape[0]//2])
                        factor[factor.shape[0]-1-i] /= np.sum(gauss_kernel[:i+gauss_kernel.shape[0]//2])
                    smoothed = smoothed * factor[:,np.newaxis,np.newaxis]

                    # Smooth in x direction
                    gauss_kernel = gauss(gauss_blurr,binning_x)
                    for i in range(smoothed.shape[0]):
                        for j in range(4):
                            smoothed[i,:,j] = np.convolve(smoothed[i,:,j],gauss_kernel,'same')

                    smoothed = (np.minimum(smoothed*255,255))[:,:int(first_x_ms/binning_x)]

                    if save_images:
                        segment_group.create_dataset(name='Filtered',data=(smoothed+0.5).astype(np.uint8))
                        
                    for key in list(f_in_div['Circuit_' + str(circuit) + '/' + pattern + '/' + segment].keys()):
                        segment_group.create_dataset(name=key,
                                                     data=f_in_div['Circuit_' + str(circuit) + '/' + pattern + '/' + segment + "/" + key])

    f_in.close()
    f_out.close()

    print(filename,time.time()-t0)
    sys.stdout.flush()

In [12]:
Parallel(n_jobs=10)(delayed(smoothing_data)(filename,True) for filename in filenames)

[None, None, None, None, None, None]