## Geneformer Fine-Tuning for Cell Annotation Application

In [2]:
import os
GPU_NUMBER = [1]
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
os.environ["NCCL_DEBUG"] = "INFO"

In [3]:
# imports
from collections import Counter
import datetime
import pickle
import subprocess
import seaborn as sns; sns.set()
from datasets import load_from_disk
from sklearn.metrics import accuracy_score, f1_score
from transformers import BertForSequenceClassification
from transformers import Trainer
from transformers.training_args import TrainingArguments

from geneformer import DataCollatorForCellClassification

  from .autonotebook import tqdm as notebook_tqdm
  def twobit_to_dna(twobit: int, size: int) -> str:
  def dna_to_twobit(dna: str) -> int:
  def twobit_1hamming(twobit: int, size: int) -> List[int]:


In [4]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

## Prepare training and evaluation datasets

In [5]:
# load cell type dataset (includes all tissues)
train_dataset=load_from_disk("/nfs/public/cell_gpt_data/Geneformer_4_recomb/dataset/tokenized/MultiOrgan_finetune_train.dataset/")

In [6]:
trainset_organ_shuffled = train_dataset.shuffle(seed=42)
trainset_organ_shuffled = trainset_organ_shuffled.rename_column("cell_type","label")
# create dictionary of cell types : label ids
target_names = list(Counter(trainset_organ_shuffled["label"]).keys())
target_name_id_dict = dict(zip(target_names,[i for i in range(len(target_names))]))

# change labels to numerical ids
def classes_to_ids(example):
    example["label"] = target_name_id_dict[example["label"]]
    return example
labeled_trainset = trainset_organ_shuffled.map(classes_to_ids, num_proc=16)

In [7]:
# create 80/20 train/eval splits
labeled_train_split = labeled_trainset.select([i for i in range(0,round(len(labeled_trainset)*0.8))])
labeled_eval_split = labeled_trainset.select([i for i in range(round(len(labeled_trainset)*0.8),len(labeled_trainset))])

# filter dataset for cell types in corresponding training set
trained_labels = list(Counter(labeled_train_split["label"]).keys())
def if_trained_label(example):
    return example["label"] in trained_labels
labeled_eval_split = labeled_eval_split.filter(if_trained_label, num_proc=16)

In [8]:
labeled_train_split

Dataset({
    features: ['input_ids', 'label', 'organ', 'length'],
    num_rows: 128000
})

In [9]:
labeled_eval_split

Dataset({
    features: ['input_ids', 'label', 'organ', 'length'],
    num_rows: 32000
})

## Fine-Tune With Cell Classification Learning Objective and Quantify Predictive Performance

In [9]:
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    # calculate accuracy and macro f1 using sklearn's function
    acc = accuracy_score(labels, preds)
    macro_f1 = f1_score(labels, preds, average='macro')
    return {
      'accuracy': acc,
      'macro_f1': macro_f1
    }

In [10]:
# set model parameters
# max input size
max_input_size = 2 ** 11  # 2048

# set training hyperparameters
# max learning rate
max_lr = 5e-5
# how many pretrained layers to freeze
freeze_layers = 0
# number gpus
num_gpus = 1
# number cpu cores
num_proc = 16
# batch size for training and eval
geneformer_batch_size = 12
# learning schedule
lr_schedule_fn = "linear"
# warmup steps
warmup_steps = 500
# number of epochs
epochs = 10
# optimizer
optimizer = "adamw"

In [None]:
# set logging steps
logging_steps = round(len(labeled_train_split)/geneformer_batch_size/10)

# reload pretrained model
model = BertForSequenceClassification.from_pretrained("/data1/chenyx/Geneformer/", 
                                                  num_labels=len(target_name_id_dict.keys()),
                                                  output_attentions = False,
                                                  output_hidden_states = False).to("cuda")

# define output directory path
current_date = datetime.datetime.now()
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
output_dir = f"/nfs/public/cell_gpt_data/Geneformer_4_recomb/model/{datestamp}_geneformer_CellClassifier_L{max_input_size}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_F{freeze_layers}/"

# ensure not overwriting previously saved model
saved_model_test = os.path.join(output_dir, f"pytorch_model.bin")
if os.path.isfile(saved_model_test) == True:
    raise Exception("Model already saved to this directory.")

# make output directory
subprocess.call(f'mkdir {output_dir}', shell=True)

