In [1]:
import matplotlib
%matplotlib inline

In [2]:
import matplotlib
%matplotlib inline
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
import matplotlib.pyplot as plt

import seaborn as sns

import os
os.chdir("/home/ec2-user/scVI/")
os.getcwd()

from umap import UMAP

use_cuda = True
import torch

# import data loading functions

In [3]:
from scvi.harmonization.utils_chenling import get_matrix_from_dir
from scvi.dataset.pbmc import PbmcDataset
from scvi.harmonization.utils_chenling import assign_label
import numpy as np
from scvi.dataset.dataset import GeneExpressionDataset
from copy import deepcopy


# import scVI models

In [4]:
from scvi.inference import UnsupervisedTrainer, SemiSupervisedTrainer
from scvi.models.scanvi import SCANVI
from scvi.models.vae import VAE

In [None]:
dataset1 = PbmcDataset(filter_out_de_genes=False)
dataset1.update_cells(dataset1.batch_indices.ravel()==0)
dataset1.subsample_genes(dataset1.nb_genes)

count, geneid, cellid = get_matrix_from_dir('cite')
count = count.T.tocsr()
seurat = np.genfromtxt('../cite/cite.seurat.labels', dtype='str', delimiter=',')
cellid = np.asarray([x.split('-')[0] for x in cellid])
labels_map = [0, 0, 1, 2, 3, 4, 5, 6]
labels = seurat[1:, 4]
cell_type = ['CD4 T cells', 'NK cells', 'CD14+ Monocytes', 'B cells','CD8 T cells', 'FCGR3A+ Monocytes', 'Other']
dataset2 = assign_label(cellid, geneid, labels_map, count, cell_type, seurat)
set(dataset2.cell_types).intersection(set(dataset2.cell_types))


File data/gene_info_pbmc.csv already downloaded
File data/pbmc_metadata.pickle already downloaded
File data/10X/pbmc8k/filtered_gene_bc_matrices.tar.gz already downloaded
Preprocessing dataset
Finished preprocessing dataset
Downsampling from 33694 to 21425 genes
Downsampling from 8381 to 8381 cells


In [None]:
rmCellTypes = 'B cells'

In [None]:
pbmc = deepcopy(dataset1)
newCellType = [k for i, k in enumerate(dataset1.cell_types) if k not in [rmCellTypes]]
pbmc.filter_cell_types(newCellType)
pbmc2 = deepcopy(dataset2)
pbmc2.filter_cell_types([rmCellTypes])
gene_dataset = GeneExpressionDataset.concat_datasets(pbmc, pbmc2)

In [None]:
rm_idx = np.arange(len(gene_dataset.cell_types))[gene_dataset.cell_types == rmCellTypes][0]

In [None]:
gene_dataset.cell_types[rm_idx] == rmCellTypes

In [None]:
pbmc3 = deepcopy(pbmc2)
pbmc3.cell_types = ['CD4 T cells']
false_gene_dataset = GeneExpressionDataset.concat_datasets(pbmc, pbmc3)

In [None]:
pbmc.cell_types

In [None]:
pbmc3.cell_types

In [None]:
pbmc2.cell_types

In [None]:
gene_dataset.subsample_genes(1000)
false_gene_dataset.subsample_genes(1000)

In [None]:
vae = VAE(gene_dataset.nb_genes, n_batch=gene_dataset.n_batches,
      n_hidden=128, n_latent=10, n_layers=2, dispersion='gene')
trainer = UnsupervisedTrainer(vae, gene_dataset, train_size=1.0)
trainer.train(n_epochs=250)
full = trainer.create_posterior(trainer.model, gene_dataset, indices=np.arange(len(gene_dataset)))


In [None]:
scanvi = SCANVI(gene_dataset.nb_genes, gene_dataset.n_batches, gene_dataset.n_labels,
                  n_hidden=128, n_latent=10, n_layers=2, dispersion='gene')
scanvi.load_state_dict(trainer.model.state_dict(), strict=False)
trainer_scanvi = SemiSupervisedTrainer(scanvi, gene_dataset, classification_ratio=50,
                                       n_epochs_classifier=1, lr_classification=5 * 1e-3)

trainer_scanvi.labelled_set = trainer_scanvi.create_posterior(indices=(gene_dataset.batch_indices == 0))
trainer_scanvi.unlabelled_set = trainer_scanvi.create_posterior(indices=(gene_dataset.batch_indices == 1))
trainer_scanvi.train(n_epochs=5)


