# Processing - Get features of flash response
Classificate the light response in two group, transient and sustainded.   
reference https://www.ncbi.nlm.nih.gov/pubmed/12966177

In [2]:
import h5py
import numpy as np
import pandas as pd

from spikelib import spiketools as spkt
from spikelib.utils import check_directory, check_groups

%matplotlib inline

# Processing - Flash features


## General parameters

In [3]:
exp_name = 'MR-0092'
fspikes = '../data/sorting/MR-0092t2.result-1.hdf5'
foutput = '../data/processed_protocols/MR-0092t2_modified_analysis_of_protocols_150um_merge.hdf5'
fprotocol_times = '../data/sync/MR-0092t2/event_list_MR-0092t2.csv'



In [4]:
# Read protocols times from sync file
protocol_startwith = 'flash'
filter_names = ['nd2', 'nd3', 'nd4', 'nd5']
intensities = [50, 100, 150, 200, 255]

nframes = 24
protocol_times = pd.read_csv(fprotocol_times)
filter_protocols = \
    protocol_times['protocol_name'].str.startswith(protocol_startwith)
filter_frames = protocol_times['n_frames'] == nframes
filter_flash = filter_protocols & filter_frames
flash_time = protocol_times[filter_flash]

# Temporal resolution
psth_bin = 0.01  # In sec
bandwidth_fit = psth_bin
fit_resolution = 0.001  # In sec

In [24]:
def get_flash_response(spks, panalysis, prefix, intensity,
                       event_list, psth_bin, bandwidth_fit,
                       fit_resolution, sr=20000.0, offset_time = 0.0):
    """Compute from spiketime psht and estimated FR.
    
    Parameters
    ----------
    spks: h5py.Group
    panalysis: h5py.Group
    prefix: str
    intensity: int
    event_list: str
    prefix: str
    intensity: str
    psth_bin: flaot
    bandwidth_fit: float
    fit_resolution: float
    sr: float, defaul 20000.0
    offset_time: float, default 0.0

    """
    flash_name = 'flash_{}_{:d}'.format(prefix, intensity)
    fiels_df = ['start_event', 'end_event', 'start_next_event']
    filter_flash = (event_list['protocol_name'] == flash_name)
    bound_time = np.array(event_list[filter_flash][fiels_df])/sr
    # Stimulius time
    (on_dur, off_dur) = np.median(np.diff(bound_time,axis=1), axis=0) # Seconds
    start_on = offset_time
    end_on = offset_time + on_dur
    start_off = offset_time + on_dur
    end_off = offset_time + off_dur + on_dur
    total_dur = off_dur + on_dur
    (start_trials, end_trials) = bound_time[:,[0,2]].T
    ntrails = len(start_trials)
    bins_fit = np.linspace(start_on, end_off,
                            int(np.ceil(total_dur/fit_resolution))
                           )
    bins_psth = np.linspace(start_on, end_off,
                            int(np.ceil(total_dur/psth_bin))
                           )
    # Name of group in HDF5 file
    intensityg = '/flash/{}_{:03d}/'.format(prefix, intensity)
    est_respg = '/flash/{}_{:03d}/est_psth/'.format(prefix, intensity)
    psth_respg = '/flash/{}_{:03d}/psth/'.format(prefix, intensity)

    check_groups(panalysis, [est_respg, psth_respg])
    for key in spks['/spiketimes/']:
        spikes = spks['/spiketimes/'+key][...]/sr
        trials_flash = spkt.get_trials(spikes, start_trials, end_trials,
                                        offset=offset_time)
        spks_flash = spkt.flatten_trials(trials_flash)

        # Response
        (psth, _) = np.histogram(spks_flash, bins=bins_psth)
        psth = psth/float(ntrails)
        est_resp = spkt.est_pdf(trials_flash, bins_fit, bandwidth=bandwidth_fit,
                                 norm_factor=psth.max())

        if key in panalysis[est_respg]:
            panalysis[est_respg+key][...] = est_resp
        else:
            panalysis[est_respg].create_dataset(key, data=est_resp, dtype=np.float, compression='gzip')

        if key in panalysis[psth_respg]:
            panalysis[psth_respg+key][...] = psth
        else:
            panalysis[psth_respg].create_dataset(key, data=psth, dtype=np.float, compression='gzip')

    panalysis[intensityg].attrs['bounds'] = (start_on, end_on, start_off, end_off)
    panalysis[psth_respg].attrs['bins'] = bins_psth
    panalysis[est_respg].attrs['time'] = bins_fit
    panalysis[est_respg].attrs['bounds'] = (start_on, end_on, start_off, end_off)
    panalysis[est_respg].attrs['bounds_name'] = u'start_on,end_on,start_off,end_off'
    
    
