# Moseq2 App

<img src="https://drive.google.com/uc?export=view&id=1PxTnCMsrk3hRHPnEjqGDzq1oPkTYfzj0">

MoSeq2 is a software toolkit for unsupervised modeling and characterization of animal behavior. Moseq takes depth recordings of animals as input and outputs a rich description of behavior as a series of reused and stereotyped motifs called 'syllables'. 

This notebook begins with depth recordings (see the [data acquisition overview](#Data-Acquisition-Overview) below) and transforms this data through the steps of: 

- **[Extraction](#Raw-Data-Extraction)**: The animal is segmented from the background and its position and heading direction are aligned across frames.
- **[Dimensionality reduction](#Principal-Component-Analysis-(PCA))**: Raw video is de-noised and transformed to low-dimensional pose trajectories using principal component analysis (PCA).
- **[Model training](#ARHMM-Modeling)**: Pose trajectories are modeled using an autoregressive hidden Markov model (AR-HMM), producing a sequence of syllable labels.

__The model output can be analyzed using the [Interactive Results Exploration](./Interactive-Model-Results-Exploration.ipynb) notebook.__

### Resources

- [Wiki](https://github.com/dattalab/moseq2-app/wiki) with instructions on data acquisition and command line options
- PDF documentation of all MoSeq functions: [Extract](https://github.com/dattalab/moseq2-extract/blob/release/Documentation.pdf), [PCA](https://github.com/dattalab/moseq2-pca/blob/release/Documentation.pdf), [Model](https://github.com/dattalab/moseq2-model/blob/release/Documentation.pdf)
- Publications:
    - [Mapping Sub-Second Structure in Mouse Behavior](http://datta.hms.harvard.edu/wp-content/uploads/2018/01/pub_23.pdf)
    - [The Striatum Organizes 3D Behavior via Moment-to-Moment Action Selection](http://datta.hms.harvard.edu/wp-content/uploads/2019/06/Markowitz.final_.pdf)
    - [Revealing the structure of pharmacobehavioral space through motion sequencing](https://www.nature.com/articles/s41593-020-00706-3)
    - [Q&A: Understanding the composition of behavior](http://datta.hms.harvard.edu/wp-content/uploads/2019/06/Datta-QA.pdf)
    
### Feedback

For general feedback and feature requests, please fill out [this survey](https://forms.gle/FbtEN8E382y8jF3p6).
    
### Data Acquisition Overview
MoSeq2 takes animal depth recordings as input. We we have developed a [data acquisition pipeline](https://github.com/dattalab/moseq2-app/wiki/Setup:-acquisition-software) for the second generation `Xbox Kinect` depth camera. We suggest following our [data acquisition tutorial](https://github.com/dattalab/moseq2-app/wiki/Acquisition) for doing recordings. MoSeq2 also accepts depth recordings from the `Azure Kinect` camera as well as the `Intel RealSense` using their standard data acquisitions pipelines.

**We recommend recording more than 10 hours of depth video (~1 million frames at 30 frames per second) to ensure quality MoSeq models**

## Notebook Setup

<img src="https://drive.google.com/uc?export=view&id=1h2GYECyEuTMlM7Rx3Q3lMVBdWEm1F0S5">

### Check to see if you're running python from the correct conda enviornment

If you performed the recommended installation, you should see the `sys.executable` path point to the python path within the `moseq2-app` environment, i.e.,
```python
import sys
print(sys.executable)
# /Users/username/miniconda3/envs/moseq2-app/bin/python
```

### Check if the dependencies are found

Run the following cell to check if `moseq2-app` is installed in your current conda kernel. The latest working version number is `0.2.1`.

In [None]:
import sys
import moseq2_app

print('Python path:', sys.executable)
print('MoSeq2 app version:', moseq2_app.__version__)

## Data file organization

The currently accepted depth data extensions are:
- `.dat` (raw depth files from our kinect2 data acquisition software)
- `.tar.gz` (compressed depth files from our kinect2 data acquisition software)
- `.avi` (compressed depth files from the `moseq2-extract` command line interface)
- `.mkv` (generated from Microsoft's recording software for the Azure Kinect)

To run this notebook, create a master folder with a copy of this notebook, and a separate subfolders for each recording file (see example directory structure below). 

```
.
└── Data_Directory/
    ├── Main-MoSeq2-Notebook.ipynb (running)
    ├── session_1/ ** - the folder containing all of a single session's data
    ├   ├── depth.dat        # depth data - the recording itself
    ├   ├── depth_ts.txt     # timestamps - csv/txt file of the frame timestamps
    ├   └── metadata.json    # metadata - json file that contains the rodent's info (group, subjectName, etc.)
    ...
    ├── session_2/ **
    ├   ├── depth.dat
    ├   ├── depth_ts.txt
    └── └── metadata.json

```

__Note: if your data was acquired using an Azure Kinect or Intel RealSense depth camera, you will not have `depth_ts.txt` or `metadata.json` in your session directories. Before extraction you need to manually create a `metadata.json` file if you wish to identify sessions based on the session name or mouse ID.__ The metadata.json folder should minimally contain the following:

```json
{"SessionName": "example session", "SubjectName": "example subject", "StartTime": "optional"}
```

### Notebook Progress File

This notebook generates a `progress.yaml` file that stores the filepaths to data generated from this notebook, including:
- extraction data files
- PC scores of the extractions
- model results

If your notebook kernel is shutdown, you can load the progress file to 'restore' your progress. The progress file does **not** track MoSeq pipeline operations that were executed outside of this notebook (for example, if you were to run PCA using the command line interface). If necessary, you can manually modify the paths in the progress file or the corresponding `progress_paths` dictionary to access the output of these external operations.

__To restore previously computed variables, look for the cells following the `Restore Progress Variables` label.__

### Restore Progress Variables

- Use this cell to load your notebook analysis progress. We recommend running this notebook from the folder where your data is located so the generated media will display properly. In that case, you can specify the `base_dir` as `./` (or the current folder).
- The `base_dir` will be stored in the progress dict and will be reused throughout the notebooks.

The `check_progress` function will print progress bars for each pipeline step in the notebook. 
- The extraction progress bar indicates total the number of extracted sessions detected in the provided `base_dir` path.
- It prints the session names that haven't been extracted. __Note: the progress does not reflect the contents of the aggregate_results/ folder.__
- The remainder of the progress bars are derived from reading the paths in the `progress_paths` dictionary, filling up the bar if the included paths are found.

In [None]:
from os.path import join
from moseq2_app.gui.progress import check_progress, restore_progress_vars

# Add the path to your data folder here.
# We recommend that you run this notebook in the same folder as your data. In that case, you don't have to change base_dir
base_dir = './'
progress_filepath = join(base_dir, 'progress.yaml')

progress_paths = restore_progress_vars(progress_filepath, init=True, overwrite=False)
check_progress(progress_filepath)

### Generate Configuration Files

The `config.yaml` will be used to hold all configurable parameters for all steps in the MoSeq pipeline. Parameters will be added to this file as you progress through the notebook. The config file can be used to run an identical pipeline in future analyses. 

In [None]:
from os.path import join
from moseq2_app.gui.progress import update_progress
from moseq2_extract.gui import generate_config_command

config_filepath = join(progress_paths['base_dir'], 'config.yaml')

print(f'generating file in path: {config_filepath}')
generate_config_command(config_filepath)
progress_paths = update_progress(progress_filepath, 'config_file', config_filepath)

A configuration file has been created in the base directory (depicted below).

```
.
└── Data_Directory/
    ├── config.yaml **
    ├── session_1/ 
    ├   ├── depth.dat        
    ├   ├── depth_ts.txt     
    ├   └── metadata.json    
    ...
    ├── session_2/ 
    ├   ├── depth.dat
    ├   ├── depth_ts.txt
    └── └── metadata.json
```

### Download a Flip File

MoSeq2 uses a Random Forest flip classifier to guarantee that the mouse is always pointed to the right after cropping and rotationally aligning the depth videos. The flip classifiers we provide __are trained for experiments run with C57BL/6 mice using with Kinect v2 depth cameras__.

If your dataset does not work with our pre-trained flip classifiers, we provide a [flip-classifier training notebook](https://github.com/dattalab/moseq2-app/tree/jupyter/). After using this notebook, add the path of your custom classifier to the `config.yaml` file.

In [None]:
from moseq2_extract.gui import download_flip_command
# selection=0 - large mice with fibers (default)
# selection=1 - adult male C57s
# selection=2 - mice with Inscopix cables
download_flip_command(progress_paths['base_dir'], config_filepath, selection=1)

# Raw Data Extraction

<img src="https://drive.google.com/uc?export=view&id=1XtDo6sVtvG0Grp5pDgLbFcli2_hRcTZK">

## Interactive ROI Detection Tool (optional)

Use this interactive tool to optimized the extraction parameters prior to extracting all of your data. Most of the parameters are related to detecting the region the mouse occupies (the ROI). This tool can also be used to catch possibly corrupted or inconsistent sessions, and to diagnose ROI detection/extraction errors.

<table>
    <tr>
        <td style="width: 45%;">
            <ol>
                <li style="text-align:left; font-size:14px">
                    Execute the code below to launch the ROI Detection Tool.
                </li>
                <li style="text-align:left; font-size:14px">
                    Click on any row in the session selector to load that session's data to the view.
                </li>
                <li style="text-align:left; font-size:14px">
                    If the indicator next to the session's name is green, then the session is considered ready for extraction. A red indicator can either mean the session has not been checked yet, or its extraction parameter set is incorrect.
                </li>
                <li style="text-align:left; font-size:14px">
                    Adjust the Depth Range Selector to include the depth range of the detected bucket floor distance (which can be found by hovering over the Background image with your mouse). You can also manually enter slider values by clicking on the numbers.
                </li>
                <li style="text-align:left; font-size:14px">
                    If the mouse seems to be cropped when at the bucket edge, increase the "dilate iterations" settings to enlarge the size of the included floor area.
                </li>
                <li style="text-align:left; font-size:14px">
                    Use the Rodent Height Threshold Slider to remove any noise/speckle from the bucket floor or walls. 
                    <ul>
                        <li style="text-align:left; font-size:14px">
                            Ensure the min height parameter is small enough to only filter out floor reflections.
                        </li>
                        <li style="text-align:left; font-size:14px">
                            Ensure the max height parameter is large enough to include the largest possible mouse height, (i.e., when the mouse is rearing). A reasonable value is around 100 mm for Kinect v2 recordings.
                        </li>
                        <li style="text-align:left; font-size:14px">
                            Hover over the mouse in either of the bottom two plots to explore its height.
                        </li>
                    </ul>
                </li>
                <li style="text-align:left; font-size:14px">
                    Use the "current frame" slider to change the displayed session frame in the bottom 2 plots.
                </li>
                <li style="text-align:left; font-size:14px">
                    Change the frame range slider values to adjust the segments of the video to extract, then click the "Extract Sample" button to trigger an extraction and view the results.
                </li>
                <li style="text-align:left; font-size:14px">
                    Once you have found a satisfactory set of parameters, click "Check All Sessions" to test the parameters on all sessions. A set of filters described <a href="https://github.com/dattalab/moseq2-app/wiki/Analysis:-extraction#check-all-session-protocol">here</a> will be applied to detect possible poor extractions.
                </li>
                <li style="text-align:left; font-size:14px">
                    If no sessions are flagged, click "Save Parameters". The parameters for each session will be written in the overall project config file and in each session-specific config file. 
                </li>
                <li style="text-align:left; font-size:14px">
                    If a session is flagged, click on it in the Session Selector and a text indicator will appear with error details. Readjust the parameters until the session passes and then save the parameters. The "Mark Passing" button can be used to manually accept a session's parameter set.
                </li>
            </ol>
            <p style="text-align:left; font-size:14px">
        </td>
        <td>
            <img src="https://drive.google.com/uc?export=view&id=1iIj92Wl0Uezn_ehjvGnwV2YzGp8Pir6f">
        </td>
    </tr>
</table>

__Note: if cell seems to be running out of memory after first use, set `compute_all_bgs=False` to reduce the memory pressure.__

In [None]:
from os.path import join
import ruamel.yaml as yaml
from moseq2_app.gui.progress import update_progress
from moseq2_app.main import interactive_roi_detector

session_config_path = join(progress_paths['base_dir'], 'session_config.yaml')
progress_paths = update_progress(progress_filepath, 'session_config', session_config_path)

with open(progress_paths['config_file'], 'r') as f:
    config_data = yaml.safe_load(f)

config_data['camera_type'] = 'auto' # 'kinect', 'azure' or 'realsense'
config_data['crop_size'] = (80, 80)
config_data['output_dir'] = 'proc' # the subfolder extracted data is saved to

# if using azure or realsense, increase the noise_tolerance for ROI detection
config_data['noise_tolerance'] = 30

# OPTIONAL additional parameters
# config_data['flip_classifier'] = './alternative-flip-classifier.pkl' # updated flip classifier path
# config_data['gaussfilter_space'] = [2.5, 2] # spatial filtering kernel size
# config_data['medfilter_time'] = (3,) # temporal filtering kernel size

# Filtering out head-fixed cables?
# config_data['cable_filter_iters'] = 3 # number of cable filtering iterations
# config_data['cable_filter_size'] = (7, 7) # cable spatial filter kernel size

with open(progress_paths['config_file'], 'w') as f:
    yaml.safe_dump(config_data, f)

compute_all_bgs = True # If False, only computes the first background on launch
                
autodetect_depths = False # If True, will readjust the bg_depth_range for each session 

interactive_roi_detector(progress_paths, compute_all_bgs=compute_all_bgs, autodetect_depths=autodetect_depths)

### Restore Progress Variables

In [None]:
from moseq2_app.gui.progress import restore_progress_vars

progress_filepath = './progress.yaml'

progress_paths = restore_progress_vars(progress_filepath)

## Extract Session(s)

__Note: If sessions are not listed when running the cell, ensure your selected extension matches that of your depth files.__

In [None]:
from moseq2_extract.gui import extract_found_sessions

# include the file extensions for the depth files you would like to search for and extract.
extensions = ['.avi', '.dat'] # .avi, .dat, and/or .mkv

# for the option to extract individual sessions, set extract_all=False
# to overwrite previously extracted sessions, set skip_extracted=False
extract_found_sessions(progress_paths['base_dir'], progress_paths['config_file'], extensions, extract_all=True, skip_extracted=True)

This is what your directory structure should look like once the process is complete:

```
.
├── config.yaml
├── session_1/
├   ...
├   └── proc/ **
├   ├   ├── roi.tiff
├   ├   ...
├   ├   ├── results_00.yaml ** (.yaml file storing extraction parameters)
├   ├   ├── results_00.h5 ** (.h5 file storing extraction)
├   └   └── results_00.mp4 ** (extracted video)
└── session_2/
├   ...
├   └── proc/ **
├   ├   ├── roi.tiff
├   ├   ...
├   ├   ├── results_00.yaml **
├   ├   ├── results_00.h5 **
└   └   └── results_00.mp4 **
        
```

### Run Extraction Validation Tests (optional)

Once all the extractions are complete, use the following cell to run data validation tests. The tests can output either an error or a warning. 
- An __error__ indicates that the session is corrupted in some way and should be excluded from PCA and Modeling.
- A __warning__ indicates that one or more sessions are statistical outliers.
  - A warning can indicate that the session may need to be inspected prior to continuing into the PCA step. 
  - Warnings can be ignored when they are consistent with experimental design (e.g. abnormally high velocity in an animal that recieved a stimulant drug). 

__Error tests__: 
- Count Dropped Frames: an error is raised if a session is missing >5% of the frames based on timestamps (requires a timestamp file).
- Missing Mouse Check: raises an error if a mouse is missing from the video for any reason for >5% of the session's total frames.
- Scalar Anomaly: raises an error if >5% of a session's computed scalar values are NaN.

__Warning tests__:
- Size Anomaly: Warning is raised when a mouse's captured body size is less than 2 standard deviations from the mean size throughout the session.
- Scalar Anomaly: Warning is raised for a session if the trained [EllipticEnvelope](https://scikit-learn.org/stable/modules/generated/sklearn.covariance.EllipticEnvelope.html) model classifies it as an outlier. 
- Position Anomaly: There are two cases that raise warnings:
    1. Mouse is stationary for >5% of the session.
    2. Mouse's position distribution is at least 2 standard deviations away from the mean of all the sessions, measured using Kullback–Leibler divergence.
     - This anomaly can indicate that a mouse has explored a much larger or smaller region of the arena compared to the other recordings in the datset.

To diagnose certain scalar anomalies, use the [Scalar Summary Cell](#Compute-Scalar-Summary) below to graph any desired scalar value.


In [None]:
from moseq2_app.main import validate_extractions

validate_extractions(progress_paths['base_dir']) # path to pre-existing aggregate_results/ folder is also permissible.

### Review Extraction Output (optional)

Run the following cell to view the extraction output.

In [None]:
from moseq2_app.main import preview_extractions

preview_extractions(progress_paths['base_dir'])

## Aggregate your results into one folder and generate an index file.

The following cell will search for the `proc/` subfolders containing the extraction output, and copy them to a single `aggregate_results/` folder. An index file called `moseq2-index.yaml` will also be generated with metadata for all extracted sessions. The index file can be used to group recordings by experimental condition for downstream analysis. 
Initially, each session as assigned to a single group called "default". We provide an interface for re-assigning group labels below.

The `aggregate_results/` folder contains all the data you need to run the rest of the pipeline. The PCA and modeling step will use data in this folder.

__Important Note: The index file contains UUIDs to map each session to a specific extraction. These UUIDs are referenced throughout the pipeline, so if you re-extract a session and re-aggregate your data, ensure all the UUIDs in the index file are up-to-date BEFORE running the PCA step.__ Not updating the index file will likely cause `KeyError`s to occur in the PCA and modeling steps.

In [None]:
from os.path import join
from moseq2_app.gui.progress import update_progress
from moseq2_extract.gui import aggregate_extract_results_command

recording_format = '{start_time}_{session_name}_{subject_name}' # filename formats for the copied extracted data files

# directory NAME to save all metadata+extracted videos to with above respective name format
aggregate_results_dirname = 'aggregate_results/'

train_data_dir = join(progress_paths['base_dir'], aggregate_results_dirname)
update_progress(progress_filepath, 'train_data_dir', train_data_dir)

# the subpath indicates to only aggregate extracted session paths with that subpath, only change if aggregating data from a different location
index_filepath = aggregate_extract_results_command(progress_paths['base_dir'], recording_format, aggregate_results_dirname)
progress_paths = update_progress(progress_filepath, 'index_file', index_filepath)

The aggregate results folder will be saved in your base directory, e.g.

```
.
├── aggregate_results/ **
├   ├── session_1_results_00.h5 ** # session 1 compressed extraction + metadata 
├   ├── session_1_results_00.yaml **
├   ├── session_1_results_00.mp4 ** # session 1 extracted video
├   ├── session_2_results_00.h5 ** # session 2 compressed extraction + metadata 
├   ├── session_2_results_00.yaml **
├   └── session_2_results_00.mp4 ** # session 2 extracted video
├── config.yaml
├── moseq2-index.yaml ** # index file
├── session_1/
└── session_2/
```

## Specify Groups

Sessions can be given "group" labels in the moseq2-index.yaml for analyses comparing different cohorts or experimental conditions. This step requires that all your sessions have a metadata.json file containing a session name. Run the cell below to launch the group assignment GUI

- Click on the column names to sort the index file.
- Enter your desired group name in the text input and click `Set Group` to update all the associated session rows.
- Once all your groups are set, click the `Update Index File` button to save current group assignments.

In [None]:
from moseq2_app.main import interactive_group_setting

interactive_group_setting(progress_paths['index_file'])

### Compute Scalar Summary (optional)

<img src="https://drive.google.com/uc?export=view&id=1wAHj1d5u0GeSa2_iJi7agjbqllbwXFKz">

Use the following command to plot a summary of scalar values for each group, such as average velocity, height, etc.
- Hold [CTRL]/[Command] and click on the Selector rows to select multiple scalars to plot.
- Hover over any of the data points to display the session information.
- Click on the legend items to show/hide groups from the plot. Double click an item to only show a single group.

In [None]:
from moseq2_app.main import interactive_scalar_summary

interactive_scalar_summary(progress_paths['index_file'])

### Plot Position Heatmaps For Each Session (Optional)

In [None]:
from os.path import join
from moseq2_viz.gui import plot_verbose_position_heatmaps

output_file = join(progress_paths['plot_path'], 'session_heatmaps') 
verbose_heatmap_fig = plot_verbose_position_heatmaps(progress_paths['index_file'], output_file)
verbose_heatmap_fig

### Plot Group Mean Position Summary (Optional)

In [None]:
from os.path import join
from moseq2_viz.gui import plot_mean_group_position_heatmaps_command

output_file = join(progress_paths['plot_path'], 'group_heatmaps') 
group_mean_heatmap_fig = plot_mean_group_position_heatmaps_command(progress_paths['index_file'], output_file)
group_mean_heatmap_fig

# Principal Component Analysis (PCA)

<img src="https://drive.google.com/uc?export=view&id=1KdNmEf_BcME5u39-mt75ROCDLDPzO7k-">

### Restore Progress Variables

In [None]:
from moseq2_app.gui.progress import restore_progress_vars

progress_filepath = './progress.yaml'

progress_paths = restore_progress_vars(progress_filepath)

## Fitting PCA

Fit PCA to your extracted data to determine the principal components (PCs) that explain the largest possible variance in your dataset. The PCs should look smooth and well defined like the examples below. The PCs should explain >90% of the variance in the dataset using around 10 PCs. If this isn't the case, consult the [table of possible pathologies ](#Possible-PCA-Pathologies). If running PCA locally, progress can be monitored using the [dask server](https://localhost:8787/)

In [None]:
from os.path import join
import ruamel.yaml as yaml
from moseq2_pca.gui import train_pca_command
from moseq2_app.gui.progress import update_progress

pca_filename = 'pca' # Name of your PCA model h5 file to be saved
pca_dirname = join(progress_paths['base_dir'], '_pca/') # Directory to save your computed PCA results

with open(progress_paths['config_file'], 'r') as f:
    config_data = yaml.safe_load(f)

# PCA parameters you may need to configure
config_data['overwrite_pca'] = False
config_data['gaussfilter_space'] = (1.5, 1) # Spatial filter for data (Gaussian)
config_data['medfilter_space'] = [0] # Median spatial filter
config_data['medfilter_time'] = [0] # Median temporal filter

# If dataset includes head-attached cables, set missing_data=True
config_data['missing_data'] = False # Set True for dataset with missing/dropped frames to reconstruct respective PCs.
config_data['missing_data_iters'] = 10 # Number of times to iterate over missing data during PCA
config_data['recon_pcs'] = 10 # Number of PCs to use for missing data reconstruction

# Dask Configuration
config_data['dask_port'] = '8787' # port to access Dask Dashboard

# UNCOMMENT to use SLURM
# config_data['cluster_type'] = 'slurm'
# config_data['nworkers'] = 8 # number of spawned jobs
# config_data['queue'] = 'short' # partition
# config_data['memory'] = '40GB' # amount of memory per worker
# config_data['cores'] = 1 # number of cores per worker
# config_data['wall_time'] = '01:00:00' # worker time limit

# UNCOMMENT if recordings contain occlusions (e.g. from overhead cables)
# config_data['missing_data'] = True

with open(progress_paths['config_file'], 'w') as f:
    yaml.safe_dump(config_data, f)

progress_paths = update_progress(progress_filepath, 'pca_dirname', pca_dirname)

# will train on data in aggregate_results/
train_pca_command(progress_paths, pca_dirname, pca_filename)

Once complete, a new directory `_pca` will be created containing the following data:
```
.
├── _pca/ **
├   ├── pca.h5 ** # pca model compressed file
├   ├── pca.yaml  ** # pca model YAML metadata file
├   ├── pca_components.png **
├   └── pca_scree.png **
├── aggregate_results/
├── config.yaml
├── moseq2-index.yaml
├── session_1/
└── session_2/

```

View your `computed PCs` and `scree plot` in the next cell.

In [None]:
from os.path import join
from IPython.display import display, Image
images = [join(progress_paths['pca_dirname'], 'pca_components.png'), 
          join(progress_paths['pca_dirname'], 'pca_scree.png')]
for im in images:
    display(Image(im))

### Possible PCA Pathologies

<table style="width: 100%;">
  <tbody>
    <tr>
      <th></th>
      <th>Good PCA Output Examples</th>
      <th style="text-align:center;">Bad Scree Plot Example</th>
      <th style="text-align:center;">Bad Principal Components Example</th>
    </tr>  
    <tr>
      <th style="text-align:center;">Pathology Description</th>
      <th style="text-align:center;"></th>
      <td style="text-align:center;">Cannot achieve a explained variance of over 90% from less than 15 Principal Components (PCs).</td>
      <td style="text-align:center;">PCs look noisy, or are not representative of realistic mouse body regions.</td>
    </tr>
    <tr>
      <th style="text-align:center;">Reference Examples</th>
      <th style="text-align:center;">
        <ul>
            <li>Components<br>
                <img src="https://drive.google.com/uc?export=view&id=1dX5Gpd3PKL4vfVviLeP0CqBrz9PW37Au" width=350 height=350></li><br><br>
            <li>Scree Plot<br>
                <img src="https://drive.google.com/uc?export=view&id=12uqsBYuWCjpUQ6QrAjo35MnwYDzHqnge" width=350 height=350>
            <br>"90.65% in 7 PCs"</li>
        </ul>
      </th>
      <td><img src="https://drive.google.com/uc?export=view&id=1DazNIPlGLAIPPQNeGF3eLR2l1QBnek4N" width=350 height=350></td>
      <td><img src="https://drive.google.com/uc?export=view&id=1xKHn0kEcs26R78aRRZwtPV3EOZ7h-qa9" width=350 height=350></td>
    </tr>
    <tr>
      <th style="text-align:center;">Image Analysis Solutions</th>
      <th style="text-align:center;"></th>
      <td>
        <ul>
          <li style="text-align:left;">Check if the crop size is too large, if so, decrease it and re-extract your data.</li>
          <li style="text-align:left;">Try (incrementally) adjusting the spatial and temporal filtering kernel sizes in the PCA step. Generally, increasing temporal smoothing will aid in increasing explained variance, but can potentially throw out data.</li>
        </ul>
      </td>
      <td>
          <ul>
              <li style="text-align:left;">Ensure that an appropriate amount of spatial and temporal filtering is applied.</li>
              <li style="text-align:left;">If you set missing_data=True, adjust spatial and temporal filtering, and try adjusting the amount of PCs used for reconstruction (the recon_pcs parameter).</li>
          </ul>
    </td>
    </tr>
    <tr>
      <th style="text-align:center;">General Solutions</th>
      <th style="text-align:center;"></th>
      <td style="text-align:center;">
          <ul>
          <li style="text-align:left;">If there are cable occlusions, try setting missing_data=True. Using an iterative PCA to reconstruct the PCs can aid in increasing the explained variance ratio.</li>
          <li style="text-align:left;">Increase the size of your dataset. If your dataset is too small, it may contribute to overfit PCs.</li>
        </ul>
      </td> <!-- G -->
      <td style="text-align:center;">Acquire and extract more data, then try again.</td>
    </tr>
  </tbody>
</table>

## Computing Principal Component Scores

In [None]:
from os.path import join
from moseq2_pca.gui import apply_pca_command
from moseq2_app.gui.progress import update_progress

scores_filename = 'pca_scores' # name of the scores file to compute and save

scores_file = join(progress_paths['pca_dirname'], scores_filename + '.h5') # path to input PC scores file to model
progress_paths = update_progress(progress_filepath, 'scores_path', scores_file)

apply_pca_command(progress_paths, scores_filename)

The output if this step is saved in a file called ```pca_scores.h5``` in the  ```_pca directory```. 
```
.
├── _pca/
├   ├── pca.h5
├   ├── pca.yaml
├   ├── pca_scores.h5  ** # scores file
├   ├── pca_components.png
├   └── pca_scree.png
├── aggregate_results/
├── config.yaml
├── moseq2-index.yaml
├── session_1/
└── session_2/

```

### Computing Model-Free Changepoints (Optional) 

This step can be used to determine a target syllable duration for the modeling step. Typically the distribution of change-point durations is smooth, left-skewed and centered around 0.3 seconds. If that is not the case, consult the [table of possible pathologies](#Possible-Model-Free-Changepoints-Pathologies) below.

__Note: the parameters below are configured for C57 mouse data, and have not been tested for other strains/species.__

In [None]:
import ruamel.yaml as yaml
from moseq2_app.gui.progress import update_progress
from moseq2_pca.gui import compute_changepoints_command

with open(progress_paths['config_file'], 'r') as f:
    config_data = yaml.safe_load(f)

changepoints_filename = 'changepoints' # name of the changepoints images to generate

# Changepoint computation parameters you may want to configure
config_data['threshold'] = 0.5 # Peak threshold to use for changepoints
config_data['dims'] = 300 # Number of random projections to compare the computed principal components with

with open(progress_paths['config_file'], 'w') as f:
    yaml.safe_dump(config_data, f)

progress_paths = update_progress(progress_filepath, 'changepoints_path', changepoints_filename)
compute_changepoints_command(progress_paths['train_data_dir'], progress_paths, changepoints_filename)

The changepoints plot will be generated and saved in the ```_pca``` directory.
```
.
├── _pca/ 
├   ├── pca.h5
├   ├── pca_scores.h5
├   ...
├   └── changepoints_dist.png **
├── aggregate_results/ 
├── config.yaml
├── moseq2-index.yaml
├── session_1/
└── session_2/
```

View ```changepoints_dist.png``` using the next cell

In [None]:
from os.path import join
from IPython.display import display, Image

changepoints_filename = 'changepoints'
display(Image(join(progress_paths['pca_dirname'], changepoints_filename + '_dist.png')))

### Possible Model-Free Changepoints Pathologies

<table style="width: 100%;">
  <tbody>
    <tr>
      <th></th>
      <th style="text-align:center;">Good Changepoint Analysis Example</th>
      <th style="text-align:center;">Poor Changepoints Analysis Example</th>
    </tr>  
    <tr>
      <th style="text-align:center;">Pathology Description</th>
      <td style="text-align:center;"></td>
      <td style="text-align:center;">Model-free syllable changepoint distances distribution is incorrectly skewed/too sparse and/or changepoint mode duration is less than 0.2s</td>
    </tr>
    <tr>
      <th style="text-align:center;">Reference Example</th>
      <td><img src="https://drive.google.com/uc?export=view&id=1zlIaunlhwu0dX-Fw8jk3Xqp9K4FxJLTn" width=350 height=350></td>
      <td><img src="https://drive.google.com/uc?export=view&id=1RhdSyvvy9TgoCv0srfuQWPStqe-N1_3C" width=350 height=350></td>
    </tr>
    <tr>
      <th style="text-align:center;">General Solutions</th>
      <td style="text-align:center;"></td>
      <td>
          <ul>
              <li style="text-align:left;">Try retraining the PCA with adjusted spatial and temporal filtering kernel sizes.</li>
              <li style="text-align:left;">Ensure your extracted data is correct with minimal flips. If the extraction version of the mouse is too noisy, then the PC trajectories cannot be accurately applied to the data.</li>
              <li style="text-align:left;">Get more data and try again.</li>
          </ul>
      </td>
    </tr>
  </tbody>
</table>

# ARHMM Modeling

<img src="https://drive.google.com/uc?export=view&id=1pAiffIWGsLtbu6MWJmMjQjRFlZwbv8-8" width=350 height=350>

### Restore Progress Variables

In [None]:
from moseq2_app.gui.progress import restore_progress_vars

progress_filepath = './progress.yaml'

progress_paths = restore_progress_vars(progress_filepath)

## Fitting the ARHMM

Fitting the ARHMM typically requires adjusting the `kappa` hyperparameter to achieve a target syllable duration (higher values of `kappa` lead to longer syllable durations). The target duration can be determined using change-points analysis or set heuristically to 0.3-0.4 seconds based on prior literature. In the code below, set `kappa` to `'scan'` to run a family of models with different `kappa` values and use the "Get Best Model Fit" cell to pick an optimal value. We recommend fitting for 100-200 iterations to pick `kappa`. For final model fitting, set `kappa` to the chosen value and fit for ~1000 iterations. 

__Note: if loading a model checkpoint, ensure the modeling parameters (especially the selected groups) are identical to that of the checkpoint. Otherwise the model will fail.__


In [None]:
from os.path import join
from moseq2_model.gui import learn_model_command
from moseq2_app.gui.progress import update_progress

modeling_session_path = 'saline-amphetamine/'
model_name = 'model.p'

session_path = join(progress_paths['base_dir'], modeling_session_path)
model_path = join(session_path, model_name) # path to save trained model

select_groups = False # select specific groups to model; if False, will model all data as is in moseq2-index.yaml

# model saving freqency (in interations); will create a checkpoints/ directory containing checkpointed models
checkpoint_freq = -1
use_checkpoint = False # resume training from latest saved checkpoint

# Advanced modeling parameters
hold_out = False # boolean to hold out data subset during the training process
nfolds = 2 # (if hold_out==True): number of folds to hold out during training; 1 fold per session

npcs = 10  # number of PCs being used; base case should be npcs should have explained variance >= ~90% 
max_states = 100 # number of maximum states the ARHMM can end up with

# use robust-ARHMM with t-distribution -> able to tolerate more noise
robust = True 

# separate group transition graphs; set to True if you want to compare multiple groups
separate_trans = True 

num_iter = 100 # number of iterations to train model

# syllable length probability distribution prior; (None, int or 'scan'); if None, kappa=nframes
kappa = None 

# if kappa == 'scan', optionally set bounds to scan kappa values between, in either a linear or log-scale.
scan_scale = 'log' # or linear
min_kappa = None
max_kappa = None
out_script = 'train_out.sh' # script file to save kappa-scanning learn_model() commands 

# total number of models to spool
n_models = 15

# Select platform to run models on
cluster_type = 'local' # currently supported cluster_types = 'local' or 'slurm'
run_cmd = False # if True, runs the commands via os.system(...), script must be run manually otherwise

## SLURM PARAMETERS
## only edit these parameters if cluster_type == 'slurm'
memory = '16GB'
wall_time='3:00:00'
partition='short'

progress_paths = update_progress(progress_filepath, 'model_path', model_path)
progress_paths = update_progress(progress_filepath, 'model_session_path', session_path)

learn_model_command(progress_paths, hold_out=hold_out, nfolds=nfolds, num_iter=num_iter, max_states=max_states,
                    npcs=npcs, kappa=kappa, separate_trans=separate_trans, robust=robust,
                    checkpoint_freq=checkpoint_freq, use_checkpoint=use_checkpoint, select_groups=select_groups,
                    cluster_type=cluster_type, min_kappa=min_kappa, scan_scale=scan_scale,
                    max_kappa=max_kappa, n_models=n_models, run_cmd=run_cmd, output_dir=modeling_session_path,
                    out_script=out_script, memory=memory, wall_time=wall_time, partition=partition)

Once training is complete, your model will be saved in your base directory (shown below). 
```
.
├── _pca/ 
├── aggregate_results/ 
├── config.yaml
├── modeling_session/ ***
├   └── model.p ***
├── moseq2-index.yaml/
├── session_1/
└── session_2/
```


### Restore Notebook Variables

In [None]:
from moseq2_app.gui.progress import restore_progress_vars

progress_filepath = './progress.yaml'

progress_paths = restore_progress_vars(progress_filepath)

## Get Best Model Fit

Use this feature to determine whether the trained model has captured median syllable durations that match the principal components changepoints.

This feature can also return the best model from a list of models found in the `progress_paths['model_session_path']`.

Below are examples of some comparative distributions that you can expect when using this tool:

<table>
    <tr>
        <td>
            <img height=400 width=400 src="https://drive.google.com/uc?export=view&id=1B6R4AGsQHaddwJj-48Pbd_5ZHOZvpttp">
        </td>
        <td>
            <img height=400 width=400 src="https://drive.google.com/uc?export=view&id=1poLAAhNlAdM8T_1Ps6OMs6vNz03NGbgr">
        </td>

In [None]:
from os.path import join
from moseq2_viz.gui import get_best_fit_model
from moseq2_app.gui.progress import update_progress

output_file = join(progress_paths['plot_path'], 'model_vs_pc_changepoints')

best_model_fit = get_best_fit_model(progress_paths, plot_all=True)
progress_paths = update_progress(progress_filepath, 'model_path', best_model_fit['best model - duration'])

***

# Notebook End

## Go to the [Interactive-Model-Results.ipynb](./Interactive-Model-Results-Exploration.ipynb) Jupyter Notebook to analyze model results.