In [None]:
import sys
import matplotlib
matplotlib.rcParams["pdf.fonttype"] = 42
matplotlib.rcParams["ps.fonttype"] = 42
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import seaborn as sns
from datetime import datetime

import numpy as np
import numpy.random as random
import pandas as pd
import scanpy as sc
import louvain
import torch

from scvi.dataset import Dataset10X, CsvDataset, AnnDatasetFromAnnData, CellMeasurement, LoomDataset, DownloadableAnnDataset
from scvi.dataset.dataset import GeneExpressionDataset
from scvi.inference import TotalPosterior, TotalTrainer, load_posterior
from scvi.models import SCANVI, TOTALVI
from scvi import set_seed

from umap import UMAP

# Control UMAP numba warnings
import warnings; warnings.simplefilter('ignore')

%matplotlib inline

set_seed(123)

use_cuda = True
show_plot = True
test_mode = False
sampleFolder = "/PATH/TO/"

# Load data

In [None]:
### See script 4b_TotalVI.ipynb (idem)

### Merge

In [None]:
### Concatenate datasets - via union
all_dataset = GeneExpressionDataset()
all_dataset.populate_from_datasets([dataset1, dataset2, dataset3], 
                                    cell_measurement_intersection={"protein_expression": False})

In [None]:
### Check out merged dataset
all_dataset

In [None]:
### Save before taking HVG
import pickle
pickle.dump(all_dataset, open("/PATH/TO/all_dataset.pkl", "wb"))

In [None]:
### Get HVGenes
all_dataset.subsample_genes(4000, batch_correction = True, mode = "seurat_v2")
all_dataset

In [None]:
all_dataset.protein_expression.shape

In [None]:
all_dataset.gene_names[0:19]

In [None]:
all_dataset.protein_names[0:19]

In [None]:
all_dataset.batch_indices

In [None]:
## Sample1 = non-CiteSeq = batch index 0
## Sample2 = CiteSeq = batch index 1
## Sample3 = CiteSeq = batch index 2


### Remove the ABs counts of non-CiteSeq samples (replace by zero's) --> Dummie count table made in Rscript
all_dataset.protein_expression[all_dataset.batch_indices.ravel() == 0] = np.zeros_like(
    all_dataset.protein_expression[all_dataset.batch_indices.ravel() == 0]
)

In [None]:
batch_mask = all_dataset.get_batch_mask_cell_measurement("protein_expression")
batch_mask[0]

In [None]:
all_dataset.protein_names[np.logical_not(batch_mask[0])]

In [None]:
print(batch_mask[1].shape)
print(batch_mask[2].shape)

In [None]:
all_dataset.n_labels

### Save merged dataset

In [None]:
### Save
import pickle
pickle.dump(all_dataset, open("PATH/TO/HVG_dataset.pkl", "wb"))

### Reload merged dataset

In [None]:
### Load data again
import pickle
all_dataset = pickle.load(open("/PATH/TO/HVG_dataset.pkl", "rb"))

In [None]:
all_dataset

In [None]:
### Load full dataset again
import pickle
all_dataset_full = pickle.load(open("/PATH/TO/results_TotalVI/all_dataset.pkl", "rb"))

In [None]:
all_dataset_full

# Train model

In [None]:
### Initialize trainer
use_cuda = True
lr = 0.001
n_epochs = 400

# early_stopping_kwargs = {
#     "early_stopping_metric": "elbo",
#     "save_best_state_metric": "elbo",
#     "patience": 45,
#     "threshold": 0,
#     "reduce_lr_on_plateau": True,
#     "lr_patience": 30,
#     "lr_factor": 0.6,
#     "posterior_class": TotalPosterior,
# }


totalvae=TOTALVI(all_dataset.nb_genes, len(all_dataset.protein_names), 
                 n_batch=all_dataset.n_batches, n_latent=20, 
                 encoder_batch=True, protein_batch_mask=batch_mask)

### Prepare trainer
trainer = TotalTrainer(
    totalvae,
    all_dataset,
    train_size=0.90,
    test_size=0.10,
    use_cuda=use_cuda,
    frequency=1,
    batch_size=256,
    early_stopping_kwargs=None,
    use_adversarial_loss=True if all_dataset.n_batches > 1 else False
)

