# Public data integration 

In [None]:
import scvi
import scanpy as sc

import scipy

import anndata as ad

import glob

import numpy as np
import pandas as pd

import matplotlib as mpl
import matplotlib.pyplot as plt

import sys
import os

# Reset logging configuration to default

In [None]:
import logging
logging.getLogger().handlers=[]
logging.basicConfig(level=logging.WARNING)

## Root directory

In [None]:
os.chdir('/research/peer/fdeckert/FD20200109SPLENO')

## Custom modules

In [None]:
sys.path.append('bin/')
from adata_qc import *

## Settup rpy2 

In [None]:
os.environ['R_HOME']='/nobackup/peer/fdeckert/miniconda3/envs/r.4.4.1-FD20200109SPLENO/lib/R'

In [None]:
# import rpy2
# %load_ext rpy2.ipython

## Figures 

In [None]:
sc.set_figure_params(figsize=(5, 5), dpi_save=1200, fontsize=12, frameon=False, facecolor='white')
mpl.rcParams['figure.facecolor']='white'

# Files and parameter 

In [None]:
# QC matrix 
nUMI_min=1000
nFeature_min=600
pMT_max=5
pRP_min=0

# Import public data sets 

## GSE207412

In [None]:
files = sorted(glob.glob("/nobackup/peer/fdeckert/GEO/GSE207412/*.h5"))
adata = []
for i, f in enumerate(files):
    a = sc.read_10x_h5(f)
    a.var_names_make_unique()
    a.obs_names_make_unique()
    a.obs["sample"] = os.path.basename(f).replace(".h5", "")
    a.obs["sample_name"] = f"sample_{i}"
    adata.append(a)

adata = ad.concat(adata, axis=0, join="outer")

In [None]:
adata.raw = adata.copy()

# Compute QC matrix

In [None]:
adata.var['MT']=adata.var_names.str.startswith('mt-')
adata.var['HB']=adata.var_names.str.startswith(('Hba', 'Hbb', 'Hbq1b', 'Hbq1a'))
adata.var['RP']=adata.var_names.str.startswith(('Rpl', 'Rps'))
adata.var['XIST']=adata.var_names.str.startswith('Xist')

In [None]:
sc.pp.calculate_qc_metrics(adata, qc_vars=['MT', 'HB', 'RP', 'XIST'], percent_top=None, inplace=True)

# LogNormalize data 

In [None]:
sc.pp.normalize_total(adata, target_sum=10000)
sc.pp.log1p(adata)

# Quality control plots

In [None]:
plot_qc_density(adata, 'total_counts', cutoff=nUMI_min, n_cols=8)

In [None]:
plot_qc_density(adata, 'n_genes_by_counts', cutoff=nFeature_min, n_cols=8)

In [None]:
plot_qc_density(adata, 'pct_counts_MT', cutoff=pMT_max, n_cols=8)

In [None]:
plot_qc_density(adata, 'pct_counts_RP', cutoff=pRP_min, n_cols=8)

In [None]:
plot_qc_scatter(adata, 'total_counts', 'n_genes_by_counts', x_cutoff=nUMI_min, y_cutoff=nFeature_min, sample_col='sample_name', dot_size=20, n_cols=8)

In [None]:
plot_qc_scatter(adata, 'pct_counts_RP', 'n_genes_by_counts', x_cutoff=pRP_min, y_cutoff=nFeature_min, sample_col='sample_name', dot_size=20, n_cols=8)

In [None]:
plot_qc_scatter(adata, 'pct_counts_RP', 'pct_counts_MT', x_cutoff=pRP_min, y_cutoff=pMT_max, sample_col='sample_name', dot_size=20, n_cols=8)

# Set raw count matrix

In [None]:
adata=adata.raw.to_adata()
adata.raw=adata.copy()

# Filter by QC matrix 

In [None]:
adata=adata[

    (adata.obs['total_counts'] >= nUMI_min) & \
    (adata.obs['n_genes_by_counts'] >= nFeature_min) & \
    (adata.obs['pct_counts_MT'] <= pMT_max) & \
    (adata.obs['pct_counts_RP'] >= pRP_min)
    
]

In [None]:
def model_history(model): 
    
    # Plot model history 
    fig, axes=plt.subplots(1, 2, figsize=(10, 5))
    
    axes[0].plot(model.history['reconstruction_loss_train']['reconstruction_loss_train'], label='train')
    axes[0].plot(model.history['reconstruction_loss_validation']['reconstruction_loss_validation'], label='validation')
    axes[0].set_title('Reconstruction Loss')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].legend()
    
    axes[1].plot(model.history['elbo_train']['elbo_train'], label='train')
    axes[1].plot(model.history['elbo_validation']['elbo_validation'], label='validation')
    axes[1].set_title('ELBO')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Loss')
    axes[1].legend()

## SCVI model 

In [None]:
adata = adata[:, (adata.X>=3).sum(axis=0)>=3]

In [None]:
adata = adata.copy()

In [None]:
cache_scvi = False

In [None]:
if not cache_scvi: 
    
    scvi.model.SCVI.setup_anndata(
    
        adata, 
        batch_key='sample_name'
        
    )
    
    model = scvi.model.SCVI(
    
        adata, 
        n_latent=30, 
        n_hidden=128, 
        n_layers=2, 
        gene_likelihood='nb', 

    )
    
    max_epochs = int(np.min([round((20000 / adata.n_obs) * 400), 400]))
    
    model.train(max_epochs=max_epochs, check_val_every_n_epoch=1)
    
    model.save('/nobackup/peer/fdeckert/GEO/GSE207412/model', overwrite=True)
    
else: 
    
    model = scvi.model.SCVI.load('/nobackup/peer/fdeckert/GEO/GSE207412/model', adata=adata)

In [None]:
cache_dim = False

In [None]:
if not cache_dim: 
    
    adata.obsm['latent'] = model.get_latent_representation()
    
    sc.pp.neighbors(adata, n_neighbors=30, use_rep='latent')
    sc.tl.leiden(adata, resolution=1, flavor='igraph', n_iterations=2, key_added='leiden')
    sc.tl.umap(adata, min_dist=1)
    
    adata.write_h5ad('/nobackup/peer/fdeckert/GEO/GSE207412/model/adata.h5ad')
    
else: 
    
    adata = sc.read_h5ad('/nobackup/peer/fdeckert/GEO/GSE207412/model/adata.h5ad')

In [None]:
sc.pl.umap(adata, color=['sample_name', 'log1p_total_counts', 'log1p_n_genes_by_counts', 'log1p_total_counts_MT', 'log1p_total_counts_RP', 'log1p_total_counts_HB', 'log1p_total_counts_XIST'], frameon=False, ncols=6, wspace=0.25, size=20, legend_loc='on data', use_raw=True)

# Save result 

In [None]:
# adata = adata.raw.to_adata()
# adata.raw = adata

In [None]:
# adata.write_h5ad(qc_h5ad_file)