## Keypoint MoSeq Result Visualization

#### This notebook visualize results from Keypoint MoSeq (e.g. from DeepLabCut)

- Make a copy of this notebook if you plan to make changes and want them saved
- Download the [example data](https://drive.google.com/drive/folders/1UNHQ_XCQEKLPPSjGspRopWBj6-YNDV6G?usp=share_link) to your drive or create a shortcut to it
- Go to "Runtime">"change runtime type" and select "Python 3" and "GPU"
- At the beginning of this notebook, you should have set up your project folder, and model object(s).

### Setup project directory
- Edit `project_dir` so the `project_dir` points to the project folder where the the example data lives.

In [None]:
import keypoint_moseq as kpm

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
# input project folder path
project_dir = '/Users/sherry/Desktop/keypoint_moseq/test2'
# load the config file
config = kpm.load_config(project_dir)

### Load model
Load in the model object as a dictionary for the analysis.

In [None]:
# input model name
name = '2022_11_13-16_02_22'
# load model object as a dictionary
results_dict = kpm.load_results(project_dir=project_dir, name=name)

## Compute moseq_df (scalar_df)

Kinematic values such as velocity and heading are extracte from the keypoints while the models identify behavioral motifs (syllables).
The steps below combine these data streams into a `DataFrame`. Learn more about `DataFrames` on [pandas's website](https://pandas.pydata.org/pandas-docs/stable/user_guide/dsintro.html).

`DataFrames` are both useful for the analysis and visualization cells below, and are a great way to organize data.
This means you can save and export them in multiple formats to run custom analyses in the software of your choosing.

In [None]:
from keypoint_moseq.viz.util import compute_moseq_df

# compute the frame by frame dataframe for kinematic values, coordinates and syllable labels
moseq_df = compute_moseq_df(results_dict, config)
print('moseq_df shape is', moseq_df.shape)
moseq_df.head()

#### add group label
This step shows the session names and the associated uuids. Check the session names and assgin group labels to the sessions.


In [None]:
group_lbl_df = moseq_df[['session_name', 'uuid']].drop_duplicates(['uuid'])
print('Session information below')
group_lbl_df

In [None]:
# assign group labels.
# length of groups must match the length of all session names
group_lbl_df['group'] = ['a', 'b', 'a', 'b', 'a', 'a', 'b', 'b', 'a', 'b']

In [None]:
# merge the session group information with moseq_df
group_lbl_df=group_lbl_df.drop(['session_name'], axis=1)
moseq_df = pd.merge(moseq_df, group_lbl_df, how = 'right', on='uuid')

In [None]:
print('moseq_df shape is', moseq_df.shape)
moseq_df.head()

#### export moseq_df

You can export the `moseq_df` to a csv file for further analysis using the following cell.

- Edit the path to point to where you want to save the dataframe in `save_path`. By default, the file will be saved to `base_dir`.

In [None]:
# import os
# moseq_df.to_csv(os.path.join(project_dir, name, 'moseq_df.csv'))

## Compute `stats_df`
`stats_df` is a `DataFrame` that contains statistical summaries (i.e., min, max, mean, std) of scalar values associated with each syllable, as well as the frequency each syllable is expressed. By default, it is computed using the features included in `moseq_df` for each session independently. This dataframe will be used to plot syllable statistics and perform hypothesis testing.

The function can group data into whichever categories you supply into the `groupby` parameter.
By default, we group by "group" (the experimental cohort) and "uuid" (each unique recording session).
An alternative is to groupby by experimental cohort and mouse: `groupby = ['group', 'SubjectName']`.

By default, each row of the `stats_df` contains the average syllable usage for one syllable for one group (experimental cohort) within one uuid (session). Changing the contents of the `groupby` variable will change the contents of the `stats_df`.

In [None]:
from keypoint_moseq.viz.util import compute_stats_df

In [None]:
# key to the syllable column
syll_key = 'syllables_reindexed'
# set threshold to include syllable with usages above threshold
threshold = 0.005
stats_df = compute_stats_df(moseq_df, threshold = threshold, groupby = ['group', 'uuid', 'session_name'], fps = 30, syll_key = syll_key, normalize = True)
print('stats_df shape is', stats_df.shape)
stats_df.head()

#### export moseq_df


You can export the `stats_df` to a csv file for further analysis using the following cell.

- Edit the path to point to where you want to save the dataframe in `save_path`. By default, the file will be saved to `base_dir`.

In [None]:
# import os
# stats_df.to_csv(os.path.join(project_dir, name, 'stats_df.csv'))

## Generate Behavioral Summary (Fingerprints)
Fingerprints summarize behavior by showing distributions of kinematic values and syllables.
These plots can be used as a useful diagnostic tool.
Sessions where the mouse wasn't extracted properly or moves too much or little can be identified in these plots.

The following cell generates the summary dataframe and plots the behavioral summary. The fingerprint plot will be automatically saved as png and pdf in the `plots` folder in the model directory.

- Set `n_bins` variable to an integer to specify the number of bins for the kinematic values. Set `n_bins` variable to `None` if you want the number of bins to match the number of syllables. `n_bins` does not bin syllables.
- Set `range_type` variable to 'robust' to include data ranging from 1 percentile top 99 percentile. Set `range_type` variable to 'full' to include all the data.
- Assign an `sklearn.preprocessing` object to `preprocessor` variable if you want to scale the values by session. `preprocessor` variable is set to `None`, the figure will show the proportion of data filling each bin.


In [None]:
from keypoint_moseq.viz.util import create_fingerprint_dataframe, plotting_fingerprint
from sklearn.preprocessing import MinMaxScaler, StandardScaler

stat_type = 'mean'
n_bins = 100  # resolution of distribution 
range_type = 'robust'  # robust or full
preprocessor = MinMaxScaler()

summary, range_dict = create_fingerprint_dataframe(moseq_df, stats_df, stat_type=stat_type, n_bins=n_bins, range_type=range_type)
plotting_fingerprint(summary, range_dict, preprocessor=preprocessor)

## Usage plot

Syllable statistics provide information about the behavioral patterns.

In [None]:
from keypoint_moseq.viz.util import plot_syll_stats_with_sem

# ordering of syllables, could be 'stat' or 'diff'
ordering='stat'
# groups to be plotted
groups = stats_df['group'].unique()
# name of the control group
ctrl_group='a'
# name of the experimental group
exp_group='b'

# boolean for whether the dots are connected
join=False

plot_syll_stats_with_sem(stats_df, syll_info=None, sig_sylls=None, stat='usage', ordering=ordering, max_sylls=None,
                             groups=groups, ctrl_group=ctrl_group, exp_group=exp_group, colors=None, join=True, figsize=(10, 5))

## Transition matrices
Transition matrices compactly represent the frequency any syllable transitions into any other syllable and is one way to describe structure in behavior.
The row of the transition matrix represents an incoming syllable, while the column represents the outgoing syllable and the value at a specific row and column position represent the frequency the incoming syllable transitions into the outgoing syllable.

These plots can help visualize gross changes in the structure of behavior between two experimental groups. For example, certain syllables that frequently transition into one set of syllables in one experimental condition might transition into a completely different set in another experimental condition.

In [None]:
from keypoint_moseq.viz.util import get_group_trans_mats

In [None]:
# set maximum syllable to include
max_syllables = int(stats_df.syllable.max())
print('maximum syllable to include:', max_syllables)
# select a transition matrix normalization method
normalize='bigram' # options: bigram, columns, rows

# Get modeled session uuids to compute group-mean transition graph for each group
uuid_groups = stats_df[['uuid', 'group', 'session_name']].drop_duplicates(['uuid']).to_numpy()
label_group, uuids, sessions = uuid_groups[:,1], uuid_groups[:, 0], uuid_groups[:,2]
group = list(set(label_group))
print('Group(s):', ', '.join(group))
model_labels = [results_dict[session][syll_key] for session in sessions]
# compute transition matrices and usages for each group
trans_mats, usages = get_group_trans_mats(model_labels, label_group, group, max_sylls=max_syllables, normalize=normalize)

In [None]:
# plot the trnsition matrices
fig, ax = plt.subplots(1, len(group), figsize=(12, 15), sharex=False, sharey=True)
title_map = dict(bigram='Bigram', columns='Incoming', rows='Outgoing')

# max color threshold for graphs - set to any value 
color_lim = max([x.max() for x in trans_mats])

for i, g in enumerate(group):
    h = ax[i].imshow(trans_mats[i][:max_syllables,:max_syllables], cmap='cubehelix', vmax=color_lim)
    if i == 0:
        ax[i].set_ylabel('Incoming syllable')
        plt.yticks(np.arange(0, max_syllables, 4))
    cb = fig.colorbar(h, ax=ax[i], fraction=0.046, pad=0.04)
    cb.set_label(f'{title_map[normalize]} transition probability')
    ax[i].set_xlabel('Outgoing syllable')
    ax[i].set_title(g)
    ax[i].set_xticks(np.arange(0, max_syllables, 4))
    