Please run code in `seurat_pbmc.ipynb` first.

In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.insert(0, '../')

import os
import time
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import figure_utils as utils

from truncated_normal import truncated_normal as tn
from scipy.stats import ttest_ind, ranksums, gmean
from sklearn.svm import SVC

%matplotlib inline

  return f(*args, **kwds)
  return f(*args, **kwds)


In [2]:
# Load data
data_dir = './'
data_R = pd.read_csv(os.path.join(data_dir, 'pbmc_seurat_preprocessed.txt'), delimiter=' ')
labels_R = pd.read_csv(os.path.join(data_dir, 'pbmc_seurat_labels.txt'), delimiter=' ')
tsne_R = pd.read_csv(os.path.join(data_dir, 'pbmc_seurat_tsne.txt'), delimiter=' ')
data = np.array(data_R).T
labels = np.array(labels_R).reshape(-1)
tsne = np.array(tsne_R)
genes = np.array(data_R.index)

tests = [
    "t",
    "wilcox",
    "bimod",
    "tobit",
    "poisson",
    "negbinom",
    "MAST"
]

In [3]:
# Split the datasets
ident0, ident1 = 0, 1
cellsamp1 = pd.read_csv(os.path.join(data_dir, 'pbmc_seurat_%sv%s_cellsamp1.txt'%(ident0, ident1)), delimiter=' ')
pdict = {j:i for i, j in enumerate(data_R.columns)}
cellsamp_inds = np.array([pdict[i] for i in np.array(cellsamp1).reshape(-1)])
cellsamp1labels = np.array(pd.read_csv(os.path.join(data_dir, 'pbmc_seurat_%sv%s_cellsamplabels1.txt'%(ident0, ident1)), delimiter=' ')).reshape(-1)

In [None]:
# get hyperplane using first half of dataset
svm = SVC(kernel='linear', C=100)
svm.fit(data[cellsamp_inds], cellsamp1labels)
a = svm.coef_.reshape(-1)
b = svm.intercept_[0]
labels1hat = svm.predict(data[cellsamp_inds])
print('Consistency of new labels with old: %.3f'\
      %(np.sum(labels1hat == cellsamp1labels)/float(len(cellsamp1labels))))

# assign points in second half of dataset
inds = np.ones(len(data)).astype(bool)
inds[cellsamp_inds] = 0
for i in np.unique(labels):
    if i != 0 and i != 1:
        inds[labels == i] = 0
labels2 = svm.predict(data[inds])
y, z = data[inds][labels2 == 0], data[inds][labels2 == 1]

Consistency of new labels with old: 1.000


In [None]:
# Grab list of genes of interest based on other DE methods
def grab_interesting_de_genes(ident0, ident1):
    genes_list = []
    ngenes = 10
    for i, test in enumerate(tests):
        seurat_de = pd.read_csv(os.path.join(data_dir, 'pbmc_seurat_%sv%s_de_%s.txt' \
                                             %(ident0, ident1, test)), delimiter=' ')
        genes_list.extend(list(seurat_de.index[:10]))
    genes_list = np.unique(genes_list)
    return genes_list

# Run TN test on genes that actually change across the two clusters
keep_inds = np.var(np.vstack((y, z)), 0) > 0
y = y[:, keep_inds]
z = z[:, keep_inds]
a = a[keep_inds]

# sweep through iterations
genes_list = grab_interesting_de_genes(ident0, ident1)
genes_to_test = [np.where(genes[keep_inds] == gene)[0][0] for gene in genes_list]
iter_sweep = np.logspace(0, 5, 11).astype(int)

tn_test_multiple_runs = {}
num_sims = 10

picklefile = 'pbmc_seurat_0v1_de_tntest_numiter_sweep.pickle'
if not os.path.isfile(picklefile):
    np.random.seed(0)
    for j in range(num_sims):
        for num_iters in iter_sweep:
            p_tn, likelihood = tn.tn_test(y, z, a=a, b=b, verbose=True, learning_rate=0.5,
                                          eps=1e-1, genes_to_test=genes_to_test,
                                          return_likelihood=True, num_iters=num_iters)
            tn_test_multiple_runs[(num_iters, j)] = (p_tn, likelihood[-1])
    pickle.dump(tn_test_multiple_runs, open(picklefile, 'wb'))
else:
    tn_test_multiple_runs = pickle.load(open(picklefile, 'rb'))

