## Import packages

In [1]:
%load_ext autoreload
%autoreload 2

import sys
import numpy as np
import seaborn as sns
import pandas as pd
import pickle
import scanpy as sc
sc.set_figure_params(dpi=100, dpi_save=300)
import scvi
import anndata as ad
from matplotlib import pyplot, cm
import os
from math import ceil
from scipy.stats import spearmanr
import math
import time

import leidenalg

from anndata import AnnData
import scanpy as sc
from scanpy import read
import pandas as pd
from sciPENN.Preprocessing import preprocess

import matplotlib.pyplot as plt
print(scvi.__version__)

0.9.1


# Read data: Pbmc (train), H1N1 (test)

In [2]:
adata_gene = sc.read("../Data/pbmc/pbmc_gene.h5ad")
adata_protein = sc.read("../Data/pbmc/pbmc_protein.h5ad")

In [3]:
adata_gene_test = sc.read("../Data/H1N1/gene_data.mtx").T
adata_gene_test.var.index = pd.read_csv("../Data/H1N1/gene_names.txt", index_col = 0).iloc[:, 0]
adata_gene_test.obs = pd.read_csv("../Data/H1N1/meta_data.txt", sep = ',', index_col = 0)

In [4]:
adata_protein_test = sc.read("../Data/H1N1/protein_data.mtx").T
adata_protein_test.var.index = [x[:len(x) - 5] for x in pd.read_csv("../Data/H1N1/protein_names.txt", index_col = 0).iloc[:,0]]
adata_protein_test.obs = pd.read_csv("../Data/H1N1/meta_data.txt", sep = ',', index_col = 0)

adata_protein_test.layers["raw"] = adata_protein_test.X

# Selecting highly variable genes - using gene expression measures from test data 

In [6]:
gene_train, protein_train, gene_test, bools, train_keys, categories = preprocess([adata_gene], [adata_protein], adata_gene_test, train_batchkeys = ["donor"], test_batchkey = "sample", gene_list = [], select_hvg = True, cell_normalize = True, log_normalize = True, gene_normalize = True, min_cells = 30, min_genes = 200)


QC Filtering Training Cells
QC Filtering Testing Cells

QC Filtering Training Genes
QC Filtering Testing Genes

Normalizing Training Cells
Normalizing Testing Cells

Log-Normalizing Training Data
Log-Normalizing Testing Data

Finding HVGs


... storing 'orig.ident' as categorical
... storing 'lane' as categorical
... storing 'donor' as categorical
... storing 'time' as categorical
... storing 'celltype.l1' as categorical
... storing 'celltype.l2' as categorical
... storing 'celltype.l3' as categorical
... storing 'Phase' as categorical
... storing 'batch' as categorical
... storing 'Dataset' as categorical
... storing 'barcode_check' as categorical
... storing 'tenx_lane' as categorical
... storing 'cohort' as categorical
... storing 'hash_maxID' as categorical
... storing 'hash_secondID' as categorical
... storing 'hto_classification' as categorical
... storing 'hto_classification_global' as categorical
... storing 'hash_ID' as categorical
... storing 'adjmfc.time' as categorical
... storing 'DMX_GLOBAL_BEST' as categorical
... storing 'DEMUXLET.BARCODE' as categorical
... storing 'sample' as categorical
... storing 'joint_classification_global' as categorical
... storing 'timepoint' as categorical
... storing 'K0' as ca


Normalizing Gene Training Data by Batch


100%|██████████| 8/8 [00:06<00:00,  1.18it/s]



Normalizing Protein Training Data by Batch


100%|██████████| 8/8 [00:02<00:00,  3.14it/s]



Normalizing Gene Testing Data by Batch


100%|██████████| 20/20 [00:01<00:00, 17.56it/s]


In [7]:
hvg = gene_test.var_names
cells_test = gene_test.obs_names
cells_train = gene_train.obs_names

In [8]:
del gene_train
del gene_test
del protein_train

import gc
gc.collect()

4701

# Format data

In [9]:
# What proteins overlap between the test and train data?

def replace(protein):
    if protein == 'CD3':
        return 'CD3-1'
    elif protein == 'CD4':
        return 'CD4-1'
    
    return protein

adata_protein_test.var_names = [replace(x) for x in adata_protein_test.var_names]

train_protein = adata_protein.var_names
test_protein = adata_protein_test.var_names
overlap_protein = train_protein[train_protein.isin(test_protein)]

In [10]:
## Subsetting the data by the HVG - pbmc

adata_gene_pbmc_hvg = adata_gene[cells_train, hvg].copy()

In [11]:

## Subsetting the data by the HVG - h1n1

adata_gene_h1n1_hvg = adata_gene_test[cells_test, hvg].copy()

In [12]:
(adata_gene_pbmc_hvg.var.index == adata_gene_h1n1_hvg.var.index).mean()

1.0

In [13]:
adata_protein = adata_protein[cells_train, :].copy()
adata_protein_test = adata_protein_test[cells_test, :].copy()

In [14]:
# Batches (subject) in training data - pbmc (8 subjects)

