In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
from matplotlib import pyplot
import os
from copy import deepcopy

from time import time

from math import ceil
from scipy.stats import spearmanr, gamma, poisson

from anndata import AnnData, read_h5ad
import scanpy as sc
from scanpy import read
import pandas as pd

from torch.utils.data import DataLoader, TensorDataset
from torch import tensor
from torch.cuda import is_available

from sciPENN.sciPENN_API import sciPENN_API

from os.path import join
import shutil

In [2]:
"""Read in Raw Data"""

adata_gene = sc.read("../Data/pbmc/pbmc_gene.h5ad")
adata_protein = sc.read("../Data/pbmc/pbmc_protein.h5ad")

adata_gene_test = sc.read("../Data/H1N1/gene_data.mtx").T
adata_gene_test.var.index = pd.read_csv("../Data/H1N1/gene_names.txt", index_col = 0).iloc[:, 0]
adata_gene_test.obs = pd.read_csv("../Data/H1N1/meta_data.txt", sep = ',', index_col = 0)

In [3]:
base = "runtime"
train_path, test_path = join(base, 'train'), join(base, 'test')

if not os.path.isdir(base):
    os.mkdir(base)
    os.mkdir(train_path), os.mkdir(test_path)
    
    indices = {frac: [] for frac in [0.1, 0.2, 0.4, 0.5, 0.8, 1.0]}
    indices_test = {}
    
    np.random.seed(342)
    n, n_test = len(adata_gene), len(adata_gene_test)
    
    for frac in indices:
        indices[frac] = np.random.choice(range(n), round(frac * n), False).tolist()
        pd.DataFrame(indices[frac], columns = ['idx']).to_csv(join(train_path, f"idx{frac}.csv"))
        
        indices_test[frac] = np.random.choice(range(n_test), round(frac * n_test), False).tolist()
        pd.DataFrame(indices_test[frac], columns = ['idx']).to_csv(join(test_path, f"idx{frac}.csv"))

else:
    indices, indices_test = {}, {}
    for path in os.listdir(train_path):
        indices[float(path[3:6])] = pd.read_csv(join(train_path, path))['idx'].tolist()
        indices_test[float(path[3:6])] = pd.read_csv(join(test_path, path))['idx'].tolist()

In [4]:
times = {}

for frac in sorted(indices):
    idx, idx_test = indices[frac], indices_test[frac]
    
    start = time()
    
    sciPENN = sciPENN_API([adata_gene[idx]], [adata_protein[idx]], adata_gene_test[idx_test], 
                        train_batchkeys = ['donor'], test_batchkey = 'sample')

    sciPENN.train(n_epochs = 10000, ES_max = 12, decay_max = 6, 
                 decay_step = 0.1, lr = 10**(-3), weights_dir = "tmp", load = False)
    
    imputed_test = sciPENN.predict()
    
    times[frac] = time() - start
                                                                                 
shutil.rmtree("tmp")

Trying to set attribute `.obs` of view, copying.


Searching for GPU
GPU detected, using GPU


Trying to set attribute `.obs` of view, copying.
Trying to set attribute `.obs` of view, copying.



QC Filtering Training Cells
QC Filtering Testing Cells

QC Filtering Training Genes
QC Filtering Testing Genes

Normalizing Training Cells
Normalizing Testing Cells

Log-Normalizing Training Data
Log-Normalizing Testing Data

Finding HVGs


... storing 'orig.ident' as categorical
... storing 'lane' as categorical
... storing 'donor' as categorical
... storing 'time' as categorical
... storing 'celltype.l1' as categorical
... storing 'celltype.l2' as categorical
... storing 'celltype.l3' as categorical
... storing 'Phase' as categorical
... storing 'batch' as categorical
... storing 'Dataset' as categorical
... storing 'barcode_check' as categorical
... storing 'tenx_lane' as categorical
... storing 'cohort' as categorical
... storing 'hash_maxID' as categorical
... storing 'hash_secondID' as categorical
... storing 'hto_classification' as categorical
... storing 'hto_classification_global' as categorical
... storing 'hash_ID' as categorical
... storing 'adjmfc.time' as categorical
... storing 'DMX_GLOBAL_BEST' as categorical
... storing 'DEMUXLET.BARCODE' as categorical
... storing 'sample' as categorical
... storing 'joint_classification_global' as categorical
... storing 'timepoint' as categorical
... storing 'K0' as ca


Normalizing Gene Training Data by Batch


