In [1]:
import os
import glob
import time
import math
import datetime
import argparse
import numpy as np
import pandas as pd
import scanpy as sc
import pickle as pk
import mindspore as ms
import mindspore.numpy as mnp
import mindspore.scipy as msc
import mindspore.dataset as ds
from tqdm import tqdm,trange
from mindspore import nn,ops
from scipy.sparse import csr_matrix as csr
from mindspore.ops import operations as P
from mindspore.amp import FixedLossScaleManager,all_finite,DynamicLossScaleManager
from mindspore.train import Model, CheckpointConfig, ModelCheckpoint, LossMonitor, Accuracy
from mindspore.context import ParallelMode
from mindspore.communication import init, get_rank, get_group_size
from mindspore.parallel._utils import _get_parallel_mode
from mindspore.common.initializer import initializer, XavierNormal

import sys
sys.path.append('..')
from config import Config
from metrics import annote_metric


from utils import Wrapper
import warnings  
warnings.filterwarnings("ignore")

[ERROR] ME(1427609:140538876626752,MainProcess):2024-12-16-22:39:06.223.265 [mindspore/run_check/_check_version.py:230] Cuda ['10.1', '11.1', '11.6'] version(libcudart*.so need by mindspore-gpu) is not found. Please confirm that the path of cuda is set to the env LD_LIBRARY_PATH, or check whether the CUDA version in wheel package and the CUDA runtime in current device matches. Please refer to the installation guidelines: https://www.mindspore.cn/install
[ERROR] ME(1427609:140538876626752,MainProcess):2024-12-16-22:39:06.305.193 [mindspore/run_check/_check_version.py:230] Cuda ['10.1', '11.1', '11.6'] version(libcudnn*.so need by mindspore-gpu) is not found. Please confirm that the path of cuda is set to the env LD_LIBRARY_PATH, or check whether the CUDA version in wheel package and the CUDA runtime in current device matches. Please refer to the installation guidelines: https://www.mindspore.cn/install


# Model

In [2]:
def freeze_module(module,filter_tag=[None]):
    for param in module.trainable_params():
        x=False
        for tag in filter_tag:
            if tag and tag in param.name:
                x=True
                break
        param.requires_grad = x
        
