# CellFM Zero-shot for Cell Annotation Application

In [1]:
import os
import sys
import glob
import time
import math
import datetime
import argparse
import numpy as np
import pandas as pd
import scanpy as sc
import anndata as ad
import mindspore as ms
import mindspore.numpy as mnp
import mindspore.scipy as msc
import mindspore.dataset as ds
from tqdm import tqdm,trange
from mindspore import nn,ops
from scipy.sparse import csr_matrix as csm
from mindspore.ops import operations as P
from mindspore.amp import FixedLossScaleManager,all_finite,DynamicLossScaleManager
from mindspore.train import Model, CheckpointConfig, ModelCheckpoint, LossMonitor
from mindspore.context import ParallelMode
from mindspore.communication import init, get_rank, get_group_size
from mindspore.parallel._utils import _get_parallel_mode
from mindspore.common.initializer import initializer, XavierNormal

[ERROR] ME(1189632:140419653146432,MainProcess):2024-12-16-20:45:21.436.362 [mindspore/run_check/_check_version.py:230] Cuda ['10.1', '11.1', '11.6'] version(libcudart*.so need by mindspore-gpu) is not found. Please confirm that the path of cuda is set to the env LD_LIBRARY_PATH, or check whether the CUDA version in wheel package and the CUDA runtime in current device matches. Please refer to the installation guidelines: https://www.mindspore.cn/install
[ERROR] ME(1189632:140419653146432,MainProcess):2024-12-16-20:45:21.503.386 [mindspore/run_check/_check_version.py:230] Cuda ['10.1', '11.1', '11.6'] version(libcudnn*.so need by mindspore-gpu) is not found. Please confirm that the path of cuda is set to the env LD_LIBRARY_PATH, or check whether the CUDA version in wheel package and the CUDA runtime in current device matches. Please refer to the installation guidelines: https://www.mindspore.cn/install


In [2]:
sys.path.append('../..')

In [3]:
from config import Config
from annotation_model import *
from metrics import annote_metric
from utils import Wrapper
from data_process import Prepare

## Prepare training and test datasets

In [4]:
# Freezing the parameters of the backbone in the context of a zero-shot model.
def freeze_module(module,filter_tag=[None]):
    for param in module.trainable_params():
        x=False
        for tag in filter_tag:
            if tag and tag in param.name:
                x=True
                break
        param.requires_grad = x

In [5]:
# Loading training and testing datasets in H5AD format.
def read_h5ad(path):
    train_data = sc.read_h5ad(path+"/train.h5ad")
    test_data = sc.read_h5ad(path+"/test.h5ad")

    train_data.obs['train'] = 0
    test_data.obs['train']  = 2

    adata = ad.concat([train_data, test_data], join='outer')
    print('origin shape:',adata.shape,len(adata.obs['cell_type'].unique()))
        
    data=adata.X.astype(np.float32)
    T=adata.X.sum(1)
    data=csm(np.round(data/np.maximum(1,T/1e5,dtype=np.float32)))
    data.eliminate_zeros()
    adata.X=data
    
    return adata

In [6]:
class SCrna():
    def __init__(self,adata,mode='train',prep=True):
        self.cls=len(adata.obs['cell_type'].unique())
        if mode=="train":
            adata=adata[adata.obs.train==0]
        elif mode=='val':
            adata=adata[adata.obs.train==1]
        else:
            adata=adata[adata.obs.train==2]
        self.gene_info=pd.read_csv(f'../../csv/expand_gene_info.csv',index_col=0,header=0)
        self.geneset={j:i+1 for i,j in enumerate(self.gene_info.index)}
        
        gene=np.intersect1d(adata.var_names,self.gene_info.index)
        adata=adata[:,gene].copy()
        adata.obs['cell_type']=adata.obs['cell_type'].astype('category')
        label=adata.obs['cell_type'].cat.codes.values
        adata.obs['label']=label
        if prep:
            adata.layers['x_normed']=sc.pp.normalize_total(adata,target_sum=1e4,inplace=False)['X']
            adata.layers['x_log1p']=adata.layers['x_normed']
            sc.pp.log1p(adata,layer='x_log1p')
        self.adata=adata
        self.id2label=adata.obs['cell_type'].cat.categories.values
        self.gene=np.array([self.geneset[i] for i in self.adata.var_names]).astype(np.int32)
        self.cls=len(adata.obs['cell_type'].unique())
        self.label=self.adata.obs['label'].values.astype(np.int32)
        print(f'{mode} adata:',adata.shape,self.cls)
        if prep:
            self.data=self.adata.layers['x_log1p'].A.astype(np.float32)
        else:
            self.data=self.adata.X.astype(np.int32)
    def __len__(self):
        return len(self.adata)
    def __getitem__(self,idx):
        data=self.data[idx].reshape(-1)
        label=self.label[idx]
        return data,self.gene,label

