In [None]:
import numpy as np
import pandas as pd
import scanpy as sc
import anndata
from sklearn.decomposition import TruncatedSVD
import matplotlib
import matplotlib.pyplot as plt

# Note: much of this notebook was copied from the kallisto tutorial available here: https://colab.research.google.com/github/pachterlab/kallistobustools/blob/master/notebooks/kb_analysis_0_python.ipynb#scrollTo=pK3fnX8hCuT-

In [None]:
# Download the data from the 10x website
!wget http://cf.10xgenomics.com/samples/cell-exp/3.0.0/pbmc_1k_v3/pbmc_1k_v3_fastqs.tar

# unpack the downloaded files
!tar -xvf pbmc_1k_v3_fastqs.tar

In [None]:
# Download a transcriptome index
!kb ref -d human -i index.idx -g t2g.txt -f1 transcriptome.fasta

In [None]:
# Run kallisto
!kb count --h5ad -i index.idx -g t2g.txt -x 10xv3 -o output --filter bustools -t 2 \
pbmc_1k_v3_fastqs/pbmc_1k_v3_S1_L001_R1_001.fastq.gz \
pbmc_1k_v3_fastqs/pbmc_1k_v3_S1_L001_R2_001.fastq.gz \
pbmc_1k_v3_fastqs/pbmc_1k_v3_S1_L002_R1_001.fastq.gz \
pbmc_1k_v3_fastqs/pbmc_1k_v3_S1_L002_R2_001.fastq.gz

In [None]:
# Set some settings
sc.settings.verbosity = 3             # verbosity: errors (0), warnings (1), info (2), hints (3)
sc.logging.print_versions()
sc.settings.set_figure_params(dpi=80)

In [None]:
# load the unfiltered matrix
results_file = 'pbmc1k.h5ad'  # the file that will store the analysis results
adata = anndata.read_h5ad("output/counts_unfiltered/adata.h5ad")
adata.var["gene_id"] = adata.var.index.values

t2g = pd.read_csv("t2g.txt", header=None, names=["tid", "gene_id", "gene_name"], sep="\t")
t2g.index = t2g.gene_id
t2g = t2g.loc[~t2g.index.duplicated(keep='first')]

adata.var["gene_name"] = adata.var.gene_id.map(t2g["gene_name"])
adata.var.index = adata.var["gene_name"]

adata.var_names_make_unique()  # this is unnecessary if using `var_names='gene_ids'` in `sc.read_10x_mtx`

In [None]:
# Removes cells with less than 1070 umi counts
adata = adata[np.asarray(adata.X.sum(axis=1)).reshape(-1) > 1070]

# Removes genes with 0 umi counts
adata = adata[:, np.asarray(adata.X.sum(axis=0)).reshape(-1) > 0]

In [None]:
# filtering
sc.pp.filter_cells(adata, min_genes=200)
sc.pp.filter_genes(adata, min_cells=3)

In [None]:
mito_genes = adata.var_names.str.startswith('MT-')
# for each cell compute fraction of counts in mito genes vs. all genes
# the `.A1` is only necessary as X is sparse (to transform to a dense array after summing)
adata.obs['percent_mito'] = np.sum(
    adata[:, mito_genes].X, axis=1).A1 / np.sum(adata.X, axis=1).A1
# add the total counts per cell as observations-annotation to adata
adata.obs['n_counts'] = adata.X.sum(axis=1).A1

In [None]:
# Create a mask to filter out cells with more than 6500 genes, less than 200 genes or less than 0.2 mitochondrial umi counts
mask = np.logical_or((adata.obs.n_genes < 6500).values, (adata.obs.n_genes > 200).values, (adata.obs.percent_mito < 0.2).values)

In [None]:
#filter
adata = adata[mask, :]

In [None]:
# normalize counts in each cell to be equal
sc.pp.normalize_total(adata, target_sum=10**4)

In [None]:
# Replace raw counts with their logarithm
sc.pp.log1p(adata)

In [None]:
adata.raw = adata

