## Geneformer Fine-Tuning for Cell Annotation Application

In [1]:
import os
GPU_NUMBER = [0]
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
os.environ["NCCL_DEBUG"] = "INFO"

In [2]:
# imports
from collections import Counter
import datetime
import pickle
import subprocess
import seaborn as sns; sns.set()
from datasets import load_from_disk
from sklearn.metrics import accuracy_score, f1_score
from transformers import BertForSequenceClassification
from transformers import Trainer
from transformers.training_args import TrainingArguments

from geneformer import DataCollatorForCellClassification

  from .autonotebook import tqdm as notebook_tqdm
  def twobit_to_dna(twobit: int, size: int) -> str:
  def dna_to_twobit(dna: str) -> int:
  def twobit_1hamming(twobit: int, size: int) -> List[int]:


In [3]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# without finetuning on Intestine data

In [13]:
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    # calculate accuracy and macro f1 using sklearn's function
    acc = accuracy_score(labels, preds)
    macro_f1 = f1_score(labels, preds, average='macro')
    return {
      'accuracy': acc,
      'macro_f1': macro_f1
    }

In [14]:
# set model parameters
# max input size
max_input_size = 2 ** 11  # 2048

# set training hyperparameters
# max learning rate
max_lr = 5e-5
# how many pretrained layers to freeze
freeze_layers = 0
# number gpus
num_gpus = 1
# number cpu cores
num_proc = 16
# batch size for training and eval
geneformer_batch_size = 12
# learning schedule
lr_schedule_fn = "linear"
# warmup steps
warmup_steps = 500
# number of epochs
epochs = 10
# optimizer
optimizer = "adamw"

In [18]:
# set logging steps
logging_steps = round(len(labeled_train_split)/geneformer_batch_size/10)

# reload pretrained model
model = BertForSequenceClassification.from_pretrained("/nfs/public/cell_gpt_data/Geneformer_4_recomb/model/231227_geneformer_CellClassifier_L2048_B12_LR5e-05_LSlinear_WU500_E10_Oadamw_F0/", 
                                                  num_labels=95,
                                                  output_attentions = False,
                                                  output_hidden_states = False).to("cuda")

# define output directory path
current_date = datetime.datetime.now()
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
output_dir = f"/nfs/public/cell_gpt_data/Geneformer_4_recomb/model/{datestamp}_geneformer_CellClassifier_L{max_input_size}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_F{freeze_layers}/"

# ensure not overwriting previously saved model
saved_model_test = os.path.join(output_dir, f"pytorch_model.bin")
if os.path.isfile(saved_model_test) == True:
    raise Exception("Model already saved to this directory.")

# make output directory
subprocess.call(f'mkdir {output_dir}', shell=True)

# set training arguments
training_args = {
    "learning_rate": max_lr,
    "do_train": True,
    "do_eval": True,
    "evaluation_strategy": "epoch",
    "save_strategy": "epoch",
    "logging_steps": logging_steps,
    "group_by_length": True,
    "length_column_name": "length",
    "disable_tqdm": False,
    "lr_scheduler_type": lr_schedule_fn,
    "warmup_steps": warmup_steps,
    "weight_decay": 0.001,
    "per_device_train_batch_size": geneformer_batch_size,
    "per_device_eval_batch_size": geneformer_batch_size,
    "num_train_epochs": epochs,
    "load_best_model_at_end": True,
    "output_dir": output_dir,
}

training_args_init = TrainingArguments(**training_args)

# create the trainer
trainer = Trainer(
    model=model,
    args=training_args_init,
    data_collator=DataCollatorForCellClassification(),
    train_dataset=labeled_train_split,
    eval_dataset=labeled_eval_split,
    compute_metrics=compute_metrics
)

In [22]:
test_dataset=load_from_disk("/nfs/public/cell_gpt_data/Geneformer_4_recomb/dataset/tokenized/Intestine_Test.dataset/")
test_dataset = test_dataset.add_column('label', [1 for _ in range(test_dataset.num_rows)])
predictions = trainer.predict(test_dataset)

  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}


In [26]:
import pickle
with open("./id_target.pkl","rb") as file:
    id_target_name_dict = pickle.load(file)

In [29]:
predictions_ids = np.array([test_predictions.predictions[i,:].argmax() for i in range(test_predictions.predictions.shape[0])])
predictions = [id_target_name_dict[p] for p in predictions_ids]
labels = np.array(test_dataset['cell_type'])

In [30]:
pd.crosstab(predictions,labels)

