## scGen batch correction and analysis of Endothelial Cells
By Monika Litvinukova <br/>
Latest update 23.07.20

### Import required modules

In [None]:
import scgen
import anndata
import numpy as np
import scanpy as sc
import pandas as pd
import tensorflow as tf

In [None]:
sc.settings.verbosity = 3
sc.logging.print_version_and_date()
sc.settings.set_figure_params(dpi = 260, color_map = 'RdPu', dpi_save = 260, vector_friendly = True, format = 'svg')

### Fixed function to correct batches.

- This was modified for `scGen` not to jumble the `adata` index and to keep `adata.raw`.
- Once `scGen` updates this in the new version, we can skip it.

In [None]:
import os
from random import shuffle

import anndata
import numpy as np
import scanpy as sc
from matplotlib import pyplot as plt
from scipy import sparse
from sklearn import preprocessing
import pandas as pd

import scgen


def batch_removal_ct5(network, adata, batch_key="batch", cell_label_key="cell_type"):
    """
        Removes batch effect of adata
        # Parameters
        network: `scgen VAE`
            Variational Auto-encoder class object after training the network.
        adata: `~anndata.AnnData`
            Annotated data matrix. adata must have `batch_key` and `cell_label_key` which you pass to the function
             in its obs.
        # Returns
            corrected: `~anndata.AnnData`
                Annotated matrix of corrected data consisting of all cell types whether they have batch effect or not.
        # Example
        ```python
        import scgen
        import anndata
        train = anndata.read("data/pancreas.h5ad")
        train.obs["cell_type"] = train.obs["celltype"].tolist()
        network = scgen.VAEArith(x_dimension=train.shape[1], model_path="./models/batch")
        network.train(train_data=train, n_epochs=20)
        corrected_adata = scgen.batch_removal(network, train)
        ```
     """
    if sparse.issparse(adata.X):
        latent_all = network.to_latent(adata.X.A)
    else:
        latent_all = network.to_latent(adata.X)
    adata_latent = anndata.AnnData(latent_all)
    adata_latent.obs = adata.obs.copy(deep=True)
    unique_cell_types = np.unique(adata_latent.obs[cell_label_key])
    shared_ct = []
    not_shared_ct = []
    for cell_type in unique_cell_types:
        temp_cell = adata_latent[adata_latent.obs[cell_label_key] == cell_type]
        if len(np.unique(temp_cell.obs[batch_key])) < 2:
            cell_type_ann = adata_latent[adata_latent.obs[cell_label_key] == cell_type]
            not_shared_ct.append(cell_type_ann)
            continue
        temp_cell = adata_latent[adata_latent.obs[cell_label_key] == cell_type]
        batch_list = {}
        batch_ind = {}
        max_batch = 0
        max_batch_ind = ""
        batches = np.unique(temp_cell.obs[batch_key])
        for i in batches:
            temp = temp_cell[temp_cell.obs[batch_key] == i]
            temp_ind = temp_cell.obs[batch_key] == i
            if max_batch < len(temp):
                max_batch = len(temp)
                max_batch_ind = i
            batch_list[i] = temp
            batch_ind[i] = temp_ind
        max_batch_ann = batch_list[max_batch_ind]
        for study in batch_list:
            delta = np.average(max_batch_ann.X, axis=0) - np.average(batch_list[study].X, axis=0)
            batch_list[study].X = delta + batch_list[study].X
            temp_cell[batch_ind[study]].X = batch_list[study].X
        shared_ct.append(temp_cell)
    all_shared_ann = anndata.AnnData.concatenate(*shared_ct, batch_key="concat_batch", index_unique=None)
    if "concat_batch" in all_shared_ann.obs.columns:
        del all_shared_ann.obs["concat_batch"]
    if len(not_shared_ct) < 1:
        corrected = anndata.AnnData(network.reconstruct(all_shared_ann.X, use_data=True))
        corrected.obs = all_shared_ann.obs.copy(deep=True)
        corrected.var_names = adata.var_names.tolist()
        corrected = corrected[adata.obs_names]
        if adata.raw is not None:
            adata_raw = anndata.AnnData(X=adata.raw.X, var=adata.raw.var)
            adata_raw.obs_names = adata.obs_names
            corrected.raw = adata_raw
        return corrected
    else:
        all_not_shared_ann = anndata.AnnData.concatenate(*not_shared_ct, batch_key="concat_batch", index_unique=None)
        all_corrected_data = anndata.AnnData.concatenate(all_shared_ann, all_not_shared_ann, batch_key="concat_batch",  index_unique=None)
        if "concat_batch" in all_shared_ann.obs.columns:
            del all_corrected_data.obs["concat_batch"]
        corrected = anndata.AnnData(network.reconstruct(all_corrected_data.X, use_data=True), )
        corrected.obs = pd.concat([all_shared_ann.obs, all_not_shared_ann.obs])
        corrected.var_names = adata.var_names.tolist()
        corrected = corrected[adata.obs_names]
        if adata.raw is not None:
            adata_raw = anndata.AnnData(X=adata.raw.X, var=adata.raw.var)
            adata_raw.obs_names = adata.obs_names
            corrected.raw = adata_raw
        return corrected

