In [1]:
import numpy as np
import h5py
import cv2
from joblib import load
import matplotlib.pyplot as plt
%matplotlib widget
import time
%load_ext autoreload
import os
%autoreload 2

from ipywidgets import Video, Image, VBox, Text
from sidecar import Sidecar
from IPython.display import display
from scipy.interpolate import interp1d
from scipy import interpolate
import numpy as np
import os
import cv2
import base64
import warnings

from ipywidgets import interact, interactive, fixed, interact_manual, FloatSlider, Layout, AppLayout
import ipywidgets as widgets


# Our custom-made imports
import sys
sys.path.append('./lib/')
from DataManagement import SplitdataManager, Filepaths, Calibration

# Paths & Parameters (Edit!)

In [2]:
### Paths to files needed and experiment specific parameters.


# experiment folder path
# experiment_path = '/media/stephens-group/guest_drive/labelling/complete_shortfin_experiment/FishTank20200130_153857/'
experiment_path = '/home/thomasreus/Documents/zebrafish_labeling_GUI/FishTank20200130_153857/'

# tracks folder path
# tracksPath = '/home/liam/Data/results_from_laetitia/FishTank20200130_153857/interpd_updated_old_tracking_on_original_instances_sLEAP.h5'
tracksPath = '/home/thomasreus/Documents/zebrafish_labeling_GUI/FishTank20200130_153857/interpd_updated_old_tracking_on_original_instances_sLEAP.h5'

# process save location and file
# labelsDataSaveFolder = '/home/liam/Desktop/saveFolder'
labelsDataSaveFolder = '/home/thomasreus/Documents/zebrafish_labeling_GUI/GUI_save_folder'
fileName = 'test.h5'

# calibration file folder
calibrationFolder = '20200120_calibration/'

# other parameters
numFish = 2
numBodyPoints = 3
numCams = 3

# Boilerplate Stuff

In [3]:
### Load data, prepare/load save file, 
### create calibration and splitmanager objects 
### and load variable names from splitmanager.


# load 3D and image coordinates data
with h5py.File(tracksPath, 'r') as hf:
    tracks_3D_raw = hf['tracks_3D'][:]
    tracks_imCoords_raw = hf['tracks_imCoords'][:]
totalnumFrames = tracks_3D_raw.shape[0]

# define the arrays which hold the edited information
tracks_3D = np.copy(tracks_3D_raw)
tracks_imCoords = np.copy(tracks_imCoords_raw)

# make the save folder if it does not already exist
if not os.path.exists(labelsDataSaveFolder):
    os.makedirs(labelsDataSaveFolder)
    
# save a copy of the image coordinates and 3D data if it doesnt exist
track_save_path = os.path.join(labelsDataSaveFolder, fileName)
if not os.path.exists(track_save_path):
    print('Making a h5 to hold outputs')
    with h5py.File(track_save_path, 'w') as hf:
        hf.create_dataset('tracks_3D', data=tracks_3D)
        hf.create_dataset('tracks_imCoords', data=tracks_imCoords)
        hf.create_dataset('missing_bps', data=np.zeros((3, totalnumFrames, numFish, numBodyPoints),dtype=bool))
        hf.create_dataset('annotated_frames', data=np.zeros((totalnumFrames,), dtype=bool))
    missing_bps = np.zeros((3, totalnumFrames, numFish, numBodyPoints), dtype=bool)
    annotated_frames = np.zeros((totalnumFrames,), dtype=bool)
    tracks_3D_edited = np.copy(tracks_3D)
    tracks_imCoords_edited = np.copy(tracks_imCoords)
else:
    print('found labels file')
    with h5py.File(track_save_path, 'r') as hf:
        tracks_3D_edited = hf['tracks_3D'][:]
        tracks_imCoords_edited = hf['tracks_imCoords'][:]
        annotated_frames = hf['annotated_frames'][:]
        missing_bps = hf['missing_bps'][:]
        
# set the path to the calibration folder and make the object
calibrationFolderPath = os.path.join(experiment_path, calibrationFolder)
cal = Calibration(calibrationFolderPath)

# make a splitdata manager
splitman = SplitdataManager(experiment_path)

# gather the names of the splitdata folders with corresponding indices
splitdata_names = [triplet[0].split('/')[-1] for triplet in splitman.splitdata_paths]
splitdata_idxs = [i for i in range(len(splitdata_names))]
splitdata_name_to_idx_dict = dict(zip(splitdata_names, splitdata_idxs))
splitdata_name_to_paths = dict(zip(splitdata_names, splitman.splitdata_paths))
totalnumFrames = splitman.start_stop_frames_for_splitdata[-1, 1]

found labels file


# Sidecar Dashboard

In [4]:
### Create Sidecar Dashboard containing:
### Widgets that hold frame specific values and
### create functionality for the labeling GUI.
### Functionalities: 
###    - Changing frame splitdata-wise (splitdata dropbox)
###    - Changing frame index-wise (Global fIdx, Frame slider)
###    - Go to next frame with missing bodypoint (Next NaN)
###    - Change frame specific focus (camIdx, bpIdx, fishIdx)
###    - Lock camIdx (Lock idx)
###    - Save process (Click me to save)
###    - Tick box for missing bodypoints and accepting frames as traning data
###    - Find closest related bodypoint data for missing bodypoints (Repair Frame).

## Widget Initialisation

In [5]:
### Create the widgets.


# make a dropdown menue to choose which splitdata you want
splitdata_chooser = widgets.Dropdown(options=splitdata_name_to_idx_dict, value=0)

# create and initialize widget to hold the splitdata_idx
splitdata_idx = widgets.IntText(value=0, description='splitdata_idx', layout=Layout(width='50%', height='50px'), disabled=True)

# create and initialize widget to hold the filepaths of the movies
xzMovPath =  widgets.Text(value=splitman.splitdata_paths[splitdata_idx.value][0] + '.mp4')
xyMovPath =  widgets.Text(value=splitman.splitdata_paths[splitdata_idx.value][1] + '.mp4')
yzMovPath =  widgets.Text(value=splitman.splitdata_paths[splitdata_idx.value][2] + '.mp4')

# create widgets to hold the global and local frame index
global_index  = widgets.IntText(value=0, description='Global fIdx', layout=Layout(width='70%', height='50px'), disabled=False)
local_index = widgets.IntText(value=0, description='Local fIdx', layout=Layout(width='70%', height='50px'), disabled=True)

# make a Global fIdx slide, and link it to the global frame number
global_index_slider = widgets.IntSlider(value=0, min=0, max=totalnumFrames-1, step=1, continuous_update=False,
                                       description='Frame', layout=Layout(width='90%', height='50px'))
gf_link = widgets.link((global_index_slider, 'value'), (global_index, 'value'))

# button to find next frame with missing data
next_widget = widgets.Button(description="Next NaN")

# --- Make variables to hold fishIdx and bpIdx for when we are clicking --- #
camIdx_widget = widgets.BoundedIntText(value=-1, min=-1, max=numCams-1, description='camIdx')
bpIdx_widget = widgets.BoundedIntText(value=-1, min=-1, max=numBodyPoints-1, description='bpIdx')
fishIdx_widget = widgets.BoundedIntText(value=-1, min=-1, max=numFish-1, description='fishIdx')

# Make the lock_camIdx_widget - the fixed camera view for plotting points
lock_camIdx_widget = widgets.RadioButtons(options=[0, 1, 2],
                                          value=1, 
                                          layout={'width': 'max-content'}, 
                                          description='Lock idx:',
                                          disabled=False)

# make a save button
save_widget = widgets.Button(description="Click Me To Save")

# make a widget to mark bodypoints as missing
missing_bp_widget = widgets.Checkbox(value=False,
                                    description='BodyPoint missing',
                                    disabled=False,
                                    indent=False)

# make a widget to mark frames as annotated
annotated_frame_widget = widgets.Checkbox(value=False,
                                        description='Frame has been verified',
                                        disabled=False,
                                        indent=False)

# button to replace missing data
repair_widget = widgets.Button(description='Repair Frame')

## Widget Functions

In [6]:
### Functionality functions of the dashboard widgets.


