# How to Train an Autoencoder for single-cell RNA-seq data
<img src="imgs/CLAIM_VanillaAE_basic.png" class="center" width=500/>

**Authorship:**
Adam Klie, *01/22/2022*
***
**Description:**
Notebook template for building and training an autoencoder for single-cell omics analysis

**Notes:**
 - **Environment:** You need a Jupyter `python3` kernel with PyTorch and Sklearn installed. [See instructions for setup here.]( https://www.notion.so/Autoencoder-Workshop-73d10091ac014f8c966a503e02759b11)
 - **GPU Usage:** The default data and model used below are lightweight enough to be trained on a cpu, but if you'd like to train on a larger dataset with more parameters, I would recommend opening a GPU backed notebook. If you are using the `ml_env` kernel described in the above environment setup, you simply need to run the following after logging onto the cluster:
 
 ```bash
 module load cuda10.2
jupyter-submit -p carter-gpu -A carter-gpu -t 05-00:00:00 -c 4 -m 16G -g 1 -I
```
***

# Setup Packages
Here we load the typical base packages we will utilize throughout the exercise. We will also need to import the PyTorch library and check if we are on a GPU node.

In [None]:
# Classic imports
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# Autoreload extension
if 'autoreload' not in get_ipython().extension_manager.loaded:
    %load_ext autoreload
    
%autoreload 2

  If you are on a GPU, the function will return `True` and will tell you which GPU(s) you are using.

In [None]:
import torch
print("Using a GPU? {}".format(torch.cuda.is_available()))
if torch.cuda.is_available():
    print("Device number [0-7]: {}".format(torch.cuda.current_device()))
    print("Device count: {}".format(torch.cuda.device_count()))

In [None]:
# You can use this variable to keep track of how many epochs you've trained a model for
epochs_trained = 0

# Load dataset
Here we load the preprocessed **pbmc3k** dataset. This dataset captures Peripheral Blood Mononuclear Cells (PBMC) freely available from 10X Genomics. There were originally 2,700 single cells that were sequenced on the Illumina NextSeq 500. Here we load in the raw counts for the variable genes in high quality cells. See the `Collect_Datasets_and_Preprocess.ipynb` and `Collect_Cell_Type_Labels.ipynb` notebooks for more details on the dataset and the preprocessing steps applied.

In [None]:
# Load the raw counts of highly variable genes. We have less than 2700 cells due to previous filtering
raw_counts = pd.read_csv("data/pbmc3k_raw_var_genes.tsv", index_col=0, sep="\t")
num_genes = raw_counts.shape[0] # rows
num_cells = raw_counts.shape[1] # columns
print("Dataset contains {} genes across {} cells".format(num_genes, num_cells))

## Standardize inputs
Here we scale each genes expression to mean 0 and standard deviation 1 across all cells. This will improve convergence properties during training

In [None]:
from sklearn.preprocessing import StandardScaler

In [None]:
# StandardScaler prefers samples x features 
scaler = StandardScaler()
scaler.fit(raw_counts.T)
scaled_counts = scaler.transform(raw_counts.T)
scaled_counts.shape, scaled_counts.mean(axis=0), scaled_counts.std(axis=0)  # Double check scaling was done correctly

## Instantiate the dataloader
[DataLoaders](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html) are fundamental PyTorch objects that interface the data you want to train with to the model you want to train. A DataLoader is essentially a Python iterator that can be looped through to pull "batches" of data at time. These "batches" of data are passed to the model during training. 

In [None]:
from torch.utils.data import TensorDataset, DataLoader

In [None]:
# Build a TensorDataset object from the scaled array, note that PyTorch likes Tensors, not numpy arrays
dataset = TensorDataset(torch.from_numpy(scaled_counts))

In [None]:
# Build a simple DataLoaders from the Dataset object
loader = DataLoader(dataset, batch_size=256, shuffle=False)

In [None]:
# Check dims of loaders for correctness (should match n_cells x torch.Size(n_genes))
print("Dimensions of training set: {} x {}".format(len(loader.dataset), loader.dataset[0][0].shape))

# The autoencoder model
Here we initialize a predefined autoencoder architecture (see `autoencoders.py`). The encoder's job is to take an input vector and output a lower-dimensional latent embedding; the decoder does the opposite, taking the latent embedding to reconstruct/output the input vector. The decoder architecture is a mirror of the encoder's. The input size of the encoder corresponds to the dimensions of our features (variable genes) and outputs a latent embedding of size 10. 

We initialize the weights using a Kaiming Uniform or He initialization (see `init_weights.py`). Each linear layer is activated with the ReLU function.

<div class="alert alert-block alert-info">
<b>TODO</b>

We provide the most basic autoencoder architecture here. This is your opportunity to explore parameters i.e. the number of hidden layers, the width per layer, etc. or try something completely different!
</div>

In [None]:
# Load predefined model and weight initializer
from scripts.utils import init_weights
from scripts.autoencoders import VanillaAE

In [None]:
# Instantiate model and move to gpu if available
model = VanillaAE(raw_counts.shape[0])
model.apply(init_weights)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Executing the model on:", device)
model.to(device)

Let's check what our model looks like, using a handy summary function.

In [None]:
from torchinfo import summary

In [None]:
# We will assume a batch size of 256. 
# The summary function expects you to include the expected input size as a parameter along with a batch size
summary(model, input_size=(256, num_genes))

Next let's test out our autoencoder structure with our initialized parameters. You can use the following code block to check to see if things are being output the way you would expect. By default, the encoded dimension size should be 10 and the size of the decoded dimesion should match the variable genes.

In [None]:
# Grab some seqs and outputs to test out on
indexes = np.random.choice(scaled_counts.shape[0], size=5)
random_cells = torch.from_numpy(scaled_counts[indexes]).float().to(device)

# Feed through encoder and to get bottleneck size
encoded_outputs = model.encoder(random_cells).squeeze(dim=1)

# Feed through encoder and decoder to get full output size
outputs = model(random_cells).squeeze(dim=1)
print("Encoder output dimension: {}\nDecoder output dimension: {}".format(encoded_outputs.shape[1], outputs.shape[1]))

# Set training optimization parameters

Before we train our model, we need to instantiate a loss function that we are aiming to optimize and an algorithm for conducting that optimization. Here we use **MSE loss** to model the reconstruction of gene expression counts, and optimize using the **adaptive momentum (Adam) algorithm**. MSE and Adam are popular in deep learning (don't worry, I'm no DanQ), but feel free to play with others or define your own!

<div class="alert alert-block alert-info">
<b>TODO</b>
    
Try out different optimization strategies and loss functions. This is general to neural networks and not specific to autoencoders, but still an important set of hyperparameters to consider.

 - PyTorch optimizers: https://pytorch.org/docs/stable/optim.html
 - PyTorch loss functions: https://pytorch.org/docs/stable/nn.html#loss-functions
</div>

In [None]:
import torch.nn as nn

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.MSELoss(reduction='sum')

# Train the model
Here we actually optimize our defined loss function via our autoencoder inputs and reconstructions. We are making use of a training function designed for autoencoders that can be found in the `train.py` file. We will also use the livelossplot package to visualize our loss across training. If all goes well, you should see loss decrease at each epoch (aka iteration through the dataset), something along the lines of:

![loss_plot](imgs/example_loss.png)

In [None]:
from scripts.train import train_autoencoder
import time

<div class="alert alert-block alert-info">
<b>TODO: Make-it-train!</b>
    
Define the number of epochs and how often to update the loss plot (default every 10 epochs). Note that updating too often will slow down training.
</div>

In [None]:
num_epochs = 100
plot_frequency = 10

In [None]:
# Make-it-train!
tic = time.perf_counter()
loss_history, _ = train_autoencoder(model=model, 
                                    dataloader=loader, 
                                    criterion=criterion, 
                                    optimizer=optimizer,
                                    device=device,
                                    num_epoch=num_epochs,
                                    plot_frequency=plot_frequency)
toc = time.perf_counter()
epochs_trained += num_epochs
print(f"Trained {num_epochs:d} epochs in {toc - tic:0.4f} seconds")

## Save trained model
If you are happy with your model, save it's parameters. You can always load it in later for interpretation or to do more training.

In [None]:
import os

In [None]:
if not os.path.exists("models"):
    os.makedirs("models")
torch.save(model.state_dict(), "models/prelim_model_{}.pt".format(epochs_trained))

# Visualize latent space
We are now ready to investigate that latent space our model has learned. We leave coming up with an awesome new analysis to the user, but we wrote some code for you to generate a two dimensional visualization of your latent space using both PCA and UMAP reduction. We added cell-type labels from the [Seurat guided clustering tutorial](https://satijalab.org/seurat/articles/pbmc3k_tutorial.html). Do you see separation between Seurat's cell-type labels?

<div class="alert alert-block alert-info">
<b>TODO: Visualize</b>
    
Use and modify the `visualize()` function to plot our model's embedding.
</div>

In [None]:
from scripts.utils import visualize

In [None]:
latent_data = model.encoder(loader.dataset.tensors[0].float().to(device)).detach().cpu().numpy()

In [None]:
cell_ids = [col[0] for col in raw_counts.columns.str.split("-")]

In [None]:
visualize(latent_embedding=latent_data, cellids=cell_ids, metadata_file="data/pbmc3k_SeuratMetadata.tsv")

## So how'd you do?

Since we used Scanpy to download the data, we will start by comparing to them. Here's what Scanpy's tutorial outputs on 40 PCs, 10 nearest neighbors and default UMAP settings. They use the same preprocessing strategy we used for this tutorial.

![scanpy_pbm3k](imgs/pbmc3k_100_1.png)

What about Seurat? They have a slightly different pipeline. These were the labels you used for you visualization as well!

![seurat_pbm3k](imgs/pbmc3k_Seurat.png)

# References