## Geneformer Fine-Tuning for Cell Annotation Application

In [2]:
import os
GPU_NUMBER = [0]
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
os.environ["NCCL_DEBUG"] = "INFO"

In [3]:
# imports
from collections import Counter
import datetime
import pickle
import subprocess
import seaborn as sns; sns.set()
from datasets import load_from_disk
from sklearn.metrics import accuracy_score, f1_score
from transformers import BertForSequenceClassification
from transformers import Trainer
from transformers.training_args import TrainingArguments

from geneformer import DataCollatorForCellClassification

  from .autonotebook import tqdm as notebook_tqdm


In [1]:
DATA_PATH = '../data/cell_type_train_data.dataset'
PRETRAINED_MODEL_PATH = '../geneformer-12L-30M'
OUTPUT_PATH = '../cell-classification-finetuned-model'

assert os.path.exists(DATA_PATH), f'DATA_PATH {DATA_PATH} does not exist'
assert os.path.exists(PRETRAINED_MODEL_PATH), f'PRETRAINED_MODEL_PATH {PRETRAINED_MODEL_PATH} does not exist'
if not os.path.exists(OUTPUT_PATH):
    os.makedirs(OUTPUT_PATH)

## Prepare training and evaluation datasets

In [4]:
# load cell type dataset (includes all tissues)
train_dataset=load_from_disk(DATA_PATH)

In [5]:
train_dataset

Dataset({
    features: ['cell_type', 'input_ids', 'length', 'organ_major'],
    num_rows: 249556
})

In [6]:
dataset_list = []
evalset_list = []
organ_list = []
target_dict_list = []

for organ in Counter(train_dataset["organ_major"]).keys():
    # collect list of tissues for fine-tuning (immune and bone marrow are included together)
    if organ in ["bone_marrow"]:  
        continue
    elif organ=="immune":
        organ_ids = ["immune","bone_marrow"]
        organ_list += ["immune"]
    else:
        organ_ids = [organ]
        organ_list += [organ]
    
    print(organ)
    
    # filter datasets for given organ
    def if_organ(example):
        return example["organ_major"] in organ_ids
    trainset_organ = train_dataset.filter(if_organ, num_proc=16)
    
    # per scDeepsort published method, drop cell types representing <0.5% of cells
    celltype_counter = Counter(trainset_organ["cell_type"])
    total_cells = sum(celltype_counter.values())
    cells_to_keep = [k for k,v in celltype_counter.items() if v>(0.005*total_cells)]
    def if_not_rare_celltype(example):
        return example["cell_type"] in cells_to_keep
    trainset_organ_subset = trainset_organ.filter(if_not_rare_celltype, num_proc=16)
      
    # shuffle datasets and rename columns
    trainset_organ_shuffled = trainset_organ_subset.shuffle(seed=42)
    trainset_organ_shuffled = trainset_organ_shuffled.rename_column("cell_type","label")
    trainset_organ_shuffled = trainset_organ_shuffled.remove_columns("organ_major")
    
    # create dictionary of cell types : label ids
    target_names = list(Counter(trainset_organ_shuffled["label"]).keys())
    target_name_id_dict = dict(zip(target_names,[i for i in range(len(target_names))]))
    target_dict_list += [target_name_id_dict]
    
    # change labels to numerical ids
    def classes_to_ids(example):
        example["label"] = target_name_id_dict[example["label"]]
        return example
    labeled_trainset = trainset_organ_shuffled.map(classes_to_ids, num_proc=16)
    
    # create 80/20 train/eval splits
    labeled_train_split = labeled_trainset.select([i for i in range(0,round(len(labeled_trainset)*0.8))])
    labeled_eval_split = labeled_trainset.select([i for i in range(round(len(labeled_trainset)*0.8),len(labeled_trainset))])
    
    # filter dataset for cell types in corresponding training set
    trained_labels = list(Counter(labeled_train_split["label"]).keys())
    def if_trained_label(example):
        return example["label"] in trained_labels
    labeled_eval_split_subset = labeled_eval_split.filter(if_trained_label, num_proc=16)

    dataset_list += [labeled_train_split]
    evalset_list += [labeled_eval_split_subset]