100%|██████████| 8/8 [00:00<00:00, 12.03it/s]



Normalizing Protein Training Data by Batch


100%|██████████| 8/8 [00:00<00:00, 17.45it/s]



Normalizing Gene Testing Data by Batch


100%|██████████| 20/20 [00:00<00:00, 83.85it/s]


Epoch 0 prediction loss = 1.397
Epoch 1 prediction loss = 0.926
Epoch 2 prediction loss = 0.915
Epoch 3 prediction loss = 0.905
Epoch 4 prediction loss = 0.905
Epoch 5 prediction loss = 0.900
Epoch 6 prediction loss = 0.899
Epoch 7 prediction loss = 0.896
Epoch 8 prediction loss = 0.895
Epoch 9 prediction loss = 0.891
Epoch 10 prediction loss = 0.891
Epoch 11 prediction loss = 0.893
Epoch 12 prediction loss = 0.891
Epoch 13 prediction loss = 0.893
Decaying loss to 0.0001
Epoch 14 prediction loss = 0.878
Epoch 15 prediction loss = 0.878
Epoch 16 prediction loss = 0.877
Epoch 17 prediction loss = 0.877
Epoch 18 prediction loss = 0.878
Epoch 19 prediction loss = 0.877
Decaying loss to 1e-05
Epoch 20 prediction loss = 0.876
Epoch 21 prediction loss = 0.876
Epoch 22 prediction loss = 0.876
Epoch 23 prediction loss = 0.876
Epoch 24 prediction loss = 0.876
Epoch 25 prediction loss = 0.876
Decaying loss to 1.0000000000000002e-06
Epoch 26 prediction loss = 0.876


Trying to set attribute `.obs` of view, copying.


Searching for GPU
GPU detected, using GPU


Trying to set attribute `.obs` of view, copying.
Trying to set attribute `.obs` of view, copying.



QC Filtering Training Cells
QC Filtering Testing Cells

QC Filtering Training Genes
QC Filtering Testing Genes

Normalizing Training Cells
Normalizing Testing Cells

Log-Normalizing Training Data
Log-Normalizing Testing Data

Finding HVGs


... storing 'orig.ident' as categorical
... storing 'lane' as categorical
... storing 'donor' as categorical
... storing 'time' as categorical
... storing 'celltype.l1' as categorical
... storing 'celltype.l2' as categorical
... storing 'celltype.l3' as categorical
... storing 'Phase' as categorical
... storing 'batch' as categorical
... storing 'Dataset' as categorical
... storing 'barcode_check' as categorical
... storing 'tenx_lane' as categorical
... storing 'cohort' as categorical
... storing 'hash_maxID' as categorical
... storing 'hash_secondID' as categorical
... storing 'hto_classification' as categorical
... storing 'hto_classification_global' as categorical
... storing 'hash_ID' as categorical
... storing 'adjmfc.time' as categorical
... storing 'DMX_GLOBAL_BEST' as categorical
... storing 'DEMUXLET.BARCODE' as categorical
... storing 'sample' as categorical
... storing 'joint_classification_global' as categorical
... storing 'timepoint' as categorical
... storing 'K0' as ca


Normalizing Gene Training Data by Batch


100%|██████████| 8/8 [00:01<00:00,  5.26it/s]



Normalizing Protein Training Data by Batch


100%|██████████| 8/8 [00:00<00:00, 11.82it/s]



Normalizing Gene Testing Data by Batch


100%|██████████| 20/20 [00:00<00:00, 74.82it/s]


Epoch 0 prediction loss = 1.408
Epoch 1 prediction loss = 0.918
Epoch 2 prediction loss = 0.905
Epoch 3 prediction loss = 0.897
Epoch 4 prediction loss = 0.894
Epoch 5 prediction loss = 0.891
Epoch 6 prediction loss = 0.888
Epoch 7 prediction loss = 0.887
Epoch 8 prediction loss = 0.884
Epoch 9 prediction loss = 0.884
Epoch 10 prediction loss = 0.882
Epoch 11 prediction loss = 0.880
Epoch 12 prediction loss = 0.884
Epoch 13 prediction loss = 0.880
Decaying loss to 0.0001
Epoch 14 prediction loss = 0.867
Epoch 15 prediction loss = 0.867
Epoch 16 prediction loss = 0.866
Epoch 17 prediction loss = 0.866
Epoch 18 prediction loss = 0.866
Epoch 19 prediction loss = 0.866
Decaying loss to 1e-05
Epoch 20 prediction loss = 0.865
Epoch 21 prediction loss = 0.865
Epoch 22 prediction loss = 0.864
Epoch 23 prediction loss = 0.865
Epoch 24 prediction loss = 0.865
Epoch 25 prediction loss = 0.865
Decaying loss to 1.0000000000000002e-06
Epoch 26 prediction loss = 0.865
Searching for GPU
GPU detected, 