def handle_splitdata_name_change(change):
    ''' Make required updates when we select a splitdata**** with the dropdown menu
    '''
    # update splitdata index
    splitdata_idx.value = change.new
    # update moviePaths
    xzMovPath.value = splitman.splitdata_paths[splitdata_idx.value][0] + '.mp4'
    xyMovPath.value = splitman.splitdata_paths[splitdata_idx.value][1] + '.mp4'
    yzMovPath.value = splitman.splitdata_paths[splitdata_idx.value][2] + '.mp4'
    # update the global index given the new splitdata index and the current local idx
    splitdata_start = splitman.start_stop_frames_for_splitdata[splitdata_idx.value][0]
    global_index_slider.value = splitdata_start + local_index.value
    return


def handle_splitdata_idx_box_change(change):
    ''' Make the required changes when we update the "splitdata_idx" 
    '''
    # first, we must make sure that this new change is within the splitdatas available
    numSplitdatas = len(splitman.splitdata_paths)
    if change.new == -1:
        splitdata_idx.value = numSplitdatas - 1
    elif change.new == numSplitdatas:
        splitdata_idx.value = 0
    else:
        splitdata_idx.value = change.new
    return


def handle_local_index_change(change):
    ''' Make the required changes when we update the "Local fIdx"
    '''
    # get the start-stop for the current splitdata index
    splitdata_start_stop = splitman.start_stop_frames_for_splitdata[splitdata_idx.value]
    splitdata_numFrames = splitdata_start_stop[1] - splitdata_start_stop[0]
    # now make sure that the new index is within the range
    local_index.value = np.mod(local_index.value, splitdata_numFrames)
    return


def handle_global_index_slider_change(change):
    ''' Make the required changes when we update the "Global fIdx slider"
    '''
    # get the current splitdata index and local index for this frame
    spIdx, locIdx = splitman.return_splitdata_folder_and_local_idx_for_global_frameIdx(global_index_slider.value, return_splitdataIdx=True)
    local_index.value = locIdx
    splitdata_idx.value = spIdx
    splitdata_chooser.label = splitdata_names[spIdx]
    splitdata_chooser.index = spIdx
    missing_bp_widget.value = bool(missing_bps[camIdx_widget.value, global_index_slider.value,
                                          fishIdx_widget.value, bpIdx_widget.value])
    annotated_frame_widget.value = bool(annotated_frames[change.new])
    graph_index_slider.value = global_index_slider.value
    return


def handle_next_press(p):
    ''' Go to next frame with missing data when next_widget is pressed
    '''
    # get current frame info
    start_frame = global_index_slider.value
    global_fIdx = np.copy(start_frame) + 1
    # loop and check for next frame with nan-value
    while global_fIdx != start_frame:
        if global_fIdx == totalnumFrames:
            global_fIdx = 0
        # go to frames if nan-value found
        if np.any(np.isnan(tracks_imCoords_edited[:, global_fIdx, :, :, :])):
            global_index_slider.value = global_fIdx
            break
        global_fIdx += 1
    # return if no missing instances
    return


# handles to update the missing bodypoint widget
def update_missing_bp_widget_on_camIdx_change(change):
    missing_bp_widget.value = bool(missing_bps[change.new, global_index_slider.value, fishIdx_widget.value, bpIdx_widget.value])
    return
def update_missing_bp_widget_on_bpIdx_change(change):
    missing_bp_widget.value = bool(missing_bps[camIdx_widget.value, global_index_slider.value, fishIdx_widget.value, change.new])
    return
def update_missing_bp_widget_on_fishIdx_change(change):
    missing_bp_widget.value = bool(missing_bps[camIdx_widget.value, global_index_slider.value, change.new, bpIdx_widget.value])
    return


def handle_save_press(p):
    # grab the most current value of the widgets back into the arrays
    # imcoords
    for camIdx in range(numCams):
        for fishIdx in range(numFish):
            for bpIdx in range(numBodyPoints):
                for imCoordIdx in range(2):
                    val = frame_imCoords_widget_list[camIdx][fishIdx][bpIdx][imCoordIdx].value
                    
                    tracks_imCoords_edited[camIdx, global_index_slider.value, fishIdx, bpIdx, imCoordIdx] = val
    # 3D values
    for fishIdx in range(numFish):
        for bpIdx in range(numBodyPoints):
            for dimIdx in range(3):
                val = frame_3D_widget_list[fishIdx][bpIdx][dimIdx].value
                tracks_3D_edited[ global_index_slider.value, fishIdx, bpIdx, dimIdx] = val
    # missing bodypoints
    missing_bps[camIdx_widget.value, global_index_slider.value,
                fishIdx_widget.value, bpIdx_widget.value] = missing_bp_widget.value
    # frame annotations
    annotated_frames[global_index_slider.value] = annotated_frame_widget.value
    
    # Now save
    with h5py.File(track_save_path, 'a') as hf:
        hf['tracks_3D'][:] = tracks_3D_edited
        hf['tracks_imCoords'][:] = tracks_imCoords_edited
        hf['annotated_frames'][:] = annotated_frames
        hf['missing_bps'][:] = missing_bps
    return


def handle_missing_bp_press(change):
    if camIdx == -1 or bpIdx == -1 or fishIdx == -1:
        return
    missing_bps[camIdx_widget.value, global_index_slider.value, fishIdx_widget.value, bpIdx_widget.value] = change.new
    return


def handle_annotation_frame_press(change):
    annotated_frames[global_index_slider.value] = change.new
    return


def handle_repair_press(change):
    ''' Find closest bodypoint data to repair missing data on current frame
    '''
    # start frame
    start_frame = global_index_slider.value
    
    # find closest frame with current fish data and edit when needed
    for fishIdx in range(numFish):
        
        # parameters
        found_frame = start_frame
        gfIdx_min = start_frame - 1
        gfIdx_plus = start_frame + 1

        # move backwards and forwards until frame with complete data is found
        while gfIdx_min != gfIdx_plus:
            # end of range exceptions
            if gfIdx_min == -1:
                gfIdx_min = totalnumFrames - 1
            if gfIdx_plus == totalnumFrames:
                gfIdx_plus = 0
            # gfIdx_min check
            fish_found = np.copy(tracks_imCoords_edited[:, gfIdx_min, fishIdx, :, :])
            if not np.any(np.isnan(fish_found)):
                found_frame = gfIdx_min
                break
            # gfIdx_plus check
            fish_found = np.copy(tracks_imCoords_edited[:, gfIdx_plus, fishIdx, :, :])
            if not np.any(np.isnan(fish_found)):
                found_frame = gfIdx_plus
                break
            # iterate
            gfIdx_min -= 1
            gfIdx_plus += 1
        
        for camIdx in range(numCams):
            for bpIdx in range(numBodyPoints):
                for imCoordIdx in range(2):
                    val1 = frame_imCoords_widget_list[camIdx][fishIdx][bpIdx][imCoordIdx].value
                    if np.isnan(val1):
                        frame_imCoords_widget_list[camIdx][fishIdx][bpIdx][imCoordIdx].value = np.copy(tracks_imCoords_edited[camIdx, found_frame, 
                                                                                                                              fishIdx, bpIdx, imCoordIdx])
    # update figures and data
    update_crop_figures(0)
    return

## Function-Widget Linking & Sidecar Creation

In [7]:
### Add functionality to widgets and create dashboard Vbox.


# add functionality
splitdata_chooser.observe(handle_splitdata_name_change, names='value')
splitdata_idx.observe(handle_splitdata_idx_box_change, names='value')
local_index.observe(handle_local_index_change, names='value')
global_index_slider.observe(handle_global_index_slider_change, names='value')

next_widget.on_click(handle_next_press)

camIdx_widget.observe(update_missing_bp_widget_on_camIdx_change, names='value')
bpIdx_widget.observe(update_missing_bp_widget_on_bpIdx_change, names='value')
fishIdx_widget.observe(update_missing_bp_widget_on_fishIdx_change, names='value')

save_widget.on_click(handle_save_press)
missing_bp_widget.observe(handle_missing_bp_press, names='value')
annotated_frame_widget.observe(handle_annotation_frame_press, names='value')

repair_widget.on_click(handle_repair_press)


# make Vbox
splitdata_details_box = VBox([splitdata_chooser, 
                              splitdata_idx, 
                              global_index, 
                              local_index, 
                              global_index_slider,
                              next_widget,
                              camIdx_widget,
                              bpIdx_widget,
                              fishIdx_widget,
                              lock_camIdx_widget,
                              save_widget,
                              missing_bp_widget,
                              annotated_frame_widget,
                              repair_widget])

