In [None]:
# imports
# import pixiedust
import logging
logging.getLogger("param.Dimension").setLevel(logging.CRITICAL)
logging.getLogger("param.ParameterizedMetaclass").setLevel(logging.CRITICAL)
logging.getLogger("param.SpreadPlot").setLevel(logging.CRITICAL)
logging.getLogger("param.CurvePlot").setLevel(logging.CRITICAL)
logging.getLogger("param.AdjointLayout").setLevel(logging.CRITICAL)
logging.getLogger("param.HoloMap").setLevel(logging.CRITICAL)
logging.getLogger("param.OverlayPlot").setLevel(logging.CRITICAL)
logging.getLogger("param.BarPlot").setLevel(logging.CRITICAL)
logging.getLogger("param.ErrorPlot").setLevel(logging.CRITICAL)
logging.getLogger("param.RasterPlot").setLevel(logging.CRITICAL)
logging.getLogger("param.Layout").setLevel(logging.CRITICAL)
logging.getLogger("param.PointPlot").setLevel(logging.CRITICAL)

import os
import sys
sys.path.insert(0, os.path.abspath(r'D:\Code Repos\prey_capture'))

import panel as pn
import holoviews as hv
from holoviews import opts, dim
hv.extension('bokeh')
from bokeh.resources import INLINE

import paths
import functions_bondjango as bd
import pandas as pd
import numpy as np
import functions_plotting as fp
import functions_data_handling as fd
import functions_kinematic as fk
from scipy.stats import sem
import sklearn.decomposition as decomp
import umap
import sklearn.mixture as mix
from scipy.stats import sem
import pickle as pk
import itertools as it
import processing_parameters

from pprint import pprint as pp

# # define the name to be used for the saved figures
# save_name = 'acrossTrials'
# line_width = 5

In [None]:
# Load the analyzed VAME data
# define the type of VAME
vame_type = 'prey_capture_15'
# define the folder
target_folder = os.path.join(r'J:\Drago Guggiana Nilo\Prey_capture\temp_VAME', vame_type)

# get a list of the result folders
result_list = os.listdir(os.path.join(target_folder,'results'))

# pp(result_list)
# load the latent and labels
label_list = [np.load(os.path.join(target_folder,'results',el,'VAME','kmeans-15','15_km_label_'+el+'.npy')) 
              for el in result_list]
latent_list = [np.load(os.path.join(target_folder,'results',el,'VAME','kmeans-15','latent_vector_'+el+'.npy')) 
               for el in result_list]

# load the aligned data
data_list = [np.load(os.path.join(target_folder,'data',el,el+'-PE-seq.npy')) for el in result_list]

# get the motif sorting
motif_sort = np.array(processing_parameters.motif_sort)
motif_revsort = np.array(processing_parameters.motif_revsort)


In [None]:
# %%time
# # UMAP embedding of the VAME data

# # compile the data
# compiled_latent = np.vstack(latent_list)
# # embed using UMAP
# # original parameters 0.5 and 10
# # 0.1 and 30 also works
# # 0.05 and 30 works too
# reducer = umap.UMAP(min_dist=0.5, n_neighbors=10)
# embedded_data = reducer.fit_transform(compiled_latent)

# # save the embedding
# np.save(os.path.join(target_folder, 'UMAP_result'), embedded_data)

# # generate the model name
# model_name = os.path.join(target_folder, 'UMAP_model.pk')
# # save the estimator
# with open(model_name, 'wb') as file:
#     pk.dump(reducer, file)

In [None]:
# Load a pre-existing embedding
embedded_data = np.load(os.path.join(target_folder, 'UMAP_result.npy'))

# generate the model name
model_name = os.path.join(target_folder, 'UMAP_model.pk')
with open(model_name, 'rb') as file:
    reducer = pk.load(file)

In [None]:

# Calculate average start and end of the motifs

# plot the average and standard deviation of each motif
# initialize a list for the averages and stds
motif_average = []
motif_sem = []
# for all the files
for idx, data in enumerate(data_list):
    # get the labels
    labels = label_list[idx]
    # cut the data by the time window used
