In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

In [2]:
import torch
import os
from collections import Counter
import sklearn
import psutil
import numpy as np
import time
from loguru import logger

import pandas as pd
# import scanpy as sc
from tqdm import tqdm
import random
import sys
import multiprocessing
import anndata as ad
import gc
import pickle

root_path = os.path.abspath('#model_folder') # model folder
sys.path.append(os.path.abspath(root_path))


data_root_path = '#dataset_path/' #data folder
dataset = 'ECA_GO'
dataPath = os.path.join(data_root_path,dataset)
assert os.path.exists(dataPath)
tmp_dataPath = os.path.join(dataPath,'tmp_files')

In [3]:
from utils.hf_tokenizer import cellGenesisTokenizer

In [4]:
meta_info = torch.load(os.path.join(dataPath,'meta_info.pt'))
chars = meta_info['token_set']
tokenizer = cellGenesisTokenizer(chars)
n_express_level = 10

In [5]:
cell_index_gene_mixed = torch.load(os.path.join(dataPath,'cell_index_gene_mixed.pt'))
cell_index_test = cell_index_gene_mixed['test_index_after_nonamp_mix']
cell_train_index_after_amp_mix = cell_index_gene_mixed['train_index_after_amp_mix']
h5 = pd.HDFStore(os.path.join(tmp_dataPath,'cellhvg.h5'), 'r') #read original data
meta_data = pd.read_csv(os.path.join(dataPath,'meta_data.csv'),index_col=0,low_memory=False) # read corresponding metadata

In [15]:
task_pool = {'porg':['organ','region'],'pct':['MCT','cell_type']}
task_spt = {'porg':'<SPToken3>','pct':'<PCT>'}

In [6]:
def prepare_cell_generation_prompt_by_idx_for_test(idx, h5, meta_data, n_express_level,task_pool, task_seed = None):
    cellhvg = h5.select('data', where=f"index={repr(idx)}")
    genes_series = cellhvg.loc[idx]
    expressed_genes = genes_series[genes_series > 0].index.tolist()
#     random_sys.shuffle(expressed_genes)
    if task_seed is None:
        task_seed = torch.randint(0,len(task_pool),(1,)).item()
    task_cols = list(task_pool.values())[task_seed]
    metadata_series = meta_data.loc[idx, task_cols]
    metadata_list = metadata_series.tolist()
    expr_values = np.array(cellhvg.loc[idx, expressed_genes])
    max_expr = np.max(expr_values)
    bins = np.linspace(0, max_expr, n_express_level+1)
    binned_expr = np.digitize(expr_values, bins, right=True) # - 1
    binned_expr = np.append(binned_expr,[0]*(len(metadata_list)+1))
    return expressed_genes, metadata_list, binned_expr, task_seed

def process_data_for_test(idx,h5,meta_data,task_pool):
    expressed_genes, metadata_list, binned_expr, task_seed = prepare_cell_generation_prompt_by_idx_for_test(idx, h5, meta_data, \
                                                                             n_express_level, task_pool, task_seed = 1)

    prefix = expressed_genes + [task_spt[list(task_pool.keys())[task_seed]]] 
    ec_prefix = tokenizer.encode(prefix)
    suffix = metadata_list
    ec_suffix = tokenizer.encode(suffix)
    prefix_len = len(ec_prefix) - 1 
    data_len = len(ec_prefix) + len(ec_suffix)

    return (ec_prefix + ec_suffix, prefix_len, data_len, binned_expr)

In [7]:
from model.model import GPTConfig, cellGPTModel
import torch.nn.functional as F

In [8]:
from tools import eccosimple_ele

In [9]:
ckp_path = '/nfs/public/cell_gpt_data/dataHub/datasets/datasets/ECA_GO/model_hub/ckpt245000.pt'

In [10]:
lm2 = eccosimple_ele.from_pretrained(ckp_path=ckp_path, meta_info= meta_info,verbose=False)

[32m2023-11-04 01:00:17.316[0m | [1mINFO    [0m | [36mmodel.model[0m:[36m__init__[0m:[36m222[0m - [1mCurrent ele mode is: 1[0m
[32m2023-11-04 01:00:23.284[0m | [1mINFO    [0m | [36mmodel.model[0m:[36m__init__[0m:[36m261[0m - [1mnumber of parameters: 368.80M[0m


Using GPU


In [11]:
def get_top_k(scores, k=10):
    indices = np.argsort(scores)
    top_k_indices = indices[::-1][:k]
    return top_k_indices

## Find cell type specific markers in Heart

In [12]:
meta_sub = meta_data[(meta_data['organ'] == 'Heart')]

In [24]:
labels_ = []
cts_ = []
for ct in meta_sub.cell_type.unique():
    meta_indexs = meta_data[(meta_data['organ'] == 'Heart') & (meta_data['cell_type'] == ct)].index
    meta_sample_index = random.sample(meta_indexs.tolist(),500)
    n_corr=0
    marker_set = []
    for idx in tqdm(meta_sample_index):
        prefix_sufix, prefix_len, _, expr = process_data_for_test(idx,h5,meta_data,task_pool)
        prompt_ele = torch.tensor(expr[:prefix_len + 1]).unsqueeze(0)
        output_ = lm2.generate(tokenizer.decode(prefix_sufix[:prefix_len + 1]),
                expr_level=prompt_ele,
    #             ignore_Idx=anno.ignore_idx,
                generate=1560,
                top_k=1,
                attribution=['grad_x_input'])
        prefix_sufix = np.array(prefix_sufix)
        pred_ct_idx = int(output_.token_ids.cpu()[0][-3])
        target_age_idx = tokenizer.encode([ct])[0]
    #     print(tokenizer.decode([pred_ct_idx]))
        if target_age_idx == pred_ct_idx:
            n_corr+=1
            attri_scores = output_.attribution['grad_x_input'][0][:-1]
            gene_names = tokenizer.convert_ids_to_tokens(prefix_sufix[get_top_k(attri_scores,k=5)])
            marker_set.append(gene_names)
    gs_ = []
    for gs in marker_set:
        gs_+=gs
    counter = Counter(gs_)
    if len(gs_)!=0:
        sorted_items = sorted(counter.items(), key=lambda x: x[1], reverse=True)
        labels, values = zip(*sorted_items[:10])
        cts_.append(ct)
        labels_.append(labels)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [02:15<00:00,  3.69it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [01:37<00:00,  5.12it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [01:35<00:00,  5.26it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [01:35<00:00,  5.25it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████

In [26]:
marker_dict = {}
for ct, lab in zip(cts_,labels_):
    marker_dict[ct] = lab

In [27]:
torch.save(marker_dict, './markerdict.pt')