# Parse NIDQ or AP digital file to extract trial information

In [None]:
import spikeextractors as se
import spikewidgets as sw
import matplotlib.pyplot as plt
import numpy as np
%matplotlib notebook

In [None]:
nidq_file = "/Users/abuccino/Documents/Data/catalyst/brody/A256_2020_10_07_g0_t0.nidq.bin"
ap_file = "..."

In [None]:
rec_nidq = se.SpikeGLXRecordingExtractor(nidq_file)
fs = rec_nidq.get_sampling_frequency()

In [None]:
def get_events_for_trace(trace, n_bits, lsb):
    # get LSB (least significant bit)
    trace_int = (trace / lsb).astype(int)
    
    # extract events
    events = {}
    channels = np.arange(0, n_bits, dtype=int)
    for chan in channels:
        chan_bin = 2**chan
        bit_mask = np.bitwise_and(trace_int, chan_bin)
        high_idxs = (bit_mask == chan_bin).astype(int)
        
        if len(np.where(high_idxs != 0)[0]) > 0:
            events[chan] = {}
            rising = np.where(np.diff(high_idxs) > 0)[0]
            falling = np.where(np.diff(high_idxs) < 0)[0]  
            state = np.array([1] * len(rising) + [-1] * len(falling))

            ttl = np.concatenate((rising, falling))
            ttl_order = np.argsort(ttl)
            ttl = ttl[ttl_order]
            state = state[ttl_order]

            events[chan]['frames'] = ttl
            events[chan]['states'] = state
            
    return events

In [None]:
def parse_event_info_nidq(nidq_file, n_bits=8):
    """
    Parse nidq trace to extract event information
    
    Parameters
    ----------
    
    nidq_file: Path
        Path to nidq.bin file
    n_bits: int
        Number of bits in digital word (default 8)
        
    Returns
    -------
    events: dict
        Dictionary with channel id as key and a dictionary with
        'frames' and 'states' as values
    """
    rec_nidq = se.SpikeGLXRecordingExtractor(nidq_file)
    fs = rec_nidq.get_sampling_frequency()
    digital_trace = rec_nidq.get_traces()[0] # only one trace
    lsb = np.sort(np.unique(digital_trace))[1]
    
    events = get_events_for_trace(digital_trace, n_bits, lsb)
    print(f"Found events for channels: {list(events.keys())}")
        
    return events

In [None]:
def parse_event_info_ap(ap_file, n_bits=8, chunk_size=100000):
    """
    Parse ap digital trace (385th trace) to extract event information
    
    Parameters
    ----------
    
    ap_file: Path
        Path to ap.bin file 
    n_bits: int
        Number of bits in digital word (default 8)
    chunk_size: int
        Chunk size in number of frames
        
    Returns
    -------
    events: dict
        Dictionary with channel id as key and a dictionary with
        'frames' and 'states' as values
    """
    from spikeextractors.extraction_tools import divide_recording_into_time_chunks
    from tqdm import tqdm
    
    rec_ap = se.SpikeGLXRecordingExtractor(ap_file)
    fs = rec_nidq.get_sampling_frequency()
    
    digital_trace = rec_ap._raw[-1] # only last trace contains digital input
    
    # get chunks
    num_frames = len(digital_trace)
    chunks = divide_recording_into_time_chunks(
        num_frames=num_frames,
        chunk_size=chunk_size,
        padding_size=0
    )
    n_chunk = len(chunks)

    chunks_loop_levels = tqdm(range(n_chunk), ascii=True, desc="Extracting digital levels")
    chunks_loop_events = tqdm(range(n_chunk), ascii=True, desc="Decoding digital input")
    
    events_all = {}
    digital_levels = np.array([])
    for i in chunks_loop_levels:
        chunk = chunks[i]
        start_frame = chunk['istart']
        end_frame = chunk['iend']
        
        trace_chunk = digital_trace[start_frame:end_frame]
        digital_levels = np.concatenate((digital_levels, np.unique(trace_chunk)))
        
    lsb = np.sort(np.unique(digital_levels))[1]
    
    for i in chunks_loop_events:
        chunk = chunks[i]
        start_frame = chunk['istart']
        end_frame = chunk['iend']
        trace_chunk = digital_trace[start_frame:end_frame]
        events_chunk = get_events_for_trace(trace_chunk, n_bits, lsb)

        for chan, events in events_chunk.items():
            if chan not in events_all:
                events_all[chan] = {"frames": np.array([]),
                                    "states": np.array([])}
            current_frames = events_all[chan]["frames"]
            current_states = events_all[chan]["states"]

            events_all[chan]["frames"] = np.concatenate((current_frames, events["frames"] + start_frame))
            events_all[chan]["states"] = np.concatenate((current_states, events["states"]))    
        
    return events_all

In [None]:
events_ap = parse_event_info_ap(ap_file, chunk_size=1000000)

In [None]:
events_nidq = parse_event_info_nidq(nidq_file)

## Display events

In [None]:
for ev, values in events_nidq.items():
    print(f"Channel {ev}: {len(values['frames'])} events")

In [None]:
for ev, values in events_ap.items():
    print(f"Channel {ev}: {len(values['frames'])} events")

In [None]:
end_frame = fs*40

In [None]:
plt.figure()
timestamps = rec_nidq.frame_to_time(np.arange(end_frame-1))
plt.plot(timestamps, rec_nidq.get_traces(end_frame=end_frame)[0])

for event_channel, ttls in events.items():
    ttl = ttls["frames"]
    states = ttls["states"]
    rising = ttl[states==1]
    rising_ = rising[rising < end_frame]
    for r in rising_:
        plt.axvline(timestamps[r], color=f"C{event_channel}", ls="--")