In [None]:
# Useful functions for minor abilities
import pandas as pd
def test_trial_presence(data, trial_name):
    '''
    test to make sure the trial exists before using it. This function is used in topdown_preening.py and check_all_tracking.py.
    '''
    try:
        data.sel(trial=trial_name)
        exists = True
    except ValueError:
        exists = False
    return exists

In [None]:
# data reading functions
import pandas
def read_dlc(dlcfile):
    '''
    Read in and manage column names of topdown data passed in in the form of .h5 files.
    '''
    try:
        # read in .h5 file
        pts = pd.read_hdf(dlcfile)
    except ValueError:
        # read in .h5 file when there is a key set in corral_files.py
        pts = pd.read_hdf(dlcfile, key='data')
    # organize columns of pts
    pts.columns = [' '.join(col[:][1:3]).strip() for col in pts.columns.values]
    pts = pts.rename(columns={pts.columns[n]: pts.columns[n].replace(' ', '_') for n in range(len(pts.columns))})
    pt_loc_names = pts.columns.values
    return pts, pt_loc_names

####################################################

def read_in_eye(data_input, side, num_points=8):
    '''
    Read in and manage column names of eye data passed in in the form of .h5 files.
    '''

    # create list of eye points that matches data variables in data xarray
    eye_pts = []
    num_points_for_range = num_points + 1
    for eye_pt in range(1, num_points_for_range):
        eye_pts.append('p' + str(eye_pt) + ' x')
        eye_pts.append('p' + str(eye_pt) + ' y')
        eye_pts.append('p' + str(eye_pt) + ' likelihood')

    # create list of eye points labeled with which eye they come from
    new_eye_pts = []
    for old_eye_pt in eye_pts:
        new_eye_pts.append(str(side) + ' eye ' + str(old_eye_pt))

    # if eye data input exists, read it in and rename the data variables using the eye_dict of side-specific names
    if data_input != None:
        try:
            # read in .h5 file
            eye_data, eye_names = read_dlc(data_input)
            # turn old and new labels into dictionary so that eye points can be renamed
            col_corrections = {new_eye_pts[i]: eye_pts[i] for i in range(0, len(new_eye_pts))}
            eye_data = pd.DataFrame.rename(eye_data, columns=col_corrections)
        except NameError:
            # if the trial's main data file wasn't provided, raise error
            print('cannot add ' + str(side) + ' eye because no topdown camera data were given')
    # if eye data wasn't given, provide message (should still move forward with top-down or just one eye)
    elif data_input == None:
        print('no ' + str(side) + ' eye data given')
        eye_data = None
        eye_names = None

    return eye_data, eye_names

####################################################

def read_data(topdown_input=None, lefteye_input=None, righteye_input=None):
    '''
    Read in topdown, left eye, and/or right eye .h5 files by calling above functions.
    '''

    # read top-down camera data into xarray
    if topdown_input != None:
        topdown_pts, topdown_names = read_dlc(topdown_input)
    elif topdown_input == None:
        print('no top-down data given')

    # read in left and right eye (okay if not provided)
    lefteye_pts, lefteye_names = read_in_eye(lefteye_input, 'left')
    righteye_pts, righteye_names = read_in_eye(righteye_input, 'right')

    return topdown_pts, topdown_names, lefteye_pts, lefteye_names, righteye_pts, righteye_names

In [None]:
# data cleaning functions

# imports
import xarray as xr
import numpy as np
import pandas as pd

#############################################
def split_xyl(eye_names, eye_data, thresh):
    '''
    Makes a separate pandas DataFrame out of x and y points. Thresholds x and y points using likelihood threshold
    provided as input parameter to function. Also returns likelihoods as a pandas DataFrame.
    '''
    x_locs = []
    y_locs = []
    likeli_locs = []
    for loc_num in range(0, len(eye_names)):
        loc = eye_names[loc_num]
        if '_x' in loc:
            x_locs.append(loc)
        elif '_y' in loc:
            y_locs.append(loc)
        elif 'likeli' in loc:
            likeli_locs.append(loc)

    # get the xarray split up into x, y,and likelihood
    for loc_num in range(0, len(likeli_locs)):
        pt_loc = likeli_locs[loc_num]
        if loc_num == 0:
            likeli_pts = eye_data.sel(point_loc=pt_loc)
        elif loc_num > 0:
            likeli_pts = xr.concat([likeli_pts, eye_data.sel(point_loc=pt_loc)], dim='point_loc', fill_value=np.nan)
    for loc_num in range(0, len(x_locs)):
        pt_loc = x_locs[loc_num]
        # threshold from likelihood
        eye_data.sel(point_loc=pt_loc)[eye_data.sel(point_loc=pt_loc) < thresh] = np.nan
        if loc_num == 0:
            x_pts = eye_data.sel(point_loc=pt_loc)
        elif loc_num > 0:
            x_pts = xr.concat([x_pts, eye_data.sel(point_loc=pt_loc)], dim='point_loc', fill_value=np.nan)
    for loc_num in range(0, len(y_locs)):
        pt_loc = y_locs[loc_num]
        # threshold from likelihood
        eye_data.sel(point_loc=pt_loc)[eye_data.sel(point_loc=pt_loc) < thresh] = np.nan
        if loc_num == 0:
            y_pts = eye_data.sel(point_loc=pt_loc)
        elif loc_num > 0:
            y_pts = xr.concat([y_pts, eye_data.sel(point_loc=pt_loc)], dim='point_loc', fill_value=np.nan)

    # drop len=1 dims
    x_pts = xr.DataArray.squeeze(x_pts)
    y_pts = xr.DataArray.squeeze(y_pts)

    # convert to dataframe, transpose so points are columns
    x_vals = xr.DataArray.to_pandas(x_pts).T
    y_vals = xr.DataArray.to_pandas(y_pts).T

    return x_vals, y_vals, likeli_pts

In [None]:
# time management functions
import pandas as pd
from datetime import datetime
from datetime import timedelta
import xarray as xr
import numpy as np

####################################################
def match_deinterlace(raw_time, timestep):
    # match the length of deinterlaced videos with DLC point structures and videos that are twice the length of the timestamp files
    out = []
    for i in raw_time:
        between_time = i + (timestep / 2)
        out.append(i)
        out.append(between_time)
    return out

####################################################
def read_time(data, len_main):
    '''
    Read in time values for timestamps from .csv files, and correct their lengths
    Takes in the time data and the length of the main point data it is associated with
    len_main is the length of the main data associated with the timestamps
    len_main is used to sort out if the time file is too short because of deinterlacing of video
    '''

    TS_read = pd.read_csv(data, names=['time'])
    TS_read['time'] = pd.to_datetime(TS_read['time'])
    # make time relative
    # TS_read['time'] = TS_read['time'] - TS_read['time'][0]

    # Test length of the read-in time as it compares to the length of the data -- once timestamps are deinterlaced, are
    # there issues with number of timestamps? THis block should sort that out.
    timestep = TS_read['time'][1] - TS_read['time'][0]
    if len_main > len(TS_read['time']):
        time_out = match_deinterlace(TS_read['time'], timestep)
    elif len_main == len(TS_read['time']):
        time_out = TS_read['time']
    elif len_main < len(TS_read['time']):
        print('issue with read_time: more timepoints than there are data')
        time_out = TS_read['time']

    return time_out

In [None]:
# Preening topdown data
import pandas as pd
import numpy as np
import matplotlib
import xarray as xr
import os
from matplotlib import pyplot as plt
import tkinter
import math