adata_gene_pbmc_hvg.obs['patient'] = pd.DataFrame(adata_gene_pbmc_hvg.obs['donor']).copy()
adata_gene_pbmc_hvg.obs['patient'] = adata_gene_pbmc_hvg.obs['donor'].astype("str")

In [15]:
# Batches (subject) in test data - h1n1 (20 subjects)

adata_gene_h1n1_hvg.obs['patient'] = pd.DataFrame(adata_gene_h1n1_hvg.obs['sample']).copy()
adata_gene_h1n1_hvg.obs['patient'] = adata_gene_h1n1_hvg.obs['patient'].astype("str")

In [16]:
## Combine data

adata = ad.concat([adata_gene_pbmc_hvg.copy(), adata_gene_h1n1_hvg.copy()],
                     join='outer')

### Note: Train on PBMC

In [17]:
train_patients = adata.obs["patient"].unique()[0:8]

In [18]:
test_patients = adata.obs["patient"].unique()[8:]

# Subset data based on HVGs and Hold Out Test Protein Set

In [19]:
adata_final = adata.copy()

In [20]:
held_out_proteins = adata_protein_test[cells_test, overlap_protein].copy()

### Now we hold-out the proteins for the test patients dataset. To do so, we can replace all the values with 0s. We will store the original values to validate after training.

In [21]:
# Modified this code cell to predict all p = 224 proteins

n, p = adata_protein.shape
n_H1N1, p_H1N1 = adata_protein_test.shape

protein_dat = pd.DataFrame(np.zeros(shape = (n + n_H1N1, p), dtype = 'float32'), 
                           index = list(adata_protein.obs_names) + list(adata_protein_test.obs_names),
                           columns = adata_protein.var_names)

protein_dat.iloc[:n] = adata_protein.X.toarray().copy() #fill the protein training data, leave test data as 0s

adata_final.obsm["protein_expression"] = protein_dat


In [22]:
n_train = len(cells_train)
n_test = len(cells_test)
total = n_train+n_test

print(n_train)
print(n_test)

train_index = list(range(0,n_train))
test_index = list(range(n_train,total))

161748
53200


# Remove additional data from memory:

In [23]:
del adata_gene
del adata_protein
del adata_protein_test
del adata_gene_test
del adata_gene_pbmc_hvg
del adata

# Run TotalVI

In [24]:
scvi.data.setup_anndata(adata_final, batch_key="patient", 
                        protein_expression_obsm_key="protein_expression")