# set training arguments
training_args = {
    "learning_rate": max_lr,
    "do_train": True,
    "do_eval": True,
    "evaluation_strategy": "epoch",
    "save_strategy": "epoch",
    "logging_steps": logging_steps,
    "group_by_length": True,
    "length_column_name": "length",
    "disable_tqdm": False,
    "lr_scheduler_type": lr_schedule_fn,
    "warmup_steps": warmup_steps,
    "weight_decay": 0.001,
    "per_device_train_batch_size": geneformer_batch_size,
    "per_device_eval_batch_size": geneformer_batch_size,
    "num_train_epochs": epochs,
    "load_best_model_at_end": True,
    "output_dir": output_dir,
}

training_args_init = TrainingArguments(**training_args)

# create the trainer
trainer = Trainer(
    model=model,
    args=training_args_init,
    data_collator=DataCollatorForCellClassification(),
    train_dataset=labeled_train_split,
    eval_dataset=labeled_eval_split,
    compute_metrics=compute_metrics
)
# train the cell type classifier
trainer.train()
predictions = trainer.predict(labeled_eval_split)
with open(f"{output_dir}predictions.pickle", "wb") as fp:
    pickle.dump(predictions, fp)
trainer.save_metrics("eval",predictions.metrics)
trainer.save_model(output_dir)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /data1/chenyx/Geneformer/ and are newly initialized: ['classifier.weight', 'classifier.bias', 'bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
mkdir: cannot create directory ‘/nfs/public/cell_gpt_data/Geneformer_4_recomb/model/231227_geneformer_CellClassifier_L2048_B12_LR5e-05_LSlinear_WU500_E10_Oadamw_F0/’: File exists
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}


Epoch,Training Loss,Validation Loss,Accuracy,Macro F1
1,0.7185,0.678757,0.791469,0.46452
2,0.5605,0.543218,0.828562,0.578318
3,0.4466,0.501964,0.842875,0.625997
4,0.3894,0.493409,0.848781,0.681758
5,0.3731,0.475856,0.854437,0.697536
6,0.3056,0.480031,0.859031,0.711759
7,0.2768,0.493541,0.859781,0.715033
8,0.2406,0.50927,0.861281,0.725465
9,0.219,0.522918,0.860719,0.730006
10,0.1987,0.525999,0.862625,0.734326


  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}


## Validation

In [14]:
# load model
model = BertForSequenceClassification.from_pretrained("/nfs/public/cell_gpt_data/Geneformer_4_recomb/model/231227_geneformer_CellClassifier_L2048_B12_LR5e-05_LSlinear_WU500_E10_Oadamw_F0", 
                                                  num_labels=len(target_name_id_dict.keys()),
                                                  output_attentions = False,
                                                  output_hidden_states = False).to("cuda")
# set training arguments
training_args = {
    "learning_rate": max_lr,
    "do_train": True,
    "do_eval": True,
    "evaluation_strategy": "epoch",
    "save_strategy": "epoch",
    "logging_steps": logging_steps,
    "group_by_length": True,
    "length_column_name": "length",
    "disable_tqdm": False,
    "lr_scheduler_type": lr_schedule_fn,
    "warmup_steps": warmup_steps,
    "weight_decay": 0.001,
    "per_device_train_batch_size": geneformer_batch_size,
    "per_device_eval_batch_size": geneformer_batch_size,
    "num_train_epochs": epochs,
    "load_best_model_at_end": True,
    "output_dir": output_dir,
}

training_args_init = TrainingArguments(**training_args)

# create the trainer
trainer = Trainer(
    model=model,
    args=training_args_init,
    data_collator=DataCollatorForCellClassification(),
    train_dataset=labeled_train_split,
    eval_dataset=labeled_eval_split,
    compute_metrics=compute_metrics
)

In [15]:
validset = load_from_disk("/nfs/public/cell_gpt_data/Geneformer_4_recomb/dataset/tokenized/MultiOrgan_finetune_test.dataset/")
validset = validset.rename_column("cell_type","label")
labeled_validset = validset.map(classes_to_ids, num_proc=16)

In [16]:
labeled_validset

Dataset({
    features: ['input_ids', 'label', 'organ', 'length'],
    num_rows: 40000
})

In [17]:
valid_predictions = trainer.predict(labeled_validset)

  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}


In [10]:
id_target_name_dict = {}
for k,v in target_name_id_dict.items():
    id_target_name_dict[v] = k

In [14]:
import pickle 
with open("./id_target.pkl","wb") as file:
    pickle.dump(id_target_name_dict, file)

In [48]:
predictions_ids = np.array([valid_predictions.predictions[i,:].argmax() for i in range(valid_predictions.predictions.shape[0])])
label_ids = valid_predictions.label_ids

