# Visualize motif sequences

In [1]:
# Imports

import logging
logging.getLogger("param.Dimension").setLevel(logging.CRITICAL)
logging.getLogger("param.ParameterizedMetaclass").setLevel(logging.CRITICAL)
logging.getLogger("param.SpreadPlot").setLevel(logging.CRITICAL)
logging.getLogger("param.CurvePlot").setLevel(logging.CRITICAL)
logging.getLogger("param.AdjointLayout").setLevel(logging.CRITICAL)
logging.getLogger("param.HoloMap").setLevel(logging.CRITICAL)
logging.getLogger("param.OverlayPlot").setLevel(logging.CRITICAL)
logging.getLogger("param.BarPlot").setLevel(logging.CRITICAL)
logging.getLogger("param.ErrorPlot").setLevel(logging.CRITICAL)
logging.getLogger("param.RasterPlot").setLevel(logging.CRITICAL)
logging.getLogger("param.Layout").setLevel(logging.CRITICAL)
logging.getLogger("param.PointPlot").setLevel(logging.CRITICAL)
logging.getLogger("param.DynamicMap").setLevel(logging.CRITICAL)
logging.getLogger("param.Callable").setLevel(logging.CRITICAL)
logging.getLogger("param.Image").setLevel(logging.CRITICAL)
logging.getLogger("param.Overlay").setLevel(logging.CRITICAL)
logging.getLogger("param.Scatter").setLevel(logging.CRITICAL)
logging.getLogger("param.LayoutPlot").setLevel(logging.CRITICAL)
logging.getLogger("param.Curve").setLevel(logging.CRITICAL)

import os
import sys
sys.path.insert(0, os.path.abspath(r'D:\Code Repos\prey_capture'))

import panel as pn
import holoviews as hv
from holoviews import opts, dim
hv.extension('bokeh')
from bokeh.resources import INLINE

import paths
import functions_bondjango as bd
import pandas as pd
import numpy as np
import functions_plotting as fp
import functions_data_handling as fd
from scipy.stats import sem
import sklearn.decomposition as decomp
import umap
import sklearn.mixture as mix
from scipy.stats import sem
import importlib
import processing_parameters
import cv2
import pickle as pk
import functions_vame as fv
import functions_io as fi

from pprint import pprint as pp

In [2]:
# Get the path to the involved files
importlib.reload(processing_parameters)

# define the type of VAME
vame_type = 'prey_capture_15'
# define the number of frames to remove at beginning and end (due to VAME interval)
vame_interval = 15
# define the folder
target_folder = os.path.join(r'J:\Drago Guggiana Nilo\Prey_capture\temp_VAME', vame_type)

# load the sorting
motif_sort = np.array(processing_parameters.motif_sort)
motif_revsort = np.array(processing_parameters.motif_revsort)

# get a list of the result folders
result_list = os.listdir(os.path.join(target_folder,'results'))
# load the search string
vame_vis_string = processing_parameters.vame_vis_string

# Load the matching prey capture data

# using the slug, perform serial calls to the database
# (super inefficient, but this is temporary as the VAME data should be includedin the hdf5 file)

# for all the files

# define the search string
# query the database for data to plot
data_all = bd.query_database('analyzed_data', vame_vis_string)
data_path = data_all[0]['analysis_path']
data_vame_name = data_all[0]['slug'].replace('_preprocessing', '')

# load the data
beh_data = pd.read_hdf(data_path, 'full_traces')
beh_data = beh_data.iloc[vame_interval:-vame_interval, :].reset_index(drop=True)
# load the frame bounds
frame_bounds = pd.read_hdf(data_path, 'frame_bounds')

# load the latent and labels
label_list = motif_revsort[np.load(os.path.join(target_folder,'results',data_vame_name,'VAME',
                                  'kmeans-15','15_km_label_'+data_vame_name+'.npy'))]
              
latent_list = np.load(os.path.join(target_folder,'results',data_vame_name,'VAME',
                                   'kmeans-15','latent_vector_'+data_vame_name+'.npy')) 
               

# load the aligned data
data_list = np.load(os.path.join(target_folder,'data',data_vame_name,
                                 data_vame_name+'-PE-seq.npy'))