#     data = data[:, 15:-15]
    # initialize a list for the starts and ends averages and stds
    temp_average = np.zeros((latent_list[0].shape[1], data.shape[0], 2))
    temp_sem = np.zeros((latent_list[0].shape[1], data.shape[0], 2))

    # for all motifs
    for motif in np.arange(latent_list[0].shape[1]):
        # find all the starts and ends for this motif
        m_idx = (labels==motif).astype(int)
        starts = np.argwhere(np.diff(m_idx)==1) + 1
        ends = np.argwhere(np.diff(m_idx)==-1) + 1
        
        # average them
        temp_average[motif, :, 0] = np.nanmean(data[:, starts], axis=1).flatten()
        temp_average[motif, :, 1] = np.nanmean(data[:, ends], axis=1).flatten()
        
        temp_sem[motif, :, 0] = sem(data[:, starts], axis=1, nan_policy='omit').flatten()
        temp_sem[motif, :, 1] = sem(data[:, ends], axis=1, nan_policy='omit').flatten()
    # save the file data
    motif_average.append(temp_average)
    motif_sem.append(temp_sem)

In [None]:
# Calculate the distributions of lengths for the motifs

# initialize a list for the durations
motif_duration = []
# also allocate a list for the motif location
motif_location = []
# get the number of motifs
num_motifs = latent_list[0].shape[1]
# for all the files
for idx, data in enumerate(data_list):
    # get the labels
    labels = label_list[idx]
    
    # initialize a list for the durations
    motif_perfile = []
    location_perfile = []
    # for all motifs
    for motif in np.arange(num_motifs):
        # find all the starts and ends for this motif
        m_idx = (labels==motif).astype(int)
        starts = np.argwhere(np.diff(m_idx)==1) + 1
        ends = np.argwhere(np.diff(m_idx)==-1) + 1

        # skip if any of the arrays is empty
        if (starts.shape[0] == 0) or (ends.shape[0] == 0):
            motif_perfile.append(np.empty((0, 1)))
            location_perfile.append(np.empty((0, 1)))
            continue
        # trim the starts and ends based on ordering
        if starts[0] > ends[0]:
            if ends.shape[0] > 1:
                ends = ends[1:]
            else:
                motif_perfile.append(np.empty((0, 1)))
                location_perfile.append(np.empty((0, 1)))
                continue
        if starts[-1] > ends [-1]:
            if starts.shape[0] > 1:
                starts = starts[:-2]
            else:
                motif_perfile.append(np.empty((0, 1))) 
                location_perfile.append(np.empty((0, 1)))
                continue
            
        # trim the starts or ends depending on size
        if starts.shape[0] > ends.shape[0]:
            starts = starts[:-2]
        if ends.shape[0] > starts.shape[0]:
            ends = ends[1:]
        # make sure the ends are always bigger than the starts
        try: 
            assert np.all((ends-starts)>0) 
        except AssertionError:
            print(str(idx)+'_'+str(motif))
            print(starts)
            print(ends)
            
        # get the duration of each motif instance and save
        motif_perfile.append(ends-starts) 
        location_perfile.append(starts)
    # save the start and origin
    motif_location.append(location_perfile)
    # save the file data
    motif_duration.append(motif_perfile)

In [None]:
# plot the durations

# allocate a list for the bars
bar_list = []
# also allocate a list for the median durations
median_duration = []
# for all the motifs
for motif in np.arange(num_motifs):
    # get the current motif
#     current_motif = motif_duration[:][motif]
    current_motif = [el[motif] for el in motif_duration]
    
    # concatenate
    current_motif = np.vstack(current_motif)
    # get and store the median
    median_duration.append(np.median(current_motif))
    
    frequencies, edges = np.histogram(current_motif, 50)
#     print(np.vstack(current_motif).shape)
    # generate a histogram of the current motif's durations across trials
    bar = hv.Bars((edges,frequencies))
    bar.opts(xrotation=45)
    bar_list.append(bar)
print(median_duration)
hv.Layout(bar_list, kdims='x')

In [None]:
# plot the average progression for each motif

# define the tolerance in the median
# median_tol = 
# allocate a list of the visualizations
motif_vis = []
motif_average_list = []

# for all the motifs
for motif in np.arange(num_motifs):
    # get the median
    current_median = median_duration[motif]
    # allocate a list per file
    motif_list = []
    # get the coordinates across trials
    # for all the files
    for idx, data in enumerate(data_list):
        # get the labels
        labels = label_list[idx]
        
        # find the starts 
        start_idx = np.argwhere(motif_duration[idx][motif]==current_median)
        starts = motif_location[idx][motif][start_idx]
        
        # skip if no hits
        if len(starts)==0:
            continue
        # get the coordinates
        motif_list.append([data[:, int(np.round(el)):int(np.round(el)+np.round(current_median))] \
                           for el in starts[0]])
        
    # store
    motif_vis.append(motif_list)
    # average
    motif_average = np.mean(np.concatenate(motif_list, axis=0), axis=0).T
