In [None]:
# imports
import os
import sys
sys.path.insert(0, os.path.abspath(r'D:\Code Repos\prey_capture'))

import panel as pn
import holoviews as hv
from bokeh.io import export_svgs, export_png
from holoviews import opts, dim
from holoviews.operation import histogram
hv.extension('bokeh')
from bokeh.resources import INLINE

import importlib
import processing_parameters
import datetime

import functions_bondjango as bd
import functions_plotting as fp
import functions_loaders as fl
import pandas as pd
import numpy as np
from pprint import pprint
import paths
import random
import scipy.stats as stat
import umap
from scipy.signal import medfilt
import sklearn.cluster as clust
from scipy.cluster.hierarchy import dendrogram
from matplotlib import pyplot as plt

pd.options.mode.chained_assignment = None  # default='warn'

In [None]:
# set up the figure config
importlib.reload(fp)
importlib.reload(processing_parameters)
# define the target saving path
save_path = os.path.join(paths.figures_path, 'Latent_vis')

# define the printing mode
save_mode = True
# define the target document
target_document = 'paper'
# set up the figure theme
fp.set_theme()
# load the label dict
label_dict = processing_parameters.label_dictionary
variable_list = processing_parameters.variable_list

In [None]:
# Load the weights for the motifs

# order is [clusters, features]
# define the number of clusters
n_cluster = 15
# get the path
first_file = os.listdir(os.path.join(paths.vame_results, 'results'))[0]
template_path = os.path.join(paths.vame_results, 'results', first_file, 'VAME', 'kmeans-'+str(n_cluster),
                             'cluster_center_'+first_file+'.npy')
# load the file
motif_weights = np.load(template_path)

# plot
plot = hv.Raster(motif_weights, kdims=['Latents', 'Motifs'])
plot.opts(width=800, height=800, tools=['hover'], colorbar=True, cmap='RdBu', clim=(-3.5, 3.5))
plot


In [None]:
# importlib.reload(processing_parameters)
# get the search string
# search_string = processing_parameters.search_string

# # get the paths from the database
# all_path = bd.query_database('analyzed_data', search_string)
# input_path = [el['analysis_path'] for el in all_path if '_preproc' in el['slug']]

# assemble the output path
# out_path = os.path.join(paths.analysis_path, 'test_latentconsolidate.hdf5')
# pprint(input_path)

In [None]:
%%time
# load the data

importlib.reload(fl)
importlib.reload(processing_parameters)
# get the paths
path_list, query_list = fl.query_search_list()

# get the data
full_df = []
frame_list = []
meta_list = []
# initialize a trial counter
trial_idx = 0
# for all the paths and queries
for path, query in zip(path_list, query_list):
#     # get rid of DG_210323_b for now
#     path = [el for el in path if 'DG_210323_b' not in el]
#     query = [el for el in query if 'DG_210323_b' not in el['analysis_path']]
    data, frames, meta = fl.load_preprocessing(path, query, behavior_flag=True)
    
    # exclude the cells
    without_cells = []
    # for all the trials in data
    for trial in data:
        # TODO: fix this in the actual function
        if 'badFile' in trial.columns:
            continue
        if 'sync_frames' in trial.columns:
            trial = trial.drop(columns=['sync_frames'])
        if 'latent_0' not in trial.columns:
            continue

        not_cells = [el for el in trial.columns if 'cell' not in el]
        no_cells_trial = trial[not_cells]
        # remove the nan rows in the latents
        latents = [el for el in no_cells_trial.columns if 'latent' in el]
        no_nan_vector = ~np.isnan(np.sum(no_cells_trial[latents].to_numpy(), axis=1))
        no_cells_trial = no_cells_trial.iloc[no_nan_vector, :]
        # save the trial number
        no_cells_trial['trial_idx'] = trial_idx
        # update the trial counter
        trial_idx += 1
        without_cells.append(no_cells_trial)
        
#         print(latents)
#         print(no_nan_vector)
#         print(without_cells[-1])
#         raise ValueError
        
    
    full_df.append(pd.concat(without_cells, axis=0))
    frame_list.append(frames)
    meta_list.append(meta)

full_df = pd.concat(full_df, axis=0)
print(full_df.shape)


In [None]:
print(full_df.columns)

In [None]:
# Correlate the latents to the other behavioral variables

# get the latents
latents = [el for el in full_df.columns if 'latent' in el]