col_0,B cell,Dendritic cell,Endothelial cell,Enteric glial cell,Enterocyte,Enterocyte progenitor,Enteroendocrine cell,Epithelial cell,Fibroblast,Goblet cell,...,Myeloid cell,Myofibroblast,Neuron,Neutrophilic granulocyte,Paneth cell,Plasma B cell,Smooth muscle cell,Stromal cell,T cell,Vascular endothelial cell
row_0,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
Adventitial fibroblast,0,0,0,0,0,0,0,0,1,0,...,0,0,0,0,0,0,0,56,1,0
Alveolar fibroblast,0,0,3,0,28,0,0,0,130,2,...,0,0,10,0,0,0,31,42,5,0
Artery endothelial cell,0,0,0,0,1,0,0,0,0,0,...,0,0,1,0,0,0,0,0,0,1
B cell,12,0,0,0,2,0,0,0,0,0,...,0,0,0,0,0,3,1,0,2,0
Basal cell,0,0,0,0,5,1,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
Type I alveolar cell,0,0,1,0,65,5,0,1,8,8,...,0,1,3,0,0,0,3,13,0,0
Type II alveolar cell,2,0,0,0,6,8,0,0,0,13,...,0,0,0,0,0,2,0,1,1,0
Vascular endothelial cell,0,0,93,0,13,0,0,0,1,0,...,0,0,0,0,1,0,0,0,0,0
Vascular smooth muscle cell,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,2,0,0


In [31]:
df_result = pd.DataFrame({"predictions":predictions,"labels":labels})

In [32]:
df_result.to_csv("GeneFormer_Anno_{}".format("Intestine_ZS.csv"))

## Prepare training and evaluation datasets

In [38]:
# load cell type dataset (includes all tissues)
train_dataset=load_from_disk("/nfs/public/cell_gpt_data/Geneformer_4_recomb/dataset/tokenized/Intestine_Finetune.dataset/")

In [39]:
trainset_organ_shuffled = train_dataset.shuffle(seed=42)
trainset_organ_shuffled = trainset_organ_shuffled.rename_column("cell_type","label")
# create dictionary of cell types : label ids
target_names = list(Counter(trainset_organ_shuffled["label"]).keys())
target_name_id_dict = dict(zip(target_names,[i for i in range(len(target_names))]))

# change labels to numerical ids
def classes_to_ids(example):
    example["label"] = target_name_id_dict[example["label"]]
    return example
labeled_trainset = trainset_organ_shuffled.map(classes_to_ids, num_proc=16)

In [40]:
# create 80/20 train/eval splits
labeled_train_split = labeled_trainset.select([i for i in range(0,round(len(labeled_trainset)*0.8))])
labeled_eval_split = labeled_trainset.select([i for i in range(round(len(labeled_trainset)*0.8),len(labeled_trainset))])

# filter dataset for cell types in corresponding training set
trained_labels = list(Counter(labeled_train_split["label"]).keys())
def if_trained_label(example):
    return example["label"] in trained_labels
labeled_eval_split = labeled_eval_split.filter(if_trained_label, num_proc=16)

In [41]:
labeled_train_split

Dataset({
    features: ['input_ids', 'label', 'length'],
    num_rows: 35143
})

In [42]:
labeled_eval_split

Dataset({
    features: ['input_ids', 'label', 'length'],
    num_rows: 8786
})

## Fine-Tune With Cell Classification Learning Objective and Quantify Predictive Performance

In [43]:
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    # calculate accuracy and macro f1 using sklearn's function
    acc = accuracy_score(labels, preds)
    macro_f1 = f1_score(labels, preds, average='macro')
    return {
      'accuracy': acc,
      'macro_f1': macro_f1
    }

In [44]:
# set model parameters
# max input size
max_input_size = 2 ** 11  # 2048

# set training hyperparameters
# max learning rate
max_lr = 5e-5
# how many pretrained layers to freeze
freeze_layers = 0
# number gpus
num_gpus = 1
# number cpu cores
num_proc = 16
# batch size for training and eval
geneformer_batch_size = 12
# learning schedule
lr_schedule_fn = "linear"
# warmup steps
warmup_steps = 500
# number of epochs
epochs = 10
# optimizer
optimizer = "adamw"

In [45]:
# set logging steps
logging_steps = round(len(labeled_train_split)/geneformer_batch_size/10)

# reload pretrained model
model = BertForSequenceClassification.from_pretrained("/data1/chenyx/Geneformer/", 
                                                  num_labels=len(target_name_id_dict.keys()),
                                                  output_attentions = False,
                                                  output_hidden_states = False).to("cuda")

# define output directory path
current_date = datetime.datetime.now()
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
output_dir = f"/nfs/public/cell_gpt_data/Geneformer_4_recomb/model/{datestamp}_geneformer_CellClassifier_L{max_input_size}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_F{freeze_layers}/"

# ensure not overwriting previously saved model
saved_model_test = os.path.join(output_dir, f"pytorch_model.bin")
if os.path.isfile(saved_model_test) == True:
    raise Exception("Model already saved to this directory.")

# make output directory
subprocess.call(f'mkdir {output_dir}', shell=True)

