# Quality control 

In [None]:
import scvi
import scanpy as sc

import scipy

# import celltypist
# from celltypist import models

import numpy as np
import pandas as pd

import matplotlib as mpl
import matplotlib.pyplot as plt

import sys
import os

# Reset logging configuration to default

In [None]:
import logging
logging.getLogger().handlers=[]
logging.basicConfig(level=logging.WARNING)

## Root directory

In [None]:
os.chdir('/research/peer/fdeckert/FD20200109SPLENO')

## Custom modules

In [None]:
sys.path.append('bin/')
from adata_qc import *

## Settup rpy2 

In [None]:
os.environ['R_HOME']='/nobackup/peer/fdeckert/miniconda3/envs/r.4.4.1-FD20200109SPLENO/lib/R'

In [None]:
# import rpy2
# %load_ext rpy2.ipython

## Figures 

In [None]:
sc.set_figure_params(figsize=(5, 5), dpi_save=1200, fontsize=12, frameon=False, facecolor='white')
mpl.rcParams['figure.facecolor']='white'

# Files and parameter 

In [None]:
# Seurat object 
raw_h5ad_file='data/scRNAseq/object/raw.h5ad'
qc_h5ad_file='data/scRNAseq/object/qc.h5ad'

In [None]:
# QC matrix 
nUMI_min=1000
nFeature_min=600
pMT_max=5
pRP_min=0

# Import filtered CellRanger h5ad

In [None]:
adata=sc.read_h5ad(raw_h5ad_file)
adata.raw=adata.copy()

# Compute QC matrix

In [None]:
adata.var['MT']=adata.var_names.str.startswith('mt-')
adata.var['HB']=adata.var_names.str.startswith(('Hba', 'Hbb', 'Hbq1b', 'Hbq1a'))
adata.var['RP']=adata.var_names.str.startswith(('Rpl', 'Rps'))
adata.var['XIST']=adata.var_names.str.startswith('Xist')

In [None]:
adata.var_names[adata.var['HB']]

In [None]:
sc.pp.calculate_qc_metrics(adata, qc_vars=['MT', 'HB', 'RP', 'XIST'], percent_top=None, inplace=True)

# LogNormalize data 

In [None]:
sc.pp.normalize_total(adata, target_sum=10000)
sc.pp.log1p(adata)

# Quality control plots

In [None]:
plot_qc_density(adata, 'total_counts', cutoff=nUMI_min, n_cols=8)

In [None]:
plot_qc_density(adata, 'n_genes_by_counts', cutoff=nFeature_min, n_cols=8)

In [None]:
plot_qc_density(adata, 'pct_counts_MT', cutoff=pMT_max, n_cols=8)

In [None]:
plot_qc_density(adata, 'pct_counts_RP', cutoff=pRP_min, n_cols=8)

In [None]:
plot_qc_scatter(adata, 'total_counts', 'n_genes_by_counts', x_cutoff=nUMI_min, y_cutoff=nFeature_min, sample_col='sample_name', dot_size=20, n_cols=8)

In [None]:
plot_qc_scatter(adata, 'pct_counts_RP', 'n_genes_by_counts', x_cutoff=pRP_min, y_cutoff=nFeature_min, sample_col='sample_name', dot_size=20, n_cols=8)

In [None]:
plot_qc_scatter(adata, 'pct_counts_RP', 'pct_counts_MT', x_cutoff=pRP_min, y_cutoff=pMT_max, sample_col='sample_name', dot_size=20, n_cols=8)

# Set raw count matrix

In [None]:
adata=adata.raw.to_adata()
adata.raw=adata.copy()

# Filter by QC matrix 

In [None]:
adata=adata[

    (adata.obs['total_counts'] >= nUMI_min) & \
    (adata.obs['n_genes_by_counts'] >= nFeature_min) & \
    (adata.obs['pct_counts_MT'] <= pMT_max) & \
    (adata.obs['pct_counts_RP'] >= pRP_min)
    
]

# Solo detection on facility subsets

## Train SCVI model

In [None]:
cache_scvi = True

In [None]:
def model_history(model): 
    
    # Plot model history 
    fig, axes=plt.subplots(1, 2, figsize=(10, 5))
    
    axes[0].plot(model.history['reconstruction_loss_train']['reconstruction_loss_train'], label='train')
    axes[0].plot(model.history['reconstruction_loss_validation']['reconstruction_loss_validation'], label='validation')
    axes[0].set_title('Reconstruction Loss')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].legend()
    
    axes[1].plot(model.history['elbo_train']['elbo_train'], label='train')
    axes[1].plot(model.history['elbo_validation']['elbo_validation'], label='validation')
    axes[1].set_title('ELBO')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Loss')
    axes[1].legend()