# allocate memory for the average matrix
average_matrix = []
# for all the trials
for trial_name, trial in full_df.groupby('trial_idx'):
    latents_matrix = trial[latents].to_numpy().copy()
    # get the behavior
    behavior_matrix = trial[variable_list].to_numpy().copy()
    temp_matrix = stat.spearmanr(latents_matrix, behavior_matrix)[0]
    temp_matrix[np.isnan(temp_matrix)] = 0
    average_matrix.append(temp_matrix)
    

average_matrix = np.mean(average_matrix, axis=0)
print(average_matrix.shape)

In [None]:
# plot the matrix

# Plot the matrix
yticks = [(idx+0.5, el.replace('_', ' ').replace('l', 'L')) for idx, el in enumerate(latents)]
xticks = [(idx+0.5, label_dict[el]) for idx, el in enumerate(variable_list)]
# ticks = [(idx+0.5, idx) for idx, el in enumerate(variable_list)]
plot_matrix = average_matrix.copy()[:len(latents), len(latents):]

# plot_matrix = np.tril(plot_matrix, k=0)
# plot_matrix[plot_matrix==0] = np.nan
# hv.Raster(correlation_matrix)
raster = hv.Raster(plot_matrix)
raster.opts(tools=['hover'])
# format the plot
raster.opts(width=950, height=800, yticks=yticks, xticks=xticks, colorbar=True, cmap='RdBu', clim=(-1, 1), xrotation=45, xlabel='Behavior', ylabel='Latents')

# assemble the file name
save_name = os.path.join(save_path, '_'.join((target_document, 'overall_correlation')) + '.png')
# save the figure
fig = fp.save_figure(raster, save_name, fig_width=15, dpi=1200, fontsize=target_document, target='screen')

In [None]:
def normalize_rows(data_in):
    
    for idx, el in enumerate(data_in):
        data_in[idx, :] = (el-np.nanmin(el))/(np.nanmax(el)-np.nanmin(el))
    return data_in

In [None]:
print(full_df.columns)

In [None]:
# visualize latents
# define the target trial
target_trial = 11

# get the target trial
# current_trial = pre_data[target_trial][1]
current_trial = full_df.iloc[full_df['trial_idx'].to_numpy()==target_trial, :]
# print(pre_data[target_trial][0])
print(current_trial.shape)
print(current_trial.columns[:50])
# pprint(current_trial.iloc[-50:, 1])

# get the time columns
x = current_trial['time_vector']

# define the target columns to keep
target_columns = ['cricket_0_mouse_distance', 'cricket_0_delta_heading', 
                  'mouse_speed', 'cricket_0_speed', 'mouse_heading', 'mouse_x', 'mouse_y', 'cricket_0_x', 'cricket_0_y', 'motifs']
target_columns += [el for el in current_trial.columns if 'latent' in el]
# leave only the target columns
current_trial = current_trial[target_columns]
y = np.arange(len(current_trial.columns))

pprint(current_trial.columns)
y_labels = [(idx, el) for idx, el in enumerate(target_columns)]

# transpose, get rid of nans and normalize the rows
current_trial = current_trial.to_numpy().T
current_trial[np.isnan(current_trial)] = 0
current_trial = normalize_rows(current_trial)

# visualize latents with respect to other variables
trial_plot = hv.Image((x, y, current_trial))
trial_plot.opts(frame_width=600, frame_height=600, tools=['hover'], colorbar=True, yticks=y_labels)

latent_correlation = np.corrcoef(current_trial)
correlation_plot = hv.Image((y, y, latent_correlation))
correlation_plot.opts(frame_width=600, frame_height=600, tools=['hover'], colorbar=True,
                      yticks=y_labels, xticks=y_labels, xrotation=45, cmap='RdBu', clim=(-1, 1))

(trial_plot+correlation_plot).opts(shared_axes=False).cols(1)

In [None]:
# generate 2D TCs with latents as x and y and a behavioral varible as the Z

# define the target behavioral variable
target_behavior = 'mouse_speed'
# define the latents of interest
first_latent = 'latent_0'
second_latent = 'latent_1'

# define the bins
bins = 10
# allocate a list for the plots
plot_list = []
# for all the trials
# for trial in pre_data:
# get the list of trials
trial_list = np.unique(full_df['trial_idx'].to_numpy())
for trial_idx in trial_list[:1]:
    
    trial = full_df.iloc[full_df['trial_idx'].to_numpy()==trial_idx, :]
    # get just the dataframe
