In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
from matplotlib import pyplot
import os
from copy import deepcopy

from time import time

from math import ceil
from scipy.stats import spearmanr, gamma, poisson

from anndata import AnnData, read_h5ad
import scanpy as sc
from scanpy import read
import pandas as pd

from torch.utils.data import DataLoader, TensorDataset
from torch import tensor
from torch.cuda import is_available

from sciPENN.sciPENN_API import sciPENN_API

In [3]:
adata_gene_train.X

<35496x1000 sparse matrix of type '<class 'numpy.float32'>'
	with 11070575 stored elements in Compressed Sparse Column format>

In [2]:
# read in data
adata_gene_train = sc.read("/home/asmauger/biostat625final/rna_train_hvg.h5ad")
adata_gene_test = sc.read("/home/asmauger/biostat625final/rna_test_hvg.h5ad")
adata_protein_train = sc.read("/home/asmauger/biostat625final/prot_train.h5ad")
adata_protein_test = sc.read("/home/asmauger/biostat625final/prot_test.h5ad")
ref = sc.read("/home/asmauger/biostat625final/pbmc_gene.h5ad")

In [3]:
# note that donor, day are stored as integers and might not work as expected with scipenn
# use 'daydonor' instead
# these data are already cell normalized and log normalized
sciPENN = sciPENN_API(gene_trainsets = [adata_gene_train], protein_trainsets = [adata_protein_train], 
                      gene_test = adata_gene_test, train_batchkeys = ['daydonor'], test_batchkey = 'daydonor',  use_gpu=False,
                     select_hvg=False, cell_normalize=False, log_normalize=False, min_cells=0, min_genes=0)

Using CPU

Normalizing Gene Training Data by Batch


100%|██████████| 9/9 [00:01<00:00,  6.12it/s]



Normalizing Protein Training Data by Batch


100%|██████████| 9/9 [00:00<00:00, 18.31it/s]



Normalizing Gene Testing Data by Batch


100%|██████████| 9/9 [00:01<00:00,  5.17it/s]


In [4]:
# make sure load=False, unless you want to re-use weights from a previous run
sciPENN.train(quantiles = [0.1, 0.25, 0.75, 0.9], n_epochs = 10000, ES_max = 12, decay_max = 6, 
             decay_step = 0.1, lr = 10**(-3), load = False)

Epoch 0 prediction loss = 1.408
Epoch 1 prediction loss = 0.948
Epoch 2 prediction loss = 0.943
Epoch 3 prediction loss = 0.937
Epoch 4 prediction loss = 0.938
Epoch 5 prediction loss = 0.931
Epoch 6 prediction loss = 0.932
Epoch 7 prediction loss = 0.931
Epoch 8 prediction loss = 0.927
Epoch 9 prediction loss = 0.928
Epoch 10 prediction loss = 0.933
Epoch 11 prediction loss = 0.929
Epoch 12 prediction loss = 0.930
Epoch 13 prediction loss = 0.927
Decaying loss to 0.0001
Epoch 14 prediction loss = 0.920
Epoch 15 prediction loss = 0.923
Epoch 16 prediction loss = 0.920
Epoch 17 prediction loss = 0.921
Epoch 18 prediction loss = 0.921
Epoch 19 prediction loss = 0.922
Decaying loss to 1e-05
Epoch 20 prediction loss = 0.922
Epoch 21 prediction loss = 0.920
Epoch 22 prediction loss = 0.920
Epoch 23 prediction loss = 0.919
Epoch 24 prediction loss = 0.919
Epoch 25 prediction loss = 0.921
Decaying loss to 1.0000000000000002e-06
Epoch 26 prediction loss = 0.920


In [5]:
imputed_test = sciPENN.predict()


In [6]:
embedding = sciPENN.embed()

In [11]:
embedding.X

AttributeError: 'sciPENN_API' object has no attribute 'X'

In [28]:
print(sum(adata_protein_test.obs.index == imputed_test.obs.index))
print(sum(imputed_test.var.index == adata_protein_test.var.index))

35492
134


In [29]:
adata_protein_test.X = adata_protein_test.X.toarray() 

adata_protein_test.layers['imputed'] = imputed_test.X
adata_protein_test.layers.update(imputed_test.layers)

# scaling by batch
batches = np.unique(adata_protein_test.obs['daydonor'].values)

for i in batches:
    indices = [x == i for x in adata_protein_test.obs['daydonor']]
    sub_adata = adata_protein_test[indices]

    sc.pp.scale(sub_adata)
    adata_protein_test[indices] = sub_adata.X


  view_to_actual(adata)


In [30]:
MSEs= ((adata_protein_test.X - adata_protein_test.layers["imputed"])**2).mean(axis = 0)**(1/2)

In [31]:
print(MSEs)

[0.94374955 0.9375769  0.82664925 0.6797544  0.49247658 0.57803774
 0.5454278  0.94394684 0.77229065 0.8223677  0.96609485 1.0005746
 0.92133415 0.98532623 0.75585854 0.9315182  0.5709761  0.528932
 0.6900277  0.9970452  0.701812   0.735463   0.9938699  0.9363315
 0.5313256  0.97555137 0.9906128  0.9602575  0.79632884 0.75971913
 0.88762254 0.89638966 0.96977484 0.52683294 0.94587386 0.9842925
 0.90976167 0.99282086 0.94975656 0.65677994 0.94174635 0.87124735
 0.99618775 0.9277489  0.4401561  0.92621505 0.8287864  0.891454
 0.9467992  0.9930261  0.94243157 0.59652823 0.94325566 0.6196285
 0.9884349  0.9888636  0.9362089  0.6987578  0.87448055 0.8599059
 0.93878615 0.9894791  0.9974234  0.94126743 0.84901834 0.9408808
 0.672948   0.68342376 0.9967339  0.5417501  0.82834446 0.5358829
 0.98286194 0.92199916 0.990635   0.9642353  0.7346427  0.95062196
 0.8093105  0.96905243 0.85875595 0.93851256 0.90955013 0.5742118
 0.9252047  0.7999241  0.6919029  0.88620704 0.58177745 0.99424887
 0.9884