In [None]:
from scvi.model.utils import mde
import pymde
import scanpy as sc
import scvi
import glob
import os
from functools import reduce
import anndata as ad
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd
from scipy.io import mmread
import gzip


## Daniel's functions for QC

In [None]:
def read_sc_data(
    counts_file, 
    features_file,
    metadata_file
):
    data_dict = {}
    for key, filename in zip(
        ['counts', 'featuresզ', 'metadata'],
        [counts_file, features_file, metadata_file]
    ):
        if filename.endswith('gz'):
            open_file = lambda x: gzip.open(x, 'rt')
            
        else: 
            open_file = lambda x: open(x, 'r')
            
        with open_file(filename) as file:
            if key == 'counts':
                # transpose due to the way the data was exported to comply with Seurat
                # see also convert_to_raw.ipynb
                data = mmread(file).T.tocsr()
            
            elif key == 'metadata':
                data = pd.read_csv(
                    file,
                    sep = '\t',
                    index_col = 0
                )
            
            else:
                data = pd.DataFrame(
                    index = file.read().rstrip().split()
                )
            
        data_dict[key] = data

    adata = ad.AnnData(
        X = data_dict['counts'],
        obs = data_dict['metadata'],
        var = data_dict['features']
    )
    return adata


def compute_qc_metrics(adata):
    # flatten is needed due to csr_matrix.sum returning a numpy.matrix object
    # which cannot be broadcasted to obs frame
    # adata.obs['nFeature_RNA'] = np.array((adata.X > 0).sum(axis = 1)).flatten()
    # adata.obs['nCount_RNA'] = np.array(adata.X.sum(axis = 1)).flatten()
#     adata.obs['percent_mt'] = np.array(
#         adata[:, adata.var.index.str.match('^mt.')].X.sum(axis = 1) / adata.X.sum(axis = 1) * 100
#     ).flatten()

    adata.obs['percent_ribo'] = np.array(
        adata[:, adata.var.index.str.match('^rp[sl]')].X.sum(axis = 1) / adata.X.sum(axis = 1) * 100
    ).flatten()


def apply_qc_thresholds(adata, sample_id_column, sample_thresholds):
    adata.obs['qc_pass'] = True
    for sample_id, thresholds in sample_thresholds.items():
        df = adata.obs.loc[adata.obs[sample_id_column] == sample_id, :]
        feature_qcs = []
        for feature, (lo, hi) in thresholds.items():
            feature_qcs.append(
                df[feature].apply(lambda x: x > lo and x < hi).values
            )
        
        qc_pass = np.all(
            np.vstack(feature_qcs),
            axis = 0
        )
        adata.obs.loc[adata.obs[sample_id_column] == sample_id, 'qc_pass'] = qc_pass


def generate_plots(
    axs, 
    df,
    qc_pass_idx, 
    thresholds = None
):
    # datacols = ['nFeature_RNA', 'percent_mt', 'percent_ribo']
    datacols = ['nFeature_RNA', 'percent_mt']
    hue = ['pass' if x else 'fail' for x in qc_pass_idx] if not all(qc_pass_idx) else None
    palette = {'pass': '#4B72B1', 'fail': 'red'} if hue else None
    for j, datacol in enumerate(datacols):
        sns.histplot(
            x = df.loc[:, datacol],
            ax = axs[0, j],
            hue = hue,
            palette = palette,
            kde = True,
            fill = True
        )
        if thresholds and datacol in thresholds:
            for position in thresholds[datacol]:
                if position:
                    axs[0, j].axvline(
                        position,
                        color = 'k',
                        linewidth = 1
                    )
                
    
    xy = [
        ('nCount_RNA', 'nFeature_RNA'),
        ('nFeature_RNA', 'percent_mt')# ,
        # ('percent_mt', 'percent_ribo')
    ]
    for j, (xcol, ycol) in enumerate(xy): 
        sns.scatterplot(
            x = df.loc[:, xcol],
            y = df.loc[:, ycol],
            ax = axs[1, j],
            hue = hue,
            palette = palette,
            edgecolor = 'k',
            facecolor = None,
            color = None,
            alpha = 0.5
        )
        sns.kdeplot(
            x = df.loc[qc_pass_idx, xcol],
            y = df.loc[qc_pass_idx, ycol],
            ax = axs[1, j],
            color = 'lightblue'
        )
        
        if thresholds:
            for key, plotline in zip(
                [xcol, ycol],
                [axs[1, j].axvline, axs[1, j].axhline]
            ):
                if key in thresholds:
                    for position in thresholds[key]:
                        if position:
                            plotline(
                                position,
                                color = 'k',
                                linewidth = 1
                            )

                            
