In [None]:
# imports
import os
import sys
sys.path.insert(0, os.path.abspath(r'D:\Code Repos\prey_capture'))

import numpy as np
import pandas as pd
import importlib
import sklearn.cluster as cluster
import scipy.stats as stat
import itertools as it

import processing_parameters
import paths
import functions_loaders as fl
import functions_plotting as fp

import panel as pn
import holoviews as hv
from bokeh.io import export_svgs, export_png
from holoviews import opts, dim
from holoviews.operation import histogram
hv.extension('bokeh')


In [None]:
# set up the figure config
importlib.reload(fp)
importlib.reload(processing_parameters)
# define the target saving path
save_path = os.path.join(paths.figures_path, 'Motif_vis')

# define the printing mode
save_mode = True
# define the target document
target_document = 'paper'
# set up the figure theme
fp.set_theme()
# load the label dict
label_dict = processing_parameters.label_dictionary
variable_list = processing_parameters.variable_list

In [None]:
%%time
# load the data

importlib.reload(fl)
importlib.reload(processing_parameters)
# get the paths
path_list, query_list = fl.query_search_list()

# get the data
full_df = []
frame_list = []
meta_list = []
# initialize a trial counter
trial_idx = 0
# for all the paths and queries
for path, query in zip(path_list, query_list):
#     # get rid of DG_210323_b for now
#     path = [el for el in path if 'DG_210323_b' not in el]
#     query = [el for el in query if 'DG_210323_b' not in el['analysis_path']]
    data, frames, meta = fl.load_preprocessing(path, query, behavior_flag=True)
    
    # exclude the cells
    without_cells = []
    # for all the trials in data
    for trial in data:
        # TODO: fix this in the actual function
        if 'badFile' in trial.columns:
            continue
        if 'sync_frames' in trial.columns:
            trial = trial.drop(columns=['sync_frames'])
        if 'latent_0' not in trial.columns:
            continue

        not_cells = [el for el in trial.columns if 'cell' not in el]
        no_cells_trial = trial[not_cells]
        # remove the nan rows in the latents
        latents = [el for el in no_cells_trial.columns if 'latent' in el]
        no_nan_vector = ~np.isnan(np.sum(no_cells_trial[latents].to_numpy(), axis=1))
        no_cells_trial = no_cells_trial.iloc[no_nan_vector, :]
        # save the trial number
        no_cells_trial['trial_idx'] = trial_idx
        # update the trial counter
        trial_idx += 1
        without_cells.append(no_cells_trial)
        
#         print(latents)
#         print(no_nan_vector)
#         print(without_cells[-1])
#         raise ValueError
        
    
    full_df.append(pd.concat(without_cells, axis=0))
    frame_list.append(frames)
    meta_list.append(meta)

full_df = pd.concat(full_df, axis=0)
print(full_df.shape)


In [None]:
# calculate average value of each variable during each motif

# exclude nans
nonan_df = full_df.dropna(axis=0).copy()

# nonan_df['motifs'] = motif_revsort[nonan_df['motifs'].to_numpy().astype(int)]

# compute the average value grouping by motif
latents = [el for el in full_df.columns if 'latent' in el]
average_per_motif = nonan_df.drop(columns=['datetime']+latents).groupby(['motifs']).mean()

print(average_per_motif)

In [None]:
# plot the motif averages
plot_matrix = average_per_motif.to_numpy().copy()
# plot_matrix = (plot_matrix/plot_matrix.sum(axis=0))
# plot_matrix = (plot_matrix - plot_matrix.mean(axis=0))/plot_matrix.std(axis=0)
plot_matrix = (plot_matrix - plot_matrix.min(axis=0))/(plot_matrix.max(axis=0) - plot_matrix.min(axis=0))
plot_matrix = plot_matrix.T

row_idx = cluster.AgglomerativeClustering(n_clusters=10).fit_predict(plot_matrix)
print(row_idx)
plot_matrix = plot_matrix[np.argsort(row_idx), :]

print(plot_matrix.sum(axis=1))

# get the labels
ticks = [(idx + 0.5, el) for idx, el in enumerate(average_per_motif.columns[np.argsort(row_idx)])]
xticks = [(el+0.5, el) for el in np.arange(plot_matrix.shape[1]) ] 

plot = hv.Raster(plot_matrix)
plot.opts(width=1200, height=1200, cmap='Spectral', xlabel='Motifs', xticks=xticks, ylabel='', yticks=ticks, tools=['hover'], colorbar=True)
plot

In [None]:
corr_matrix = stat.spearmanr(plot_matrix.T)[0]

plot = hv.Raster(corr_matrix)
plot.opts(width=1300, height=1200, cmap='RdBu', xticks=ticks, yticks=ticks, xlabel='', ylabel='', xrotation=45)
plot


In [None]:
# define the motif sorting
motif_sort = np.argsort(average_per_motif['time_vector'].to_numpy())

