This notebook shows how to setup a new project, train a keypoint-MoSeq model and visualize the resulting syllables. You can load keypoint tracking results from SLEAP, DeepLabCut, or using your own custom format. We provide an [example DeepLabCut dataset](https://drive.google.com/drive/folders/1UNHQ_XCQEKLPPSjGspRopWBj6-YNDV6G?usp=share_link) that can be used for the tutorial.

# Colab setup

- Make a copy of this notebook if you plan to make changes and want them saved
- To use the [example data](https://drive.google.com/drive/folders/1UNHQ_XCQEKLPPSjGspRopWBj6-YNDV6G?usp=share_link), downlod it your google drive or create a shortcut to it
- To use your own data, upload it to google drive
- Go to "Runtime">"change runtime type" and select "Python 3" and "GPU"

### Install keypoint MoSeq and mount your google drive

In [None]:
! pip install --upgrade "jax[cuda]==0.3.22" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
! pip install keypoint-moseq

from google.colab import drive
drive.mount('/content/drive')

# Project setup
Create a new project directory with a keypoint-MoSeq `config.yml` file.

In [None]:
import keypoint_moseq as kpms

project_dir = '/content/drive/MyDrive/demo_project/'
config = lambda: kpms.load_config(project_dir)

### Option 1: Setup from DeepLabCut using example data

In [None]:
dlc_config = '/content/drive/MyDrive/dlc_project/config.yaml'
kpms.setup_project(project_dir, deeplabcut_config=dlc_config)

### Option 2: Setup from SLEAP

In [None]:
sleap_file = 'XXX' # choose a .h5 file for one of your recordings
kpms.setup_project(project_dir, sleap_file=sleap_file)

### Options 3: Manual setup

In [None]:
bodyparts=[
    'tail', 'spine4', 'spine3', 'spine2', 'spine1',
    'head', 'nose', 'right ear', 'left ear']

skeleton=[
    ['tail', 'spine4'],
    ['spine4', 'spine3'],
    ['spine3', 'spine2'],
    ['spine2', 'spine1'],
    ['spine1', 'head'],
    ['nose', 'head'],
    ['left ear', 'head'],
    ['right ear', 'head']]

video_dir='/content/drive/MyDrive/dlc_project/videos/'

kpms.setup_project(
    project_dir,
    video_dir=video_dir,
    bodyparts=bodyparts,
    skeleton=skeleton)

## Edit the config file

The config can be edited in a text editor or using the function `kpms.update_config`, as shown below. In general, the following parameters should be specified for each project:

- `bodyparts` (name of each keypoint; automatically imported from SLEAP/DeepLabCut)
- `use_bodyparts` (subset of bodyparts to use for modeling, set to all bodyparts by default; for mice we recommend excluding the tail)
- `anterior_bodyparts` and `posterior_bodyparts` (used for rotational alignment)
- `video_dir` (directory with videos of each experiment)

Edit the config as follows for the [example DeepLabCut dataset](https://drive.google.com/drive/folders/1UNHQ_XCQEKLPPSjGspRopWBj6-YNDV6G?usp=share_link):

In [None]:
kpms.update_config(
    project_dir,
    video_dir='/content/drive/MyDrive/dlc_project/videos/',
    anterior_bodyparts=['nose'],
    posterior_bodyparts=['spine4'],
    use_bodyparts=[
        'spine4', 'spine3', 'spine2', 'spine1',
        'head', 'nose', 'right ear', 'left ear'])

## Load data

Data can be loaded from DeepLabCut, SLEAP or from any another source as long as it has the following format, where K is the number of keypoints and D is 2 or 3
- `coordinates`: dict from session names to keypoint coordinate arrays of shape (T,K,D). Each key should start with its video name (e.g. `coordinates["experiment1_etc"]` would correspond to `experiment1.avi`. In general this will already be true if importing from SLEAP or DeepLabCut).
    
- `confidences`: dict from session names to **nonnegative** keypoint confidence arrays of shape (T,K). Confidences are optional (they are used to set the error prior for each observation).

In [None]:
# load data from DeepLabCut
dlc_results_directory = '/content/drive/MyDrive/dlc_project/videos/'
coordinates, confidences, bodyparts = kpms.load_deeplabcut_results(dlc_results_directory)

# load data from SLEAP
# sleap_results_directory = '...'
# coordinates, confidences, bodyparts = kpms.load_sleap_results(sleap_results_directory)

# format data for modeling
data, labels = kpms.format_data(coordinates, confidences=confidences, **config())

## Calibration [disabled in colab]

The purpose of calibration is to learn the relationship between error and keypoint confidence scores. The resulting regression coefficients (`slope` and `intercept`) are used during modeling to set the noise prior on a per-frame, per-keypoint basis. **This step is disabled in colab**. In any case it can safely be skipped since the default parameters are fine for most datasets.  

## Fit PCA

Run the cell below to fit a PCA model to aligned and centered keypoint coordinates. The model is saved to ``{project_dir}/pca.p`` and can be reloaded using ``kpms.load_pca``. Two plots are generated: a cumulative [scree plot](https://en.wikipedia.org/wiki/Scree_plot) and a depiction of each PC, where translucent nodes/edges represent the mean pose and opaque nodes/edges represent a perturbation in the direction of the PC. 

- After fitting, edit `latent_dimension` in the config. 
- A good heuristic is the number of dimensions needed to explain 90% of variance, or 10 dimensions - whichever is lower.  

In [None]:
pca = kpms.fit_pca(**data, **config())
kpms.save_pca(pca, project_dir)

kpms.print_dims_to_explain_variance(pca, 0.9)
kpms.plot_scree(pca, project_dir=project_dir)
kpms.plot_pcs(pca, project_dir=project_dir, **config())

# use the following to load an already 
# pca = kpms.load_pca(project_dir)

In [None]:
kpms.update_config(project_dir, latent_dim=4)

# Model fitting

Fitting a keypoint-MoSeq model involves:
1. **Initialization:** Auto-regressive (AR) parameters and syllable sequences are randomly initialized using pose trajectories from PCA.
2. **Fitting an AR-HMM:** The AR parameters, transition probabilities and syllable sequences are iteratively updated through Gibbs sampling. 
3. **Fitting the full model:** All parameters, including both the AR-HMM as well as centroid, heading, noise-estimates and continuous latent states (i.e. PCA trajectories) are iteratively updated through Gibbs sampling. This step is especially useful for noisy data.
4. **Apply the trained model:** The learned model parameters are used to infer a syllable sequence for each experiment. This step should always be applied at the end of model fitting, and it can also be used later on to infer syllable sequences for newly added data.

## Setting hyperparameters

There are two ways to change hyperparameters:
1. Update the config using `kpms.update_config` and then re-initialize the model
2. Change the model directly via `kpms.update_hypparams`

In general, the main hyperparam that needs to be adjusted is **kappa**, which sets the time-scale of syllables. Higher kappa leads to longer syllables. For this tutorial we chose kappa values that yielded a median syllable duration of 400ms (12 frames). In general, you will need to tune kappa for each new dataset based on the intended syllable time-scale. **You will need to pick two kappas: one for AR-HMM fitting and one for the full model.**
- We recommend iteratively updating kappa and refitting the model until the target syllable time-scale is attained.  
- Model fitting can be stopped at any time by interrupting the kernel, and kappa can be adjusted as described above.
- The full model will generally require a lower value of kappa to yield the same target syllable durations.


## Initialization

In [None]:
# optionally update kappa in the config before initializing 
# kpms.update_config(project_dir=project_dir, kappa=NUMBER)

# initialize the model
model = kpms.init_model(data, pca=pca, **config())

## Fitting an AR-HMM

In addition to fitting an AR-HMM, the function below:
- generates a name for the model and a new directory in `project_dir`
- saves a checkpoint every 10 iterations from which fitting can be restarted
    - a single checkpoint file contains the full history of fitting, and can be used to restart fitting from any iteration
- plots the progress of fitting every 10 iterations, including
    - the distributions of syllable frequencies and durations for the most recent iteration
    - the change in median syllable duration across fitting iterations
    - the syllable sequence across iterations in a random window

In [None]:
model, history, name = kpms.fit_model(model, data, labels, ar_only=True, 
                                      num_iters=50, project_dir=project_dir)

## Fitting the full model

The following code fits a full keypoint-MoSeq model, using the results of AR-HMM fitting for initialization
- If using your own data, you may need to try a few values of kappa at this step. 
- Use `kpms.revert` to resume from the same starting point each time you restart fitting

In [None]:
# load model checkpoint generated during step 2 (AR-HMM fitting)
checkpoint = kpms.load_checkpoint(project_dir=project_dir, name=name)

# the following will cause fitting to resume from iteration 50, rather than the most recent iteration
# checkpoint = kpms.revert(checkpoint, 50)

# update kappa to maintain the desired syllable time-scale
checkpoint = kpms.update_hypparams(checkpoint, kappa=9e4)

model, history, name = kpms.resume_fitting(**checkpoint, project_dir=project_dir, 
                                           ar_only=False, num_iters=200)

## Extract model results

Extract modeling results for each session and save the results to `{project_dir}/{name}/results.h5`. The results are stored as follows, and can be reloaded at a later time using `kpms.load_results`:
```
    results.h5
    ├──session_name1
    │  ├──estimated_coordinates  # denoised coordinates
    │  ├──syllables_reindexed    # syllables reindexed by frequency
    │  ├──syllables              # non-reindexed syllables labels (z)
    │  ├──latent_state           # inferred low-dim pose state (x)
    │  ├──centroid               # inferred centroid (v)
    │  └──heading                # inferred heading (h)
    ⋮
```
Checkout the docs for an [in-depth explanation of the modeling results](https://keypoint-moseq.readthedocs.io/en/latest/FAQs.html#interpreting-model-outputs).

In [None]:
# load saved model checkpoint
checkpoint = kpms.load_checkpoint(project_dir=project_dir, name=name)

# extract results
results = kpms.extract_results(project_dir=project_dir, **config(), **checkpoint)

### Save results in csv format

After extracting to an h5 file, the results can optionally be saved in csv format. A separate csv file will be created for each session and saved to `{project_dir}/{name}/results/`. 

In [None]:
# optionally save results as csv
kpms.save_results_as_csv(project_dir=project_dir, name=name, **config())

## Apply to new data

The code below shows how to apply a trained model to new data. This is useful if you have performed new experiments and would like to maintain an existing set of syllables. The results for the new experiments will be added to the existing `results.h5` file. **This step is optional and can be skipped if you do not have new data to add**.

In [None]:
# load saved model checkpoint
checkpoint = kpms.load_checkpoint(project_dir=project_dir, name=name)

# load new data (e.g. from deeplabcut)
new_data = 'path/to/new/data/' # can be a file, a directory, or a list of files
coordinates, confidences = kpms.load_deeplabcut_results(new_data)

results = kpms.apply_model(
    coordinates=coordinates, 
    confidences=confidences, 
    project_dir=project_dir, 
    name=name, 
    pca=kpms.load_pca(project_dir),
    params=checkpoint['params'],
    hypparams=checkpoint['hypparams'],
    **config())

# optionally rerun `save_results_as_csv` to export the new results
kpms.save_results_as_csv(project_dir=project_dir, name=name, **config())

# Visualize syllables

## Trajectory plots
Generate plots showing the average trajectory of poses associated with each given syllable. 

In [None]:
kpms.generate_trajectory_plots(coordinates=coordinates, name=name, project_dir=project_dir, **config())

## Crowd & grid movies
Generate video clips showing examples of each syllable.

In [None]:
kpms.generate_grid_movies(name=name, project_dir=project_dir, coordinates=coordinates, **config())
kpms.generate_crowd_movies(name=name, project_dir=project_dir, coordinates=coordinates, **config())