#     motif_average = np.array([el for el in motif_list[0]])[0, :, :].T
#     print(motif_average.shape)
    motif_average_list.append(np.stack((motif_average[0, :].T, motif_average[-1, :].T), axis=1))
#     print(motif_average.shape)

    def mouse_trajectory(time):
#         x = motif_average[:time+1, [0, 2, 4, 6, 8]]
#         y = motif_average[:time+1, [1, 3, 5, 7, 9]]
        x = motif_average[time, [0, 2, 4, 6, 8, 10, 12, 14]]
        y = motif_average[time, [1, 3, 5, 7, 9, 11, 13, 15]]
        
        return hv.Curve((x, y))

    def cricket_trajectory(time):
#         x = motif_average[:time+1, [10, 12]]
#         y = motif_average[:time+1, [11, 13]]
        
        x = motif_average[time, [16, 18]]
        y = motif_average[time, [17, 19]]

    
#     return hv.Curve((x, y))*arrow_head*arrow_body
        return hv.Curve((x, y))

    mouse_map = hv.DynamicMap(mouse_trajectory, kdims=['time'])
    if 'prey_capture' in vame_type:
        cricket_map = hv.DynamicMap(cricket_trajectory, kdims=['time'])


# both_map = (mouse_map*cricket_map).opts(width=600, height=400, xlim=(0, 1280), ylim=(0, 1024))
if 'prey_capture' in vame_type:
    both_map = (mouse_map*cricket_map).opts(width=600, height=400, xlim=(-40, 40), ylim=(-40, 40))
else:
    both_map = (mouse_map).opts(width=600, height=400, xlim=(-40, 40), ylim=(-40, 40))
        
both_panel = pn.panel(both_map.redim.range(time=(0, motif_average.shape[0]-1)), 
                      center=True, widget_location='top')
both_panel

In [None]:
# plot the motifs
# average across files
print(motif_average_list[0].shape)
# average_motifs = np.nanmean(np.stack(motif_average_list, axis=2), axis=2)
# sem_motifs = np.nanmean(np.stack(motif_sem, axis=3), axis=3)
average_motifs = np.array(motif_average_list)
# define the indexes to plot
index_x = [0, 2, 4, 6, 8, 10, 12, 14]
index_y = [1, 3, 5, 7, 9, 11, 13, 15]

cricket_x = [16, 18]
cricket_y = [17, 19]
# plot
# allocate memory for the plots
plot_list = []

# for all the motifs
# for motif in np.arange(latent_list[0].shape[1]):
for motif in motif_sort:
#         line = hv.Curve((list(np.arange(average.shape[0])), average),label=bino, vdims='Goodness of fit').opts(
#         width=400, height=400, shared_axes=False,xticks=x_labels, xrotation=45, padding=0.1, 
#         fontsize={'title': 16, 'labels': 14, 'xticks': 12, 'yticks': 12})
#     shadow = hv.Spread((list(np.arange(average.shape[0])), average, errors)).opts(shared_axes=False) 
    curve1 = hv.Curve((average_motifs[motif, index_x, 0], average_motifs[motif, index_y, 0]))
#     shadow1 = hv.Spread((average_motifs[motif, index_x, 0], average_motifs[motif, index_y, 0],
#                                                                              sem_motifs[motif, index_y, 0]))
    
    curve2 = hv.Curve((average_motifs[motif, index_x, 1], average_motifs[motif, index_y, 1]))
#     shadow2 = hv.Spread((average_motifs[motif, index_x, 1], average_motifs[motif, index_y, 1],
#                                                                              sem_motifs[motif, index_y, 1]))
    if 'prey_capture' in vame_type:
        curve3 = hv.Curve((average_motifs[motif, [16, 18], 0], average_motifs[motif, [17, 19], 0]))
        curve4 = hv.Curve((average_motifs[motif, [16, 18], 1], average_motifs[motif, [17, 19], 1]))
#     plot_list.append(curve1*shadow1*curve2*shadow2)
        plot_list.append(curve1*curve2*curve3*curve4)
    else:
        plot_list.append(curve1*curve2)
hv.Layout(plot_list, kdims='time').cols(3)
# print(average_motifs)
# plot_list[0]

