# Annotate data with Tusi et al., 2018 data and scANVI 

In [None]:
import scanpy as sc
import numpy as np
import pandas as pd
import os
from scipy import sparse

import torch
import scvi

import matplotlib.pyplot as plt
import matplotlib
%matplotlib inline

In [None]:
os.chdir('/research/peer/fdeckert/FD20200109SPLENO')

# Import Tusi et al., 2018 reference data

In [None]:
adata = sc.read_mtx('data/TUSI_2018/basal_bone_marrow.raw_counts.mtx.gz')
adata.var_names = pd.read_table('data/TUSI_2018/genes.txt', header=None, dtype=str)[0]
adata.obs = pd.read_csv('data/TUSI_2018/basal_bone_marrow.metadata.csv', index_col=0)

## Filter object according to original publication 

In [None]:
adata = adata[adata.obs['pass_filter'] == True]

## Plot PBA probability

In [None]:
pba_prob = adata.obs[['PBA_Prob_E', 'PBA_Prob_GN', 'PBA_Prob_Ly', 'PBA_Prob_D', 'PBA_Prob_Meg', 'PBA_Prob_M', 'PBA_Prob_Ba']]
pba_prob.plot.kde()

## Add PBA probability label 

In [None]:
adata.obs['pba_prob_label'] = pba_prob.idxmax(axis = 1, skipna = True)

## Plot spring plot

In [None]:
adata.obsm['X_spring'] = adata.obs[['SPRING_x', 'SPRING_y']]
sc.pl.embedding(adata, basis = 'X_spring', color = 'pba_prob_label')

## Define marker genes 

In [None]:
tusi_marker_dict = {
    'E': ['Hbb-bt', 'Hba-a2', 'Hba-a1', 'Alas2', 'Bpgm'], 
    'Ba': ['Lmo4', 'Ifitm1', 'Ly6e', 'Srgn'],
    'Meg': ['Pf4', 'Itga2b', 'Vwf', 'Pbx1', 'Mef2c'],
    'MPP': ['Hlf', 'Gcnt2'],
    'Ly': ['Cd79a', 'Igll1', 'Vpreb3', 'Vpreb1', 'Lef1'],
    'D': ['H2-Aa', 'Cd74', 'H2-Eb1', 'H2-Ab1', 'Cst3'],
    'M': ['Csf1r', 'Ly6c2', 'Ccr2'], 
    'GN': ['Lcn2', 'S100a8', 'Ltf', 'Lyz2', 'S100a9']
}

tusi_marker_list = sorted({x for v in tusi_marker_dict.values() for x in v})

## Compute cell type score 

In [None]:
normalized_adata = adata.copy()
sc.pp.normalize_total(normalized_adata, target_sum = 1e4)
sc.pp.log1p(normalized_adata)
normalized_adata = normalized_adata[:,tusi_marker_list].copy()
sc.pp.scale(normalized_adata)

In [None]:
def get_score(normalized_adata, gene_set):
    """Returns the score per cell given a dictionary of + and - genes

    Parameters
    ----------
    normalized_adata
      anndata dataset that has been log normalized and scaled to mean 0, std 1
    gene_set
      a dictionary with two keys: 'positive' and 'negative'
      each key should contain a list of genes
      for each gene in gene_set['positive'], its expression will be added to the score
      for each gene in gene_set['negative'], its expression will be subtracted from its score

    Returns
    -------
    array of length of n_cells containing the score per cell
    """
    score = np.zeros(normalized_adata.n_obs)
    for gene in gene_set['positive']:
        expression = np.array(normalized_adata[:, gene].X)
        score += expression.flatten()
    for gene in gene_set['negative']:
        expression = np.array(normalized_adata[:, gene].X)
        score -= expression.flatten()
    return score

def get_cell_mask(normalized_adata, gene_set):
    """Calculates the score per cell for a list of genes, then returns a mask for
    the cells with the highest 50 scores.

    Parameters
    ----------
    normalized_adata
      anndata dataset that has been log normalized and scaled to mean 0, std 1
    gene_set
      a dictionary with two keys: 'positive' and 'negative'
      each key should contain a list of genes
      for each gene in gene_set['positive'], its expression will be added to the score
      for each gene in gene_set['negative'], its expression will be subtracted from its score

    Returns
    -------
    Mask for the cells with the top 50 scores over the entire dataset
    """
    score = get_score(normalized_adata, gene_set)
    cell_idx = score.argsort()[-20:]
    mask = np.zeros(normalized_adata.n_obs)
    mask[cell_idx] = 1
    return mask.astype(bool)

