In [7]:
import numpy as np
from scipy.io import loadmat
from scipy.signal import filtfilt, butter

from pathlib import Path
from re import match
from datetime import datetime

import plotly.graph_objs as go
import ipywidgets as widgets
from IPython.display import display

from config import config
data_dir, _, _ = config()
data_dir /= 'trial_viz'

In [12]:
def spike_train(spikes: [float], offset, line_height=50):
    return [go.Scatter(x = [s, s], y = [offset, offset+line_height], mode = 'lines', line=dict(color='grey', width=1)) for s in spikes]
    
def lfp_trace(lfp):
    return [go.Scatter(x = np.arange(0, len(lfp))/1000, y = lfp, mode = 'lines', line=dict(color='black', width=2))]

def trial_events(events, offset, line_height=100):
    color_map = {'fixate': 'red', 
                'noise': 'green',
                'shape': 'blue',
                'saccade': 'purple'}
    return [go.Scatter(x = [v,v], 
                       y = [offset, offset+line_height], 
                       mode = 'lines', 
                       line = dict(color=color_map[k])) for k,v in events.items() if v]

def get_dates(monkey_name):
    files = sorted(Path(data_dir).glob(f'*{monkey_name.lower()}.mat'))
    date_matches = [match(r'(\d+-\d+-\d+)', str(f.name)) for f in files]
    return {datetime.strptime(d.group(0), '%Y-%m-%d').date(): files[i] 
            for i, d in enumerate(date_matches) if d}  

def load_day(date_file):
    mat = loadmat(date_file)
    N = len(mat['num'])
    lfps = [lfp[0].reshape(-1) for lfp in mat['lfp']]
    spikes = [s[0].reshape(-1) for s in mat['spikes']]
    event_names = mat['events'].dtype.names
    events = {event_name: mat['events'][event_name][0][0].reshape(-1) for event_name in event_names}
    b, a = butter(8, 0.2) # 8th order butterworth, 0.2 * 500 = 100 Hz lowpass
    return [{'lfp': lfps[i], #filtfilt(b, a, lfps[i]),
           'spikes': spikes[i],
           'events': {name: events[name][i] for name in event_names}} for i in range(N)]


monkey_name = 'Jaws'
day_data = []
fig = go.FigureWidget(layout=go.Layout(showlegend=False,
                                       xaxis=dict(showgrid=False),
                                       yaxis=dict(showgrid=False)))

def update_monkey_name(change):
    global monkey_name
    monkey_name = change['new'] or change['old']

def update_day_data(change):
    global day_data
    day = change['new'] or change['owner'].options[0]
    day_data = load_day(get_dates(monkey_name)[day])
    
def display_trial(change):
    trial_num = change['new'] or change['old']
    raw = day_data[trial_num-1]
    max_lfp, min_lfp = max(raw['lfp']), min(raw['lfp'])
    lfp_bound = max(abs(min_lfp), abs(max_lfp))
    spike_offset = lfp_bound * 1.1

    data = []
#     for i,spike in enumerate(raw['spikes']):
#         data.extend(spike_train(spike, spike_offset+i*100, line_height=100))
    data.extend(lfp_trace(raw['lfp']))
    data.extend(trial_events(raw['events'], offset=-lfp_bound, line_height=lfp_bound*2))
    fig.data = []
    fig.add_traces(data)

monkey_selector = widgets.Dropdown(options = ['Zorin', 'Jaws'],
                                  value = 'Jaws',
                                  description = 'Monkey:')
monkey_selector.observe(update_monkey_name, names='value')
update_monkey_name({'new': monkey_selector.value})

date_selector = widgets.Dropdown(description = 'Date:')
widgets.dlink((monkey_selector, 'value'), (date_selector, 'options'), lambda name: get_dates(name).keys())
widgets.dlink((monkey_selector, 'value'), (date_selector, 'value'), lambda name: list(get_dates(name).keys())[0])
date_selector.observe(update_day_data, names='value')
update_day_data({'new': date_selector.value})

trial_selector = widgets.IntSlider(description = 'Trial number:', min=1, max=len(day_data))
widgets.dlink((date_selector, 'value'), (trial_selector, 'value'), lambda day: 1)
widgets.dlink((date_selector, 'value'), (trial_selector, 'max'), lambda day: len(day_data))
trial_selector.observe(display_trial, names='value')
display_trial({'new': trial_selector.value})

def update_trial_selector(newval):
    trial_selector.value = newval

prev_trial = widgets.Button(description='Previous')
next_trial = widgets.Button(description='Next')

display(monkey_selector)
display(date_selector)
display(trial_selector)
display(prev_trial, next_trial)
prev_trial.on_click(lambda b: update_trial_selector(max(1, trial_selector.value - 1)))
next_trial.on_click(lambda b: update_trial_selector(max(1, trial_selector.value + 1)))
display(fig)

Dropdown(description='Monkey:', index=1, options=('Zorin', 'Jaws'), value='Jaws')

Dropdown(description='Date:', options=(datetime.date(2013, 5, 1), datetime.date(2013, 5, 2), datetime.date(201…

IntSlider(value=1, description='Trial number:', max=676, min=1)

Button(description='Previous', style=ButtonStyle())

Button(description='Next', style=ButtonStyle())

FigureWidget({
    'data': [{'line': {'color': 'black', 'width': 2},
              'mode': 'lines',
          …

In [13]:
def trial_to_str_rep(trial):
    lfp = [str(x) for x in trial['lfp'].round(2)]
    spikes = [[str(x) for x in unit.reshape(-1)] for unit in trial['spikes']]
    events = {k: str(v) for k,v in trial['events'].items()}
    return {'lfp': lfp, 'spikes': spikes, 'events': events}

trial_to_str_rep(day_data[0])['events']

{'fixate': 'nan', 'noise': 'nan', 'shape': 'nan', 'saccade': 'nan'}

In [6]:
[x[-1] for x in day_data[0]['spikes']]

[array([2.511], dtype=float32),
 array([2.507], dtype=float32),
 array([2.525], dtype=float32),
 array([2.501], dtype=float32),
 array([2.402], dtype=float32),
 array([2.455], dtype=float32),
 array([2.499], dtype=float32),
 array([0.815], dtype=float32),
 array([2.462], dtype=float32),
 array([2.434], dtype=float32),
 array([2.539], dtype=float32),
 array([2.391], dtype=float32)]