### Read in data
the objects have raw counts in the `adata.X` and no raw

In [None]:
ec_raw = sc.read_h5ad('/home/jovyan/mona_data/2-EC/heart_ec_merged_annotated_ml200423_RAW.h5ad')
ec_raw.shape

In [None]:
ec_raw.obs['cell_types'].cat.categories

In [None]:
trans_from = [['EC_C0_Cap','EC_C2_art', 'EC_C3_cap','EC_D0_cap','EC_D2_cap','EC_H0_cap', 'EC_H1_cap','EC_H4_cap','EC_S0_cap','EC_S2_cap',],
              ['EC_C1_art','EC_D1_art','EC_S4_art','Ec_H2_art'],
              ['EC_C4_ven','EC_C6_ven','EC_D4_ven','EC_S3_ven','EC_S5_ven','Ec_H3_ven'],
              ['EC_C8_lnEC','EC_D6_ln','EC_H9_ln'],
              ['EC_C7_endo','EC_H5_endo','EC_S1_endo'],
              ['EC_C5_stromal','EC_D3_stromal','EC_H6_stromal','EC_S7_stromal'],
              ['EC_H8_fb','EC_S8_fb'],
              ['EC_D5_cmc','EC_H7_cmc', 'EC_S9_cmc'],
              ['EC_C10_meso','Ec_S6_meso'],
              ['EC_C9_RBC']]

trans_to = ['EC_cap', 'EC_art', 'EC_ven', 'lnEC', 'EC_endo', 'EC_stromal', 'FB_like_EC', 'CMC_like_EC', 'Meso', 'RBC']

ec_raw.obs['cell_type'] = [str(i) for i in ec_raw.obs['cell_types']]
for leiden,celltype in zip(trans_from, trans_to):
    for leiden_from in leiden:
        ec_raw.obs['cell_type'][ec_raw.obs['cell_type'] == leiden_from] = celltype
ec_raw.shape

In [None]:
ec_raw.obs['cell_type'] = ec_raw.obs['cell_type'].astype('category')
ec_raw.obs['cell_type'].cat.categories

### Normalise and scale the dataset

In [None]:
ec = ec_raw.copy()
sc.pp.normalize_per_cell(ec)
sc.pp.log1p(ec)
ec.raw = ec.copy()

In [None]:
sc.pp.scale(ec, max_value = 10)

### Run scGEN on source

- Prepare the object

In [None]:
hvg_scGen_batch_1 = ec.copy()
hvg_scGen_batch_1.obs["batch"] = ec.obs["cell_source"].tolist()
hvg_scGen_batch_1.obs["cell_type"] = ec.obs["cell_type"].tolist()

- Prepare scGEN object

In [None]:
tf.reset_default_graph()
network_batch_1 = scgen.VAEArith(x_dimension = hvg_scGen_batch_1.shape[1], model_path = "/home/jovyan/mona_data/2-EC/models/")

- Train the model

In [None]:
network_batch_1.train(train_data = hvg_scGen_batch_1, n_epochs = 20)

- Correct batches

In [None]:
corrected_batch_1 = batch_removal_ct5(network_batch_1, hvg_scGen_batch_1)

