# Moseq2 App

<img src="https://drive.google.com/uc?export=view&id=1jCpb4AQKBasne2SoXw_S1kwU7R9esAF6">

MoSeq2 software toolkit for unsupervised modeling and characterization of animal behavior. Moseq transforms depth recordings of animals into a rich description of behavior as a series of reused and stereotyped motifs, also known as 'syllables'. 

This notebook begins with depth recordings (see the [data acquisition overview](#Data-Acquisition-Overview) below) and transforms this data through the steps of: 

- **[Extraction](#Raw-Data-Extraction)**: The animal is segmented from the background and its position and heading direction are aligned across frames.
- **[Dimensionality reduction](#Principal-Component-Analysis-(PCA))**: Raw video is de-noised and transformed to low-dimensional pose trajectories using principal component analysis (PCA).
- **[Model training](#ARHMM-Modeling)**: Pose trajectories are modeled using an autoregressive hidden Markov model (AR-HMM), producing a sequence of syllable labels.
- **[Analysis](#Visualize-Analysis-Results)**: Model output is reported through visualization and statistical analysis.

These are notebook shortcuts you can click on to navigate to that part of the notebook.

### Resources
We've provided links to the MoSeq documentation and recent publications that have used this software.
- Documentation of all MoSeq functions (links to pdfs):
    - [Wiki](https://github.com/dattalab/moseq2-app/wiki)
    - [App](https://github.com/dattalab/moseq2-app/blob/release/Documentation.pdf)
    - [Extract](https://github.com/dattalab/moseq2-extract/blob/release/Documentation.pdf)
    - [PCA](https://github.com/dattalab/moseq2-pca/blob/release/Documentation.pdf)
    - [Model](https://github.com/dattalab/moseq2-model/blob/release/Documentation.pdf)
    - [Viz](https://github.com/dattalab/moseq2-viz/blob/release/Documentation.pdf)
- Publications
    - [Mapping Sub-Second Structure in Mouse Behavior](http://datta.hms.harvard.edu/wp-content/uploads/2018/01/pub_23.pdf)
    - [The Striatum Organizes 3D Behavior via Moment-to-Moment Action Selection](http://datta.hms.harvard.edu/wp-content/uploads/2019/06/Markowitz.final_.pdf)
    - [Revealing the structure of pharmacobehavioral space through motion sequencing](https://www.nature.com/articles/s41593-020-00706-3)
    - [Q&A: Understanding the composition of behavior](http://datta.hms.harvard.edu/wp-content/uploads/2019/06/Datta-QA.pdf)
    
### Feedback

If you would like to leave us feedback on how you liked or disliked this notebook,
or if you want specific and new features, please fill out [this survey](https://forms.gle/FbtEN8E382y8jF3p6).
    
### Data Acquisition Overview
MoSeq2 takes animal depth recordings as input. We we have developed a [data acquisition pipeline](https://github.com/dattalab/moseq2-app/wiki/Setup:-acquisition-software) for the second generation `Xbox Kinect` depth camera. We suggest following our [data acquisition tutorial](https://github.com/dattalab/moseq2-app/wiki/Acquisition) for doing recordings. 

MoSeq2 also accepts depth recordings from an `Azure Kinect` camera as well as the `Intel RealSense` depth camera. These cameras have their own means of acquiring data that is built-in to their respective Development Kits.

**We recommend recording more than 10 hours of depth video (~1 million frames at 30 frames per second) to ensure quality MoSeq models**

## Notebook Setup

<img src="https://drive.google.com/uc?export=view&id=1Zkd0tATi8r2ENHvN8OczIrEf4K8PFmhM">

### Check to see if you're running python from the correct conda enviornment

If you performed the recommended installation, you should see the `sys.executable` path point to the python path within the `moseq2-app` environment. I.e.,
```python
import sys
print(sys.executable)
# /Users/wgillis/miniconda3/envs/moseq2-app/bin/python
```

### Check if the dependencies are found

Run the following cell to check if `moseq2-app` is installed in your current conda kernel. The latest working version number is `0.1.0`.

In [None]:
import sys
import moseq2_app

print('Python path:', sys.executable)
print('MoSeq2 app version:', moseq2_app.__version__)

## Data file organization

The currently accepted depth data extensions are `.dat`, `.tar.gz`, `.avi` and `.mkv`. `.dat` files are generated from our kinect2 data acquisition software, `.tar.gz` files are compressed depth files using the data acquisition software, `.avi` files are compressed `.dat` files using `moseq2-extract`, and `.mkv` files are generated from Microsoft's recording software for the Azure Kinect.

After performing data acquisition, store all of your recording folders in the same folder (shown below). We recommend that you save a copy of this notebook in the same folder as your data to: (1) access them in this notebook, (2) have a unique MoSeq pipeline notebook for each project, and (3) enable the videos this notebook generates to load. 

```
.
└── Data_Directory/
    ├── Main-MoSeq2-Notebook.ipynb (running)
    ├── session_1/ ** - the folder containing all of a single session's data
    ├   ├── depth.dat        # depth data - the recording itself
    ├   ├── depth_ts.txt     # timestamps - csv/txt file of the frame timestamps
    ├   └── metadata.json    # metadata - json file that contains the rodent's info (group, subjectName, etc.)
    ...
    ├── session_2/ **
    ├   ├── depth.dat
    ├   ├── depth_ts.txt
    └── └── metadata.json

```

__Note: if your data was acquired using an Azure Kinect or Intel RealSense depth camera, you will not have `depth_ts.txt` or `metadata.json` in your session directories. MoSeq2-Extract will automatically generate metadata json files for all acquired sessions missing the file.__

### Notebook Progress File

This notebook generates a `progress.yaml` file that stores the filepaths to data generated from this notebook. For example it will contain paths to:
- aggregated extractions
- PC scores of the extractions
- model results

In the case that your notebook kernel is shutdown for any reason, you can load the progress file to 'restore' your progress. The progress file does **not** track MoSeq pipeline operations that were executed outside of this notebook (for example, if you were to run PCA using the CLI). You can manually modify the paths in the progress file or loaded dictionary to record filepaths for these external operations.

__To restore previously computed variables, look for the cells following the `Restore Progress Variables` label.__

### Restore Progress Variables

- Use this cell to load your notebook analysis progress. We recommend running this notebook from the folder where your data is located so the generated media will display properly. In that case, you can specify the `base_dir` as `./` (or the current folder).

`check_progress` will print progress bars for each pipeline step in the notebook. 
- The extraction progress bar indicates total the number of extracted sessions detected in the provided `base_dir` path.
- It prints the session names that haven't been extracted. __Note: the progress does not reflect the contents of the aggregate_results/ folder.__
- The remainder of the progress bars are derived from reading the paths in the `progress_paths` dictionary, filling up the bar if the included paths are found.

In [None]:
from os.path import join
from moseq2_app.gui.progress import check_progress

base_dir = './' # Add the path to your data folder here.
# We recommend that you run this notebook in the same folder as your data. In that case, you don't have to change base_dir
progress_filepath = join(base_dir, 'progress.yaml')

progress_paths = check_progress(base_dir, progress_filepath)

### Generate Configuration Files

The `config.yaml` will be used to hold all configurable parameters for all steps in the MoSeq pipeline. The parameters used will be added to this file as you progress through the notebook. You can then use it to run an identical pipeline in future analyses, or directly configure parameters from there when debugging cells.

In [None]:
from os.path import join
from moseq2_extract.gui import generate_config_command
from moseq2_app.gui.progress import update_progress, restore_progress_vars

progress_paths = restore_progress_vars(progress_filepath)

config_filepath = join(base_dir, 'config.yaml')

print(f'generating file in path: {config_filepath}')
generate_config_command(config_filepath)
progress_paths = update_progress(progress_filepath, 'config_file', config_filepath)

A configuration file has been created in the base directory (depicted below).

```
.
└── Data_Directory/
    ├── config.yaml **
    ├── session_1/ 
    ├   ├── depth.dat        
    ├   ├── depth_ts.txt     
    ├   └── metadata.json    
    ...
    ├── session_2/ 
    ├   ├── depth.dat
    ├   ├── depth_ts.txt
    └── └── metadata.json
```

### Download a Flip File

MoSeq2 uses a flip classifier to guarantee that the mouse is always oriented facing east in the extractions. The flip classifiers we provide __are trained for experiments run with C57BL/6 mice using with Kinect v2 depth cameras__.

If your dataset does not work with these flip classifiers, consider training your own. Click [this link](https://github.com/dattalab/moseq2-app/tree/jupyter/) to view the flip-classifier training notebooks. Once you have it trained, add the path to the `config.yaml` file.

In [None]:
from moseq2_extract.gui import download_flip_command
# selection=0 - large mice with fibers (default)
# selection=1 - adult male C57s
# selection=2 - mice with Inscopix cables
download_flip_command(base_dir, config_filepath, selection=1)

# Raw Data Extraction

<img src="https://drive.google.com/uc?export=view&id=1vH3fs0Xu5c4gyzFChzV6r_5JGi6JReW2">

The MoSeq2-Extract module is used to segment the mouse from the background and create data files for dimensionality reduction and modeling. The resulting `.h5` and `.yaml` data files stored in a `proc` subfolder created in the session's folder by default.`.mp4` videos are also generated and primarily used for quality assurance after extraction.

## Interactive ROI Detection Tool

Use this interactive tool to detect your recordings' mouse extraction configuration parameters prior to extracting all of your data. This tool can also be used to catch possibly corrupted or inconsistent sessions, and handle diagnosing ROI detection/extraction errors and saving unique parameter sets for solved edge cases.

<center><h3>Widget Guide</h3></center>

<img src="https://drive.google.com/uc?export=view&id=1kgMrz-Py4x1k-WwyoHS95cbPRMdJuwvy">

<h3>Instructions</h3>
<br>

<table>
    <tr>
        <td style="width: 45%;">
            <ol>
                <li style="text-align:left; font-size:15px">
                    Click on any row in the session selector to load that session's data to the view.
                </li>
                <li style="text-align:left; font-size:15px">
                    If the indicator next to the session's name is green, then the session is considered ready for extraction. A red indicator can either mean the session has not been checked yet, or its unique parameter set is incorrect.
                </li>
                <li style="text-align:left; font-size:15px">
                    Adjust the Depth Range Selector to include the depth range of the detected bucket floor distance (which can be found by hovering over the Background image with your mouse).
                </li>
                <li style="text-align:left; font-size:15px">
                    If the mouse seems to be cropped when at the bucket edge, increase the dilate iterations to enlarge the size of the included floor area.
                </li>
                <li style="text-align:left; font-size:15px">
                    Use the Rodent Height Threshold Slider to remove any noise/speckle from the bucket floor or walls. 
                    <ul>
                        <li style="text-align:left; font-size:15px">
                            Ensure the Min-Height is small enough to only filter out floor reflects and the mouse tail. (Filtering out the tail helps keep the cropped image oriented facing east). 
                        </li>
                        <li style="text-align:left; font-size:15px">
                            Ensure the Max-Height is large enough to include the highest possible distance the mouse can rear up the bucket without including any possible extrema anomalies from the bucket walls.
                        </li>
                        <li style="text-align:left; font-size:15px">
                            To explore the session's mouse heights, hover over either of the bottom 2 plots to view the mouse height.
                        </li>
                    </ul>
                </li>
                <li style="text-align:left; font-size:15px">
                    Use the current frame selector slider to change the displayed session frame in the bottom 2 plots.
                </li>
                <li style="text-align:left; font-size:15px">
                    Change the frame range slider values to adjust the segments of the video to extract, then click the "Extract Sample" button to trigger an extraction and view the results.
                </li>
                <li style="text-align:left; font-size:15px">
                    Once you have found a parameter set that is satisfactory, click the "Check All Sessions" button to test the parameters on all of the found sessions, flagging any outliers.
                </li>
                <li style="text-align:left; font-size:15px">
                    In the case that no sessions were flagged, click the "Save Parameters" button to save the currently displayed and configured parameters to the inputted config file, as well as the individual session config file.
                </li>
                <li style="text-align:left; font-size:15px">
                    Otherwise, if a session is flagged, click on the session in the Session Selector, view the text indicator for the the error details, then adjust the parameters until the session passes, and finally save the parameters.
                    <ul>
                        <li style="text-align:left; font-size:15px">
                            If a session appears to have a passing ROI+extraction but is indicated as a flagged, you can use the "Mark Passing" button to manually accept the session's parameter set. This is the case when a session's ROI area pixel count is outside of the acceptance threshold when compared with the latest passing session's pixel count.
                        </li>
                    </ul>
                </li>
            </ol>
            <p style="text-align:left; font-size:15px">
                    Note that you can also manually edit the slider values by clicking on the numbers to activate keyboard editing.
            </p>
        </td>
        <td>
            <img style="display:contents" src="https://drive.google.com/uc?export=view&id=1jwLb1Tzpx0iAl89RF7z9bnI-sjFftTDh">
        </td>
    </tr>
</table>

__Notice: if cell seems to be running out of memory after first use, set `compute_all_bgs=False` to reduce the memory pressure.__

If you are using an alternative `flip_classifier`, `crop_size`,  or would like to edit other extraction preprocessing parameters, use the following format:
```
# Read in the config file
with open(progress_paths['config_file'], 'r') as f:
    config_data = yaml.safe_load(f)
    
# Edit its contents
config_data['crop_size'] = (100, 100) # new crop size
config_data['flip_classifier'] = './alternative-flip-classifier.pkl' # updated flip classifier path
...
config_data['gaussfilter_space'] = [2.5, 2] # new spatial filtering kernel size
config_data['medfilter_time'] = [3] # new temporal filtering kernel size

# Filtering out head-fixed cables?
config_data['cable_filter_iters'] = 10 # new number of cable filtering iterations
config_data['cable_filter_size'] = (7, 7) # new spatial filter kernel size

# Write the changes back to the config file before running the interactive_roi_detector command.
with open(progress_paths['config_file'], 'w') as f:
    yaml.safe_dump(config_data, f)
```

In [None]:
from os.path import join
import ruamel.yaml as yaml
from moseq2_app.main import interactive_roi_detector
from moseq2_app.gui.progress import update_progress, restore_progress_vars

progress_paths = restore_progress_vars(progress_filepath)

session_config_path = join(progress_paths['base_dir'], 'session_config.yaml')
progress_paths = update_progress(progress_filepath, 'session_config', session_config_path)

with open(progress_paths['config_file'], 'r') as f:
    config_data = yaml.safe_load(f)

config_data['camera_type'] = 'kinect' # 'kinect', 'azure' or 'realsense'
config_data['crop_size'] = (80, 80)

# if using azure or realsense, increase the noise_tolerance
config_data['noise_tolerance'] = 30

with open(progress_paths['config_file'], 'w') as f:
    yaml.safe_dump(config_data, f)

interactive_roi_detector(base_dir, progress_paths, compute_all_bgs=True)

### Restore Progress Variables

In [None]:
from os.path import join
from moseq2_app.gui.progress import restore_progress_vars

base_dir = './' # User-defined absolute path

progress_filepath = join(base_dir, 'progress.yaml')

progress_paths = restore_progress_vars(progress_filepath)

## Extract Session(s)

- If `extract_all=False`, the cell will prompt you to choose whether you would like to extract individual sessions, (empty string to extract all of them). Enter your selection, and then wait for the extraction to complete to preview them.
- If `skip_extracted=True`, the command will only search for (and list) sessions that have not been previously extracted.

__Note: If sessions are not listed when running the cell, ensure your selected extension matches that of your depth files.__

In [None]:
from moseq2_extract.gui import extract_found_sessions

# include the file extensions for the depth files you would like to search for and extract.
extensions = ['.avi'] # and/or .dat, .mkv

extract_found_sessions(base_dir, progress_paths['config_file'], extensions, extract_all=True, skip_extracted=True)

This is what your directory structure should look like once the process is complete:

```
.
├── config.yaml
├── session_1/
├   ...
├   └── proc/ **
├   ├   ├── roi.tiff
├   ├   ...
├   ├   ├── results_00.yaml+h5 ** (represents .h5 and .yaml files)
├   └   └── results_00.mp4 ** (extracted video)
└── session_2/
├   ...
├   └── proc/ **
├   ├   ├── roi.tiff
├   ├   ...
├   ├   ├── results_00.yaml+h5 **
└   └   └── results_00.mp4 **
        
```

## Run Extraction Validation Tests

Once all the extractions are complete, run this cell to run some data validation tests. The tests will either emit an error or a warning. 
- An __error__ indicates that the session is corrupted in some way and should not be included in the following pipeline steps.
- The __warning__ tests are meant to primarily give a better idea of how standarized the dataset is, with respect to the captured sizes and certain behavior metrics. Allowing you to make a more conscious decision on whether to include data in the subsequent steps.
  - Depending on the experimental conditions, a warning can indicate that the session may need to be inspected prior to continuing into the PCA step. 
  - Experimenters using sedating or stimulating drugs (or experiments of the like) may ignore position and scalar related warning anomalies. Since the conditions may be expected to cause divergences in general motility, the warnings can therefore be interpretted as a confirmation that the captured experimental group data differs from that of the controls. 

Error raising tests: 
- Count Dropped Frames: an error is raised if a session is missing more than 5% of the frames, given the session's accompanying timestamp file exists.
- Missing Mouse Check: raises an error if a mouse is missing from the video for any reason for >5% of the session's total frames.
- Scalar Anomaly: raises an error if >5% of a session's computed scalar values are NaN.

Warning raising tests:
- Size Anomaly: Warning is raised when a mouse's captured body-area is less than 2 standard deviations from the mean size throughout the session.
- Scalar Anomaly: Warning is raised with a list of specified scalars if some of the mean scalar values are outside of the 1st through 3rd quartile range. 
- Position Anomaly: There are two cases that raise warnings:
    1. Mouse is stationary for >5% of the session.
    2. Mouse's position PDF is at least 2 standard deviations away from the mean of all the sessions, measured using Kullback–Leibler divergence.
     - This anomaly can indicate that a mouse has explored a much larger or smaller region of the arena compared to the remainder of the dataset.

To diagnose certain scalar anomalies, use the [Scalar Summary Cell](#Compute-Scalar-Summary) below to graph any desired scalar value.
   
The following cell will run the tests and emit the warnings and errors if any are found.

In [None]:
from moseq2_app.main import validate_extractions

validate_extractions(base_dir)

## Preview Extractions

Run this cell to launch an interactive session selection widget to load and preview any extracted session.

In [None]:
from moseq2_app.main import preview_extractions

preview_extractions(base_dir)

### Aggregate your results into one folder and generate an index file.

The following cell will search through your base directory for the `proc/` folders in each session, and copy them all in a single directory. 

Then it will generate the `moseq2-index.yaml` file by searching for all the metadata found in the `results_00.h5`/`results_00.yaml` files, and consolidate all that information in one file, assigning each session to a `default` group.

The `aggregate_results/` folder contains all the data you need to run the rest of the pipeline. The PCA will only train on data included in that folder, and same for the model.

The `moseq2-index.yaml` file contains all the sessions+metadata that are included in `aggregate_results/`, it will also be heavily used in the visualization steps to plot different mouse and/or group statistics.

__Important Note: The index file contains UUIDs for each session which are newly generated during the extraction step. These UUIDs are referenced throughout the pipeline, so if you re-extract a session, ensure that you re-aggregate your data to ensure all the UUIDs are up-to-date BEFORE the PCA step.__ Not updating the index file could cause `KeyError`s to occur when referencing the extracted data and/or the pca_scores with the model results.

In [None]:
from os.path import join
from moseq2_app.gui.progress import restore_progress_vars
from moseq2_extract.gui import aggregate_extract_results_command

progress_paths = restore_progress_vars(progress_filepath)

recording_format = '{start_time}_{session_name}_{subject_name}' # filename formats for the copied extracted data files

# directory NAME to save all metadata+extracted videos to with above respective name format
aggregate_results_dirname = 'aggregate_results/'

train_data_dir = join(base_dir, aggregate_results_dirname)
update_progress(progress_filepath, 'train_data_dir', train_data_dir)

# the subpath indicates to only aggregate extracted session paths with that subpath, only change if aggregating data from a different location
index_filepath = aggregate_extract_results_command(base_dir, recording_format, aggregate_results_dirname)
progress_paths = update_progress(progress_filepath, 'index_file', index_filepath)

The aggregate results folder will be saved in your base directory,
resulting in the following directory (sample) structure where the base directory contains the notebook:

```
.
├── aggregate_results/ **
├   ├── session_1_results_00.h5+yaml ** # session 1 compressed extraction + metadata 
├   ├── session_1_results_00.mp4 ** # session 1 extracted video
├   ├── session_2_results_00.h5+yaml ** # session 2 compressed extraction + metadata 
├   └── session_2_results_00.mp4 ** # session 2 extracted video
├── config.yaml
├── moseq2-index.yaml ** # index file
├── session_1/
└── session_2/
```

__Notice your index file has also been generated in your base directory, and it's general initial structure is shown below.__

```
files:
- group: default
  metadata:
      ...
      SubjectName: control_mouse_1
      SessionName: day1
  path: [/path/to/results_00.h5, /path/to/results_00.yaml]
  uuid: 11dc6c26-0de6-4145-9bcc-a9ec200b667e
- group: default
  metadata:
      ...
      SubjectName: drug_mouse_1
      SessionName: day1
  path: [/path/to/results_00.h5, /path/to/results_00.yaml]
  uuid: 16d76d24-35c3-4ca8-aedc-c12456abb4c4
...
pca_path: ./_pca/pca_scores.h5
```

## Specify Groups

MoSeq using groups in the moseq2-index.yaml file to indicate whether your collected sessions are representing a single experimental group, or many different groups that you would like to compare while modeling and visualizing.

Specifying groups also helps distinguishing the data points in the scalar and heatmap plots.

The index file requires that all your sessions have a metadata.json file in order to successfully assign each recorded subject or session to a group.

Use this GUI to input the group names associated with all the sessions. 
- You can click on the column names to sort the index file.
- You can use your keyboard to select multiple rows.
- Enter your desired group name in the text input and click `Set Group` to update all the associated session rows.
- Once all your groups are set, click the `Update Index File` button to update the file.

In [None]:
from moseq2_app.main import interactive_group_setting
from moseq2_app.gui.progress import restore_progress_vars

progress_paths = restore_progress_vars(progress_filepath)

interactive_group_setting(progress_paths['index_file'])

## Compute Scalar Summary

Use the following command to compute some scalar summary information about your modeled groups, such as average velocity, height, etc.

This graph is meant to give you an idea of whether your extractions were consistent throughout the sessions. If you have a large standard deviation in mouse length/width when your mice are all the same size in reality, then there may have been an error in the extraction or acquisition.

In [None]:
from glob import glob
from os.path import join
from IPython.display import display, Image
from moseq2_viz.gui import plot_scalar_summary_command
from moseq2_app.gui.progress import restore_progress_vars

progress_paths = restore_progress_vars(progress_filepath)

# Prefix name of the saved scalar position and summary graphs
output_file = join(progress_paths['plot_path'], 'scalars') 

# Scalars to display
show_scalars = ['velocity_2d_mm', 'velocity_3d_mm',
                'height_ave_mm', 'width_mm', 'length_mm']

colors = None # None for default colors; otherwise use list

scalar_df = plot_scalar_summary_command(progress_paths['index_file'], output_file, 
                                        show_scalars=show_scalars, 
                                        colors=colors)

# Graph the output
display(Image(output_file+'_summary.png'))

## Plot Position Heatmaps For Each Session
Each heatmap will be titled with the session's subject name and group.

In [None]:
from glob import glob
from os.path import join
from IPython.display import display, Image
from moseq2_viz.gui import plot_verbose_position_heatmaps
from moseq2_app.gui.progress import restore_progress_vars

progress_paths = restore_progress_vars(progress_filepath)

output_file = join(progress_paths['plot_path'], 'session_heatmaps') 
plot_verbose_position_heatmaps(progress_paths['index_file'], output_file)

# Graph the output
display(Image(output_file+'.png'))

## Plot Group Mean Position Summary

These plots will give you a good idea of the general captured hyperactivity level and amount of area exploration in each of your experimental groups.

In [None]:
from glob import glob
from os.path import join
from IPython.display import display, Image
from moseq2_app.gui.progress import restore_progress_vars
from moseq2_viz.gui import plot_mean_group_position_heatmaps_command

progress_paths = restore_progress_vars(progress_filepath)

output_file = join(progress_paths['plot_path'], 'group_heatmaps') 
plot_mean_group_position_heatmaps_command(progress_paths['index_file'], output_file)

# Graph the output
display(Image(output_file+'.png'))

***
<center><h1>Principal Component Analysis (PCA)</h1></center>

***

Once the data has been extracted, compute Principal Components (PCs) of the data to perform dimensionality reduction on the data going into the modeling step.
<img src="https://drive.google.com/uc?export=view&id=1I1WcfEwzpfwIxNYStX7swLAIvjQEVApy">

### Restore Progress Variables

In [None]:
from os.path import join
from moseq2_app.gui.progress import restore_progress_vars

base_dir = './' # User-defined absolute path

progress_filepath = join(base_dir, 'progress.yaml')

progress_paths = restore_progress_vars(progress_filepath)

## Training PCA

Train a PCA model on your extracted data to acquire the Principal Components that explain the largest possible variance in your dataset. If the resulting Principal Components look smooth with well defined regions, and your Scree plot shows an explained variance of 90% or above in less than 10 PCs then the PCA model is properly trained. Otherwise, consult the [pathologies below](#Possible-PCA-Pathologies) to solve any issues.

- You can check your distributed data processing progress while the PCA operations are taking place by checking the [dask server](https://localhost:8787/). This is only meant for optimization or debugging.

- If there are occlusions over the rodent in your extractions (shown below), set `config_data['missing_data'] = True` to recompute the missing data.
<img src="https://drive.google.com/uc?export=view&id=1y9_aRzrE3PS34GC2LJe3zuEXvms04S90">

__Using SLURM?__ Add and edit the following config parameters to spawn dask workers according to your specifications:
```
# Read in the config file
with open(progress_paths['config_file'], 'r') as f:
    config_data = yaml.safe_load(f)
    
# Edit its contents
config_data['cluster_type'] = 'slurm'

config_data['nworkers'] = 8 # number of spawned jobs
config_data['queue'] = 'short' # partition
config_data['memory'] = '40GB' # amount of memory per worker
config_data['cores'] = 1 # number of cores per worker
config_data['wall_time'] = '01:00:00' # worker time limit

# Write the changes back to the config file before running the interactive_roi_detector command.
with open(progress_paths['config_file'], 'w') as f:
    yaml.safe_dump(config_data, f)
```

In [None]:
from os.path import join
import ruamel.yaml as yaml
from moseq2_pca.gui import train_pca_command
from moseq2_app.gui.progress import update_progress, restore_progress_vars

progress_paths = restore_progress_vars(progress_filepath)

pca_filename = 'pca' # Name of your PCA model h5 file to be saved
pca_dirname = '_pca/' # Directory to save your computed PCA results

with open(progress_paths['config_file'], 'r') as f:
    config_data = yaml.safe_load(f)

# PCA parameters you may need to configure
config_data['gaussfilter_space'] = (1.5, 1) # Spatial filter for data (Gaussian)
config_data['medfilter_space'] = [0] # Median spatial filter
config_data['medfilter_time'] = [0] # Median temporal filter

# If dataset includes head-attached cables, set missing_data=True
config_data['missing_data'] = False # Set True for dataset with missing/dropped frames to reconstruct respective PCs.
config_data['missing_data_iters'] = 10 # Number of times to iterate over missing data during PCA
config_data['recon_pcs'] = 10 # Number of PCs to use for missing data reconstruction

# Dask Configuration
config_data['dask_port'] = '8787' # port to access Dask Dashboard

with open(progress_paths['config_file'], 'w') as f:
    yaml.safe_dump(config_data, f)

progress_paths = update_progress(progress_filepath, 'pca_dirname', join(base_dir, pca_dirname))

# will train on data in aggregate_results/
train_pca_command(progress_paths, pca_dirname, pca_filename)

Once complete, a new directory titled `_pca` will be created containing all your PCA data.
```
.
├── _pca/ **
├   ├── pca.h5 ** # pca model compressed file
├   ├── pca.yaml  ** # pca model YAML metadata file
├   ├── pca_components.png **
├   └── pca_scree.png **
├── aggregate_results/
├── config.yaml
├── moseq2-index.yaml
├── session_1/
└── session_2/

```

View your `computed PCs` and `scree plot` in the next cell.

In [None]:
from os.path import join
from IPython.display import display, Image
images = [join(progress_paths['pca_dirname'], 'pca_components.png'), 
          join(progress_paths['pca_dirname'], 'pca_scree.png')]
for im in images:
    display(Image(im))

### Possible PCA Pathologies

<table style="width: 100%;">
  <tbody>
    <tr>
      <th></th>
      <th>Good PCA Output Examples</th>
      <th style="text-align:center;">Bad Scree Plot Example</th>
      <th style="text-align:center;">Bad Principal Components Example</th>
    </tr>  
    <tr>
      <th style="text-align:center;">Pathology Description</th>
      <th style="text-align:center;"></th>
      <td style="text-align:center;">Cannot achieve a explained variance of over 90% from less than 15 Principal Components (PCs).</td>
      <td style="text-align:center;">Graphed PCs look overprocessed, or are not representative of realistic mouse body regions.</td>
    </tr>
    <tr>
      <th style="text-align:center;">Reference Examples</th>
      <th style="text-align:center;">
        <ul>
            <li>Components<br>
                <img src="https://drive.google.com/uc?export=view&id=1dX5Gpd3PKL4vfVviLeP0CqBrz9PW37Au" width=350 height=350></li><br><br>
            <li>Scree Plot<br>
                <img src="https://drive.google.com/uc?export=view&id=12uqsBYuWCjpUQ6QrAjo35MnwYDzHqnge" width=350 height=350>
            <br>"90.65% in 7 PCs"</li>
        </ul>
      </th>
      <td><img src="https://drive.google.com/uc?export=view&id=14OwThgsf2GXnrl3-9TXEMvF3PDxmRsHE" width=350 height=350></td>
      <td><img src="https://drive.google.com/uc?export=view&id=1d35zKWiT7bkWbNNAon_JdSjKyVgcHHzi" width=350 height=350></td>
    </tr>
    <tr>
      <th style="text-align:center;">Image Analysis Solutions</th>
      <th style="text-align:center;"></th>
      <td>
        <ul>
          <li style="text-align:left;">Check if the crop size is too large, if so, decrease it and re-extract your data.</li>
          <li style="text-align:left;">Try (incrementally) adjusting the spatial and temporal filtering kernel sizes in the PCA step. Generally, increasing temporal smoothing will aid in increasing explained variance, however overfiltering will hinder ARHMM reliability.</li>
        </ul>
      </td>
      <td>
          <ul>
              <li style="text-align:left;">Ensure that an appropriate amount of spatial and temporal filtering is applied.</li>
              <li style="text-align:left;">If there are  missing frames, apply and appropriate amount of temporal filtering, and a proper amount of PCs are being reconstructed (recon_pcs is set to the appropriate amount of PCs).</li>
          </ul>
    </td>
    </tr>
    <tr>
      <th style="text-align:center;">General Solutions</th>
      <th style="text-align:center;"></th>
      <td style="text-align:center;">
          <ul>
          <li style="text-align:left;">Try setting missing_data=True. Using an iterative PCA to reconstruct the PCs can aid in increasing the explained variance ratio.</li>
          <li style="text-align:left;">Increase the size of your dataset. If your dataset is too small, it may contribute to overprocessing PCs as well.</li>
        </ul>
      </td> <!-- G -->
      <td style="text-align:center;">Acquire and extract more data, then try with more data.</td>
    </tr>
  </tbody>
</table>

## Computing Principal Component Scores

Apply your trained PCA model using your computed principal components to compute your PC Scores. PC Scores are a result of multiplying the input videos by the transpose of the principal components. They indicate the "score" or explainability ratio (AKA Pose Trajectory) of each computed Principal Component in each frame of the inputted videos.

These are the values that the ARHMM will train on, as the PCA Scores have a [positive semi-definite matrix](https://en.wikipedia.org/wiki/Definite_symmetric_matrix) property, which is required by the model.

In [None]:
from os.path import join
from moseq2_pca.gui import apply_pca_command
from moseq2_app.gui.progress import update_progress, restore_progress_vars

progress_paths = restore_progress_vars(progress_filepath)

scores_filename = 'pca_scores' # name of the scores file to compute and save

scores_file = join(progress_paths['pca_dirname'], scores_filename+'.h5') # path to input PC scores file to model
progress_paths = update_progress(progress_filepath, 'scores_path', scores_file)

apply_pca_command(progress_paths, scores_filename)

Once complete, you will have a pca_scores file saved in your pca directory. (Example shown below)
```
.
├── _pca/
├   ├── pca.h5
├   ├── pca.yaml
├   ├── pca_scores.h5  ** # scores file
├   ├── pca_components.png
├   └── pca_scree.png
├── aggregate_results/
├── config.yaml
├── moseq2-index.yaml
├── session_1/
└── session_2/

```

## (Optional) Computing Model-Free Syllable Changepoints

This is an optional step used to aid in determining model-free syllable lengths; which are general approximations of the duration of respective body language syllables. The Changepoint distribution is also meant to provide a clearer representation of the accuracy of the PCA-fit.

A good Changepoint graph should show a smooth left-skewed histogram distribution of changepoint durations, with a CPE curve accurately fit to the histogram. Having a proper left-skewed distribution indicates that the Principal Components accurately represent your extracted data. If that is not the case, consult the [below pathologies](#Possible-Model-Free-Changepoints-Pathologies).

The x-axis denotes the time taken to transition from one pose to another. The y-axis indicates the probability of a changepoint having the duration at each point on x. Ideally, the changepoint mode/curve-maximum should be roughly 0.3 seconds (300 ms).

__Note: the parameters below have been preconfigured to best process C57 mouse data, and have not been tested for other species. Configure them at your own risk.__

In [None]:
import ruamel.yaml as yaml
from moseq2_pca.gui import compute_changepoints_command
from moseq2_app.gui.progress import update_progress, restore_progress_vars

progress_paths = restore_progress_vars(progress_filepath)

with open(progress_paths['config_file'], 'r') as f:
    config_data = yaml.safe_load(f)

changepoints_filename = 'changepoints' # name of the changepoints images to generate

# Changepoint computation parameters you may want to configure
config_data['threshold'] = 0.5 # Peak threshold to use for changepoints
config_data['dims'] = 300 # Number of random projections to compare the computed principal components with

with open(progress_paths['config_file'], 'w') as f:
    yaml.safe_dump(config_data, f)

progress_paths = update_progress(progress_filepath, 'changepoints_path', changepoints_filename)
compute_changepoints_command(progress_paths['train_data_dir'], progress_paths, changepoints_filename)

The changepoints plot will be generated and saved in the pca directory (example below).

```
.
├── _pca/ 
├   ├── pca.h5
├   ├── pca_scores.h5
├   ...
├   └── changepoints_dist.png **
├── aggregate_results/ 
├── config.yaml
├── moseq2-index.yaml
├── session_1/
└── session_2/
```

View your changepoints distance plot (if the text is too small, check the pdf file):

In [None]:
from os.path import join
from IPython.display import display, Image

display(Image(join(progress_paths['pca_dirname'], progress_paths['changepoints_path']+'_dist.png')))

### Possible Model-Free Changepoints Pathologies

<table style="width: 100%;">
  <tbody>
    <tr>
      <th></th>
      <th style="text-align:center;">Good Changepoint Analysis Example</th>
      <th style="text-align:center;">Poor Changepoints Analysis Example</th>
    </tr>  
    <tr>
      <th style="text-align:center;">Pathology Description</th>
      <td style="text-align:center;"></td>
      <td style="text-align:center;">Model-free syllable changepoint distances distribution is incorrectly skewed/too sparse and/or changepoint mode duration is less than 0.2s</td>
    </tr>
    <tr>
      <th style="text-align:center;">Reference Example</th>
      <td><img src="https://drive.google.com/uc?export=view&id=1sMkSB34bGbOimumN6Gg1-zV2Hk98v2Zy" width=350 height=350></td>
      <td><img src="https://drive.google.com/uc?export=view&id=1S-ALkPmb8sBZGkKmJ7Q3-RdxAbfS0PWV" width=350 height=350></td>
    </tr>
    <tr>
      <th style="text-align:center;">General Solutions</th>
      <td style="text-align:center;"></td>
      <td>
          <ul>
              <li style="text-align:left;">Try retraining the PCA with adjusted spatial and temporal filtering kernel sizes.</li>
              <li style="text-align:left;">Ensure your extracted data is correct with minimal flips. If the extraction version of the mouse is too noisy, then the PC trajectories cannot be accurately applied to the data.</li>
              <li style="text-align:left;">Get more data and try again.</li>
          </ul>
      </td>
    </tr>
  </tbody>
</table>

***
<center><h1>ARHMM Modeling</h1></center>

***

In order to train your ARHMM (Auto-Regressive Hidden Markov Model), you will use your computed PC scores as your input data, and specify whether you are modeling a single experimental group for observational research, or modeling multiple different groups (e.g. control vs. experimental groups) for comparative analysis.

<img src="https://drive.google.com/uc?export=view&id=15uoeWDTn8fbcau2jErJEVFUt2iJyadmo">

### Restore Progress Variables

In [None]:
from os.path import join
from moseq2_app.gui.progress import restore_progress_vars

base_dir = './' # User-defined absolute path

progress_filepath = join(base_dir, 'progress.yaml')

progress_paths = restore_progress_vars(progress_filepath)

## Train ARHMM
__Note: when loading a model checkpoint, ensure the all the inputted parameters (especially the selected groups) are identical to that of the checkpoint. Otherwise the model will not train.__

In [None]:
from os.path import join
import ruamel.yaml as yaml
from moseq2_model.gui import learn_model_command
from moseq2_app.gui.progress import update_progress, restore_progress_vars

progress_paths = restore_progress_vars(progress_filepath)

modeling_session_path = 'saline-amphetamine/'
model_name = 'model.p'

session_path = join(base_dir, modeling_session_path)
model_path = join(session_path, model_name) # path to save trained model

select_groups = False # select specific groups to model; if False, will model all data as is in moseq2-index.yaml

# model saving freqency (in interations); will create a checkpoints/ directory containing checkpointed models
checkpoint_freq = -1
use_checkpoint = False # resume training from latest saved checkpoint

# Advanced modeling parameters
hold_out = False # boolean to hold out data subset during the training process
nfolds = 2 # (if hold_out==True): number of folds to hold out during training; 1 fold per session

npcs = 10  # number of PCs being used
max_states = 100 # number of maximum states the ARHMM can end up with

# use robust-ARHMM with t-distribution -> yields less states/syllables if True, 
# used to constrict accepted behavioral variability
robust = True 

# separate group transition graphs; set to True if ngroups > 1
separate_trans = True 

num_iter = 100 # number of iterations to train model

# syllable length probability distribution prior; (None, int or 'scan'); if None, kappa=nframes
kappa = None 

# if kappa == 'scan', optionally set bounds to scan kappa values between, in either a linear or log-scale.
scan_scale = 'log' # or linear
min_kappa = None
max_kappa = None

# total number of models to spool
n_models = 5

# Select platform to run models on
cluster_type = 'local' # currently supported cluster_types = 'local' or 'slurm'
run_cmd = False # if True, runs the commands via os.system(...)

progress_paths = update_progress(progress_filepath, 'model_path', model_path)
progress_paths = update_progress(progress_filepath, 'model_session_path', session_path)

learn_model_command(progress_paths, hold_out=hold_out, nfolds=nfolds, num_iter=num_iter, max_states=max_states,
                    npcs=npcs, kappa=kappa, separate_trans=separate_trans, robust=robust,
                    checkpoint_freq=checkpoint_freq, use_checkpoint=use_checkpoint, select_groups=select_groups,
                    cluster_type=cluster_type, min_kappa=min_kappa, scan_scale=scan_scale,
                    max_kappa=max_kappa, n_models=n_models, run_cmd=run_cmd, output_dir=modeling_session_path)

Once training is complete, your model will be saved in your base directory (shown below). 
```
.
├── _pca/ 
├── aggregate_results/ 
├── config.yaml
├── modeling_session/ ***
├   └── model.p ***
├── moseq2-index.yaml/
├── session_1/
└── session_2/
```


### Restore Notebook Variables

In [None]:
from os.path import join
from moseq2_app.gui.progress import restore_progress_vars

base_dir = './' # User-defined absolute path

progress_filepath = join(base_dir, 'progress.yaml')

progress_paths = restore_progress_vars(progress_filepath)

## Get Best Model Fit

Use this feature to determine whether the trained model has captured median syllable durations that match the principal components changepoints.

This feature can also return the best model from a list of models found in the `progress_paths['model_session_path']`.

Below are examples of some comparative distributions that you can expect when using this tool:

<table>
    <tr>
        <td>
            <img height=400 width=400 src="https://drive.google.com/uc?export=view&id=1ENQVOFcM7moN_k6G_hVysIAaH-smRnEd">
        </td>
        <td>
            <img height=400 width=400 src="https://drive.google.com/uc?export=view&id=1rtfzkBGISuu8fpGNLNOTt9881Hgg_rXC">
        </td>

In [None]:
from os.path import join
from IPython.display import display, Image
from moseq2_viz.gui import get_best_fit_model
from moseq2_app.gui.progress import update_progress, restore_progress_vars

progress_paths = restore_progress_vars(progress_filepath)

output_file = join(progress_paths['plot_path'], 'model_vs_pc_changepoints')

best_model_fit = get_best_fit_model(progress_paths, plot_all=True)
progress_paths = update_progress(progress_filepath, 'model_path', best_model_fit)
display(Image(output_file+'.png'))

# End and User Survey

Please take some time to tell us your thoughts about this notebook:
**[user feedback survey](https://forms.gle/FbtEN8E382y8jF3p6)**