def get_flash_features(panalysis, prefix, intensity, kwargs_fit, kind='estimated'):
    """Compute a set of feature from estimated psth
    
    Parametes
    ---------
    panalysis: h5py.Group
    prefix: str
    intensity: int
    kind: str
        Source to get flash features
    """
    est_respg = '/flash/{}_{:03d}/est_psth/'.format(prefix, intensity)
    psth_respg = '/flash/{}_{:03d}/psth/'.format(prefix, intensity)
    type_respg = '/flash/{}_{:03d}/type/'.format(prefix, intensity)
    char_respg = '/flash/{}_{:03d}/char/'.format(prefix, intensity)

    check_groups(panalysis, [est_respg, psth_respg, type_respg, char_respg])

    for key in panalysis[est_respg]:
        # TODO: add a if clause to select source from estimated or psth
        est_resp = panalysis[est_respg + key]
        est_time = panalysis[est_respg].attrs['time']
        bounds = panalysis[est_respg].attrs['bounds']
        type_fit, char_fit = spkt.get_features_flash(est_resp, est_time, bounds, **kwargs_fit)

        if key in panalysis[char_respg]:
            panalysis[char_respg+key][...] = char_fit
        else:
            panalysis[char_respg].create_dataset(key, data=char_fit, dtype=np.float, compression='gzip')

        if key in panalysis[type_respg]:
            panalysis[type_respg+key][...] = type_fit
        else:
            panalysis[type_respg].create_dataset(key, data=type_fit, dtype=np.int)        

        col_name = u'latency_on,latency_off,bias_idx,decay_on,decay_off,\
                    resp_index_on,resp_index_off,sust_index_on,sust_index_off,\
                    peakresp_on,peakresp_off'
        panalysis[char_respg].attrs['col_name'] = col_name
        panalysis[type_respg].attrs['type_name'] = u'null:0,on:1,off:2,onoff:3'
        panalysis[char_respg].attrs['kwargs'] = str(kwargs_fit)  

## Get PSTH and estimated firing rate

In [25]:
for prefix in filter_names:
    for intensity in intensities:
        with h5py.File(fspikes, 'r') as spks,\
             h5py.File(foutput, 'a') as panalysis:
            get_flash_response(
                spks=spks,
                panalysis=panalysis,
                prefix=prefix,
                intensity=intensity,
                event_list=flash_time,
                psth_bin=psth_bin,
                bandwidth_fit=bandwidth_fit,
                fit_resolution=fit_resolution,
                sr=20000.0
            )
    

## Compute flash features

In [28]:
#keywords for spkt.get_features_flash()
fpeak_min_time = 0.01  # min time between peaks in sec
kwargs_fit = {
    'resp_thr': 1.0/3,  # Threshold to select valid unit based on number of trials
    'bias_thr': 0.65,  # Threshold to classify into on, off, onoff
    'ri_thr': 0.3,  # Threshold for Response Index (RI)
    'ri_span': 0.1,  # Span for Response Index (RI)
    'fpeak_thr': 0.5,  # threshold of max response to find the first peak of response
    'fpeak_min_dist': int(fpeak_min_time/fit_resolution),
    'sust_time': 0.4,  # Windows time to compute Sustained index in seg
    'decrease_factor': np.e,
}


for prefix in filter_names:
    for intensity in intensities:
        with h5py.File(foutput, 'a') as panalysis:
            get_flash_features(
                panalysis=panalysis,
                prefix=prefix,
                intensity=intensity,
                kwargs_fit=kwargs_fit
            )