In [None]:
import sys, os
import torch
import numpy as np 
import pandas as pd
from scDisInFact import scdisinfact, create_scdisinfact_dataset
from scDisInFact import utils

import matplotlib.pyplot as plt

from umap import UMAP
from sklearn.decomposition import PCA
import scipy.sparse as sp
from scipy import sparse
from scipy import stats
from scipy.sparse import issparse, csr_matrix

import scanpy as sc
plt.rcParams['text.color'] = 'black'
sc.set_figure_params(dpi=100, facecolor='white')

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
data_name = "pbmc"
out_dir = "/home/dongjiayi/workbench/denoise/data/pbmc/"
data_path = "/home/dongjiayi/workbench/denoise/data/pbmc/"
adata = sc.read_h5ad(data_path + "demo.h5ad")
adata.obs['batch'] = '0' 
counts = adata.X
meta_cells = adata.obs

print(adata)

In [None]:
celltype = list(set(adata.obs['cellstate']))

train_data_list = []
test_data_list = []
for i in celltype:
    test_idx = ((meta_cells["condition"] == 1) & (meta_cells["cellstate"] == i))
    test_data_list.append(test_idx)
    train_idx = ~test_idx
    train_data_list.append(train_idx)

In [None]:
# default setting of hyper-parameters
reg_mmd_comm = 1e-4
reg_mmd_diff = 1e-4
reg_kl_comm = 1e-5
reg_kl_diff = 1e-2
reg_class = 1
reg_gl = 1

Ks = [8, 2, 2] # 2

batch_size = 64
nepochs = 100
interval = 10
lr = 5e-4
lambs = [reg_mmd_comm, reg_mmd_diff, reg_kl_comm, reg_kl_diff, reg_class, reg_gl]

In [None]:
concat_adata_list = []

In [None]:
for i in range(len(celltype)): 
    data_dict = create_scdisinfact_dataset(counts[train_data_list[i],:], meta_cells.loc[train_data_list[i],:], 
                                           condition_key = ["condition", "cellstate"], batch_key = "batch", log_trans=False) # 
    
    model = scdisinfact(data_dict = data_dict, Ks = Ks, batch_size = batch_size, interval = interval, lr = lr, 
                    reg_mmd_comm = reg_mmd_comm, reg_mmd_diff = reg_mmd_diff, reg_gl = reg_gl, reg_class = reg_class, 
                    reg_kl_comm = reg_kl_comm, reg_kl_diff = reg_kl_diff, seed = 0, device = device)
    model.train()
    losses = model.train_model(nepochs = nepochs, recon_loss = "NB")

    _ = model.eval()
    
    input_idx = ((meta_cells["condition"] == 0) & (meta_cells["cellstate"] == celltype[i])).values
    counts_input = counts[input_idx,:]
    meta_input = meta_cells.loc[input_idx,:]

    counts_predict = model.predict_counts(input_counts = counts_input, meta_cells = meta_input, 
                                        condition_keys = ["condition", "cellstate"], 
                                        batch_key = "batch", predict_conds = [1, celltype[i]], 
                                        predict_batch = '0')
    
    trans_test_data = sc.AnnData(X=counts_predict)
    trans_test_data.obs_names=meta_input.index.to_series()
    trans_test_data.var=adata.var
    trans_test_data.obs['cellstate'] = celltype[i]
    trans_test_data.obs['condition'] = 2
    
    concat_adata_list.append(trans_test_data)

In [None]:
concat_adata = adata.concatenate(concat_adata_list)  
print(concat_adata)
# save model and data
concat_adata.write_h5ad(os.path.join(out_dir, f'{data_name}_scDisInFact_adata.h5ad'))