In [2]:
#!/usr/bin/env python
# gene_prompt_trainer.py
# -----------------------------------------------------------
# Fine-tune a prompt-enabled Geneformer model for GENE
# classification using a single .dataset file.
# -----------------------------------------------------------

import argparse, json, yaml, pickle, random, datetime, pathlib
from pathlib import Path
from collections import defaultdict

import numpy as np
import torch, torch.nn as nn
from datasets import load_from_disk
from transformers import (
    BertForTokenClassification,
    Trainer, TrainingArguments, EarlyStoppingCallback,
)
from geneformer import DataCollatorForGeneClassification
import loralib as lora
import pandas as pd

# -----------------------------------------------------------
# 1 ▸ CLI
# -----------------------------------------------------------
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_file", default="/fs/scratch/PCON0022/ch/Geneformer/examples/example_input_files/gc-30M_sample50k.dataset")
parser.add_argument("--gene_class_dict", default="/fs/scratch/PCON0022/ch/Geneformer/examples/example_input_files/dosage_sensitivity_TFs.pickle")
parser.add_argument("--token_dict", 
                default="/fs/scratch/PCON0022/ch/scPEFT_reproduction/Geneformer/geneformer/token_dictionary_gc95M.pkl")
                # default="/fs/scratch/PCON0022/ch/scPEFT_reproduction/Geneformer/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl")
parser.add_argument("--ckpt_dir", 
                default="/fs/scratch/PCON0022/ch/scPEFT_reproduction/geneformer_peft/Pretrain_ckpts/Pretrain_ckpts/geneformer-12L-30M-prompt")
                # default="/fs/scratch/PCON0022/ch/Geneformer/gf-6L-30M-i2048")
parser.add_argument("--output_root", default="/fs/scratch/PCON0022/ch/scPEFT_reproduction/geneformer_peft/example_py/outputs")


parser.add_argument("--epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--lr",        type=float, default=5e-5)
parser.add_argument("--seed",      type=int, default=42)
args = parser.parse_args('')

torch.manual_seed(args.seed); random.seed(args.seed); np.random.seed(args.seed)

# -----------------------------------------------------------
# 2 ▸ Load data & 80/10/10 split
# -----------------------------------------------------------
full_ds = load_from_disk(args.dataset_file).shuffle(seed=args.seed)  # one .dataset only
tmp = full_ds.train_test_split(test_size=0.2, seed=args.seed)
val_test = tmp["test"].train_test_split(test_size=0.5, seed=args.seed + 1)
train_ds, eval_ds, test_ds = tmp["train"], val_test["train"], val_test["test"]

# -----------------------------------------------------------
# 3 ▸ Dict helpers
# -----------------------------------------------------------
def load_dict(pth):
    p = pathlib.Path(pth)
    with open(pth, "rb" if p.suffix == ".pkl" or p.suffix == ".pickle" else "r") as f:
        return (
            pickle.load(f) if p.suffix == ".pkl" or p.suffix == ".pickle"
            else json.load(f) if p.suffix == ".json"
            else yaml.safe_load(f)
        )

gene_class_dict = load_dict(args.gene_class_dict)      # {label: [ENS,…]}
token_dict      = load_dict(args.token_dict)           # {ENS: int_id}

# ↪ map gene token-id ➜ class-label
inverse_gene_dict = {
    token_dict[g]: cls for cls, genes in gene_class_dict.items() for g in genes if g in token_dict
}
class_id_dict = {cls: i for i, cls in enumerate(gene_class_dict.keys())}
id_class_dict = {v: k for k, v in class_id_dict.items()}

def label_example(ex):
    ex["labels"] = [
        class_id_dict.get(inverse_gene_dict.get(tok, None), -100)
        for tok in ex["input_ids"]
    ]
    return ex

# filter out cells without any labelled genes, then add "labels"
target_tokens = set(inverse_gene_dict.keys())
def keep_cell(ex): return not target_tokens.isdisjoint(ex["input_ids"])
train_ds = train_ds.filter(keep_cell, num_proc=16).map(label_example, num_proc=16)
eval_ds  =  eval_ds.filter(keep_cell, num_proc=16).map(label_example, num_proc=16)
test_ds  =  test_ds.filter(keep_cell, num_proc=16).map(label_example, num_proc=16)

# -----------------------------------------------------------
# 4 ▸ Collator 
# -----------------------------------------------------------
data_collator = DataCollatorForGeneClassification(token_dictionary=token_dict)