#     trial = trial[1]
    # skip if the behavior is not there
    if target_behavior not in trial.columns:
        continue
    # get the variables of interest
    feature_0 = trial.loc[:, first_latent]
    feature_1 = trial.loc[:, second_latent]
    behavior = trial.loc[:, target_behavior]
    # remove nans and infs
    feature_0[np.isnan(feature_0)] = 0
    feature_1[np.isnan(feature_1)] = 0
    behavior[np.isnan(behavior)] = 0
    
    feature_0[np.isinf(feature_0)] = 0
    feature_1[np.isinf(feature_1)] = 0
    behavior[np.isinf(behavior)] = 0
    
    # get the histogram
    current_tc, x_edge, y_edge, bin_number = stat.binned_statistic_2d(feature_0, feature_1, behavior, statistic='mean', bins=bins)
    
    # plot and store
    plot = hv.Image((x_edge, y_edge, np.array(current_tc)), kdims=[first_latent, second_latent])
    plot.opts(tools=['hover'], cmap='Spectral')
    plot_list.append(plot)

hv.Layout(plot_list).cols(5)
    
    

In [None]:
# generate a single map for all trials

# define the target behavioral variable
target_behavior = 'mouse_speed'
# define the latents of interest
first_latent = 'latent_0'
second_latent = 'latent_1'

# define the bins
bins = 10
# allocate memory for the accumulated trials
feature_0_list = []
feature_1_list = []
behavior_list = []
# # for all the trials
# for trial in pre_data:
# get the list of trials
trial_list = np.unique(full_df['trial_idx'].to_numpy())
# for all the trials
for trial_idx in trial_list:
    
    trial = full_df.iloc[full_df['trial_idx'].to_numpy()==trial_idx, :]
#     # get just the dataframe
#     trial = trial[1]
    # skip if the behavior is not there
    if target_behavior not in trial.columns:
        continue
    # get the variables of interest
    feature_0 = trial.loc[:, first_latent]
    feature_1 = trial.loc[:, second_latent]
    behavior = trial.loc[:, target_behavior]
    # remove nans and infs
    feature_0[np.isnan(feature_0)] = 0
    feature_1[np.isnan(feature_1)] = 0
    behavior[np.isnan(behavior)] = 0
    
    feature_0[np.isinf(feature_0)] = 0
    feature_1[np.isinf(feature_1)] = 0
    behavior[np.isinf(behavior)] = 0
    
    # store the variables in a list
    feature_0_list.append(feature_0)
    feature_1_list.append(feature_1)
    behavior_list.append(behavior)
    
# concatenate
feature_0 = pd.concat(feature_0_list, axis=0)
feature_1 = pd.concat(feature_1_list, axis=0)
behavior = pd.concat(behavior_list, axis=0)
# get the histogram
current_tc, x_edge, y_edge, bin_number = stat.binned_statistic_2d(feature_0, feature_1, behavior, statistic='mean', bins=bins)

# plot and store
plot = hv.Image((x_edge, y_edge, np.array(current_tc)), kdims=[first_latent, second_latent])
plot.opts(tools=['hover'], cmap='Spectral')

In [None]:
# function to separate the day from the time in the datetime field
def separate_day_time(input_string):
    split_string = input_string.split(' ')
    day = split_string[0]
#     time = split_string[1]
    return day

In [None]:
# generate averages of each latent over time

# get a list of the latents
latent_list = [el for el in full_df.columns if 'latent' in el]

# allocate a list for the latent plots
latent_plots = []
# for all the latents
for latent in latent_list:
    # allocate a list for the overlay across mice
    overlay_list = []
    # run through all the mice
    for mouse_name, mouse_data in full_df.groupby(['mouse']):

        # transform the date so it starts at 0

        # get the latent variable along with the date
        current_variable = mouse_data.loc[:, [latent,'datetime']]
        # convert the datetime into just date
        current_variable['datetime'] = [el[:10] for el in current_variable['datetime']]
        # group by the date and average
        current_mean = current_variable.groupby(['datetime'], as_index=False).mean()
        current_sem = current_variable.groupby(['datetime'], as_index=False).sem()
        # get the dates
        day_data = current_mean['datetime'].to_numpy()
        # reformat day as a delta
        # convert to datetime first
        day_data = [datetime.datetime.strptime(el, '%Y-%m-%d') for el in day_data]
        # calculate the deltas
        delta_days = [(el-day_data[0]) for el in day_data]
        delta_days = [el.days for el in delta_days]
        # update the datetime field
        current_mean['datetime'] = delta_days
        current_sem['datetime'] = delta_days
        
        # filter out anything above the 10th day
        selection_vector = current_mean['datetime'].to_numpy() < 11
        current_mean = current_mean.iloc[selection_vector, :]
        current_sem = current_sem.iloc[selection_vector, :]
        
        # zero the mean
        current_mean[latent] -= current_mean.iloc[current_mean['datetime'].to_numpy() == 0, 1].to_numpy()