In [None]:
# Collapse multiple runs into one
plot_results = []
for num_iters in iter_sweep:
    runs = [tn_test_multiple_runs[(num_iters, j)] for j in range(num_sims)]
    p_tn, likelihood = sorted(runs, key = lambda x:x[1])[-1]
    plot_results.append(p_tn)
    
# Collapse multiple runs into one
plot_results = []
for num_iters in iter_sweep:
    runs = [tn_test_multiple_runs[(num_iters, j)] for j in range(num_sims) 
            if not np.isinf(tn_test_multiple_runs[(num_iters, j)][1])]
    p_tn, likelihood = sorted(runs, key = lambda x:x[1])[-1]
    plot_results.append(p_tn)

for i in range(5):
    curve = [-np.log10(result[i]) for result in plot_results]
    plt.plot(iter_sweep[2:], curve[2:], label=genes_list[i])
plt.xscale('log')
plt.xlabel('number of iterations')
plt.ylabel(r'$-\log_{10}(p)$')
plt.legend(bbox_to_anchor=(1.4, 1))
plt.savefig('figures/pbmc_%sv%s_num_iters.pdf'%(ident0, ident1), format='pdf', dpi=500, bbox_inches='tight')
plt.show()    

# Testing other clusters

In [None]:
def load_and_split_data(ident0, ident1, save=True):
    cellsamp1 = pd.read_csv(os.path.join(data_dir, 'pbmc_seurat_%sv%s_cellsamp1.txt'%(ident0, ident1)), delimiter=' ')
    pdict = {j:i for i, j in enumerate(data_R.columns)}
    cellsamp_inds = np.array([pdict[i] for i in np.array(cellsamp1).reshape(-1)])
    cellsamp1labels = np.array(pd.read_csv(os.path.join(data_dir, 'pbmc_seurat_%sv%s_cellsamplabels1.txt'%(ident0, ident1)), delimiter=' ')).reshape(-1)

    # get hyperplane using first half of dataset
    svm = SVC(kernel='linear', C=100)
    svm.fit(data[cellsamp_inds], cellsamp1labels)
    a = svm.coef_.reshape(-1)
    b = svm.intercept_[0]
    labels1hat = svm.predict(data[cellsamp_inds])
    print('Consistency of new labels with old: %.3f'\
          %(np.sum(labels1hat == cellsamp1labels)/float(len(cellsamp1labels))))

    # assign points in second half of dataset
    inds = np.ones(len(data)).astype(bool)
    inds[cellsamp_inds] = 0
    for i in np.unique(labels):
        if i != ident0 and i != ident1:
            inds[labels == i] = 0
    labels2 = svm.predict(data[inds])
    y, z = data[inds][labels2 == 0], data[inds][labels2 == 1]

    if save:
        np.savetxt('pbmc_seurat_%sv%s_cellsamp2.txt'%(ident0, ident1), data_R.columns[inds], fmt='%s')
        np.savetxt('pbmc_seurat_%sv%s_cellsamplabels2.txt'%(ident0, ident1), labels2, fmt='%s')
    
    return y, z, a, b


def compare_clusters_multiple_runs_less_genes(ident0, ident1, num_runs=10):
    """Set up and run TN test data splitting framework
        run TN test multiple times to account for variance in optimization
        use less genes to save computation
    """
    
    picklefile = 'pbmc_seurat_%sv%s_de_tntest_multiple_runs_max.pickle'%(ident0, ident1)
    if os.path.isfile(picklefile):
        tn_test_multiple_runs = pickle.load(open(picklefile, 'rb'))
        p_tn, likelihood = sorted(tn_test_multiple_runs, key = lambda x:x[1])[-1]
        return p_tn
    
    y, z, a, b = load_and_split_data(ident0, ident1)
    
    # Run TN test on genes that actually change across the two clusters
    keep_inds = np.var(np.vstack((y, z)), 0) > 0
    np.savetxt('pbmc_seurat_%sv%s_filtered_genes.txt'%(ident0, ident1), keep_inds)
    y = y[:, keep_inds]
    z = z[:, keep_inds]
    a = a[keep_inds]

    genes_list = grab_interesting_de_genes(ident0, ident1)
    genes_to_test = [np.where(genes[keep_inds] == gene)[0][0] for gene in genes_list]

    # Run TN test multiple times to evaluate stability
    tn_test_multiple_runs = []
    num_runs = 10
    for _ in range(num_runs):
        p_tn, likelihood = tn.tn_test(y, z, a=a, b=b, verbose=True, learning_rate=0.5,
                                      eps=1e-1, genes_to_test=genes_to_test,
                                      return_likelihood=True)
        tn_test_multiple_runs.append((p_tn, likelihood[-1]))
    pickle.dump(tn_test_multiple_runs, open(picklefile, 'wb'))
    
    p_tn, likelihood = sorted(tn_test_multiple_runs, key = lambda x:x[1])[-1]
    return p_tn


