In [1]:
import os
import sys
import glob
import time
import math
import datetime
import argparse
import numpy as np
import pandas as pd
import scanpy as sc
import pickle as pk
import mindspore as ms
import mindspore.nn as nn
import mindspore.numpy as mnp
import mindspore.scipy as msc
import mindspore.dataset as ds
from tqdm import tqdm,trange
from mindspore import nn,ops
from scipy.sparse import csr_matrix as csr
from mindspore.ops import operations as P
from mindspore.amp import FixedLossScaleManager,all_finite,DynamicLossScaleManager
from mindspore.train import Model, CheckpointConfig, ModelCheckpoint, LossMonitor, Accuracy
from mindspore.context import ParallelMode
from mindspore.communication import init, get_rank, get_group_size
from mindspore.parallel._utils import _get_parallel_mode
from mindspore.common.initializer import initializer, XavierNormal

In [2]:
sys.path.append('..')
from utils import Wrapper
from config import Config
from metrics import annote_metric
from genefunc_model import *

In [3]:
class SCrna():
    def __init__(self,path,data,fold,mode):
        self.mode=mode
        adata=sc.read_h5ad(f'{path}/t123.h5ad')
        self.gene_info=pd.read_csv(f'../csv/gene_info.csv',index_col=0,header=0)
        common_gene=np.intersect1d(adata.var_names,self.gene_info.index)
        self.adata=adata[:,common_gene].copy()
        gene=self.adata.var[self.adata.var[f'train_{data}']>-1]
        idx=gene[f'train_{data}']==fold
        self.geneset={j:i+1 for i,j in enumerate(self.gene_info.index)}
        if mode=='train':
            self.gene=np.array([self.geneset[i] for i in gene[~idx].index]).astype(np.int32)
            self.label=gene[f'{data}'][~idx].values
        else:
            self.gene=np.array([self.geneset[i] for i in gene[idx].index]).astype(np.int32)
            self.label=gene[f'{data}'][idx].values
    def __len__(self):
        return len(self.gene)
    def __getitem__(self,idx):
        return self.gene[idx],self.label[idx]

In [4]:
def build_dataset(
    data,batch,
    mask_rate=0.2,
    drop=True,
    shuffle=True,
    rank_size=None,
    rank_id=None,
):
    dataset = ds.GeneratorDataset(
        data, 
        column_names=["gene",'label'],
        shuffle=shuffle,
        num_shards=rank_size, 
        shard_id=rank_id
    )
    dataset = dataset.batch(
        batch,
        num_parallel_workers=4, 
        drop_remainder=drop, 
    )
    return dataset

In [5]:
ms.set_context(
    device_target='Ascend', 
    mode=ms.GRAPH_MODE,
    device_id=0,
)
ms.set_seed(0)

In [6]:
para=ms.load_checkpoint("../weights/base_weight.ckpt")

In [7]:
acc=[]
for i in range(5):
    print(f"-----------fold {i}-----------")
    cfg=Config()
    model=MLP(para['gene_emb'].value(),cfg)
    optimizer=nn.Adam(model.trainable_params(),1e-4,weight_decay=1e-5)
    wrapper=Wrapper(model,optimizer)
    trainer=Model(
        wrapper,
        amp_level='O0',
        eval_network=model,
        metrics={
            'accuracy':annote_metric(2,key='accuracy'),
        },
        eval_indexes=[0,1,2]
    )
    scrna=SCrna('../datasets/processed/','t1',fold=i,mode='train')
    trainset=build_dataset(
        scrna,4,
        drop=True
    )
    sctest=SCrna('../datasets/processed/','t1',fold=i,mode='test')
    testset=build_dataset(
        sctest,
        len(sctest),
        drop=False
    )
    loss_cb = LossMonitor(50)
    cbs=[loss_cb]
    trainer.train(30,trainset,callbacks=cbs)
    acci=trainer.eval(testset)['accuracy']
    print(f'Accuracy on fold {i} is {acci:.4f}')
    acc.append(acci)
print(f"Average accuracy on t1 is {sum(acc)/5:.4f}")

-----------fold 0-----------
epoch: 1 step: 50, loss is 0.6858367919921875
epoch: 2 step: 15, loss is 0.628882646560669
epoch: 2 step: 65, loss is 0.5618929862976074
epoch: 3 step: 30, loss is 0.4210722744464874
epoch: 3 step: 80, loss is 0.9581413865089417
epoch: 4 step: 45, loss is 0.3028731644153595
epoch: 5 step: 10, loss is 0.942497193813324
epoch: 5 step: 60, loss is 0.7543084621429443
epoch: 6 step: 25, loss is 0.30059099197387695
epoch: 6 step: 75, loss is 0.5268030166625977
epoch: 7 step: 40, loss is 0.3934265673160553
epoch: 8 step: 5, loss is 0.0921880453824997
epoch: 8 step: 55, loss is 0.48324501514434814
epoch: 9 step: 20, loss is 0.30186137557029724
epoch: 9 step: 70, loss is 0.7508274912834167
epoch: 10 step: 35, loss is 0.08719252049922943
epoch: 10 step: 85, loss is 0.7976909279823303
epoch: 11 step: 50, loss is 0.061052706092596054
epoch: 12 step: 15, loss is 0.33150607347488403
epoch: 12 step: 65, loss is 0.06424576044082642
epoch: 13 step: 30, loss is 0.23686097562