# display Vbox
sc = Sidecar(title='Sidecar Output')
with sc:
    display(splitdata_details_box)

# Visual Widgets

In [8]:
### In this section we create the following widgets:
###    - Interactive line-plot of raw data 
###    - 3d and imagecoordinate values of tracking results
###    - Images of 3 camera views including trackresults
### And add interactivity to these widgets with update and change functions.

## Interactive Line-Plot of Raw Data

In [9]:
### In this section we make a widget that contains a plot of the position timeseries, 
### this can be used for identifying by eye places where the data is poor or missing. 
### The POI of the specific global index range can be changed by an interactive slidebar. 

In [10]:
plt.ioff()

fig, (ax1, ax2) = plt.subplots(2, sharex=True, gridspec_kw={'hspace': 0.2})
fig.figsize = (18,10)
#fig.legend(handles, labels, loc='upper center')
fig.suptitle('Trajectory Data Timeseries')
fig.canvas.header_visible = False
fig.canvas.layout.min_height = '400px'
fig.canvas.layout.width = '100%'
ax1.grid(which='both')
ax2.grid(which='both')
ax1.ticklabel_format(useOffset=False)
ax2.ticklabel_format(useOffset=False)
plot_frame_width = 50

#plt.xticks(np.arange(0, totalnumFrames, 100))

# ---- define the info for each plot ----- #
# plot 1 - heads
bpIdx = 0
p1_data = np.copy(tracks_3D_edited[:, :, bpIdx, :])

# plot 2 - pecs
bpIdx = 2
p2_data = np.copy(tracks_3D_edited[:, :, bpIdx, :])

colors = ['lightcoral', 'brown', 'maroon', 
          'forestgreen', 'darkgreen', 'limegreen']


# ----- define the line plots ------ #
lines1_x1 = ax1.plot(p1_data[:, 0, 0], label='fish1_head_x', color=colors[0])
lines1_y1 = ax1.plot(p1_data[:, 0, 1], label='fish1_head_y', color=colors[1])
lines1_z1 = ax1.plot(p1_data[:, 0, 2], label='fish1_head_z', color=colors[2])

lines1_x2 = ax1.plot(p1_data[:, 1, 0], label='fish2_head_x', color=colors[3])
lines1_y2 = ax1.plot(p1_data[:, 1, 1], label='fish2_head_y', color=colors[4])
lines1_z2 = ax1.plot(p1_data[:, 1, 2], label='fish2_head_z', color=colors[5])



lines2_x1 = ax2.plot(p2_data[:, 0, 0], label='fish1_x', color=colors[0])
lines2_y1 = ax2.plot(p2_data[:, 0, 1], label='fish1_y', color=colors[1])
lines2_z1 = ax2.plot(p2_data[:, 0, 2], label='fish1_z', color=colors[2])

lines2_x2 = ax2.plot(p2_data[:, 1, 0], label='fish2_x', color=colors[3])
lines2_y2 = ax2.plot(p2_data[:, 1, 1], label='fish2_y', color=colors[4])
lines2_z2 = ax2.plot(p2_data[:, 1, 2], label='fish2_z', color=colors[5])



ax1.set_title('Heads: frame = {0}'.format(global_index_slider.value))
ax2.set_title('Tails)')

ax1.set_ylim(-1, 43)
ax2.set_ylim(-1, 43)

ax1.set_yticks([i for i in range(0,41,5)], minor=True)
ax2.set_yticks([i for i in range(0,41,5)], minor=True)

ax1.set_ylabel('Head Position (cm)')
ax2.set_ylabel('Tail Position (cm)')

# use any axis here to get the labels
handles, labels = ax2.get_legend_handles_labels()
fig.legend(handles, labels, loc='center right')

if global_index_slider.value < plot_frame_width:
    ax1.set_xlim(0, 2*plot_frame_width)
    ax2.set_xlim(0, 2*plot_frame_width)
elif totalnumFrames - global_index_slider.value < plot_frame_width:
    ax1.set_xlim(totalnumFrames-2*plot_frame_width, totalnumFrames)
    ax2.set_xlim(totalnumFrames-2*plot_frame_width, totalnumFrames)
else:
    ax1.set_xlim(global_index_slider.value-plot_frame_width, global_index_slider.value+plot_frame_width)
    ax2.set_xlim(global_index_slider-plot_frame_width, global_index_slider.value+plot_frame_width)


def update_lines(change):
    ax1.set_title('Heads: frame = {0})'.format(change.new))
    ax2.set_title('Tails')
    if change.new < plot_frame_width:
        ax1.set_xlim(0, 2*plot_frame_width)
        ax2.set_xlim(0, 2*plot_frame_width)
    elif totalnumFrames - change.new < plot_frame_width:
        ax1.set_xlim(totalnumFrames-2*plot_frame_width, totalnumFrames)
        ax2.set_xlim(totalnumFrames-2*plot_frame_width, totalnumFrames)
    else:
        ax1.set_xlim(change.new-plot_frame_width, change.new+plot_frame_width)
        ax2.set_xlim(change.new-plot_frame_width, change.new+plot_frame_width)
    fig.canvas.draw()
    fig.canvas.flush_events()

    
# graph index slider
graph_index_slider = widgets.IntSlider(value=0, min=0, max=totalnumFrames-1, step=1, continuous_update=True,
                                       description='Frame', layout=Layout(width='90%', height='50px'))
graph_index_slider.observe(update_lines, names='value')

layout = widgets.Layout(width='100%')
plot_widget = VBox([fig.canvas, graph_index_slider], layout=layout)

## Tracking Result Widgets

In [11]:
### In this section we make a display of widgets for holding the 3D positions and 
### image coordinates

In [12]:
### Grab the image coordinates and 3D coordinats for this frame


# layout parameters
number_layout = widgets.Layout(width='70px')
text_layout = widgets.Layout(width='120px')

# default positions of bodypoints in cropped coords
p1_init = np.array([20, 20])
p2_init = np.array([80, 80])
p3_init = np.array([140, 140])
def_cropped_pos = [p1_init, p2_init, p3_init]

# preallocate list structures to hold the numbers
frame_imCoords_widget_list = [[[[ [] for _ in range(2)] for _ in range(numBodyPoints)] for _ in range(numFish)] for _ in range(numCams)]
frame_3D_widget_list = [ [[[] for _ in range(3)] for _ in range(numBodyPoints)] for _ in range(numFish)]

# fill in the widgets with initial values
for camIdx in range(numCams):
    for fishIdx in range(numFish):
        for bpIdx in range(numBodyPoints):
            for imCoordIdx in range(2):
                wid_val = np.round(np.copy(tracks_imCoords_edited[camIdx, global_index_slider.value, fishIdx, bpIdx, imCoordIdx]),2)
                val_widge_coord = widgets.FloatText(value=wid_val,layout=number_layout)
                frame_imCoords_widget_list[camIdx][fishIdx][bpIdx][imCoordIdx] = val_widge_coord
                
for fishIdx in range(numFish):
    for bpIdx in range(numBodyPoints):
        for dimIdx in range(3):
            wid_val = np.round(np.copy(tracks_3D_edited[global_index_slider.value, fishIdx, bpIdx, dimIdx]), 2)
            val_widge_3D = widgets.FloatText(value=wid_val,layout=number_layout)
            frame_3D_widget_list[fishIdx][bpIdx][dimIdx] = val_widge_3D


In [13]:
### Make widgets to vizualize the 3D positions
### (?) Implement general function for experiments with different number of fish


# layout parameters
number_layout = widgets.Layout(width='80px')
text_layout = widgets.Layout(width='120px')

# widget for fish 1
fishIdx = 0
f1_bp_box_list = []
for bpIdx in range(numBodyPoints):
    f1_bp_XYZ_box = widgets.HBox(frame_3D_widget_list[fishIdx][bpIdx][:], layout=widgets.Layout(padding=('0px 30px 0 0')))
    f1_bp_box_list.append(f1_bp_XYZ_box)
f1_3D_wid_title = widgets.Text('Fish1 3D bps', layout=text_layout)
f1_3D_wid = widgets.VBox([f1_3D_wid_title]+f1_bp_box_list)

