# Computing behavioural syllables using keypoint-moseq

In this notebook, we will explore possible simple analyses we can carry out with a trained keypoint-moseq model applied to new data.
Using a trained model is useful if you have collected data for new experiments but would like to maintain an existing set of syllables. 

We assume that you have gone through the `EPM_train_keypoint_moseq.ipynb` notebook, that explains how to create and train a keypoint-moseq model on a set of 10 videos. We also assumed you have downloaded the `mouse-EPM` folder containing the sample data.

This notebook is based on the one at https://github.com/dattalab/keypoint-moseq/blob/main/docs/keypoint_moseq_colab.ipynb. 

## A. Setup
### A1. Create a conda environment and install the required packages
In this notebook, we will use the conda environment defined in the `keypoint-moseq` repository, with an optional additional package.

In a terminal, clone the `keypoint-moseq` repository (ideally somewhere outside the `course-behavioural-analysis` repository):
```bash
git clone https://github.com/dattalab/keypoint-moseq
cd keypoint-moseq
```

Then, create the appropriate conda environment for your platform. For example, for a Linux installation with a GPU we would run:
```bash
# Linux (GPU)
conda env create -f conda_envs/environment.linux_gpu.yml
``` 
For other platforms, please the full list of commands in the [keypoint-moseq docs](https://keypoint-moseq.readthedocs.io/en/latest/install.html#install-using-conda).


This last command will create a conda environment called `keypoint_moseq`. We can activate this environment by running:
```bash
conda activate keypoint-moseq
```

Optionally, for interacting with plots in the notebook, we can install the `ipympl` package in the `keypoint-moseq` environment:
```bash
pip install ipympl 
```

Once all requirements are installed, you can re-open this notebook and select the `keypoint-moseq` kernel.

### A2. Import required packages

In [1]:
import itertools
import keypoint_moseq as kpms
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path


Using Panel interactively in VSCode notebooks requires the jupyter_bokeh package to be installed. You can install it with:

   pip install jupyter_bokeh

or:
    conda install jupyter_bokeh

and try again.



In [2]:
# To use interactive plots, uncomment the following line and run this cell (optional)
%matplotlib widget

### A3. Specify paths to trained model and new data

You should modify `DATA_DIR` to point to the directory where you downloaded the `mouse-EPM` folder containing the sample data.

In [4]:
DATA_DIR = Path.home() / "Data" / "behav-analysis-course" / "mouse-EPM"

# path to new data (SLEAP predictions for video-1)
new_data = DATA_DIR / "derivatives" / "software-SLEAP_project" / "predictions" / "video-1.predictions.analysis.h5"  

# path to kpt-moseq trained model
project_dir = DATA_DIR / "derivatives" / "software-kptmoseq_n-10_project"
model_name = '2024_09_19-15_54_42'

## B. Apply trained model to new SLEAP predictions

Load the trained model and config

In [None]:
# load model
model, data, metadata, current_iter = kpms.load_checkpoint(project_dir, model_name)

# load config 
config = lambda: kpms.load_config(project_dir)  # noqa: E731

Parse the new keypoint data and format it for the model

In [None]:

# parse keypoint data
coordinates, confidences, bodyparts = kpms.load_keypoints(str(new_data), "sleap")

# - coordinates: for each video, an array of size (nframes, n_kpts, n_spatial_dims) with the coords of the keypoints in image coord system
# - confidences: for each video, an array of size (nframes, n_kpts)
print(coordinates.keys())
print(coordinates["video-1.predictions.analysis"].shape)


In [None]:
# format data for model
data, metadata = kpms.format_data(coordinates, confidences, **config())

Apply the trained model to the new data

In [None]:
# in theory: The results for the new
# experiments will be added to the existing `results.h5` file.


# apply trained model to new data
# extracts results to variable and h5 file
results = kpms.apply_model(
    model, 
    data, 
    metadata, 
    project_dir, 
    model_name, 
    **config(), 
    results_path=(
        '/home/sminano/swc/project_teaching_behaviour/mouse-EPM-moseq-video-1/'
        '2024_09_12-18_32_02/results_video-1.h5'
    )  # ----> otherwise it overwrites results.h5 file!
)  

# optionally rerun `save_results_as_csv` to export the new results as a csv
kpms.save_results_as_csv(results, project_dir, model_name)

Inspect results

In [None]:
print(results.keys())
print(results["video-1.predictions.analysis"].keys())

## C. Compute frequency of syllables in the new data

In [None]:
####  Count frequency of syllables across all video
syllables_per_frame = results["video-1.predictions.analysis"][
    "syllable"
]  # array: (nframes, )

syllables_count = {}
for syl in np.unique(syllables_per_frame):
    syllables_count[syl] = sum(syl == syllables_per_frame)

# sort by count
syllables_count = dict(
    sorted(syllables_count.items(), key=lambda item: item[1], reverse=True)
)

# print top 10
# (the syllable ids are assigned based on their frequency in the training data)
n_frames = results["video-1.predictions.analysis"]["syllable"].shape[0]
for syl, count in list(syllables_count.items())[:10]:
    print(f"Syllable id-{syl}: {(count/n_frames)*100:.2f} % of frames")

## D. Plot ethogram for new data

In [None]:
# find lengths of continuous tracks
# itertools.groupby: generates a break or new group
# every time the value of the key function changes
syllable_chunks = [
    (key, len(list(group_iter)))
    for key, group_iter in itertools.groupby(syllables_per_frame)
]  # list of (syllable_id, len)

list_durations = [dur for syl, dur in syllable_chunks]
start_chunks = np.cumsum([0]+list_durations) - 0.5

list_colors = (
    plt.get_cmap("tab10").colors
    + plt.get_cmap("tab20b").colors
    + plt.get_cmap("Set3").colors
    + plt.get_cmap("Set1").colors
)  # 51 colors

frames_max_to_plot = 1000

fig, ax = plt.subplots(1, 1, figsize=(8,5))
rects = ax.barh(
    y=results.keys(),
    width=[syl_dur for syl_id, syl_dur in syllable_chunks],
    left=start_chunks[:-1],  # starting frame of each chunk - 0.5
    height=1,
    color=[
        list_colors[syl_id%len(list_colors)] 
        for syl_id, syl_dur in syllable_chunks
    ],
)
ax.bar_label(
    rects, 
    labels=[syl_id for syl_id, syl_dur in syllable_chunks],
    label_type='center', 
    color='white'
)
ax.set_xlim(0, frames_max_to_plot)
ax.set_xlabel('frames')
ax.yaxis.set_visible(False)

ax.set_aspect(100)
ax.set_title(*results.keys())

## E. Compute median duration per syllable and plot

In [None]:
# compute median syllable duration
median_syllable_duration = np.median([syl_dur for (syl_id, syl_dur) in syllable_chunks])

# compute median duration per syllable ID
median_duration_per_syl = {}
for syl in list(syllables_count.keys()):
    median_duration_per_syl[syl] = np.median(
        [syl_dur for (syl_id, syl_dur) in syllable_chunks if syl_id == syl]
    )  # frames

# plot median duration per syllable
fps = 30  # fps

fig, ax = plt.subplots(1, 1)
ax.scatter(
    x=median_duration_per_syl.keys(),
    y=median_duration_per_syl.values(),
)
ax.hlines(
    y=median_syllable_duration,
    xmin=-1,
    xmax=len(median_duration_per_syl) + 1,
    colors="r",
)
ax.set_xlabel("syllable ID")
ax.set_ylabel("median duration (frames)")

print(f"Median syllable duration (frames): {median_syllable_duration}")
print(f"Median syllable duration (ms): {1000*median_syllable_duration/fps}")



## F. Visualise the most frequent syllables in the new data

In [None]:
frames_max_to_plot = len(syllables_per_frame)
list_syllables_to_plot = [1, 3, 2, 4, 42, 31]

for selected_syl in list_syllables_to_plot:

    fig, ax = plt.subplots(1, 1, figsize=(8,5))
    rects = ax.barh(
        y=results.keys(),
        width=[syl_dur for syl_id, syl_dur in syllable_chunks],
        left=start_chunks[:-1],  # starting frame of each chunk
        height=1,
        color=[
            'red' if syl_id==selected_syl 
            else 'grey' 
            for syl_id, _ in syllable_chunks
        ],
    )
    # ax.bar_label(rects, label_type='center', color='white')
    ax.set_xlim(0, frames_max_to_plot)
    ax.set_xlabel('frames')
    ax.yaxis.set_visible(False)

    ax.set_aspect(int(frames_max_to_plot/10))
    ax.set_title(f'{list(results.keys())[0]} - syllable {selected_syl}')

## G. Plot centroid location for the top three syllables

In [None]:
fig, ax = plt.subplots(1,1)
top_k_syllables = 3

centroid_array = results['video-2.predictions.analysis']['centroid']
syllable_array = results['video-2.predictions.analysis']['syllable']

ax.scatter(
    x=centroid_array[:,0],
    y=centroid_array[:,1],
    s=1,
    c=[
        list_colors[int(syl_id)%len(list_colors)] 
        if syl_id in list(range(top_k_syllables)) else (0.5, 0.5, 0.5)
        for syl_id in syllable_array
    ],
)
# ax.legend([f'syllable {syl}' for syl in list(range(top_k_syllables))])
ax.set_title(f'{list(results.keys())[0]} - top {top_k_syllables} syllables')

## Appendix

The results in the .h5 file follow the structure below. 
```
    results.h5
    ├──recording_name1
    │  ├──syllable      # syllable labels (z)
    │  ├──latent_state  # inferred low-dim pose state (x)
    │  ├──centroid      # inferred centroid (v)
    │  └──heading       # inferred heading (h)
    ⋮
```

They can be reloaded at a later time using `kpms.load_results`.

Check the docs for an [in-depth explanation of the modeling results](https://keypoint-moseq.readthedocs.io/en/latest/FAQs.html#interpreting-model-outputs).