def plot_qc(
    adata,
    thresholds = None, 
    sample_id_column = None,
    sharex = False
):
    if not sample_id_column:
        fig, axs = plt.subplots(2, 2)
        generate_plots(
            axs,
            adata.obs,
            qc_pass_idx = adata.obs[column_dict['qc_pass']] if 'qc_pass' in adata.obs.columns else [True] * adata.obs.shape[0],
            thresholds = thresholds
        )
        
    else:
        fig, axs = plt.subplots(
            adata.obs[sample_id_column].nunique(), 
            4, 
            sharex = 'col' if sharex else 'none'
        )
        for i, sample_id in enumerate(adata.obs[sample_id_column].unique()):
            tmp_df = adata[adata.obs[sample_id_column] == sample_id, :].obs
            generate_plots(
                axs[i, :].reshape(2, 2),
                tmp_df,
                qc_pass_idx = tmp_df['qc_pass'] if 'qc_pass' in tmp_df.columns else [True] * tmp_df.shape[0],
                thresholds = thresholds[sample_id] if thresholds else None
            )
            axs[i, 0].set_ylabel(sample_id)
    
    return fig


def integrate_data_scvi(
    adata, 
    batch_key, 
    categorical_covariate_keys = None,
    continuous_covariate_keys = None,
    use_highly_variable_genes = True,
    n_top_genes = 4000,
    use_gpu = True,
    max_epochs = None,
    train_size = 0.9
    
):
    adata.layers['counts'] = adata.X.copy()
    adata.raw = adata
    
    if use_highly_variable_genes:
        print('computing highly variable genes')
        sc.pp.highly_variable_genes(
            adata,
            n_top_genes = n_top_genes,
            layer = 'counts',
            subset = True,
            flavor = 'seurat_v3',
        )
        
    scvi.model.SCVI.setup_anndata(
        adata,
        layer = 'counts',
        batch_key = batch_key,
        categorical_covariate_keys = categorical_covariate_keys,
        continuous_covariate_keys = continuous_covariate_keys
    )
    # non default parameters from scVI tutorial and scIB github
    # see https://docs.scvi-tools.org/en/stable/tutorials/notebooks/harmonization.html
    # and https://github.com/theislab/scib/blob/main/scib/integration.py
    model = scvi.model.SCVI(
        adata,
        n_layers = 2,
        n_latent = 30,
        gene_likelihood = 'nb'
    )
    model.train(
        use_gpu = use_gpu,
        max_epochs = max_epochs,
        train_size = train_size
    )
    adata.obsm['X_scvi'] = model.get_latent_representation()
    
    print('compute umap from scvi embedding')
    sc.pp.neighbors(
        adata,
        use_rep = 'X_scvi'
    )
    sc.tl.umap(
        adata
    )
    
    return {'data': adata, 'model': model}

## Loading the packages

In [None]:
# Seed for reproducibility
import torch
import numpy as np
import pandas as pd
import scanpy as sc
from typing import Tuple

# scVI imports
import scvi
# from scvi.dataset import AnnDatasetFromAnnData
# from scvi.inference import UnsupervisedTrainer
# from scvi.models.vae import VAE

torch.manual_seed(0)
np.random.seed(0)
sc.settings.verbosity = 0  # verbosity: errors (0), warnings (1), info (2), hints (3)

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
# from scvi.external import CellAssign

In [None]:
sc.set_figure_params(figsize=(4, 4))