a=7

spleen
kidney
lung
brain
placenta
immune
large_intestine
pancreas
liver


In [7]:
trainset_dict = dict(zip(organ_list,dataset_list))
traintargetdict_dict = dict(zip(organ_list,target_dict_list))
evalset_dict = dict(zip(organ_list,evalset_list))

In [8]:
trainset_dict

{'spleen': Dataset({
     features: ['label', 'input_ids', 'length'],
     num_rows: 12330
 }),
 'kidney': Dataset({
     features: ['label', 'input_ids', 'length'],
     num_rows: 35199
 }),
 'lung': Dataset({
     features: ['label', 'input_ids', 'length'],
     num_rows: 26098
 }),
 'brain': Dataset({
     features: ['label', 'input_ids', 'length'],
     num_rows: 10656
 }),
 'placenta': Dataset({
     features: ['label', 'input_ids', 'length'],
     num_rows: 7415
 }),
 'immune': Dataset({
     features: ['label', 'input_ids', 'length'],
     num_rows: 20562
 }),
 'large_intestine': Dataset({
     features: ['label', 'input_ids', 'length'],
     num_rows: 39678
 }),
 'pancreas': Dataset({
     features: ['label', 'input_ids', 'length'],
     num_rows: 21934
 }),
 'liver': Dataset({
     features: ['label', 'input_ids', 'length'],
     num_rows: 22427
 })}

In [9]:
type(train_dataset)

datasets.arrow_dataset.Dataset

## Fine-Tune With Cell Classification Learning Objective and Quantify Predictive Performance

In [10]:
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    # calculate accuracy and macro f1 using sklearn's function
    acc = accuracy_score(labels, preds)
    macro_f1 = f1_score(labels, preds, average='macro')
    return {
      'accuracy': acc,
      'macro_f1': macro_f1
    }

### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example hyperparameters are defined below, but please see the "hyperparam_optimiz_for_disease_classifier" script for an example of how to tune hyperparameters for downstream applications.

In [20]:
# set model parameters
# max input size
max_input_size = 2 ** 11  # 2048

# set training hyperparameters
# max learning rate
max_lr = 5e-5
# how many pretrained layers to freeze
freeze_layers = 0
# number gpus
num_gpus = 1
# number cpu cores
num_proc = 16
# batch size for training and eval
geneformer_batch_size = 2 # default 12 --> OOM
# learning schedule
lr_schedule_fn = "linear"
# warmup steps
warmup_steps = 500
# number of epochs
epochs = 10
# optimizer
optimizer = "adamw"

