In [1]:
import numpy as np
import h5py
import cv2
from joblib import load
import matplotlib.pyplot as plt
#from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
#from mpl_toolkits.mplot3d import Axes3D
# supress color warning when making movie
#from matplotlib.axes._axes import _log as matplotlib_axes_logger
#matplotlib_axes_logger.setLevel('ERROR')
%matplotlib widget
import time
%load_ext autoreload
import os
%autoreload 2
import sys
sys.path.append('./lib/')

from ipywidgets import Video, Image, VBox, Text
from sidecar import Sidecar
from IPython.display import display
from scipy.interpolate import interp1d
from scipy import interpolate
import numpy as np
import cv2
import base64

from ipywidgets import interact, interactive, fixed, interact_manual, FloatSlider, Layout, AppLayout

import ipywidgets as widgets

#from bqplot import pyplot as plt

from DataManagement import SplitdataManager, Filepaths, Calibration

In [2]:
# the path to the experiment
experiment_path = '/media/stephens-group/guest_drive/labelling/complete_shortfin_experiment/FishTank20200130_153857/'  
experiment_path

'/media/stephens-group/guest_drive/labelling/complete_shortfin_experiment/FishTank20200130_153857/'

In [3]:
# make a splitdata manager
splitman = SplitdataManager(experiment_path)

In [4]:
# make a calibration object
calibrationFolderPath = '/home/liam/Data/experiments/FishTank20200130_153857/' + '20200120_calibration/'
cal = Calibration(calibrationFolderPath)



In [5]:
# gather the names of the splitdata folders with corresponding indices
splitdata_names = [triplet[0].split('/')[-1] for triplet in splitman.splitdata_paths]
splitdata_idxs = [i for i in range(len(splitdata_names))]
splitdata_name_to_idx_dict = dict(zip(splitdata_names, splitdata_idxs))
splitdata_name_to_paths = dict(zip(splitdata_names, splitman.splitdata_paths))
totalnumFrames = splitman.start_stop_frames_for_splitdata[-1, 1]


In [6]:
# load the trajectories from sLEAP tracked in the old way
trackTestPath = '/home/liam/Data/results_from_laetitia/FishTank20200130_153857/old_tracking_on_original_instances_sLEAP.h5'
with h5py.File(trackTestPath, 'r') as hf:
    tracks_3D_test = hf['tracks_3D'][:]
    
    
# load the new networks WITH tatsuos annotations for training, after conversion to original_instances format
inst_path = '/home/liam/Data/testing_network_accuracy_in_2_fish/original_instances_sLEAP.h5'
with h5py.File(inst_path, 'r') as hf:
    original_instances_new = hf['original_instances'][:]
    
# load the tracked image coordiantes that have gone through the calibration
with h5py.File('/home/liam/Data/results_from_laetitia/FishTank20200130_153857/old_tracking_on_original_instances_sLEAP_imCoords.h5', 'r') as hf:
    tracks_imCoords = hf['tracks_imCoords'][:]
    
# get the total number of frames in the experiment
totalnumFrames = tracks_3D_test.shape[0]

# other params
numFish = 2
numBodyPoints = 3
numCams = 3

# Making the Sidecar dashboard

In [7]:
## ---- Make the first panels of the dashboard --- ##

# make a dropdown menue to choose which splitdata you want
splitdata_chooser = widgets.Dropdown(options=splitdata_name_to_idx_dict, value=0)

# create and initialize widget to hold the splitdata_idx
splitdata_idx = widgets.IntText(value=0, description='splitdata_idx', layout=Layout(width='50%', height='50px'))
    
# create and initialize widget to hold the filepaths of the movies
xzMovPath =  widgets.Text(value=splitman.splitdata_paths[splitdata_idx.value][0] + '.mp4')
xyMovPath =  widgets.Text(value=splitman.splitdata_paths[splitdata_idx.value][1] + '.mp4')
yzMovPath =  widgets.Text(value=splitman.splitdata_paths[splitdata_idx.value][2] + '.mp4')

# create widgets to hold the global and local frame index
global_index  = widgets.IntText(value=0, description='Global fIdx', layout=Layout(width='70%', height='50px'))
local_index = widgets.IntText(value=0, description='Local fIdx', layout=Layout(width='70%', height='50px'))

# make a Global fIdx slide, and link it to the global frame number
global_index_slider = widgets.IntSlider(value=0, min=0, max=totalnumFrames-1, step=1, continuous_update=True,
                                       description='Frame', layout=Layout(width='90%', height='50px'))