data_list = data_list[:, vame_interval:-vame_interval]
print(beh_data.shape)

(26, 31)


In [3]:
# get the dlc coordinates
likelihood_threshold = 0.8
file_path_dlc = os.path.join(paths.videoexperiment_path,
                                    data_all[0]['slug'].replace('_preprocessing', '_dlc.h5'))
# load the bonsai info
raw_h5 = pd.read_hdf(file_path_dlc)
# get the column names
column_names = raw_h5.columns

# DLC in small arena
filtered_traces = pd.DataFrame(raw_h5[[
    [el for el in column_names if ('mouseSnout' in el) and ('x' in el)][0],
    [el for el in column_names if ('mouseSnout' in el) and ('y' in el)][0],
    [el for el in column_names if ('mouseBarL' in el) and ('x' in el)][0],
    [el for el in column_names if ('mouseBarL' in el) and ('y' in el)][0],
    [el for el in column_names if ('mouseBarR' in el) and ('x' in el)][0],
    [el for el in column_names if ('mouseBarR' in el) and ('y' in el)][0],
    [el for el in column_names if ('mouseHead' in el) and ('x' in el)][0],
    [el for el in column_names if ('mouseHead' in el) and ('y' in el)][0],
    [el for el in column_names if ('mouseBody1' in el) and ('x' in el)][0],
    [el for el in column_names if ('mouseBody1' in el) and ('y' in el)][0],
    [el for el in column_names if ('mouseBody2' in el) and ('x' in el)][0],
    [el for el in column_names if ('mouseBody2' in el) and ('y' in el)][0],
    [el for el in column_names if ('mouseBody3' in el) and ('x' in el)][0],
    [el for el in column_names if ('mouseBody3' in el) and ('y' in el)][0],
    [el for el in column_names if ('mouseBase' in el) and ('x' in el)][0],
    [el for el in column_names if ('mouseBase' in el) and ('y' in el)][0],
    [el for el in column_names if ('cricketHead' in el) and ('x' in el)][0],
    [el for el in column_names if ('cricketHead' in el) and ('y' in el)][0],
    [el for el in column_names if ('cricketBody' in el) and ('x' in el)][0],
    [el for el in column_names if ('cricketBody' in el) and ('y' in el)][0],
]].to_numpy(), columns=['mouse_snout_x', 'mouse_snout_y', 'mouse_barl_x', 'mouse_barl_y',
                        'mouse_barr_x', 'mouse_barr_y', 'mouse_head_x', 'mouse_head_y',
                        'mouse_x', 'mouse_y', 'mouse_body2_x', 'mouse_body2_y',
                        'mouse_body3_x', 'mouse_body3_y', 'mouse_base_x', 'mouse_base_y',
                        'cricket_0_head_x', 'cricket_0_head_y', 'cricket_0_x', 'cricket_0_y'])

# get the likelihoods
likelihood_frame = pd.DataFrame(raw_h5[[
    [el for el in column_names if ('mouseSnout' in el) and ('likelihood' in el)][0],
    [el for el in column_names if ('mouseBarL' in el) and ('likelihood' in el)][0],
    [el for el in column_names if ('mouseBarR' in el) and ('likelihood' in el)][0],
    [el for el in column_names if ('mouseHead' in el) and ('likelihood' in el)][0],
    [el for el in column_names if ('mouseBody1' in el) and ('likelihood' in el)][0],
    [el for el in column_names if ('mouseBody2' in el) and ('likelihood' in el)][0],
    [el for el in column_names if ('mouseBody3' in el) and ('likelihood' in el)][0],
    [el for el in column_names if ('mouseBase' in el) and ('likelihood' in el)][0],
    [el for el in column_names if ('cricketHead' in el) and ('likelihood' in el)][0],
    [el for el in column_names if ('cricketBody' in el) and ('likelihood' in el)][0],
]].to_numpy(), columns=['mouse_snout', 'mouse_barl', 'mouse_barr', 'mouse_head', 'mouse', 'mouse_body2',
                        'mouse_body3', 'mouse_base',
                        'cricket_0_head', 'cricket_0'])