def plot_bar_p(gene_names, p_correct, p_seurat=None, label=None, width=0.35,
               set_ylabel=True, p_tn_se=None, mark_red=[]):
    fig, ax = plt.subplots(figsize=(3, 1.5))
    ind = np.arange(len(p_correct))
    if p_seurat is not None:
        rects = ax.bar(ind-0.2,
                       p_seurat, width,
                       align='center',
                       label=label)
        rects = ax.bar(ind-0.2+0.3,
                       p_correct, width,
                       yerr=p_tn_se,
                       align='center',
                       label='TN test')
    else:
        rects = ax.bar(ind,
                       p_correct, width,
                       yerr=p_tn_se,
                       align='center',
                       label='TN test',
                       color='#ff7f0e')
    if set_ylabel: ax.set_ylabel('-log($p$)')
    xticks_pos = [0.65*patch.get_width()+patch.get_xy()[0]-0.1 for patch in rects]
    plt.xticks(np.array(xticks_pos), gene_names, rotation=90)
    for ticklabel in plt.gca().get_xticklabels():
        if ticklabel.get_text() in mark_red:
            ticklabel.set_color('r')
    plt.legend()     

        
def get_de_df(ident0, ident1):
    genes_list = grab_interesting_de_genes(ident0, ident1)
    df_dict = {'TN': list(-np.log10(np.max(tn_test_multiple_runs, axis=0)))}
    for i, test in enumerate(tests):
        seurat_de = -np.log10(pd.read_csv(os.path.join(data_dir, 'pbmc_seurat_%sv%s_de_%s.txt'%(ident0, ident1, test)), delimiter=' '))
        df_dict[test] = list(seurat_de.loc[genes_list]['p_val'])
    de_df = pd.DataFrame.from_dict(df_dict).set_index(genes_list).round(2)
    interesting_genes = genes_list[np.argmin(np.array(de_df), 1) == 0]
    return de_df, interesting_genes


def view_genes(gene_names, ident0, ident1, marked_genes=None):
    
    labels1 = np.array(pd.read_csv('./pbmc_seurat_%sv%s_cellsamplabels1.txt'%(ident0, ident1), delimiter=' ')).reshape(-1)
    labels2 = np.loadtxt('./pbmc_seurat_%sv%s_cellsamplabels2.txt'%(ident0, ident1))

    cellsamp1 = pd.read_csv(os.path.join(data_dir, 'pbmc_seurat_%sv%s_cellsamp1.txt'%(ident0, ident1)), delimiter=' ')
    pdict = {j:i for i, j in enumerate(data_R.columns)}
    inds1 = np.array([pdict[i] for i in np.array(cellsamp1).reshape(-1)])
    cellsamp2 = np.loadtxt('./pbmc_seurat_%sv%s_cellsamp2.txt'%(ident0, ident1), dtype='str')
    inds2 = np.array([pdict[i] for i in np.array(cellsamp2).reshape(-1)])

    data1 = data[inds1]
    data2 = data[inds2]

    y1, z1 = data1[labels1 == 0], data1[labels1 == 1]
    y2, z2 = data2[labels2 == 0], data2[labels2 == 1]
    
    for i, gene_name in enumerate(gene_names):
        if i%4 == 0:
            if i > 0:
                plt.tight_layout()
                plt.show()
            plt.figure(figsize=(2*4, 1.8))
        plt.subplot(1, 4, i%4+1)
        g_ind = np.where(genes == gene_name)[0][0]
        utils.plot_stacked_hist(np.vstack((y1, y2))[:, g_ind], np.vstack((z1, z2))[:, g_ind])
        plt.title(gene_name+('*' if gene_name in marked_genes else ''))
        plt.axis('off')
        print('%10s | type %s with mean %.2f'%(gene_name, ident0, np.mean(data[labels == ident0, g_ind])))
        print('%10s | type %s with mean %.2f'%(gene_name, ident1, np.mean(data[labels == ident1, g_ind])))

    plt.tight_layout()
    
    