In [21]:
for organ in organ_list:
    print(organ)
    organ_trainset = trainset_dict[organ]
    organ_evalset = evalset_dict[organ]
    organ_label_dict = traintargetdict_dict[organ]
    
    # set logging steps
    logging_steps = round(len(organ_trainset)/geneformer_batch_size/10)
    
    # reload pretrained model
    model = BertForSequenceClassification.from_pretrained(PRETRAINED_MODEL_PATH, 
                                                      num_labels=len(organ_label_dict.keys()),
                                                      output_attentions = False,
                                                      output_hidden_states = False).to("cuda")
    
    # define output directory path
    current_date = datetime.datetime.now()
    datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
    output_dir = f"{OUTPUT_PATH}/{datestamp}_geneformer_CellClassifier_{organ}_L{max_input_size}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_F{freeze_layers}/"

    # ensure not overwriting previously saved model
    saved_model_test = os.path.join(output_dir, f"pytorch_model.bin")
    if os.path.isfile(saved_model_test) == True:
        raise Exception("Model already saved to this directory.")

    # make output directory
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    # set training arguments
    training_args = {
        "learning_rate": max_lr,
        "do_train": True,
        "do_eval": True,
        "evaluation_strategy": "epoch",
        "save_strategy": "epoch",
        "logging_steps": logging_steps,
        "group_by_length": True,
        "length_column_name": "length",
        "disable_tqdm": False,
        "lr_scheduler_type": lr_schedule_fn,
        "warmup_steps": warmup_steps,
        "weight_decay": 0.001,
        "per_device_train_batch_size": geneformer_batch_size,
        "per_device_eval_batch_size": geneformer_batch_size,
        "num_train_epochs": epochs,
        "load_best_model_at_end": True,
        "output_dir": output_dir,
    }
    
    training_args_init = TrainingArguments(**training_args)

    # create the trainer
    trainer = Trainer(
        model=model,
        args=training_args_init,
        data_collator=DataCollatorForCellClassification(),
        train_dataset=organ_trainset,
        eval_dataset=organ_evalset,
        compute_metrics=compute_metrics
    )
    # train the cell type classifier
    trainer.train()
    predictions = trainer.predict(organ_evalset)
    with open(f"{output_dir}predictions.pickle", "wb") as fp:
        pickle.dump(predictions, fp)
    trainer.save_metrics("eval",predictions.metrics)
    trainer.save_model(output_dir)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at ../geneformer-12L-30M and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


spleen


mkdir: cannot create directory ‘/path/to/models/240215_geneformer_CellClassifier_spleen_L2048_B2_LR5e-05_LSlinear_WU500_E10_Oadamw_F0/’: File exists


Epoch,Training Loss,Validation Loss,Accuracy,Macro F1
1,0.1182,0.074267,0.987999,0.95871
2,0.0763,0.100048,0.985728,0.957144
3,0.0513,0.151832,0.979565,0.953563
4,0.0756,0.085918,0.985728,0.955965
5,0.0384,0.083895,0.98962,0.967294
6,0.0361,0.098717,0.988972,0.966407
7,0.0089,0.097984,0.990269,0.971054
8,0.0177,0.093758,0.98962,0.967629
9,0.0386,0.09696,0.98962,0.969677
10,0.0171,0.08898,0.989945,0.970374


  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at ../geneformer-12L-30M and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


kidney


  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}


Epoch,Training Loss,Validation Loss,Accuracy,Macro F1
1,0.3745,0.418104,0.917614,0.872063
2,0.306,0.339669,0.931591,0.882974
3,0.1838,0.401208,0.934773,0.895995
4,0.0855,0.498755,0.932727,0.885104
5,0.0557,0.575111,0.929545,0.881147
6,0.0558,0.555059,0.936477,0.896368
7,0.0215,0.654156,0.936023,0.893097
8,0.0059,0.705073,0.935114,0.892168
9,0.0078,0.709571,0.939659,0.897715
10,0.002,0.733304,0.937159,0.894743


  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at ../geneformer-12L-30M and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


lung


  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}


Epoch,Training Loss,Validation Loss,Accuracy,Macro F1
1,0.285,0.355553,0.933793,0.854257
2,0.2439,0.298663,0.944215,0.897322
3,0.1737,0.35536,0.944674,0.885211
4,0.0932,0.400015,0.944521,0.8942
5,0.042,0.481196,0.941149,0.874735
6,0.0412,0.47048,0.948506,0.906068
7,0.0208,0.571224,0.944981,0.898732
8,0.0112,0.599385,0.946207,0.897745
9,0.0037,0.62273,0.946207,0.885586
10,0.0062,0.619837,0.947586,0.900416


  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at ../geneformer-12L-30M and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


brain


  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}


