## Here we use the trained model to extract attention and analyze coding and non-coding genes

In [1]:
import os
import glob
import time
import math
import datetime
import argparse
import numpy as np
import pandas as pd
import scanpy as sc
import mindspore as ms
import mindspore.numpy as mnp
import mindspore.scipy as msc
import mindspore.dataset as ds
from tqdm import tqdm,trange
from mindspore import nn,ops
from scipy.sparse import csr_matrix as csm
from mindspore.amp import FixedLossScaleManager,all_finite,DynamicLossScaleManager
from mindspore.train import Model, CheckpointConfig, ModelCheckpoint, LossMonitor
from mindspore.context import ParallelMode
from mindspore.communication import init, get_rank, get_group_size
from mindspore.parallel._utils import _get_parallel_mode
from mindspore.common.initializer import initializer, XavierNormal

import sys
sys.path.append('..')
from config import Config
from model import *
from metrics import *
from utils import Wrapper,WrapperWithLossScaleCell,load_dist_model
from utils import WarmCosineDecay,Adam,AdamWeightDecay,set_weight_decay
from data_process import Prepare

import warnings
warnings.filterwarnings("ignore")

[ERROR] ME(1157347:139791510296384,MainProcess):2024-12-12-14:42:20.267.318 [mindspore/run_check/_check_version.py:230] Cuda ['10.1', '11.1', '11.6'] version(libcudart*.so need by mindspore-gpu) is not found. Please confirm that the path of cuda is set to the env LD_LIBRARY_PATH, or check whether the CUDA version in wheel package and the CUDA runtime in current device matches. Please refer to the installation guidelines: https://www.mindspore.cn/install
[ERROR] ME(1157347:139791510296384,MainProcess):2024-12-12-14:42:20.334.103 [mindspore/run_check/_check_version.py:230] Cuda ['10.1', '11.1', '11.6'] version(libcudnn*.so need by mindspore-gpu) is not found. Please confirm that the path of cuda is set to the env LD_LIBRARY_PATH, or check whether the CUDA version in wheel package and the CUDA runtime in current device matches. Please refer to the installation guidelines: https://www.mindspore.cn/install


#### Model and Dataset

In [2]:
def freeze_module(module,filter_tag=[None]):
    for param in module.trainable_params():
        x=False
        for tag in filter_tag:
            if tag and tag in param.name:
                x=True
                break
        param.requires_grad = x
class Backbone(nn.Cell):
    def __init__(self,n_genes,cfg,shard=None,**kwargs):
        super().__init__()
        self.depth=cfg.enc_nlayers
        self.if_cls=cfg.label
        self.n_genes=n_genes
        self.add_zero=cfg.add_zero and not cfg.pad_zero
        self.pad_zero=cfg.pad_zero
        # tensor
        self.gene_emb=ms.Parameter(
            initializer(XavierNormal(0.5),[n_genes+1+(-n_genes-1)%8,cfg.enc_dims])
        )
        self.cls_token=ms.Parameter(initializer(XavierNormal(0.5),[1,1,cfg.enc_dims]))
        self.gene_emb[0,:]=0
        # layer
        self.value_enc=ValueEncoder(cfg.enc_dims,shard=shard)
        self.encoder=nn.CellList([
            RetentionLayer(
                cfg.enc_dims,cfg.enc_num_heads,cfg.enc_nlayers,
                cfg.enc_dropout*i/cfg.enc_nlayers, cfg.lora,
                cfg.recompute,shard=shard
            )
            for i in range(cfg.enc_nlayers)
        ])
        self.one=P.Ones()
        self.zero=P.Zeros()
        self.tile=P.Tile()
        self.gather=P.Gather()
        self.maskmul=P.Mul()
        self.posa=P.Add()
        self.rsqrt=P.Rsqrt()
        self.cat1=P.Concat(1)
        self.sum=P.ReduceSum(True)
        self.detach=P.StopGradient()
    def construct(self,expr,gene,zero_idx):
        b,l=gene.shape
        gene_emb=self.gather(self.gene_emb,gene,0)
        expr_emb,unmask=self.value_enc(expr)
        len_scale=self.detach(self.rsqrt(self.sum(zero_idx,-1)-1).reshape(b,1,1,1))

        expr_emb=self.posa(gene_emb,expr_emb)
        cls_token=self.tile(self.cls_token,(b,1,1))
        expr_emb=self.cat1((cls_token,expr_emb))
        expr_emb=self.maskmul(expr_emb,zero_idx.reshape(b,-1,1))
        mask_pos=zero_idx.reshape(b,1,-1,1)
        for i in range(self.depth):
            expr_emb=self.encoder[i](
                expr_emb,
                v_pos=len_scale,
                attn_mask=mask_pos
            )
        return expr_emb
