In [1]:
from configparser import ConfigParser, ExtendedInterpolation

import h5py
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import scipy.optimize as opt
from spikelib.fitting import gaussian2d
from spikelib.utils import check_groups

%matplotlib notebook
config = ConfigParser(interpolation=ExtendedInterpolation())
config.read('../../config.ini')

['../../config.ini']

In [2]:
def fit_sta_2d(data_raw):
    (_, y_shape, x_shape) = data_raw.shape
    (frame, y0, x0) = np.unravel_index(np.abs(data_raw).argmax(), data_raw.shape)
    z0 = data_raw[frame, y0, x0]
    data = data_raw[frame]

    x = np.linspace(0, x_shape - 1, x_shape)
    y = np.linspace(0, y_shape - 1, y_shape)
    x, y = np.meshgrid(x, y)
    # (xy, amp, x0, y0, sigma_x, sigma_y, theta, offset, revel=True):
    initial_guess = (z0, x0, y0, 1, 1, 1.5, 0)
    popt, pcov = opt.curve_fit(gaussian2d, (x.ravel(), y.ravel()), np.sign(z0)*data.ravel(), p0=initial_guess)
    
    return popt, pcov, frame


def truncate_center(number, constrains):
    """Set a min and max value for a number.
    
    Parameters
    ----------
    number: float
        any number to truncate
    constrains: list
        min and max value
    
    Return
    ------
    value: flaot
    
    """
    min_value, max_value = constrains
    if number < min_value:
        new_number = min_value 
    elif number > max_value:
        new_number = max_value - 1
    else:
        new_number = number
    
    return new_number

def sta_fitting_2d(file_name, intensity):
    with h5py.File(file_name, 'r+') as h5file:
        unit_names = list(h5file[f'/sta/{intensity}/raw/'].keys())
        (frames, y_shape, x_shape) = h5file[f'/sta/{intensity}/raw/' + unit_names[0]].shape

        x = np.linspace(0, x_shape - 1, x_shape)
        y = np.linspace(0, y_shape - 1, y_shape)
        xy = np.meshgrid(x,y)

        rawsta_group = f'/sta/{intensity}/raw/'
        temp_raw_group = f'/sta/{intensity}/temporal/raw/'
        spacial_group = f'/sta/{intensity}/spatial/char/'


        check_groups(h5file, [temp_raw_group, spacial_group])      

        for kunit in unit_names:
            raw_sta = h5file[f'/sta/{intensity}/raw/' + kunit][:]
            try:
                popt, pcov, frame = fit_sta_2d(raw_sta)
            except RuntimeError:
                popt = (0 for _ in range(7))
                frame = 0
                print('Couldn\'t fit {}' )
            (amp, x0, y0, sigma_x, sigma_y, theta, offset) = popt
            theta = np.rad2deg(np.unwrap(np.array([0, theta]))[1])

            data = raw_sta[frame]
            y_0 = int(np.floor(truncate_center(y0, [0, y_shape])))
            x_0 = int(np.floor(truncate_center(x0, [0, x_shape])))
            raw_temp = raw_sta[:, y_0, x_0]        
            snr = (raw_sta[:, y_0, x_0].var() / raw_sta.var(axis=0).mean())
            spatial_params = (theta, sigma_x, sigma_y, x0, y0, snr , frame)

            if kunit in h5file[temp_raw_group]:
                h5file[temp_raw_group + kunit][...] = raw_temp
            else:
                h5file[temp_raw_group].create_dataset(kunit, data=raw_temp, dtype=np.float)

            if kunit in h5file[spacial_group]:
                h5file[spacial_group + kunit][...] = spatial_params
            else:
                h5file[spacial_group].create_dataset(kunit, data=spatial_params, dtype=np.float)            

        h5file[temp_raw_group].attrs['time'] = h5file[rawsta_group].attrs['time']
        h5file[temp_raw_group].attrs['nsamples_before'] = h5file[rawsta_group].attrs['nsamples_before']
        h5file[temp_raw_group].attrs['nsamples_after'] = h5file[rawsta_group].attrs['nsamples_after']
        h5file[temp_raw_group].attrs['fps'] = h5file[rawsta_group].attrs['fps']
        h5file[temp_raw_group].attrs['nsamples'] = h5file[rawsta_group].attrs['nsamples']
        h5file[spacial_group].attrs['col_name'] = 'angle,a,b,x,y,snr,frame'


In [3]:
# Parameters
events_file = config['SYNC']['events']
processed_file = config['FILES']['processed']
protocol_name = config['CHECKERBOARD']['protocol_name']

df = pd.read_csv(events_file)
checkerboard_times = df[df['protocol_name'] == protocol_name]

In [4]:
for event in checkerboard_times.itertuples():
    # Sync
    intensity = '{}-{}'.format(event.nd, int(event.intensity))
    print('Computing fitting to {}'.format(intensity))
    
    sta_fitting_2d(processed_file, intensity)

Computing fitting to nd3-255


RuntimeError: Optimal parameters not found: Number of calls to function has reached maxfev = 1600.