In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)

In [2]:
import scanpy as sc
import torch
import scarches
from scarches.models import TRANVAE
from scarches.dataset import remove_sparsity
import matplotlib.pyplot as plt
import numpy as np

In [3]:
sc.settings.set_figure_params(dpi=200, frameon=False)
sc.set_figure_params(dpi=200)
sc.set_figure_params(figsize=(4, 4))
torch.set_printoptions(precision=3, sci_mode=False, edgeitems=7)

In [4]:
label_ratio = 0

In [5]:
condition_key = 'study'
cell_type_key = 'cell_type'


tranvae_epochs = 500
target_conditions = ['Pancreas CelSeq2', 'Pancreas SS2']

early_stopping_kwargs = {
    "early_stopping_metric": "val_unweighted_loss",
    "threshold": 0,
    "patience": 20,
    "reduce_lr": True,
    "lr_patience": 13,
    "lr_factor": 0.1,
}

In [6]:
adata_all = sc.read(os.path.expanduser(f'~/Documents/benchmarking_datasets/pancreas_normalized.h5ad'))
adata = adata_all.raw.to_adata()
adata = remove_sparsity(adata)
source_adata = adata[~adata.obs[condition_key].isin(target_conditions)].copy()
target_adata = adata[adata.obs[condition_key].isin(target_conditions)].copy()

In [7]:
idx = np.arange(len(target_adata))
np.random.shuffle(idx)
n_labeled = int(label_ratio*len(target_adata))
labeled_ind = idx[:n_labeled]

In [8]:
print(target_adata.obs[condition_key][labeled_ind].unique().tolist())
print(target_adata.obs[condition_key][~labeled_ind].unique().tolist())
print(target_adata.obs[cell_type_key][labeled_ind].unique().tolist())
print(target_adata.obs[cell_type_key][~labeled_ind].unique().tolist())

[]
[]
[]
[]


In [9]:
load_path = os.path.expanduser(f'~/Documents/aaa_dev_mars/pancreas_testing/reference_03/')
tranvae = scarches.models.TRANVAE.load_query_data(
    adata=target_adata,
    reference_model = f'{load_path}reference_model/',
    labeled_indices=labeled_ind,
)


INITIALIZING NEW NETWORK..............
Encoder Architecture:
	Input Layer in, out and cond: 1000 128 5
	Hidden Layer 1 in/out: 128 128
	Mean/Var Layer in/out: 128 10
Decoder Architecture:
	First Layer in, out and cond:  10 128 5
	Hidden Layer 1 in/out: 128 128
	Output Layer in/out:  128 1000 



In [10]:
tranvae.model.condition_encoder

{'Pancreas inDrop': 0,
 'Pancreas CelSeq': 1,
 'Pancreas Fluidigm C1': 2,
 'Pancreas CelSeq2': 3,
 'Pancreas SS2': 4}

In [34]:
unique_labels = []
for key in tranvae.model.cell_type_encoder:
    unique_labels.append(tranvae.model.cell_type_encoder[key])
tensor_l = torch.unique(torch.tensor(unique_labels, requires_grad=False))
x = torch.randn([30,5])
indices = x.eq().nonzero()
y = x[indices]
y.mean(0)
print(y.size())

torch.Size([0, 2, 5])


In [12]:
print(tranvae.model.landmarks_labeled)
print(tranvae.model.landmarks_unlabeled)

[[-0.34  0.05 -0.23 -1.03  0.56 -0.03  0.01 -0.1  -0.34 -0.16]
 [-0.33  0.04 -0.23 -1.01  0.55 -0.04  0.01 -0.11 -0.38 -0.17]
 [-0.19 -0.05 -0.24  0.55  0.71  1.04 -0.59  0.37 -0.14  0.41]
 [-0.25  0.08 -0.24 -0.35 -0.33  0.77  0.34  1.72 -0.34  0.52]
 [-0.35  0.1  -0.11 -0.99  0.58 -0.01  0.04 -0.09 -0.36 -0.19]
 [-0.23  0.02 -0.2   0.03  0.31 -0.19 -0.05  0.37 -0.25  0.25]
 [-0.33  0.11 -0.23  0.05  0.32 -1.08  1.69  0.53 -0.21  0.76]
 [-0.1   0.02 -0.23 -0.31  1.19  0.8   1.28  0.47 -0.11  1.44]]
[[-0.34  0.03 -0.17  0.18 -0.6   0.68  0.8  -0.57 -0.32  0.92]
 [-0.08  0.03 -0.23 -0.32  1.22  0.85  1.3   0.47 -0.11  1.48]
 [-0.19  0.02 -0.21  0.32 -0.26 -1.51 -0.32  0.64 -0.32  0.24]
 [-0.23  0.08 -0.24 -0.36 -0.35  0.79  0.35  1.77 -0.33  0.52]
 [-0.32  0.04 -0.31 -0.43  0.33 -0.58 -0.6   0.34 -0.21  1.96]
 [-0.19 -0.07 -0.23  0.59  0.72  1.08 -0.61  0.37 -0.13  0.41]
 [-0.33  0.07 -0.2  -1.05  0.57 -0.03  0.02 -0.12 -0.37 -0.19]
 [-0.33  0.11 -0.23  0.06  0.32 -1.09  1.71  0.53 -0.2

In [13]:
print(tranvae.labeled_indices_)

[]