In [None]:
### Do training
print("Start =", datetime.now().strftime("%H:%M:%S"))

trainer.train(lr=lr, n_epochs=n_epochs)

print("End =", datetime.now().strftime("%H:%M:%S"))

In [None]:
### Plotting likelihood
plt.plot(trainer.history["elbo_train_set"], label="train")
plt.plot(trainer.history["elbo_test_set"], label="test")
plt.title("Negative ELBO over training epochs")
plt.ylim(1000,1500)
plt.legend()

# Get full posterior

In [None]:
### Get full posterior
full_posterior = trainer.create_posterior(
    totalvae, all_dataset, indices=np.arange(len(all_dataset)), type_class=TotalPosterior
)
full_posterior = full_posterior.update({"batch_size":32})

### Extract latent space
latent, batch_indices, label, library_gene = full_posterior.sequential().get_latent()
batch_indices = batch_indices.ravel()

In [None]:
latent.shape

# Save

In [None]:
print(os.getcwd())
print(sampleFolder)

In [None]:
### Save
import pickle
pickle.dump(full_posterior, open("/PATH/TO/fullPosterior.pkl", "wb"))

# Reload

In [None]:
### Load data again
import pickle
full_posterior = pickle.load(open("/PATH/TO/fullPosterior.pkl", "rb"))

In [None]:
latent, batch_indices, label, library_gene = full_posterior.sequential().get_latent()

# Clustering 

In [None]:
### Create adata object
post_adata = sc.AnnData(X=all_dataset.X)
post_adata.var.index = all_dataset.gene_names
post_adata.obsm["X_totalVI"] = latent

In [None]:
post_adata.obsm["X_totalVI"].shape

In [None]:
### table() of batch indices
print(np.array(np.unique(all_dataset.batch_indices, return_counts=True)).T)

In [None]:
### Run umap
#The higher the min_dist, the closer the clusters
sc.pp.neighbors(post_adata, use_rep="X_totalVI", n_neighbors=30, metric="correlation")
sc.tl.umap(post_adata, min_dist=0.3)

In [None]:
### Run clustering
sc.tl.louvain(post_adata, key_added="louvain", resolution=1.5)

In [None]:
### Prepare for umap
d_names = ["Sample1","Sample2","Sample3"]
post_adata.obs["sample"] = [d_names[int(b)] for b in all_dataset.batch_indices]

In [None]:
post_adata.obs["type"] = "RnaSeq"

In [None]:
for sampleName in ["Sample2","Sample3"]:
#     print (sampleName)
    post_adata.obs.loc[post_adata.obs['sample']==sampleName, 'type']='Citeseq'

In [None]:
post_adata.obs

In [None]:
print(np.array(np.unique(post_adata.obs['type'], return_counts=True)).T)

In [None]:
inds = np.random.permutation(np.arange(all_dataset.X.shape[0]))

In [None]:
print(all_dataset.X.shape)
print(inds.shape)

In [None]:
### Create umap
figUmap = sc.pl.umap(
    post_adata, 
    color="louvain",
    ncols=1,
    alpha=0.9,
    legend_loc="on data",
#     legend_loc="right margin",
    return_fig=True
)

In [None]:
### Create umap split on sample
figUmapSplitSample = sc.pl.umap(
    post_adata[inds], 
    color=["sample"],
    ncols=1,
    alpha=0.9,
    return_fig=True
)

In [None]:
### Create umap split on type
figUmapSplitType = sc.pl.umap(
    post_adata[inds], 
    color=["type"],
    ncols=1,
    alpha=0.9,
    return_fig=True
)

In [None]:
### Save plots
figUmap.savefig("PATH/TO/umap.png", dpi=200, bbox_inches='tight')
figUmapSplitSample.savefig("PATH/TO/results_TotalVI/umapSplitSample.png", dpi=200, bbox_inches='tight')
figUmapSplitType.savefig("PATH/TO/results_TotalVI/umapSplitType.png", dpi=200, bbox_inches='tight')

# Get denoised data

In [None]:
# Number of Monte Carlo samples to average over
n_samples = 10
parsed_protein_names=all_dataset.protein_names.tolist()