gf_link = widgets.link((global_index_slider, 'value'), (global_index, 'value'))

    
def handle_splitdata_name_change(change):
    ''' Make required updates when we select a splitdata**** with the dropdown menu
    '''
    # update splitdata index
    splitdata_idx.value = change.new
    # update moviePaths
    xzMovPath.value = splitman.splitdata_paths[splitdata_idx.value][0] + '.mp4'
    xyMovPath.value = splitman.splitdata_paths[splitdata_idx.value][1] + '.mp4'
    yzMovPath.value = splitman.splitdata_paths[splitdata_idx.value][2] + '.mp4'
    # update the global index given the new splitdata index and the current local idx
    splitdata_start = splitman.start_stop_frames_for_splitdata[splitdata_idx.value][0]
    global_index.value = splitdata_start + local_index.value
# apply the function
splitdata_chooser.observe(handle_splitdata_name_change, names='value')



def handle_local_index_change(change):
    ''' Make the required changes when we update the "Local fIdx"
    '''
    # get the start-stop for the current splitdata index
    splitdata_start_stop = splitman.start_stop_frames_for_splitdata[splitdata_idx.value]
    splitdata_numFrames = splitdata_start_stop[1] - splitdata_start_stop[0]
    # find the new local index
    local_index.value = change.new
    # now make sure that the new index is within the range
    local_index.value = np.mod(local_index.value, splitdata_numFrames)
    # update global_frame
    global_index.value = splitdata_start_stop[0] + local_index.value
# apply the function
local_index.observe(handle_local_index_change, names='value')



def handle_global_index_change(change):
    ''' Make the required changes when we update the "Global fIdx"
    '''
    # get the start-stop for the current splitdata index
    splitdata_start_stop = splitman.start_stop_frames_for_splitdata[splitdata_idx.value]
    splitdata_numFrames = splitdata_start_stop[1] - splitdata_start_stop[0]
    # find the new global index
    if change.new >= totalnumFrames:
        global_index.value = totalnumFrames - 1
    else:
        global_index.value = change.new
    # get the current splitdata index and local index for this frame
    spIdx, locIdx = splitman.return_splitdata_folder_and_local_idx_for_global_frameIdx(global_index.value, return_splitdataIdx=True)
    local_index.value = locIdx
    splitdata_idx.value = spIdx
    splitdata_chooser.label = splitdata_names[spIdx]
    splitdata_chooser.index = spIdx
# apply the function
global_index.observe(handle_global_index_change, names='value')   
    
    
def handle_splitdata_idx_box_change(change):
    ''' Make the required changes when we update the "splitdata_idx" 
    '''
    # first, we must make sure that this new change is within the splitdatas available
    numSplitdatas = len(splitman.splitdata_paths)
    if change.new == -1:
        splitdata_idx.value = numSplitdatas - 1
    elif change.new == numSplitdatas:
        splitdata_idx.value = 0
    else:
        splitdata_idx.value = change.new
    # update the dropdown box
    splitdata_chooser.label = splitdata_names[splitdata_idx.value]
    splitdata_chooser.index = splitdata_idx.value
# apply
splitdata_idx.observe(handle_splitdata_idx_box_change, names='value')



# --- Make variables to hold fishIdx and bpIdx for when we are clicking --- #
camIdx_widget = widgets.IntText(value=-1, description='camIdx')
bpIdx_widget = widgets.IntText(value=-1, description='bpIdx')
fishIdx_widget = widgets.IntText(value=-1, description='fishIdx')



# ---- make the final vertical box for sidecar and display --- ##

splitdata_details_box = VBox([splitdata_chooser, 
                              splitdata_idx, 
                              global_index, 
                              local_index, 
                              global_index_slider,
                              camIdx_widget,
                              bpIdx_widget,
                              fishIdx_widget])



sc = Sidecar(title='Sidecar Output')
with sc:
    display(splitdata_details_box)
    

# Interactive plot of raw data in a range

In [8]:
tracks_3D = np.copy(tracks_3D_test)

In [9]:

plt.ioff()


# set the slider we are using
slider = global_index_slider

fig, (ax1, ax2) = plt.subplots(2, sharex=True, gridspec_kw={'hspace': 0.2})
fig.figsize = (18,10)
#fig.legend(handles, labels, loc='upper center')
fig.suptitle('Trajectory Data Timeseries')
fig.canvas.header_visible = False
fig.canvas.layout.min_height = '400px'
fig.canvas.layout.width = '100%'
ax1.grid(which='both')
ax2.grid(which='both')
ax1.ticklabel_format(useOffset=False)
ax2.ticklabel_format(useOffset=False)
plot_frame_width = 50

#plt.xticks(np.arange(0, totalnumFrames, 100))

# ---- define the info for each plot ----- #
# plot 1 - heads
bpIdx = 0
p1_data = np.copy(tracks_3D[:, :, bpIdx, :])

# plot 2 - pecs
bpIdx = 2
p2_data = np.copy(tracks_3D[:, :, bpIdx, :])