class Net(nn.Cell):
    def __init__(self,backbone,cfg,shard=None,**kwargs):
        super().__init__()
        # const
        self.num_class=cfg.num_cls
        self.extractor=backbone
        cls_weight=kwargs.get('cls_weight',np.ones(cfg.num_cls))
        self.weight=ms.Tensor(cls_weight,ms.float32)
        self.cluster_emb=ms.Parameter(
            initializer(XavierNormal(0.5),[cfg.num_cls,cfg.enc_dims])
        )
        self.query_layer=nn.CellList([
            CrossRetentionLayer(cfg.enc_dims,cfg.enc_num_heads,cfg.enc_dropout,False)
            for i in range(2)
        ])
        self.classifier=nn.Dense(cfg.enc_dims,1,has_bias=False)
        # operator
        self.tile=P.Tile()
        self.slice=P.Slice()
        self.cat1=P.Concat(1)
        self.mm=P.MatMul(transpose_b=True)
        self.logsoftmax=P.LogSoftmax(-1)
        # loss
        self.nll_loss=ops.NLLLoss()
        self.logger=ops.ScalarSummary()
        self.parallel_mode=_get_parallel_mode()
        self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
        if self.is_distributed:
            self.allgather=ops.AllGather()
    def forward(self,expr,gene,zero_idx):
        emb=self.extractor(expr,gene,zero_idx)
        cls_token,expr_emb=emb[:,0],emb[:,1:]
        b,l,d=expr_emb.shape
        attn_mask=self.slice(zero_idx,(0,1),(-1,-1))
        clst_emb=self.cat1((cls_token.reshape(-1,1,d),self.tile(self.cluster_emb.astype(cls_token.dtype).reshape(1,-1,d),(b,1,1))))
        for query in self.query_layer:
            clst_emb=query(clst_emb,y=expr_emb,attn_mask=attn_mask.reshape(b,1,-1,1))
        cls_token,cluster=clst_emb[:,0],clst_emb[:,1:]
        labelpred1=self.classifier(cluster).reshape(b,-1)
        labelpred2=self.mm(
            cls_token,self.cluster_emb.astype(cls_token.dtype)
        )
        return labelpred1,labelpred2,cls_token
    def construct(
        self,nonz_data,nonz_gene,zero_idx,label
    ):
        labelpred1,labelpred2,cls_token=self.forward(
            nonz_data,nonz_gene,zero_idx
        )
        logits1=self.logsoftmax(labelpred1.astype(ms.float32))
        logits2=self.logsoftmax(labelpred2.astype(ms.float32))
        loss1=self.nll_loss(logits1,label,self.weight.astype(ms.float32))[0]
        loss2=self.nll_loss(logits2,label,self.weight.astype(ms.float32))[0]
        self.logger('gw_celoss',loss1)
        self.logger('cw_celoss',loss2)
        loss=loss1+loss2
        if self.training:
            return loss
        else:
            if self.is_distributed:
                loss=self.allgather(loss).mean()
                label=self.allgather(label).reshape(-1,)
                labelpred1=self.allgather(labelpred1).reshape(-1,self.num_class)
                labelpred2=self.allgather(labelpred2).reshape(-1,self.num_class)
            return loss,labelpred1,label,cls_token