# widget for fish 2
fishIdx = 1
f2_bp_box_list = []
for bpIdx in range(numBodyPoints):
    f2_bp_XYZ_box = widgets.HBox(frame_3D_widget_list[fishIdx][bpIdx][:], layout=widgets.Layout(padding=('0px 10px 0 0')))
    f2_bp_box_list.append(f2_bp_XYZ_box)
f2_3D_wid_title = widgets.Text('Fish2 3D bps', layout=text_layout)
f2_3D_wid = widgets.VBox([f2_3D_wid_title]+f2_bp_box_list)

# Final widget
tracks_3D_frame_widget = widgets.HBox([f1_3D_wid, f2_3D_wid], layout=widgets.Layout(padding=('0px 0px 30px 200px')))

In [14]:
### Make widgets to vizualize the image coordinates
### (?) Implement general function for experiments with different number of fish


# layout parameters
number_layout = widgets.Layout(width='80px')
text_layout = widgets.Layout(width='120px')
cam_title_layout = widgets.Layout(width='160px')
title_layout = widgets.Layout(width='480px')
camNames = ['xz', 'xy', 'yz']

# widget for fish 1
fishIdx = 0
cam_bp_imcoord_wid_list = []
for camIdx in range(3):
    cam_bp_list = []
    for bpIdx in range(numBodyPoints):
        bp_imcoords_box = widgets.HBox(frame_imCoords_widget_list[camIdx][fishIdx][bpIdx][:])
        cam_bp_list.append(bp_imcoords_box)
    cam_widget_title = widgets.Text(camNames[camIdx], layout=cam_title_layout)
    cam_widget = widgets.VBox([cam_widget_title]+cam_bp_list)
    cam_bp_imcoord_wid_list.append(cam_widget)
f1_imcoord_wid_title = widgets.Text('Fish1 bodypoint image coordinates', layout=title_layout)
f1_imcoord_wid_data = widgets.HBox(cam_bp_imcoord_wid_list)
f1_imcoord_wid = widgets.VBox([f1_imcoord_wid_title, f1_imcoord_wid_data])

# widget for fish 2
fishIdx = 1
cam_bp_imcoord_wid_list = []
for camIdx in range(3):
    cam_bp_list = []
    for bpIdx in range(numBodyPoints):
        bp_imcoords_box = widgets.HBox(frame_imCoords_widget_list[camIdx][fishIdx][bpIdx][:])
        cam_bp_list.append(bp_imcoords_box)
    cam_widget_title = widgets.Text(camNames[camIdx], layout=cam_title_layout)
    cam_widget = widgets.VBox([cam_widget_title]+cam_bp_list)
    cam_bp_imcoord_wid_list.append(cam_widget)
f2_imcoord_wid_title = widgets.Text('Fish2 bodypoint image coordinates', layout=title_layout)
f2_imcoord_wid_data = widgets.HBox(cam_bp_imcoord_wid_list)
f2_imcoord_wid = widgets.VBox([f2_imcoord_wid_title, f2_imcoord_wid_data])

# final widget
tracks_imCoord_frame_widget = widgets.HBox([f1_imcoord_wid, f2_imcoord_wid])

In [15]:
### Update function for data when changing frame


def update_data(change):
    
    # save the values for the old frame
    for camIdx in range(numCams):
        for fishIdx in range(numFish):
            for bpIdx in range(numBodyPoints):
                for imCoordIdx in range(2):
                    val = frame_imCoords_widget_list[camIdx][fishIdx][bpIdx][imCoordIdx].value
                    tracks_imCoords_edited[camIdx, change.old, fishIdx, bpIdx, imCoordIdx] = val
    for fishIdx in range(numFish):
        for bpIdx in range(numBodyPoints):
            for dimIdx in range(3):
                val = frame_3D_widget_list[fishIdx][bpIdx][dimIdx].value
                tracks_3D_edited[change.old, fishIdx, bpIdx, dimIdx] = val

    
    # update widget for the new frame
    for camIdx in range(numCams):
        for fishIdx in range(numFish):
            for bpIdx in range(numBodyPoints):
                for imCoordIdx in range(2):
                    wid_val = np.round(np.copy(tracks_imCoords_edited[camIdx, change.new, fishIdx, bpIdx, imCoordIdx]),2)
                    frame_imCoords_widget_list[camIdx][fishIdx][bpIdx][imCoordIdx].value = wid_val
    for fishIdx in range(numFish):
        for bpIdx in range(numBodyPoints):
            for dimIdx in range(3):
                wid_val = np.round(np.copy(tracks_3D_edited[change.new, fishIdx, bpIdx, dimIdx]), 2)
                frame_3D_widget_list[fishIdx][bpIdx][dimIdx].value = wid_val
    
    
global_index_slider.observe(update_data, names='value')

In [16]:
# make the final coordinate widget
coordinate_widget = widgets.VBox([tracks_3D_frame_widget, tracks_imCoord_frame_widget])


## Loading Images

In [17]:
### In this section we make a widget to hold the the images of the fish tank.
### And we add the following functionalities to the widget:
###    - move bodypoint by clicking near a bodypoint or dragging a bodypoint
###    - move selected bodypoint by clicking crtl+arrowkeys
###    - zoom into figure by using (shift+) scrollwheel
###    - select previous or next bodypoint by pressing (shift+)tab

In [18]:
### Initial creation of figures


# make a list of colors
no_id_color = (255,255,255)           # white
fish_colors = [(255,0,0), (0,255,0)]
edges = [[0,1], [1,2]]                # how bodypoints are joined                  


ROIS = {
    "xz": [182, 278, 962, 1011],
    "xy": [157, 281, 868, 999],
    "yz": [186, 290, 958, 1015],
}


# ----- Make the figure ------#
plt.ioff()

fig1, (f1_ax1, f1_ax2, f1_ax3) = plt.subplots(1, 3, figsize=(9,3))

f_axs = [f1_ax1, f1_ax2, f1_ax3]

f1_ax1.set_title('XZ')
f1_ax2.set_title('XY')
f1_ax3.set_title('YZ')

for fax in f_axs:
    fax.get_xaxis().set_visible(False)
    fax.get_yaxis().set_visible(False)
    
fig1.canvas.header_visible = False
fig1.canvas.layout.min_height = '400px'
fig1.canvas.layout.width = '100%'

plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, 
           hspace = 0, wspace = 0)



#  -------------- get the initial frame data -------------------#
video_paths = [xzMovPath, xyMovPath, yzMovPath]
caps = [cv2.VideoCapture(video_path.value) for video_path in video_paths]
[cap.set(1,local_index.value) for cap in caps] #go to current frame
# grab the frame from each camera
if caps[0].isOpened() and caps[1].isOpened() and caps[2].isOpened():  
    # grab the 3 images
    ret_xz, frame_xz = caps[0].read()
    if not ret_xz:
        raise TypeError('xz movie for splitdata_idx={0} not opening'.format(splitdata_idx.value))
    ret_xy, frame_xy = caps[1].read()
    if not ret_xy:
        raise TypeError('xy movie for splitdata_idx={0} not opening'.format(splitdata_idx.value))
    ret_yz, frame_yz = caps[2].read()
    if not ret_yz:
        raise TypeError('yz movie for splitdata_idx={0} not opening'.format(splitdata_idx.value))
    frames = [frame_xz, frame_xy, frame_yz]
# close everything before we finish
[caps[capIdx].release() for capIdx in range(3)];