In [None]:
# Load the transition and usage analysis

# load the latent and labels
usage_list = [np.load(os.path.join(target_folder,'results',el,'VAME','kmeans-15','motif_usage_'+el+'.npy')) 
              for el in result_list]
# transition_list = [np.load(os.path.join(target_folder,'results',el,'VAME_prey_5_model','kmeans-15','behavior_quantification','transition_matrix.npy')) 
#                for el in result_list]

In [None]:
# Calculate and plot the average usage
# compiled_usage = np.vstack(usage_list)

# allocate memory for the output usages
usage_all = np.zeros((len(label_list), latent_list[0].shape[1]))
# for all the files
for idx, labels in enumerate(label_list):
    # get the unique numbers and their counts
    unique_nums, unique_counts = np.unique(motif_revsort[labels], return_counts=True)
    # fill in the corresponding indexes in the matrix
    usage_all[idx, unique_nums] = unique_counts

# average
average_usage = np.mean(usage_all, axis=0)
sem_usage = sem(usage_all, axis=0)

motif_number = latent_list[0].shape[1]
# plot
def motif_usage_plot(data_in, std_in, axis_limits):
    bars = hv.Bars(data_in, kdims=['Motif'], vdims=['Fraction'])
    bars.opts(width=600, height=400, ylim=(0, 150))
    errorbars = hv.ErrorBars((np.arange(axis_limits), data_in, std_in))

    return bars*errorbars

# calculate the succ and fail averages
succ_usages = np.array([el for idx, el in enumerate(usage_all) if 'succ' in result_list[idx]])
succ_average = np.mean(succ_usages, axis=0)
succ_std = sem(succ_usages, axis=0)/np.max(succ_average)
succ_average /= np.max(succ_average)

succ_plot = motif_usage_plot(succ_average, succ_std, motif_number).opts(ylim=(0, 1.2))

fail_usages = np.array([el for idx, el in enumerate(usage_all) if 'fail' in result_list[idx]])
fail_average = np.mean(fail_usages, axis=0)
fail_std = sem(fail_usages, axis=0)/np.max(fail_average)
fail_average /= np.max(fail_average)

fail_plot = motif_usage_plot(fail_average, fail_std, motif_number).opts(ylim=(0, 1.2))

img = succ_plot+fail_plot
img.opts(shared_axes=False).cols(1)
img


In [None]:
%%time

# Load the matching prey capture data

# using the slug, perform serial calls to the database
# (super inefficient, but this is temporary as the VAME data should be includedin the hdf5 file)

# allocate memory for the data
beh_data = []

# for all the files
for files in result_list:
    # define the search string
    search_string = 'slug:'+files
    # query the database for data to plot
    data_all = bd.query_database('analyzed_data', search_string)
    data_path = data_all[0]['analysis_path']

    # load the data
    beh_data.append(pd.read_hdf(data_path, 'full_traces'))
    

In [None]:
# generate parameter profiles for the motifs
# THIS IS WHERE MOTIF_SORT IS CALCULATED

# define the target parameter
target_parameter = 'cricket_0_mouse_distance'
# allocate memory for all the motifs
motif_parameter = []
# allocate memory to store the max of the parameter for sorting purposes
motif_max = []

# for all the motifs
# for motif in np.arange(num_motifs):
for motif in motif_sort:
    # get the median
    current_median = median_duration[motif]
    # allocate a list per file
    motif_list = []
    # get the coordinates across trials
    # for all the files
    for idx, data in enumerate(beh_data):
        # get the labels
        labels = label_list[idx]
        
        # find the starts 
        start_idx = np.argwhere(motif_duration[idx][motif]==current_median)
        starts = motif_location[idx][motif][start_idx]
        
        # skip if no hits
        if len(starts)==0 or (target_parameter not in data.keys()):
            continue
        # get the current set of motifs
        current_motif = [data[target_parameter].to_numpy()\
                         [int(np.round(el)):int(np.round(el)+np.round(current_median))] \
                           for el in starts[0]]
        # get the coordinates
        motif_list.append(current_motif)
    
    
    # turn into an array
    average_parameter = np.mean(np.concatenate(motif_list, axis=0), axis=0).T
    std_parameter = sem(np.concatenate(motif_list, axis=0), axis=0).T
    # get the max and store
    motif_max.append(np.max(average_parameter))
    
    # plot
    x_coord = np.array(np.arange(average_parameter.shape[0]))
    curve = hv.Curve((x_coord, average_parameter)).opts(width=250)
    spread = hv.Spread((x_coord, average_parameter, std_parameter))
    motif_parameter.append(curve*spread)