In [49]:
predictions = [id_target_name_dict[p] for p in predictions_ids]
labels = [id_target_name_dict[l] for l in label_ids]

In [50]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
def cal_metrics(celltypes_labels, predictions):
    accuracy = accuracy_score(celltypes_labels, predictions)
    precision = precision_score(celltypes_labels, predictions, average="weighted")
    recall = recall_score(celltypes_labels, predictions, average="weighted")
    weighted_f1 = f1_score(celltypes_labels, predictions, average="weighted")

    results = {
            "test/accuracy": accuracy,
            "test/precision": precision,
            "test/recall": recall,
            "test/weighted_f1": weighted_f1,
        }

    return results

In [51]:
cal_metrics(labels,predictions)

  _warn_prf(average, modifier, msg_start, len(result))


{'test/accuracy': 0.8584,
 'test/precision': 0.8578970893308969,
 'test/recall': 0.8584,
 'test/weighted_f1': 0.8560748129057907}

## Test on Bonemarrow

In [54]:
testset = load_from_disk("/nfs/public/cell_gpt_data/Geneformer_4_recomb/dataset/tokenized/BoneMarrow_AHCA.dataset/")
testset = testset.add_column('label', [1 for _ in range(testset.num_rows)])

In [55]:
test_predictions = trainer.predict(testset)

  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}


In [56]:
test_predictions

PredictionOutput(predictions=array([[-0.5361285 ,  3.7151241 , -0.81339383, ..., -1.3169482 ,
        -0.40120322,  0.26952046],
       [-2.5385108 , -1.4970512 , -2.2173736 , ...,  1.1953446 ,
        -0.68179995,  2.1621404 ],
       [-1.7400182 ,  3.6705656 ,  0.2633136 , ..., -1.4319061 ,
        -0.2542736 ,  0.41489103],
       ...,
       [-2.3354623 , -0.91758025, -1.6148753 , ...,  1.0879475 ,
        -0.74409795,  1.7601032 ],
       [ 1.2616403 ,  2.7343066 , -1.419966  , ..., -1.5234289 ,
        -0.99613947, -0.65326726],
       [ 0.7887688 ,  3.2840142 , -1.6650503 , ..., -1.6993607 ,
        -0.92037296, -0.7364452 ]], dtype=float32), label_ids=array([1, 1, 1, ..., 1, 1, 1]), metrics={'test_loss': 9.297154426574707, 'test_accuracy': 0.0, 'test_macro_f1': 0.0, 'test_runtime': 17.0371, 'test_samples_per_second': 189.587, 'test_steps_per_second': 15.848})

In [78]:
predictions_ids = np.array([test_predictions.predictions[i,:].argmax() for i in range(test_predictions.predictions.shape[0])])
predictions = [id_target_name_dict[p] for p in predictions_ids]
labels = np.array(testset['cell_type'])

In [59]:
cal_metrics(labels,predictions)

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


{'test/accuracy': 0.17739938080495357,
 'test/precision': 0.1747138022895817,
 'test/recall': 0.17739938080495357,
 'test/weighted_f1': 0.17593282391556797}

In [82]:
pd.crosstab(predictions,labels)

col_0,B cell,Erythrocyte,Fibroblast,Macrophage,Monocyte,NK T cell,Plasma cell,Secretory cell,T cell
row_0,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
B cell,464,0,0,0,0,0,0,0,0
CD4 T cell,0,0,0,0,0,2,0,0,808
CD8 T cell,2,0,0,0,0,48,0,0,1180
Conventional dendritic cell 2,0,0,0,3,0,0,0,0,0
Erythroblast,0,12,0,0,0,0,0,23,4
Erythrocyte,0,0,0,0,0,0,1,0,0
Erythroid progenitor cell,0,21,0,0,0,0,0,0,2
Granulocyte,0,0,0,0,1,0,0,0,0
Granulocyte-monocyte progenitor (GMP),0,1,0,0,0,0,0,0,1
Hepatic stellate cell,0,0,1,0,0,0,0,0,0


In [83]:
df_result = pd.DataFrame({"predictions":predictions,"labels":labels})

In [85]:
df_result.to_csv("GeneFormer_Anno_{}".format("Bonemarraw_AHCA.csv"))

## Test on Heart

In [86]:
testset = load_from_disk("/nfs/public/cell_gpt_data/Geneformer_4_recomb/dataset/tokenized/Heart_Simonson2023.dataset/")
testset = testset.add_column('label', [1 for _ in range(testset.num_rows)])

In [87]:
test_predictions = trainer.predict(testset)

  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}