In [7]:
# Creating a data loader
def build_dataset(
    data,prep,batch,
    rank_size=None,
    rank_id=None,
    drop=True,
    shuffle=True
):
    dataset = ds.GeneratorDataset(
        data, 
        column_names=['data','gene','label'],
        shuffle=shuffle,
        num_shards=rank_size, 
        shard_id=rank_id
    )
    dataset = dataset.map(
        prep.seperate, input_columns=['data'],
        output_columns=['data', 'nonz','zero']
    )
    dataset = dataset.map(
        prep.sample, input_columns=['data','nonz','zero'],
        output_columns=['data','nonz','cuted','z_sample','seq_len']
    )
    dataset = dataset.map(
        prep.compress, input_columns=['data','nonz'],
        output_columns=['data','nonz_data', 'nonz']
    )
    dataset = dataset.map(
        prep.compress, input_columns=['gene','nonz'],
        output_columns=['gene','nonz_gene', 'nonz']
    )
    dataset = dataset.map(
        prep.attn_mask, input_columns=['seq_len'],
        output_columns=['zero_idx']
    )
    dataset = dataset.map(prep.pad_zero, input_columns=['nonz_data'])
    dataset = dataset.map(prep.pad_zero, input_columns=['nonz_gene'])
    dataset=dataset.project(
        columns=['nonz_data','nonz_gene','zero_idx','label']
    )
    dataset = dataset.batch(
        batch,
        num_parallel_workers=4, 
        drop_remainder=drop, 
    )
    return dataset

In [8]:
# Here, you can choose the type and number of the GPU, such as Ascend and GPU.
ms.set_context(
    device_target='GPU', 
    mode=ms.GRAPH_MODE,
    device_id=0,
)
ms.set_seed(0)

In [9]:
adata=read_h5ad(f"../datasets/cell_annotion/Inter/Liver3")
trainset=SCrna(adata,mode='train')
testset=SCrna(adata,mode='test')

origin shape: (4014, 20007) 12
train adata: (3402, 15760) 12
test adata: (612, 15760) 12


In [10]:
cfg=Config()
cfg.num_cls=trainset.cls
cfg.enc_nlayers=2

In [11]:
prep=Prepare(
    cfg.nonz_len,pad=1,mask_ratio=0,random=False
)
train_loader=build_dataset(
    trainset,
    prep,
    16,
    drop=True,
    shuffle=True,
)
test_loader=build_dataset(
    testset,
    prep,
    1,
    drop=False,
    shuffle=False,
)

## Create the training model for CellFM and freeze the parameters of its backbone layer.

In [12]:
para=ms.load_checkpoint("../checkpoint/CellFM_80M_weight.ckpt")

In [13]:
backbone=Backbone(len(trainset.geneset),cfg)
ms.load_param_into_net(backbone, para)
model=Net(backbone,cfg)

In [14]:
freeze_module(model.extractor)

In [15]:
optimizer=nn.Adam(model.trainable_params(),1e-4,weight_decay=1e-5)
update_cell=nn.DynamicLossScaleUpdateCell(1,2,1000)
wrapper=Wrapper(model,optimizer)
trainer=Model(
    wrapper,
    eval_network=model,
    amp_level='O0',
    metrics={
        'accuracy':annote_metric(trainset.cls,key='accuracy'),
    },
    eval_indexes=[0,1,2]
)
loss_cb = LossMonitor(20)

In [16]:
loss_cb = LossMonitor(20)
ckpt_config = CheckpointConfig(
    save_checkpoint_steps=len(train_loader),
    keep_checkpoint_max=1,
    integrated_save=False,
    async_save=False
)
ckpt_cb = ModelCheckpoint(
    prefix=f'Liver3_zeroshot', 
    directory=f"../checkpoint/CellAnnotation/", 
    config=ckpt_config
)
cbs=[loss_cb,ckpt_cb]

## Model training and evaluation

In [17]:
trainer.train(30,train_loader,callbacks=cbs)

epoch: 1 step: 20, loss is 2.5083718299865723
epoch: 1 step: 40, loss is 1.8856178522109985
epoch: 1 step: 60, loss is 0.8571237325668335
epoch: 1 step: 80, loss is 0.25524070858955383
epoch: 1 step: 100, loss is 1.1458042860031128
epoch: 1 step: 120, loss is 0.6305822134017944
epoch: 1 step: 140, loss is 0.4231998324394226
epoch: 1 step: 160, loss is 0.7089542746543884
epoch: 1 step: 180, loss is 0.9662556648254395
epoch: 1 step: 200, loss is 0.2579959034919739
epoch: 2 step: 8, loss is 0.08099916577339172
epoch: 2 step: 28, loss is 0.10933481156826019
epoch: 2 step: 48, loss is 0.19778043031692505
epoch: 2 step: 68, loss is 0.05334336683154106
epoch: 2 step: 88, loss is 0.0036920057609677315
epoch: 2 step: 108, loss is 0.026573309674859047
epoch: 2 step: 128, loss is 0.8848670721054077
epoch: 2 step: 148, loss is 0.9530234336853027
epoch: 2 step: 168, loss is 1.1099324226379395
epoch: 2 step: 188, loss is 0.3019728660583496
epoch: 2 step: 208, loss is 0.5893732309341431
epoch: 3 step

In [18]:
ms.load_param_into_net(model, ms.load_checkpoint(ckpt_cb.latest_ckpt_file_name))
trainer.eval(test_loader)

{'accuracy': 0.9477124183006536}