# draw the info
for camIdx in range(3):
    frame = frames[camIdx]
    #frame_cam_tracks = original_instances_new[camIdx, t]
    for fishIdx in range(numFish):
        
        # draw the bodypoints
        for bpIdx in range(numBodyPoints):
            # get the data
            imCoord_x = frame_imCoords_widget_list[camIdx][fishIdx][bpIdx][0].value
            imCoord_y = frame_imCoords_widget_list[camIdx][fishIdx][bpIdx][1].value
            imCoord = np.array([imCoord_x, imCoord_y])
            #get the color
            fcolor = fish_colors[fishIdx]
            color = ( int (fcolor [ 0 ]), int(fcolor [ 1 ]), int(fcolor [ 2 ]))
            if np.all(np.isnan(imCoord)):
                continue
            cv2.circle(frame, (int(imCoord[0]), int(imCoord[1])), radius=2, color=color, thickness=-1)
            
        # draw the lines
        for edge in edges:
            pt1_bpIdx = edge[0]
            pt2_bpIdx = edge[1]
            
            pt1_imCoord_x = frame_imCoords_widget_list[camIdx][fishIdx][pt1_bpIdx][0].value
            pt1_imCoord_y = frame_imCoords_widget_list[camIdx][fishIdx][pt1_bpIdx][1].value
            pt1_imCoord = np.array([pt1_imCoord_x, pt1_imCoord_y])
            
            pt2_imCoord_x = frame_imCoords_widget_list[camIdx][fishIdx][pt2_bpIdx][0].value
            pt2_imCoord_y = frame_imCoords_widget_list[camIdx][fishIdx][pt2_bpIdx][1].value
            pt2_imCoord = np.array([pt2_imCoord_x, pt2_imCoord_y])

            # parse the points for format and emptiness
            if np.all(np.isnan(pt1_imCoord)):
                pt1 =tuple([np.NaN, np.NaN])
            else:
                pt1 = tuple([int(x) for x in pt1_imCoord])
            if np.all(np.isnan(pt2_imCoord)):
                pt2 =tuple([np.NaN, np.NaN])
            else:
                pt2 = tuple([int(x) for x in pt2_imCoord])
            if ~np.all(np.isnan(pt1)) and ~np.all(np.isnan(pt2)):
                cv2.line(frame, pt1=pt1,  pt2=pt2, color=color, thickness=1, lineType=cv2.LINE_AA)

                
                
                
# ------------ plot the images -------------#

panel1 = f1_ax1.imshow(frames[0])
panel2 = f1_ax2.imshow(frames[1])
panel3 = f1_ax3.imshow(frames[2])


# Set the plot limits so we only see the interior region of the cage
# f1_ax1.set_xlim(ROIS['xz'][1], ROIS['xz'][3])
# f1_ax1.set_ylim(ROIS['xz'][0], ROIS['xz'][2])

# f1_ax2.set_xlim(ROIS['xy'][0], ROIS['xy'][2])
# f1_ax2.set_ylim(ROIS['xy'][1], ROIS['xy'][3])

# f1_ax3.set_xlim(ROIS['yz'][0], ROIS['yz'][2])
# f1_ax3.set_ylim(ROIS['yz'][1], ROIS['yz'][3])

In [19]:
### Update function for figure
### (?) This could maybe be (rewritten a little and) called 
###     in other functions to do the drawing


def update_crop_figures(change):
    
    caps = [cv2.VideoCapture(video_path.value) for video_path in video_paths]
    [cap.set(1,local_index.value) for cap in caps] #go to current frame
    # grab the frame from each camera
    if caps[0].isOpened() and caps[1].isOpened() and caps[2].isOpened():  
        # grab the 3 images
        ret_xz, frame_xz = caps[0].read()
        if not ret_xz:
            raise TypeError('xz movie for splitdata_idx={0} not opening'.format(splitdata_idx.value))
        ret_xy, frame_xy = caps[1].read()
        if not ret_xy:
            raise TypeError('xy movie for splitdata_idx={0} not opening'.format(splitdata_idx.value))
        ret_yz, frame_yz = caps[2].read()
        if not ret_yz:
            raise TypeError('yz movie for splitdata_idx={0} not opening'.format(splitdata_idx.value))
        frames = [frame_xz, frame_xy, frame_yz]
    # release the caps
    [caps[capIdx].release() for capIdx in range(3)];
    
    # --------------- Draw the skeletons --------------------#
    # draw the info
    for camIdx in range(3):
        frame = frames[camIdx]
        #frame_cam_tracks = original_instances_new[camIdx, t]
        for fishIdx in range(numFish):

            # draw the bodypoints
            for bpIdx in range(numBodyPoints):
                # get the data
                imCoord_x = frame_imCoords_widget_list[camIdx][fishIdx][bpIdx][0].value
                imCoord_y = frame_imCoords_widget_list[camIdx][fishIdx][bpIdx][1].value
                imCoord = np.array([imCoord_x, imCoord_y])
                #get the color
                fcolor = fish_colors[fishIdx]
                color = ( int (fcolor [ 0 ]), int(fcolor [ 1 ]), int(fcolor [ 2 ]))
                #color = ( int (no_id_color[ 0 ]), int(no_id_color[ 1 ]), int(no_id_color[ 2 ]))
                if np.all(np.isnan(imCoord)):
                    continue
                cv2.circle(frame, (int(imCoord[0]), int(imCoord[1])), radius=2, color=color, thickness=-1)

            # draw the lines
            for edge in edges:
                pt1_bpIdx = edge[0]
                pt2_bpIdx = edge[1]

                pt1_imCoord_x = frame_imCoords_widget_list[camIdx][fishIdx][pt1_bpIdx][0].value
                pt1_imCoord_y = frame_imCoords_widget_list[camIdx][fishIdx][pt1_bpIdx][1].value
                pt1_imCoord = np.array([pt1_imCoord_x, pt1_imCoord_y])

                pt2_imCoord_x = frame_imCoords_widget_list[camIdx][fishIdx][pt2_bpIdx][0].value
                pt2_imCoord_y = frame_imCoords_widget_list[camIdx][fishIdx][pt2_bpIdx][1].value
                pt2_imCoord = np.array([pt2_imCoord_x, pt2_imCoord_y])

                # parse the points for format and emptiness
                if np.all(np.isnan(pt1_imCoord)):
                    pt1 =tuple([np.NaN, np.NaN])
                else:
                    pt1 = tuple([int(x) for x in pt1_imCoord])
                if np.all(np.isnan(pt2_imCoord)):
                    pt2 =tuple([np.NaN, np.NaN])
                else:
                    pt2 = tuple([int(x) for x in pt2_imCoord])
                if ~np.all(np.isnan(pt1)) and ~np.all(np.isnan(pt2)):
                    cv2.line(frame, pt1=pt1,  pt2=pt2, color=color, thickness=1, lineType=cv2.LINE_AA)
    
    # now update the figures    
    panel1.set_data(frames[0])
    panel2.set_data(frames[1])
    panel3.set_data(frames[2])
    
    fig1.canvas.draw()
    fig1.canvas.flush_events()
    return

global_index_slider.observe(update_crop_figures, names='value')

In [20]:
### Functionality functions


def click_down_event_handler(event):
    # find the fishIdx and camIdx image we clicked on
    for fishIdx in range(numFish):
        for camIdx in range(numCams):
            comp_ax = f_axs[camIdx]
            if event.inaxes == comp_ax:
                camIdx_widget.value = camIdx
            
    # find the closest image coordinate, wihtin the theshold
    find_idx_of_closest_point(event.xdata, event.ydata, tolerance=20)

    
def find_idx_of_closest_point(x, y, tolerance=15):
    ''' Return the bodypoint corresponding to this click
    '''
    camIdx = camIdx_widget.value
    
    distances = [[] for _ in range(numFish)]
    
    for fishIdx in range(2):
        for bpIdx in range(numBodyPoints):
            xp = frame_imCoords_widget_list[camIdx][fishIdx][bpIdx][0].value 
            yp = frame_imCoords_widget_list[camIdx][fishIdx][bpIdx][1].value 
            dist = (((x-xp)**2) + ((y-yp)**2))**0.5
            distances[fishIdx].append(dist)
            
    # the fish with the lowest cost is the fishIdx
    best_fishIdx = np.nanargmin([np.min(fish_distances) for fish_distances in distances])
    best_bpIdx = np.nanargmin(distances[best_fishIdx])
    min_dist = np.nanmin(distances[best_fishIdx])
    
    if min_dist < tolerance:
        bpIdx_widget.value = best_bpIdx
        fishIdx_widget.value = best_fishIdx
    else:
        bpIdx_widget.value = -1
        fishIdx_widget.value = -1

         