def create_and_save_visualization(ident0, ident1, genes_to_view, genes_to_mark_red=[], ngenes=5):
    p_tn = compare_clusters_multiple_runs_less_genes(ident0, ident1)
    plt.savefig('figures/pbmc_%sv%s_multiple_runs_gene_exp.pdf'\
                %(ident0, ident1), format='pdf', dpi=500, bbox_inches='tight')
    genes_list = grab_interesting_de_genes(ident0, ident1)
    if genes_to_view is None:
        genes_to_view = genes_list
    view_genes(genes_to_view, ident0, ident1, marked_genes=[])
    plt.savefig('figures/pbmc_%sv%s_gene_exp.pdf'%(ident0, ident1), format='pdf', dpi=500, bbox_inches='tight')

    for i, test in enumerate(tests):
        seurat_de = -np.log10(pd.read_csv(os.path.join(data_dir, 'pbmc_seurat_%sv%s_de_%s.txt' \
                                                       %(ident0, ident1, test)), delimiter=' '))
        del seurat_de['avg_logFC']
        del seurat_de['pct.1']
        del seurat_de['pct.2']
        del seurat_de['p_val_adj']
        gene_names = seurat_de.index[:ngenes]
        p_seurat = seurat_de['p_val'][:ngenes]
        p_correct = [-np.log10(p_tn)[genes_list == gene][0] for gene in gene_names][:ngenes]
        if test == "t": 
            test = 'Welch\'s $t$'
        plot_bar_p(gene_names, p_correct, p_seurat=p_seurat, label=test,
                   set_ylabel = True if i == 0 or i == 4 else False, mark_red=genes_to_mark_red)
        plt.savefig('figures/pbmc_de_%sv%s_%s.pdf'%(ident0, ident1, test), format='pdf', dpi=500, bbox_inches='tight')

    inds = np.argsort(p_tn)[:ngenes]
    p_correct = -np.log10(p_tn)[inds]
    gene_names = genes_list[inds]
    plot_bar_p(gene_names, p_correct, set_ylabel=False)
    plt.savefig('figures/pbmc_de_%sv%s_tn.pdf'%(ident0, ident1), format='pdf', dpi=500, bbox_inches='tight')
    plt.show()
    
    
def tsne_select2(ident0, ident1):
    labels_ = np.array(labels).astype(str)
    labels_[~((labels == ident0) | (labels == ident1))] = 'Other'
    plt.scatter(tsne[labels_ == str(ident0), 0], tsne[labels_ == str(ident0), 1],
                label=r'c%s ($n$ = %s)'%(ident0, np.sum(labels_== str(ident0))), c='g', s=10)
    plt.scatter(tsne[labels_ == str(ident1), 0], tsne[labels_ == str(ident1), 1],
                label=r'c%s ($n$ = %s)'%(ident1, np.sum(labels_== str(ident1))), c='r', s=10)
    plt.scatter(tsne[labels_ == 'Other', 0], tsne[labels_ == 'Other', 1],
                label=r'Other ($n$ = %s)'%(np.sum(labels_=='Other')), c='k', s=10)
    plt.legend(bbox_to_anchor=(1.4, 1))
    plt.axis('equal')
    plt.axis('off')
    plt.savefig('figures/pbmc_tsne_%sv%s.pdf'%(ident0, ident1),
                    format='pdf', dpi=500, bbox_inches='tight')
    plt.show()    

In [None]:
np.random.seed(0)
ident0, ident1 = 0, 1
ngenes=10
genes_to_view = ['S100A4', 'S100A11', 'B2M', 'HLA-A']
genes_to_mark_red = ['B2M', 'HLA-A']
tsne_select2(ident0, ident1)
create_and_save_visualization(ident0, ident1, genes_to_view,
                              genes_to_mark_red=genes_to_mark_red, ngenes=ngenes)

p_tn = compare_clusters_multiple_runs_less_genes(ident0, ident1)
genes_list = grab_interesting_de_genes(ident0, ident1)

for i, test in enumerate(tests):
    seurat_de = -np.log10(pd.read_csv(os.path.join(data_dir, 'pbmc_seurat_%sv%s_de_%s.txt' \
                                                   %(ident0, ident1, test)), delimiter=' '))
    del seurat_de['avg_logFC']
    del seurat_de['pct.1']
    del seurat_de['pct.2']
    del seurat_de['p_val_adj']
    gene_names = seurat_de.index[:ngenes]
    p_seurat = seurat_de['p_val'][:ngenes]
    p_correct = [-np.log10(p_tn)[genes_list == gene][0] for gene in gene_names][:ngenes]
    
    print(test)
    for j in range(ngenes):
        print('%10s\tpDE: %.2f\tpTN: %.2f'%(gene_names[j], p_seurat[j],  p_correct[j]))
    print(' ')