colors = ['lightcoral', 'brown', 'maroon', 
          'forestgreen', 'darkgreen', 'limegreen']


# ----- define the line plots ------ #
lines1_x1 = ax1.plot(p1_data[:, 0, 0], label='fish1_head_x', color=colors[0])
lines1_y1 = ax1.plot(p1_data[:, 0, 1], label='fish1_head_y', color=colors[1])
lines1_z1 = ax1.plot(p1_data[:, 0, 2], label='fish1_head_z', color=colors[2])

lines1_x2 = ax1.plot(p1_data[:, 1, 0], label='fish2_head_x', color=colors[3])
lines1_y2 = ax1.plot(p1_data[:, 1, 1], label='fish2_head_y', color=colors[4])
lines1_z2 = ax1.plot(p1_data[:, 1, 2], label='fish2_head_z', color=colors[5])



lines2_x1 = ax2.plot(p2_data[:, 0, 0], label='fish1_x', color=colors[0])
lines2_y1 = ax2.plot(p2_data[:, 0, 1], label='fish1_y', color=colors[1])
lines2_z1 = ax2.plot(p2_data[:, 0, 2], label='fish1_z', color=colors[2])

lines2_x2 = ax2.plot(p2_data[:, 1, 0], label='fish2_x', color=colors[3])
lines2_y2 = ax2.plot(p2_data[:, 1, 1], label='fish2_y', color=colors[4])
lines2_z2 = ax2.plot(p2_data[:, 1, 2], label='fish2_z', color=colors[5])


#ax1.legend(loc='upper right')
#ax2.legend(loc='upper right')


#lines2 = ax2.plot(p2_data)

ax1.set_title('Heads: frame = {0}'.format(slider.value))
ax2.set_title('Tails)')
#ax2.set_title('Plotting: global_fIdx = {0})'.format(slider.value))

ax1.set_ylim(-1, 43)
ax2.set_ylim(-1, 43)

ax1.set_yticks([i for i in range(0,41,5)], minor=True)
ax2.set_yticks([i for i in range(0,41,5)], minor=True)


#ax2.set_yticklabels([i for i in range(0, 40, 5)])

ax1.set_ylabel('Head Position (cm)')
ax2.set_ylabel('Tail Position (cm)')

# use any axis here to get the labels
handles, labels = ax2.get_legend_handles_labels()
fig.legend(handles, labels, loc='center right')

if slider.value < plot_frame_width:
    ax1.set_xlim(0, 2*plot_frame_width)
    ax2.set_xlim(0, 2*plot_frame_width)
elif totalnumFrames - slider.value < plot_frame_width:
    ax1.set_xlim(totalnumFrames-2*plot_frame_width, totalnumFrames)
    ax2.set_xlim(totalnumFrames-2*plot_frame_width, totalnumFrames)
else:
    ax1.set_xlim(slider.value-plot_frame_width, slider.value+plot_frame_width)
    ax2.set_xlim(slider.value-plot_frame_width, slider.value+plot_frame_width)



def update_lines(change):
    time.sleep(0.01)
    ax1.set_title('Heads: frame = {0})'.format(change.new))
    ax2.set_title('Tails')
    #ax2.set_title('Plotting: global_fIdx = {0})'.format(change.new))
    if change.new < plot_frame_width:
        ax1.set_xlim(0, 2*plot_frame_width)
        ax2.set_xlim(0, 2*plot_frame_width)
    elif totalnumFrames - change.new < plot_frame_width:
        ax1.set_xlim(totalnumFrames-2*plot_frame_width, totalnumFrames)
        ax2.set_xlim(totalnumFrames-2*plot_frame_width, totalnumFrames)
    else:
        ax1.set_xlim(slider.value-plot_frame_width, slider.value+plot_frame_width)
        ax2.set_xlim(slider.value-plot_frame_width, slider.value+plot_frame_width)
    #lines[0].set_data(timearr, data)
    #plt.xlim(timearr[0], timearr[-1])
    #fig.clf()
    fig.canvas.draw()
    fig.canvas.flush_events()
    
slider.observe(update_lines, names='value')

layout = widgets.Layout(width='100%') #set width and height

plot_widget = VBox([fig.canvas, slider], layout=layout)

# aaa= AppLayout(
#     center=fig.canvas,
#     footer=slider)
#)

plot_widget