def read_h5ad(path,fold):
    suffix=path.split('.')[-1]
    if suffix=='h5ad':
        adata=sc.read_h5ad(path)
    else:
        adata=sc.read_10x_h5(path)
    print('origin shape:',adata.shape,len(adata.obs['cell_type'].unique()))
    batch=adata.obs.batch.unique()[fold]
    train=adata[adata.obs.batch!=batch]
    test=adata[adata.obs.batch==batch]
    train_type,train_freq=np.unique(train.obs['cell_type'],return_counts=True)
    test_type,test_freq=np.unique(test.obs['cell_type'],return_counts=True)
    train_type=train_type[train_freq>10]
    common_type=np.intersect1d(train_type,test_type)
        
    data=adata.X.astype(np.float32)
    T=adata.X.sum(1)
    data=csm(np.round(data/np.maximum(1,T/1e5,dtype=np.float32)))
    data.eliminate_zeros()
    adata.X=data
    
    adata=adata[adata.obs['cell_type'].isin(common_type)]
    print('filtered shape:',adata.shape,len(adata.obs['cell_type'].unique()))
    return adata,batch

class SCrna():
    def __init__(self,adata,batch,mode='train',prep=True):
        self.cls=len(adata.obs['cell_type'].unique())
        if mode=="train":
            adata=adata[adata.obs.batch!=batch]
        else:
            adata=adata[adata.obs.batch==batch]
        self.gene_info=pd.read_csv(f'../csv/expand_gene_info.csv',index_col=0,header=0)
        self.geneset={j:i+1 for i,j in enumerate(self.gene_info.index)}
        
        gene=np.intersect1d(adata.var_names,self.gene_info.index)
        adata=adata[:,gene].copy()
        adata.obs['cell_type']=adata.obs['cell_type'].astype('category')
        label=adata.obs['cell_type'].cat.codes.values
        adata.obs['label']=label
        if prep:
            adata.layers['x_normed']=sc.pp.normalize_total(adata,target_sum=1e4,inplace=False)['X']
            adata.layers['x_log1p']=adata.layers['x_normed']
            sc.pp.log1p(adata,layer='x_log1p')
        if len(adata)==0:
            raise Exception('samples are filtered')
        self.adata=adata
        self.id2label=adata.obs['cell_type'].cat.categories.values
        self.gene=np.array([self.geneset[i] for i in self.adata.var_names]).astype(np.int32)
        self.cls=len(adata.obs['cell_type'].unique())
        self.label=self.adata.obs['label'].values.astype(np.int32)
        print(f'{mode} adata:',adata.shape,self.cls,batch)
        if prep:
            self.data=self.adata.layers['x_log1p'].A.astype(np.float32)
        else:
            self.data=self.adata.X.astype(np.int32)
    def __len__(self):
        return len(self.adata)
    def __getitem__(self,idx):
        data=self.data[idx].reshape(-1)
        label=self.label[idx]
        return data,self.gene,label
def build_dataset(
    data,prep,batch,
    rank_size=None,
    rank_id=None,
    drop=True,
    shuffle=True
):
    dataset = ds.GeneratorDataset(
        data, 
        column_names=['data','gene','label'],
        shuffle=shuffle,
        num_shards=rank_size, 
        shard_id=rank_id
    )
    dataset = dataset.map(
        prep.seperate, input_columns=['data'],
        output_columns=['data', 'nonz','zero']
    )
    dataset = dataset.map(
        prep.sample, input_columns=['data','nonz','zero'],
        output_columns=['data','nonz','cuted','z_sample','seq_len']
    )
    dataset = dataset.map(
        prep.compress, input_columns=['data','nonz'],
        output_columns=['data','nonz_data', 'nonz']
    )
    dataset = dataset.map(
        prep.compress, input_columns=['gene','nonz'],
        output_columns=['gene','nonz_gene', 'nonz']
    )
    dataset = dataset.map(
        prep.attn_mask, input_columns=['seq_len'],
        output_columns=['zero_idx']
    )
    dataset = dataset.map(prep.pad_zero, input_columns=['nonz_data'])
    dataset = dataset.map(prep.pad_zero, input_columns=['nonz_gene'])
    dataset=dataset.project(
        columns=['nonz_data','nonz_gene','zero_idx','label']
    )
    dataset = dataset.batch(
        batch,
        num_parallel_workers=4, 
        drop_remainder=drop, 
    )
    return dataset