# # get a new motif sorting index
# motif_sort = np.argsort(motif_max)
# motif_revsort = np.argsort(motif_sort)
print(motif_sort)
print(motif_revsort)
# display the plots
hv.Layout(motif_parameter).cols(3)


In [None]:
# Get the behavioral data

# allocate memory for the distances
distance_list = []
target_key = 'cricket_0_delta_heading'

# define the interval to take from the edges
# edges = [7, -14] # for bins = 14
edges = [15, -15] # for bins = 30

def triangle_area(p1, p2, p3):
    a1 = p1[:, 0]*(p2[:, 1] - p3[:, 1])
    a2 = p2[:, 0]*(p3[:, 1] - p1[:, 1])
    a3 = p3[:, 0]*(p1[:, 1] - p2[:, 1])
    return np.abs(a1+a2+a3)/2

# for all the files
for files in beh_data:
#     temp_values = np.log(files.index[edges[0]:edges[1]])
#     temp_values = (files['time_vector'][edges[0]:edges[1]]/np.max(files['time_vector'][edges[0]:edges[1]]))
#     temp_values = files['mouse_speed'][edges[0]:edges[1]]
    if target_key in files.keys():
        temp_values = files[target_key][edges[0]:edges[1]]
#         temp_values = np.log(files[target_key][edges[0]:edges[1]])
#         temp_values[temp_values>10] = 10

        temp_values[np.isinf(temp_values)] = 0
#         snout = files[['mouse_snout_x', 'mouse_snout_y']].to_numpy()[edges[0]:edges[1], :]
#         tail = files[['mouse_base_x', 'mouse_base_y']].to_numpy()[edges[0]:edges[1], :]
#         temp_values = fk.distance_calculation(snout, tail)
#         temp_values[temp_values>10] = 10

#         # calculate curvature
#         snout = files[['mouse_head_x', 'mouse_head_y']].to_numpy()[edges[0]:edges[1], :]
#         body = files[['mouse_body2_x', 'mouse_body2_y']].to_numpy()[edges[0]:edges[1], :]
#         tail = files[['mouse_base_x', 'mouse_base_y']].to_numpy()[edges[0]:edges[1], :]
        
# #         triangle = triangle_area(snout, body, tail)
#         side_1 = fk.distance_calculation(snout, body)
#         side_2 = fk.distance_calculation(snout, tail)
#         side_3 = fk.distance_calculation(tail, body)
        
# #         temp_values = 4*triangle/(side_1*side_2*side_3)
#         temp_values = np.arccos((side_1**2 + side_3**2 - side_2**2)/(2*side_1*side_3))
# #         temp_values = np.log(temp_values)
#         temp_values[temp_values<2] = 2
#         temp_values[np.isinf(temp_values)] = 0
#         temp_values = (files['time_vector'][edges[0]:edges[1]]/
#                        np.max(files['time_vector'][edges[0]:edges[1]]))
    else:
        temp_values = np.zeros_like(files['mouse_x'][edges[0]:edges[1]])
#     temp_values[temp_values==0] = np.nan
    distance_list.append(temp_values)
    
print(label_list[0].shape)
print(distance_list[0].shape)
print(beh_data[0].keys())

# print(beh_data[0]['mouse_x'])
# print(data_list[0][3, :])

In [None]:
# Create categorical labels


unique_mice, point_id = np.unique(['_'.join(el.split('_')[7:10]) for el in result_list], return_inverse=True)
# _, point_id = np.unique(['_'.join(el.split('_')[0:3]) for el in result_list], return_inverse=True)

distance_list = []

for idx, files in enumerate(result_list):
#     distance_list.append(np.zeros_like(beh_data[idx]['mouse_x'][edges[0]:edges[1]])+point_id[idx])
    if 'succ' in files:
        distance_list.append(np.zeros_like(beh_data[idx]['mouse_x'][edges[0]:edges[1]]))
    elif 'fail' in files:
        distance_list.append(np.zeros_like(beh_data[idx]['mouse_x'][edges[0]:edges[1]])+1)
    else:
        distance_list.append(np.zeros_like(beh_data[idx]['mouse_x'][edges[0]:edges[1]])+2)
distance_list = np.array(distance_list)


# print(unique_mice)

In [None]:
# plot the UMAP clusters