#         # get the latent data
#         y_mean = current_mean.loc[:, latent]
#         y_sem = current_sem.loc[:, latent]
#         # generate the plots
#         curve = hv.Curve(current_mean, vdims=[latent], kdims=['datetime'])
#         curve.opts(width=600, xrotation=45, tools=['hover'])
#         spread = hv.Spread((current_mean['datetime'], current_mean[latent], current_sem[latent]), vdims=[latent, 'Error'], kdims=['datetime'])
#         # store as overlay
#         overlay_list.append(curve*spread)
#     # store in the layout list
#     latent_plots.append(hv.Overlay(overlay_list))
    
        # store as dataframe
        overlay_list.append(current_mean)
    # generate the accumulated dataframe
    plot_df = pd.concat(overlay_list, axis=0)
    # select data only for the first 10 days
    selection_vector = plot_df['datetime'].to_numpy() < 11
    plot_df = plot_df.iloc[selection_vector, :]
    # plot
    mouse_mean = plot_df.groupby(['datetime'], as_index=False).mean()
    mouse_error = plot_df.groupby(['datetime'], as_index=False).std()
    
    curve = hv.Curve(mouse_mean, vdims=[latent], kdims=['datetime'])
    curve.opts(width=600, xrotation=45, tools=['hover'], xlabel='Day')
    spread = hv.Spread((mouse_mean['datetime'], mouse_mean[latent], mouse_error[latent]), vdims=[latent, 'Error'], kdims=['datetime'])
    latent_plots.append(curve*spread)

# turn into a layout and plot
latent_layout = hv.Layout(latent_plots).cols(3)
latent_layout


In [None]:
# group by date and generate the TC

# define the target behavioral variable
target_behavior = 'cricket_0_mouse_distance'
# define the latents of interest
first_latent = 'latent_0'
second_latent = 'latent_1'

# get a list of dates
date_list = np.unique(full_df['datetime'])
# allocate a plot list
plot_list = []

# for all the dates
for date in date_list:
    # get the data corresponding to this date
    data_idx = np.argwhere(full_df.loc[:, 'datetime'].to_numpy() == date).flatten()
    current_data = full_df.iloc[data_idx, :]
    # get the variables of interest
    feature_0 = current_data.loc[:, first_latent]
    feature_1 = current_data.loc[:, second_latent]
    behavior = current_data.loc[:, target_behavior]
    # remove nans and infs
    feature_0.loc[np.isnan(feature_0)] = 0
    feature_1.loc[np.isnan(feature_1)] = 0
    behavior.loc[np.isnan(behavior)] = 0
    
    feature_0.loc[np.isinf(feature_0)] = 0
    feature_1.loc[np.isinf(feature_1)] = 0
    behavior.loc[np.isinf(behavior)] = 0
    
    # generate the TC
    current_tc, x_edge, y_edge, bin_number = stat.binned_statistic_2d(feature_0, feature_1, behavior, statistic='mean', bins=bins)

    # plot and store
    im = hv.Image((x_edge, y_edge, np.array(current_tc)), kdims=[first_latent, second_latent])
    im.opts(tools=['hover'], cmap='Spectral', title=date)
    plot_list.append(im)

hv.Layout(plot_list).cols(5)

In [None]:
%%time
# UMAP embedding of the VAME data

# compile the data
# compiled_latent = np.vstack(latent_list)
compiled_latent = full_df.loc[:, [el for el in full_df.columns if 'latent' in el]].to_numpy()

# print(np.sum(compiled_latent, axis=0))
# raise ValueError
# embed using UMAP
# original parameters 0.5 and 10
# 0.1 and 30 also works
# 0.05 and 30 works too
reducer = umap.UMAP(min_dist=0.1, n_neighbors=20)
embedded_data = reducer.fit_transform(compiled_latent[:, :])

# # save the embedding
# np.save(os.path.join(target_folder, 'UMAP_result'), embedded_data)

# # generate the model name
# model_name = os.path.join(target_folder, 'UMAP_model.pk')
# # save the estimator
# with open(model_name, 'wb') as file:
#     pk.dump(reducer, file)

In [None]:
# plot the UMAP results

