In [1]:
from types import CellType
import numpy as np
import pandas as pd
import os
import scanpy as sc
from model.dataset import load_train_dataset,load_test_dataset
from model.model import Class_VAE
from model.layers import LinearAverage,Logit_Linear
from model.process_data import select_peak
from utils.metric_compute import compute_EAS,compute_EAS_EpiAnno
from sklearn import preprocessing
from anndata import AnnData
from utils.utils import set_seed,ForeverDataIterator,ForeverDataIteratorExtension
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [16]:
class argument():
    def __init__(self,class_t,hard_weight,sample_weight,class_t_coff,class_coff) -> None:
        self.class_t = class_t
        self.sample_weight = sample_weight
        self.hard_weight = hard_weight
        self.class_t_coff = class_t_coff
        self.class_coff = class_coff

In [17]:
batch_size = 64
binary = True # set False if data is scRNA-seq
args = argument(True,False,True,1.0,1.0)
lr = 0.0002
epoch = 3
threshold = 0.95

In [4]:
set_seed()
# read data
train_adata = sc.read_h5ad("../../atac_class/EpiAnno_Forebrain.h5ad")
test_adata = sc.read_h5ad("../../atac_class/preprocess_mouse_brain.h5ad")
train_adata.X[train_adata.X > 0] = 1
test_adata.X[test_adata.X > 0] = 1
train_adata.obs['domain'] = 0
test_adata.obs['domain'] = 1
device = torch.device(0)

In [5]:
#change cell tyep to num label and data preprocess
one_hot = preprocessing.OneHotEncoder(sparse=False)
le = preprocessing.LabelEncoder().fit(np.append(np.unique(train_adata.obs['CellType'].values),'unknown'))
le_domain = one_hot.fit([[0],[1]])
train_adata,test_adata = select_peak(train_adata,test_adata,peak_rate=0.001)
train_adata.obs['Label'] = le.transform(np.array(train_adata.obs['CellType'].values).reshape(-1,1))
test_adata.obs['Label'] = le.transform(np.array(test_adata.obs['CellType'].values).reshape(-1,1))

raw train data:AnnData object with n_obs × n_vars = 2088 × 436206
    obs: 'CellType', 'batch', 'domain'
    var: 'chrom', 'chromStart', 'chromEnd', 'name'
raw test data:AnnData object with n_obs × n_vars = 17003 × 436206
    obs: 'cell', 'tissue', 'tissue.replicate', 'cluster', 'subset_cluster', 'tsne_1', 'tsne_2', 'subset_tsne1', 'subset_tsne2', 'id', 'cell_label', 'CellType', 'domain'
    var: 'chrom', 'chromStart', 'chromEnd', 'name'




processed_atac:AnnData object with n_obs × n_vars = 1098 × 20000
    obs: 'CellType', 'batch', 'domain', 'n_genes'
    var: 'chrom', 'chromStart', 'chromEnd', 'name', 'n_cells', 'prop_shared_cells', 'variability_score'
process test data:AnnData object with n_obs × n_vars = 17003 × 20000
    obs: 'cell', 'tissue', 'tissue.replicate', 'cluster', 'subset_cluster', 'tsne_1', 'tsne_2', 'subset_tsne1', 'subset_tsne2', 'id', 'cell_label', 'CellType', 'domain', 'n_genes'
    var: 'chrom', 'chromStart', 'chromEnd', 'name', 'n_cells'




In [6]:
#set data loader
cell_num = train_adata.shape[0] 
input_dim = train_adata.shape[1]
class_sample_counts = np.array(list(range(len(np.unique(train_adata.obs['Label'].values)))))
for i in np.unique(train_adata.obs['Label'].values):
    class_sample_counts[i] = len(train_adata[train_adata.obs['Label'].values == i])
train_onehot = np.zeros((cell_num,len(np.unique(train_adata.obs['Label'].values))))
for i in np.unique(train_adata.obs['Label'].values):
    train_onehot[train_adata.obs['Label'].values == i,i] = 1