def preen_topdown_data(all_topdown_data, trial_list, pt_names, savepath_input, coord_correction_val=1200, num_points=8, thresh=0.99, savefig=False):

    # run through each trial individually
    for trial_num in range(0, len(trial_list)):
        # get the name of the current trial
        current_trial = trial_list[trial_num]
        test_trial = test_trial_presence(all_topdown_data, current_trial)
        if test_trial is True:
            with all_topdown_data.sel(trial=current_trial) as topdown_data:
                # interpolate across NaNs fro each point_loc, then piece dataset back together
                topdown_interp = xr.DataArray.interpolate_na(topdown_data, dim='frame', use_coordinate='frame', method='linear')

                # for each point location in the topdown input data, select y head points and subtract them from int to fix coordinates
                y_names = [name for name in pt_names if '_y' in name]
                x_names = [name for name in pt_names if '_x' in name]
                l_names = [name for name in pt_names if 'lik' in name]
                y_data = topdown_interp.sel(point_loc=y_names) - coord_correction_val
                x_data = topdown_interp.sel(point_loc=x_names)
                l_data = topdown_interp.sel(point_loc=l_names)
                topdown_coordcor = xr.concat([x_data, y_data, l_data], dim='point_loc', fill_value=np.nan)

                # make figure of nose position over time, with start and finish labeled in green and red respectively
                if savefig is True:
                    fig1_dir = savepath_input + '/' + current_trial + '/'
                    if not os.path.exists(fig1_dir):
                        os.makedirs(fig1_dir)
                    fig1_path = fig1_dir + 'nose_position_over_time.png'

                    # for now, just drop NaNs that remain in the topdown_interp xarray after interpolation
                    # coordcor_pts_wout_nans = drop_leading_and_lagging_nans(topdown_coordcor, pt_names)
                    nose_x_pts = topdown_coordcor.sel(point_loc='nose_x')
                    nose_y_pts = topdown_coordcor.sel(point_loc='nose_y')
                    plt.figure(figsize=(15, 15))
                    plt.title('mouse nose x/y path before likelihood threshold')
                    plt.plot(np.squeeze(nose_x_pts), np.squeeze(nose_y_pts))
                    plt.plot((np.squeeze(nose_x_pts)[0]), (np.squeeze(nose_y_pts)[0]), 'go') # starting point
                    plt.plot((np.squeeze(nose_x_pts)[-1]), (np.squeeze(nose_y_pts)[-1]), 'ro')  # ending point
                    plt.savefig(fig1_path, dpi=300)
                    plt.close()

                # threshold points using the input paramater (thresh) to find all times when all points are good (only want high values)
                likeli_loop_count = 0
                for pt_num in range(0, len(pt_names)):
                    current_pt_loc = pt_names[pt_num]
                    if 'likelihood' in current_pt_loc:
                        # find the associated x and y points of the selected likelihood
                        # assumes order is x, y, likelihood, will cause problems if isn't true of data...
                        assoc_x_pos = pt_names[pt_num - 2]
                        assoc_x_pt = topdown_coordcor.sel(point_loc=assoc_x_pos)
                        assoc_y_pos = pt_names[pt_num - 1]
                        assoc_y_pt = topdown_coordcor.sel(point_loc=assoc_y_pos)

                        # select only the likelihood data for this point
                        likeli_pt = topdown_coordcor.sel(point_loc=current_pt_loc)

                        # set x/y coords to NaN where the likelihood is below threshold value
                        assoc_x_pt[likeli_pt < thresh] = np.nan
                        assoc_y_pt[likeli_pt < thresh] = np.nan

                        likeli_thresh_1loc = xr.concat([assoc_x_pt, assoc_y_pt, likeli_pt], dim='point_loc')

                        if likeli_loop_count == 0:
                            likeli_thresh_allpts = likeli_thresh_1loc
                        elif likeli_loop_count > 0:
                            likeli_thresh_allpts = xr.concat([likeli_thresh_allpts, likeli_thresh_1loc], dim='point_loc', fill_value=np.nan)

                        likeli_loop_count = likeli_loop_count + 1

                if savefig is True:
                    fig2_dir = savepath_input + '/' + current_trial + '/'
                    if not os.path.exists(fig2_dir):
                        os.makedirs(fig2_dir)
                    fig2_path = fig2_dir + 'nose_position_over_time_thresh.png'

                    # make a plot of the mouse's path, where positions that fall under threshold will be NaNs
                    nose_x_thresh_pts = likeli_thresh_allpts.sel(point_loc='nose_x')
                    nose_y_thresh_pts = likeli_thresh_allpts.sel(point_loc='nose_y')
                    # mask the NaNs, but only for the figure (don't want to lose time information for actual analysis)
                    nose_x_thresh_nonan_pts = nose_x_thresh_pts[np.isfinite(nose_x_thresh_pts)]
                    nose_y_thresh_nonan_pts = nose_y_thresh_pts[np.isfinite(nose_y_thresh_pts)]
                    plt.figure(figsize=(15, 15))
                    plt.title('mouse nose x/y path after likelihood threshold')
                    plt.plot(np.squeeze(nose_x_thresh_nonan_pts), np.squeeze(nose_y_thresh_nonan_pts))
                    plt.plot((np.squeeze(nose_x_thresh_nonan_pts)[0]), (np.squeeze(nose_y_thresh_nonan_pts)[0]), 'go') # starting point
                    plt.plot((np.squeeze(nose_x_thresh_nonan_pts)[-1]), (np.squeeze(nose_y_thresh_nonan_pts)[-1]), 'ro') # ending point
                    plt.savefig(fig2_path, dpi=300)
                    plt.close()

                # x_vals, y_vals, likeli_pts = split_xyl(pt_names, topdown_coordcor, 0.99)
                # timestamp_list = list(x_vals.index.values)

                # if savefig is True:
                #     frame_slice = timestamp_list[0]
                #     x_to_plot = x_vals.loc[[frame_slice]]
                #     y_to_plot = y_vals.loc[[frame_slice]]
                #
                #     fig3_dir = savepath_input + '/' + current_trial + '/'
                #     if not os.path.exists(fig3_dir):
                #         os.makedirs(fig3_dir)
                #     fig3_path = fig3_dir + 'dlc_topdown_pts_at_time_' + str(frame_slice) + '.png'
                #
                #     plt.figure(figsize=(15, 10))
                #     plt.plot(int(x_to_plot.iloc[0,0]), int(y_to_plot.iloc[0,0]), 'bo')
                #     plt.plot(int(x_to_plot.iloc[0,1]), int(y_to_plot.iloc[0,1]), 'go')
                #     plt.plot(int(x_to_plot.iloc[0,2]), int(y_to_plot.iloc[0,2]), 'ro')
                #     plt.plot(int(x_to_plot.iloc[0,3]), int(y_to_plot.iloc[0,3]), 'co')
                #     plt.plot(int(x_to_plot.iloc[0,4]), int(y_to_plot.iloc[0,4]), 'mo')
                #     plt.plot(int(x_to_plot.iloc[0,5]), int(y_to_plot.iloc[0,5]), 'yo')
                #     plt.plot(int(x_to_plot.iloc[0,6]), int(y_to_plot.iloc[0,6]), 'ko')
                #     plt.title('topdown dlc points at time ' + str(frame_slice) + ' of ' + str(current_trial))
                #     plt.savefig(fig3_path, dpi=300)
                #     plt.close()

                # align the head of the mouse from the topdown view, even if some points are missing
                # theta_all, aligned_all = align_head(topdown_coordcor, timestamp_list, pt_names)

                # if savefig is True:
                #     fig5_dir = savepath_input + '/' + current_trial + '/'
                #     if not os.path.exists(fig5_dir):
                #         os.makedirs(fig5_dir)
                #     fig5_path = fig5_dir + 'dlc_topdown_head_angle.png'
                #
                #     plt.figure(figsize=(15, 10))
                #     plt.plot(theta_all)
                #     plt.title('topdown head angle')
                #     plt.xlabel('frame')
                #     plt.ylabel('angle')
                #     plt.savefig(fig5_path, dpi=300)
                #     plt.close()

                # this trial's data with no NaNs both post-thresholding and post-y-coordinate correction
                # mask the NaNs
                likeli_thresh_allpts['trial'] = current_trial

                # append this trial to all others now that processing is done
                if trial_num == 0:
                    all_topdown_output = likeli_thresh_allpts
                elif trial_num > 0:
                    all_topdown_output = xr.concat([all_topdown_output, likeli_thresh_allpts], dim='trial', fill_value=np.nan)

    return all_topdown_output