Trying to set attribute `.obs` of view, copying.
Trying to set attribute `.obs` of view, copying.
Trying to set attribute `.obs` of view, copying.



QC Filtering Training Cells
QC Filtering Testing Cells

QC Filtering Training Genes
QC Filtering Testing Genes

Normalizing Training Cells
Normalizing Testing Cells

Log-Normalizing Training Data
Log-Normalizing Testing Data

Finding HVGs


... storing 'orig.ident' as categorical
... storing 'lane' as categorical
... storing 'donor' as categorical
... storing 'time' as categorical
... storing 'celltype.l1' as categorical
... storing 'celltype.l2' as categorical
... storing 'celltype.l3' as categorical
... storing 'Phase' as categorical
... storing 'batch' as categorical
... storing 'Dataset' as categorical
... storing 'barcode_check' as categorical
... storing 'tenx_lane' as categorical
... storing 'cohort' as categorical
... storing 'hash_maxID' as categorical
... storing 'hash_secondID' as categorical
... storing 'hto_classification' as categorical
... storing 'hto_classification_global' as categorical
... storing 'hash_ID' as categorical
... storing 'adjmfc.time' as categorical
... storing 'DMX_GLOBAL_BEST' as categorical
... storing 'DEMUXLET.BARCODE' as categorical
... storing 'sample' as categorical
... storing 'joint_classification_global' as categorical
... storing 'timepoint' as categorical
... storing 'K0' as ca


Normalizing Gene Training Data by Batch


100%|██████████| 8/8 [00:02<00:00,  3.41it/s]



Normalizing Protein Training Data by Batch


100%|██████████| 8/8 [00:01<00:00,  7.41it/s]



Normalizing Gene Testing Data by Batch


100%|██████████| 20/20 [00:00<00:00, 37.49it/s]


Epoch 0 prediction loss = 1.388
Epoch 1 prediction loss = 0.895
Epoch 2 prediction loss = 0.886
Epoch 3 prediction loss = 0.880
Epoch 4 prediction loss = 0.876
Epoch 5 prediction loss = 0.872
Epoch 6 prediction loss = 0.870
Epoch 7 prediction loss = 0.870
Epoch 8 prediction loss = 0.866
Epoch 9 prediction loss = 0.866
Epoch 10 prediction loss = 0.868
Epoch 11 prediction loss = 0.865
Epoch 12 prediction loss = 0.862
Epoch 13 prediction loss = 0.864
Decaying loss to 0.0001
Epoch 14 prediction loss = 0.852
Epoch 15 prediction loss = 0.852
Epoch 16 prediction loss = 0.851
Epoch 17 prediction loss = 0.850
Epoch 18 prediction loss = 0.850
Epoch 19 prediction loss = 0.850
Decaying loss to 1e-05
Epoch 20 prediction loss = 0.850
Epoch 21 prediction loss = 0.849
Epoch 22 prediction loss = 0.848
Epoch 23 prediction loss = 0.849
Epoch 24 prediction loss = 0.849
Epoch 25 prediction loss = 0.849
Decaying loss to 1.0000000000000002e-06
Epoch 26 prediction loss = 0.849


Trying to set attribute `.obs` of view, copying.


Searching for GPU
GPU detected, using GPU


Trying to set attribute `.obs` of view, copying.
Trying to set attribute `.obs` of view, copying.



QC Filtering Training Cells
QC Filtering Testing Cells

QC Filtering Training Genes
QC Filtering Testing Genes

Normalizing Training Cells
Normalizing Testing Cells

Log-Normalizing Training Data
Log-Normalizing Testing Data

Finding HVGs