In [None]:
# flavor="cell_ranger" is consistent with Seurat and flavor="suerat" is not consistent with Seurat
sc.pp.highly_variable_genes(adata, min_mean=0.01, max_mean=8, min_disp=1, n_top_genes=2000, flavor="cell_ranger", n_bins=20)

In [None]:
sc.pp.scale(adata, max_value=10)

In [None]:
data_var = adata.var.copy()

In [None]:
data = np.copy(adata.X)[:, adata.var.highly_variable]
data = data - data.mean(axis=0)
u, s, v = np.linalg.svd(data)

In [None]:
# one pass over the variables doing CAVI on the new VI scheme
def update_step_sparse(X,
                       vi_mu_z,
                       vi_sigma_z,
                       vi_psi_0,
                       vi_mu_w,
                       vi_sigma_w,
                       sigma_sq_e,
                       sigma_sq_1,
                       p_0):
    # Update Z
    new_mu_z = np.copy(vi_mu_z)
    expected_w = (1 - vi_psi_0) * vi_mu_w
    expected_w_gram = np.einsum('ik,il->kl',
                                expected_w,
                                expected_w)
    var_w = (1 - vi_psi_0) * (vi_mu_w ** 2 + vi_sigma_w)
    var_w -= expected_w ** 2
    expected_w_gram += np.diag(var_w.sum(axis=0))
    new_sigma_z = np.linalg.inv(expected_w_gram / sigma_sq_e
                                + np.eye(expected_w_gram.shape[0]))
    for n in range(X.shape[0]):
        new_mu_z[n] = new_sigma_z.dot(expected_w.T.dot(X[n])) / sigma_sq_e

    # Update W
    new_mu_w = np.copy(vi_mu_w)
    new_sigma_w = np.copy(vi_sigma_w)
    new_psi = np.copy(vi_psi_0)
    # expected_z_sq_sum = np.einsum('nk,kk->k', new_mu_z**2, new_sigma_z)
    expected_x_z_sum = np.einsum('ni,nk->ik', X, new_mu_z)
    expected_z_cov = np.einsum('nk,nl->kl', new_mu_z, new_mu_z)
    expected_z_cov += new_mu_z.shape[0] * new_sigma_z
    expected_z_sq_sum = np.diag(expected_z_cov)
    for i in range(vi_mu_w.shape[0]):
        for k in range(vi_mu_w.shape[1]):
            new_sigma_w[i, k] = (expected_z_sq_sum[k] / sigma_sq_e
                                 + 1. / sigma_sq_1) ** -1
            linked_ests = np.dot((1 - new_psi[i]) * new_mu_w[i],
                                 expected_z_cov[k])
            linked_ests -= ((1 - new_psi[i, k])
                            * new_mu_w[i, k]
                            * expected_z_cov[k, k])
            new_mu_w[i, k] = (new_sigma_w[i, k]
                              * (expected_x_z_sum[i, k] - linked_ests)
                              / sigma_sq_e)
            log_odds = (np.log(p_0 / (1-p_0))
                        + 0.5 * np.log(sigma_sq_1)
                        - 0.5 * new_mu_w[i, k]**2 / new_sigma_w[i, k]
                        - 0.5 * np.log(new_sigma_w[i, k]))
            new_psi[i, k] = 1. / (1 + np.exp(-log_odds))
    return new_mu_z, new_sigma_z, new_psi, new_mu_w, new_sigma_w

In [None]:
sigma_sq_e = 0.5
sigma_sq_1 = 0.5
p_zero = 0.999
K = 10

num_vars = data.shape[1]
u, s, v = np.linalg.svd(data)
vi_mu_z = np.copy(u[:, 0:K])
vi_sigma_z = np.eye(K)
vi_mu_w = np.copy(v[0:K, :].T) * s[0:K]
vi_sigma_w = np.ones((num_vars, K))
vi_psi_0 = np.ones((num_vars, K)) * 0.01
for i in range(5000):
    (new_vi_mu_z,
     vi_sigma_z,
     vi_psi_0,
     vi_mu_w,
     vi_sigma_w) = update_step_sparse(data,
                                      vi_mu_z,
                                      vi_sigma_z,
                                      vi_psi_0,
                                      vi_mu_w,
                                      vi_sigma_w,
                                      sigma_sq_e,
                                      sigma_sq_1,
                                      p_zero)
    if i % 100 == 0:
        error = np.sum((new_vi_mu_z - vi_mu_z)**2)
        print(error)
        if error < 1e-7:
            break
    vi_mu_z = new_vi_mu_z

