In [1]:
from configparser import ConfigParser, ExtendedInterpolation

import h5py
from ipywidgets import IntSlider, interact, Dropdown
import numpy as np
import matplotlib.pylab as plt
from matplotlib.gridspec import GridSpec

from spikelib.fitting import fit_temp_sta
from spikelib.utils import check_groups, datasets_to_array

%matplotlib inline

config = ConfigParser(interpolation=ExtendedInterpolation())
config.read('../../config.ini')

['../../config.ini']

# General information about input output

In [2]:
# Data files
processed_file = config['FILES']['processed']

# intensities = ['nd2-255', 'nd3-255', 'nd4-255', 'nd5-255']
with h5py.File(processed_file, 'r') as f:
    intensities = list(f['/sta'].keys())
print(f'intensities: {intensities}')

intensities: ['nd3-255']


In [3]:
# Data files
intensity = intensities[0]
# Groups for raw data in hdf5 file
temp_group = '/sta/{}/temporal/raw/'.format(intensity)
spatial_group = '/sta/{}/spatial/char/'.format(intensity)
sta_group = '/sta/{}/raw/'.format(intensity)
valid = '/sta/{}/valid/'.format(intensity)

# Groups for fit data in hdf5 file
temp_fit_group = '/sta/{}/temporal/fit/'.format(intensity)
temp_pars_group = '/sta/{}/temporal/fit_params/'.format(intensity)

# Load Data

In [21]:
# Load data

with h5py.File(processed_file, 'r') as stafile:
    time_raw = stafile[temp_group].attrs['time']
    time_fit = stafile[temp_fit_group].attrs['time']
    n_frame, y_size, x_size = stafile[sta_group + list(stafile[sta_group].keys())[0]].shape
    nraw_samples = time_raw.size
    nfit_samples = time_fit.size
    raw_dtype = np.dtype([
        ('id', 'U18'),
        ('temp_raw', 'f8',(nraw_samples,)),
        ('snr', 'f8')
    ])
    fit_dtype = np.dtype([
        ('id', 'U18'),
        ('parameters', 'f8',(5,)),
        ('temp_fit', 'f8', (nfit_samples,)),
    ])
    temporal_sta = np.zeros(len(stafile[temp_group]), dtype=raw_dtype)
    temp_fit = np.zeros((len(temporal_sta)), dtype=fit_dtype)
    sta_frame_max = np.zeros((len(temporal_sta), 4, y_size, x_size))

    for kidx, kkey in enumerate(stafile[temp_group]):
        temporal_sta[kidx]['id'] = kkey
        temporal_sta[kidx]['temp_raw'] = stafile[temp_group][kkey][...]
        temporal_sta[kidx]['snr'] = stafile[spatial_group][kkey][-2]
        temp_fit[kidx]['id'] = kkey
        temp_fit[kidx]['parameters'] = stafile[temp_pars_group][kkey][...]
        temp_fit[kidx]['temp_fit'] = stafile[temp_fit_group][kkey][...]
        frame_max = stafile[spatial_group][kkey][-1][...]
        frame_max = int(0 if frame_max <= 1 else frame_max - 1 )
        frame_max = int(frame_max if frame_max <= n_frame - 4 else n_frame - 4 )
        sta_frame_max[kidx] = stafile[sta_group + kkey][frame_max:frame_max+4, :][...]

# Visualization of fitting

In [5]:
thr_snr = 2
no_valid_by_snr = np.where(temporal_sta['snr'] < thr_snr)[0]
valid_by_snr = np.where(temporal_sta['snr'] >= thr_snr)[0]