# set training arguments
training_args = {
    "learning_rate": max_lr,
    "do_train": True,
    "do_eval": True,
    "evaluation_strategy": "epoch",
    "save_strategy": "epoch",
    "logging_steps": logging_steps,
    "group_by_length": True,
    "length_column_name": "length",
    "disable_tqdm": False,
    "lr_scheduler_type": lr_schedule_fn,
    "warmup_steps": warmup_steps,
    "weight_decay": 0.001,
    "per_device_train_batch_size": geneformer_batch_size,
    "per_device_eval_batch_size": geneformer_batch_size,
    "num_train_epochs": epochs,
    "load_best_model_at_end": True,
    "output_dir": output_dir,
}

training_args_init = TrainingArguments(**training_args)

# create the trainer
trainer = Trainer(
    model=model,
    args=training_args_init,
    data_collator=DataCollatorForCellClassification(),
    train_dataset=labeled_train_split,
    eval_dataset=labeled_eval_split,
    compute_metrics=compute_metrics
)
# train the cell type classifier
trainer.train()
predictions = trainer.predict(labeled_eval_split)
with open(f"{output_dir}predictions.pickle", "wb") as fp:
    pickle.dump(predictions, fp)
trainer.save_metrics("eval",predictions.metrics)
trainer.save_model(output_dir)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /data1/chenyx/Geneformer/ and are newly initialized: ['bert.pooler.dense.bias', 'classifier.bias', 'bert.pooler.dense.weight', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
mkdir: cannot create directory ‘/nfs/public/cell_gpt_data/Geneformer_4_recomb/model/231230_geneformer_CellClassifier_L2048_B12_LR5e-05_LSlinear_WU500_E10_Oadamw_F0/’: File exists
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}


Epoch,Training Loss,Validation Loss,Accuracy,Macro F1
1,0.6555,0.631545,0.78682,0.497898
2,0.4886,0.484145,0.83064,0.588218
3,0.4077,0.444231,0.843387,0.617303
4,0.335,0.450034,0.847712,0.655733
5,0.2789,0.444588,0.855452,0.666561
6,0.2486,0.46596,0.855224,0.712657
7,0.1987,0.486606,0.86285,0.705214
8,0.2001,0.53134,0.862167,0.710227
9,0.1601,0.549774,0.863078,0.727442
10,0.1419,0.553988,0.868313,0.717084


  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}


In [46]:
validset = load_from_disk("/nfs/public/cell_gpt_data/Geneformer_4_recomb/dataset/tokenized/Intestine_Test.dataset/")
validset = validset.rename_column("cell_type","label")
labeled_validset = validset.map(classes_to_ids, num_proc=16)

Map (num_proc=16): 100%|██████████| 10983/10983 [00:00<00:00, 25147.24 examples/s]


In [47]:
labeled_validset

Dataset({
    features: ['input_ids', 'label', 'length'],
    num_rows: 10983
})

In [48]:
valid_predictions = trainer.predict(labeled_validset)

  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}


In [55]:
id_target_name_dict = {}
for k,v in target_name_id_dict.items():
    id_target_name_dict[v] = k

In [57]:
predictions_ids = np.array([valid_predictions.predictions[i,:].argmax() for i in range(valid_predictions.predictions.shape[0])])
predictions = [id_target_name_dict[p] for p in predictions_ids]
labels = np.array(test_dataset['cell_type'])

In [58]:
pd.crosstab(predictions,labels)

col_0,B cell,Dendritic cell,Endothelial cell,Enteric glial cell,Enterocyte,Enterocyte progenitor,Enteroendocrine cell,Epithelial cell,Fibroblast,Goblet cell,...,Myeloid cell,Myofibroblast,Neuron,Neutrophilic granulocyte,Paneth cell,Plasma B cell,Smooth muscle cell,Stromal cell,T cell,Vascular endothelial cell
row_0,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
B cell,441,0,0,0,7,3,0,0,0,3,...,0,0,0,0,0,61,0,0,10,0
Dendritic cell,0,92,0,0,3,0,0,0,0,0,...,0,5,0,0,0,0,0,0,2,0
Endothelial cell,2,0,161,0,3,0,0,0,1,0,...,0,0,0,0,2,0,1,5,0,6
Enterocyte,19,0,10,3,3854,317,0,13,46,44,...,0,2,11,0,4,9,9,17,22,0
Enterocyte progenitor,0,0,3,0,134,978,1,0,5,15,...,0,0,2,0,2,0,1,8,0,0
Epithelial cell,0,0,0,0,9,9,0,39,0,0,...,0,0,0,0,0,0,0,0,0,0
Fibroblast,1,0,2,0,6,4,0,0,349,1,...,0,0,2,0,2,0,11,15,1,0
Goblet cell,3,0,0,0,57,23,0,0,4,565,...,0,2,0,0,5,2,1,0,2,0
Macrophage,1,15,0,0,2,1,0,0,1,2,...,0,0,0,0,1,1,0,1,1,0
Mast cell,0,1,0,0,3,0,0,0,0,2,...,0,0,0,0,3,1,1,0,2,0


In [59]:
df_result = pd.DataFrame({"predictions":predictions,"labels":labels})

In [60]:
df_result.to_csv("GeneFormer_Anno_{}".format("Intestine_finetune.csv"))