In [None]:
def scvi_workflow(adata, subset=None): 

    # Subset genes
    adata=adata[adata.obs.sample_group_rep==subset]
    adata=adata[:, (adata.X>=3).sum(axis=0)>=3]

    # Ensure copy
    adata=adata.copy()

    # Setup anndata
    scvi.model.SCVI.setup_anndata(
    
        adata, 
        batch_key='sample_group_rep'

    )
    
    # Model 
    model=scvi.model.SCVI(
    
        adata, 
        n_latent=30, 
        n_hidden=128, 
        n_layers=2, 
        gene_likelihood='nb'

    )
    
    # Max epochs heuristic 
    max_epochs=int(np.min([round((20000 / adata.n_obs) * 400), 400]))
    
    # Train model 
    model.train(max_epochs=max_epochs, check_val_every_n_epoch=1)

    # Model history
    model_history(model)
    
    # Save results 
    model.save('data/scRNAseq/object/scvi/qc/model/' + subset, overwrite=True, save_anndata=False)
    adata.write_h5ad('data/scRNAseq/object/scvi/qc/model/adata_' + subset + '.h5ad')

In [None]:
if not cache_scvi: 

    # Run SCVI
    for subset in adata.obs['sample_group_rep'].cat.categories: 
        scvi_workflow(adata.copy(), subset)
        
else: 
    
    for subset in adata.obs['sample_group_rep'].cat.categories: 
        
        # Plot model history
        model_history(scvi.model.SCVI.load('data/scRNAseq/object/scvi/qc/model/' + subset, adata=sc.read_h5ad('data/scRNAseq/object/scvi/qc/model/adata_' + subset + '.h5ad')))        

## Solo doublet detection 

In [None]:
cache_solo = True

In [None]:
def solo_workflow(subset=None): 

    # Import SCVI model 
    adata=sc.read_h5ad('data/scRNAseq/object/scvi/qc/model/adata_' + subset + '.h5ad')
    model=scvi.model.SCVI.load('data/scRNAseq/object/scvi/qc/model/' + subset, adata=adata)

    # Register solo model 
    model_solo=scvi.external.SOLO.from_scvi_model(model)
    
    # Max epochs heuristic 
    max_epochs=int(np.min([round((20000 / adata.n_obs) * 400), 400]))
    
    # Train solo
    model_solo.train(max_epochs=max_epochs, check_val_every_n_epoch=1)

    # Get doublet score and labels 
    prob=model_solo.predict(soft=True, include_simulated_doublets=False)
    label=model_solo.predict(soft=False, include_simulated_doublets=False)

    adata.obs['solo_pred_doublet']=prob['doublet']
    adata.obs['solo_pred_singlet']=prob['singlet']
    adata.obs['solo_label']=label

    # Store results 
    model_solo.save('data/scRNAseq/object/scvi/qc/model_solo/' + subset, overwrite=True, save_anndata=False)
    adata.write_h5ad('data/scRNAseq/object/scvi/qc/model_solo/adata_' + subset + '.h5ad')

In [None]:
if not cache_solo: 

    # Run SOlO
    for subset in adata.obs['sample_group_rep'].cat.categories: 
        solo_workflow(subset)

## Dim reduction on latent space

In [None]:
cache_dim = True

In [None]:
def dim_workflow(subset=None): 

    # Import SCVI model 
    adata=sc.read_h5ad('data/scRNAseq/object/scvi/qc/model_solo/adata_' + subset + '.h5ad')
    model=scvi.model.SCVI.load('data/scRNAseq/object/scvi/qc/model/' + subset, adata=adata)
    
    adata.obsm['latent']=model.get_latent_representation()
    
    sc.pp.neighbors(adata, n_neighbors=30, use_rep='latent')
    sc.tl.leiden(adata, resolution=1, flavor='igraph', n_iterations=2, key_added='leiden')
    sc.tl.umap(adata, min_dist=1)
    
    adata.write_h5ad('data/scRNAseq/object/scvi/qc/model_solo/adata_' + subset + '.h5ad')

    # UMAP 
    sc.pl.umap(adata, color=['sample_group_rep', 'S_score', 'G2M_score', 'solo_label', 'solo_pred_doublet', 'log1p_total_counts', 'log1p_n_genes_by_counts', 'pct_counts_MT', 'pct_counts_RP', 'pct_counts_HB'], frameon=False, ncols=5, wspace=0.1, size=50, legend_loc='on data', use_raw=False)

    return(adata)