In [88]:
predictions_ids = np.array([test_predictions.predictions[i,:].argmax() for i in range(test_predictions.predictions.shape[0])])
predictions = [id_target_name_dict[p] for p in predictions_ids]
labels = np.array(testset['cell_type'])

In [89]:
cal_metrics(labels,predictions)

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


{'test/accuracy': 0.4994945728726489,
 'test/precision': 0.6583047706859362,
 'test/recall': 0.4994945728726489,
 'test/weighted_f1': 0.5106326438703199}

In [90]:
pd.crosstab(predictions,labels)

col_0,Adipocyte,Arterial endothelial cell,Capillary endothelial cell,Cardiomyocyte cell,Endocardial cell,Endothelial cell,Fibroblast,Lymphoid cell,Lyphatic endothelial cell,Macrophage,Mast cell,Mesothelial cell,Neuron,Pericyte,Unclassified
row_0,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
Adipocyte,8,0,0,0,0,0,0,0,0,0,0,0,0,0,0
Adventitial fibroblast,0,0,0,0,0,0,186,0,0,0,0,0,1,0,0
B cell,0,0,0,0,0,0,0,3,0,2,0,0,0,0,0
Basement membrane fibroblast,0,0,0,1,0,0,1995,0,1,0,0,0,1,0,0
CD4 T cell,0,0,0,0,0,0,0,15,0,0,0,0,0,0,0
CD8 T cell,0,0,0,0,0,0,0,864,0,0,1,0,1,0,0
Capillary endothelial cell,0,0,1932,0,1,30,0,0,0,0,0,0,0,0,86
Cardiomyocyte cell,1,0,0,20869,0,0,0,0,0,2,0,1,6,0,4
Cholangiocyte,61,0,0,0,0,0,2,0,1,1,0,0,0,0,0
Conventional dendritic cell 2,0,0,0,0,0,0,0,0,0,5,0,0,0,0,0


In [91]:
df_result = pd.DataFrame({"predictions":predictions,"labels":labels})

In [92]:
df_result.to_csv("GeneFormer_Anno_{}".format("Heart_Simonson2023.csv"))

## Test on Liver

In [93]:
testset = load_from_disk("/nfs/public/cell_gpt_data/Geneformer_4_recomb/dataset/tokenized/Liver_Suo2022.dataset/")
testset = testset.add_column('label', [1 for _ in range(testset.num_rows)])

In [94]:
test_predictions = trainer.predict(testset)

  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}


In [95]:
predictions_ids = np.array([test_predictions.predictions[i,:].argmax() for i in range(test_predictions.predictions.shape[0])])
predictions = [id_target_name_dict[p] for p in predictions_ids]
labels = np.array(testset['cell_type'])

In [96]:
cal_metrics(labels,predictions)

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


{'test/accuracy': 0.23601664953601445,
 'test/precision': 0.302260790114065,
 'test/recall': 0.23601664953601445,
 'test/weighted_f1': 0.2277567011805992}

In [97]:
pd.crosstab(predictions,labels)

col_0,B cell,CD4 T cell,CD8 T cell,Conventional DC1,Conventional DC2,Dendritic cell,Endothelial cell,Hepatocyte,Lymphoid cell,Mast cell,Monocyte,Myeloid cell,NK cell,Neutrophilic granulocyte,Plasma B cell,T cell
row_0,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1
B cell,532,0,0,0,0,2,0,0,0,2,0,0,0,0,0,1
CD4 T cell,3,164,46,0,0,0,0,0,291,0,0,0,0,0,0,86
CD4 Treg,19,35,21,0,0,1,0,0,142,0,0,0,652,0,0,115
CD8 T cell,10,11,11,0,0,1,0,0,174,0,0,0,231,0,0,147
Capillary endothelial cell,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0
Cholangiocyte,0,0,0,0,0,0,2,54,0,0,0,1,0,0,0,0
Common lymphoid progenitor (CLP),711,0,1,0,0,0,0,0,1,1,0,0,0,0,0,6
Conventional dendritic cell,0,0,0,0,1,3,0,0,0,0,0,0,0,0,0,0
Conventional dendritic cell 1,0,0,0,23,0,6,0,0,0,0,0,0,0,0,0,0
Conventional dendritic cell 2,7,0,0,314,1569,81,0,0,0,2,26,1146,0,0,0,0


In [98]:
df_result = pd.DataFrame({"predictions":predictions,"labels":labels})

In [99]:
df_result.to_csv("GeneFormer_Anno_{}".format("Liver_Suo2022.csv"))