In [1]:
from mne.datasets import sample
from mne.io import read_raw_fif
import numpy as np 
import mne 
import time
from mne import filter

In [2]:
def calc_band_filters(f_ranges, sample_rate, filter_len="1000ms", l_trans_bandwidth=4, h_trans_bandwidth=4):
    """"Calculate bandpass filters with adjustable length for given frequency ranges.
    This function returns for the given frequency band ranges the filter coefficients with length "filter_len".
    Thus the filters can be sequentially used for band power estimation.
    Parameters
    ----------
    f_ranges : list of lists
        frequency ranges.
    sample_rate : float
        sampling frequency.
    filter_len : str, optional
        length of the filter. Human readable (e.g."1000ms" or "1s"). Default is "1000ms"
    l_trans_bandwidth : int/float, optional
        Length of the lower transition band. The default is 4.
    h_trans_bandwidth : int/float, optional
        Length of the higher transition band. The default is 4.
    Returns
    -------
    filter_fun : array
        filter coefficients stored in array of shape (n_franges, filter_len (in samples))
    """
    filter_list = []
    for a, f_range in enumerate(f_ranges):
        h = mne.filter.create_filter(None, sample_rate, l_freq=f_range[0], h_freq=f_range[1], fir_design='firwin',
                                        l_trans_bandwidth=l_trans_bandwidth, h_trans_bandwidth=h_trans_bandwidth,
                                        filter_length=filter_len)
        filter_list.append(h)
    filter_bank = np.vstack(filter_list)
    return filter_bank

def apply_filter(dat_, filter_bank, fs):
        """Apply previously calculated (bandpass) filters to data.
        Parameters
        ----------
        dat_ : array (n_samples, ) or (n_channels, n_samples)
            segment of data.
        filter_fun : array
            output of calc_band_filters.
        Returns
        -------
        filtered : array
            (n_chan, n_fbands, filter_len) array conatining the filtered signal
            at each freq band, where n_fbands is the number of filter bands used to decompose the signal
        """    
        if dat_.ndim == 1:
            filtered = np.zeros((1, filter_bank.shape[0], fs))
            for filt in range(filter_bank.shape[0]):
                filtered[0, filt, :] = np.convolve(filter_bank[filt,:], dat_)[int(fs-fs/2):int(fs+fs/2)]
        elif dat_.ndim == 2:
            filtered = np.zeros((dat_.shape[0], filter_bank.shape[0], fs))
            for chan in range(dat_.shape[0]):
                for filt in range(filter_bank.shape[0]):
                    filtered[chan, filt, :] = np.convolve(filter_bank[filt, :], \
                                                        dat_[chan,:])[int(fs-fs/2):int(fs+fs/2)] # mode="full"
        return filtered

In [3]:
# setup examplary data 
data_points = 500
channels = 6
data = np.random.random([channels, data_points])

# definition of 
filter_ranges = np.array([[4, 8], [8, 12], [13, 20], [20, 35], [13, 35], [60, 80], [90, 200], \
                          [60, 200]])

fs = 500
filter_bank = calc_band_filters(filter_ranges, fs, filter_len="1000ms", l_trans_bandwidth=4, h_trans_bandwidth=4)

No data specified. Sanity checks related to the length of the signal relative to the filter order will be skipped.
Setting up band-pass filter from 4 - 8 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 4.00
- Lower transition bandwidth: 4.00 Hz (-6 dB cutoff frequency: 2.00 Hz)
- Upper passband edge: 8.00 Hz
- Upper transition bandwidth: 4.00 Hz (-6 dB cutoff frequency: 10.00 Hz)
- Filter length: 501 samples (1.002 sec)

No data specified. Sanity checks related to the length of the signal relative to the filter order will be skipped.
Setting up band-pass filter from 8 - 12 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB st

In [4]:
start = time.time()
dat_filtered = apply_filter(data, filter_bank, fs)
end = time.time()
print(end - start)

0.005991220474243164


In [5]:
dat_filtered.shape

(6, 8, 500)

In [6]:
start = time.time()
filtered_out = np.zeros([channels, filter_ranges.shape[0], data_points])
for filter_idx, filter_ in enumerate(filter_ranges):
    #print(filter_)
    filtered_out[:,filter_idx,:] = mne.filter.filter_data(data, sfreq=fs, l_freq=filter_[0], \
                                                          h_freq=filter_[1], method="fir", verbose=0)
end = time.time()
print(end - start)

0.0310056209564209


  h_freq=filter_[1], method="fir", verbose=0)
  h_freq=filter_[1], method="fir", verbose=0)
  h_freq=filter_[1], method="fir", verbose=0)
  h_freq=filter_[1], method="fir", verbose=0)


In [7]:
filtered_out.shape

(6, 8, 500)