In [None]:
unlabelled_idx = trainer_scanvi.unlabelled_set.indices

In [None]:
full_scanvi = trainer_scanvi.create_posterior(trainer_scanvi.model, gene_dataset, indices=np.arange(len(gene_dataset)))
all_y_pred = []
for i_batch, tensors in enumerate(full_scanvi.sequential()):
    sample_batch, _, _, _, _ = tensors
    all_y_pred += [scanvi.classify(sample_batch)]

all_y_pred = np.array(torch.cat(all_y_pred))    
max_prob=(np.max(all_y_pred,axis=1))

In [None]:
scanvi_false = SCANVI(false_gene_dataset.nb_genes, false_gene_dataset.n_batches, false_gene_dataset.n_labels,
                  n_hidden=128, n_latent=10, n_layers=2, dispersion='gene')
scanvi_false.load_state_dict(trainer.model.state_dict(), strict=False)
trainer_scanvi_false = SemiSupervisedTrainer(scanvi_false, false_gene_dataset, classification_ratio=50,
                                       n_epochs_classifier=1, lr_classification=5 * 1e-3)

trainer_scanvi_false.labelled_set = trainer_scanvi_false.create_posterior(indices=(false_gene_dataset.batch_indices == 0))
trainer_scanvi_false.unlabelled_set = trainer_scanvi_false.create_posterior(indices=(false_gene_dataset.batch_indices == 1))
trainer_scanvi_false.train(n_epochs=10)


In [None]:
full_scanvi_false = trainer_scanvi_false.create_posterior(trainer_scanvi_false.model, false_gene_dataset, indices=np.arange(len(false_gene_dataset)))
all_y_pred_false = []
for i_batch, tensors in enumerate(full_scanvi_false.sequential()):
    sample_batch, _, _, _, _ = tensors
    all_y_pred_false += [scanvi_false.classify(sample_batch)]

all_y_pred_false = np.array(torch.cat(all_y_pred_false))
max_prob_false=(np.max(all_y_pred_false,axis=1))
pred_prob_false = (np.argmax(all_y_pred_false,axis=1))

In [None]:
fig, ax = plt.subplots(figsize=(8, 8))
plt.hist(max_prob[gene_dataset.batch_indices.ravel() ==1], 30, alpha=0.5, label='maximum prob')
plt.hist(max_prob_false[false_gene_dataset.batch_indices.ravel() ==1], 30, alpha=0.5, label='maximum prob (mis-specified model)')
plt.legend(loc='upper left')
fig.savefig("figures/SCANVI_stress_probfalse_Bcells.pdf", transparency=True)


In [None]:
from scipy.special import logit

In [None]:
fig, ax = plt.subplots(figsize=(8, 8))
plt.hist(logit(max_prob[gene_dataset.batch_indices.ravel() !=1][:np.sum(gene_dataset.batch_indices.ravel() ==1)]), 30, alpha=0.5, label='Not B cells')
plt.hist(logit(max_prob[false_gene_dataset.batch_indices.ravel() ==1]), 30, alpha=0.5, label='B cells')
plt.legend(loc='upper left')
plt.show()
# fig.savefig("figures/SCANVI_stress_probfalse_Bcells.pdf", transparency=True)


In [None]:
fig, ax = plt.subplots(figsize=(8, 8))
plt.hist(logit(max_prob_false[gene_dataset.batch_indices.ravel() !=1][:np.sum(gene_dataset.batch_indices.ravel() ==1)]), 30, alpha=0.5, label='Not B cells')
plt.hist(logit(max_prob_false[false_gene_dataset.batch_indices.ravel() ==1]), 30, alpha=0.5, label='B cells')
plt.legend(loc='upper left')
plt.show()
# fig.savefig("figures/SCANVI_stress_probfalse_Bcells.pdf", transparency=True)


In [None]:
fig, ax = plt.subplots(figsize=(8, 8))
plt.hist(max_prob[gene_dataset.batch_indices.ravel() ==0], 30, alpha=0.5, label='maximum prob')
plt.hist((max_prob_false[false_gene_dataset.batch_indices.ravel() ==0]), 30, alpha=0.5, label='maximum prob (mis-specified model)')
plt.legend(loc='upper left')
fig.savefig("figures/SCANVI_stress_probfalse_all.pdf", transparency=True)


In [None]:
from scvi.metrics.clustering import select_indices_evenly

In [None]:
latent, batch_indices, _ = full.sequential().get_latent()

