<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Label-the-Rodent's-Orientations-Within-Frame-Ranges" data-toc-modified-id="Label-the-Rodent's-Orientations-Within-Frame-Ranges-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Label the Rodent's Orientations Within Frame Ranges</a></span></li><li><span><a href="#Prepare-Train-Validation-Datasets" data-toc-modified-id="Prepare-Train-Validation-Datasets-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Prepare Train-Validation Datasets</a></span></li><li><span><a href="#Fit-or-Evaluate-the-Flip-Classifier-Model" data-toc-modified-id="Fit-or-Evaluate-the-Flip-Classifier-Model-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>Fit or Evaluate the Flip Classifier Model</a></span></li><li><span><a href="#Correct-Extracted-Dataset-Using-Train-Flip-Classifier-Model" data-toc-modified-id="Correct-Extracted-Dataset-Using-Train-Flip-Classifier-Model-4"><span class="toc-item-num">4&nbsp;&nbsp;</span>Correct Extracted Dataset Using Train Flip Classifier Model</a></span><ul class="toc-item"><li><span><a href="#Apply-a-flip-classifier-to-correct-the-extracted-dataset" data-toc-modified-id="Apply-a-flip-classifier-to-correct-the-extracted-dataset-4.1"><span class="toc-item-num">4.1&nbsp;&nbsp;</span>Apply a flip classifier to correct the extracted dataset</a></span></li><li><span><a href="#Preview-Corrected-Sessions" data-toc-modified-id="Preview-Corrected-Sessions-4.2"><span class="toc-item-num">4.2&nbsp;&nbsp;</span>Preview Corrected Sessions</a></span></li></ul></li></ul></div>

Flip classifiers are machine learning models that MoSeq2-Extract uses to ensure that the mouse is always extracted with the mouse's nose pointing to the right and tail to the left. This notebook is a streamlined utility and guide for preparing data and training a model that handles your specific data acquisition use case.

To use this notebook, you must first extract some data using MoSeq2-Extract to use as training data for the flip classifier model. 100K frames are optimal for training the flip classifier. 

This can be an iterative process if your data contains large amounts of flips throughout the extractions. On your first iteration, it is acceptable to extract the data without a flip-classifier. After training a new flip classifier, you may apply it to your dataset to correct the flips without having to re-extract the data before going into the PCA step.

# Set up the project structure

This initializes a Moseq project if it has not been created, else it uses an existing progress file.

Instructions:
- Run the cells
- Feel free to change the path names as needed

In [None]:
from os.path import join
from moseq2_app.gui.progress import check_progress, restore_progress_vars
from pathlib import Path
from moseq2_app.flip.train import CleanParameters, create_training_dataset, train_classifier, save_classifier 
from moseq2_app.flip.widget import FlipClassifierWidget, DisplayWidget
from dataclasses import fields
import os
import yaml
import panel as pn

base_dir = './'
progress_filepath = join(base_dir, 'progress.yaml')

progress_paths = restore_progress_vars(progress_filepath, init=True, overwrite=False)
check_progress(progress_filepath)

In [None]:
pn.extension()

# Labeling frames
**Instructions:**
- **Run the following cell** to launch the Data Labeller GUI.

- **Select the target session from the dropdown menu** and start labeling.

- **Drag the slider** to select a frame index to preview.

- **Click `Start Range`** to start selecting the range.
  - **Drag the slider** to the end of the range.
  - **Click `Facing Left` or `Facing Right`** to specify the correct orientation for the range of frames.
  - After specifying the orientation, the selected frames will be added to the dataset used to train the model.

- **Click `Cancel Select`** to cancel the selection.

**Note**: The `Current Total Selected` section turns green when there are enough labeled frames to train the model.

If your frame selection was interrupted for any reason, and you would like to relaunch the tool with all of your previously selected frame ranges, run the cell again. Feel free to change the flip_path name.

In [None]:
flip_path = "flip_classifier"
FF = FlipClassifierWidget(data_path=progress_paths['base_dir'], flip_path=flip_path)
FF.show()

# Train the model on the labeled frames

Here, we train a machine learning model on these labeled frames. We can use a Random Forest Classifier or Support Vector Machine. The model learns to identify the correct orientation of the mouse, and can be used for future steps of the MoSeq extraction pipeline. 

**Instructions**
- Run the following cells! 
- The two options for classifier are "svm" or "rf". Note that each of the model training has the following parameters:

### Random Forest:
rf_n_estimators (int): Number of trees in random forest.
rf_max_depth (int): Maximum depth of the tree.
rf_min_samples_split (int): Minimum samples required to split an internal node.
rf_min_samples_leaf (int): Minimum samples required to be at a leaf node.
rf_max_features (Union[str, int, float]): Number of features to consider for best split.

### SVM:
svm_C (float): Regularization parameter.
svm_kernel (str): Kernel type to be used in the algorithm.
svm_gamma (str or float): Kernel coefficient for 'rbf', 'poly' and 'sigmoid'.
svm_class_weight (dict or string): Class weights.

### Cross-Validation Splits:
cv_splits (int): Number of folds for cross-validation.
        

In [None]:
def initialize_clean_params(config_file):
    with open(config_file, 'r') as f:
        config_data = yaml.safe_load(f)
    params = {key: config_data[key] for key in config_data if key in {f.name for f in fields(CleanParameters)}}
    return CleanParameters(**params)

training_data_path = FF.train_file
CleanParameters = initialize_clean_params(progress_paths['config_file'])

dataset_path, validation_range = create_training_dataset(
    data_index_path=training_data_path,
    clean_parameters=CleanParameters,
    validation_size=0.1
)

In [None]:
dataset_path = Path(dataset_path)
classifier = "svm" # or rf
model_name = "flip_classifier.pkl"
clf = train_classifier(
    data_path=dataset_path, 
    classifier=classifier.upper()
)

In [None]:
save_path = os.path.join(FF.model_path, model_name)
save_classifier(clf, save_path)

# Display the results!
This cell displays the application of the trained model to the validation range (the portion of the data the model was not trained on). Depending on the results, you can refine the model or train more data.
**Instructions:**
- **Run the following cell** to display the results of your model.

In [None]:
validation_widget = DisplayWidget(
    data_path=progress_paths['base_dir'],
    classifier_path=save_path,
    validation_ranges_path=validation_range
)

***