# get the labels
# compiled_labels = np.expand_dims(np.hstack(distance_list), axis=1)
compiled_labels = np.expand_dims(full_df.loc[:, 'cricket_0_delta_heading'].to_numpy().copy(), axis=1)
# need to threshold, for some reason there's some weird distances
# compiled_labels[compiled_labels>50] = 50
# compiled_labels[compiled_labels<0] = 0
compiled_labels = medfilt(compiled_labels, kernel_size=[21, 1])

# define the sampling ratio
sampling_ratio = 10

umap_data = np.concatenate((embedded_data[::sampling_ratio, :],compiled_labels[::sampling_ratio, :]), axis=1)

print(umap_data.shape)
                            
                            
umap_plot = hv.Scatter(umap_data, vdims=['Dim 2','parameter'], kdims=['Dim 1'])
# umap_plot = hv.HexTiles(umap_data, kdims=['Dim 1', 'Dim 2'])
umap_plot.opts(color='parameter', colorbar=True, cmap='Spectral', tools=['hover'], alpha=0.5)
umap_plot.opts(width=1200, height=1000, size=5)
# umap_plot.opts(width=1200, height=1000)
umap_plot


In [None]:
# add a trajectory on top of the map

# define the target trial
target_trial = 30

traj_data = embedded_data[full_df.loc[:, 'trial_idx'] == target_trial, :]
trajectory = hv.Curve((traj_data[:, 0], traj_data[:, 1]))
start = hv.Scatter((traj_data[0, 0], traj_data[0, 1])).opts(color='Red', size=10)
end = hv.Scatter((traj_data[-1, 0], traj_data[-1, 1])).opts(color='Blue', size=10)

umap_plot*trajectory*start*end

In [None]:
# plot aggregated trajectories in UMAP space

# allocate a list for the trials
trial_plots = []
# get the number of trials
trial_list = np.unique(full_df.loc[:, 'trial_idx'])
# for all the trials
for trial in trial_list[::20]:
    # get the corresponding coordinates
    traj_data = embedded_data[full_df.loc[:, 'trial_idx'] == trial, :]
    
    # generate the curve
    curve = hv.Curve((traj_data[::10, 0], traj_data[::10, 1]))
    curve.opts(width=1200, height=1000, alpha=0.2)
    start = hv.Scatter((traj_data[0, 0], traj_data[0, 1])).opts(color='Red', size=10)
    start.opts(width=1200, height=1000)
    end = hv.Scatter((traj_data[-1, 0], traj_data[-1, 1])).opts(color='Blue', size=10)
    end.opts(width=1200, height=1000)
    
#     trial_plots.append(start)
#     trial_plots.append(end)
    
    # store
    trial_plots.append(curve*start*end)
    
# generate the figure
overlay = umap_plot*hv.Overlay(trial_plots[:])
overlay.opts({'Curve': dict(color=hv.Palette('Spectral'))})
overlay

In [None]:
print(latent_list)

In [None]:
# calculate transition matrices
print(full_df.columns, full_df.shape)

# for all the mice and days
for (mouse, day), data in full_df.groupby(['mouse', 'datetime']):
    print(mouse)
    # define the number of frames back to look
    back_frames = 5
    # get the number of motif
    motif_number = latent_list[0].shape[1]
    # allocate memory for the matrices
    transition_matrices = np.zeros((back_frames, len(label_list), motif_number, motif_number))
    # for all the label files
    for idx, files in enumerate(label_list):
        # remove the consecutive repeats
        clean_label = [iterator[0] for iterator in it.groupby(files)]

        # for all the frames
        for idx_frame, frames in enumerate(clean_label):
            # get the labels (coordinates)
            x = motif_revsort[frames]
            # for all the back frames
            for bframes in np.arange(back_frames):
                # skip if not further enough yet
                if (idx_frame - bframes) < 0:
                    continue
                y = motif_revsort[clean_label[idx_frame - bframes - 1]]
                transition_matrices[bframes, idx, x, y] += 1

In [None]:
# plot the clustering result

def plot_dendrogram(model, **kwargs):
    # Create linkage matrix and then plot the dendrogram

    # create the counts of samples under each node
    counts = np.zeros(model.children_.shape[0])
    n_samples = len(model.labels_)
    for i, merge in enumerate(model.children_):
        current_count = 0
        for child_idx in merge:
            if child_idx < n_samples:
                current_count += 1  # leaf node
            else:
                current_count += counts[child_idx - n_samples]
        counts[i] = current_count

    linkage_matrix = np.column_stack(
        [model.children_, model.distances_, counts]
    ).astype(float)

    # Plot the corresponding dendrogram
    dendrogram(linkage_matrix, **kwargs)

plot_dendrogram(cluster_element, truncate_mode="level", p=3)