Epoch,Training Loss,Validation Loss,Accuracy,Macro F1
1,0.1355,0.171991,0.967718,0.713007
2,0.1382,0.130009,0.977102,0.863523
3,0.0771,0.140143,0.979354,0.86945
4,0.0445,0.149834,0.976351,0.885812
5,0.0237,0.23425,0.973348,0.822678
6,0.0011,0.231824,0.975601,0.85093
7,0.0,0.272678,0.975225,0.842565
8,0.0,0.302681,0.975601,0.838271
9,0.0,0.320088,0.97485,0.83304
10,0.0,0.334539,0.973724,0.822007


  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at ../geneformer-12L-30M and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


placenta


  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}


Epoch,Training Loss,Validation Loss,Accuracy,Macro F1
1,0.1279,0.150387,0.97411,0.959393
2,0.0769,0.08232,0.985437,0.977031
3,0.0722,0.107004,0.983819,0.975251
4,0.0628,0.130626,0.982201,0.972853
5,0.0376,0.147633,0.983819,0.974833
6,0.0186,0.145511,0.984358,0.976609
7,0.0,0.176854,0.982201,0.973161
8,0.0186,0.190506,0.980583,0.969789
9,0.0,0.187301,0.98274,0.973205
10,0.0,0.191507,0.98274,0.973801


  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at ../geneformer-12L-30M and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


immune


  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}


Epoch,Training Loss,Validation Loss,Accuracy,Macro F1
1,0.3158,0.34685,0.92821,0.857464
2,0.2192,0.314662,0.947471,0.870045
3,0.1524,0.315866,0.944747,0.896177
4,0.0664,0.314146,0.953891,0.91081
5,0.034,0.332014,0.958171,0.922297
6,0.0203,0.388642,0.955253,0.918228
7,0.016,0.494645,0.950973,0.906555
8,0.0115,0.453396,0.956226,0.91912
9,0.0011,0.447169,0.958366,0.923812
10,0.0,0.470735,0.958366,0.925349


  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at ../geneformer-12L-30M and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


large_intestine


  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}


Epoch,Training Loss,Validation Loss,Accuracy,Macro F1
1,0.364,0.394132,0.91754,0.856455
2,0.255,0.33699,0.934577,0.86597
3,0.1788,0.417955,0.929335,0.873945
4,0.1179,0.421227,0.93871,0.88586
5,0.063,0.44029,0.936794,0.885499
6,0.04,0.521299,0.940625,0.882971
7,0.0165,0.608636,0.938206,0.887446
8,0.0159,0.613944,0.941028,0.889703
9,0.0067,0.679225,0.940524,0.891661
10,0.0,0.692274,0.94123,0.892477


  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at ../geneformer-12L-30M and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


pancreas


  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}


Epoch,Training Loss,Validation Loss,Accuracy,Macro F1
1,0.3622,0.400625,0.91612,0.85865
2,0.2575,0.360589,0.93016,0.86313
3,0.136,0.346444,0.942013,0.887012
4,0.0835,0.437124,0.940372,0.894332
5,0.0526,0.442392,0.941831,0.892462
6,0.0478,0.551121,0.937272,0.889416
7,0.0053,0.563099,0.943837,0.889084
8,0.0086,0.629666,0.939825,0.895991
9,0.0013,0.677695,0.941284,0.893701
10,0.0,0.6844,0.944384,0.897918


  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at ../geneformer-12L-30M and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


liver


  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}


Epoch,Training Loss,Validation Loss,Accuracy,Macro F1
1,0.4384,0.493172,0.896558,0.778715
2,0.3321,0.375583,0.924202,0.821743
3,0.1875,0.450893,0.920278,0.807853
4,0.149,0.506308,0.923667,0.830291
5,0.1103,0.554555,0.923132,0.84196
6,0.0712,0.575345,0.926164,0.833691
7,0.0853,0.642671,0.918851,0.829214
8,0.0365,0.628874,0.924202,0.842332
9,0.0072,0.658359,0.92545,0.847783
10,0.0103,0.676417,0.924559,0.844086


  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