In [None]:
# eye tracking
import pandas as pd
import numpy as np
from skimage import measure
import xarray as xr
import matplotlib.pyplot as plt

####################################################
def get_eye_angles(ellipseparams):
    R = np.linspace(0,2*np.pi,100)
    longaxis_all = np.maximum(ellipseparams[:,2],ellipseparams[:,3])
    shortaxis_all = np.minimum(ellipseparams[:,2],ellipseparams[:,3])
    Ellipticity = shortaxis_all/longaxis_all
    lis, = np.where(Ellipticity<.9)
    A = np.vstack([np.cos(ellipseparams[lis,4]),np.sin(ellipseparams[lis,4])])
    b = np.expand_dims(np.diag(A.T@np.squeeze(ellipseparams[lis,0:2].T)),axis=1)
    CamCent = np.linalg.inv(A@A.T)@A@b
    longaxis = np.squeeze(np.maximum(ellipseparams[lis,2],ellipseparams[lis,3]))
    shortaxis = np.squeeze(np.minimum(ellipseparams[lis,2],ellipseparams[lis,3]))
    Ellipticity = shortaxis/longaxis
    scale = np.sum(np.sqrt(1-(Ellipticity)**2)*(np.linalg.norm(ellipseparams[lis,0:2]-CamCent.T,axis=1)))/np.sum(1-(Ellipticity)**2);
    temp = (ellipseparams[:,0]-CamCent[0])/scale
    theta = np.arcsin(temp)
    phi = np.arcsin((ellipseparams[:,1]-CamCent[1])/np.cos(theta)/scale)
    return theta, phi, longaxis_all, shortaxis_all, CamCent

####################################################
def preen_then_get_eye_angles(ellipseparams, pxl_thresh):
    bdfit2, temp = np.where(ellipseparams[:, 2:4] > pxl_thresh)
    eparams = pd.DataFrame(ellipseparams)
    eparams.iloc[bdfit2, :] = np.nan
    eparams = eparams.interpolate(method='linear', limit_direction='both', axis=0)
    ellipseparams[bdfit2, :] = eparams.iloc[bdfit2, :]
    # run get_eye_angles on the cleaned data
    theta, phi, longaxis_all, shortaxis_all, CamCent = get_eye_angles(ellipseparams)
    return theta, phi, longaxis_all, shortaxis_all, CamCent

####################################################
def calc_ellipse(num_frames, x_vals, y_vals, pxl_thresh):
    emod = measure.EllipseModel()
    # create an empty array to be populated by the five outputs of EllipseModel()
    ellipseparams = np.empty((0, 5))
    # get list of all timestamps
    timestamp_list = x_vals.index.values
    # index through each frame and stack the ellipse parameters
    for timestamp in timestamp_list:
        try:
            # first the ellipse
            x_block = x_vals.loc[timestamp, :]
            y_block = y_vals.loc[timestamp, :]
            xy = np.column_stack((x_block, y_block))
            if emod.estimate(xy) is True:
                params_raw = np.array(emod.params)
                params_expanded = np.expand_dims(params_raw, axis=0)
                ellipseparams = np.append(ellipseparams, params_expanded, axis=0)
        except KeyError:
            # if the timestamp cannot be found, add a filler entry of parameters
            ellipseparams = np.append(ellipseparams, np.empty((0, 5)), axis=0)

    theta, phi, longaxis_all, shortaxis_all, CamCent = preen_then_get_eye_angles(ellipseparams, pxl_thresh)

    return theta, phi, longaxis_all, shortaxis_all, CamCent

####################################################
def eye_angles(eye_data_input, eye_names, trial_id_list, savepath_input, all_trial_time, savefig=False, thresh=0.99, pxl_thresh=50, side='left'):
    '''
    Prepares data for use with get_eye_angles, one eye at a time.
    pxl_thresh is the max number of pixels for radius of pupil; thresh is the likelihood threshold
    '''
    for trial_num in range(0, len(trial_id_list)):
        current_trial_name = trial_id_list[trial_num]
        if eye_data_input.sel(trial=current_trial_name) is not None:
            eye_data = eye_data_input.sel(trial=current_trial_name)

            x_vals, y_vals, likeli_vals = split_xyl(eye_names, eye_data, thresh)

            x_vals = pd.DataFrame.dropna(x_vals)
            y_vals = pd.DataFrame.dropna(y_vals)

            # get the number of frames
            num_frames = len(x_vals)

            # make a plot of an example frame, showing the points of the ellipse
            # a way to make sure the data are somewhat elliptical
            if savefig is True:
                frame_slice = 3
                x_to_plot = x_vals.loc[[frame_slice]]
                y_to_plot = y_vals.loc[[frame_slice]]
                plt.figure(figsize=(15,10))
                plt.scatter(x_to_plot, y_to_plot, color='r')
                plt.title('dlc points at frame ' + str(frame_slice) + ' of ' + str(side) + ' eye of ' + str(current_trial_name))
                plt.savefig(savepath_input + '/' + current_trial_name + '/' + 'dlc_eye_pts_at_time_' + str(frame_slice) + '.png', dpi=300)
                plt.close()

            # get the ellipse parameters out of the point positional data
            theta, phi, longaxis_all, shortaxis_all, CamCent = calc_ellipse(num_frames, x_vals, y_vals, pxl_thresh)

            if savefig is True:
                plt.subplots(2, 1, figsize=(30,20))
                plt.subplot(211)
                plt.plot(theta * 180 / np.pi)
                plt.xlabel('frame')
                plt.ylabel('angle')
                plt.title('theta for ' + str(side) + ' eye of ' + str(current_trial_name))
                plt.subplot(212)
                plt.plot(phi * 180 / np.pi)
                plt.xlabel('frame')
                plt.ylabel('angle')
                plt.title('phi for ' + str(side) + ' eye of ' + str(current_trial_name))
                plt.savefig(savepath_input + '/' + current_trial_name + '/' + 'theta_phi_traces_over_time_.png', dpi=300)
                plt.close()

            cam_center = [np.squeeze(CamCent[0]).tolist(), np.squeeze(CamCent[1]).tolist()]

            # make a DataFrame of the data that calc_ellipse() outputs
            trial_ellipse_df = pd.DataFrame({'theta':list(theta), 'phi':list(phi), 'longaxis_all':list(longaxis_all),
                                             'shortaxis_all':list(shortaxis_all)})

            # turn DataFrame into an xr DataArray, name the dims, fill in metadata (the trial name, which eye it is, etc.)
            ellipse_params = ['theta', 'phi', 'longaxis_all', 'shortaxis_all']
            # len_index = len(x_vals.index.values)
            # len_data = len(trial_ellipse_df)
            # len_diff = len_index - len_data
            # if len_index > len_data:
            #     time = x_vals.index.values[:-len_diff]
            # elif len_index < len_data:
            #     step = x_vals.index.values[-1] - x_vals.index.values[-2]
            #     time = x_vals.index.values
            #     time.append(x_vals.index.values[-1] + step)
            # elif len_index == len_data:
            #     time = x_vals.index.values

            trial_ellipse_data = xr.DataArray(trial_ellipse_df, coords=[('frame', range(0, len(trial_ellipse_df))), ('ellipse_params', ellipse_params)])
            trial_ellipse_data['trial'] = current_trial_name
            trial_ellipse_data['eye_side'] = side
            trial_ellipse_data['cam_center_x'] = cam_center[0]
            trial_ellipse_data['cam_center_y'] = cam_center[1]

            # append ellipse data from the current trial to a main xr DataArray to be saved out
            if trial_num == 0:
                side_ellipse = trial_ellipse_data
            elif trial_num > 0:
                side_ellipse = xr.concat([side_ellipse, trial_ellipse_data], dim='trial', fill_value=np.nan)

    return side_ellipse