- Visualise

In [None]:
sc.pp.neighbors(corrected_batch_1, random_state = 1786)
sc.tl.umap(corrected_batch_1, min_dist = 0.3, spread = 1, random_state = 1786)
sc.pl.umap(corrected_batch_1, color = ['cell_source',  'cell_type', 'region', 'donor', 'CDH5', 'SEMA3G', 'ACKR1', 'MSLN'], size = 1, legend_fontsize = 6, color_map = 'RdPu', frameon = False)

### Correct for donor

In [None]:
hvg_scGen_batch_2 = corrected_batch_1.copy()
hvg_scGen_batch_2.obs["batch"] = corrected_batch_1.obs["donor"].tolist()
hvg_scGen_batch_2.obs["cell_type"] = corrected_batch_1.obs["cell_type"].tolist()

In [None]:
tf.reset_default_graph()
network_batch_2 = scgen.VAEArith(x_dimension = hvg_scGen_batch_2.shape[1], model_path = "/home/jovyan/mona_data/2-EC/models/")

In [None]:
network_batch_2.train(train_data = hvg_scGen_batch_2, n_epochs = 20)

In [None]:
corrected_batch_2 = batch_removal_ct5(network_batch_2, hvg_scGen_batch_2)

In [None]:
sc.pp.neighbors(corrected_batch_2, random_state = 1712)
sc.tl.umap(corrected_batch_2, min_dist = 0.3, spread = 1, random_state = 1712)
sc.pl.umap(corrected_batch_2, color = ['cell_source', 'region', 'donor', 'n_counts', 'n_genes', 'percent_mito', 'percent_ribo', 'scrublet_score'], size = 1, legend_fontsize = 6, color_map = 'RdPu', frameon = False)

In [None]:
sc.pl.umap(corrected_batch_2, color = ['cell_source',  'cell_type', 'region', 'donor', 'CDH5', 'SEMA3G', 'ACKR1', 'MSLN'], size = 1, legend_fontsize = 6, color_map = 'RdPu', frameon = False)

In [None]:
sc.pl.umap(corrected_batch_2, color = ['cell_types', 'cell_type'], size = 1, legend_fontsize = 6, color_map = 'RdPu')

In [None]:
sc.tl.leiden(corrected_batch_2, resolution = 0.5, random_state = 1712, key_added = 'leiden05')
sc.pl.umap(corrected_batch_2, color = ['cell_type', 'leiden05'], size = 1, legend_fontsize = 6, color_map = 'RdPu')

### Call markers on Leiden

In [None]:
sc.tl.rank_genes_groups(corrected_batch_2, 'leiden05', method = 'wilcoxon', n_genes = 500, use_raw = True)
result = corrected_batch_2.uns['rank_genes_groups']
groups = result['names'].dtype.names
heart_scGen_wilcox = pd.DataFrame(
{group + '_' + key[:1]: result[key][group]
    for group in groups for key in ['names', 'pvals_adj', 'logfoldchanges']})
heart_scGen_wilcox.head(10)

In [None]:
heart_scGen_wilcox.to_csv('/home/jovyan/mona_data/2-EC/heart_EC_scGEN_ml200505_wilcox05.csv', index = False, sep = ',')

In [None]:
sc.pl.rank_genes_groups_matrixplot(corrected_batch_2, n_genes = 10, cmap = 'RdPu', standard_scale = 'var')

In [None]:
sc.pl.umap(corrected_batch_2, color = ['cell_type', 'leiden05'], size = 1, legend_fontsize = 6, color_map = 'RdPu')

