In [1]:
import numpy as np
import h5py

import sys
from joblib import Parallel,delayed
import time

import re
import os

from scipy.optimize import curve_fit as fit
from scipy.signal import medfilt
from scipy.ndimage import minimum_filter as minfilt
from scipy.ndimage import maximum_filter as maxfilt

import matplotlib.pyplot as plt

In [2]:
input_data_path = '../Step_1/'

In [3]:
max_dist        = 50
min_length      = 50
fitting_error   = 25
epsilon         = 0.33

In [4]:
stimulation     = {0: [],
                   1: [1], 2: [2], 3: [3], 4: [4],
                   7: [], 8: [],
                   80: [0,1,2,3],
                   100: [1], 102: [2], 104: [3], 106: [4],
                   108: [1,2], 110: [2,3], 112: [3,4], 114: [1,4],
                   116: [1,3], 118: [2,4], 120: [],
                   }

In [5]:
filenames = []
for (_,_,files) in os.walk(input_data_path):
    for filename in files:
        if '.h5' in filename:
            filenames.append(filename)

In [6]:
def remove_single_artefact(data,stim,seg_length,flag=False):
    
#     if stim == 1 and flag:
#         plt.figure(figsize=(3,0.5))
#         plt.subplot(1,3,1)
#         d = data[data[:,1]==0,::2]
#         d = d[d[:,1]<500,:]
#         plt.plot(d[:,1],d[:,0],'r.',alpha=0.1,markersize=1)
#         plt.xlim([0,500])
#         plt.ylim([0,2400])
#         plt.axis('off')
        
    if data.shape[0] <= 1:
        return data,0
    
    critical = np.argwhere(data[:,1]==(stim-1))[:,0]
    if critical.shape[0] <= 1:
        return data,0
    
    vec      = data[critical,0]
    mask     = np.concatenate([np.ones(1,dtype=np.bool),vec[1:]-vec[:-1]!=0])
    critical = critical[mask]

    critical = np.flip(critical)

    i = 0
    while i < critical.shape[0]:
        x_thr = data[critical[i],0]
        y_thr = data[critical[i],2] + 25*3

        subset         = data[critical,:]
        still_critical = np.argwhere(np.logical_or(subset[:,0]>=x_thr,subset[:,2]<y_thr))[:,0]
        
        i              = np.argwhere(still_critical==i)[0,0] + 1

        critical       = critical[still_critical]

    count = 0
    
    criteria_3 = data[critical,2] - minfilt(data[critical,2],7) < fitting_error
    critical   = critical[criteria_3]
    
    criteria_1 = critical.shape[0] > min_length
    
    
    if criteria_1:
        wave       = [j/(1e-5+abs(data[critical[-(1+j)],0] - data[critical[-1],0])) for j in range(critical.shape[0])]
        wave       = wave[min_length//2:]
        criteria_2 = np.max(wave) > epsilon
        value      = np.max(wave)
    else:
        wave       = 0
        criteria_2 = False
        value      = 0
    
    mask = np.array(list(set(np.arange(data.shape[0]))-set(critical)))
    if criteria_1 and criteria_2:
        if mask.shape[0] == 0:
            data = data[:0,:]
        else:
            data  = data[mask,:]
        count = 1
    
#     if stim == 1 and flag:
#         plt.subplot(1,3,2)
#         d = data[data[:,1]==0,::2]
#         d = d[d[:,1]<500,:]
#         plt.plot(d[:,1],d[:,0],'r.',alpha=0.1,markersize=1)
                
#         plt.xlim([0,500])
#         plt.ylim([0,2400])
#         plt.axis('off')
        
#         plt.subplot(1,3,3)
#         plt.plot(wave)
        
#         plt.show()
#         print(criteria_1,criteria_2,critical.shape[0],value)
    
    return data,count

In [7]:
def remove_artefact(filename):
    count = 0
    t0    = time.time()
    
    f_in  = h5py.File(input_data_path+filename,'r')
    f_out = h5py.File(filename,'w')
    
    # Over all 15 circuits
    for circuit in range(15):
        
        # Create the circuit group of current circuit
        circuit_group = f_out.create_group('Circuit_' + str(circuit))
        
        # Find the Patterns, that belong to this circuit
        keys = []
        for key in f_in['Circuit_' + str(circuit)].keys():
            keys.append(int(key[8:]))
                        
        
        # For every pattern (key) in this circuit, segment and remove artifact
        for key in keys:
            
            # First, segment the pattern data
            ids = np.array(f_in['Circuit_' + str(circuit) + '/Pattern_' + str(key) + '/Stim_IDs'])
            id_diff = ids[1:] - ids[:-1]
            mask = (id_diff<max_dist)
            
            pos       = 0
            locations = []

            while pos < mask.shape[0]:
                # Start of next high
                pos      = np.argmax(mask[pos:]) + pos

                # End of next high.
                next_low = np.argmin(mask[pos:]) + pos
                if next_low == pos:
                    # End of file
                    if mask[pos]:
                        next_low = mask.shape[0]
                    else:
                        break

                if next_low - pos >= min_length:
                    locations.append([ids[pos],ids[next_low]+1])

                pos = next_low
            # Result of this block: The segments are saved in locations
            
            # Create pattern group
            if len(locations)>0:
                pattern_group = circuit_group.create_group('Pattern_'+str(key))
            
            # For each segment, do artifact removal
            for loc_i,location in enumerate(locations):
                
                count_small = 0
                
                mask = np.array(f_in['Circuit_' + str(circuit) + '/Pattern_' + str(key) + '/Spikes'])[:,0,0]
                mask = np.argwhere((mask>=location[0])*(mask<location[1]))[:,0]

                data      = np.array(f_in['Circuit_' + str(circuit) + '/Pattern_' + str(key) + '/Spikes'])[mask,0,:]
                data[:,0]-= location[0]
                
                #data = data[np.logical_or(data[:,1]!=(stim-1),data[:,2]>25),:] # Remove 1 sec only from stim elec
                data = data[data[:,2]>25,:] # Remove 1 sec from all electrodes

                # Do for loop, since some patterns have 0, 1, or multiple stimulation electrodes                
                for stim in stimulation[key]:
                    
                    flag = circuit == 8
                    data,c_s = remove_single_artefact(data,stim,location[1]-location[0],flag=flag)
                    count       += c_s
                    count_small += c_s
                    
                # Save all the data that is relevant
                segment_group = pattern_group.create_group('Segment_'+str(loc_i))
                segment_group.create_dataset("Spikes",data=data,dtype=np.uint32)
                segment_group.create_dataset("Start_ID",data=location[0],dtype=np.uint32)
                segment_group.create_dataset("Stop_ID",data=location[1],dtype=np.uint32)
                segment_group.create_dataset("Artifacts",data=count_small,dtype=np.uint32)
    
    # Close files
    f_out.close()
    f_in.close()
    
    print(filename,count,time.time()-t0)

In [8]:
# def remove_artefact(filename):
#     count = 0
#     t0    = time.time()
    
#     f_in  = h5py.File(input_data_path+filename,'r')
#     f_out = h5py.File(filename,'w')
    
#     # Over all 15 circuits
#     for circuit in range(15):
        
#         # Create the circuit group of current circuit
#         circuit_group = f_out.create_group('Circuit_' + str(circuit))
        
#         # Find the Patterns, that belong to this circuit
#         keys = []
#         for key in f_in['Circuit_' + str(circuit)].keys():
#             keys.append(int(key[8:]))
                        
        
#         # For every pattern (key) in this circuit, segment and remove artifact
#         for key in keys:
            
#             # First, segment the pattern data
#             ids = np.array(f_in['Circuit_' + str(circuit) + '/Pattern_' + str(key) + '/Stim_IDs'])
#             id_diff = ids[1:] - ids[:-1]
#             mask = (id_diff<max_dist)
            
#             pos       = 0
#             locations = []

#             while pos < mask.shape[0]:
#                 # Start of next high
#                 pos      = np.argmax(mask[pos:]) + pos

#                 # End of next high.
#                 next_low = np.argmin(mask[pos:]) + pos
#                 if next_low == pos:
#                     # End of file
#                     if mask[pos]:
#                         next_low = mask.shape[0]
#                     else:
#                         break

#                 if next_low - pos >= min_length:
#                     locations.append([ids[pos],ids[next_low]+1])

#                 pos = next_low
#             # Result of this block: The segments are saved in locations
            
#             # Create pattern group
#             if len(locations)>0:
#                 pattern_group = circuit_group.create_group('Pattern_'+str(key))
            
#             # For each segment, do artifact removal
#             for loc_i,location in enumerate(locations):
#                 count_small = 0
                
#                 mask = np.array(f_in['Circuit_' + str(circuit) + '/Pattern_' + str(key) + '/Spikes'])[:,0,0]
#                 mask = np.argwhere((mask>=location[0])*(mask<location[1]))[:,0]

#                 data      = np.array(f_in['Circuit_' + str(circuit) + '/Pattern_' + str(key) + '/Spikes'])[mask,0,:]
#                 data[:,0]-= location[0]

#                 # Do for loop, since some patterns have 0, 1, or multiple stimulation electrodes                
#                 for stim in stimulation[key]:
                    
#                     # Repeat until no artifact is left
#                     while True:    
#                         # Find all elements that are about the electrode being stimulated
#                         mapping = np.where(data[:,1]==(stim-1))[0]
#                         stim_electrode = data[mapping,0:4:2]
                        
#                         # Test if enough spikes
#                         if stim_electrode.shape[0] <= 1:
#                             break

#                         # Find the first element in the electrode being stimulated
#                         first_spike_indices  = np.where(stim_electrode[1:,0] - stim_electrode[:-1,0])[0]
#                         first_spike_indices = mapping[np.concatenate([[0],first_spike_indices+1])]

#                         # The spikes need to be in at least eps of the dataset:
#                         if 1.*first_spike_indices.shape[0]/(location[1]-location[0]) < epsilon:
#                             break

#                         # Get the x and y values of the first spikes
#                         fs_x = data[first_spike_indices,0]
#                         fs_y = data[first_spike_indices,2]

#                         if False: 
#                             # Do a data fit
#                             def f(x,a,b,c,d):
#                                 return a * np.exp(-x*b) + c * x + d
#                             param,_ = fit(f,fs_x,medfilt(fs_y,21),p0=[1,0.01,0,10],maxfev=100000)
#                             fs_fit = f(fs_x,param[0],param[1],param[2],param[3])
#                         else:
#                             # Fit data using the median
#                             fs_fit = medfilt(fs_y,21)

#                         # All data that is in the fitting error to the fit gets deleted
#                         to_be_deleted = first_spike_indices[np.where(np.abs(fs_fit-fs_y)<fitting_error)[0]]

#                         # Check, if artifact is still is often enough in dataset:
#                         if 1.*to_be_deleted.shape[0]/(location[1]-location[0]) < epsilon:
#                             break
                            
#                         count += 1
#                         count_small += 1

#                         # Remove the artifacts from the dataset
#                         if to_be_deleted.shape[0] == data.shape[0]:
#                             data = data[:0,:]
#                         else:
#                             data = data[np.array(list(set(np.arange(data.shape[0]))-set(to_be_deleted))),:]
            
#                 # Save all the data that is relevant
#                 segment_group = pattern_group.create_group('Segment_'+str(loc_i))
#                 segment_group.create_dataset("Spikes",data=data,dtype=np.uint32)
#                 segment_group.create_dataset("Start_ID",data=location[0],dtype=np.uint32)
#                 segment_group.create_dataset("Stop_ID",data=location[1],dtype=np.uint32)
#                 segment_group.create_dataset("Artifacts",data=count_small,dtype=np.uint32)
    
#     # Close files
#     f_out.close()
#     f_in.close()
    
#     print(filename,count,time.time()-t0,)

In [9]:
Parallel(n_jobs=8)(delayed(remove_artefact)(filename) for filename in filenames)

[None, None, None, None, None, None, None, None]