In [None]:
#hand curated list of genes for identifying ground truth
E_geneset = {"positive":tusi_marker_dict['E'], 
             "negative":[x for x in tusi_marker_list if x not in tusi_marker_dict['E']]}
Ba_geneset = {"positive":tusi_marker_dict['Ba'], 
             "negative":[x for x in tusi_marker_list if x not in tusi_marker_dict['Ba']]}
Meg_geneset = {"positive":tusi_marker_dict['Meg'], 
             "negative":[x for x in tusi_marker_list if x not in tusi_marker_dict['Meg']]}
MPP_geneset = {"positive":tusi_marker_dict['MPP'], 
             "negative":[x for x in tusi_marker_list if x not in tusi_marker_dict['MPP']]}
Ly_geneset = {"positive":tusi_marker_dict['Ly'], 
             "negative":[x for x in tusi_marker_list if x not in tusi_marker_dict['Ly']]}
D_geneset = {"positive":tusi_marker_dict['D'], 
             "negative":[x for x in tusi_marker_list if x not in tusi_marker_dict['D']]}
M_geneset = {"positive":tusi_marker_dict['M'], 
             "negative":[x for x in tusi_marker_list if x not in tusi_marker_dict['M']]}
GN_geneset = {"positive":tusi_marker_dict['GN'], 
             "negative":[x for x in tusi_marker_list if x not in tusi_marker_dict['GN']]}

In [None]:
E_mask = get_cell_mask(normalized_adata, E_geneset,)
Ba_mask = get_cell_mask(normalized_adata, Ba_geneset,)
Meg_mask = get_cell_mask(normalized_adata, Meg_geneset,)
MPP_mask = get_cell_mask(normalized_adata, MPP_geneset,)
Ly_mask = get_cell_mask(normalized_adata, Ly_geneset,)
D_mask = get_cell_mask(normalized_adata, D_geneset,)
M_mask = get_cell_mask(normalized_adata, M_geneset,)
GN_mask = get_cell_mask(normalized_adata, GN_geneset,)

In [None]:
seed_labels = np.array(E_mask.shape[0] * ["Unknown"])
seed_labels[E_mask] = "E"
seed_labels[Ba_mask] = "Ba"
seed_labels[Meg_mask] = "Meg"
seed_labels[MPP_mask] = "MPP"
seed_labels[Ly_mask] = "Ly"
seed_labels[D_mask] = "D"
seed_labels[M_mask] = "M"
seed_labels[GN_mask] = "GN"

adata.obs["seed_labels"] = seed_labels

In [None]:
adata.obs.seed_labels.value_counts()

In [None]:
sc.pl.embedding(adata, basis = 'X_spring', color = 'seed_labels', palette = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#d3d3d3'], 
               sort_order=False, legend_loc='on data')

# Transfer of annotation with scANVI

In [None]:
scvi.data.setup_anndata(adata, batch_key=None, labels_key="seed_labels")

In [None]:
scvi_model = scvi.model.SCVI(adata, n_latent=30, n_layers=2)

In [None]:
# scvi_model.train(100)
# scvi_model.save('data/scvi/model_tusi_bBM/', overwrite = True)
scvi_model.load('data/scvi/model_tusi_bBM/', adata)

In [None]:
scanvi_model = scvi.model.SCANVI.from_scvi_model(scvi_model, 'Unknown')
scanvi_model.train(25)
scanvi_model.save('data/scvi/model_tusi_bBM_scanvi/', overwrite = True)
# scanvi_model.load('data/scvi/model_tusi_bBM/', adata)

In [None]:
adata.obs["C_scANVI"] = scanvi_model.predict(adata)
adata.obsm["X_scANVI"] = scanvi_model.get_latent_representation(adata)

In [None]:
sc.pp.neighbors(adata, use_rep="X_scANVI")
sc.tl.umap(adata)

In [None]:
sc.pl.umap(
    adata, color='C_scANVI', 
    palette = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#d3d3d3'], 
    sort_order=False, legend_loc='on data'
)

In [None]:
sc.pl.embedding(
    adata, basis = 'X_spring', color = 'C_scANVI', 
    palette = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#d3d3d3'], 
    sort_order=False, legend_loc='on data'
)