In [1]:
import os
import sys
import glob
import time
import math
import datetime
import argparse
import numpy as np
import pandas as pd
import scanpy as sc
import mindspore as ms
import mindspore.numpy as mnp
import mindspore.scipy as msc
import mindspore.dataset as ds
from tqdm import tqdm,trange
from mindspore import nn,ops
from scipy.sparse import csr_matrix as csm
from mindspore.ops import operations as P
from mindspore.amp import FixedLossScaleManager,all_finite,DynamicLossScaleManager
from mindspore.train import Model, CheckpointConfig, ModelCheckpoint, LossMonitor
from mindspore.context import ParallelMode
from mindspore.communication import init, get_rank, get_group_size
from mindspore.parallel._utils import _get_parallel_mode
from mindspore.common.initializer import initializer, XavierNormal

In [2]:
sys.path.append('..')

In [3]:
from config import Config
from annotation_model import *
from metrics import annote_metric
from utils import Wrapper
from data_process import Prepare

In [4]:
def freeze_module(module,filter_tag=[None]):
    for param in module.trainable_params():
        x=False
        for tag in filter_tag:
            if tag and tag in param.name:
                x=True
                break
        param.requires_grad = x

In [5]:
def read_h5ad(path,var_rate=0,test_rate=0.3):
    suffix=path.split('.')[-1]
    if suffix=='h5ad':
        adata=sc.read_h5ad(path)
    else:
        adata=sc.read_10x_h5(path)
    print('origin shape:',adata.shape,len(adata.obs['cell_type'].unique()))
    adata.obs['train']=0
    adatas=[adata[adata.obs['batch']==i].copy() for i in adata.obs['batch'].unique()]
    for adatai in adatas:
        for i in adatai.obs['cell_type'].unique():
            idx=adatai.obs['cell_type']==i
            size=idx.sum()
            order=np.random.permutation(size)
            num1=int(np.ceil(size*test_rate))
            num2=int(np.ceil(size*var_rate))
            test=order[:num1]
            val=order[num1:num1+num2]
            test=idx.values.nonzero()[0][test]
            val=idx.values.nonzero()[0][val]
            adatai.obs['train'][test]=2
            adatai.obs['train'][val]=1
    adata=sc.concat(adatas,merge='same')
        
    data=adata.X.astype(np.float32)
    T=adata.X.sum(1)
    data=csm(np.round(data/np.maximum(1,T/1e5,dtype=np.float32)))
    data.eliminate_zeros()
    adata.X=data
    
    return adata

In [6]:
class SCrna():
    def __init__(self,adata,mode='train',prep=True):
        self.cls=len(adata.obs['cell_type'].unique())
        if mode=="train":
            adata=adata[adata.obs.train==0]
        elif mode=='val':
            adata=adata[adata.obs.train==1]
        else:
            adata=adata[adata.obs.train==2]
        self.gene_info=pd.read_csv(f'../csv/gene_info.csv',index_col=0,header=0)
        self.geneset={j:i+1 for i,j in enumerate(self.gene_info.index)}
        
        gene=np.intersect1d(adata.var_names,self.gene_info.index)
        adata=adata[:,gene].copy()
        adata.obs['cell_type']=adata.obs['cell_type'].astype('category')
        label=adata.obs['cell_type'].cat.codes.values
        adata.obs['label']=label
        if prep:
            adata.layers['x_normed']=sc.pp.normalize_total(adata,target_sum=1e4,inplace=False)['X']
            adata.layers['x_log1p']=adata.layers['x_normed']
            sc.pp.log1p(adata,layer='x_log1p')
        self.adata=adata
        self.id2label=adata.obs['cell_type'].cat.categories.values
        self.gene=np.array([self.geneset[i] for i in self.adata.var_names]).astype(np.int32)
        self.cls=len(adata.obs['cell_type'].unique())
        self.label=self.adata.obs['label'].values.astype(np.int32)
        print(f'{mode} adata:',adata.shape,self.cls)
        if prep:
            self.data=self.adata.layers['x_log1p'].A.astype(np.float32)
        else:
            self.data=self.adata.X.astype(np.int32)
    def __len__(self):
        return len(self.adata)
    def __getitem__(self,idx):
        data=self.data[idx].reshape(-1)
        label=self.label[idx]
        return data,self.gene,label

