# Parse NIDQ file to extract trial information

In [None]:
import spikeextractors as se
import spikewidgets as sw
import matplotlib.pyplot as plt
import numpy as np
%matplotlib notebook

In [None]:
nidq_file = "/Users/abuccino/Documents/Data/catalyst/A256_2020_10_07_g0_t0.nidq.bin"

In [None]:
rec_nidq = se.SpikeGLXRecordingExtractor(nidq_file)
fs = rec_nidq.get_sampling_frequency()

In [None]:
def parse_event_info(nidq_file, n_bits=8):
    """
    Parse nidq trace to extract event information
    
    Parameters
    ----------
    
    nidq_file: Path
        Path to nidq file .bin
    n_bits: int
        Number of bits in digital word (default 8)
        
    Returns
    -------
    events: dict
        Dictionary with channel id as key and a dictionary with
        'frames' and 'states' as values
    """
    rec_nidq = se.SpikeGLXRecordingExtractor(nidq_file)
    fs = rec_nidq.get_sampling_frequency()
    traces = rec_nidq.get_traces()[0] # only one trace
    
    # get LSB (least significant bit)
    lsb = np.unique(traces)[1]
    traces_int = (traces / lsb).astype(int)
    
    # extract events
    events = {}
    channels = np.arange(0, n_bits, dtype=int)
    for chan in channels:
        chan_bin = 2**chan
        bit_mask = np.bitwise_and(traces_int, chan_bin)
        high_idxs = (bit_mask == chan_bin).astype(int)
        
        if len(np.where(high_idxs != 0)[0]) > 0:
            print(f"Found events for channel {chan}")
            events[chan] = {}
            rising = np.where(np.diff(high_idxs) > 0)[0]
            falling = np.where(np.diff(high_idxs) < 0)[0]  
            state = np.array([1] * len(rising) + [-1] * len(falling))

            ttl = np.concatenate((rising, falling))
            ttl_order = np.argsort(ttl)
            ttl = ttl[ttl_order]
            state = state[ttl_order]

            events[chan]['frames'] = ttl
            events[chan]['states'] = state
        
    return events

In [None]:
events = parse_event_info(nidq_file)

## Display events

In [None]:
for ev, values in events.items():
    print(f"Channel {ev}: {len(values['frames'])} events")

In [None]:
end_frame = fs*40

In [None]:
plt.figure()
timestamps = rec_nidq.frame_to_time(np.arange(end_frame-1))
plt.plot(timestamps, rec_nidq.get_traces(end_frame=end_frame)[0])

for event_channel, ttls in events.items():
    ttl = ttls["frames"]
    states = ttls["states"]
    rising = ttl[states==1]
    rising_ = rising[rising < end_frame]
    for r in rising_:
        plt.axvline(timestamps[r], color=f"C{event_channel}", ls="--")