#### Config and loading data

In [3]:
id = 0
dist = False
enhance = False
epoch = 1
batch = 1
fold = 1
data = 'PBMC'

datapath = '../dataset/PBMC.h5ad'
model_path = "../checkpoint/PBMC-1-30_1055.ckpt"

In [4]:
ms.set_context(
    device_target='GPU', 
    mode=ms.GRAPH_MODE,
    device_id=id,
)
cfg=Config()
rank_id = None
rank_size = None
cfg.nonz_len=2048
if dist:
    ms.set_auto_parallel_context(
        parallel_mode=ms.ParallelMode.DATA_PARALLEL, 
        parameter_broadcast=True,
        gradients_mean=True,
        comm_fusion={"allreduce": {"mode": "auto", "config": None}},
    )
    init()
    rank_id = get_rank()
    rank_size = get_group_size()
ms.set_seed(0)
adata,batch=read_h5ad(f"{datapath}/{data}.h5ad",fold)
trainset=SCrna(adata,batch,mode='train')
testset=SCrna(adata,batch,mode='test')
cfg.enc_dims=1536
cfg.enc_nlayers=2
cfg.enc_num_heads=48
cfg.recompute=False
cfg.num_cls=trainset.cls
cfg.pad_zero=True
cut=None
prep=Prepare(
    cfg.nonz_len,pad=1,mask_ratio=0,
    dw=False,zero_len=None,
    cut=cut,random=False
)
test_loader=build_dataset(
    testset,
    prep,
    1,
    drop=True,
    shuffle=False,
    rank_size=rank_size,
    rank_id=rank_id,
)

origin shape: (18868, 6998) 7
filtered shape: (18868, 6998) 7
train adata: (16893, 5936) 7 test
test adata: (1975, 5936) 7 test


#### Load trained Model

In [None]:
backbone=Backbone(len(trainset.geneset),cfg,shard=None)
model=Net(backbone,cfg,shard=None)

model_path = "../checkpoint/PBMC-1-30_1055.ckpt"
ms.load_param_into_net(model, ms.load_checkpoint(model_path))