In [7]:
def build_dataset(
    data,prep,batch,
    rank_size=None,
    rank_id=None,
    drop=True,
    shuffle=True
):
    dataset = ds.GeneratorDataset(
        data, 
        column_names=['data','gene','label'],
        shuffle=shuffle,
        num_shards=rank_size, 
        shard_id=rank_id
    )
    dataset = dataset.map(
        prep.seperate, input_columns=['data'],
        output_columns=['data', 'nonz','zero']
    )
    dataset = dataset.map(
        prep.sample, input_columns=['data','nonz','zero'],
        output_columns=['data','nonz','cuted','z_sample','seq_len']
    )
    dataset = dataset.map(
        prep.compress, input_columns=['data','nonz'],
        output_columns=['data','nonz_data', 'nonz']
    )
    dataset = dataset.map(
        prep.compress, input_columns=['gene','nonz'],
        output_columns=['gene','nonz_gene', 'nonz']
    )
    dataset = dataset.map(
        prep.attn_mask, input_columns=['seq_len'],
        output_columns=['zero_idx']
    )
    dataset = dataset.map(prep.pad_zero, input_columns=['nonz_data'])
    dataset = dataset.map(prep.pad_zero, input_columns=['nonz_gene'])
    dataset=dataset.project(
        columns=['nonz_data','nonz_gene','zero_idx','label']
    )
    dataset = dataset.batch(
        batch,
        num_parallel_workers=4, 
        drop_remainder=drop, 
    )
    return dataset

In [8]:
ms.set_context(
    device_target='Ascend', 
    mode=ms.GRAPH_MODE,
    device_id=0,
)
ms.set_seed(0)

In [9]:
adata=read_h5ad(f"../datasets/processed/Pancrm.h5ad")
trainset=SCrna(adata,mode='train')
testset=SCrna(adata,mode='test')

origin shape: (14767, 15558) 15


A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  adatai.obs['train'][test]=2
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  adatai.obs['train'][val]=1


train adata: (10316, 15285) 15
test adata: (4451, 15285) 15


In [10]:
cfg=Config()
cfg.num_cls=trainset.cls

In [11]:
prep=Prepare(
    cfg.nonz_len,pad=1,mask_ratio=0,random=False
)
train_loader=build_dataset(
    trainset,
    prep,
    16,
    drop=True,
    shuffle=True,
)
test_loader=build_dataset(
    testset,
    prep,
    1,
    drop=False,
    shuffle=False,
)

In [12]:
para=ms.load_checkpoint("../weights/base_weight.ckpt")

In [13]:
backbone=Backbone(len(trainset.geneset),cfg)
ms.load_param_into_net(backbone, para)
model=Net(backbone,cfg)

In [14]:
freeze_module(model.extractor)

In [15]:
optimizer=nn.Adam(model.trainable_params(),1e-4,weight_decay=1e-5)
update_cell=nn.DynamicLossScaleUpdateCell(1,2,1000)
wrapper=Wrapper(model,optimizer)
trainer=Model(
    wrapper,
    eval_network=model,
    amp_level='O0',
    metrics={
        'accuracy':annote_metric(trainset.cls,key='accuracy'),
    },
    eval_indexes=[0,1,2]
)
loss_cb = LossMonitor(20)

In [16]:
loss_cb = LossMonitor(20)
ckpt_config = CheckpointConfig(
    save_checkpoint_steps=len(train_loader),
    keep_checkpoint_max=1,
    integrated_save=False,
    async_save=False
)
ckpt_cb = ModelCheckpoint(
    prefix=f'Pancrm_intra', 
    directory=f"../checkpoint/CellAnnotation/", 
    config=ckpt_config
)
cbs=[loss_cb,ckpt_cb]

In [17]:
trainer.train(30,train_loader,callbacks=cbs)

epoch: 1 step: 20, loss is 3.9814298152923584
epoch: 1 step: 40, loss is 4.81403923034668
epoch: 1 step: 60, loss is 4.15134334564209
epoch: 1 step: 80, loss is 3.271254539489746
epoch: 1 step: 100, loss is 3.0359766483306885
epoch: 1 step: 120, loss is 4.056429862976074
epoch: 1 step: 140, loss is 3.7438266277313232
epoch: 1 step: 160, loss is 4.359557151794434
epoch: 1 step: 180, loss is 3.1246750354766846
epoch: 1 step: 200, loss is 4.2687530517578125
epoch: 1 step: 220, loss is 3.5927300453186035
epoch: 1 step: 240, loss is 3.8465421199798584
epoch: 1 step: 260, loss is 4.6303606033325195
epoch: 1 step: 280, loss is 3.776381015777588
epoch: 1 step: 300, loss is 4.1960954666137695
epoch: 1 step: 320, loss is 3.611701011657715
epoch: 1 step: 340, loss is 3.3810300827026367
epoch: 1 step: 360, loss is 3.7454464435577393
epoch: 1 step: 380, loss is 3.231947183609009
epoch: 1 step: 400, loss is 3.565889358520508
epoch: 1 step: 420, loss is 3.692998170852661
epoch: 1 step: 440, loss is 2

In [17]:
ms.load_param_into_net(model, ms.load_checkpoint(ckpt_cb.latest_ckpt_file_name))
trainer.eval(test_loader)

{'accuracy': 0.9483262188272298}