# Probability of background for each (cell, protein)
py_mixing = full_posterior.sequential().get_sample_mixing(n_samples=n_samples, give_mean=True)
protein_foreground_prob = pd.DataFrame(
    data=(1 - py_mixing), columns=parsed_protein_names
)

In [None]:
## Denoised genes for all samples
denoised_genes_reallyAll, denoised_proteins_reallyAll = full_posterior.sequential().get_normalized_denoised_expression(
    n_samples=n_samples, give_mean=True
)
print(denoised_genes_reallyAll.shape)
print(denoised_proteins_reallyAll.shape)

In [None]:
## Denoised genes for all citeSeq samples
denoised_genes_all, denoised_proteins_all = full_posterior.sequential().get_normalized_denoised_expression(
    n_samples=n_samples, give_mean=True, transform_batch=[1,2] ## sample batch index
)
print(denoised_genes_all.shape)
print(denoised_proteins_all.shape)

In [None]:
## Denoised genes for sample3 because it has a different amount of ABs (20 more) added than sample2
denoised_genes_Sample3, denoised_proteins_Sample3 = full_posterior.sequential().get_normalized_denoised_expression(
    n_samples=n_samples, give_mean=True, transform_batch=[2] ## sample batch index
)
print(denoised_genes_Sample3.shape)
print(denoised_proteins_Sample3.shape)

In [None]:
## Load txt files containing overlap and sample3 ABs created in the 4a_pre_TotalVI-ScVI.R script
fileName="PATH/TO/overlapABs.txt"
overlapABs = np.loadtxt(fileName, dtype='str')

fileName="PATH/TO/ABs_Sample3.txt"
ABs_Sample3 = np.loadtxt(fileName, dtype='str')

In [None]:
overlapABs_IDs=[]
for i in range(len(overlapABs)):
    theIndex=parsed_protein_names.index(overlapABs[i])
    overlapABs_IDs.append(theIndex)

In [None]:
ABs_Sample3_IDs=[]
for i in range(len(ABs_Sample3)):
    theIndex=parsed_protein_names.index(ABs_Sample3[i])
    ABs_Sample3_IDs.append(theIndex)

In [None]:
denoised_proteins_overlap=denoised_proteins_all[:,overlapABs_IDs]
denoised_proteins_Sample3=denoised_proteins_Sample3[:,ABs_Sample3_IDs]

In [None]:
print(denoised_proteins_overlap.shape)
print(denoised_proteins_Sample3.shape)

In [None]:
denoised_genes=denoised_genes_reallyAll
denoised_proteins=np.concatenate((denoised_proteins_overlap,
                                  denoised_proteins_Sample3), axis=1)

In [None]:
print(denoised_genes.shape)
print(denoised_proteins.shape)

In [None]:
len(parsed_protein_names)

In [None]:
parsed_protein_names_New=np.concatenate((overlapABs, Sample3)).tolist()

In [None]:
len(parsed_protein_names_New)

In [None]:
### Get raw values
combined_protein = all_dataset.protein_expression

In [None]:
print(combined_protein.shape)
print(protein_foreground_prob.shape)
print(denoised_proteins.shape)

In [None]:
### Add normalised protein values to post_adata (via obs)
for i, p in enumerate(parsed_protein_names):
    post_adata.obs["{}_fore_prob".format(p)] = protein_foreground_prob[p].values
    post_adata.obs["{}_observed".format(p)] = combined_protein[:, i]

In [None]:
### Add normalised protein values to post_adata (via obs)
for i, p in enumerate(parsed_protein_names_New):
    post_adata.obs["{}".format(p)] = denoised_proteins[:, i]

In [None]:
### Add normalised gene values to post_adata (via layer)
post_adata.layers["norm_genes"] = denoised_genes

In [None]:
post_adata.var.index

##  Save post_adata

In [None]:
post_adata.obs["sample"] = post_adata.obs["sample"].astype("str")
post_adata.obs["louvain"] = post_adata.obs["louvain"].astype("int")
post_adata.obs["type"] = post_adata.obs["type"].astype("str")

In [None]:
### Save
import pickle
pickle.dump(post_adata, open("PATH/TO/post_adata.pkl", "wb"))

##  Reload post_adata

In [None]:
### The next steps are the same as in script 4b_TotalVI.ipynb