([],
 ['global_step',
  'learning_rate',
  'beta1_power',
  'beta2_power',
  'moment1.cluster_emb',
  'moment1.query_layer.0.attn1.q_proj.weight',
  'moment1.query_layer.0.attn1.k_proj.weight',
  'moment1.query_layer.0.attn1.v_proj.weight',
  'moment1.query_layer.0.attn1.u_proj.weight',
  'moment1.query_layer.0.attn1.o_proj.weight',
  'moment1.query_layer.0.attn2.q_proj.weight',
  'moment1.query_layer.0.attn2.k_proj.weight',
  'moment1.query_layer.0.attn2.v_proj.weight',
  'moment1.query_layer.0.attn2.u_proj.weight',
  'moment1.query_layer.0.attn2.o_proj.weight',
  'moment1.query_layer.0.ffn.u_proj.weight',
  'moment1.query_layer.0.ffn.v_proj.weight',
  'moment1.query_layer.0.ffn.o_proj.weight',
  'moment1.query_layer.0.post_norm1.gamma',
  'moment1.query_layer.0.post_norm1.beta',
  'moment1.query_layer.0.post_norm2.gamma',
  'moment1.query_layer.0.post_norm2.beta',
  'moment1.query_layer.0.post_norm3.gamma',
  'moment1.query_layer.0.post_norm3.beta',
  'moment1.query_layer.1.attn1.q_p

##### Retrieve CellFM's attention weights

In [None]:
from tqdm import tqdm
from collections import defaultdict
attention_map = {}
for i in range(cfg.num_cls):
    attention_map[i] = defaultdict(list)

attention_genes = {}
for i in range(cfg.num_cls):
    attention_genes[i] = []


pred_list = []
label_list = []

for idx, data in tqdm(enumerate(test_loader), total=len(test_loader)):
    nonz_data,nonz_gene,zero_idx,label = data
    emb = model.extractor(nonz_data, nonz_gene, zero_idx)
    cls_token,expr_emb=emb[:,0],emb[:,1:]
    b,l,d=expr_emb.shape
    attn_mask=model.slice(zero_idx,(0,1),(-1,-1))
    clst_emb=model.cat1((cls_token.reshape(-1,1,d),model.tile(model.cluster_emb.astype(cls_token.dtype).reshape(1,-1,d),(b,1,1))))
    for query in model.query_layer[:-1]:
        clst_emb=query(clst_emb,y=expr_emb,attn_mask=attn_mask.reshape(b,1,-1,1))
    attn2 = model.query_layer[-1].attn2
    q,k,v,u = attn2.qkvu_compute(clst_emb, expr_emb)

    _,l1,d=q.shape
    _,l2,d=k.shape
    Q = attn2.transpose1(P.Reshape()(q,(-1,l1,48,attn2.head_dims)),(0,2,1,3))
    K = attn2.transpose1(P.Reshape()(k,(-1,l2,48,attn2.head_dims)),(0,2,1,3))
    V = attn2.transpose1(P.Reshape()(v,(-1,l2,48,attn2.head_dims)),(0,2,1,3))
    U = attn2.transpose1(P.Reshape()(u,(-1,l1,48,attn2.head_dims)),(0,2,1,3))
    
    Q=attn2.kernelQ(Q)
    K=attn2.kernelK(K)
    U=attn2.kernelU(U)
    if attn_mask is not None:
        K=attn2.maskmul(K,attn_mask.reshape(b,1,-1,1))
    Q=attn2.div(Q,attn2.scale)
    K=attn2.div(K,attn2.scale)

    attn_scores = Q @ K.permute(0,1,3,2)

    attn_scores = attn_scores[:, :, 0, :]

    order = ms.ops.argsort(attn_scores, axis=-1)
    attn_scores = ms.ops.argsort(order, axis=-1).astype(ms.float32) / 2048

    reduce_mean = ops.ReduceMean(keep_dims=False)
    attn_scores = reduce_mean(attn_scores, axis=1)


    attn_scores = attn_scores.asnumpy()[0]

    clst_emb=model.query_layer[-1](clst_emb,y=expr_emb,attn_mask=attn_mask.reshape(b,1,-1,1))
    cls_token,cluster=clst_emb[:,0],clst_emb[:,1:]
    labelpred1=model.classifier(cluster).reshape(b,-1)

    pred_list.append(labelpred1)
    label_list.append(label)

    cell_type = label.asnumpy()[0]
    gene_names = nonz_gene.asnumpy()[0]


    for idx, gene_idx in enumerate(gene_names):
        attention_map[cell_type][gene_idx].append(attn_scores[idx])


pred_list = ops.operations.Concat(axis=0)(pred_list)
label_list = ops.operations.Concat(axis=0)(label_list)



  0%|          | 0/1975 [00:00<?, ?it/s][ERROR] CORE(1157347,7f23bf505740,python):2024-12-12-14:42:52.324.374 [mindspore/core/utils/file_utils.cc:253] GetRealPath] Get realpath failed, path[/tmp/ipykernel_1157347/1715952173.py]
[ERROR] CORE(1157347,7f23bf505740,python):2024-12-12-14:42:52.324.414 [mindspore/core/utils/file_utils.cc:253] GetRealPath] Get realpath failed, path[/tmp/ipykernel_1157347/1715952173.py]
[ERROR] CORE(1157347,7f23bf505740,python):2024-12-12-14:42:52.325.316 [mindspore/core/utils/file_utils.cc:253] GetRealPath] Get realpath failed, path[/tmp/ipykernel_1157347/1715952173.py]
100%|██████████| 1975/1975 [04:14<00:00,  7.75it/s]