In [None]:
ec_markers = ['PECAM1', 'CDH5', 'VWF', 'SEMA3G', 'HEY1', 'DLL4', 'ACKR1', 'NR2F2', 'DLL1', 'JUN', 'FOS', 'ATF3', 'SMOC1', 'INHBA', 
             'NOVA1', 'NPR3', 'POSTN', 'PROX1', 'PDPN', 'RGS5', 'PDGFRB', 'ABCC9', 'MYH11', 'ACTA2', 'DCN', 'CDH19', 'PTPRC', 'TTN', 'RYR2', 'MYL4', 'MYL7', 'AC020637.1',
              'NRP2','ANO2','DOCK4','PTPRM','CDKN1C','TCF15','MARCKS','GPIHBP1','KCTD12','MCF2L','ITGA6','TMSB10','AQP1','CA4', 'CXCL1','ICAM1','CX3CL1','IL6','CXCL2','CCL2','IL32',
             'NDUFA4L2','RGS5','COX4I2','TMSB4X','MT2A','CALM2','ADIRF','FTH1','TPM2','TIMP1','YBX1','NTRK3','MYH11', 'DOCK8', 'RERGL','ZBTB7C','LINC01568',
             'RGCC','FCN3','IFI27','FABP5','TMEM88','CLDN5','PARD3', 'DLC1','PDGFRB','MALAT1','PLA2G5','EPS8','ABCC9','PDZD2','FRMD3','CALD1','EGFLAM','GUCY1A2', 'HBB']
sc.pl.matrixplot(corrected_batch_2, ec_markers, groupby = 'leiden05', cmap = 'RdPu', dendrogram = False, standard_scale = 'var')

In [None]:
sc.pl.umap(corrected_batch_2, color = ['leiden05', 'PECAM1', 'CDH5', 'VWF', 'SEMA3G', 'HEY1', 'DLL4', 'ACKR1', 'NR2F2', 'DLL1', 'JUN', 'FOS', 'ATF3', 'SMOC1', 'INHBA', 
             'NOVA1', 'NPR3', 'POSTN', 'PROX1', 'PDPN', 'RGS5', 'PDGFRB','MYL2', 'LGALS1','RGCC', 'TCF15', 'IL1RL1', 'ITGA1', 'ITGA6','CDKN1C','HIGD1B','NDUFA4L2','AGT',
             'RGS5','STEAP4','CPE','PDGFRB','ABCC9','GUCY1A2','COX4I2','ACTA2','TAGLN','TEKT4','FCER1G','S100A8','S100A9','S100A4','TMSB4X','CXCL11', 'CXCL10', 'CXCL8', 'CX3CL1', 'VCAM1', 'ICAM1' ], size = 1, legend_fontsize = 4, legend_loc = 'on data', frameon = False)

### Annotate the dataset

In [None]:
sc.pl.umap(corrected_batch_2, color = ['leiden05'], size = 1, legend_fontsize = 4, legend_loc = 'on data', frameon = False)

In [None]:
corrected_batch_2.obs['leiden05'].cat.categories

In [None]:
corrected_batch_2.obs['cell_states'] = corrected_batch_2.obs['leiden05'].copy()
corrected_batch_2.obs['cell_states'].cat.categories = ['EC0_cap', 'EC1_cap', 'EC2_art', 'EC3_activ', 'EC4_ven', 'EC5_cap', 'EC6_endo', 'EC7_stromal', 'EC8_CMC-like', 'EC9_stromal', 'Meso10', 
                                                       'EC11_AgP', 'EC12_FB-like', 'EC13_ln', 'EC14']
sc.pl.umap(corrected_batch_2, color = ['cell_states'], size = 1, legend_fontsize = 5, legend_loc = 'on data', frameon = False)

### Save and export the object

In [None]:
## Create a new object with main matrix containing all the genes, log-transformed
scGEN_website = anndata.AnnData(X = ec.X, obs = corrected_batch_2.obs , var = ec.var, obsm = corrected_batch_2.obsm)
scGEN_website

In [None]:
scGEN_website.write('/home/jovyan/mona_data/2-EC/heart_EC_scGEN_ml200511_website.h5ad')

In [None]:
## Create a new object with main matrix containing all the genes, raw
scGEN_raw = anndata.AnnData(X = ec_raw.X, obs = corrected_batch_2.obs , var = ec_raw.var, obsm = corrected_batch_2.obsm)
scGEN_raw

In [None]:
sc.pl.umap(scGEN_raw, color = ['VWF', 'RGS5', 'MSLN'], size = 1, legend_fontsize = 6, color_map = 'RdPu', frameon = False)

In [None]:
scGEN_raw.write('/home/jovyan/mona_data/2-EC/heart_EC_scGEN_ml200511_raw.h5ad')