# nan the trace where the likelihood is too low
# for all the columns
for col in likelihood_frame.columns:
    # get the vector for nans
    nan_vector = likelihood_frame[col] < likelihood_threshold
    # nan the points
    filtered_traces.loc[nan_vector, col+'_x'] = np.nan
    filtered_traces.loc[nan_vector, col+'_y'] = np.nan
        
# trim to the current bounds
filtered_traces = \
    filtered_traces.iloc[frame_bounds.loc[0, 'start']:frame_bounds.loc[0, 'end']-1, :].reset_index(drop=True)

filtered_traces = filtered_traces.iloc[vame_interval:-vame_interval, :].reset_index(drop=True)
print(filtered_traces.shape)

(26, 20)


In [4]:
%%time
# get the video
# assemble the path
video_path = os.path.join(paths.videoexperiment_path,
                          data_all[0]['slug'].replace('_preprocessing', '.avi'))
# create the video object
cap = cv2.VideoCapture(video_path)
# allocate memory for the corners
frame_list = []
# # define sigma for the edge detection parameters
# sigma = 0.2
# get the frames to mode
for frames in np.arange(frame_bounds.loc[0, 'end']-1):

    # read the image
    frame_list.append(cap.read()[1])

# release the capture
cap.release()

frame_list = frame_list[frame_bounds.loc[0, 'start']:]
# keep this for the motif videos
frames_formotif = frame_list.copy()
# trim to interval
frame_list = frame_list[vame_interval:-vame_interval]
print(len(frame_list))
print(frame_list[0].shape)
print(latent_list.shape)

26
(1024, 1280, 3)
(26, 15)
Wall time: 1.32 s


In [5]:
%%time
# get the UMAP embedding
embedded_data = np.load(os.path.join(target_folder, 'UMAP_result.npy'))

# generate the model name
model_name = os.path.join(target_folder, 'UMAP_model.pk')
with open(model_name, 'rb') as file:
    reducer = pk.load(file)

Wall time: 7.73 s


In [6]:
%%time
# Embed the current data
current_data = reducer.transform(latent_list)
print(current_data.shape)

(26, 2)
Wall time: 5.46 s


In [7]:
# generate an animation showing the frame in the video next to the umap position of the pose

# define the indexes
cols_x = [el for el in beh_data.columns if (('mouse' in el) and ('_x' in el))]
cols_y = [el for el in beh_data.columns if (('mouse' in el) and ('_y' in el))]
cols_cricket_x = [el for el in beh_data.columns if (('cricket' in el) and ('_x' in el))]
cols_cricket_y = [el for el in beh_data.columns if (('cricket' in el) and ('_y' in el))]

index_x = [0, 2, 4, 6, 8, 10, 12, 14]
index_y = [1, 3, 5, 7, 9, 11, 13, 15]
cricket_idx_x = [16, 18]
cricket_idx_y = [17, 19]

def frame_plot(time):
    
    current_frame = frame_list[time]
    return hv.Image(current_frame).opts(
        invert_yaxis=True, invert_xaxis=True, cmap='Gray')#, width=600, height=600)

def umap_trajectory(time):
    current_points = current_data[:time, :]
    return hv.Scatter((current_points[:, 0], current_points[:, 1]))
def umap_current(time):
    current_point = current_data[time, :]
    return hv.Scatter((current_point[0], current_point[1])).opts(color='red')

def skeleton(time):
    current_skeleton_x = beh_data.loc[time, cols_x].to_numpy()/40 - 0.5
    current_skeleton_y = beh_data.loc[time, cols_y].to_numpy()/40 - 0.5
    
#     current_skeleton_x = (filtered_traces.loc[time, index_x].to_numpy()/1200) - 0.5
#     current_skeleton_y = 1-(filtered_traces.loc[time, index_y].to_numpy()/1200) - 0.5
    
    return hv.Curve((current_skeleton_x, current_skeleton_y))

def cricket_skeleton(time):
    current_cricket_x = beh_data.loc[time, cols_cricket_x].to_numpy()/40 - 0.5
    current_cricket_y = beh_data.loc[time, cols_cricket_y].to_numpy()/40 - 0.5
    return hv.Curve((current_cricket_x, current_cricket_y))