# get the labels
compiled_labels = np.expand_dims(np.hstack(label_list).T, axis=1)
print(motif_sort)
compiled_labels = motif_revsort[compiled_labels]

# define the sampling ratio
sampling_ratio = 10

umap_data = np.concatenate((embedded_data[::sampling_ratio, :],
                            compiled_labels[::sampling_ratio, :]), axis=1)

print(umap_data.shape)
                            
                            
umap_plot = hv.Scatter(umap_data, vdims=['Dim 2','cluster'], kdims=['Dim 1'])
print(umap_plot)
umap_plot.opts(color='cluster', colorbar=True, cmap='Spectral', tools=['hover'])
umap_plot.opts(opts.Scatter(width=800, height=600))
umap_plot

# save the plot 

In [None]:
# plot the UMAP embedding with another parameter

# get the labels
compiled_labels = np.expand_dims(np.hstack(distance_list), axis=1)

# define the sampling ratio
sampling_ratio = 10

umap_data = np.concatenate((embedded_data[::sampling_ratio, :],compiled_labels[::sampling_ratio,:]), axis=1)

print(umap_data.shape)
                            
                            
umap_plot = hv.Scatter(umap_data, vdims=['Dim 2','parameter'], kdims=['Dim 1'])
umap_plot.opts(color='parameter', colorbar=True, cmap='Spectral', tools=['hover'])
umap_plot.opts(opts.Scatter(width=800, height=600))
umap_plot

In [None]:
# calculate transition matrices

# define the number of frames back to look
back_frames = 5
# get the number of motif
motif_number = latent_list[0].shape[1]
# allocate memory for the matrices
transition_matrices = np.zeros((back_frames, len(label_list), motif_number, motif_number))
# for all the label files
for idx, files in enumerate(label_list):
    # remove the consecutive repeats
    clean_label = [iterator[0] for iterator in it.groupby(files)]
    
    # for all the frames
    for idx_frame, frames in enumerate(clean_label):
        # get the labels (coordinates)
        x = motif_revsort[frames]
        # for all the back frames
        for bframes in np.arange(back_frames):
            # skip if not further enough yet
            if (idx_frame - bframes) < 0:
                continue
            y = motif_revsort[clean_label[idx_frame - bframes - 1]]
            transition_matrices[bframes, idx, x, y] += 1
#     # for all the frames except the first one
#     for idx_frame, frames in enumerate(files[1:]):
#         # get the coordinates (y is pre, x is post)
# #         x = frames
# #         y = files[idx_frame-1]

#         x = np.argwhere(frames==motif_sort)
#         y = np.argwhere(files[idx_frame-1]==motif_sort)
#         if x != y:
#             # increment the counter
#             transition_matrices[idx, x, y] += 1

In [None]:
# plot average matrices

# define the plotting function
def plot_matrix(matrix_in, axis_lims):
# plot the overall average
    # get the labels
    x_labels = [((idx+0.5)/(axis_lims)-0.5, el) for idx, el in enumerate(np.arange(axis_lims))]
    y_labels = [((axis_lims-idx-0.5)/(axis_lims)-0.5, el) 
                for idx, el in enumerate(np.arange(axis_lims))]
    # plot
    target_matrix = hv.Image(matrix_in, kdims=['Post motif', 'Pre motif'])
    target_matrix.opts(tools=['hover'], xticks=x_labels, yticks=y_labels, cmap='viridis')
    return target_matrix

# plot the average for succ vs fail
# initialize a list for the plots
plot_list = []

# for all the back frames
for bframes in np.arange(back_frames):
    temp_matrix = transition_matrices[bframes, :, :, :]

    # get the succ
    succ_matrices = np.mean([el for idx, el in enumerate(temp_matrix) 
                             if 'succ' in result_list[idx]], axis=0)

    fail_matrices = np.mean([el for idx, el in enumerate(temp_matrix) 
                             if 'fail' in result_list[idx]], axis=0)
    
    # get the normalization factor
    if bframes == 0:
        succ_norm = np.max(succ_matrices)
        fail_norm = np.max(fail_matrices)
        
    succ = plot_matrix(succ_matrices/succ_norm, motif_number)
    fail = plot_matrix(fail_matrices/fail_norm, motif_number)
    
    plot_list.append(succ)
    plot_list.append(fail)

# img = (succ+fail).opts(shared_axes=False)

img = hv.Layout(plot_list).opts(shared_axes=True).cols(2)
img
