# QIIME 2 Tutorial: Machine Learning

This notebook contains materials accompanying the workshop **Microbiome-Based Tools: From Research to Application**. The notebook and corresponding setup script were adapted from the [**Advanced Block Course: Computational Biology**](https://github.com/bokulich-lab/advanced-comp-bio-tutorial.git); all source code is licensed under the Apache License 2.0.

Save your own local copy of this notebook by using `File > Save a copy in Drive`. At some point you may be prompted to trust the notebook. We promise that it is safe 🤞

**Disclaimer:**

The Google Colab notebook environment will interpret any command as Python code by default. If we want to run bash commands we will have to prefix them by `!`. So any command you see with a leading `!` is a bash command and if you wanted to run it in your terminal you would omit the leading `!`. For example, if in the Colab notebook you ran `!wget` you would just run `wget` in your terminal. 

In this notebook we use the `!` prefix because we run all QIIME 2 commands using the [`q2cli`](https://github.com/qiime2/q2cli/) (QIIME 2 command-line interface). However, QIIME 2 also has a python API and a Galaxy interface. You can learn more about these and other QIIME 2 interfaces at https://qiime2.org/.

You can run the entire notebook by selecting `Runtime > Run all` from the menu in Google Colab. Some steps are time-comsuming and the entire notebook may take up to 30-60 minutes, so run the entire notebook now and we will inspect the commands and results as we work through as a class.

## Setup

QIIME 2 is usually installed by following the [official installation instructions](https://docs.qiime2.org/2023.9/install/). However, because we are using Google Colab and there are some caveats to using conda here, we will have to hack around the installation a little. But no worries, we provide a setup script below which does all this work for us. 😌

So let's start by pulling a local copy of the project repository down from GitHub.

In [None]:
! git clone https://github.com/bokulich-lab/uzh-microbiome-tutorial.git materials
! mkdir /content/prefetch_cache

We will switch to working within the `materials` directory for the rest of the notebook.

In [None]:
%cd materials

Now we are ready to set up our environment. This will take about 10 minutes.
**Note:** This setup is only relevant for Google Colaboratory and will not work on your local machine. Please follow the [official installation instructions](https://docs.qiime2.org/2023.9/install/) for that.

In [None]:
%run setup_qiime2

And we will use some Python packages below, so let's load these here:

In [None]:
import numpy as np
import pandas as pd
import seaborn as sns

## Predicting categorical data

*Supervised learning classifiers* predict the categorical metadata classes of unlabeled samples by learning the composition of labeled training samples. For example, we may use a classifier to diagnose or predict disease susceptibility based on stool microbiome composition, or predict sample type as a function of the sequence variants, microbial taxa, or metabolites detected in a sample. In this tutorial, we will use the read processing tutorial data to train a classifier that predicts body site from which a sample was collected.

Sections:
1. Train and test a categorical classifier
2. Optimize feature selection

### Training/testing classifier

We will train and test a classifier that predicts which body site a sample originated from, based on its microbial composition. We will do so using the `classify-samples` pipeline, which performs a series of steps:

1. The input samples are randomly split into a `training` set and a `test` set. The test set is held out until the end of the pipeline, allowing us to test accuracy on a set of samples that was not used for model training. The fraction of input samples to include in the test set is adjusted with the `--p-test-size` parameter.

2. We train the learning model using the training set samples. The model is trained to predict a specific `target` value for each sample (contained in a metadata column) based on the feature data associated with that sample. A range of different estimators can be selected using the `estimator` parameter; more details on individual estimators can be found in the [scikit-learn documentation](http://scikit-learn.org/stable/supervised_learning.html) (not sure which to choose? See the [estimator selection flowchart](http://scikit-learn.org/stable/tutorial/machine_learning_map/index.html)).

3. The trained model is used to predict the target values for each test sample, based on the feature data associated with that sample, and predict class probabilities for each sample. Class probabilities are the likelihood that a sample belongs to each class (i.e., group of samples with the same `target` value).

4. Model accuracy is calculated by comparing each test sample’s predicted value to the true value for that sample.

In [None]:
pd.read_csv("data/moving_pictures/moving_pictures_metadata.tsv",
            sep="\t",
            index_col=0,
            skiprows=[1])

In [None]:
! qiime sample-classifier classify-samples \
    --i-table data/moving_pictures/moving_pictures_table.qza \
    --m-metadata-file data/moving_pictures/moving_pictures_metadata.tsv \
    --m-metadata-column body-site \
    --p-estimator RandomForestClassifier \
    --p-random-state 123 \
    --output-dir rf_classifier

Use [QIIME 2 View](https://view.qiime2.org) to check out `accuracy_results.qzv`, which presents classification accuracy results in the form of a confusion matrix, as well as [Receiver Operating Characteristic (ROC) curves](https://en.wikipedia.org/wiki/Receiver_operating_characteristic).

**Question: What other metadata can we predict with `classify-samples`?** Take a look at the metadata columns in the `sample-metadata` and try some other categorical columns. Not all metadata can be easily learned by the classifier!

In [None]:
! qiime metadata tabulate \
    --m-input-file rf_classifier/predictions.qza \
    --o-visualization rf_classifier/predictions.qzv

In [None]:
! qiime metadata tabulate \
    --m-input-file rf_classifier/probabilities.qza \
    --o-visualization rf_classifier/probabilities.qzv

In [None]:
! qiime metadata tabulate \
    --m-input-file rf_classifier/test_targets.qza \
    --m-input-file rf_classifier/predictions.qza \
    --o-visualization rf_classifier/test_targets_predictions.qzv

A list of all features, and their relative importances (or feature weights or model coefficients, depending on the learning model used), will be reported in `feature_importance.qza`. Features with higher importance scores were more useful for distinguishing classes.

To understand which microbial taxa are associated with each feature, open with `data/moving_pictures/moving_pictures_taxonomy.qzv` side-by-side.

**Question: What are the 5 most important taxa in the model?**

In [None]:
! qiime metadata tabulate \
    --m-input-file rf_classifier/feature_importance.qza \
    --o-visualization rf_classifier/feature_importance.qzv

### Feature selection

If `--p-optimize-feature-selection` is enabled, only the selected features (i.e., the most important features, which maximize model accuracy, as determined using [recursive feature elimination](http://scikit-learn.org/stable/modules/feature_selection.html#recursive-feature-elimination)) will be reported in this artifact, and all other results (e.g., model accuracy and predictions) that are output use the final, optimized model that utilizes this reduced feature set. This allows us to not only see which features are most important (and hence used by the model). 

In [None]:
! qiime sample-classifier classify-samples \
    --i-table data/moving_pictures/moving_pictures_table.qza \
    --m-metadata-file data/moving_pictures/moving_pictures_metadata.tsv \
    --m-metadata-column body-site \
    --p-optimize-feature-selection \
    --p-estimator RandomForestClassifier \
    --p-n-estimators 20 \
    --p-random-state 123 \
    --output-dir rf_opt_classifier

In [None]:
! qiime metadata tabulate \
    --m-input-file rf_opt_classifier/feature_importance.qza \
    --o-visualization rf_opt_classifier/feature_importance.qzv

**Question: What are the 5 most important taxa in this model?** How do they differ from the five most important taxa in the previous model?

We can use that information to filter out uninformative features from our feature table for other downstream analyses outside of `q2-sample-classifier`.

In [None]:
# Optional
! qiime feature-table filter-features \
    --i-table data/moving_pictures/moving_pictures_table.qza \
    --m-metadata-file rf_opt_classifier/feature_importance.qza \
    --o-filtered-table rf_opt_classifier/important_feature_table.qza

We can also use the `heatmap` pipeline to generate an abundance heatmap of the most important features in each sample or group. Let’s make a heatmap of the top 30 most abundant features in each of our sample types.

In [None]:
! qiime sample-classifier heatmap \
    --i-table data/moving_pictures/moving_pictures_table.qza \
    --i-importance rf_opt_classifier/feature_importance.qza \
    --m-sample-metadata-file data/moving_pictures/moving_pictures_metadata.tsv \
    --m-sample-metadata-column body-site \
    --p-group-samples \
    --p-feature-count 30 \
    --o-filtered-table rf_opt_classifier/important_feature_table_top_30.qza \
    --o-heatmap rf_opt_classifier/important_feature_heatmap.qzv

**Note:** The model we trained here is a toy example containing very few samples from a single study and will probably not be useful for predicting other unknown samples. But if you have samples from one of these body sites, it could be a fun exercise to give it a spin!

## Predicting continuous data

Supervised learning models can also predict continuous metadata values of samples -- we call these models regressors. For example, we may use a regressor to predict the abundance of a metabolite that will be produced by a microbial community, or a sample’s pH, temperature, or altitude as a function of the sequence variants, microbial taxa, or metabolites detected in a sample. 

In this section, we will predict continuous sample data from the previous dataset and from the [ECAM study](https://doi.org/10.1126/scitranslmed.aad7121), a longitudinal study of microbiome development in US infants.

Sections:
1. Predict on previous dataset ("moving pictures")
2. Predict on ECAM dataset

### Regression on moving pictures dataset

In [None]:
! qiime sample-classifier regress-samples \
    --i-table data/moving_pictures/moving_pictures_table.qza \
    --m-metadata-file data/moving_pictures/moving_pictures_metadata.tsv \
    --m-metadata-column days-since-experiment-start \
    --p-estimator RandomForestRegressor \
    --output-dir mp_regressor

### Regression on ECAM dataset

In [None]:
pd.read_csv("data/ecam/ecam_metadata.tsv",
            sep="\t",
            index_col=0,
            skiprows=[1])

In [None]:
! qiime sample-classifier regress-samples \
    --i-table data/ecam/ecam_table.qza \
    --m-metadata-file data/ecam/ecam_metadata.tsv \
    --m-metadata-column month \
    --p-estimator RandomForestRegressor \
    --output-dir ecam_regressor

**Question: How differently did the regression models perform on the moving pictures data vs. on the ECAM data?** Why might that be the case?

## Nested cross-validation

In the examples above, we split the data sets into training and test sets for model training and testing. It is essential that we keep a test set that the model has never seen before for validating model performance. But what if we want to predict target values for each sample in a data set? For that, we use nested cross validation (NCV). This can be valuable in a number of different cases, e.g. predicting mislabeled samples (those that are classified incorrectly during NCV) or for assessing estimator variance (since multiple models are trained during NCV, we can look at the variance in their accuracy).

In [None]:
! qiime sample-classifier classify-samples-ncv \
    --i-table data/moving_pictures/moving_pictures_table.qza \
    --m-metadata-file data/moving_pictures/moving_pictures_metadata.tsv \
    --m-metadata-column body-site \
    --p-estimator RandomForestClassifier \
    --p-random-state 123 \
    --output-dir moving_pictures_ncv

In [None]:
! qiime sample-classifier confusion-matrix \
    --i-predictions moving_pictures_ncv/predictions.qza \
    --i-probabilities moving_pictures_ncv/probabilities.qza \
    --m-truth-file data/moving_pictures/moving_pictures_metadata.tsv \
    --m-truth-column body-site \
    --o-visualization moving_pictures_ncv/ncv_confusion_matrix.qzv

In [None]:
! qiime sample-classifier regress-samples-ncv \
    --i-table data/ecam/ecam_table.qza \
    --m-metadata-file data/ecam/ecam_metadata.tsv \
    --m-metadata-column month \
    --p-estimator RandomForestRegressor \
    --p-random-state 123 \
    --output-dir ecam_ncv

In [None]:
! qiime sample-classifier scatterplot \
    --i-predictions ecam_ncv/predictions.qza \
    --m-truth-file data/ecam/ecam_metadata.tsv \
    --m-truth-column month \
    --o-visualization ecam_ncv/ecam_scatterplot.qzv

## Notes

### Warning

Just as with any statistical method, the actions described in this plugin require adequate sample sizes to achieve meaningful results. As a rule of thumb, a minimum of [approximately 50 samples](http://scikit-learn.org/stable/tutorial/machine_learning_map/index.html) should be provided. Categorical metadata columns that are used as classifier targets should have a minimum of 10 samples per unique value, and continuous metadata columns that are used as regressor targets should not contain many outliers or grossly uneven distributions. Smaller counts will result in inaccurate models, and may result in errors.


### Best practices for `q2-sample-classifier`
As this tutorial has demonstrated, q2-sample-classifier can be extremely powerful for feature selection and metadata prediction. However, with power comes responsibility. Unsuspecting users are at risk of committing grave errors, particularly from overfitting and data leakage. Here follows a list (though inevitably incomplete) of ways that users can abuse this plugin, yielding misleading results. Do not do these things. More extensive guides exist for avoiding data leakage and overfitting in general, so this list focuses on bad practices that are particular to this plugin and to biological data analysis.

1. Data leakage occurs whenever a learning model learns (often inadvertently) about test sample data, leading to unduly high performance estimates.

    - Model accuracy should always be assessed on test data that has never been seen by the learning model. The pipelines and nested cross-validation methods in q2-sample-classifier (including those described in this tutorial) do this by default. However, care must be taken when using the fit-* and predict-* methods independently.

    - In some situations, technical replicates could be problematic and lead to pseudo-data leakage, depending on experimental design and technical precision. If in doubt, group your feature table to average technical replicates, or filter technical replicates from your data prior to supervised learning analysis.

2. Overfitting occurs whenever a learning model is trained to overperform on the training data but, in doing so, cannot generalize well to other data sets. This can be problematic, particularly on small data sets and whenever input data have been contorted in inappropriate ways.

    - If the learning model is intended to predict values from data that is produced in batches (e.g., to make a diagnosis on microbiome sequence data that will be produced in a future analysis), consider incorporating multiple batches in your training data to reduce the likelihood that learning models will overfit on batch effects and similar noise.

    - Similarly, be aware that batch effects can strongly impact performance, particularly if these are covariates with the target values that you are attempting to predict. For example, if you wish to classify whether samples belong to one of two different groups and those groups were analyzed on separate sequencing runs (for microbiome amplicon sequence data), training a classifier on these data will likely lead to inaccurate results that will not generalize to other data sets.