... storing 'orig.ident' as categorical
... storing 'lane' as categorical
... storing 'donor' as categorical
... storing 'time' as categorical
... storing 'celltype.l1' as categorical
... storing 'celltype.l2' as categorical
... storing 'celltype.l3' as categorical
... storing 'Phase' as categorical
... storing 'batch' as categorical
... storing 'Dataset' as categorical
... storing 'barcode_check' as categorical
... storing 'tenx_lane' as categorical
... storing 'cohort' as categorical
... storing 'hash_maxID' as categorical
... storing 'hash_secondID' as categorical
... storing 'hto_classification' as categorical
... storing 'hto_classification_global' as categorical
... storing 'hash_ID' as categorical
... storing 'adjmfc.time' as categorical
... storing 'DMX_GLOBAL_BEST' as categorical
... storing 'DEMUXLET.BARCODE' as categorical
... storing 'sample' as categorical
... storing 'joint_classification_global' as categorical
... storing 'timepoint' as categorical
... storing 'K0' as ca


Normalizing Gene Training Data by Batch


100%|██████████| 8/8 [00:02<00:00,  2.77it/s]



Normalizing Protein Training Data by Batch


100%|██████████| 8/8 [00:01<00:00,  5.84it/s]



Normalizing Gene Testing Data by Batch


100%|██████████| 20/20 [00:00<00:00, 30.81it/s]


Epoch 0 prediction loss = 1.390
Epoch 1 prediction loss = 0.897
Epoch 2 prediction loss = 0.885
Epoch 3 prediction loss = 0.880
Epoch 4 prediction loss = 0.875
Epoch 5 prediction loss = 0.873
Epoch 6 prediction loss = 0.871
Epoch 7 prediction loss = 0.869
Epoch 8 prediction loss = 0.868
Epoch 9 prediction loss = 0.866
Epoch 10 prediction loss = 0.867
Epoch 11 prediction loss = 0.866
Epoch 12 prediction loss = 0.866
Decaying loss to 0.0001
Epoch 13 prediction loss = 0.855
Epoch 14 prediction loss = 0.855
Epoch 15 prediction loss = 0.854
Epoch 16 prediction loss = 0.854
Epoch 17 prediction loss = 0.854
Epoch 18 prediction loss = 0.854
Decaying loss to 1e-05
Epoch 19 prediction loss = 0.853
Epoch 20 prediction loss = 0.853
Epoch 21 prediction loss = 0.853
Epoch 22 prediction loss = 0.853
Epoch 23 prediction loss = 0.853
Epoch 24 prediction loss = 0.853
Decaying loss to 1.0000000000000002e-06
Epoch 25 prediction loss = 0.853
Searching for GPU
GPU detected, using GPU


Trying to set attribute `.obs` of view, copying.
Trying to set attribute `.obs` of view, copying.
Trying to set attribute `.obs` of view, copying.



QC Filtering Training Cells
QC Filtering Testing Cells

QC Filtering Training Genes
QC Filtering Testing Genes

Normalizing Training Cells
Normalizing Testing Cells

Log-Normalizing Training Data
Log-Normalizing Testing Data

Finding HVGs


... storing 'orig.ident' as categorical
... storing 'lane' as categorical
... storing 'donor' as categorical
... storing 'time' as categorical
... storing 'celltype.l1' as categorical
... storing 'celltype.l2' as categorical
... storing 'celltype.l3' as categorical
... storing 'Phase' as categorical
... storing 'batch' as categorical
... storing 'Dataset' as categorical
... storing 'barcode_check' as categorical
... storing 'tenx_lane' as categorical
... storing 'cohort' as categorical
... storing 'hash_maxID' as categorical
... storing 'hash_secondID' as categorical
... storing 'hto_classification' as categorical
... storing 'hto_classification_global' as categorical
... storing 'hash_ID' as categorical
... storing 'adjmfc.time' as categorical
... storing 'DMX_GLOBAL_BEST' as categorical
... storing 'DEMUXLET.BARCODE' as categorical
... storing 'sample' as categorical
... storing 'joint_classification_global' as categorical
... storing 'timepoint' as categorical
... storing 'K0' as ca


Normalizing Gene Training Data by Batch


100%|██████████| 8/8 [00:04<00:00,  1.88it/s]



Normalizing Protein Training Data by Batch


100%|██████████| 8/8 [00:02<00:00,  3.99it/s]



Normalizing Gene Testing Data by Batch


100%|██████████| 20/20 [00:00<00:00, 29.61it/s]


