In [None]:
import sys, os
import torch
import numpy as np 
import pandas as pd
from scDisInFact import scdisinfact, create_scdisinfact_dataset
from scDisInFact import utils

import matplotlib.pyplot as plt
import scDisInFact.bmk as bmk

from umap import UMAP
from sklearn.decomposition import PCA
import scipy.sparse as sp
from scipy import sparse
from scipy import stats

import scanpy as sc
plt.rcParams['text.color'] = 'black'
sc.set_figure_params(dpi=100, facecolor='white')

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")


In [None]:
data_name = "pbmc"
out_dir = "/home/dongjiayi/workbench/denoise/data/pbmc/"
data_path = "/home/dongjiayi/workbench/denoise/data/pbmc/"
adata = sc.read_h5ad(data_path + "demo.h5ad")
adata.obs['batch'] = adata.obs['batch'].apply(lambda x:int(x))
counts = adata.X
meta_cells = adata.obs

print(adata)

In [None]:
adata = sc.read_h5ad(data_path + "demo.h5ad")
adata.obs['batch'] = "none"
meta_cells = adata.obs
counts = adata.X
counts = np.log1p(counts/(np.sum(counts, axis = 1, keepdims = True) + 1e-6) * 100)
counts_ctrl = counts[(adata.obs["condition"] == 0),:]
counts_stim = counts[(adata.obs["condition"] == 1),:]

pvals = bmk.wilcoxon_rank_sum(counts_ctrl, counts_stim, fdr = True)
assert pvals.shape[0] == 2000
assert np.min(pvals) >= 0
# scale to 0-1
score_wilcoxon = 1 - pvals/np.max(pvals)
score_wilcoxon  = score_wilcoxon.squeeze()

In [None]:
# default setting of hyper-parameters
reg_mmd_comm = 1e-4
reg_mmd_diff = 1e-4
reg_kl_comm = 1e-5
reg_kl_diff = 1e-2
reg_class = 1
reg_gl = 1

Ks = [8, 2]

batch_size = 64
nepochs = 100
interval = 10
lr = 5e-4
lambs = [reg_mmd_comm, reg_mmd_diff, reg_kl_comm, reg_kl_diff, reg_class, reg_gl]

In [None]:
auprc_dict = pd.DataFrame(columns = ["AUPRC", "AUROC", "Eprec", "Pearson", "common_100", "common_200",
                                     "common_300", "common_400", "common_500"])

for i in range(10):
    data_dict = create_scdisinfact_dataset(counts, meta_cells, 
                                            condition_key = ["condition"], batch_key = "batch", log_trans=False)

    model = scdisinfact(data_dict = data_dict, Ks = Ks, batch_size = batch_size, interval = interval, lr = lr, 
                    reg_mmd_comm = reg_mmd_comm, reg_mmd_diff = reg_mmd_diff, reg_gl = reg_gl, reg_class = reg_class, 
                    reg_kl_comm = reg_kl_comm, reg_kl_diff = reg_kl_diff, seed = i, device = device) 
    model.train() 
    losses = model.train_model(nepochs = nepochs, recon_loss = "NB")
    _ = model.eval()
    
    
    gene_scores = model.extract_gene_scores()
    # the scores of 500 genes for condition type 1
    print(gene_scores)

    wilcoxon_df = pd.DataFrame(score_wilcoxon)
    wilcoxon_df.index = adata.var_names
    wilcoxon_df.columns = ["infWeight"]
    
    wilcoxon_df = wilcoxon_df.sort_values(['infWeight'], ascending=False)

    Inv_df = pd.DataFrame(gene_scores[0])
    Inv_df.index = adata.var_names
    Inv_df.columns = ["infWeight"]
    Inv_df = Inv_df.sort_values(['infWeight'], ascending=False)
    
    data = pd.concat([wilcoxon_df, Inv_df], axis=1)
    data.columns = ['wilcoxon', 'scDisInFact']
    print(data.head())
    
    from scipy.stats import pearsonr, spearmanr, kendalltau

    inf = np.array(data['scDisInFact'])
    gt = np.array(data['wilcoxon'])
    # gt[gt < 1] = 0
    gt[pvals > 0.05] = 0
    gt[pvals <= 0.05] = 1
    
    AUPRC_value = bmk.compute_auprc(inf, gt)
    # AUPRCre = bmk.compute_auprc(inf, gt)/(ndiff_genes/ngenes)
    AUROC = bmk.compute_auroc(inf, gt)
    Eprec = bmk.compute_earlyprec(inf, gt)
    Pearson = pearsonr(inf, gt)[0]
    
    common_gene = {}
    for j in [100, 200, 300, 400, 500]:
        common_gene[j] = len(wilcoxon_df.index[:j] & Inv_df.index[:j])
        
    auprc_dict = pd.concat([auprc_dict,
                            pd.DataFrame.from_dict({
                                    "AUPRC": [AUPRC_value], 
                                    "AUROC": [AUROC],
                                    "Eprec": [Eprec],
                                    "Pearson": [Pearson],
                                    "common_100": [common_gene[100]],
                                    "common_200": [common_gene[200]],
                                    "common_300": [common_gene[300]],
                                    "common_400": [common_gene[400]],
                                    "common_500": [common_gene[500]],
                                    })], axis = 0, ignore_index = True)