In [None]:
# deal with videos

# import packages
import cv2
import numpy as np
import xarray as xr

####################################################
def plot_pts_on_vid(trial_name, camtype, vid_path, savepath, dlc_data=None, ell_data=None):
    '''
    Open video from any camera passed in plot its DLC points over the video feed saved out as an .mp4 file.
    '''

    # read topdown video in
    vidread = cv2.VideoCapture(vid_path)
    width = int(vidread.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(vidread.get(cv2.CAP_PROP_FRAME_HEIGHT))

    # setup the file to save out of this
    savepath = str(savepath) + '/' + str(trial_name) + '/' + str(trial_name) + '_' + str(camtype) + '.avi'
    fourcc = cv2.VideoWriter_fourcc(*'XVID')
    out_vid = cv2.VideoWriter(savepath, fourcc, 20.0, (width, height))

    # small aesthetic things to set
    plot_color0 = (225, 255, 0)
    plot_color1 = (0, 255, 255)

    if camtype == 't':
        print('plotting points on topdown view')
        while (1):
            # read the frame for this pass through while loop
            ret_td, frame_td = vidread.read()

            if not ret_td:
                break

            # get current frame number to be displayed, so that it can be used to slice DLC data
            frame_time = vidread.get(cv2.CAP_PROP_POS_FRAMES)

            try:
                for k in range(0, 30, 3):
                    topdownTS = dlc_data.sel(frame=frame_time)
                    try:
                        td_pts_x = topdownTS.isel(point_loc=k)
                        td_pts_y = topdownTS.isel(point_loc=k + 1)
                        center_xy = (int(td_pts_x), int(td_pts_y))
                        if k == 0:
                            # plot them on the fresh topdown frame
                            pt_frame_td = cv2.circle(frame_td, center_xy, 6, plot_color0, -1)
                        elif k >= 3:
                            # plot them on the topdown frame with all past topdown points
                            pt_frame_td = cv2.circle(pt_frame_td, center_xy, 6, plot_color0, -1)
                    except ValueError:
                        pt_frame_td = frame_td
            except KeyError:
                pt_frame_td = frame_td

            out_vid.write(pt_frame_td)
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break

        out_vid.release()
        cv2.destroyAllWindows()

    elif camtype == 'l':
        print('plotting ellipse and points on left eye view')
        while (1):
            # read the frame for this pass through while loop
            ret_le, frame_le = vidread.read()

            if not ret_le:
                break

            # get current frame number to be displayed, so that it can be used to slice DLC data
            frame_time = vidread.get(cv2.CAP_PROP_POS_FRAMES)

            try:
                leftellipseTS = ell_data.sel(frame=frame_time)
                try:
                    # get out ellipse parameters and plot them on the video
                    ellipse_center = (int(leftellipseTS['cam_center_x'].values), int(leftellipseTS['cam_center_y'].values))
                    ellipse_longaxis = int(leftellipseTS.sel(ellipse_params='longaxis_all').values)
                    ellipse_shortaxis = int(leftellipseTS.sel(ellipse_params='shortaxis_all').values)
                    ellipse_axes = (ellipse_longaxis, ellipse_shortaxis)
                    ellipse_theta = int(leftellipseTS.sel(ellipse_params='theta').values)
                    ellipse_phi = int(leftellipseTS.sel(ellipse_params='phi').values)
                    
                    e_points = emod.predict_xy(ellipse_theta).astype(np.int)
                    for k in range(0, np.size(e_points, axis=1)):
                        ptk = e_points[k]
                        if k == 0:
                            pt_frame_le = cv2.circle(frame_le, ptk, 2, plot_color0, -1)
                        elif k >= 1:
                            pt_frame_le = cv2.circle(pt_frame_le, ptk, 2, plot_color0, -1)
                    
                    
                    
                    except ValueError:
                    plot_lellipse = frame_le

                for k in range(0, 24, 3):
                    try:
                        # get out the DLC points and plot them on the video
                        leftptsTS = dlc_data.sel(time=frame_time)
                        le_pts_x = leftptsTS.isel(point_loc=k)
                        le_pts_y = leftptsTS.isel(point_loc=k + 1)
                        le_center_xy = (int(le_pts_x), int(le_pts_y))
                        if k == 0:
                            # plot them on the fresh lefteye frame
                            plot_lellipse = cv2.circle(plot_lellipse, le_center_xy, 6, plot_color1, -1)
                        elif k >= 3:
                            # plot them on the lefteye frame with all past lefteye points
                            plot_lellipse = cv2.circle(plot_lellipse, le_center_xy, 6, plot_color1, -1)
                    except ValueError:
                        # print('ignoring ValueError raised by left eye DLC points')
                        pass

            except KeyError:
                plot_lellipse = plot_lellipse

            out_vid.write(plot_lellipse)
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break

        out_vid.release()
        cv2.destroyAllWindows()

    elif camtype == 'r':
        print('plotting ellipse and points on right eye view')
        while (1):
            # read the frame for this pass through while loop
            ret_re, frame_re = vidread.read()

            if not ret_re:
                break

            # get current frame number to be displayed, so that it can be used to slice DLC data
            frame_time = vidread.get(cv2.CAP_PROP_POS_FRAMES)

            try:
                rightellipseTS = ell_data.sel(frame=frame_time)
                try:
                    # get out ellipse parameters and plot them on the video
                    ellipse_center = (int(rightellipseTS['cam_center_x'].values), int(rightellipseTS['cam_center_y'].values))
                    ellipse_longaxis = int(rightellipseTS.sel(ellipse_params='longaxis_all').values)
                    ellipse_shortaxis = int(rightellipseTS.sel(ellipse_params='shortaxis_all').values)
                    ellipse_axes = (ellipse_longaxis, ellipse_shortaxis)
                    ellipse_theta = int(rightellipseTS.sel(ellipse_params='theta').values)
                    ellipse_phi = int(rightellipseTS.sel(ellipse_params='phi').values)
                    plot_rellipse = cv2.ellipse(frame_re, ellipse_center, ellipse_axes, ellipse_theta, 0, 360,
                                                plot_color0, 4)
                except ValueError:
                    plot_rellipse = frame_re

                for k in range(0, 24, 3):
                    try:
                        # get out the DLC points and plot them on the video
                        rightptsTS = dlc_data.sel(time=frame_time)
                        le_pts_x = rightptsTS.isel(point_loc=k)
                        le_pts_y = rightptsTS.isel(point_loc=k + 1)
                        le_center_xy = (int(le_pts_x), int(le_pts_y))
                        if k == 0:
                            # plot them on the fresh righteye frame
                            plot_rellipse = cv2.circle(plot_rellipse, le_center_xy, 6, plot_color1, -1)
                        elif k >= 3:
                            # plot them on the righteye frame with all past lefteye points
                            plot_rellipse = cv2.circle(plot_rellipse, le_center_xy, 6, plot_color1, -1)
                    except ValueError:
                        # print('ignoring ValueError raised by right eye DLC points')
                        pass

            except KeyError:
                plot_rellipse = plot_rellipse

            out_vid.write(plot_rellipse)
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break

        out_vid.release()
        cv2.destroyAllWindows()

    elif camtype == 'w':
        print('writing worldcam view')
        while (1):
            # read the frame for this pass through while loop
            ret_wc, frame_wc = vidread.read()

            if not ret_wc:
                break

            out_vid.write(frame_wc)
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break

        out_vid.release()
        cv2.destroyAllWindows()

    else:
        print('unknown camtype argument... exiting')


In [None]:
# package imports for main script
from glob import glob
import os.path
import numpy as np
import sys
import xarray as xr
import pandas as pd
import argparse
import warnings

# user inputs

In [None]:
# inputs
dlcpath = '/Users/dylanmartins/data/Niell/PreyCapture/Cohort3/J463c(blue)/110719/CorralledApproachDataDI/'
vidpath = '/Users/dylanmartins/data/Niell/PreyCapture/Cohort3/J463c(blue)/110719/CorralledApproachVids/'
savepath = '/Users/dylanmartins/data/Niell/PreyCapture/Cohort3Outputs/J463c(blue)_110719/workspace_test_02/'
savefig = True
savenc = True
looplim = 3
likthresh = 0.99
pxlthresh = 50
coordcor = 0

In [None]:
# find the files wanted from the given args.dlcpath and args.vidpath
# DeepLabCut point locations
topdown_file_list = glob(os.path.join(dlcpath, '*top*DeepCut*.h5'))
righteye_file_list = glob(os.path.join(dlcpath, '*eye1r*DeInter2*.h5'))
lefteye_file_list = glob(os.path.join(dlcpath, '*eye2l*DeInter2*.h5'))
# video files that those points come from
righteye_vid_list = glob(os.path.join(vidpath, '*eye1r*.avi'))
lefteye_vid_list = glob(os.path.join(vidpath, '*eye2l*.avi'))
topdown_vid_list = glob(os.path.join(vidpath, '*top*.avi'))
worldcam_vid_list = glob(os.path.join(vidpath, '*world*.avi'))
# camera time files
righteye_time_file_list = glob(os.path.join(dlcpath, '*eye1r*TS*.csv'))
lefteye_time_file_list = glob(os.path.join(dlcpath, '*eye2l*TS*.csv'))
topdown_time_file_list = glob(os.path.join(dlcpath, '*topTS*.csv'))

# exclude some of the sets of data that cause issues
topdown_file_list = [i for i in topdown_file_list if '1_110719_01' not in i]
topdown_file_list = [i for i in topdown_file_list if '2_110719_08' not in i]
topdown_file_list = [i for i in topdown_file_list if '1_110719_11' not in i]

# sort the files that are used to find all other files
topdown_file_list = sorted(topdown_file_list)

In [None]:
# loop through each topdown DLC point .h5 file name
loop_count = 0
trial_id_list = []
for file in topdown_file_list:
    if loop_count < looplim:
        # get unique sections of filename out so that they can be used to find the associated files
        split_path = os.path.split(file)
        file_name = split_path[1]
        mouse_key = file_name[0:5]
        trial_key = file_name[17:28]

        # find the right/left eye DLC files that match the topdown DLC file
        righteye_files = [i for i in righteye_file_list if mouse_key and trial_key in i]
        lefteye_files = [i for i in lefteye_file_list if mouse_key and trial_key in i]
        # find the camera time files that match the topdown DLC file
        topdown_time_files = [i for i in topdown_time_file_list if mouse_key and trial_key in i]
        lefteye_time_files = [i for i in lefteye_time_file_list if mouse_key and trial_key in i]
        righteye_time_files = [i for i in righteye_time_file_list if mouse_key and trial_key in i]

        # the above lines return lists of one string, this converts them into just a string
        lefteye_file = lefteye_files[0]
        righteye_file = righteye_files[0]
        topdown_time_file = topdown_time_files[0]
        lefteye_time_file = lefteye_time_files[0]
        righteye_time_file = righteye_time_files[0]

        # read in the data from file locations
        topdown_pts, topdown_names, lefteye_pts, lefteye_names, righteye_pts, righteye_names = read_data(file, lefteye_file, righteye_file)

        # make a unique name for the mouse and the recording trial
        trial_id = 'mouse_' + str(mouse_key) + '_trial_' + str(trial_key)
        trial_id_list.append(trial_id)

        # read in the time stamp data of each camera for this trial
        if topdown_time_file is not None:
            topdown_time = read_time(topdown_time_file, len(topdown_pts))
        elif topdown_time_file is None:
            topdown_time = None

        if lefteye_time_file is not None:
            lefteye_time = read_time(lefteye_time_file, len(lefteye_pts))
        elif lefteye_time_file is None:
            lefteye_time = None

        if righteye_time_file is not None:
            righteye_time = read_time(righteye_time_file, len(righteye_pts))
        elif righteye_time_file is None:
            righteye_time = None

        print('building xarrays for trial ' + trial_id)
        # build one DataArray that stacks up all topdown trials in separate dimensions
        if topdown_time is not None:
            if loop_count == 0:
                topdown = xr.DataArray(topdown_pts, dims=['frame', 'point_loc'])
                topdown['trial'] = trial_id
                topdown_time_df = pd.DataFrame(topdown_time, columns=[trial_id])
            elif loop_count > 0:
                topdown_trial = xr.DataArray(topdown_pts, dims=['frame', 'point_loc'])
                topdown_trial['trial'] = trial_id
                topdown = xr.concat([topdown, topdown_trial], dim='trial', fill_value=np.nan)
                topdown_time_df_to_append = pd.DataFrame(topdown_time, columns=[trial_id])
                topdown_time_df = topdown_time_df.join(topdown_time_df_to_append)
        elif topdown_time is None:
            print('trial ' + trial_id + ' has no topdown time data')

        # build one DataArray that stacks up all trials in separate dimensions for each of two possible eyes
        if lefteye_pts is not None and lefteye_time is not None:
            if loop_count == 0:
                # create a DataArray of left eye point positions
                lefteye = xr.DataArray(lefteye_pts, dims=['frame', 'point_loc'])
                lefteye['trial'] = trial_id
                lefteye_time_df = pd.DataFrame(lefteye_time,columns=[trial_id])
            elif loop_count > 0:
                # point positions (concat new to full)
                lefteye_trial = xr.DataArray(lefteye_pts, dims=['frame', 'point_loc'])
                lefteye_trial['trial'] = trial_id
                lefteye = xr.concat([lefteye, lefteye_trial], dim='trial', fill_value=np.nan)
                lefteye_time_df_to_append = pd.DataFrame(lefteye_time,columns=[trial_id])
                lefteye_time_df = lefteye_time_df.join(lefteye_time_df_to_append)
        elif lefteye_pts is None or lefteye_time is None:
            print('trial ' + trial_id + ' has no left eye camera data')

        if righteye_pts is not None and righteye_time is not None:
            if loop_count == 0:
                righteye = xr.DataArray(righteye_pts, dims=['time', 'point_loc'])
                righteye['trial'] = trial_id
                righteye_time_df = pd.DataFrame(righteye_time,columns=[trial_id])
            elif loop_count > 0:
                righteye_trial = xr.DataArray(righteye_pts, dims=['time', 'point_loc'])
                righteye_trial['trial'] = trial_id
                righteye = xr.concat([righteye, righteye_trial], dim='trial', fill_value=np.nan)
                righteye_time_df_to_append = pd.DataFrame(righteye_time,columns=[trial_id])
                righteye_time_df = righteye_time_df.join(righteye_time_df_to_append)
        elif righteye_pts is None or righteye_time is None:
            print('trial ' + trial_id + ' has no right eye camera data')

        # turn time pandas objects into xarrays
        if topdown_time is not None:
            if loop_count == 0:
                all_topdownTS = xr.DataArray(topdown_time_df)
                all_topdownTS['trial'] = trial_id
            elif loop_count > 0:
                topdownTS = xr.DataArray(topdown_time_df)
                topdownTS['trial'] = trial_id
                all_topdownTS = xr.concat([all_topdownTS, topdownTS], dim='trial')

        if lefteye_time is not None:
            if loop_count == 0:
                all_lefteyeTS = xr.DataArray(lefteye_time_df)
                all_lefteyeTS['trial'] = trial_id
            elif loop_count > 0:
                lefteyeTS = xr.DataArray(lefteye_time_df)
                lefteyeTS['trial'] = trial_id
                all_lefteyeTS = xr.concat([all_lefteyeTS, lefteyeTS], dim='trial')

        if righteye_time is not None:
            if loop_count == 0:
                all_righteyeTS = xr.DataArray(righteye_time_df)
                all_righteyeTS['trial'] = trial_id
            elif loop_count > 0:
                righteyeTS = xr.DataArray(righteye_time_df)
                righteyeTS['trial'] = trial_id
                all_righteyeTS = xr.concat([all_righteyeTS, righteyeTS], dim='trial')

        loop_count = loop_count + 1

In [None]:
# process the topdown data
print('preening top-down points')
preened_topdown = preen_topdown_data(topdown, trial_id_list, topdown_names, savepath, savefig=savefig, coord_correction_val=coordcor, thresh=likthresh)
preened_topdown = xr.DataArray.rename(preened_topdown, 'topdown')

In [None]:
with warnings.catch_warnings():
    # ignore a reoccurring runtime error while running the ellipse parameter functions
    warnings.simplefilter("ignore")

    print('getting left eye angles')
    left_ellipse = eye_angles(lefteye, lefteye_names, trial_id_list, savepath, lefteye_time_df, savefig=savefig, side='left', pxl_thresh=pxlthresh)
    left_ellipse = xr.DataArray.rename(left_ellipse, 'left_ellipse')
    print('getting right eye angles')
    right_ellipse = eye_angles(righteye, righteye_names, trial_id_list, savepath, righteye_time_df, savefig=savefig, side='right', pxl_thresh=pxlthresh)
    right_ellipse = xr.DataArray.rename(right_ellipse, 'right_ellipse')
print('done getting eye angles')

In [None]:
# STILL WORKING ON THIS
# get out the topdown head angle

import matplotlib.pyplot as plt
import numpy as np

def rotmat(theta):
    m = [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]
    return m

all_trial_pts = preened_topdown
pt_names = topdown_names
thresh = 0.99
cricket = True

for trial_num in range(0, len(trial_id_list)):
        current_trial_name = trial_id_list[trial_num]
        if all_trial_pts.sel(trial=current_trial_name) is not None:
            pt_input = all_trial_pts.sel(trial=current_trial_name)
            
            # -- function starts here --
            if cricket is True:
                pt_input = pt_input[:-3,:]
                pt_names = pt_names[:-3]

            x_vals, y_vals, likeli_pts = split_xyl(pt_names, pt_input, thresh)

            data = np.stack([x_vals.T, y_vals.T])
            
            centroid = np.squeeze(np.mean(data, axis=1))
            centered = np.zeros(np.shape(data), dtype=object)
            theta_good = np.zeros(np.shape(centroid[0,:]), dtype=object)
            aligned_good = np.zeros(np.shape(data), dtype=object)
            
            # get centered points
            for h in range(0,len(data[0,:,0])):
                centered[:,h,:] = data[:,h,:] - centroid
                
            # last good frame will be used as reference frame
            for testframe in range(0, np.size(data, axis=2)):
                testptnum = np.size(data[:, :, testframe], axis=1)
                num_good = np.count_nonzero(~np.isnan(data[1, :, testframe]))
                if testptnum == num_good:
                    ref = centered[:,:,testframe]
            
            # if there are no NaNs and it's a perfect timepoint, loop through a range of thetas and rotate the frame by that much
            # then, calculate how well it matches the reference
            for frame in range(0, np.size(centroid, axis=1)):
                num_ideal_points = np.size(data[0, :, frame], axis=0)
                num_real_pts = np.count_nonzero(~np.isnan(data[1, :, frame]))
                if num_real_pts == num_ideal_points:
                    c = centered[:,:,frame]
                    theta = np.linspace(0, (2 * math.pi), 101)
                    theta = theta[1:-1]
                    rms = np.zeros(len(theta))
                    for i in range(0, len(theta)):
                        c_rot = np.matmul(c.T, rotmat(theta[i])) # rotation
                        rms[i] = np.nansum((ref - c_rot.T) ** 2)  # root mean squared
                    # find index for theta with smallest error and rotate by this amount
                    theta_good[frame] = np.argmin(rms)
                    aligned_good[:,:,frame] = np.matmul(c.T, rotmat(theta_good[frame])).T
                elif num_real_pts < num_ideal_points:
                    theta_good[frame] = np.nan
                    aligned_good[:,:,frame] = np.nan
                    

In [None]:
# calculate mean head from good points across trials
mean_head = np.nanmean(aligned_good, axis=2)

In [None]:
# rotate mean head to align to x-axis
longaxis = mean_head[:, [num_ideal_points-1, 1]] # line from middle of head to nose
longtheta = np.arctan2(np.diff(longaxis[1]).astype(float), np.diff(longaxis[0]).astype(float))[0] # angle of line
headrot = rotmat(-longtheta)

aligned = np.zeros(np.shape(aligned_good.T), dtype=object)

for frame in range(0,np.size(aligned_good, axis=2)):
    aligned[frame,:,:] = np.matmul(aligned_good[:,:,frame].T, headrot)

aligned = aligned.T

In [None]:
mean_head1 = np.nanmean(aligned, axis=2)

mean_stack = np.stack(([mean_head1[0,:]**2, mean_head1[1,:]**2]), axis=1)
mean_dist = mean_stack[:,0] + mean_stack[:,1]
for i in range(0,len(mean_dist)):
    mean_dist[i] = np.sqrt(mean_dist[i])
cent = np.zeros([2, np.size(aligned_good.T, axis=0)], dtype=object)

# get all cetroids
for frame in range(0, np.size(centroid, axis=1)):
    c = data[:,:,frame]
    
    mesh1 = np.floor(np.amin(c[0,:]))
    mesh2 = np.ceil(np.amax(c[0,:]))
    mesh3 = np.floor(np.amin(c[1,:]))
    mesh4 = np.ceil(np.amax(c[1,:]))
    meshx, meshy = np.meshgrid((mesh1, mesh2), (mesh3, mesh4), sparse=False)
        
    # for each head point calculate how far the pixels are from it, calculate error of how
    # far this is from where it should be, and add these up
    err = 0
    for i in range(0,num_ideal_points):
        if ~np.isnan(c[0,i]):
            r = np.sqrt((meshx-c[0,i])**2 + (meshy-c[1,i])**2) # distance
            err = err + (mean_dist[i] - r)**2 # error
    # find minimum, then get x and y values and set as centeroid
    ind = np.argmin(err)
    indi, indj = np.unravel_index(ind,np.shape(err))
    cent[0,frame] = meshx[indi,indj]
    cent[1,frame] = meshy[indi,indj]


In [None]:
# center all points using calculated centroid
for i in range(0,num_ideal_points):
    centered[:,i,:] = data[:,i,:] - cent

In [None]:
# now, align all timepoints
allaligned = np.zeros(np.shape(centered), dtype=object)
alltheta = np.zeros(np.shape(centroid[0,:]), dtype=object)

for frame in range(0, np.size(centroid, axis=1)):
    num_ideal_points = np.size(data[0, :, frame], axis=0)
    num_real_pts = np.count_nonzero(~np.isnan(data[1, :, frame]))
    c = centered[:,:,frame]
    if num_real_pts >= 3:
        theta = np.linspace(0, (2 * math.pi), 101)
        theta = theta[1:-1]
        del rms
        rms = np.zeros(len(theta))
        for i in range(0, len(theta)):
            c_rot = np.matmul(c.T, rotmat(theta[i])) # rotation
            rms[i] = np.nansum((mean_head - c_rot.T) ** 2)  # root mean squared
        # find index for theta with smallest error and rotate by this amount
        alltheta[frame] = np.argmin(rms)
        allaligned[:,:,frame] = np.matmul(c.T, rotmat(alltheta[frame])).T
    elif num_real_pts < 3:
        alltheta[frame] = np.nan
        allaligned[:,:,frame] = np.nan

In [None]:
# head angle was negative of what we want, so this fixes that
alltheta = 2 * math.pi - alltheta
# range -pi to pi
alltheta = np.where(alltheta > math.pi, alltheta, alltheta-2*math.pi)

In [None]:
# plots of head theta
plt.figure(figsize=(15,15))
plt.plot(alltheta)
plt.xlabel('frame')
plt.ylabel('angle')
plt.title('head theta over frames')
plt.show()

In [None]:
allaligned

In [None]:
alltheta

In [None]:
%%% rotate mean head to align to x-axis
longAxis = meanHead([nPts 1],:); %%% line between middle of head and nose points
longTheta = atan2(diff(longAxis(:,2)), diff(longAxis(:,1)));  %%% angle of line
headRot= rotmat(-longTheta);  %%% rotation matrix to fix this
for i = 1:size(aligned,3);
    aligned(:,:,i) = aligned(:,:,i)*headRot';
end
meanHead = nanmean(aligned,3);

In [None]:
# calculate the x/y centroid that best matches the defined distances between marked points and the centroid
xcentsq = np.where(pd.isnull(mean_head[0, :]), np.nan, (mean_head[0, :]**2))
ycentsq = np.where(pd.isnull(mean_head[1, :]), np.nan, (mean_head[0, :]**2))
np.sqrt(xcentsq + ycentsq)

In [None]:
cent = np.zeros([2, np.size(aligned_good.T, axis=0)], dtype=object)

In [None]:
np.shape(cent)

In [None]:

            

            mean_distance = 

            # make a mesh grid that covers x/y position of all head points at this time
            

            # for each head point calculate how far the pixels are from it
            # then calculate error of how far this is from where it should be, then add these up
#             for frame in range(0, len(data)):
#                 err = 0
#                 theta_all = []
#                 aligned_all = []
#                 for pt_time in range(0, num_points):
#                     pt = data[:, i]
#                     if pt != np.nan:
#                         r = np.sqrt((meshx - data(pt, 0))**2 + (meshy - data(pt, 0))**2) # distance
#                         err = err + (mean_distance[i] - r)**2 # error

#                     num_real_pts = num_points - np.count_nonzero(~np.isnan(pt[0]))

#                     # do the alignment if there are at least 4 good points
#                     if num_real_pts >= 4:
#                         c = centered[i]
#                         # if there are no NaNs and it's a perfect timepoint, loop through a range of thetas and rotate the frame by that much
#                         # then, calculate how well it matches the reference
#                         theta = np.linspace(0, (2 * math.pi), 101)
#                         theta = theta[1:-1]
#                         rms = np.zeros(len(theta))
#                         for i in range(0, len(theta)):
#                             c_rot = c * rotmat(theta[i]) # rotation
#                             rms[i] = np.nansum((ref - c_rot) **2) # root mean squared
#                         # find smallest error and rotate by this amount
#                         y, ind = min(rms)
#                         theta_out = 2 * math.pi - theta(ind)
#                         aligned_out = c * rotmat(theta(ind))
#                         theta_all.append(theta_out)
#                         aligned_all.append(aligned_out)
#                     elif num_real_pts < 4:
#                         theta_out = np.nan
#                         aligned_out = np.nan
#                         theta_all.append(theta_out)
#                         aligned_all.append(aligned_out)

In [None]:
pt_input[:-4, :]

In [None]:
# STILL WORKING ON THIS
# eye_calibration.py

import matplotlib.pyplot as plt
import numpy as np

all_eye_ellipses = left_ellipse
all_eye_dlc_pts = lefteye
side = 'left'
savepath_input = savepath
ell_thresh = 0.90

for trial_num in range(0, len(trial_id_list)):
        current_trial_name = trial_id_list[trial_num]
        if all_eye_ellipses.sel(trial=current_trial_name) is not None:
            ellipse_data = all_eye_ellipses.sel(trial=current_trial_name)
            dlc_data = all_eye_dlc_pts.sel(trial=current_trial_name)

            # get out parameters of the selected trial
            thetas = ellipse_data.sel(ellipse_params='theta').values
            phis = ellipse_data.sel(ellipse_params='phi').values
            longaxes = ellipse_data.sel(ellipse_params='longaxis_all').values
            shortaxes = ellipse_data.sel(ellipse_params='shortaxis_all').values
            camcenter = (ellipse_data['cam_center_x'].values, ellipse_data['cam_center_y'].values)
            
            
            
            
            
            
            
            
            
            
            
            good_ell_times = np.argwhere((shortaxes / longaxes) < ell_thresh)
            
            # calibration figure 1: eye axes relative to center
            A = [np.cos(camcenter), np.sin(camcenter)]
            b = np.diag(np.matmul(A, np.stack([np.array(thetas),np.array(phis)]).T))
            cent_adj = np.linalg.lstsq(np.matmul(A.T, A), np.matmul(A.T, b))
            
            plt.figure(figsize=(10,10))
            for i in range(0, len(good_ell_times)):
                plt.plot(thetas[good_ell_times[i]] + [-5 * np.cos(cent_adj[good_ell_times[i]]), 5 * np.cos(cent_adj[good_ell_times[i]])], phis[good_ell_times[i]] + [-5 * np.sin(cent_adj[good_ell_times[i]]), 5 * np.sin(cent_adj[good_ell_times[i]])], 'ko')
            plt.plot(camcenter[0], camcenter[1], 'r*')
            plt.title('eye axes relative to center')
            plt.show
            
            # calibraiton figure 2: example frame's parameters
#             ellipticity = shortaxes / longaxes
#             p = 1 - (ellipticity) ** 2
#             q = np.linalg.norm((np.stack([np.array(thetas),np.array(phis)]).T - camcenter), 2, 1).T
#             pix2deg_scale = np.nansum(np.sqrt(p.T) * q) / np.nansum(p)
#             theta_rad = np.asin((thetas - cent_adj[0]) * (1 / pix2deg_scale))
#             theta_deg = np.asind((thetas - cent_adj[0]) * (1 / pix2deg_scale))
#             phi_deg = np.asind((phis - cent_adj[1]) / cos(theta_rad * (1 / pix2deg_scale)))
#             ind = 50
#             R = np.linspace(0, 2*pi, 100)
#             w = cent_adj
#             L = longaxes[ind]
#             l = shortaxes[ind]
#             xc = thetas[ind]
#             yc = phi[ind]
#             rotation1 = [[np.cos(w), -np.sin(w)],[np.sin(w), np.cos(w)]]
#             L1 = [[L,0],[0,l]]
#             c1 = [[xc],[yc]]
#             q = [[np.cos(R)],[np.sin(R)]]
#             qstar = rotation1 * L1 * q + c1
#             qcirc1 = [[L / pix2deg_scale, 0],[0, L / pix2deg_scale]] * q
#             qcirc2 = [[L,0],[0,L]] * q + cent_adj
#             theta2 = np.real(np.asin((thetas - cent_adj)* (1 / pix2deg_scale)))
#             phi2 = np.real(np.asin(phis - cent_adj) / (np.cos(theta2) * (1 / pix2deg_scale)))
#             new_cent = pix2deg_scale * [[np.sin(theta2)],[np.sin(phi2) * np.cos(theta2)]] + cent_adj
#             pointsrot = new_cent + pix2deg_scale * [[np.cos(theta2), 0],[-np.sin(phi2), np.cos(phi2)]] * qcirc1
#             omega_val = camcent * 180 / np.pi
            
#             plt.figure(figsize=(10,10))
#             plt.plot(qstar[1], qstar[2], 'g-')
#             plt.plot(points)
#             # ...
#             plt.title('omega=' + str(omega_val))
#             plt.show
            
#             # calibration figure 3
#             xvals = np.linalg.norm(np.stack([np.array(thetas),np.array(phis)]).T - cent_adj, ord=None)
#             yvals = (pix2deg_scale * np.sqrt(1-(shortaxes / longaxes))**2)
#             calibR, calibM, b = np.regression(xvals, yvals.T)
            
#             plt.figure(figsize=(10,10))
#             plt.plot(xvals, yvals, 'k.')
#             plt.title('scale=' + str(pix2deg_scale) + ' r=' + str(calibR) + ' m=' + str(calibM))
#             plt.xlabel('pupil camera dist')
#             plt.ylabel('scale * ellipticity')
#             plt.show
            
#             # calibration figure 4: camera center calibration
#             delta = cent_adj - np.stack([np.array(thetas),np.array(phis)]).T
            
#             plt.figure(figsize=(10,10))
#             plot(np.linalg.norm(delta,2,1), 
            
            
            
            
            
#             plt.savefig(savepath_input + '/' + current_trial_name + '/' + str(side) + '_side_ellipse_calibration.png', dpi=300)
#             plt.close()

In [None]:
np.shape(A)

In [None]:
# confirm that the eye tracking has done an alright job
print('checking calibration of eyes')
plot_check_eye_calibration(left_ellipse, lefteye, trial_id_list, 'left', savepath)
plot_check_eye_calibration(right_ellipse, righteye, trial_id_list, 'right', savepath)

In [None]:
if savefig is True:
    for current_trial in trial_id_list:
        print('writing video data')
        # go through the trial key and pull out unique sections that appear in video names
        mouse_key = current_trial[6:11]
        trial_key = current_trial[18:]

        # get out the video associated with this trial for every camera viewpoint
        righteye_vids = [i for i in righteye_vid_list if mouse_key and trial_key in i]
        lefteye_vids = [i for i in lefteye_vid_list if mouse_key and trial_key in i]
        topdown_vids = [i for i in topdown_vid_list if mouse_key and trial_key in i]
        worldcam_vids = [i for i in worldcam_vid_list if mouse_key and trial_key in i]

        # plot the points on each camera view and save them out separately
        topdown_vid = topdown_vids[0]
        td_pt_data = preened_topdown.sel(trial=current_trial)
        plot_pts_on_vid(current_trial, 't', topdown_vid, savepath, td_pt_data)

        try:
            righteye_vid = righteye_vids[0]
            re_pt_data = righteye.sel(trial=current_trial)
            re_pt_ell = right_ellipse.sel(trial=current_trial)
            plot_pts_on_vid(current_trial, 'r', righteye_vid, savepath, re_pt_data, re_pt_ell)
        except IndexError:
            pass

        try:
            worldcam_vid = worldcam_vids[0]
            plot_pts_on_vid(current_trial, 'w', worldcam_vid, savepath)
        except IndexError:
            pass

        try:
            lefteye_vid = lefteye_vids[0]
            le_pt_data = lefteye.sel(trial=current_trial)
            le_pt_ell = left_ellipse.sel(trial=current_trial)
            plot_pts_on_vid(current_trial, 'l', lefteye_vid, savepath, le_pt_data, le_pt_ell)
        except IndexError:
            pass
print('done writing videos')

In [None]:
# save out the xarrays as .nc files
if savenc is True:
    print('saving out xarray data')
    ds_topdown = preened_topdown.to_dataset(name='topdown')
    ds_leftellipse = left_ellipse.to_dataset(name='left_ellipse')
    ds_rightellipse = right_ellipse.to_dataset(name='right_ellipse')

    ds_leftellipse = ds_leftellipse.assign({'cam_center_x': ds_leftellipse['cam_center_x'].values})
    ds_leftellipse = ds_leftellipse.assign({'cam_center_y': ds_leftellipse['cam_center_y'].values})
    ds_rightellipse = ds_rightellipse.assign({'cam_center_x': ds_rightellipse['cam_center_x'].values})
    ds_rightellipse = ds_rightellipse.assign({'cam_center_y': ds_rightellipse['cam_center_y'].values})
    ds_leftellipse = xr.Dataset.drop_vars(ds_leftellipse, names=['eye_side', 'cam_center_x', 'cam_center_y'])
    ds_rightellipse = xr.Dataset.drop_vars(ds_rightellipse, names=['eye_side', 'cam_center_x', 'cam_center_y'])

    gathered = xr.merge([ds_topdown, ds_leftellipse, ds_rightellipse])
    gathered_path = savepath + 'params.nc'
    gathered.to_netcdf(gathered_path)

    ds_topdown_raw = topdown.to_dataset(name='topdown')
    ds_lefteye_raw = lefteye.to_dataset(name='lefteye')
    ds_righteye_raw = righteye.to_dataset(name='righteye')

    gathered_raw = xr.merge([ds_topdown_raw, ds_lefteye_raw, ds_righteye_raw])
    gathered_raw_path = savepath + 'raw_points.nc'
    gathered_raw.to_netcdf(gathered_raw_path)

    ds_topdown_time = all_topdownTS.to_dataset(name='topdown_time')
    ds_lefteye_time = all_lefteyeTS.to_dataset(name='lefteye_time')
    ds_righteye_time = all_righteyeTS.to_dataset(name='righteye_time')

    gathered_time = xr.merge([ds_topdown_time, ds_lefteye_time, ds_righteye_time])
    gathered_time_path = savepath + 'timestamps.nc'
    gathered_time.to_netcdf(gathered_time_path)