In [None]:
# conda activate python38
import scanpy as sc
import squidpy as sq

import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt

plt.rcParams['text.color'] = 'black'
sc.set_figure_params(dpi=100, facecolor='white')

import warnings
warnings.filterwarnings("ignore")

In [None]:
from pertvi.model import PertVI

In [None]:
data_name='pbmc'
out_dir = "/home/dongjiayi/workbench/denoise/data/pbmc/"
data_path = "/home/dongjiayi/workbench/denoise/data/pbmc/"
adata = sc.read_h5ad(data_path + "demo.h5ad")
adata.obs['condition'] = adata.obs['condition'].apply(lambda x:str(x))

In [None]:
PertVI.get_pert(adata, drug_label='condition')
PertVI.setup_anndata(adata, labels_key ='condition')

In [None]:
celltype = list(set(adata.obs['cellstate']))
train_data_list = []
test_data_list = []

for i in celltype:
    test_adata = adata[(adata.obs['cellstate'] == i) & (adata.obs['condition'] == '0')]
    test_data_list.append(test_adata)
    train_adata = adata[adata.obs['cellstate'] != i]
    train_data_list.append(train_adata)

In [None]:
for i, ct in enumerate(celltype):
    train_adata = train_data_list[i]
    test_adata = test_data_list[i]
    PertVI.get_pert(train_adata, drug_label='condition')
    PertVI.get_pert(test_adata, drug_label='condition')
    PertVI.setup_anndata(train_adata, labels_key = 'condition')

    model_test = PertVI(train_adata, n_layers=2,n_latent=100,n_hidden=1000,lam_l0=0.2,lam_l1=1e-4,
                        kl_weight=0.5,lam_corr=1.5, use_observed_lib_size=True)
    model_test.train(
        train_size=0.9,
        use_gpu=1, # needs to be adjusted base on your own device
        batch_size=400,
        early_stopping=False,
        max_epochs=400,
        lr = 1e-3,
        weight_decay = 1e-5,
        n_samples_per_label = 2,
    )
    
    PertVI.setup_anndata(test_adata, batch_key='batch', labels_key = 'condition')
    test_adata.obsm['pred'] = np.array([['0', '1']] * test_adata.shape[0])
    test_adata.obsm['X_trans'] = model_test.get_response(test_adata, pert_key='pred')

    trans_train_adata = sc.AnnData(X=test_adata.obsm['X_trans'], obs=test_adata.obs, var=test_adata.var)
    trans_train_adata.obs['condition'] = [2] * trans_train_adata.shape[0]
    print(trans_train_adata)
    concat_adata_list.append(trans_train_adata)


In [None]:
concat_adata = adata.concatenate(concat_adata_list)   
print(concat_adata)

In [None]:
concat_adata.write_h5ad(os.path.join(out_dir, f'{data_name}_scShift_adata.h5ad'))