# Geneformer zero-shot cell type classification with knn classification

In [1]:
import os
GPU_NUMBER = [0] # set cuda number to use
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
os.environ["NCCL_DEBUG"] = "INFO"

In [None]:
# imports
from collections import Counter
import datetime
import pickle
import subprocess
import seaborn as sns; sns.set()
from datasets import load_from_disk
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from transformers import BertConfig, BertForSequenceClassification, BertForMaskedLM, BertModel
from transformers import Trainer
from transformers.training_args import TrainingArguments

from geneformer import DataCollatorForCellClassification
import sys
import re
import numpy as np
import torch


from sklearn.neighbors import KNeighborsClassifier
import json

In [None]:
# load cell type dataset (includes all tissues)


# select task type (ctc: cell type classification, isp: in silico perturbation)
f_type = ctc

# dataset_name (xxx.dataset path) 
dataset_name = "/path/to/your/dataset/to/analysis/xxx.dataset"


train_dataset=load_from_disk(dataset_name)

# check and remove column names
if f_type == "isp" :
    try :
        print(np.unique(train_dataset["disease"]))
    except KeyError as e :
        print("KeyError: {}".format(e))
        print("changing to disease")
        train_dataset = train_dataset.rename_column("column name in diseases infomation","disease")
        print("change finished")
        print(np.unique(train_dataset["disease"]))
    
elif f_type == "ctc" :
    try :
        print(np.unique(train_dataset["cell_type"]))
    except KeyError as e :
        print("KeyError: {}".format(e))
        print("changing to cell_type")
        train_dataset = train_dataset.rename_column("column name in cell types infomation","cell_type")
        print("change finished")
        print(np.unique(train_dataset["cell_type"]))

else :
    print("error: select fine turning type (ctc or isp)")
    sys.exit(1)

print(train_dataset)

In [None]:
import glob
import os
import tqdm
from tqdm.notebook import tqdm
from tqdm import tqdm_notebook as tqdm

rmfiles = glob.glob(dataset_name+"/cache*")
#print(rmfiles)
if rmfiles == [] :
    print("not exist cache files")
else :
    for tqdm_i2, rmfile in zip(tqdm(rmfiles, desc='remove files loop'), rmfiles) :
        os.remove(rmfile)
    print("removed cache files!!")

In [None]:
import glob
import os
import tqdm
from tqdm.notebook import tqdm
from tqdm import tqdm_notebook as tqdm

rmfiles = glob.glob(dataset_name+"/tmp*")
#print(rmfiles)
if rmfiles == [] :
    print("not exist tmp... files")
else :
    for tqdm_i2, rmfile in zip(tqdm(rmfiles, desc='remove files loop'), rmfiles) :
        os.remove(rmfile)
    print("removed tmp... files!!")

In [None]:


dataset_list = []
evalset_list = []
organ_list = []
target_dict_list = []


for organ in Counter(train_dataset["organ_major"]).keys():
    # collect list of tissues for fine-tuning (immune and bone marrow are included together)
    if organ in ["bone_marrow"]:  
        continue
    elif organ=="immune":
        organ_ids = ["immune","bone_marrow"]
        organ_list += ["immune"]
    else:
        organ_ids = [organ]
        organ_list += [organ]
    
    print(organ)
    
    # filter datasets for given organ
    def if_organ(example):
        return example["organ_major"] in organ_ids
    trainset_organ = train_dataset.filter(if_organ, num_proc=16)
    
    # per scDeepsort published method, drop cell types representing <0.5% of cells
    celltype_counter = Counter(trainset_organ["cell_type"])
    total_cells = sum(celltype_counter.values())
    cells_to_keep = [k for k,v in celltype_counter.items() if v>(0.005*total_cells)]
    def if_not_rare_celltype(example):
        return example["cell_type"] in cells_to_keep
    trainset_organ_subset = trainset_organ.filter(if_not_rare_celltype, num_proc=16)
      
    # shuffle datasets and rename columns
    trainset_organ_shuffled = trainset_organ_subset.shuffle(seed=42)
    trainset_organ_shuffled = trainset_organ_shuffled.rename_column("cell_type","label")
    trainset_organ_shuffled = trainset_organ_shuffled.remove_columns("organ_major")
    
    # create dictionary of cell types : label ids
    target_names = list(Counter(trainset_organ_shuffled["label"]).keys())
    target_name_id_dict = dict(zip(target_names,[i for i in range(len(target_names))]))
    target_dict_list += [target_name_id_dict]
    
    # change labels to numerical ids
    def classes_to_ids(example):
        example["label"] = target_name_id_dict[example["label"]]
        return example
    labeled_trainset = trainset_organ_shuffled.map(classes_to_ids, num_proc=16)
    
    # create 80/20 train/eval splits
    labeled_train_split = labeled_trainset.select([i for i in range(0,round(len(labeled_trainset)*0.8))])
    labeled_eval_split = labeled_trainset.select([i for i in range(round(len(labeled_trainset)*0.8),len(labeled_trainset))])
    
    # filter dataset for cell types in corresponding training set
    trained_labels = list(Counter(labeled_train_split["label"]).keys())
    def if_trained_label(example):
        return example["label"] in trained_labels
    labeled_eval_split_subset = labeled_eval_split.filter(if_trained_label, num_proc=16)

    dataset_list += [labeled_train_split]
    evalset_list += [labeled_eval_split_subset]

In [None]:
trainset_dict = dict(zip(organ_list,dataset_list))
traintargetdict_dict = dict(zip(organ_list,target_dict_list))

evalset_dict = dict(zip(organ_list,evalset_list))


print(trainset_dict)
print(traintargetdict_dict)

print(evalset_dict)