In [None]:
np.random.seed(0)
ident0, ident1 = 1, 3
genes_to_view = None
create_and_save_visualization(ident0, ident1, genes_to_view, ngenes=10)

In [None]:
np.random.seed(0)
ident0, ident1 = 2, 5
genes_to_view = None
create_and_save_visualization(ident0, ident1, genes_to_view, ngenes=10)

In [None]:
ident0 = 0
ident1 = 1
ident2 = 2
ident3 = 3
ident4 = 5

labels_ = np.array(labels).astype(str)
labels_[~((labels == ident0) | (labels == ident1) | (labels == ident2) 
          | (labels == ident3) | (labels == ident4))] = 'Other'
plt.scatter(tsne[labels_ == str(ident0), 0], tsne[labels_ == str(ident0), 1],
            label=r'c%s ($n$ = %s)'%(ident0, np.sum(labels_== str(ident0))), c='g', s=10)
plt.scatter(tsne[labels_ == str(ident1), 0], tsne[labels_ == str(ident1), 1],
            label=r'c%s ($n$ = %s)'%(ident1, np.sum(labels_== str(ident1))), c='r', s=10)
plt.scatter(tsne[labels_ == str(ident2), 0], tsne[labels_ == str(ident2), 1],
            label=r'c%s ($n$ = %s)'%(ident2, np.sum(labels_== str(ident2))), c='teal', s=10)   
plt.scatter(tsne[labels_ == str(ident3), 0], tsne[labels_ == str(ident3), 1],
            label=r'c%s ($n$ = %s)'%(ident3, np.sum(labels_== str(ident3))), c='purple', s=10)    
plt.scatter(tsne[labels_ == str(ident4), 0], tsne[labels_ == str(ident4), 1],
            label=r'c%s ($n$ = %s)'%(ident4, np.sum(labels_== str(ident4))), c='brown', s=10)    
plt.scatter(tsne[labels_ == 'Other', 0], tsne[labels_ == 'Other', 1],
            label=r'Other ($n$ = %s)'%(np.sum(labels_=='Other')), c='k', s=10)
plt.legend(bbox_to_anchor=(1.4, 1))
plt.axis('equal')
plt.axis('off')
plt.savefig('figures/pbmc_tsne_supp.pdf',
                format='pdf', dpi=500, bbox_inches='tight')
plt.show()    

# Extra plots for recomb talk

In [None]:
plt.scatter(tsne[:, 0], tsne[:, 1], s=10)
plt.savefig('figures/slides_pbmc_visualization.pdf', format='pdf', dpi=500, bbox_inches='tight')
plt.figure()
utils.plot_labels_legend(tsne[:, 0], tsne[:, 1], labels, legend=False)
plt.savefig('figures/slides_pbmc_visualization_clust.pdf', format='pdf', dpi=500, bbox_inches='tight')

In [None]:
inds = ~((labels==3) | (labels==0))
plt.scatter(tsne[inds, 0], tsne[inds, 1], s=10, c='k')
plt.scatter(tsne[labels==0, 0], tsne[labels==0, 1], s=10)
plt.scatter(tsne[labels==3, 0], tsne[labels==3, 1], s=10, c='#d62728')
plt.savefig('figures/slides_pbmc_visualization_twoclust.pdf', format='pdf', dpi=500, bbox_inches='tight')

In [None]:
def plot_stacked_hist(v0, v1, title=None, label=None):
    """Plot two histograms on top of one another"""
    if label is None: label = ['0','1']
    bins = np.histogram(np.hstack((v0, v1)), bins=20)[1]
    data = [v0, v1]
    plt.hist(data, bins, label=label, alpha=0.8, color=['#1f77b4','#d62728'],
             density=True, edgecolor='none', rwidth=1.)
    if title is not None: plt.title(title)

In [None]:
v0 = np.random.normal(0, 1, 100)
v1 = np.random.normal(0, 1, 100)
plot_stacked_hist(v0, v1)
plt.savefig('figures/slides_ttest_null.pdf', format='pdf', dpi=500, bbox_inches='tight')

In [None]:
v0 = np.random.normal(2, 1, 100)
v1 = np.random.normal(-2, 1, 100)
plot_stacked_hist(v0, v1)
plt.savefig('figures/slides_ttest_alt.pdf', format='pdf', dpi=500, bbox_inches='tight')