def motif(time):
    current_motif = label_list[:time]
    return hv.Curve(current_motif)

def static_pose(time):
    current_pose_x = data_list[index_x, time].T
    current_pose_y = data_list[index_y, time].T
    return hv.Curve((current_pose_x, current_pose_y))

def static_cricket(time):
    cricket_pose_x = data_list[cricket_idx_x, time].T
    cricket_pose_y = data_list[cricket_idx_y, time].T
    return hv.Curve((cricket_pose_x, cricket_pose_y))

frame_map = hv.DynamicMap(frame_plot, kdims=['time'])
umap_map = hv.DynamicMap(umap_trajectory, kdims=['time']).opts(xlim=(-10, 30), ylim=(-10, 30))
current_map = hv.DynamicMap(umap_current, kdims=['time']).opts(xlim=(-10, 30), ylim=(-10, 30))
skeleton_map = hv.DynamicMap(skeleton, kdims=['time'])
c_skeleton_map = hv.DynamicMap(cricket_skeleton, kdims=['time'])
motif_map = hv.DynamicMap(motif, kdims=['time']).opts(ylim=(-1, 31), xlim=(0, 200))
pose_map = hv.DynamicMap(static_pose, kdims=['time']).opts(xlim=(-40, 40), ylim=(-40, 40))
c_pose_map = hv.DynamicMap(static_cricket, kdims=['time'])

# umap_base = hv.Scatter(embedded_data)

# both_map = (frame_map).opts(width=600, height=400)
# sub_map = (umap_map*current_map+motif_map).opts(shared_axes=False)
# both_map = (frame_map*skeleton_map+sub_map).opts(
#     width=1000, height=800, shared_axes=False).cols(1)
both_map = (frame_map*skeleton_map*c_skeleton_map+umap_map*current_map+motif_map+pose_map*c_pose_map)
both_map.opts(width=1000, height=800, shared_axes=False).cols(2)
        
both_panel = pn.panel(both_map.redim.range(time=(0, len(frame_list)-1)), 
                      center=True, widget_location='top')