def click_release_event_handler(event):
    ''' When you release the mouse, redraw the appropriate figure
    '''
    # do nothing if we dont have an active bpIdx and fishIdx
    if camIdx_widget.value == -1 or bpIdx_widget.value == -1:
        return
    
    # do nothing if cam is locked
    if camIdx_widget.value == lock_camIdx_widget.value:
        return
    
    # -------------------------------------------------------#
    #      Update the image you clicked on
    # -------------------------------------------------------#
    # update the image coordinates using the mouse click
    frame_imCoords_widget_list[camIdx_widget.value][fishIdx_widget.value][bpIdx_widget.value][0].value = event.xdata
    frame_imCoords_widget_list[camIdx_widget.value][fishIdx_widget.value][bpIdx_widget.value][1].value = event.ydata
    
    # grab the frame we want to redraw
    cap = cv2.VideoCapture(video_paths[camIdx_widget.value].value)
    cap.set(1,local_index.value) #go to current frame
    # grab the frame from each camera
    if cap.isOpened():
        # grab the 3 images
        ret, frame = cap.read()
        if not ret:
            raise TypeError('movie for splitdata_idx={0} not opening'.format(splitdata_idx.value))
    # release the caps
    cap.release();
    
    for fishIdx in range(numFish):

        # draw the bodypoints
        for bpIdx in range(numBodyPoints):
            # get the data
            imCoord_x = frame_imCoords_widget_list[camIdx_widget.value][fishIdx][bpIdx][0].value
            imCoord_y = frame_imCoords_widget_list[camIdx_widget.value][fishIdx][bpIdx][1].value
            imCoord = np.array([imCoord_x, imCoord_y])
            #get the color
            fcolor = fish_colors[fishIdx]
            color = ( int (fcolor [ 0 ]), int(fcolor [ 1 ]), int(fcolor [ 2 ]) )
            #color = ( int (no_id_color[ 0 ]), int(no_id_color[ 1 ]), int(no_id_color[ 2 ]))
            if np.all(np.isnan(imCoord)):
                continue
            cv2.circle(frame, (int(imCoord[0]), int(imCoord[1])), radius=2, color=color, thickness=-1)

        # draw the lines
        for edge in edges:
            pt1_bpIdx = edge[0]
            pt2_bpIdx = edge[1]

            pt1_imCoord_x = frame_imCoords_widget_list[camIdx_widget.value][fishIdx][pt1_bpIdx][0].value
            pt1_imCoord_y = frame_imCoords_widget_list[camIdx_widget.value][fishIdx][pt1_bpIdx][1].value
            pt1_imCoord = np.array([pt1_imCoord_x, pt1_imCoord_y])

            pt2_imCoord_x = frame_imCoords_widget_list[camIdx_widget.value][fishIdx][pt2_bpIdx][0].value
            pt2_imCoord_y = frame_imCoords_widget_list[camIdx_widget.value][fishIdx][pt2_bpIdx][1].value
            pt2_imCoord = np.array([pt2_imCoord_x, pt2_imCoord_y])

            # parse the points for format and emptiness
            if np.all(np.isnan(pt1_imCoord)):
                pt1 =tuple([np.NaN, np.NaN])
            else:
                pt1 = tuple([int(x) for x in pt1_imCoord])
            if np.all(np.isnan(pt2_imCoord)):
                pt2 =tuple([np.NaN, np.NaN])
            else:
                pt2 = tuple([int(x) for x in pt2_imCoord])
            if ~np.all(np.isnan(pt1)) and ~np.all(np.isnan(pt2)):
                cv2.line(frame, pt1=pt1,  pt2=pt2, color=color, thickness=1, lineType=cv2.LINE_AA)
            
    # now update the figures
    if camIdx_widget.value == 0:
        panel1.set_data(frame)
    elif camIdx_widget.value == 1:
        panel2.set_data(frame)
    elif camIdx_widget.value == 2:
        panel3.set_data(frame)
        
    # -------------------------------------------------------#
    #     Update the other image
    # -------------------------------------------------------#
    
    # decide on "other_camIdx" - the other camera to update
    if lock_camIdx_widget.value == 0:
        if camIdx_widget.value == 1:
            other_camIdx = 2
        elif camIdx_widget.value == 2:
            other_camIdx = 1
            
    if lock_camIdx_widget.value == 1:
        if camIdx_widget.value == 0:
            other_camIdx = 2
        elif camIdx_widget.value == 2:
            other_camIdx = 0
            
    if lock_camIdx_widget.value == 2:
        if camIdx_widget.value == 0:
            other_camIdx = 1
        elif camIdx_widget.value == 1:
            other_camIdx = 0
            
    # grab the frame we want to redraw
    cap = cv2.VideoCapture(video_paths[other_camIdx].value)
    cap.set(1,local_index.value) #go to current frame
    # grab the frame from each camera
    if cap.isOpened():
        # grab the 3 images
        ret, frame = cap.read()
        if not ret:
            raise TypeError('movie for splitdata_idx={0} not opening'.format(splitdata_idx.value))
    # release the caps
    cap.release();

    # get the image coordinates of the other point
    clicked_point_imCoord_x = frame_imCoords_widget_list[camIdx_widget.value][fishIdx_widget.value][bpIdx_widget.value][0].value
    clicked_point_imCoord_y = frame_imCoords_widget_list[camIdx_widget.value][fishIdx_widget.value][bpIdx_widget.value][1].value
    clicked_point_imcoords = np.array([clicked_point_imCoord_x, clicked_point_imCoord_y])
    fixed_point_imCoord_x = frame_imCoords_widget_list[lock_camIdx_widget.value][fishIdx_widget.value][bpIdx_widget.value][0].value
    fixed_point_imCoord_y = frame_imCoords_widget_list[lock_camIdx_widget.value][fishIdx_widget.value][bpIdx_widget.value][1].value
    fixed_point_imcoords = np.array([fixed_point_imCoord_x, fixed_point_imCoord_y])
    
    if lock_camIdx_widget.value == 0:
        if camIdx_widget.value == 1:
            other_camIdx = 2
            other_point_imcoords = cal.compute_YZ_imcoords_from_XZ_XY(fixed_point_imcoords, clicked_point_imcoords)
        elif camIdx_widget.value == 2:
            other_camIdx = 1
            other_point_imcoords = cal.compute_XY_imcoords_from_XZ_YZ(fixed_point_imcoords, clicked_point_imcoords)
            
    if lock_camIdx_widget.value == 1:
        if camIdx_widget.value == 0:
            other_camIdx = 2
            other_point_imcoords = cal.compute_YZ_imcoords_from_XZ_XY(clicked_point_imcoords, fixed_point_imcoords)
        elif camIdx_widget.value == 2:
            other_camIdx = 0
            other_point_imcoords = cal.compute_XZ_imcoords_from_XY_YZ(fixed_point_imcoords, clicked_point_imcoords)
            
    if lock_camIdx_widget.value == 2:
        if camIdx_widget.value == 0:
            other_camIdx = 1
            other_point_imcoords = cal.compute_XY_imcoords_from_XZ_YZ(clicked_point_imcoords, fixed_point_imcoords)
        elif camIdx_widget.value == 1:
            other_camIdx = 0
            other_point_imcoords = cal.compute_XZ_imcoords_from_XY_YZ(clicked_point_imcoords, fixed_point_imcoords)

    # update the image coordinates using the mouse click
    frame_imCoords_widget_list[other_camIdx][fishIdx_widget.value][bpIdx_widget.value][0].value = other_point_imcoords[0][0]
    frame_imCoords_widget_list[other_camIdx][fishIdx_widget.value][bpIdx_widget.value][1].value = other_point_imcoords[0][1]
    
    # draw the points for this view
    for fishIdx in range(numFish):
        # draw the bodypoints
        for bpIdx in range(numBodyPoints):
            # get the data
            imCoord_x = frame_imCoords_widget_list[other_camIdx][fishIdx][bpIdx][0].value
            imCoord_y = frame_imCoords_widget_list[other_camIdx][fishIdx][bpIdx][1].value
            imCoord = np.array([imCoord_x, imCoord_y])
            #get the color
            fcolor = fish_colors[fishIdx]
            color = ( int (fcolor [ 0 ]), int(fcolor [ 1 ]), int(fcolor [ 2 ]) )
            #color = ( int (no_id_color[ 0 ]), int(no_id_color[ 1 ]), int(no_id_color[ 2 ]))
            if np.all(np.isnan(imCoord)):
                continue
            cv2.circle(frame, (int(imCoord[0]), int(imCoord[1])), radius=2, color=color, thickness=-1)
        # draw the lines
        for edge in edges:
            pt1_bpIdx = edge[0]
            pt2_bpIdx = edge[1]

            pt1_imCoord_x = frame_imCoords_widget_list[other_camIdx][fishIdx][pt1_bpIdx][0].value
            pt1_imCoord_y = frame_imCoords_widget_list[other_camIdx][fishIdx][pt1_bpIdx][1].value
            pt1_imCoord = np.array([pt1_imCoord_x, pt1_imCoord_y])

            pt2_imCoord_x = frame_imCoords_widget_list[other_camIdx][fishIdx][pt2_bpIdx][0].value
            pt2_imCoord_y = frame_imCoords_widget_list[other_camIdx][fishIdx][pt2_bpIdx][1].value
            pt2_imCoord = np.array([pt2_imCoord_x, pt2_imCoord_y])

            # parse the points for format and emptiness
            if np.all(np.isnan(pt1_imCoord)):
                pt1 =tuple([np.NaN, np.NaN])
            else:
                pt1 = tuple([int(x) for x in pt1_imCoord])
            if np.all(np.isnan(pt2_imCoord)):
                pt2 =tuple([np.NaN, np.NaN])
            else:
                pt2 = tuple([int(x) for x in pt2_imCoord])
            if ~np.all(np.isnan(pt1)) and ~np.all(np.isnan(pt2)):
                cv2.line(frame, pt1=pt1,  pt2=pt2, color=color, thickness=1, lineType=cv2.LINE_AA)
            
    # now update the figures
    if other_camIdx == 0:
        panel1.set_data(frame)
    elif other_camIdx == 1:
        panel2.set_data(frame)
    elif other_camIdx == 2:
        panel3.set_data(frame)
        
    # ----------------------------------------------------------------------#
    fig1.canvas.draw()
    fig1.canvas.flush_events()
    return