[34mINFO    [0m Using batches from adata.obs[1m[[0m[32m"patient"[0m[1m][0m                                             
[34mINFO    [0m No label_key inputted, assuming all cells have same label                           
[34mINFO    [0m Using data from adata.X                                                             
[34mINFO    [0m Computing library size prior per batch                                              
[34mINFO    [0m Using protein expression from adata.obsm[1m[[0m[32m'protein_expression'[0m[1m][0m                      
[34mINFO    [0m Using protein names from columns of adata.obsm[1m[[0m[32m'protein_expression'[0m[1m][0m                
[34mINFO    [0m Found batches with missing protein expression                                       
[34mINFO    [0m Successfully registered anndata object containing [1;36m214948[0m cells, [1;36m1000[0m vars, [1;36m28[0m       
         batches, [1;36m1[0m labels, and [1;36m224[0m proteins. 

In [25]:
scvi.data.view_anndata_setup(adata_final)

In [26]:
totalvae = scvi.model.TOTALVI(
    adata_final,
    latent_distribution = "normal",
    n_layers_decoder = 2)

In [27]:
# Training with the default number of epochs 

# Training with the default number of epochs 
n_epochs = 400
lr = 4e-3

if os.path.isdir('weights_dir/totalvi_seurattoh1n1'):
    totalvae = totalvae.load("weights_dir/totalvi_seurattoh1n1", adata = adata_final)
else:
    totalvae.train(max_epochs=400)
    plt.plot(totalvae.history["elbo_validation"], label="test")
    plt.title("Negative ELBO over training epochs")
    plt.legend()
    
    totalvae.save("weights_dir/totalvi_seurattoh1n1")

  and should_run_async(code)


[34mINFO    [0m Found batches with missing protein expression                                       
[34mINFO    [0m Using data from adata.X                                                             
[34mINFO    [0m Computing library size prior per batch                                              
[34mINFO    [0m Registered keys:[1m[[0m[32m'X'[0m, [32m'batch_indices'[0m, [32m'local_l_mean'[0m, [32m'local_l_var'[0m, [32m'labels'[0m,     
         [32m'protein_expression'[0m[1m][0m                                                               
[34mINFO    [0m Successfully registered anndata object containing [1;36m214948[0m cells, [1;36m1000[0m vars, [1;36m28[0m       
         batches, [1;36m1[0m labels, and [1;36m224[0m proteins. Also registered [1;36m0[0m extra categorical covariates 
         and [1;36m0[0m extra continuous covariates.                                                  


# Analyze output - Results on training data

In [28]:
_, protein_means = totalvae.get_normalized_expression(
    transform_batch=train_patients,
    include_protein_background=True,
    sample_protein_mixing=False,
    return_mean=True,
)

  and should_run_async(code)


In [29]:
protein_means

Unnamed: 0,CD80,CD86,CD274,CD273,CD275-1,CD11b-1,Galectin-9,CD270,CD252,CD155,...,CD161,CCR10,CD271,GP130,CD199,CD45RB,CD46,VEGFR-3,CLEC2,CD26-2
L1_AAACCCAAGAAACTCA,2.122579,42.603584,3.542659,1.900590,5.734207,252.631058,8.444075,6.467201,2.292355,23.216591,...,2.250927,15.861979,9.327844,9.159569,12.023575,4.953502,25.004738,3.501064,85.095512,8.973317
L1_AAACCCAAGACATACA,1.944327,3.131711,2.883026,2.501685,4.673710,43.982746,8.101368,5.943116,1.780041,1.220875,...,5.135927,10.778120,5.622022,6.469916,9.121045,11.672628,19.283367,2.721327,29.697044,20.427986
L1_AAACCCACAACTGGTT,1.730093,3.039639,2.562512,1.408242,6.178349,42.735741,7.498310,5.121841,1.813404,1.195522,...,1.565825,9.835451,5.074905,6.924809,9.119061,26.926096,15.520886,2.731612,30.029915,15.653356
L1_AAACCCACACGTACTA,1.277338,3.048873,2.309936,1.289405,4.371858,48.280956,6.546273,4.899028,1.600756,1.463000,...,10.583746,9.258512,10.843285,3.101491,8.288669,6.429509,11.257324,2.385475,30.942623,7.496638
L1_AAACCCACAGCATACT,1.813619,3.200383,2.561083,1.378076,4.037916,46.464508,7.602444,5.516404,1.816808,1.389404,...,1.555040,11.113550,5.503995,7.180275,8.959747,27.838329,17.474224,2.702652,38.023842,16.081009
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TTTGTCAGTACGACCC_H1B2ln6,1.058843,3.002409,2.087906,1.097152,8.679170,33.836662,5.503890,3.256434,1.317475,1.078937,...,8.151533,5.980213,6.574310,2.477416,6.724470,11.029349,8.542373,1.960657,23.119783,10.067211
TTTGTCAGTCAAACTC_H1B2ln6,1.184887,2.999192,2.110954,1.185924,8.520926,33.671787,5.722351,3.208076,1.387711,1.087927,...,10.776732,6.141523,6.814425,2.537821,6.993121,11.813402,8.822578,2.252738,22.407648,24.358059
TTTGTCATCCCATTTA_H1B2ln6,1.131335,3.067525,2.112303,1.148918,6.918774,39.463673,5.639174,3.264669,1.445550,1.129810,...,4.250178,6.910967,7.846093,2.637681,7.319806,7.544795,8.754487,1.988150,25.133907,8.187986
TTTGTCATCGAGAACG_H1B2ln6,1.491890,22.897707,3.211294,1.417611,11.139934,174.541382,6.787586,4.410346,1.906893,12.960753,...,1.619370,9.131236,5.693849,6.869630,9.849549,3.717112,14.070439,2.712232,64.030907,8.808941


#### Note that: transform_batch is a power parameter. Setting this allows one to predict the expression of cells as if they came from the inputted batch. In this case, we’ve observed protein expression in the training batchs “RPM211 and RPM232” (batch categories from original adata object), but we have no protein expression in the test batchs “RPM215 and RPM218”. We’d like to take the cells of the trainig batch and make a counterfactual prediction: “What would the expression look like if my batch "RPM211 and RPM232" cells came from batch “RPM215 and RPM218”?”

# Imputed protein expression: 

In [30]:
true_protein_test = pd.DataFrame(held_out_proteins.X.toarray(), index = held_out_proteins.obs.index, columns = held_out_proteins.var.index)

  and should_run_async(code)


In [31]:
imputed_proteins_test = protein_means[adata_final.obs.patient.isin(test_patients)]

pat_names = adata_final.obs['patient'].isin(test_patients)
patients = adata_final.obs.patient[pat_names].values
# imputed_proteins_test = imputed_pros[overlap_protein] # Subset totalvi output to only include overlapping proteins

In [32]:
def corr2_coeff(A, B, pearson = True):
    if pearson:
        # Rowwise mean of input arrays & subtract from input arrays themeselves
        A_mA = A - A.mean(1)[:, None]
        B_mB = B - B.mean(1)[:, None]

        # Sum of squares across rows
        ssA = (A_mA**2).sum(1)
        ssB = (B_mB**2).sum(1)

        # Finally get corr coeff
        corr_mat = np.dot(A_mA, B_mB.T) / np.sqrt(np.dot(ssA[:, None],ssB[None]))
        
        return corr_mat[range(corr_mat.shape[0]), range(corr_mat.shape[0])]
    
    else:
        corrs = [0.] * A.shape[0]
        
        for i in range(A.shape[0]):
            corrs[i] = spearmanr(A[i], B[i])[0]
            
        return corrs

In [35]:
# Normalize totalvi output, and gold standard counts

true_protein_test = AnnData(true_protein_test)
features = imputed_proteins_test.columns
imputed_proteins_test = AnnData(imputed_proteins_test)

sc.pp.normalize_total(true_protein_test)
sc.pp.log1p(true_protein_test)

sc.pp.normalize_total(imputed_proteins_test)
sc.pp.log1p(imputed_proteins_test)

for patient in test_patients:
    indices = [x == patient for x in patients]
    sub_adata = imputed_proteins_test[indices]
    sc.pp.scale(sub_adata)
    imputed_proteins_test[indices] = sub_adata.X
    
    sub_adata = true_protein_test[indices]
    sc.pp.scale(sub_adata)
    true_protein_test[indices] = sub_adata.X
    
imputed_proteins_test.var.index = features
true_protein_test = pd.DataFrame(true_protein_test.X)
imputed_proteins_test = pd.DataFrame(imputed_proteins_test[:, overlap_protein].X, columns = overlap_protein)

  and should_run_async(code)
  view_to_actual(adata)


In [36]:
true_protein_test

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,51,52,53,54,55,56,57,58,59,60
0,-0.856749,-0.799429,-0.908386,1.066265,-0.088338,-0.875649,-0.827785,-0.549446,1.427857,-0.358665,...,-1.549895,0.504503,0.911077,-0.361428,0.835827,-0.924783,-0.389847,-1.155132,3.531978,1.457820
1,1.261656,-0.637372,0.479529,-0.916338,-1.062947,-0.432882,0.280367,-1.267646,2.101276,-0.667196,...,0.502855,0.362364,-1.450839,-1.119216,-0.501357,0.042378,0.694279,-0.338780,-0.403490,-0.673463
2,2.900027,-0.799429,-0.908386,0.774333,0.270856,0.468479,0.977750,1.274535,-0.904078,-0.334997,...,-1.167870,0.538626,-1.416779,-0.127589,-0.562933,0.828592,0.814127,0.156215,-0.459826,-1.458512
3,-0.879315,0.199079,-0.917188,-0.873633,-1.041756,-0.507675,0.814200,1.315016,-0.562525,-0.355751,...,-0.387379,0.270822,-1.318072,0.131795,-0.636566,1.034567,0.713609,0.400571,0.150326,-0.142480
4,-0.823877,2.448060,0.620528,-0.868077,1.045721,1.285337,-0.983535,-0.189732,-0.318040,0.107841,...,-0.322182,-1.271282,0.811585,0.019888,1.348030,-0.972814,-1.681144,1.226925,-0.504852,-0.470269
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
53195,0.160563,-0.759005,0.559321,0.164986,-0.383920,-0.674692,1.265615,-0.529768,-0.772042,-0.147527,...,-1.913172,1.004717,0.413846,0.104888,1.240839,1.128798,0.090976,-0.935341,-0.317434,1.098921
53196,0.326723,-0.759005,-0.890973,-1.151192,-1.514158,-0.160409,1.232731,-0.956962,0.105984,0.134345,...,-0.086974,0.693600,-0.532668,-1.188730,1.078710,0.763189,-0.212628,-0.331338,-0.227680,1.770346
53197,-1.096726,-0.315758,0.448617,-1.150636,-1.547090,-0.413327,-0.905030,-0.895673,1.035224,-0.412564,...,-0.251097,0.921159,0.230813,-0.020701,0.258502,-0.589080,-0.664109,-0.162576,1.690028,0.113766
53198,-1.129333,1.331117,0.920766,-0.417793,-0.933391,-0.481890,-0.821922,0.060114,-0.801555,-0.733166,...,-0.137300,-0.896412,-1.185782,-0.292678,0.652033,-1.203004,-0.867004,0.130169,-0.279391,-0.485385


In [37]:
sq = lambda x, y: (x - y)**2

  and should_run_async(code)


In [38]:
corrs_table = np.zeros((imputed_proteins_test.shape[1], len(np.unique(patients))))
sq_table = corrs_table.copy()

for i, patient in enumerate(np.unique(patients)):
    truth = true_protein_test[patients == patient].to_numpy()
    imputed = imputed_proteins_test[patients == patient].to_numpy()

    corrs_table[:, i] = corr2_coeff(truth.T, imputed.T)
    sq_table[:, i] = sq(truth, imputed).mean(axis = 0)

if np.isnan(corrs_table).sum() > 0:
    corrs_table[np.isnan(corrs_table)] = 0
    
corrs_table = pd.DataFrame(corrs_table)
sq_table = pd.DataFrame(sq_table)
corrs_table.index, corrs_table.columns = imputed_proteins_test.columns, np.unique(patients)
sq_table.index, sq_table.columns = imputed_proteins_test.columns, np.unique(patients)

In [39]:
corrs_table.mean(axis = 1)

index
CD80     0.104965
CD86     0.851459
CD274    0.026103
CD273    0.037045
CD70     0.043259
           ...   
CD28     0.906934
CD127    0.842853
CD71     0.563087
CD16     0.838444
CD161    0.750554
Length: 61, dtype: float64

In [40]:
corrs_table.mean(axis = 1)

index
CD80     0.104965
CD86     0.851459
CD274    0.026103
CD273    0.037045
CD70     0.043259
           ...   
CD28     0.906934
CD127    0.842853
CD71     0.563087
CD16     0.838444
CD161    0.750554
Length: 61, dtype: float64

In [41]:
#here are correlations

corrs_table.mean()

200_d0    0.552016
201_d0    0.508506
205_d0    0.518476
207_d0    0.503090
209_d0    0.528032
212_d0    0.518836
215_d0    0.526223
229_d0    0.513464
233_d0    0.512974
234_d0    0.548481
236_d0    0.501280
237_d0    0.519196
245_d0    0.505793
250_d0    0.527070
256_d0    0.552442
261_d0    0.521052
268_d0    0.531649
273_d0    0.501960
277_d0    0.537550
279_d0    0.500678
dtype: float64

In [42]:
#here are correlations

corrs_table.mean().mean()

0.521438385712481

In [43]:
corrs_table

Unnamed: 0_level_0,200_d0,201_d0,205_d0,207_d0,209_d0,212_d0,215_d0,229_d0,233_d0,234_d0,236_d0,237_d0,245_d0,250_d0,256_d0,261_d0,268_d0,273_d0,277_d0,279_d0
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1
CD80,0.067830,0.102745,0.097995,0.038045,0.092374,0.072395,0.151426,0.095299,0.085692,0.122360,0.133657,0.137003,0.056307,0.181671,0.128072,0.091860,0.128460,0.088199,0.130436,0.097472
CD86,0.846862,0.802574,0.835720,0.872993,0.886083,0.862786,0.837629,0.826333,0.856093,0.882573,0.835255,0.835984,0.835322,0.849250,0.900259,0.855067,0.873499,0.800275,0.903890,0.830742
CD274,0.063382,0.007885,0.061903,0.002642,0.093262,-0.049200,0.041923,-0.023651,-0.004834,0.084252,0.010749,0.033781,0.028412,-0.053071,0.126692,-0.024628,0.003610,0.028629,-0.008858,0.099171
CD273,0.023464,0.008926,0.055787,0.049147,0.003316,0.012657,0.040714,0.035787,0.046929,0.037032,0.076803,0.033634,0.041652,0.019952,0.090911,0.015314,0.056135,0.082130,0.047145,-0.036534
CD70,0.063288,0.079553,0.070544,-0.003878,0.044235,-0.013428,0.022878,0.002292,0.024651,0.064813,0.020768,0.072505,0.069121,0.066704,0.103654,0.027869,0.045426,0.041417,0.058445,0.004323
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
CD28,0.904603,0.912193,0.895245,0.909690,0.898745,0.925953,0.884411,0.920440,0.907860,0.907743,0.906741,0.903991,0.910266,0.911136,0.926030,0.898768,0.887611,0.908468,0.894643,0.924151
CD127,0.847347,0.838587,0.843632,0.855256,0.838631,0.838462,0.815328,0.821660,0.858348,0.843002,0.842489,0.848810,0.848484,0.835072,0.877162,0.851826,0.838979,0.828323,0.822266,0.863393
CD71,0.598493,0.481297,0.493003,0.514387,0.568792,0.625164,0.592932,0.612026,0.544021,0.561750,0.493898,0.660973,0.500061,0.687291,0.562354,0.548431,0.494674,0.656822,0.566133,0.499233
CD16,0.806394,0.857252,0.825018,0.879405,0.787396,0.933000,0.891169,0.885719,0.920662,0.848553,0.845450,0.866403,0.760492,0.830763,0.835204,0.860066,0.920356,0.861716,0.646162,0.707693


In [44]:
sq_table.mean()

  and should_run_async(code)


200_d0    0.895662
201_d0    0.982525
205_d0    0.962848
207_d0    0.993472
209_d0    0.943621
212_d0    0.962040
215_d0    0.947200
229_d0    0.972533
233_d0    0.973548
234_d0    0.902608
236_d0    0.996718
237_d0    0.961256
245_d0    0.988019
250_d0    0.945409
256_d0    0.894740
261_d0    0.957472
268_d0    0.936443
273_d0    0.995781
277_d0    0.924571
279_d0    0.998247
dtype: float64

In [45]:
sq_table.mean().mean()

0.9567355702585372

In [46]:
corrs_table

Unnamed: 0_level_0,200_d0,201_d0,205_d0,207_d0,209_d0,212_d0,215_d0,229_d0,233_d0,234_d0,236_d0,237_d0,245_d0,250_d0,256_d0,261_d0,268_d0,273_d0,277_d0,279_d0
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1
CD80,0.067830,0.102745,0.097995,0.038045,0.092374,0.072395,0.151426,0.095299,0.085692,0.122360,0.133657,0.137003,0.056307,0.181671,0.128072,0.091860,0.128460,0.088199,0.130436,0.097472
CD86,0.846862,0.802574,0.835720,0.872993,0.886083,0.862786,0.837629,0.826333,0.856093,0.882573,0.835255,0.835984,0.835322,0.849250,0.900259,0.855067,0.873499,0.800275,0.903890,0.830742
CD274,0.063382,0.007885,0.061903,0.002642,0.093262,-0.049200,0.041923,-0.023651,-0.004834,0.084252,0.010749,0.033781,0.028412,-0.053071,0.126692,-0.024628,0.003610,0.028629,-0.008858,0.099171
CD273,0.023464,0.008926,0.055787,0.049147,0.003316,0.012657,0.040714,0.035787,0.046929,0.037032,0.076803,0.033634,0.041652,0.019952,0.090911,0.015314,0.056135,0.082130,0.047145,-0.036534
CD70,0.063288,0.079553,0.070544,-0.003878,0.044235,-0.013428,0.022878,0.002292,0.024651,0.064813,0.020768,0.072505,0.069121,0.066704,0.103654,0.027869,0.045426,0.041417,0.058445,0.004323
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
CD28,0.904603,0.912193,0.895245,0.909690,0.898745,0.925953,0.884411,0.920440,0.907860,0.907743,0.906741,0.903991,0.910266,0.911136,0.926030,0.898768,0.887611,0.908468,0.894643,0.924151
CD127,0.847347,0.838587,0.843632,0.855256,0.838631,0.838462,0.815328,0.821660,0.858348,0.843002,0.842489,0.848810,0.848484,0.835072,0.877162,0.851826,0.838979,0.828323,0.822266,0.863393
CD71,0.598493,0.481297,0.493003,0.514387,0.568792,0.625164,0.592932,0.612026,0.544021,0.561750,0.493898,0.660973,0.500061,0.687291,0.562354,0.548431,0.494674,0.656822,0.566133,0.499233
CD16,0.806394,0.857252,0.825018,0.879405,0.787396,0.933000,0.891169,0.885719,0.920662,0.848553,0.845450,0.866403,0.760492,0.830763,0.835204,0.860066,0.920356,0.861716,0.646162,0.707693


In [47]:
corrs_table.to_csv('corrs_results/totalvi_pbmctoh1n1.csv')
sq_table.to_csv('mse_results/totalvi_pbmctoh1n1.csv')

  and should_run_async(code)


# Interval Coverage: 1000 Samples - Normalized, scaled, and log-transformed

In [48]:
# normalized, scaled by patient, and log-transformed gold standard counts
true_protein_test.columns = overlap_protein
true_protein_test

index,CD80,CD86,CD274,CD273,CD70,CD40,CD3-1,CD4-1,CD8,CD19,...,CD184,CD2,CD303,IgD,CD18,CD28,CD127,CD71,CD16,CD161
0,-0.856749,-0.799429,-0.908386,1.066265,-0.088338,-0.875649,-0.827785,-0.549446,1.427857,-0.358665,...,-1.549895,0.504503,0.911077,-0.361428,0.835827,-0.924783,-0.389847,-1.155132,3.531978,1.457820
1,1.261656,-0.637372,0.479529,-0.916338,-1.062947,-0.432882,0.280367,-1.267646,2.101276,-0.667196,...,0.502855,0.362364,-1.450839,-1.119216,-0.501357,0.042378,0.694279,-0.338780,-0.403490,-0.673463
2,2.900027,-0.799429,-0.908386,0.774333,0.270856,0.468479,0.977750,1.274535,-0.904078,-0.334997,...,-1.167870,0.538626,-1.416779,-0.127589,-0.562933,0.828592,0.814127,0.156215,-0.459826,-1.458512
3,-0.879315,0.199079,-0.917188,-0.873633,-1.041756,-0.507675,0.814200,1.315016,-0.562525,-0.355751,...,-0.387379,0.270822,-1.318072,0.131795,-0.636566,1.034567,0.713609,0.400571,0.150326,-0.142480
4,-0.823877,2.448060,0.620528,-0.868077,1.045721,1.285337,-0.983535,-0.189732,-0.318040,0.107841,...,-0.322182,-1.271282,0.811585,0.019888,1.348030,-0.972814,-1.681144,1.226925,-0.504852,-0.470269
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
53195,0.160563,-0.759005,0.559321,0.164986,-0.383920,-0.674692,1.265615,-0.529768,-0.772042,-0.147527,...,-1.913172,1.004717,0.413846,0.104888,1.240839,1.128798,0.090976,-0.935341,-0.317434,1.098921
53196,0.326723,-0.759005,-0.890973,-1.151192,-1.514158,-0.160409,1.232731,-0.956962,0.105984,0.134345,...,-0.086974,0.693600,-0.532668,-1.188730,1.078710,0.763189,-0.212628,-0.331338,-0.227680,1.770346
53197,-1.096726,-0.315758,0.448617,-1.150636,-1.547090,-0.413327,-0.905030,-0.895673,1.035224,-0.412564,...,-0.251097,0.921159,0.230813,-0.020701,0.258502,-0.589080,-0.664109,-0.162576,1.690028,0.113766
53198,-1.129333,1.331117,0.920766,-0.417793,-0.933391,-0.481890,-0.821922,0.060114,-0.801555,-0.733166,...,-0.137300,-0.896412,-1.185782,-0.292678,0.652033,-1.203004,-0.867004,0.130169,-0.279391,-0.485385


In [49]:
# Create empty vectors to store percentiles
n, p = true_protein_test.shape
s = 1000

q10 = np.empty(shape=(n,p))
q90 = np.empty(shape=(n,p))
q25 = np.empty(shape=(n,p))
q75 = np.empty(shape=(n,p))


pat_names = adata_final.obs['patient'].isin(test_patients)
patients = adata_final.obs.patient[pat_names].values

  and should_run_async(code)


In [50]:
class generator:
    def __init__(self, test_idx, batch_size, seed = 123):
        self.test_idx = test_idx
        self.batch_size = batch_size
        self.seed = seed
        
    def __iter__(self):
        np.random.seed(self.seed)
        idx = np.random.choice(range(len(self.test_idx)), len(self.test_idx), False)
        
        batch, test_batch = [], []
        
        for index in idx:
            batch.append(index)
            test_batch.append(self.test_idx[index])
            
            if len(batch) == self.batch_size:
                yield batch, test_batch
                batch, test_batch = [], []
                
        if batch:
            yield batch, test_batch
    
    def get(self):
        return self

In [51]:
base_path = 'totalVI_quantiles_h1n1'

if not os.path.isdir(base_path):
    os.mkdir(base_path)

In [52]:
if all([os.path.isfile(os.path.join(base_path, path)) for path in ('q10.csv', 'q25.csv', 'q75.csv', 'q90.csv')]):
    q10 = pd.read_csv(os.path.join(base_path, 'q10.csv'), index_col = 0)
    q25 = pd.read_csv(os.path.join(base_path, 'q25.csv'), index_col = 0)
    q75 = pd.read_csv(os.path.join(base_path, 'q75.csv'), index_col = 0)
    q90 = pd.read_csv(os.path.join(base_path, 'q90.csv'), index_col = 0)
    
else:    
    for idx, test_idx in generator(test_index, 5000).get(): 
        start_time = time.time()

        ######################################################
        # (1) sequentially sample 5000 from the test index #
        #####################################################

        _, protein_means_samples = totalvae.get_normalized_expression(
            transform_batch=train_patients,
            n_samples=1000,
            include_protein_background=True,
            sample_protein_mixing=False,
            return_mean = False,
            indices = test_idx,
        )

        # Subset to overlapping proteins for evaluations

        # index of overlaping_proteins
        name_map = {protein: i for i, protein in enumerate(imputed_proteins_test.columns)}
        index_overlap = [name_map[protein] for protein in overlap_protein]


        ###################################
        # (2) normalize and scale samples #
        ###################################

        # normalize totalvi sample output
        protein_means_samples = protein_means_samples[:, index_overlap]
        sf = protein_means_samples.sum(axis = 1)
        sf = np.median(sf, axis = 0)[None, :]/sf

        protein_means_samples_norm = protein_means_samples * sf[:, None, :]
        protein_means_samples_norm = np.log(protein_means_samples_norm + 1)
        protein_means_samples_norm = protein_means_samples_norm[:, index_overlap]

        
        patient_indices = patients[idx]
        for patient in test_patients:
            indices = [x == patient for x in patient_indices]
            sub_data = protein_means_samples_norm[indices]
            mean, sd = sub_data.mean(axis = 0), sub_data.std(axis = 0)
            sub_data = (sub_data - mean)/sd
            protein_means_samples_norm[indices] = sub_data

        #########################################################
        # (3) calculate percentiles 0.10, 0.90, 0.25, and 0.75 #
        ########################################################

        # loop for each protein
        q10[idx] = np.percentile(protein_means_samples_norm, 10, axis = 2)
        q90[idx] = np.percentile(protein_means_samples_norm, 90, axis = 2)
        q25[idx] = np.percentile(protein_means_samples_norm, 25, axis = 2)
        q75[idx] = np.percentile(protein_means_samples_norm, 75, axis = 2)


        ##################################
        # (4) remove samples from memory #
        ###################################

        del protein_means_samples
        del protein_means_samples_norm

        end_time = time.time()

        print("Run time for loop: --- %s seconds ---" % (end_time - start_time)) #

    cols = imputed_proteins_test[overlap_protein].columns
    names = imputed_proteins_test[overlap_protein].index
    q10 = pd.DataFrame(q10, columns=cols, index = adata_final.obs.index[n_train:])
    q90 = pd.DataFrame(q90, columns=cols, index = adata_final.obs.index[n_train:])
    q25 = pd.DataFrame(q25, columns=cols, index = adata_final.obs.index[n_train:])
    q75 = pd.DataFrame(q75, columns=cols, index = adata_final.obs.index[n_train:])

    q10.to_csv(os.path.join(base_path, 'q10.csv'))
    q25.to_csv(os.path.join(base_path, 'q25.csv'))
    q75.to_csv(os.path.join(base_path, 'q75.csv'))
    q90.to_csv(os.path.join(base_path, 'q90.csv'))

Run time for loop: --- 324.7294840812683 seconds ---
Run time for loop: --- 326.95079922676086 seconds ---
Run time for loop: --- 335.30950570106506 seconds ---
Run time for loop: --- 338.5204451084137 seconds ---
Run time for loop: --- 336.64714765548706 seconds ---
Run time for loop: --- 331.2415916919708 seconds ---
Run time for loop: --- 335.9373321533203 seconds ---
Run time for loop: --- 335.1538460254669 seconds ---
Run time for loop: --- 335.31837701797485 seconds ---
Run time for loop: --- 329.8086862564087 seconds ---
Run time for loop: --- 223.7840359210968 seconds ---


In [54]:
true_protein_test.index = held_out_proteins.obs.index

In [55]:
true_protein_test.to_csv(os.path.join(base_path, "truth.csv"))

In [56]:
r50 = (true_protein_test < q75)
l50 = (true_protein_test > q25)

print(f"Effective Coverage Probability for Nominal 50% PIs: {(r50*l50).mean()}")
print(f"Mean effective Coverage Probability for Nominal 50% PI: {(r50*l50).mean().mean()}")

Effective Coverage Probability for Nominal 50% PIs: index
CD80     0.093947
CD86     0.106729
CD274    0.100244
CD273    0.071748
CD70     0.123590
           ...   
CD28     0.013816
CD127    0.039041
CD71     0.070752
CD16     0.054586
CD161    0.054474
Length: 61, dtype: float64
Mean effective Coverage Probability for Nominal 50% PI: 0.05335726611610995


  f"evaluating in Python space because the {repr(op_str)} "


In [57]:
lengths = np.subtract(q75,q25)
print(f"Mean 50% interval lengths: {lengths.mean()}")
print(f"Overall mean 50% interval length: {lengths.mean().mean()}")

Mean 50% interval lengths: index
CD80     0.187028
CD86     0.073804
CD274    0.193295
CD273    0.206771
CD70     0.412705
           ...   
CD28     0.164624
CD127    0.235965
CD71     0.205954
CD16     0.117590
CD161    0.166084
Length: 61, dtype: float64
Overall mean 50% interval length: 0.17023453237704683


  and should_run_async(code)


In [58]:
r80 = (true_protein_test < q90)
l80 = (true_protein_test > q10)

print(f"Effective Coverage Probability for Nominal 80% PIs: {(r80*l80).mean()}")
print(f"Mean effective Coverage Probability for Nominal 80% PI: {(r80*l80).mean().mean()}")

Effective Coverage Probability for Nominal 80% PIs: index
CD80     0.168553
CD86     0.190263
CD274    0.183853
CD273    0.134436
CD70     0.221936
           ...   
CD28     0.026391
CD127    0.072594
CD71     0.130752
CD16     0.102350
CD161    0.106898
Length: 61, dtype: float64
Mean effective Coverage Probability for Nominal 80% PI: 0.10188524590163936


  f"evaluating in Python space because the {repr(op_str)} "


In [59]:
lengths = np.subtract(q90,q10)
print(f"Mean 80% interval lengths: {lengths.mean()}")
print(f"Overall mean 80% interval length: {lengths.mean().mean()}")

  and should_run_async(code)


Mean 80% interval lengths: index
CD80     0.356352
CD86     0.140541
CD274    0.367524
CD273    0.393052
CD70     0.784078
           ...   
CD28     0.314253
CD127    0.449074
CD71     0.391809
CD16     0.223784
CD161    0.316557
Length: 61, dtype: float64
Overall mean 80% interval length: 0.3245049394860531


## Double checking intervals using the protein mean estimates from totalVI

### Normalized:

In [60]:
imputed_proteins_test.index = held_out_proteins.obs.index

In [61]:
imputed_proteins_test.columns = overlap_protein
r50 = (imputed_proteins_test < q75)
l50 = (imputed_proteins_test > q25)

print((r50*l50).mean())
print(f"Effective Coverage Probability for Nominal 50% PI: {(r50*l50).mean().mean()}")

index
CD80     0.058045
CD86     0.124417
CD274    0.108139
CD273    0.070056
CD70     0.221034
           ...   
CD28     0.012350
CD127    0.040526
CD71     0.162237
CD16     0.054211
CD161    0.070132
Length: 61, dtype: float64
Effective Coverage Probability for Nominal 50% PI: 0.08194071243682978


  f"evaluating in Python space because the {repr(op_str)} "


In [62]:
r80 = (imputed_proteins_test < q90)
l80 = (imputed_proteins_test > q10)

print((r80*l80).mean())
print(f"Effective Coverage Probability for Nominal 80% PI: {(r80*l80).mean().mean()}")

index
CD80     0.115883
CD86     0.240902
CD274    0.206823
CD273    0.138816
CD70     0.393628
           ...   
CD28     0.023289
CD127    0.074455
CD71     0.307613
CD16     0.119624
CD161    0.142293
Length: 61, dtype: float64
Effective Coverage Probability for Nominal 80% PI: 0.15578854924195737


  and should_run_async(code)
  f"evaluating in Python space because the {repr(op_str)} "