In [None]:
if not cache_dim: 

    for subset in adata.obs['sample_group_rep'].cat.categories: 
        dim_workflow(subset)

else: 

    for subset in adata.obs['sample_group_rep'].cat.categories: 
        sc.pl.umap(sc.read_h5ad('data/scRNAseq/object/scvi/qc/model_solo/adata_' + subset + '.h5ad'), color=['sample_group_rep', 'S_score', 'G2M_score', 'solo_label', 'solo_pred_doublet', 'log1p_total_counts', 'log1p_n_genes_by_counts', 'pct_counts_MT', 'pct_counts_RP', 'pct_counts_HB'], frameon=False, ncols=5, wspace=0.1, size=50, legend_loc='on data', use_raw=False)   

## Fetch SOLO result 

In [None]:
solo = [sc.read_h5ad('data/scRNAseq/object/scvi/qc/model_solo/adata_' + subset + '.h5ad').obs[['solo_pred_singlet', 'solo_pred_doublet', 'solo_label']] for subset in adata.obs['sample_group_rep'].cat.categories]
solo = pd.concat(solo)

In [None]:
adata.obs = adata.obs.merge(solo, left_index=True, right_index=True, how='left')

# Set var intersect 

In [None]:
adata_bsf = adata[adata.obs.facility=="BSF"]
var_bsf = adata_bsf[:, (adata_bsf.X>0).sum(axis=0)>=1].var_names

In [None]:
adata_vbc = adata[adata.obs.facility=="VBC"]
var_vbc = adata_vbc[:, (adata_vbc.X>0).sum(axis=0)>=1].var_names

In [None]:
adata.uns['var_intersect'] = list(set(var_vbc) & set(var_bsf))

## SCVI model 

In [None]:
adata.obs['S_score_scale'] = adata.obs.groupby('sample_group_rep')['S_score'].transform(lambda x: (x - x.min()) / (x.max() - x.min()))
adata.obs['G2M_score_scale'] = adata.obs.groupby('sample_group_rep')['G2M_score'].transform(lambda x: (x - x.min()) / (x.max() - x.min()))

In [None]:
adata = adata[:, (adata.X>=3).sum(axis=0)>=3]

In [None]:
adata = adata.copy()

In [None]:
cache_scvi = False

In [None]:
if not cache_scvi: 
    
    scvi.model.SCVI.setup_anndata(
    
        adata, 
        batch_key='sample_group_rep', 
        categorical_covariate_keys=['facility'], 
        continuous_covariate_keys=['S_score_scale', 'G2M_score_scale']
        
    )
    
    model = scvi.model.SCVI(
    
        adata, 
        n_latent=30, 
        n_hidden=128, 
        n_layers=2, 
        gene_likelihood='nb', 

    )
    
    max_epochs = int(np.min([round((20000 / adata.n_obs) * 400), 400]))
    
    model.train(max_epochs=max_epochs, check_val_every_n_epoch=1)
    
    model.save('data/scRNAseq/object/scvi/qc/model/model', overwrite=True)
    
else: 
    
    model = scvi.model.SCVI.load('data/scRNAseq/object/scvi/qc/model/model', adata=adata)

In [None]:
cache_dim = False

In [None]:
if not cache_dim: 
    
    adata.obsm['latent'] = model.get_latent_representation()
    
    sc.pp.neighbors(adata, n_neighbors=30, use_rep='latent')
    sc.tl.leiden(adata, resolution=1, flavor='igraph', n_iterations=2, key_added='leiden')
    sc.tl.umap(adata, min_dist=1)
    
    adata.write_h5ad('data/scRNAseq/object/scvi/qc/model/adata.h5ad')
    
else: 
    
    adata = sc.read_h5ad('data/scRNAseq/object/scvi/qc/model/adata.h5ad')

In [None]:
sc.pl.umap(adata, color=['leiden', 'facility', 'label_main_immgen', 'label_main_haemopedia', 'solo_label', 'solo_pred_doublet', 'log1p_total_counts', 'log1p_n_genes_by_counts', 'log1p_total_counts_MT', 'log1p_total_counts_RP', 'log1p_total_counts_HB', 'log1p_total_counts_XIST'], frameon=False, ncols=6, wspace=0.25, size=20, legend_loc='on data', use_raw=False)

# Save result 

In [None]:
adata = adata.raw.to_adata()
adata.raw = adata

In [None]:
adata.write_h5ad(qc_h5ad_file)