In [34]:
def plot_fit(kidx):
    temp_sta = temporal_sta[kidx]['temp_raw']
    temp_sta_fit = temp_fit[kidx]['temp_fit']
    fig = plt.figure(constrained_layout=True, figsize=(10,4))

    gs = GridSpec(2, 5, figure=fig)
    ax_temp = fig.add_subplot(gs[:, :3])
    ax_spatial = [fig.add_subplot(gs[0, 3]),
                  fig.add_subplot(gs[0, 4]),
                  fig.add_subplot(gs[1, 3]),
                  fig.add_subplot(gs[1, 4])]
    ax_temp.plot(time_raw,temp_sta, 'k+')
    ax_temp.plot(time_raw,temp_sta, 'k')
    ax_temp.plot(time_fit, temp_sta_fit, 'r')
    ax_temp.set(ylim=[-1,1], xlabel=temporal_sta[kidx]['id'] + '  ({:2.2f})'.format(temporal_sta[kidx]['snr']))
    ax_temp.grid(b=True, which='major', color='k', linestyle='-',alpha=0.2) 
    ax_temp.grid(b=True, which='minor', color='k', linestyle='-',alpha=0.1)
    ax_temp.minorticks_on()
    vmax=np.abs(sta_frame_max[kidx]).max()
    ax_spatial[0].pcolor(sta_frame_max[kidx][0], vmin=-vmax, vmax=vmax, cmap='RdBu_r')
    ax_spatial[1].pcolor(sta_frame_max[kidx][1], vmin=-vmax, vmax=vmax, cmap='RdBu_r')
    ax_spatial[2].pcolor(sta_frame_max[kidx][2], vmin=-vmax, vmax=vmax, cmap='RdBu_r')
    ax_spatial[3].pcolor(sta_frame_max[kidx][3], vmin=-vmax, vmax=vmax, cmap='RdBu_r')
    [k.set(title=f'relative {idx}') for idx, k in enumerate(ax_spatial)]
    plt.show()
    plt.close()

## no valid unit filtered for snr

In [35]:
interact(plot_fit, kidx=Dropdown(options=no_valid_by_snr));

interactive(children=(Dropdown(description='kidx', options=(0, 3, 5, 7, 9, 10, 11, 16, 18, 19, 20, 21, 22, 23,…

## valid unit filtered for snr

In [36]:
interact(plot_fit, kidx=Dropdown(options=valid_by_snr));

interactive(children=(Dropdown(description='kidx', options=(1, 2, 4, 6, 8, 12, 13, 14, 15, 17, 24, 27, 32, 35,…

## all units

In [74]:
interact(plot_fit, kidx=Dropdown(options=range(len(temporal_sta)-1)));

interactive(children=(Dropdown(description='kidx', options=(2, 3, 4, 5, 6, 8, 9, 10, 11, 15, 16, 17, 18, 20, 2…

# Fitting for specific unit


In [95]:
# Get index from unit name
np.where(temp_fit[:]['id'] == 'temp_74')[0][0]

223

In [96]:
# Set kidx to refit this unit
kidx = None 
if kidx:
    kcell = temporal_sta[kidx]
    temp_to_fit = kcell['temp_raw']

    params, tmp_fit = fit_temp_sta(temp_to_fit,
                                   time_raw,
                                   time_fit, 
                                   tau1=-0.1,
                                   tau2=-0.04,
                                   amp1=0.01,
                                   amp2=0.1,
                                   min_time=None,
                                   max_time=None)

    temp_fit[kidx]['parameters'] = np.asarray(params)
    temp_fit[kidx]['temp_fit'] = tmp_fit
    params

In [97]:
kidx_invalid = None
if kidx_invalid:
    temp_fit[kidx_invalid]['parameters'] *= 0.
    temp_fit[kidx_invalid]['temp_fit'] *= 0.

# Set parameter to zero for invalid units

In [37]:
no_valid = [k for k in no_valid_by_snr if k not in [0, 18, 26, 38, 72, 89, 106, 112, 117, 186, 224, 226, 235, 250]]
valid = [k for k in valid_by_snr if k not in [46, 97, 109, 111, 124, 153, 154, 167, 168, 170, 171, 184, 192, 201, 258, 261, 262, 264]]

In [38]:
for kidx in no_valid:
    temp_fit[kidx]['parameters'] *= 0.
    temp_fit[kidx]['temp_fit'] *= 0.

In [39]:
for kidx in range(temp_fit.shape[0]):
     if not kidx in valid:
        temp_fit[kidx]['parameters'] *= 0.
        temp_fit[kidx]['temp_fit'] *= 0.

## Save to hdf5

In [40]:
with h5py.File(processed_file, 'r+') as fanalysis:
    check_groups(fanalysis, [temp_fit_group, temp_pars_group])      
    for kunit in temp_fit:
        key = kunit['id']
        fanalysis[temp_fit_group+key][...] = kunit['temp_fit']
        fanalysis[temp_pars_group+key][...] = kunit['parameters']