VBox(children=(Canvas(header_visible=False, layout=Layout(min_height='400px', width='100%'), toolbar=Toolbar(t…

# Make widgets to hold tracking results

In [10]:
# ----- Grab the image coordinates and 3D coordinats for this frame ---- #
number_layout = widgets.Layout(width='70px')
text_layout = widgets.Layout(width='120px')

# default positions of bodypoints in cropped coords
p1_init = np.array([20, 20])
p2_init = np.array([80, 80])
p3_init = np.array([140, 140])
def_cropped_pos = [p1_init, p2_init, p3_init]

# preallocate list structures to hold the numbers
frame_imCoords_widget_list = [[[[ [] for _ in range(2)] for _ in range(numBodyPoints)] for _ in range(numFish)] for _ in range(numCams)]
frame_3D_widget_list = [ [[[] for _ in range(3)] for _ in range(numBodyPoints)] for _ in range(numFish)]
# fill in the widgets with initial values
for camIdx in range(numCams):
    for fishIdx in range(numFish):
        for bpIdx in range(numBodyPoints):
            for imCoordIdx in range(2):
                wid_val = np.round(np.copy(tracks_imCoords[camIdx, global_index_slider.value, fishIdx, bpIdx, imCoordIdx]),2)
                val_widge_coord = widgets.FloatText(value=wid_val,layout=number_layout)
                frame_imCoords_widget_list[camIdx][fishIdx][bpIdx][imCoordIdx] = val_widge_coord
                
for fishIdx in range(numFish):
    for bpIdx in range(numBodyPoints):
        for dimIdx in range(3):
            wid_val = np.round(np.copy(tracks_3D[global_index_slider.value, fishIdx, bpIdx, dimIdx]), 2)
            val_widge_3D = widgets.FloatText(value=wid_val,layout=number_layout)
            frame_3D_widget_list[fishIdx][bpIdx][dimIdx] = val_widge_3D


In [11]:
# ---- Make widgets to vizualize the 3D positions --- #
number_layout = widgets.Layout(width='80px')
text_layout = widgets.Layout(width='120px')

fishIdx = 0
f1_bp_box_list = []
for bpIdx in range(numBodyPoints):
    f1_bp_XYZ_box = widgets.HBox(frame_3D_widget_list[fishIdx][bpIdx][:], layout=widgets.Layout(padding=('0px 30px 0 0')))
    f1_bp_box_list.append(f1_bp_XYZ_box)
f1_3D_wid_title = widgets.Text('Fish1 3D bps', layout=text_layout)
f1_3D_wid = widgets.VBox([f1_3D_wid_title]+f1_bp_box_list)


fishIdx = 1
f2_bp_box_list = []
for bpIdx in range(numBodyPoints):
    f2_bp_XYZ_box = widgets.HBox(frame_3D_widget_list[fishIdx][bpIdx][:], layout=widgets.Layout(padding=('0px 10px 0 0')))
    f2_bp_box_list.append(f2_bp_XYZ_box)
f2_3D_wid_title = widgets.Text('Fish2 3D bps', layout=text_layout)
f2_3D_wid = widgets.VBox([f2_3D_wid_title]+f2_bp_box_list)

# Final widget
tracks_3D_frame_widget = widgets.HBox([f1_3D_wid, f2_3D_wid], layout=widgets.Layout(padding=('0px 0px 30px 200px')))
#tracks_3D_frame_widget

In [12]:
# ---- Make widgets to vizualize the image coordinates --- #
number_layout = widgets.Layout(width='80px')
text_layout = widgets.Layout(width='120px')
cam_title_layout = widgets.Layout(width='160px')
title_layout = widgets.Layout(width='480px')
camNames = ['xz', 'xy', 'yz']

fishIdx = 0
cam_bp_imcoord_wid_list = []
for camIdx in range(3):
    cam_bp_list = []
    for bpIdx in range(numBodyPoints):
        bp_imcoords_box = widgets.HBox(frame_imCoords_widget_list[camIdx][fishIdx][bpIdx][:])
        cam_bp_list.append(bp_imcoords_box)
    cam_widget_title = widgets.Text(camNames[camIdx], layout=cam_title_layout)
    cam_widget = widgets.VBox([cam_widget_title]+cam_bp_list)
    cam_bp_imcoord_wid_list.append(cam_widget)
f1_imcoord_wid_title = widgets.Text('Fish1 bodypoint image coordinates', layout=title_layout)
f1_imcoord_wid_data = widgets.HBox(cam_bp_imcoord_wid_list)
f1_imcoord_wid = widgets.VBox([f1_imcoord_wid_title, f1_imcoord_wid_data])


fishIdx = 1
cam_bp_imcoord_wid_list = []
for camIdx in range(3):
    cam_bp_list = []
    for bpIdx in range(numBodyPoints):
        bp_imcoords_box = widgets.HBox(frame_imCoords_widget_list[camIdx][fishIdx][bpIdx][:])
        cam_bp_list.append(bp_imcoords_box)
    cam_widget_title = widgets.Text(camNames[camIdx], layout=cam_title_layout)
    cam_widget = widgets.VBox([cam_widget_title]+cam_bp_list)
    cam_bp_imcoord_wid_list.append(cam_widget)
f2_imcoord_wid_title = widgets.Text('Fish2 bodypoint image coordinates', layout=title_layout)
f2_imcoord_wid_data = widgets.HBox(cam_bp_imcoord_wid_list)
f2_imcoord_wid = widgets.VBox([f2_imcoord_wid_title, f2_imcoord_wid_data])


# final widget
tracks_imCoord_frame_widget = widgets.HBox([f1_imcoord_wid, f2_imcoord_wid])
#tracks_imCoord_frame_widget

In [13]:
# Now make an update function for the data


def update_data(change):
    time.sleep(0.01)
    
    # fill in the widgets with initial values
    for camIdx in range(numCams):
        for fishIdx in range(numFish):
            for bpIdx in range(numBodyPoints):
                for imCoordIdx in range(2):
                    wid_val = np.round(np.copy(tracks_imCoords[camIdx, change.new, fishIdx, bpIdx, imCoordIdx]),2)
                    frame_imCoords_widget_list[camIdx][fishIdx][bpIdx][imCoordIdx].value = wid_val

    for fishIdx in range(numFish):
        for bpIdx in range(numBodyPoints):
            for dimIdx in range(3):
                wid_val = np.round(np.copy(tracks_3D[change.new, fishIdx, bpIdx, dimIdx]), 2)
                frame_3D_widget_list[fishIdx][bpIdx][dimIdx].value = wid_val
    
    
global_index_slider.observe(update_data, names='value')

In [14]:
# make the final coordinate widget
coordinate_widget = widgets.VBox([tracks_3D_frame_widget, tracks_imCoord_frame_widget])
coordinate_widget

VBox(children=(HBox(children=(VBox(children=(Text(value='Fish1 3D bps', layout=Layout(width='120px')), HBox(ch…

# Loading images

In [15]:
## ---- Interpolate the heads for making movies without gaps ----- ##

def fill_nan(A):
    '''
    Thanks - https://stackoverflow.com/a/9815522
    interpolate to fill nan values
    '''
    inds = np.arange(A.shape[0])
    good = np.where(np.isfinite(A))
    f = interpolate.interp1d(inds[good], A[good],bounds_error=False)
    B = np.where(np.isfinite(A),A,f(inds))
    return B

# interpolate the fish_pecs so we dont blink
xz_heads = np.copy(original_instances_new[0,:,:,0,:])
xy_heads = np.copy(original_instances_new[1,:,:,0,:])
yz_heads = np.copy(original_instances_new[2,:,:,0,:])

xz_heads_interpd = np.copy(xz_heads)
xy_heads_interpd = np.copy(xy_heads)
yz_heads_interpd = np.copy(yz_heads)

numFish=2
for fishIdx in range(numFish):
    
    xz_heads_x = np.copy(xz_heads[:, fishIdx, 0])
    xz_heads_y = np.copy(xz_heads[:, fishIdx, 1])
    xz_heads_x_interpd = fill_nan(xz_heads_x)
    xz_heads_y_interpd = fill_nan(xz_heads_y)
    xz_heads_interpd[:, fishIdx, 0] = xz_heads_x_interpd
    xz_heads_interpd[:, fishIdx, 1] = xz_heads_y_interpd
    
    xy_heads_x = np.copy(xy_heads[:, fishIdx, 0])
    xy_heads_y = np.copy(xy_heads[:, fishIdx, 1])
    xy_heads_x_interpd = fill_nan(xy_heads_x)
    xy_heads_y_interpd = fill_nan(xy_heads_y)
    xy_heads_interpd[:, fishIdx, 0] = xy_heads_x_interpd
    xy_heads_interpd[:, fishIdx, 1] = xy_heads_y_interpd
    
    yz_heads_x = np.copy(yz_heads[:, fishIdx, 0])
    yz_heads_y = np.copy(yz_heads[:, fishIdx, 1])
    yz_heads_x_interpd = fill_nan(yz_heads_x)
    yz_heads_y_interpd = fill_nan(yz_heads_y)
    yz_heads_interpd[:, fishIdx, 0] = yz_heads_x_interpd
    yz_heads_interpd[:, fishIdx, 1] = yz_heads_y_interpd
    
    
# make a list of crop_centers for the frames
crop_centers = [xz_heads_interpd, xy_heads_interpd, yz_heads_interpd]

In [16]:
# def _pad_frame(frame, padding_depth):
#     ''' Return a padded version of the frame '''
#     # grab the padding from the image
#     padding = frame[0:padding_depth,:]
#     # make the padded frame
#     frame = np.vstack((padding, frame, padding))
#     return frame

# make a list of colors
no_id_color = (255,255,255) # white
edges = [[0,1], [1,2]]
padding_depth = 70
bw = 20
cropSize = (80,80)

plt.ioff()

fig1, (f1_ax1, f1_ax2, f1_ax3) = plt.subplots(1, 3, figsize=(9,3))
fig2, (f2_ax1, f2_ax2, f2_ax3) = plt.subplots(1, 3, figsize=(9,3))

f1_axs = [f1_ax1, f1_ax2, f1_ax3]
f2_axs = [f2_ax1, f2_ax2, f2_ax3]
f_axs = [f1_axs, f2_axs]


f1_ax1.set_title('Fish1_xz')
f1_ax2.set_title('Fish1_xy')
f1_ax3.set_title('Fish1_yz')

f2_ax1.set_title('Fish2_xz')
f2_ax2.set_title('Fish2_xy')
f2_ax3.set_title('Fish2_yz')

fig1.canvas.header_visible = False
fig1.canvas.layout.min_height = '400px'
fig1.canvas.layout.width = '100%'

fig2.canvas.header_visible = False
fig2.canvas.layout.min_height = '400px'
fig2.canvas.layout.width = '100%'

#---- TEMP: ADDING CLICKING TO THE FIGURES ------ #
old_coord_widget_x = widgets.FloatText(value=0)
old_coord_widget_y = widgets.FloatText(value=0)

new_coord_widget_x = widgets.FloatText(value=0)
new_coord_widget_y = widgets.FloatText(value=0)


def click_down_event_handler(event):
    old_coord_widget_x.value = new_coord_widget_x.value   
    old_coord_widget_y.value = new_coord_widget_y.value  
    
    # find the fishIdx and camIdx image we clicked on
    for fishIdx in range(numFish):
        for camIdx in range(numCams):
            comp_ax = f_axs[fishIdx][camIdx]
            if event.inaxes == comp_ax:
                camIdx_widget.value = camIdx
            
    # find the closest image coordinate, wihtin the theshold
    find_idx_of_closest_point(event.xdata, event.ydata, tolerance=20)

    
def find_idx_of_closest_point(x, y, tolerance=15):
    ''' Return the bodypoint corresponding to this click
    '''
    #time.sleep(0.01)
    camIdx = camIdx_widget.value
    
    distances = [[] for _ in range(numFish)]
    
    for fishIdx in range(2):
        for bpIdx in range(numBodyPoints):
            xp = frame_imCoords_widget_list[camIdx][fishIdx][bpIdx][0].value 
            yp = frame_imCoords_widget_list[camIdx][fishIdx][bpIdx][1].value 
            dist = (((x-xp)**2) + ((y-yp)**2))**0.5
            distances[fishIdx].append(dist)
            
    # the fish with the lowest cost is the fishIdx
    best_fishIdx = np.nanargmin([np.min(fish_distances) for fish_distances in distances])
    best_bpIdx = np.nanargmin(distances[best_fishIdx])
    min_dist = np.nanmin(distances[best_fishIdx])
    
    if min_dist < tolerance:
        bpIdx_widget.value = best_bpIdx
        fishIdx_widget.value = best_fishIdx
    else:
        bpIdx_widget.value = -1
        fishIdx_widget.value = -1

         
def click_release_event_handler(event):
    new_coord_widget_x.value = event.xdata 
    new_coord_widget_y.value = event.ydata 
    return 


fig1.canvas.mpl_connect('button_press_event', click_down_event_handler)
fig1.canvas.mpl_connect('button_release_event', click_release_event_handler)

fig2.canvas.mpl_connect('button_press_event', click_down_event_handler)
fig2.canvas.mpl_connect('button_release_event', click_release_event_handler)





#  -------------- get the initial frame data -------------------#
video_paths = [xzMovPath, xyMovPath, yzMovPath]
caps = [cv2.VideoCapture(video_path.value) for video_path in video_paths]
[cap.set(1,local_index.value) for cap in caps] #go to current frame
# grab the frame from each camera
if caps[0].isOpened() and caps[1].isOpened() and caps[2].isOpened():  
    # grab the 3 images
    ret_xz, frame_xz = caps[0].read()
    if not ret_xz:
        raise TypeError('xz movie for splitdata_idx={0} not opening'.format(splitdata_idx.value))
    ret_xy, frame_xy = caps[1].read()
    if not ret_xy:
        raise TypeError('xy movie for splitdata_idx={0} not opening'.format(splitdata_idx.value))
    ret_yz, frame_yz = caps[2].read()
    if not ret_yz:
        raise TypeError('yz movie for splitdata_idx={0} not opening'.format(splitdata_idx.value))
    frames = [frame_xz, frame_xy, frame_yz]
# close everything before we finish
[caps[capIdx].release() for capIdx in range(3)];




# --------------- Draw the skeletons --------------------#


# draw the info
for camIdx in range(3):
    frame = frames[camIdx]
    #frame_cam_tracks = original_instances_new[camIdx, t]
    for fishIdx in range(numFish):
        
        # draw the bodypoints
        for bpIdx in range(numBodyPoints):
            # get the data
            imCoord_x = frame_imCoords_widget_list[camIdx][fishIdx][bpIdx][0].value
            imCoord_y = frame_imCoords_widget_list[camIdx][fishIdx][bpIdx][1].value
            imCoord = np.array([imCoord_x, imCoord_y])
            #get the color
            color = ( int (no_id_color[ 0 ]), int(no_id_color[ 1 ]), int(no_id_color[ 2 ]))
            if np.all(np.isnan(imCoord)):
                continue
            cv2.circle(frame, (int(imCoord[0]), int(imCoord[1])), radius=2, color=color, thickness=-1)
            
        # draw the lines
        for edge in edges:
            pt1_bpIdx = edge[0]
            pt2_bpIdx = edge[1]
            
            pt1_imCoord_x = frame_imCoords_widget_list[camIdx][fishIdx][pt1_bpIdx][0].value
            pt1_imCoord_y = frame_imCoords_widget_list[camIdx][fishIdx][pt1_bpIdx][1].value
            pt1_imCoord = np.array([pt1_imCoord_x, pt1_imCoord_y])
            
            pt2_imCoord_x = frame_imCoords_widget_list[camIdx][fishIdx][pt2_bpIdx][0].value
            pt2_imCoord_y = frame_imCoords_widget_list[camIdx][fishIdx][pt2_bpIdx][1].value
            pt2_imCoord = np.array([pt2_imCoord_x, pt2_imCoord_y])

            # parse the points for format and emptiness
            if np.all(np.isnan(pt1_imCoord)):
                pt1 =tuple([np.NaN, np.NaN])
            else:
                pt1 = tuple([int(x) for x in pt1_imCoord])
            if np.all(np.isnan(pt2_imCoord)):
                pt2 =tuple([np.NaN, np.NaN])
            else:
                pt2 = tuple([int(x) for x in pt2_imCoord])
            if ~np.all(np.isnan(pt1)) and ~np.all(np.isnan(pt2)):
                cv2.line(frame, pt1=pt1,  pt2=pt2, color=color, thickness=1, lineType=cv2.LINE_AA)




# ------ plot the images --------- #

f1_panel1 = f1_ax1.imshow(frames[0])
f1_panel2 = f1_ax2.imshow(frames[1])
f1_panel3 = f1_ax3.imshow(frames[2])
f1_panels = [f1_panel1, f1_panel2, f1_panel3]

f2_panel1 = f2_ax1.imshow(frames[0])
f2_panel2 = f2_ax2.imshow(frames[1])
f2_panel3 = f2_ax3.imshow(frames[2])
f2_panels = [f2_panel1, f2_panel2, f2_panel3]

f_panels = [f1_panels, f2_panels]


# set the images limits (the default zoom)
imshow_ranges = [[ [] for _ in range(numCams)] for _ in range(numFish)]
for fishIdx in range(numFish):
    for camIdx in range(3):
        cc = crop_centers[camIdx][global_index_slider.value, fishIdx, :]
        if ~np.all(np.isnan(cc)):
            f_axs[fishIdx][camIdx].set_xlim(cc[0]-int(cropSize[0]), cc[0]+int(cropSize[0]))
            f_axs[fishIdx][camIdx].set_ylim(cc[1]-int(cropSize[1]), cc[1]+int(cropSize[1]))




# # plot the default positions
# for fishIdx in range(numFish):
#     for camIdx in range(3):
#         xs = []
#         ys = []
#         for bpIdx in range(numBodyPoints):
#             x = frame_imCoords_widget_list[camIdx][fishIdx][bpIdx][0].value
#             y = frame_imCoords_widget_list[camIdx][fishIdx][bpIdx][1].value
#             xs.append(x)
#             ys.append(y)
#         # draw the bodypoints
#         f_axs[fishIdx][camIdx].scatter(xs,ys,s=8, color='blue')
#         # draw the lines
#         f_axs[fishIdx][camIdx].plot(xs,ys, color='blue')


        
        
def update_crop_figures(change):
    caps = [cv2.VideoCapture(video_path.value) for video_path in video_paths]
    [cap.set(1,local_index.value) for cap in caps] #go to current frame
    # grab the frame from each camera
    if caps[0].isOpened() and caps[1].isOpened() and caps[2].isOpened():  
        # grab the 3 images
        ret_xz, frame_xz = caps[0].read()
        if not ret_xz:
            raise TypeError('xz movie for splitdata_idx={0} not opening'.format(splitdata_idx.value))
        ret_xy, frame_xy = caps[1].read()
        if not ret_xy:
            raise TypeError('xy movie for splitdata_idx={0} not opening'.format(splitdata_idx.value))
        ret_yz, frame_yz = caps[2].read()
        if not ret_yz:
            raise TypeError('yz movie for splitdata_idx={0} not opening'.format(splitdata_idx.value))
        frames = [frame_xz, frame_xy, frame_yz]
    # release the caps
    [caps[capIdx].release() for capIdx in range(3)];
    
    # --------------- Draw the skeletons --------------------#
    # draw the info
    for camIdx in range(3):
        frame = frames[camIdx]
        #frame_cam_tracks = original_instances_new[camIdx, t]
        for fishIdx in range(numFish):

            # draw the bodypoints
            for bpIdx in range(numBodyPoints):
                # get the data
                imCoord_x = frame_imCoords_widget_list[camIdx][fishIdx][bpIdx][0].value
                imCoord_y = frame_imCoords_widget_list[camIdx][fishIdx][bpIdx][1].value
                imCoord = np.array([imCoord_x, imCoord_y])
                #get the color
                color = ( int (no_id_color[ 0 ]), int(no_id_color[ 1 ]), int(no_id_color[ 2 ]))
                if np.all(np.isnan(imCoord)):
                    continue
                cv2.circle(frame, (int(imCoord[0]), int(imCoord[1])), radius=2, color=color, thickness=-1)

            # draw the lines
            for edge in edges:
                pt1_bpIdx = edge[0]
                pt2_bpIdx = edge[1]

                pt1_imCoord_x = frame_imCoords_widget_list[camIdx][fishIdx][pt1_bpIdx][0].value
                pt1_imCoord_y = frame_imCoords_widget_list[camIdx][fishIdx][pt1_bpIdx][1].value
                pt1_imCoord = np.array([pt1_imCoord_x, pt1_imCoord_y])

                pt2_imCoord_x = frame_imCoords_widget_list[camIdx][fishIdx][pt2_bpIdx][0].value
                pt2_imCoord_y = frame_imCoords_widget_list[camIdx][fishIdx][pt2_bpIdx][1].value
                pt2_imCoord = np.array([pt2_imCoord_x, pt2_imCoord_y])

                # parse the points for format and emptiness
                if np.all(np.isnan(pt1_imCoord)):
                    pt1 =tuple([np.NaN, np.NaN])
                else:
                    pt1 = tuple([int(x) for x in pt1_imCoord])
                if np.all(np.isnan(pt2_imCoord)):
                    pt2 =tuple([np.NaN, np.NaN])
                else:
                    pt2 = tuple([int(x) for x in pt2_imCoord])
                if ~np.all(np.isnan(pt1)) and ~np.all(np.isnan(pt2)):
                    cv2.line(frame, pt1=pt1,  pt2=pt2, color=color, thickness=1, lineType=cv2.LINE_AA)
    # ------------------------------------------------------------#
            
    # now update the figures
    f1_panel1.set_data(frames[0])
    f1_panel2.set_data(frames[1])
    f1_panel3.set_data(frames[2])
    
    f2_panel1.set_data(frames[0])
    f2_panel2.set_data(frames[1])
    f2_panel3.set_data(frames[2])
    
    
    # set the images limits (the default zoom)
    imshow_ranges = [[ [] for _ in range(numCams)] for _ in range(numFish)]
    for fishIdx in range(numFish):
        for camIdx in range(3):
            cc = crop_centers[camIdx][global_index_slider.value, fishIdx, :]
            if ~np.all(np.isnan(cc)):
                f_axs[fishIdx][camIdx].set_xlim(cc[0]-int(cropSize[0]), cc[0]+int(cropSize[0]))
                f_axs[fishIdx][camIdx].set_ylim(cc[1]-int(cropSize[1]), cc[1]+int(cropSize[1]))
    
    fig1.canvas.draw()
    fig1.canvas.flush_events()
    fig2.canvas.draw()
    fig2.canvas.flush_events()
    return
        
        
        
global_index_slider.observe(update_crop_figures, names='value')


image_widget = widgets.VBox([fig1.canvas, fig2.canvas])
#image_widget

# Combined view

In [17]:
display(plot_widget, coordinate_widget, image_widget)

VBox(children=(Canvas(header_visible=False, layout=Layout(min_height='400px', width='100%'), toolbar=Toolbar(t…

VBox(children=(HBox(children=(VBox(children=(Text(value='Fish1 3D bps', layout=Layout(width='120px')), HBox(ch…

VBox(children=(Canvas(header_visible=False, layout=Layout(min_height='400px', width='100%'), toolbar=Toolbar(t…