both_panel
#     current_coord = current_data[time, :]
    
    

  [cmap for cmap in cm.cmap_d if not
  [cmap for cmap in cm.cmap_d if not


In [8]:
# Get the motif locations

# get the motif number
motif_number = latent_list.shape[1]
# turn the movie into an array
movie_array = np.array(frame_list)
# allocate memory for all the locations
location_perfile = []
duration_perfile = []
# for all the motifs
for motif in np.arange(motif_number):

    # find all the starts and ends for this motif
    m_idx = (label_list==motif).astype(int)
    starts = np.argwhere(np.diff(np.pad(m_idx, (1, 1), mode='constant', constant_values=(0, 0)))==1)
    ends = np.argwhere(np.diff(np.pad(m_idx, (1, 1), mode='constant', constant_values=(0, 0)))==-1)

    # skip if any of the arrays is empty
    if (starts.shape[0] == 0) or (ends.shape[0] == 0):
        duration_perfile.append(np.empty((0, 1)))
        location_perfile.append(np.empty((0, 1)))
        continue
    # trim the starts and ends based on ordering
    if starts[0] > ends[0]:
        if ends.shape[0] > 1:
            ends = ends[1:]
        else:
            duration_perfile.append(np.empty((0, 1)))
            location_perfile.append(np.empty((0, 1)))
            continue
    if starts[-1] > ends [-1]:
        if starts.shape[0] > 1:
            starts = starts[:-1]
        else:
            duration_perfile.append(np.empty((0, 1))) 
            location_perfile.append(np.empty((0, 1)))
            continue
    # trim the starts or ends depending on size
    if starts.shape[0] > ends.shape[0]:
        starts = starts[:-1]
    if ends.shape[0] > starts.shape[0]:
        ends = ends[1:]
    # make sure the ends are always bigger than the starts
    try: 
        assert np.all((ends-starts)>0) 
    except AssertionError:
        print(str(idx)+'_'+str(motif))
        print(starts)
        print(ends)

    # save the locations for this motif
    location_perfile.append(starts)
    duration_perfile.append(ends-starts)
        
print(location_perfile[0])
# print(duration_perfile[0])
print(label_list)
# print(location_perfile)
# print(duration_perfile)

[]
[3 3 3 3 8 8 7 7 7 5 5 5 5 5 5 5 5 5 5 7 7 7 7 7 7 7]


In [9]:
# Get the snippets for a single motif

# define the target motif
target_motif = 13

# allocate a list for the frames
video_frames = []
indexes = []

# for all instances of the motif
for idx, instances in enumerate(location_perfile[target_motif]):

    # get the indexes
    instance_idx = np.array(np.arange(instances[0], instances[0]+duration_perfile[target_motif][idx][0]))
    
    # get the video frames
    video_frames.append(movie_array[instance_idx])
    indexes.append(instance_idx)

# plot
video_frames = np.concatenate(video_frames, axis=0)
indexes = np.concatenate(indexes, axis=0)

# create the function for the dynamic map
def show_frame(time):
    current_frame = video_frames[time]
    return hv.Image(current_frame).opts(invert_yaxis=True, invert_xaxis=True, cmap='Gray')
def skeleton(time):
    current_skeleton_x = beh_data.loc[indexes[time], cols_x].to_numpy()/40 - 0.5
    current_skeleton_y = beh_data.loc[indexes[time], cols_y].to_numpy()/40 - 0.5
    return hv.Curve((current_skeleton_x, current_skeleton_y))

def cricket_skeleton(time):
    current_cricket_x = beh_data.loc[indexes[time], cols_cricket_x].to_numpy()/40 - 0.5
    current_cricket_y = beh_data.loc[indexes[time], cols_cricket_y].to_numpy()/40 - 0.5
    return hv.Curve((current_cricket_x, current_cricket_y))

# create the dynamic map
frame_map = hv.DynamicMap(show_frame, kdims=['time'])
skeleton_map = hv.DynamicMap(skeleton, kdims=['time'])
c_skeleton_map = hv.DynamicMap(cricket_skeleton, kdims=['time'])

both_map = (frame_map*skeleton_map*c_skeleton_map).opts(width=800, height=600, shared_axes=False)
        
both_panel = pn.panel(both_map.redim.range(time=(0, len(video_frames)-1)), 
                      center=True, widget_location='top')
both_panel


ValueError: need at least one array to concatenate

In [None]:
# Align the video egocentrically
importlib.reload(fv)
# define the path for saving the movies
temp_path = paths.temp_path
# clean the folder
fi.delete_contents(temp_path)

# create a bounded movie to align later
# assemble the bounded movie path
bounded_path = os.path.join(temp_path, 'bounded.avi')

# save the bounded movie
# get the width and height
width = frame_list[0].shape[1]
height = frame_list[0].shape[0]

# create the writer
out = cv2.VideoWriter(bounded_path,cv2.VideoWriter_fourcc('M','J','P','G'), 10, (1280,1024))
# save the movie
for frames in frames_formotif:
    out.write(frames)

out.release()

# create the egocentric movie
path_dlc = data_path
path_vame = target_folder
file_format = '.avi'
crop_size = (200, 200)
use_video = True
check_video = False
save_align = False

_, video_frames = fv.run_alignment(path_dlc, path_vame, file_format, crop_size,
                                   use_video=use_video, check_video=check_video, 
                                   save_align=save_align, video_path=bounded_path)
# load the egocentric movie

# create the single motif segments as above

In [None]:
# create the egocentric movie

# assemble the bounded movie path
egocentric_path = os.path.join(temp_path, 'egocentric.avi')

# save the bounded movie
# get the width and height
width = frame_list[0].shape[1]
height = frame_list[0].shape[0]
print(np.max(video_frames[0]))
# hv.Image(video_frames[10]).opts(tools=['hover'])

# cv2.imshow('Frame', video_frames[0])

# print(video_frames[2].shape)

# create the writer
out2 = cv2.VideoWriter(egocentric_path,cv2.VideoWriter_fourcc('M','J','P','G'), 10, (2000,2000))
# save the movie
for frames in video_frames:
    out2.write(frames.astype('uint8'))

out2.release()