In [None]:
keys = gene_dataset.cell_types
key_order = np.argsort(keys)
labels = gene_dataset.labels.ravel()

sample = select_indices_evenly(2000, labels)
colors = sns.color_palette('tab20')
latent_s = latent[sample, :]
label_s = labels[sample]
batch_s = batch_indices.ravel()[sample]

from umap import UMAP
latent_u = UMAP(spread=2).fit_transform(latent_s)

fig, ax = plt.subplots(figsize=(8, 8))
for i,k in enumerate(key_order):
    idx = label_s==k
    ax.scatter(latent_u[idx, 0], latent_u[idx, 1], c=colors[i%20], label=keys[k],
                   edgecolors='none')
plt.legend(bbox_to_anchor=(1.1, 0.5), borderaxespad=0, fontsize='x-large')
plt.axis("off")
fig.tight_layout()
plt.savefig("figures/SCANVI_stress_vaeUMAP_labels.pdf", transparency=True)


In [None]:
fig, ax = plt.subplots(figsize=(8, 8))
for i,k in enumerate(['all others (PBMC8K)','B only (Cite-Seq) ']):
    idx = batch_s==i
    ax.scatter(latent_u[idx, 0], latent_u[idx, 1], c=colors[i%20], label=k,
                   edgecolors='none')
plt.legend(bbox_to_anchor=(1.1, 0.5), borderaxespad=0, fontsize='x-large')
plt.axis("off")
fig.tight_layout()
plt.savefig("figures/SCANVI_stress_vaeUMAP_batch.pdf", transparency=True)


### SCANVI latent space

In [None]:
latent_scanvi, _, _ = full_scanvi.sequential().get_latent()
latent_s = latent_scanvi[sample, :]

from umap import UMAP
latent_scanvi_u = UMAP(spread=2).fit_transform(latent_s)

fig, ax = plt.subplots(figsize=(8, 8))
for i,k in enumerate(key_order):
    idx = label_s==k
    ax.scatter(latent_scanvi_u[idx, 0], latent_scanvi_u[idx, 1], c=colors[i%20], label=keys[k],
                   edgecolors='none')
# plt.legend(bbox_to_anchor=(1.1, 0.5), borderaxespad=0, fontsize='x-large')
plt.axis("off")
fig.tight_layout()
plt.savefig("figures/SCANVI_stress_scanviUMAP_labels.pdf", transparency=True)

fig, ax = plt.subplots(figsize=(8, 8))
for i,k in enumerate(['all others (PBMC8K)','B only (Cite-Seq) ']):
    idx = batch_s==i
    ax.scatter(latent_scanvi_u[idx, 0], latent_scanvi_u[idx, 1], c=colors[i%20], label=k,
                   edgecolors='none')
plt.legend(bbox_to_anchor=(1.1, 0.5), borderaxespad=0, fontsize='x-large')
plt.axis("off")
fig.tight_layout()
plt.savefig("figures/SCANVI_stress_scanviUMAP_batch.pdf", transparency=True)


### Incorrect number of classes SCANVI

In [None]:
latent_scanvi_false, _, _ = full_scanvi_false.sequential().get_latent()

latent_s = latent_scanvi_false[sample, :]

from umap import UMAP
latent_scanvi_false_u = UMAP(spread=2).fit_transform(latent_s)

fig, ax = plt.subplots(figsize=(8, 8))
for i,k in enumerate(key_order):
    idx = label_s==k
    ax.scatter(latent_scanvi_false_u[idx, 0], latent_scanvi_false_u[idx, 1], c=colors[i%20], label=keys[k],
                   edgecolors='none')
# plt.legend(bbox_to_anchor=(1.1, 0.5), borderaxespad=0, fontsize='x-large')
plt.axis("off")
fig.tight_layout()
plt.savefig("figures/SCANVI_stress_falsescanviUMAP_labels.pdf", transparency=True)

fig, ax = plt.subplots(figsize=(8, 8))
for i,k in enumerate(['all others (PBMC8K)','B only (Cite-Seq) ']):
    idx = batch_s==i
    ax.scatter(latent_scanvi_false_u[idx, 0], latent_scanvi_false_u[idx, 1], c=colors[i%20], label=k,
                   edgecolors='none')
# plt.legend(bbox_to_anchor=(1.1, 0.5), borderaxespad=0, fontsize='x-large')
plt.axis("off")
fig.tight_layout()
plt.savefig("figures/SCANVI_stress_falsescanviUMAP_batch.pdf", transparency=True)
