## Tutorial of DeepAdapter
### A self-adaptive and versatile tool for eliminating multiple undesirable variations from transcriptome
In this notebook, you will learn how to re-train DeepAdapter with the example dataset.

## 1. Installation and requirements
### 1.1. Installation
To run locally, please open a terminal window and download the code with:
```sh
$ # Clone this repository to your local computer
$ git clone https://github.com/mjDelta/DeepAdapter.git
$ cd DeepAdapter
$ # create a new conda environment
$ conda create -n deepAdapter python=3.9
$ # activate environment
$ conda activate deepAdapter
$ # Install dependencies
$ pip install -r requirements.txt
$ # Launch jupyter notebook
$ jupyter notebook
```
### 1.2. Download datasets
Please download the open datasets in [Zenodo](https://zenodo.org/records/10494751).
These datasets are collected from literatures to demonstrate multiple unwanted variations, including:
* batch datasets: LINCS-DToxS ([van Hasselt et al. Nature Communications, 2020](https://www.nature.com/articles/s41467-020-18396-7)) and Quartet project ([Yu, Y. et al. Nature Biotechnology, 2023](https://www.nature.com/articles/s41587-023-01867-9)).
* platform datasets: profiles from microarray ([Iorio, F. et al. Cell, 2016](https://www.cell.com/cell/pdf/S0092-8674(16)30746-2.pdf)) and RNA-seq ([Ghandi, M. et al. Nature, 2019](https://www.nature.com/articles/s41586-019-1186-3)).
* purity datasets: profiles from cancer cell lines ([Ghandi, M. et al. Nature, 2019](https://www.nature.com/articles/s41586-019-1186-3)) and tissues ([Weinstein, J.N. et al. Nature genetics, 2013](https://www.nature.com/articles/ng.2764)).

After downloading, place the datasets in the `data/` directory located in the same hierarchy as this tutorial.
* batch datasets: `data/batch_data/`
* platform datasets: `data/platform_data/`
* purity datasets: `data/purity_data/`
  
**Putting datasets in the right directory is important for loading the example datasets successfully.**

To execute a "cell", please press Shift+Enter

## 2. Load the datasets and preprocess
### 2.1. load the modules
There are three modules in DeepAdapter:
* `models`: network structure and training process are defined here.
* `utils`: triplet, decompostion and other utils are defined here.
* `params`: you can revise the parameter settings here.
We load the default parameters `dl_params` for this tutorial.

In [None]:
%load_ext autoreload
%autoreload 2

import os, sys
import pandas as pd
import numpy as np

from utils import data_utils as DT
from utils import utils as UT
import utils.triplet as TRP
from models.trainer import Trainer
from models.data_loader import TransData, DataLoader
from models.dl_utils import AE, FBatch


### 2.2. Load the demonstrated datasets
We ultilize Batch-LINCS for demonstration. To load datasets of platform and purity variations, please download them in Zenodo (https://zenodo.org/records/10494751).
  * In the tutorial, we have **data** for gene expression, **batches** for unwanted variations, and **donors** for biological signals.
  * In training your own DeepAdapter, please refer to `DeepAdapter-YourOwnData-Tutorial.ipynb`.

In [None]:
loadTransData = DT.LoadTransData()
data, batches, wells, donors, infos, test_infos = loadTransData.load_lincs_lds1593()
ids = np.arange(len(data))

### 2.3. Preprocess the transcriptomic data
The gene expression profiles are preprocessed by sample normalization, gene ranking, and log normalization. Let $S_i = \sum_l x_{i l}$ denote the sum over all genes. In sample normalization, we divide $S_i$ for every sample and multiply a constant 10000 ([Xiaokang Yu et al. Nature communications, 2023](https://www.nature.com/articles/s41467-023-36635-5)):
$$x_{i l} = \frac{x_{i l}}{S_i} 10^4.$$
Then, we sort genes by their expression levels and perform the log transformation $x_{i l} = \log {(x_{i l} + 1)}$.

In [None]:
prepTransData = DT.PrepTransData()
raw_df = prepTransData.sample_norm(data)
raw_df, sorted_cols = prepTransData.sort_genes_sgl_df(raw_df)
input_arr = prepTransData.sample_log(raw_df)
bat2label, label2bat, unwanted_labels, unwanted_onehot = prepTransData.label2onehot(batches)

## 3. Train DeepAdapter
### 3.1. Adjust DeepAdapter's parameters
The parameters for DeepAdapter are as follows (**Note: you can open `params.dl_params.py` and revise the parameters.**):
* **epochs**: the total training epochs of DeepAdapter, default = $150000$
* **ae_epochs**: the warmup epochs of autoencoder in DeepAdapter, default = $400$
* **batch_epochs**: the warmup epochs of discriminator in DeepAdapter, default = $50$
* **batch_size**: the batch size of dataloader, default = $256$
* **hidden_dim**: the hidden units of autoencoder in DeepAdapter, default = $256$
* **z_dim**: the latent units of autoencoder in DeepAdapter, default = $128$
* **drop**: the dropout rate of DeepAdapter, default = $0.3$
* **lr_lower_ae**: the lower learning rate of autoencoder in DeepAdapter, default = $1e-5$
* **lr_upper_ae**: the upper learning rate of autoencoder in DeepAdapter, default = $5e-4$
* **lr_lower_batch**: the lower learning rate of discriminator in DeepAdapter, default = $1e-5$
* **lr_upper_batch**: the upper learning rate of discriminator in DeepAdapter, default = $5e-4$

In [None]:
from params import dl_params as DLPARAM
net_args = DLPARAM.load_dl_params()

### 3.2. Split dataset
* For the tutorial, we extract the biosamples across all batches as the test set; then split the rest into training and validation set randomly.</br>
That means the training data seen by DeepAdapter doesn't disperse across all unwanted variations while the testing data does.</br>
Acutally, this split method could increase the training difficulty.
* For your own dataset, you can split the dataset randomly using the function `DT.data_split_random`.</br>
Please refer to `DeepAdapter-YourOwnData-Tutorial.ipynb`.

In [None]:
train_data, train_labels, train_labels_hot, \
    val_data, val_labels, val_labels_hot, \
    test_data, test_labels, test_labels_hot, \
    train_ids, val_ids, test_ids, \
    tot_train_val_idxs, tot_train_idxs, tot_val_idxs, tot_test_idxs = DT.data_split_lds1593(input_arr, unwanted_labels, unwanted_onehot, ids, infos, test_infos)

In [None]:
train_bios, val_bios, test_bios = donors[tot_train_idxs], donors[tot_val_idxs], donors[tot_test_idxs]

In [None]:
bio_label2bat = {t:t for t in set(train_bios)}

### 3.3. Train DeepAdapter
Two options are provided for training DeepAdapter. If you want to learn the training process, please train it step by step. If you want to skip these initializations, please use the one-line code :)!
* To train DeepAdapter step by step, you need to initialize models, dataloaders, trainer, and the mutual nearest neighbors.
* To train DeepAdapter in one-line code, just utilize `deepAdapter.run.train()`.

#### 3.3.1. Train it step by step

In [None]:
db_name = "LDS1593"
out_dir = os.path.join("model/batch", "deepAligner_LINCS_batch/stepByStep_{}/".format(db_name))
os.makedirs(out_dir, exist_ok = True)

In [None]:
## initialize models
in_dim = input_arr.shape[1]
num_unw_vars = len(bat2label)
ae = AE(in_dim, net_args.hidden_dim, num_unw_vars, net_args.z_dim, net_args.drop).cuda()
fbatch = FBatch(net_args.hidden_dim, num_unw_vars, net_args.z_dim, net_args.drop).cuda()

## initialize dataloaders
train_trans = TransData(train_data, train_labels, train_bios, train_ids, train_labels_hot)
train_loader = DataLoader(train_trans, batch_size = net_args.batch_size, collate_fn = train_trans.collate_fn, shuffle = True, drop_last = False)
val_trans = TransData(val_data, val_labels, val_bios, val_ids, val_labels_hot)
val_loader = DataLoader(val_trans, batch_size = net_args.batch_size, collate_fn = val_trans.collate_fn, shuffle = False, drop_last = False)
test_trans = TransData(test_data, test_labels, test_bios, test_ids, test_labels_hot)
test_loader = DataLoader(test_trans, batch_size = net_args.batch_size, collate_fn = test_trans.collate_fn, shuffle = False, drop_last = False)

## initialize trainer
trainer = Trainer(train_loader, val_loader, test_loader, ae, fbatch, bio_label2bat, label2bat, net_args, out_dir)

## initialize mutual nearest neighbors
train_mutuals = TRP.find_MNN_cosine_kSources(train_data, train_labels, train_ids)
val_mutuals = TRP.find_MNN_cosine_kSources(val_data, val_labels, val_ids)

## begin training!
trainer.fit(train_mutuals, val_mutuals)

#### 3.3.2. Train it in one-line code
Parameters for one-line code training:
* **train_list**: the list of training transcriptomic profiles, unwanted variations, biological signals, data ids, and onehot representations of unwanted variations.
* **val_list**: the list of validation transcriptomic profiles, unwanted variations, biological signals, data ids, and onehot representations of unwanted variations.
* **test_list**: the list of testing transcriptomic profiles, unwanted variations, biological signals, data ids, and onehot representations of unwanted variations.
* **label2unw**: the dictionary which maps unwanted labels (e.g., 0, 1 ...) to unwanted variations (e.g., batch1, batch2 ...)
* **label2wnt**: the dictionary which maps biological labels (e.g., 0, 1 ...) to biological annotations (e.g., donor1, donor2 ...)
* **net_args**: the parameters to construct DeepAdapter
* **out_dir**: the out directory for saved models and logged losses.

In [None]:
db_name = "LDS1593"
out_dir = os.path.join("model/batch", "deepAligner_LINCS_batch/oneLineCode_{}/".format(db_name))
os.makedirs(out_dir, exist_ok = True)

In [None]:
train_list = [train_data, train_labels, train_bios, train_ids, train_labels_hot]
val_list = [val_data, val_labels, val_bios, val_ids, val_labels_hot]
test_list = [test_data, test_labels, test_bios, test_ids, test_labels_hot]

from deepAdapter import run as RUN
trainer = RUN.train(
    train_list = train_list, 
    val_list = val_list, 
    test_list = test_list, 
    label2unw = label2bat, 
    label2wnt = bio_label2bat, 
    net_args = net_args, 
    out_dir = out_dir)

## 4. Align the data
### 4.1. Load trained model & quantatitive evaluation
* Step 1: load the best-trained model
* Step 2: utilize `trainer.evaluate()`

In `trainer.evaluate()`, we perform decomposition analysis of aligned data and perform the quantatitive analysis including alignment score, ASW, NMI, and ARI calcuation. The quantatitive results are recorded in `record_path`.

In [None]:
trainer.load_trained_ae(os.path.join(out_dir, "ae.tar"))

record_path = os.path.join(out_dir, "test_res.csv")
test_data, test_aligned_data, test_wnt_infs, test_unw_infs = trainer.evaluate(record_path, db_name, test_loader)

Additionally, you can perform any other analysis you like with the aligned data `aligned_data`!

### 4.2. Save the aligned data

In [None]:
all_trans = TransData(np.vstack((train_data, val_data, test_data)), 
                      np.array(list(train_labels) + list(val_labels) + list(test_labels)),
                      np.array(list(train_bios) + list(val_bios) + list(test_bios)),
                      np.array(list(train_ids) + list(val_ids) + list(test_ids)), 
                      np.vstack((train_labels_hot, val_labels_hot, test_labels_hot)))
all_loader = DataLoader(all_trans, batch_size = net_args.batch_size, collate_fn = all_trans.collate_fn, shuffle = False, drop_last = False)
record_path = os.path.join(out_dir, "res.csv")
data, aligned_data, wnt_infs, unw_infs = trainer.evaluate(record_path, db_name, all_loader)

save_path = os.path.join(out_dir, "DA_data.csv")
df = pd.DataFrame(data, columns = sorted_cols)
df["ID"] = np.array(list(train_ids) + list(val_ids) + list(test_ids))
df["wantInfo"] = wnt_infs
df["unwantInfo"] = unw_infs
df.to_csv(save_path, index = False)