# for white background of figures (only for docs rendering)
%config InlineBackend.print_figure_kwargs={'facecolor' : "w"}
%config InlineBackend.figure_format='retina'

### Reading in final and pilot experiments

In [None]:
bm_final = sc.read_h5ad( "../data/h5ad/" + "bm_final_singlet.h5ad")
bm_pilot = sc.read_h5ad( "../data/h5ad/" + "bm_pilot_singlet.h5ad")

In [None]:
bm_final.obs_names_make_unique()
bm_pilot.obs_names_make_unique()

In [None]:
print(bm_final.shape)
print(bm_pilot.shape)

In [None]:
bm_final.X[:5, :5].todense()


In [None]:
bm_final

In [None]:
bm_final.obs = bm_final.obs.rename(columns={'percent.mt': 'percent_mt'})
bm_pilot.obs = bm_pilot.obs.rename(columns={'percent.mt': 'percent_mt'})

## Preprocessing - done in R

In [None]:
# min_genes = 0
# min_cells = 3

# sc.settings.verbosity = 2
# sc.pp.filter_cells(bm, min_genes=min_genes)
# sc.pp.filter_genes(bm, min_cells=min_cells)
# sc.pp.filter_cells(bm, min_genes=1)

In [None]:
# mito_genes = bm.var_names.str.startswith("mt-")
# bm.obs["percent_mito"] = (
#     np.sum(bm[:, mito_genes].X, axis=1).A1 / np.sum(bm.X, axis=1).A1
# )
# bm.obs["n_counts"] = bm.X.sum(axis=1).A1

In [None]:
# (bm.obs['percent_mito']).shape

In [None]:
# adata = adata[adata.obs["n_genes"] < 2500, :]
# adata = adata[adata.obs["percent_mito"] < 0.05, :]

In [None]:
# bm.layers["counts"] = bm.X.copy() # preserve counts
# sc.pp.normalize_total(bm, target_sum=1e4)
# sc.pp.log1p(bm)
# bm.raw = bm # freeze the state in `.raw`

## Checking QC

In [None]:
compute_qc_metrics(bm_final)
fig = plot_qc(
    bm_final,
    sample_id_column = 'HTO_maxID'
)
fig.set_figwidth(15)
fig.set_figheight(bm_final.obs.HTO_maxID.nunique() * 2.5)
fig.tight_layout()

In [None]:
compute_qc_metrics(bm_pilot)
fig = plot_qc(
    bm_pilot,
    sample_id_column = 'HTO_maxID'
)
fig.set_figwidth(15)
fig.set_figheight(bm_final.obs.HTO_maxID.nunique() * 2)
fig.tight_layout()

## Setting up SC

In [None]:
bm_pilot.layers["counts"] = bm_pilot.X.copy()
bm_final.layers["counts"] = bm_final.X.copy()

In [None]:
type(bm_pilot)

## Combining the two datasets with all raw counts

In [None]:
adatas = {
    "pilot": bm_pilot[bm_pilot.obs['HTO_maxID'].isin(['HTO-PBS1', 'HTO-PBS2', 'HTO-SA1'])],
    "final": bm_final
}

In [None]:
bm = ad.concat(adatas, label="batch")

In [None]:
bm.shape

In [None]:
scvi.model.SCVI.setup_anndata(bm, 
                              layer="counts", 
                              batch_key="batch", categorical_covariate_keys=["HTO_maxID"],
    continuous_covariate_keys=["percent_mt"])

In [None]:
vae = scvi.model.SCVI(bm, n_layers=2, n_latent=30, gene_likelihood="nb")

In [None]:
vae.train()

In [None]:
bm.raw = bm

In [None]:
bm.obsm["X_scVI"] = vae.get_latent_representation()

In [None]:
bm.layers["scVI_normalized"] = vae.get_normalized_expression(bm)

In [None]:
bm.layers["scVI_normalized"]

In [None]:
bm.shape

In [None]:
np.max(bm.layers["scVI_normalized"], axis = 1)

In [None]:
sc.pp.neighbors(bm, use_rep="X_scVI")
sc.tl.leiden(bm)