#### Select genes with high attention weights

In [7]:

for i in range(cfg.num_cls):
    for k in attention_map[i]:
        attention_map[i][k] = sum(attention_map[i][k]) / len(attention_map[i][k])

for i in range(cfg.num_cls):
    attention_map[i] = sorted(attention_map[i].items(), key=lambda x:x[1], reverse=True)


# attention_map[0]

In [None]:
import pandas as pd
df = pd.read_csv('../csv/expand_gene_info.csv',index_col=0,header=0)

geneset={i+1:j for i,j in enumerate(df.index)}
feature = dict(zip(df.index, df['feature']))

id2label=adata.obs['cell_type'].cat.categories.values


In [9]:
topk = 100
for i in range(cfg.num_cls):
    print(id2label[i])
    topk_gene = [gene[0] for  gene in attention_map[i][:topk]]
    topk_gene = [geneset[gene] for gene in topk_gene]
    non_encoded_gene, encoded_gene = [], []
    for gene in topk_gene:
        if feature[gene] != 'protein coding':
            non_encoded_gene.append(gene)
        else:
            encoded_gene.append(gene)
    print('encoded gene:', encoded_gene)
    print('non encoded gene:', non_encoded_gene)
    print('--------------')
    

CD4T
encoded gene: ['CDO1', 'FAM50B', 'RAPGEF6', 'CCDC51', 'MAP9', 'PRR12', 'SMNDC1', 'PMM2', 'DPM3', 'ZNF599', 'TMTC3', 'CNPY4', 'F5', 'PIAS2', 'ZNF771', 'PARN', 'IL18R1', 'PCP2', 'WARS2', 'SHPRH', 'SEC14L2', 'TPRG1', 'ZNF580', 'PGAM5', 'EPS15L1', 'MAGEE1', 'RMDN3', 'CBL', 'NMB', 'HERC2', 'FBLN5', 'MYBL1', 'GJB6', 'NIFK', 'TBL1X', 'HSPH1', 'TXLNG', 'RPS6KB1', 'RAB11FIP4', 'SMC1A', 'TOMM34', 'CENPE', 'CLEC11A', 'NR2C2AP', 'LTN1', 'ZNF256', 'COQ7', 'CRY2', 'GNB1L', 'SCCPDH', 'CEACAM21', 'ECE2', 'HIRIP3', 'ZBTB48', 'TRIM41', 'XXYLT1', 'ABCB8', 'ZNF700', 'PSPH', 'TUBB1', 'TMEM129', 'STK19', 'UBASH3A', 'ZBTB2', 'ASPSCR1', 'SPP1', 'PJA1', 'PDLIM2', 'DEF6', 'PEX6', 'FLT1', 'PCIF1', 'GDPD5', 'ATG9B', 'ST6GALNAC1', 'PHF6', 'MLLT11', 'ATG9A', 'ELOVL4', 'SFMBT2', 'PICK1', 'ZNF85', 'LIN9', 'NQO1', 'ATR', 'SEL1L', 'NCKIPSD', 'FCGR3A', 'ZMYM2', 'PRC1', 'RCC1', 'HBG2', 'STXBP4', 'RORC', 'MAP2K3']
non encoded gene: ['MIR29A', 'IQCH-AS1', 'CD27-AS1', 'ENTPD1-AS1', 'ZNF337-AS1']
--------------
CD14+Mon