Epoch 0 prediction loss = 1.396
Epoch 1 prediction loss = 0.891
Epoch 2 prediction loss = 0.881
Epoch 3 prediction loss = 0.877
Epoch 4 prediction loss = 0.875
Epoch 5 prediction loss = 0.871
Epoch 6 prediction loss = 0.871
Epoch 7 prediction loss = 0.870
Epoch 8 prediction loss = 0.867
Epoch 9 prediction loss = 0.869
Epoch 10 prediction loss = 0.870
Epoch 11 prediction loss = 0.866
Epoch 12 prediction loss = 0.866
Epoch 13 prediction loss = 0.867
Epoch 14 prediction loss = 0.867
Epoch 15 prediction loss = 0.867
Epoch 16 prediction loss = 0.865
Decaying loss to 0.0001
Epoch 17 prediction loss = 0.857
Epoch 18 prediction loss = 0.856
Epoch 19 prediction loss = 0.856
Epoch 20 prediction loss = 0.856
Epoch 21 prediction loss = 0.855
Epoch 22 prediction loss = 0.855
Decaying loss to 1e-05
Epoch 23 prediction loss = 0.854
Epoch 24 prediction loss = 0.855
Epoch 25 prediction loss = 0.855
Epoch 26 prediction loss = 0.854
Epoch 27 prediction loss = 0.855
Epoch 28 prediction loss = 0.854
Decayi

Trying to set attribute `.obs` of view, copying.
Trying to set attribute `.obs` of view, copying.
Trying to set attribute `.obs` of view, copying.



QC Filtering Training Cells
QC Filtering Testing Cells

QC Filtering Training Genes
QC Filtering Testing Genes

Normalizing Training Cells
Normalizing Testing Cells

Log-Normalizing Training Data
Log-Normalizing Testing Data

Finding HVGs


... storing 'orig.ident' as categorical
... storing 'lane' as categorical
... storing 'donor' as categorical
... storing 'time' as categorical
... storing 'celltype.l1' as categorical
... storing 'celltype.l2' as categorical
... storing 'celltype.l3' as categorical
... storing 'Phase' as categorical
... storing 'batch' as categorical
... storing 'Dataset' as categorical
... storing 'barcode_check' as categorical
... storing 'tenx_lane' as categorical
... storing 'cohort' as categorical
... storing 'hash_maxID' as categorical
... storing 'hash_secondID' as categorical
... storing 'hto_classification' as categorical
... storing 'hto_classification_global' as categorical
... storing 'hash_ID' as categorical
... storing 'adjmfc.time' as categorical
... storing 'DMX_GLOBAL_BEST' as categorical
... storing 'DEMUXLET.BARCODE' as categorical
... storing 'sample' as categorical
... storing 'joint_classification_global' as categorical
... storing 'timepoint' as categorical
... storing 'K0' as ca


Normalizing Gene Training Data by Batch


100%|██████████| 8/8 [00:05<00:00,  1.48it/s]



Normalizing Protein Training Data by Batch


100%|██████████| 8/8 [00:02<00:00,  2.91it/s]



Normalizing Gene Testing Data by Batch


100%|██████████| 20/20 [00:01<00:00, 18.21it/s]


Epoch 0 prediction loss = 1.392
Epoch 1 prediction loss = 0.887
Epoch 2 prediction loss = 0.879
Epoch 3 prediction loss = 0.873
Epoch 4 prediction loss = 0.871
Epoch 5 prediction loss = 0.867
Epoch 6 prediction loss = 0.867
Epoch 7 prediction loss = 0.866
Epoch 8 prediction loss = 0.866
Epoch 9 prediction loss = 0.865
Epoch 10 prediction loss = 0.864
Decaying loss to 0.0001
Epoch 11 prediction loss = 0.852
Epoch 12 prediction loss = 0.852
Epoch 13 prediction loss = 0.852
Epoch 14 prediction loss = 0.852
Epoch 15 prediction loss = 0.852
Epoch 16 prediction loss = 0.852
Decaying loss to 1e-05
Epoch 17 prediction loss = 0.851
Epoch 18 prediction loss = 0.851
Epoch 19 prediction loss = 0.850
Epoch 20 prediction loss = 0.851
Epoch 21 prediction loss = 0.851
Epoch 22 prediction loss = 0.850
Decaying loss to 1.0000000000000002e-06
Epoch 23 prediction loss = 0.851


In [5]:
for key in times:
    times[key] = [times[key]]

In [6]:
pd.DataFrame(times, index = ['sciPENN']).T.to_csv(join(base, "scipenn.csv"))

In [7]:
pd.read_csv(join(base, "scipenn.csv"), index_col = 0)

Unnamed: 0,sciPENN
0.1,76.619142
0.2,139.741037
0.4,279.383748
0.5,352.099309
0.8,585.513345
1.0,664.257289