def key_press_event_handler(event):
    '''
    (shift+)tab to move between bodypoints
    key interaction for moving bodypoint indicators (ctrl + arrowkeys)
                                                    (v, b, n, m)
    '''
    # return if no camera selected
    if camIdx_widget.value == -1:
        return
    # check if tab is pressed
    if event.key == 'tab' or event.key == 'shift+tab':
        fishIdx = fishIdx_widget.value
        bpIdx = bpIdx_widget.value
        camIdx = camIdx_widget.value
        if event.key == 'tab':
            bpIdx += 1
            if bpIdx == numBodyPoints:
                bpIdx = 0
                fishIdx += 1
            if fishIdx == numFish:
                fishIdx = 0
        else:
            bpIdx -= 1
            if bpIdx == -1:
                bpIdx = numBodyPoints - 1
                fishIdx -= 1
            if fishIdx == -1:
                fishIdx = numFish - 1
        bpIdx = max([0, bpIdx])
        fishIdx = max([0, fishIdx])
        bpIdx_widget.value = bpIdx
        fishIdx_widget.value = fishIdx
        # get right ax
        axes = [f1_ax1, f1_ax2, f1_ax3]
        ax = axes[camIdx]
        # current axes ranges
        cur_xlim = ax.get_xlim()
        cur_ylim = ax.get_ylim()
        cur_xrange = (cur_xlim[1] - cur_xlim[0])*.5
        cur_yrange = (cur_ylim[1] - cur_ylim[0])*.5
        # center and crop all cameraviews
        for camIdx, ax in enumerate(axes):
            # make sure home button resets to full image
            ax.figure.canvas.toolbar.push_current()
            imCoord_x = frame_imCoords_widget_list[camIdx][fishIdx][bpIdx][0].value
            imCoord_y = frame_imCoords_widget_list[camIdx][fishIdx][bpIdx][1].value
            if np.isnan(imCoord_x) or np.isnan(imCoord_y):
                continue
            # set new limits
            ax.set_xlim([imCoord_x - cur_xrange,
                         imCoord_x + cur_xrange])
            ax.set_ylim([imCoord_y - cur_yrange,
                         imCoord_y + cur_yrange])
        fig1.canvas.draw()
        fig1.canvas.flush_events()
        return
    # return if no bodypoint or fish is selected
    if bpIdx_widget.value == -1 or fishIdx_widget.value == -1:
        return
    # do nothing if cam is locked
    if camIdx_widget.value == lock_camIdx_widget.value:
        return
    key_actions = {'ctrl+right': [0, 1], 'ctrl+left': [0, -1], 'ctrl+up': [1, -1], 'ctrl+down': [1, 1]}