weight = np.zeros(len(np.unique(train_adata.obs['Label'].values)))
class_weight = np.zeros((cell_num,len(np.unique(train_adata.obs['Label'].values))))
for i in np.unique(train_adata.obs['Label'].values):
    class_weight[train_adata.obs['Label'].values != i,i] = 1
    class_weight[train_adata.obs['Label'].values == i,i] = np.sum(class_sample_counts[list(range(len(np.unique(train_adata.obs['Label'].values)))) != i]) / class_sample_counts[i]
    weight[i] = max(class_sample_counts) / class_sample_counts[i]
train_loader = load_train_dataset(train_adata,weight=class_weight,onehotlabel=train_onehot,batch_size=batch_size,drop_last=False,num_workers=4,shuffle=True,sample=None)
test_adata_loader_1 = load_train_dataset(test_adata,batch_size = batch_size,shuffle=True,drop_last=False,num_workers=4,sample=None)

In [7]:
latent_dim = 32
encode_dim = [3200,1600,800,400]
class_dim = len(np.unique(train_adata.obs["CellType"]))
decode_dim = []
dims = [input_dim, latent_dim,latent_dim, encode_dim,decode_dim]#VAE encoder and decoder layers
c_dim = [latent_dim,class_dim]#open and close classifier layers

In [8]:
train_iter = ForeverDataIteratorExtension(train_loader)
test_iter = ForeverDataIterator(test_adata_loader_1)
memory_bank = LinearAverage(inputSize=class_dim,outputSize=len(test_adata),device=device,threshold=np.ones((batch_size,class_dim)),celltype = test_adata.obs['Label'].values)

In [9]:
#model training
model = Class_VAE(dims,c_dim,dropout=0,binary = binary,finally_activate=None,num_class=class_dim,device=device)



In [19]:
#model save path
save_path = ''

In [21]:
model.fit(args,train_iter = train_iter,
            # val_loader=val_loader,
            test_iter = test_iter,
            lr=lr, 
            n = epoch,
            weight_decay=5e-4,
            savepath='../save_model/' + save_path,
            imgpath = '../img/' + save_path,
            device = device,
            iter = len(train_loader),
            memory_bank = memory_bank,
            # memory_bank_s = memory_bank_s,
            class_num = class_dim,
            embedding_size = latent_dim,
            # logit_save = logit_save
            weight=weight,
            threshold=threshold
            )

Epochs: 100%|██████████| 3/3 [00:01<00:00,  1.94it/s, recon 5778.181 kl 314.297 o_class=4.715 c_class=2.746 class_t=0.203 supLoss=0.000 center=0.000]


In [22]:
del train_loader
del test_adata_loader_1
test_adata_loader = load_test_dataset(test_adata,shuffle=False,drop_last=False)
train_all_loader = load_test_dataset(train_adata,shuffle=False,drop_last=False)

In [26]:
#predict result and embedding of test dataset
test_pred_label,test_prob,test_embedding = model.predict_class(test_adata_loader,device=device)



In [27]:
#metric computation
origin_label = np.unique(train_adata.obs["Label"].values)
test_adata.obs['pred_label'] = le.inverse_transform(test_pred_label)
test_embedding = AnnData(test_embedding,obs=test_adata.obs)
test_prob_select = test_prob[np.arange(0,len(test_pred_label)),test_pred_label]
test_embedding.obs["score"] = test_prob_select
test_pred_label[test_prob_select < 0.5] = class_dim
print("EAS_EpiAnno:{} , same:{} , diff:{}".format(*compute_EAS_EpiAnno(y_pred=test_pred_label,origin_label=origin_label,y_true = test_adata.obs["Label"].values)))
print("EAS:{}".format(compute_EAS(y_pred=test_pred_label,y_true = test_adata.obs["Label"].values,unknown=class_dim)))

EAS_EpiAnno:-0.007659714219782954 , same:7352 , diff:4862
EAS:-0.5893876768584285


In [28]:
#save result
test_embedding.write('../final_result/' + save_path +"_test_embedding.h5ad")