# -----------------------------------------------------------
# 5 ▸ Model: prompt-enabled checkpoint + PEFT unfreeze
# -----------------------------------------------------------
model = BertForTokenClassification.from_pretrained(
    args.ckpt_dir,
    num_labels=len(class_id_dict),
    ignore_mismatched_sizes=True,
).to("cuda")

prompt_types = [p.strip() for p in model.config.prompt_type.split(",") if p.strip()]

if "lora" in prompt_types:
    lora.mark_only_lora_as_trainable(model, bias="lora_only")
if "Gene_token_prompt" in prompt_types:
    for n,p in model.named_parameters():
        p.requires_grad = ("bert.adapter" in n) or ("classifier" in n)
if "encoder_prompt" in prompt_types:
    for n,p in model.named_parameters():
        p.requires_grad = (("Space_Adapter" in n) or ("MLP_Adapter" in n) or ("classifier" in n))
if "prefix_prompt" in prompt_types:
    for n,p in model.named_parameters():
        p.requires_grad = ("prompt_embeddings" in n) or ("classifier" in n)

# -----------------------------------------------------------
# 6 ▸ Trainer subclass – prefix-mask & token-level CE loss
# -----------------------------------------------------------
class PromptTrainer(Trainer):
    def __init__(self,*a,prompt_types=None,**kw):
        super().__init__(*a,**kw); self.prompt_types=prompt_types
    def compute_loss(self, model, inputs, return_outputs=False):
        if "prefix_prompt" in self.prompt_types:
            bs = inputs["input_ids"].size(0)
            pre = torch.ones(bs, model.config.num_token, device=inputs["input_ids"].device)
            inputs["attention_mask"] = torch.cat(
                (torch.cat((inputs["attention_mask"][:,:1], pre), dim=1),
                 inputs["attention_mask"][:,1:]), dim=1)
        out = model(**inputs)
        loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
        loss = loss_fct(out.logits.view(-1, model.num_labels),
                        inputs["labels"].view(-1))
        return (loss, out) if return_outputs else loss

# -----------------------------------------------------------
# 7 ▸ Metrics
# -----------------------------------------------------------
from sklearn.metrics import balanced_accuracy_score, precision_score, recall_score, f1_score
def compute_metrics(pred):
    lbls = pred.label_ids.reshape(-1)
    prds = pred.predictions.argmax(-1).reshape(-1)
    mask = lbls != -100
    lbls, prds = lbls[mask], prds[mask]
    return {
        "accuracy": balanced_accuracy_score(lbls, prds),
        "precision": precision_score(lbls, prds, average="macro"),
        "recall": recall_score(lbls, prds, average="macro"),
        "f1": f1_score(lbls, prds, average="macro"),
    }

# -----------------------------------------------------------
# 8 ▸ TrainingArguments & run folder
# -----------------------------------------------------------
run_dir = (Path(args.output_root)/("_".join(prompt_types) or "noprompt")/
           Path(args.dataset_file).stem/
           datetime.datetime.now().strftime("%y%m%d_%H%M%S"))
run_dir.mkdir(parents=True, exist_ok=True); model.config.save_pretrained(run_dir)

training_args = TrainingArguments(
    output_dir=str(run_dir),
    per_device_train_batch_size=args.batch_size,
    per_device_eval_batch_size=args.batch_size,
    learning_rate=args.lr,
    num_train_epochs=args.epochs,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    warmup_steps=500,
    report_to="none",
)

# -----------------------------------------------------------
# 9 ▸ Train & evaluate
# -----------------------------------------------------------
trainer = PromptTrainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=5)],
    prompt_types=prompt_types,
)

trainer.train()
test_metrics = trainer.evaluate(test_ds, metric_key_prefix="test")
trainer.save_metrics("test", test_metrics)
trainer.save_model(run_dir)
print(test_metrics)


Some weights of BertForTokenClassification were not initialized from the model checkpoint at /fs/scratch/PCON0022/ch/scPEFT_reproduction/geneformer_peft/Pretrain_ckpts/Pretrain_ckpts/geneformer-12L-30M-prompt and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.4706,0.432639,0.665003,0.819234,0.665003,0.691527
2,0.4121,0.392743,0.717448,0.830774,0.717448,0.747786
3,0.3944,0.382963,0.727679,0.835258,0.727679,0.758099


  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}


early stopping required metric_for_best_model, but did not find eval_loss so early stopping is disabled


{'test_loss': 0.3817903399467468, 'test_accuracy': 0.7267683781159545, 'test_precision': 0.8325678037322718, 'test_recall': 0.7267683781159545, 'test_f1': 0.7568869250096371, 'test_runtime': 64.7934, 'test_samples_per_second': 67.584, 'test_steps_per_second': 2.114, 'epoch': 3.0}