In [None]:
cell_loadings = vi_mu_z
gene_loadings = vi_mu_w
pips = 1 - vi_psi_0

In [None]:
print('index for GATA3', np.where(data_var[data_var.highly_variable].gene_name == 'GATA3')[0][0])
print('index for SPI1', np.where(data_var[data_var.highly_variable].gene_name == 'SPI1')[0][0])
print('index for BCL2', np.where(data_var[data_var.highly_variable].gene_name == 'BCL2')[0][0])

In [None]:
fig, ax = plt.subplots(nrows=3, ncols=2, figsize=(7, 7))

#GATA3
gene_idx = 1163
sorter = np.argsort(data[:,gene_idx])
ax[0, 0].scatter(u[sorter, 0], u[sorter, 1], c=data[sorter,gene_idx])
ax[0, 1].scatter(cell_loadings[sorter, 0], cell_loadings[sorter, 1], c=data[sorter,gene_idx])
print('PIPs on first 2 PCs for GATA3:', pips[gene_idx, 0], pips[gene_idx, 1])

#SPI1
gene_idx = 1096
sorter = np.argsort(data[:,gene_idx])
ax[1, 0].scatter(u[sorter, 0], u[sorter, 1], c=data[sorter,gene_idx])
ax[1, 1].scatter(cell_loadings[sorter, 0], cell_loadings[sorter, 1], c=data[sorter,gene_idx])
print('PIPs on first 2 PCs for SPI1:', pips[gene_idx, 0], pips[gene_idx, 1])

#BCL2
gene_idx = 1740
sorter = np.argsort(data[:,gene_idx])
ax[2, 0].scatter(u[sorter, 0], u[sorter, 1], c=data[sorter,gene_idx])
ax[2, 1].scatter(cell_loadings[sorter, 0], cell_loadings[sorter, 1], c=data[sorter,gene_idx])
print('PIPs on first 2 PCs for BCL2:', pips[gene_idx, 0], pips[gene_idx, 1])


ax[2, 1].arrow(1, -2.2, .4, 0, width=0.05, head_length=0.05)
ax[2, 1].arrow(1, -2.2, 0, 0.9, width=0.025, head_length=0.08)
ax[2, 1].text(1.06, -2.55, 'PC1', fontsize=10)
ax[2, 1].text(0.82, -1.95, 'PC2', fontsize=10, rotation=90)
ax[0, 0].set_title('Classical PCA')
ax[0, 1].set_title('Sparse pPCA')
ax[0, 0].set_ylabel('GATA3')
ax[1, 0].set_ylabel('SPI1')
ax[2, 0].set_ylabel('BCL2')
plt.subplots_adjust(wspace=0, hspace=0)
plt.setp(plt.gcf().get_axes(), xticks=[], yticks=[]);
plt.savefig('../figs/pbmc_pca.pdf', bbox_inches='tight', pad_inches=0)

In [None]:
plt.scatter(np.arange(10)+1, pips.mean(axis=0))
plt.xlabel('Principal Component')
plt.ylabel('Average PIP')
plt.savefig('../figs/pbmc_pips.pdf', bbox_inches='tight', pad_inches=0)

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
ax[0].scatter(v[0,:], pips[:, 0] * gene_loadings[:, 0], alpha=0.05)
ax[0].set_title('PC 1')
ax[0].set_xlabel('Classical PCA loading')
ax[0].set_ylabel('VI pPCA posterior mean')

ax[1].scatter(v[1,:], pips[:, 1] * gene_loadings[:, 1], alpha=0.05)
ax[1].set_title('PC 2')
ax[1].set_xlabel('Classical PCA loading')
plt.savefig('../figs/pbmc_posterior_mean.pdf', bbox_inches='tight', pad_inches=0)