#     key_actions = {'m': [0, 1], 'v': [0, -1], 'b': [1, -1], 'n': [1, 1]}
    try:
        key_action = key_actions[event.key]
        frame_imCoords_widget_list[camIdx_widget.value][fishIdx_widget.value][bpIdx_widget.value][key_action[0]].value += key_action[1]
    except:
        return
    # grab the frame we want to redraw
    cap = cv2.VideoCapture(video_paths[camIdx_widget.value].value)
    cap.set(1,local_index.value) #go to current frame
    # grab the frame from each camera
    if cap.isOpened():
        # grab the 3 images
        ret, frame = cap.read()
        if not ret:
            raise TypeError('movie for splitdata_idx={0} not opening'.format(splitdata_idx.value))
    # release the caps
    cap.release();
    # --------------- Draw the skeletons --------------------#
    camIdx = camIdx_widget.value
    for fishIdx in range(numFish):
        # draw the bodypoints
        for bpIdx in range(numBodyPoints):
            # get the data
            imCoord_x = frame_imCoords_widget_list[camIdx][fishIdx][bpIdx][0].value
            imCoord_y = frame_imCoords_widget_list[camIdx][fishIdx][bpIdx][1].value
            imCoord = np.array([imCoord_x, imCoord_y])
            #get the color
            fcolor = fish_colors[fishIdx]
            color = ( int (fcolor [ 0 ]), int(fcolor [ 1 ]), int(fcolor [ 2 ]) )
            #color = ( int (no_id_color[ 0 ]), int(no_id_color[ 1 ]), int(no_id_color[ 2 ]))
            if np.all(np.isnan(imCoord)):
                continue
            cv2.circle(frame, (int(imCoord[0]), int(imCoord[1])), radius=2, color=color, thickness=-1)
        
        # draw the lines
        for edge in edges:
            pt1_bpIdx = edge[0]
            pt2_bpIdx = edge[1]
            pt1_imCoord_x = frame_imCoords_widget_list[camIdx][fishIdx][pt1_bpIdx][0].value
            pt1_imCoord_y = frame_imCoords_widget_list[camIdx][fishIdx][pt1_bpIdx][1].value
            pt1_imCoord = np.array([pt1_imCoord_x, pt1_imCoord_y])
            pt2_imCoord_x = frame_imCoords_widget_list[camIdx][fishIdx][pt2_bpIdx][0].value
            pt2_imCoord_y = frame_imCoords_widget_list[camIdx][fishIdx][pt2_bpIdx][1].value
            pt2_imCoord = np.array([pt2_imCoord_x, pt2_imCoord_y])
            # parse the points for format and emptiness
            if np.all(np.isnan(pt1_imCoord)):
                pt1 =tuple([np.NaN, np.NaN])
            else:
                pt1 = tuple([int(x) for x in pt1_imCoord])
            if np.all(np.isnan(pt2_imCoord)):
                pt2 =tuple([np.NaN, np.NaN])
            else:
                pt2 = tuple([int(x) for x in pt2_imCoord])
            if ~np.all(np.isnan(pt1)) and ~np.all(np.isnan(pt2)):
                cv2.line(frame, pt1=pt1,  pt2=pt2, color=color, thickness=1, lineType=cv2.LINE_AA)
    # now update the figures
    if camIdx == 0:
        panel1.set_data(frame)
    elif camIdx == 1:
        panel2.set_data(frame)
    elif camIdx == 2:
        panel3.set_data(frame)
        
    # -------------------------------------------------------#
    #     Update the other image
    # -------------------------------------------------------#
    
    # decide on "other_camIdx" - the other camera to update
    if lock_camIdx_widget.value == 0:
        if camIdx_widget.value == 1:
            other_camIdx = 2
        elif camIdx_widget.value == 2:
            other_camIdx = 1
            
    if lock_camIdx_widget.value == 1:
        if camIdx_widget.value == 0:
            other_camIdx = 2
        elif camIdx_widget.value == 2:
            other_camIdx = 0
            
    if lock_camIdx_widget.value == 2:
        if camIdx_widget.value == 0:
            other_camIdx = 1
        elif camIdx_widget.value == 1:
            other_camIdx = 0
            
    # grab the frame we want to redraw
    cap = cv2.VideoCapture(video_paths[other_camIdx].value)
    cap.set(1,local_index.value) #go to current frame
    # grab the frame from each camera
    if cap.isOpened():
        # grab the 3 images
        ret, frame = cap.read()
        if not ret:
            raise TypeError('movie for splitdata_idx={0} not opening'.format(splitdata_idx.value))
    # release the caps
    cap.release();
     
    # get the image coordinates of the other point
    clicked_point_imCoord_x = frame_imCoords_widget_list[camIdx_widget.value][fishIdx_widget.value][bpIdx_widget.value][0].value
    clicked_point_imCoord_y = frame_imCoords_widget_list[camIdx_widget.value][fishIdx_widget.value][bpIdx_widget.value][1].value
    clicked_point_imcoords = np.array([clicked_point_imCoord_x, clicked_point_imCoord_y])
    fixed_point_imCoord_x = frame_imCoords_widget_list[lock_camIdx_widget.value][fishIdx_widget.value][bpIdx_widget.value][0].value
    fixed_point_imCoord_y = frame_imCoords_widget_list[lock_camIdx_widget.value][fishIdx_widget.value][bpIdx_widget.value][1].value
    fixed_point_imcoords = np.array([fixed_point_imCoord_x, fixed_point_imCoord_y])
    
    if lock_camIdx_widget.value == 0:
        if camIdx_widget.value == 1:
            other_camIdx = 2
            other_point_imcoords = cal.compute_YZ_imcoords_from_XZ_XY(fixed_point_imcoords, clicked_point_imcoords)
        elif camIdx_widget.value == 2:
            other_camIdx = 1
            other_point_imcoords = cal.compute_XY_imcoords_from_XZ_YZ(fixed_point_imcoords, clicked_point_imcoords)
            
    if lock_camIdx_widget.value == 1:
        if camIdx_widget.value == 0:
            other_camIdx = 2
            other_point_imcoords = cal.compute_YZ_imcoords_from_XZ_XY(clicked_point_imcoords, fixed_point_imcoords)
        elif camIdx_widget.value == 2:
            other_camIdx = 0
            other_point_imcoords = cal.compute_XZ_imcoords_from_XY_YZ(fixed_point_imcoords, clicked_point_imcoords)
            
    if lock_camIdx_widget.value == 2:
        if camIdx_widget.value == 0:
            other_camIdx = 1
            other_point_imcoords = cal.compute_XY_imcoords_from_XZ_YZ(clicked_point_imcoords, fixed_point_imcoords)
        elif camIdx_widget.value == 1:
            other_camIdx = 0
            other_point_imcoords = cal.compute_XZ_imcoords_from_XY_YZ(clicked_point_imcoords, fixed_point_imcoords)

    # update the image coordinates using the mouse click
    frame_imCoords_widget_list[other_camIdx][fishIdx_widget.value][bpIdx_widget.value][0].value = other_point_imcoords[0][0]
    frame_imCoords_widget_list[other_camIdx][fishIdx_widget.value][bpIdx_widget.value][1].value = other_point_imcoords[0][1]
    
    # draw the points for this view
    for fishIdx in range(numFish):
        # draw the bodypoints
        for bpIdx in range(numBodyPoints):
            # get the data
            imCoord_x = frame_imCoords_widget_list[other_camIdx][fishIdx][bpIdx][0].value
            imCoord_y = frame_imCoords_widget_list[other_camIdx][fishIdx][bpIdx][1].value
            imCoord = np.array([imCoord_x, imCoord_y])
            #get the color
            fcolor = fish_colors[fishIdx]
            color = ( int (fcolor [ 0 ]), int(fcolor [ 1 ]), int(fcolor [ 2 ]) )
            #color = ( int (no_id_color[ 0 ]), int(no_id_color[ 1 ]), int(no_id_color[ 2 ]))
            if np.all(np.isnan(imCoord)):
                continue
            cv2.circle(frame, (int(imCoord[0]), int(imCoord[1])), radius=2, color=color, thickness=-1)
        # draw the lines
        for edge in edges:
            pt1_bpIdx = edge[0]
            pt2_bpIdx = edge[1]

            pt1_imCoord_x = frame_imCoords_widget_list[other_camIdx][fishIdx][pt1_bpIdx][0].value
            pt1_imCoord_y = frame_imCoords_widget_list[other_camIdx][fishIdx][pt1_bpIdx][1].value
            pt1_imCoord = np.array([pt1_imCoord_x, pt1_imCoord_y])

            pt2_imCoord_x = frame_imCoords_widget_list[other_camIdx][fishIdx][pt2_bpIdx][0].value
            pt2_imCoord_y = frame_imCoords_widget_list[other_camIdx][fishIdx][pt2_bpIdx][1].value
            pt2_imCoord = np.array([pt2_imCoord_x, pt2_imCoord_y])

            # parse the points for format and emptiness
            if np.all(np.isnan(pt1_imCoord)):
                pt1 =tuple([np.NaN, np.NaN])
            else:
                pt1 = tuple([int(x) for x in pt1_imCoord])
            if np.all(np.isnan(pt2_imCoord)):
                pt2 =tuple([np.NaN, np.NaN])
            else:
                pt2 = tuple([int(x) for x in pt2_imCoord])
            if ~np.all(np.isnan(pt1)) and ~np.all(np.isnan(pt2)):
                cv2.line(frame, pt1=pt1,  pt2=pt2, color=color, thickness=1, lineType=cv2.LINE_AA)
            
    # now update the figures
    if other_camIdx == 0:
        panel1.set_data(frame)
    elif other_camIdx == 1:
        panel2.set_data(frame)
    elif other_camIdx == 2:
        panel3.set_data(frame)    
        
    fig1.canvas.draw()
    fig1.canvas.flush_events()
    return


def scroll_event_handler(event):
    '''scrolling event handler to resize figures
    '''
    base_scale = 1.5
    # select image from cursor location
    c_width = fig1.canvas._width
    if event.x < c_width / 3:
        ax = f1_ax1
    elif event.x > c_width / 3 and event.x < 2 * c_width / 3:
        ax = f1_ax2
    elif event.x > 2 * c_width / 3:
        ax = f1_ax3
    else:
        return
    # get current size and register event
    cur_xlim = ax.get_xlim()
    cur_ylim = ax.get_ylim()
    cur_xrange = (cur_xlim[1] - cur_xlim[0])*.5
    cur_yrange = (cur_ylim[1] - cur_ylim[0])*.5
    xdata = event.xdata 
    ydata = event.ydata
    if xdata == None or ydata == None:
        return
    if event.button == 'up':
        scale_factor = 1/base_scale
    elif event.button == 'down':
        scale_factor = base_scale
    else:
        scale_factor = 1
        print(event.button)
    # make sure home button resets to full image
    ax.figure.canvas.toolbar.push_current()
    # set new limits
    ax.set_xlim([xdata - cur_xrange*scale_factor,
                 xdata + cur_xrange*scale_factor])
    ax.set_ylim([ydata - cur_yrange*scale_factor,
                 ydata + cur_yrange*scale_factor])
    fig1.canvas.draw()
    fig1.canvas.flush_events()
    return

# connect functions to canvas
fig1.canvas.mpl_connect('button_press_event', click_down_event_handler)
fig1.canvas.mpl_connect('button_release_event', click_release_event_handler)
fig1.canvas.mpl_connect('key_press_event', key_press_event_handler)
fig1.canvas.mpl_connect('scroll_event', scroll_event_handler)

12

In [21]:
# Make the final widget
image_widget = widgets.VBox([fig1.canvas])

# Combined View (Labeling Zone)

In [22]:
# In this section we give an example of sticking all the widgets together

In [23]:
display(plot_widget, coordinate_widget, image_widget)

VBox(children=(Canvas(header_visible=False, layout=Layout(min_height='400px', width='100%'), toolbar=Toolbar(t…

VBox(children=(HBox(children=(VBox(children=(Text(value='Fish1 3D bps', layout=Layout(width='120px')), HBox(ch…

VBox(children=(Canvas(header_visible=False, layout=Layout(min_height='400px', width='100%'), toolbar=Toolbar(t…