In [None]:
bm.obsm["X_mde"] = mde(bm.obsm["X_scVI"])

In [None]:
sc.settings.set_figure_params(dpi=80, frameon=False, figsize=(8, 6), facecolor='white')  # low dpi (dots per inch) yields small inline figures

In [None]:
import os
os.getcwd()

In [None]:
sc.pl.embedding(
    bm,
    basis="X_mde",
    color=["batch", "leiden", "HTO_maxID"],
    frameon=False,
    ncols=2,
    save = "integration_all_genes.pdf"
)

In [None]:
# # run PCA then generate UMAP plots
# bm.raw = bm
# sc.pp.scale(bm)
# sc.tl.pca(bm)
# sc.pp.neighbors(bm, n_pcs=30, n_neighbors=20)
# sc.pl.umap(
#     bm,
#     color=["batch", "leiden", "HTO_maxID"],
#     frameon=False,
#     ncols=2,
# )

## Combining cells based on highly variable genes

In [None]:
# marker_genes = pd.read_csv('../output/markers_pivot.tsv', sep = "\t", index_col = 0)
# marker_genes.head()

In [None]:
hvg_pilot_ = bm_pilot.copy()
hvg_final_ = bm_final.copy()

In [None]:
sc.pp.highly_variable_genes(
    hvg_pilot_,
    n_top_genes=4000,
    subset=True,
    layer="counts",
    flavor="seurat_v3"
)
sc.pp.highly_variable_genes(
    hvg_final_,
    n_top_genes=4000,
    subset=True,
    layer="counts",
    flavor="seurat_v3"
)

In [None]:
# hvg_plus_marker = hvg_pilot_.var_names.union(hvg_final_.var_names).union(marker_genes.index)
# len(hvg_plus_marker)

hvg_geneset = hvg_pilot_.var_names.union(hvg_final_.var_names)
len(hvg_geneset)

In [None]:
# hvg_pilot = bm_pilot[:, hvg_plus_marker.intersection(bm_pilot.var_names)].copy()
# hvg_final = bm_final[:, hvg_plus_marker.intersection(bm_final.var_names)].copy()

In [None]:
np.max(hvg_pilot_.X, axis = 1).toarray()

In [None]:
hvg_pilot_.shape

In [None]:
bm_pilot.shape

In [None]:
adatas_hvg = {
    "pilot": hvg_pilot_[bm_pilot.obs['HTO_maxID'].isin(['HTO-PBS1', 'HTO-PBS2', 'HTO-SA1'])],
    "final": hvg_final_
}

In [None]:
hvg = ad.concat(adatas_hvg, label="batch") 

In [None]:
scvi.model.SCVI.setup_anndata(hvg, 
                              layer="counts", 
                              batch_key="batch", categorical_covariate_keys=["HTO_maxID"],
    continuous_covariate_keys=["percent_mt"])

In [None]:
vae_hvg = scvi.model.SCVI(hvg, n_layers=2, n_latent=30, gene_likelihood="nb")

In [None]:
vae_hvg.train()

In [None]:
hvg.raw = hvg

In [None]:
hvg.layers['scVI_normalized'] = vae_hvg.get_normalized_expression(hvg)

In [None]:
hvg.layers

In [None]:
hvg.obsm["X_scVI"] = vae_hvg.get_latent_representation()

In [None]:
sc.pp.neighbors(hvg, use_rep="X_scVI")
sc.tl.leiden(hvg)

In [None]:
hvg.obsm["X_mde"] = mde(hvg.obsm["X_scVI"])

In [None]:
sc.settings.set_figure_params(dpi=80, frameon=False, figsize=(8, 6), facecolor='white')  # low dpi (dots per inch) yields small inline figures

In [None]:
sc.pl.embedding(
    hvg,
    basis="X_mde",
    color=["batch", "leiden", "HTO_maxID"],
    frameon=False,
    ncols=2,
    save = "integration_hvg_genes.pdf"
)

## Saving

In [None]:
hvg.write("../data/h5ad/hvg_integrated_170824.h5ad")

In [None]:
bm.write("../data/h5ad/bm_integrated_170824.h5ad")