class MLP(nn.Cell):
    def __init__(self,gene_emb,cfg,shard=None):
        super().__init__()
        self.depth=cfg.enc_nlayers
        self.gene_emb=ms.Parameter(ms.Tensor(gene_emb))
        emb_dims = gene_emb.shape[-1]
        self.gene_emb.requires_grad=False
        self.mlp=nn.SequentialCell(
            nn.Dense(emb_dims,emb_dims//2,has_bias=False),
            nn.Dropout(p=0.15),
            nn.SiLU(),
            nn.Dense(emb_dims//2,emb_dims//4,has_bias=False),
            nn.Dropout(p=0.15),
            nn.SiLU(),
            nn.Dense(emb_dims//4,2,has_bias=False),
        )
        self.gather=P.Gather()
        self.logsoftmax=P.LogSoftmax(-1)
        self.nll_loss=nn.NLLLoss()
    def construct(self,gene_id,label):
        gene_emb=self.gather(self.gene_emb,gene_id,0).astype(ms.float32)
        func_pred=self.mlp(gene_emb)
        loss=self.nll_loss(self.logsoftmax(func_pred),label)
        if self.training:
            return loss
        else:
            return loss,func_pred,label

class SCrna():
    def __init__(self,path,data,fold,mode,gene_index):
        self.mode=mode
        adata=sc.read_h5ad(f'{path}/t123.h5ad')

     
        common_gene=np.intersect1d(list(adata.var_names), list(gene_index))

        self.adata=adata[:,common_gene].copy()
        gene=self.adata.var[self.adata.var[f'train_{data}']>-1]
        
        idx=gene[f'train_{data}']==fold
    
        self.geneset={j:i+1 for i,j in enumerate(gene_index)}

        if mode=='train':
            self.gene=np.array([self.geneset[i] for i in gene[~idx].index]).astype(np.int32)
            self.label=gene[f'{data}'][~idx].values
        else:
            self.gene=np.array([self.geneset[i] for i in gene[idx].index]).astype(np.int32)
            self.label=gene[f'{data}'][idx].values

        
    def __len__(self):
        return len(self.gene)
    def __getitem__(self,idx):
        return self.gene[idx],self.label[idx]

def build_dataset(
    data,batch,
    mask_rate=0.2,
    drop=True,
    shuffle=True,
    rank_size=None,
    rank_id=None,
):
    dataset = ds.GeneratorDataset(
        data, 
        column_names=["gene",'label'],
        shuffle=shuffle,
        num_shards=rank_size, 
        shard_id=rank_id
    )
    dataset = dataset.batch(
        batch,
        num_parallel_workers=4, 
        drop_remainder=drop, 
    )
    return dataset



# Config

In [3]:
args = {}
args['id'] = 3
args['epoch'] = 30
args['batch'] = 4
args['fold'] = 2
args['fp16'] = False
args['data'] = "t1"
args['datapath'] = "../datasets/genefunction/"
args['readpath'] = '../'
args['savepath'] = '../checkpoint/genefunction/'
args['lr'] = 1e-4
class config:
    def __init__(self, args):
        self.id = args['id']
        self.epoch = args['epoch']
        self.batch = args['batch']
        self.fold = args['fold']
        self.fp16 = args['fp16']
        self.data = args['data']
        self.datapath = args['datapath'] 
        self.readpath = args['readpath'] 
        self.savepath = args['savepath'] 
        self.lr = args['lr']

args = config(args)

ms.set_context(
    device_target='GPU', 
    mode=ms.GRAPH_MODE,
    device_id=args.id,
)
cfg=Config()
rank_id = None
rank_size = None
ms.set_seed(0)
shard=None

# Loading Dataset

In [4]:
path = args.datapath

gene_index = pd.read_csv(f'../../csv/gene_info.csv',index_col=0,header=0).index
gene_emb = ms.load_checkpoint("../checkpoint/base_weight.ckpt")['gene_emb'].value()
scrna=SCrna(path,args.data,args.fold,mode='train',gene_index=gene_index)
trainset=build_dataset(
    scrna,args.batch,
    rank_size=rank_size,
    rank_id=rank_id,
    drop=True
)
sctest=SCrna(path,args.data,args.fold,mode='test',gene_index=gene_index)
testset=build_dataset(
    sctest,
    len(sctest),
    rank_size=rank_size,
    rank_id=rank_id,
    drop=False
)

# Evaluating

In [5]:
model=MLP(gene_emb, cfg)
lr=args.lr
optimizer=nn.Adam(model.trainable_params(),lr,weight_decay=1e-5)
wrapper=Wrapper(model,optimizer)
trainer=Model(
    wrapper,
    amp_level='O0',
    eval_network=model,
    metrics={
        'metrics':annote_metric(2),
    },
    eval_indexes=[0,1,2]
)
ckpt_config = CheckpointConfig(
    save_checkpoint_steps=args.epoch*len(trainset),
    keep_checkpoint_max=1,
    integrated_save=False,
    async_save=True,
)

loss_cb = LossMonitor(1)
now=datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
# print(f'Begin training {len(trainset)*args.epoch} steps at {now}')
cbs = [loss_cb]
trainer.fit(args.epoch,trainset,testset,callbacks=cbs)
print(f'dataset {args.data}:', len(scrna) + len(sctest))

epoch: 1 step: 1, loss is 0.6931521892547607
epoch: 1 step: 2, loss is 0.6931703090667725
epoch: 1 step: 3, loss is 0.69302898645401
epoch: 1 step: 4, loss is 0.6930884122848511
epoch: 1 step: 5, loss is 0.6930348873138428
epoch: 1 step: 6, loss is 0.6929506659507751
epoch: 1 step: 7, loss is 0.6930937767028809
epoch: 1 step: 8, loss is 0.6928531527519226
epoch: 1 step: 9, loss is 0.6927726864814758
epoch: 1 step: 10, loss is 0.693138837814331
epoch: 1 step: 11, loss is 0.6926484107971191
epoch: 1 step: 12, loss is 0.6928457021713257
epoch: 1 step: 13, loss is 0.6926879286766052
epoch: 1 step: 14, loss is 0.693278431892395
epoch: 1 step: 15, loss is 0.6921830773353577
epoch: 1 step: 16, loss is 0.6922993659973145
epoch: 1 step: 17, loss is 0.692756175994873
epoch: 1 step: 18, loss is 0.6924155950546265
epoch: 1 step: 19, loss is 0.6921775341033936
epoch: 1 step: 20, loss is 0.6924899816513062
epoch: 1 step: 21, loss is 0.6919459104537964
epoch: 1 step: 22, loss is 0.6922548413276672
ep

[ERROR] CORE(1427609,7fd1c1d18740,python):2024-12-16-22:39:46.857.623 [mindspore/core/utils/file_utils.cc:253] GetRealPath] Get realpath failed, path[/tmp/ipykernel_1427609/2429778555.py]
[ERROR] CORE(1427609,7fd1c1d18740,python):2024-12-16-22:39:46.857.665 [mindspore/core/utils/file_utils.cc:253] GetRealPath] Get realpath failed, path[/tmp/ipykernel_1427609/2429778555.py]
[ERROR] CORE(1427609,7fd1c1d18740,python):2024-12-16-22:39:46.857.678 [mindspore/core/utils/file_utils.cc:253] GetRealPath] Get realpath failed, path[/tmp/ipykernel_1427609/2429778555.py]


epoch: 2 step: 41, loss is 0.6324800848960876
epoch: 2 step: 42, loss is 0.5745666027069092
epoch: 2 step: 43, loss is 0.6076438426971436
epoch: 2 step: 44, loss is 0.5794374942779541
epoch: 2 step: 45, loss is 0.5442765951156616
epoch: 2 step: 46, loss is 0.5221415758132935
epoch: 2 step: 47, loss is 0.6672753095626831
epoch: 2 step: 48, loss is 0.6252565383911133
epoch: 2 step: 49, loss is 0.6101400256156921
epoch: 2 step: 50, loss is 0.5616727471351624
epoch: 2 step: 51, loss is 0.5852972865104675
epoch: 2 step: 52, loss is 0.4993596076965332
epoch: 2 step: 53, loss is 0.4719761610031128
epoch: 2 step: 54, loss is 0.6142653822898865
epoch: 2 step: 55, loss is 0.7408678531646729
epoch: 2 step: 56, loss is 0.5031419992446899
epoch: 2 step: 57, loss is 0.5690963864326477
epoch: 2 step: 58, loss is 0.5873899459838867
epoch: 2 step: 59, loss is 0.6310907602310181
epoch: 2 step: 60, loss is 0.4738936424255371
epoch: 2 step: 61, loss is 0.6972630620002747
epoch: 2 step: 62, loss is 0.61519