In [8]:
def compute_metrics(preds, labels):
    acc = accuracy_score(labels, preds)
    pre = precision_score(labels, preds, average='macro')
    rec = recall_score(labels, preds, average='macro')
    macro_f1 = f1_score(labels, preds, average='macro')
    return {
      'accuracy': acc,
      'macro_precision': pre,
      'macro_recall': rec,
      'macro_f1': macro_f1
    }

In [None]:
# set model parameters
# max input size
max_input_size = 2**11  # 2048

# set training hyperparameters
# max learning rate
max_lr = 5e-5
# how many pretrained layers to freeze
freeze_layers = 0
# number gpus
num_gpus = 1
# number cpu cores
num_proc = 16
# batch size for training and eval
geneformer_batch_size = 12
# learning schedule
lr_schedule_fn = "cosine" #"polynomial", "linear", "cosine"
# warmup steps
warmup_steps = 500
# number of epochs
epochs = 10
# optimizer
optimizer = "adamW"



In [None]:

for organ in organ_list:
    print("="*50)
    print(organ)
    organ_trainset = trainset_dict[organ]
    organ_evalset = evalset_dict[organ]
    organ_label_dict = traintargetdict_dict[organ]
    print(organ_label_dict)
    
    # set logging steps
    logging_steps = round(len(organ_trainset)/geneformer_batch_size/10)


    pretrain_model = "your mouse-Geneformer name"
    
    # reload pretrained model
    model = BertModel.from_pretrained("/path/to/mouse-Geneformer/model/{}/models/".format(pretrain_model), 
                                                            num_labels=len(organ_label_dict.keys()),
                                                            output_attentions = False,
                                                            output_hidden_states = False).to("cuda")
    

    # define output directory path
    current_date = datetime.datetime.now()
    datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
    output_dir = f"/path/to/your/zero-shot/directory/to/save/result/{datestamp}_mouse-geneformer_zero-shot-CellClassifier_{organ}_L{max_input_size}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_F{freeze_layers}_CTC-{organ}/"
    # ensure not overwriting previously saved model
            
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)   
    
    model.eval()
    
    train_pooled_embeddings = []
    train_labels_list = []
    
    # get feature vector in each cell text in train data 
    for _, inputs, label in zip(tqdm(organ_trainset["input_ids"], desc='input data loop'), organ_trainset["input_ids"], organ_trainset["label"]) :
        inputs = torch.Tensor(inputs).to(torch.long).unsqueeze(0).to("cuda")
        
        with torch.no_grad():
            outputs = model(inputs)
        sequence_output, pooled_output = outputs[:2]
    
        # save feature vector in list
        pooled_emb_data = pooled_output.to("cpu").detach().numpy().copy()
        train_pooled_embeddings.append(pooled_emb_data)

        # save label of feature vector in list        
        train_labels_list.append(label)

    train_data_feature_array = np.array(train_pooled_embeddings)
    train_features = train_data_feature_array.reshape(-1, train_data_feature_array.shape[2])

    train_labels = np.array(train_labels_list)

    del train_pooled_embeddings, train_labels_list

    

    test_pooled_embeddings = []
    test_labels_list = []
    
    # get feature vector in each cell text in test data
    for _, inputs, label in zip(tqdm(organ_evalset["input_ids"], desc='input data loop'), organ_evalset["input_ids"], organ_evalset["label"]) :
        inputs = torch.Tensor(inputs).to(torch.long).unsqueeze(0).to("cuda")
        
        with torch.no_grad():
            outputs = model(inputs)
        sequence_output, pooled_output = outputs[:2]
    
        # save feature vector in list
        pooled_emb_data = pooled_output.to("cpu").detach().numpy().copy()
        test_pooled_embeddings.append(pooled_emb_data)

        # save label of feature vector in list        
        test_labels_list.append(label)
        
    test_data_feature_array = np.array(test_pooled_embeddings)
    test_features = test_data_feature_array.reshape(-1, test_data_feature_array.shape[2])

    test_labels = np.array(test_labels_list)

    del test_pooled_embeddings, test_labels_list


    # cell classification using k-nearest neighbor methods
    knn = KNeighborsClassifier(n_neighbors=5, metric='cosine')
    knn.fit(train_features, train_labels)

    test_pred_labels = knn.predict(test_features)

    result_dict = compute_metrics(test_pred_labels, test_labels)

    
    print("accuracy: {}, precision: {}, recall: {}, f1_score: {}".format(result_dict["accuracy"], result_dict["macro_precision"], result_dict["macro_recall"], result_dict["macro_f1"]))


    with open(output_dir+organ+"_zero-shot-knn_result.json", "w") as f :
        json.dump(result_dict, f)
    print("saved zero-shot classification result!")
    
print("finish") 

In [None]:

rmfiles = glob.glob(dataset_name+"/cache*")
#print(rmfiles)
for tqdm_i2, rmfile in zip(tqdm(rmfiles, desc='remove files loop'), rmfiles) :
    os.remove(rmfile)
print("removed cache files!!")


In [None]:
import os, sys, glob, re, json
import numpy as np
import pandas as pd


zs_reslut_dicts = glob.glob("/path/to/your/zero-shot/directory/to/save/result/*")
print(len(zs_reslut_dicts))

organ = ["set", "each", "organ", "name", "to", "watch", "zero-shot", "classification", "result"]
for zs_result_dict in zs_reslut_dicts :
    print("="*100)
    json_path = glob.glob(zs_result_dict+"/*.json")[0]
    match = re.search(r'CellClassifier[.[A-Z]*.]*(?P<organ>.*?)_L2048', json_path)
    print(match.group('organ'))
    json_open = open(json_path, 'r')
    result = json.load(json_open)
    print(result)
    