In [1]:
import numpy as np
#!pip install git+https://github.com/aaren/wavelets
#!pip install scipy==1.2.1

In [6]:
data_path='/media/maria/DATA1/Documents/data_for_suite2p/TX39/'
dt=1
spks= np.load(data_path+'spks.npy')
print('Shape of the data matrix, neurons by timepoints:',spks.shape)
iframe = np.load(data_path+'iframe.npy') # iframe[n] is the microscope frame for the image frame n
ivalid = iframe+dt<spks.shape[-1] # remove timepoints outside the valid time range
iframe = iframe[ivalid]
S = spks[:, iframe+dt]
print(S.shape)

Shape of the data matrix, neurons by timepoints: (18795, 30766)
(18795, 30560)


In [2]:
"""
DeepInsight Toolbox
© Markus Frey
https://github.com/CYHSM/DeepInsight
Licensed under MIT License
"""
from wavelets import WaveletAnalysis
import numpy as np


def wavelet_transform(signal, sampling_rate, average_window=1000, scaling_factor=0.25, wave_highpass=2, wave_lowpass=30000):
    """
    Calculates the wavelet transform for each point in signal, then averages
    each window and returns together fourier frequencies

    Parameters
    ----------
    signal : (N,1) array_like
        Signal to be transformed
    sampling_rate : int
        Sampling rate of signal
    average_window : int, optional
        Average window to downsample wavelet transformed input, by default 1000
    scaling_factor : float, optional
        Determines amount of log-spaced frequencies M in output, by default 0.25
    wave_highpass : int, optional
        Cut of frequencies below, by default 2
    wave_lowpass : int, optional
        Cut of frequencies above, by default 30000

    Returns
    -------
    wavelet_power : (N, M) array_like
        Wavelet transformed signal
    wavelet_frequencies : (M, 1) array_like
        Corresponding frequencies to wavelet_power
    """
    (wavelet_power, wavelet_frequencies, wavelet_obj) = simple_wavelet_transform(signal, sampling_rate,
                                                                                 scaling_factor=scaling_factor, wave_highpass=wave_highpass, wave_lowpass=wave_lowpass)

    # Average over window
    if average_window is not 1:
        wavelet_power = np.reshape(
            wavelet_power, (wavelet_power.shape[0], wavelet_power.shape[1] // average_window, average_window))
        wavelet_power = np.mean(wavelet_power, axis=2).transpose()
    else:
        wavelet_power = wavelet_power.transpose()

    return wavelet_power, wavelet_frequencies


def simple_wavelet_transform(signal, sampling_rate, scaling_factor=0.25, wave_lowpass=None, wave_highpass=None):
    """
    Simple wavelet transformation of signal

    Parameters
    ----------
    signal : (N,1) array_like
        Signal to be transformed
    sampling_rate : int
        Sampling rate of signal
    scaling_factor : float, optional
        Determines amount of log-space frequencies M in output, by default 0.25
    wave_highpass : int, optional
        Cut of frequencies below, by default 2
    wave_lowpass : int, optional
        Cut of frequencies above, by default 30000

    Returns
    -------
    wavelet_power : (N, M) array_like
        Wavelet transformed signal
    wavelet_frequencies : (M, 1) array_like
        Corresponding frequencies to wavelet_power
    wavelet_obj : object
        WaveletTransform Object
    """
    wavelet_obj = WaveletAnalysis(signal, dt=1 / sampling_rate, dj=scaling_factor)
    wavelet_power = wavelet_obj.wavelet_power
    wavelet_frequencies = wavelet_obj.fourier_frequencies

    if wave_lowpass or wave_highpass:
        wavelet_power = wavelet_power[(wavelet_frequencies < wave_lowpass) & (wavelet_frequencies > wave_highpass), :]
        wavelet_frequencies = wavelet_frequencies[(
            wavelet_frequencies < wave_lowpass) & (wavelet_frequencies > wave_highpass)]

    return (wavelet_power, wavelet_frequencies, wavelet_obj)

In [4]:
import time
from joblib import Parallel, delayed
import numpy as np
import h5py

#import deepinsight.util.wavelet_transform as wt


def preprocess_input(fp_hdf_out, raw_data, average_window=1000, channels=None, window_size=100000,
                     gap_size=50000, sampling_rate=30000, scaling_factor=0.5, num_cores=4):
    """
    Transforms raw neural data to frequency space, via wavelet transform implemented currently with aaren-wavelets (https://github.com/aaren/wavelets)
    Saves wavelet transformed data to HDF5 file (N, P, M) - (Number of timepoints, Number of frequencies, Number of channels)
    Parameters
    ----------
    fp_hdf_out : str
        File path to HDF5 file
    raw_data : (N, M) file or array_like
        Variable storing the raw_data (N data points, M channels), should allow indexing
    average_window : int, optional
        Average window to downsample wavelet transformed input, by default 1000
    channels : array_like, optional
        Which channels from raw_data to use, by default None
    window_size : int, optional
        Window size for calculating wavelet transformation, by default 100000
    gap_size : int, optional
        Gap size for calculating wavelet transformation, by default 50000
    sampling_rate : int, optional
        Sampling rate of raw_data, by default 30000
    scaling_factor : float, optional
        Determines amount of log-spaced frequencies P in output, by default 0.5
    num_cores : int, optional
        Number of paralell cores to use to calculate wavelet transformation, by default 4
    """
    # Get number of chunks
    if channels is None:
        channels = np.arange(0, raw_data.shape[1])
    num_points = raw_data.shape[0]
    num_chunks = (num_points // gap_size) - 1
    (_, wavelet_frequencies) = wavelet_transform(np.ones(window_size), sampling_rate, average_window, scaling_factor)
    num_fourier_frequencies = len(wavelet_frequencies)

    # Prepare output file
    hdf5_file = h5py.File(fp_hdf_out, mode='a')
    hdf5_file.create_dataset("inputs/wavelets", [((num_chunks + 1) * gap_size) //
                                                 average_window, num_fourier_frequencies, len(channels)], np.float32)
    hdf5_file.create_dataset("inputs/fourier_frequencies", [num_fourier_frequencies], np.float16)

    # Prepare par pool
    par = Parallel(n_jobs=num_cores, verbose=0)

    # Start parallel wavelet transformation
    print('Number of chunks {}'.format(num_chunks))
    for c in range(0, num_chunks):
        t_chunk = time.time()
        print('Starting chunk {}'.format(c))

        # Cut ephys
        start = gap_size * c
        end = start + window_size
        print('Start {} - End {}'.format(start, end))
        raw_chunk = raw_data[start: end, channels]

        # Process raw chunk
        raw_chunk = preprocess_chunk(raw_chunk, subtract_mean=True, convert_to_milivolt=False)

        # Calculate wavelet transform
        wavelet_transformed = np.zeros((raw_chunk.shape[0] // average_window, num_fourier_frequencies, len(channels)))
        for ind, (wavelet_power, wavelet_frequencies) in enumerate(par(delayed(wavelet_transform)(raw_chunk[:, i], sampling_rate, average_window, scaling_factor) for i in range(0, len(channels)))):
            wavelet_transformed[:, :, ind] = wavelet_power

        # Save in output file
        wavelet_index_end = end // average_window
        wavelet_index_start = start // average_window
        index_gap = gap_size // 2 // average_window
        if c == 0:
            this_index_start = 0
            this_index_end = wavelet_index_end - index_gap
            hdf5_file["inputs/wavelets"][this_index_start:this_index_end,
                                         :, :] = wavelet_transformed[0: -index_gap, :, :]
        elif c == num_chunks - 1:  # Make sure the last one fits fully
            this_index_start = wavelet_index_start + index_gap
            this_index_end = wavelet_index_end
            hdf5_file["inputs/wavelets"][this_index_start:this_index_end, :, :] = wavelet_transformed[index_gap::, :, :]

        else:
            this_index_start = wavelet_index_start + index_gap
            this_index_end = wavelet_index_end - index_gap
            hdf5_file["inputs/wavelets"][this_index_start:this_index_end,
                                         :, :] = wavelet_transformed[index_gap: -index_gap, :, :]
        hdf5_file.flush()
        print('This chunk time {}'.format(time.time() - t_chunk))

    # 7.) Put frequencies in and close file
    hdf5_file["inputs/fourier_frequencies"][:] = wavelet_frequencies
    hdf5_file.flush()
    hdf5_file.close()


def preprocess_chunk(raw_chunk, subtract_mean=True, convert_to_milivolt=False):
    """
    Preprocesses a chunk of data.
    Parameters
    ----------
    raw_chunk : array_like
        Chunk of raw_data to preprocess
    subtract_mean : bool, optional
        Subtract mean over all other channels, by default True
    convert_to_milivolt : bool, optional
        Convert chunk to milivolt , by default False
    Returns
    -------
    raw_chunk : array_like
        preprocessed_chunk
    """
    # Subtract mean across all channels
    if subtract_mean:
        raw_chunk = raw_chunk.transpose() - np.mean(raw_chunk.transpose(), axis=0)
        raw_chunk = raw_chunk.transpose()
    # Convert to milivolt
    if convert_to_milivolt:
        raw_chunk = raw_chunk * (0.195 / 1000)
    return raw_chunk


In [7]:
path='/media/maria/DATA1/Documents/DeepInsightAn'
preprocess_input(path,S.T)

NameError: name 'wt' is not defined