# motif_revsort = np.array([np.argwhere(motif_sort==el)[0][0] for el in np.arange(motif_sort.max()+1)])
motif_revsort = np.array(np.arange(motif_sort.max()+1))
print(motif_revsort)

print(average_per_motif.loc[motif_sort, 'time_vector'])

In [None]:
# calculate motif usage

# get a list of the motifs present
motif_list = np.unique(nonan_df['motifs'].fillna(0).to_numpy())
# allocate a dataframe
motif_counts = []
mouse_list = []
# for all mice
for mouse_name, mouse_data in nonan_df.groupby(['mouse']):
    
    # get the motifs
    current_motifs = mouse_data['motifs'].to_numpy()
    # remove the nans
    current_motifs = motif_revsort[current_motifs[~np.isnan(current_motifs)].astype(int)]
    # histogram them
    counts = np.bincount(current_motifs)/current_motifs.shape[0]
    # store in a list
    motif_counts.append(counts)
    mouse_list.append(mouse_name)

# create the output dataframe
motif_counts = pd.DataFrame(np.array(motif_counts).T, columns=mouse_list)
# motif_counts['mouse'] = mouse_list

# print(motif_counts)

In [None]:
# plot motif usage

plot_list = []
for mouse in mouse_list:
    plot = hv.Scatter(motif_counts[mouse])
    plot.opts(width=800, xlabel='Motif', ylabel='Fraction')
    plot_list.append(plot)
hv.Overlay(plot_list)

In [None]:
# plot motif average usage fraction 

In [None]:
print(nonan_df.shape, nonan_df.columns)

In [None]:
# %%time
# calculate transition matrices

# define the number of frames back to look
back_frames = 10
# get the number of motif
#     motif_number = latent_list[0].shape[1]
motif_number = len(motif_list)
# get a list of the trials
trial_list = np.unique(nonan_df['trial_idx'].to_numpy())
# allocate memory for the matrices
# transition_matrices = np.zeros((back_frames, len(label_list), motif_number * motif_number))
transition_matrices = []
# for all the label files
for idx, trial in enumerate(trial_list):
#     # remove the consecutive repeats
#     clean_label = [iterator[0] for iterator in it.groupby(files)]
    # get the motifs for the current trial
    current_trial = nonan_df.iloc[nonan_df['trial_idx'].to_numpy()==trial, :].reset_index(drop=True)
    mouse = current_trial.loc[0, 'mouse']
    datetime = current_trial.loc[0, 'datetime']
#     print(list(set(current_trial['motifs'].to_numpy().astype(int))))
#     print(current_trial['motifs'].to_numpy())
#     current_trial = list(set(motif_revsort[current_trial['motifs'].to_numpy().astype(int)]))
#     print(current_trial)
#     raise ValueError
    # remove the repeats
    current_trial = motif_revsort[[el[0] for el in it.groupby(current_trial['motifs'].to_numpy().astype(int))]]

    # for all the back frames
    for bframes in np.arange(back_frames):
        # initialize the transition matrix array
        current_transitions = np.zeros((motif_number, motif_number))
            # for all the frames
        for idx_frame, frames in enumerate(current_trial):
            # get the labels (coordinates)
    #         x = motif_revsort[frames]
            x = frames
            # skip if not further enough yet
            if (idx_frame - bframes) < 0:
                continue
            y = current_trial[idx_frame - bframes - 1]
            current_transitions[x, y] += 1
        # store on the list along with the mouse, trial id and date
        transition_matrices.append([trial, mouse, datetime, bframes] + current_transitions.flatten().tolist())
# turn into a dataframe
matrix_columns = [f'm_{el:02d}' for el in np.arange(motif_number**2)]
transition_matrices = pd.DataFrame(transition_matrices, columns=['trial_idx', 'mouse', 'datetime', 'back_frame']+matrix_columns)

In [None]:
# plot the average transition matrix at each time point

# allocate a plot list
plot_list = []

# get the corresponding rows
average_matrices = transition_matrices.groupby(['back_frame'], as_index=False).mean()

ticks = [(el+0.5, str(el)) for el in np.arange(motif_number)]

# for all the frames
for bframe in np.arange(back_frames):
    # current matrix
    current_matrix = average_matrices.iloc[average_matrices['back_frame'].to_numpy()==bframe, :]
    current_matrix = current_matrix.loc[:, [el for el in current_matrix.columns if 'm_' in el]].to_numpy()
    current_matrix = current_matrix.reshape((motif_number, motif_number), order='F')
    plot = hv.Raster(current_matrix)
    plot.opts(width=800, height=800, cmap='Viridis', tools=['hover'], xticks=ticks, yticks=ticks, ylabel='Current', xlabel='Next')
    
    plot_list.append(plot)
    
layout = hv.Layout(plot_list).cols(2)
layout
