In [1]:
import torch
from datasets import load_from_disk

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Load the data. For test we can use small sample of GTDB that contain 1000 bacteria genomes from GTDB.

In [3]:
basedir = "/leonardo_scratch/fast/EUHPC_R04_194/training_datasets/gtdb/toy_dataset1024_bacterias/"
batch_size = 64
dataset_training = load_from_disk(basedir)['train']
dataset_test = load_from_disk(basedir)['validation']

We also can filter out 5-10 classes just for test case

In [4]:
num_classes = 5
allowed_labels = dataset_test.unique('labels')[:num_classes]

dataset_training = dataset_training.filter(
    lambda batch: [label in allowed_labels for label in batch["labels"]],
    batched=True
)
dataset_test = dataset_test.filter(
    lambda batch: [label in allowed_labels for label in batch["labels"]],
    batched=True
)
print("After filtering: ")
print(f"Size of train {len(dataset_training)}")
print(f"Size of test {len(dataset_test)}")
print(f"Number of unique classes {len(allowed_labels )}")

After filtering: 
Size of train 3570
Size of test 10
Number of unique classes 5


We need to use label encodeing for GTDB dataset to map labels to interval (0, number of classes)

In [5]:
unique_cats = dataset_training.unique('labels')
cat2id = {cat: i for i, cat in enumerate(unique_cats)}

def encode_batch(batch):
    batch['labels'] = [cat2id[c] for c in batch['labels']]
    return batch

dataset_training = dataset_training.map(encode_batch, batched=True)
dataset_test = dataset_test.map(encode_batch, batched=True)

num_of_classes = len(unique_cats)
print(f"Size of train {len(dataset_training)}")
print(f"Number of unique classes {num_of_classes}")

Map: 100%|██████████| 10/10 [00:00<00:00, 991.40 examples/s]

Size of train 3570
Number of unique classes 5





Load model to use for finetuning, we can use version from models.py or from models2.py

In [6]:
from prokbert.models import *

bert_model_path = "neuralbioinfo/prokbert-mini-long"
model = ProkBertForCurricularClassification.from_pretrained(
    bert_model_path,
    bert_base_model = bert_model_path,
    torch_dtype=torch.bfloat16,
    curricular_num_labels = num_classes,
    curricular_face_m = 0.5,
    curricular_face_s = 64.0,
    classification_dropout_rate = 0.1,
    curriculum_hidden_size = 128,
)

model = model.to(device)

num_params = sum(p.numel() for p in model.parameters())
print(f"Number of params of the model: {num_params}")

Some weights of ProkBertForCurricularClassification were not initialized from the model checkpoint at neuralbioinfo/prokbert-mini-long and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'curricular_face.kernel', 'curricular_face.t', 'linear.bias', 'linear.weight', 'weighting_layer.bias', 'weighting_layer.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Number of params of the model: 26630017


Now handle params for the model (freeze all params exept head - to increase speed of training). In real training we use both grups of params but with different lr

In [7]:
from transformers import AdamW

print("Set up learning utilities")
bert_params = []
head_params = []

for name, param in model.named_parameters():
    if "bert" in name:
        param.requires_grad = False
        bert_params.append(param)
    else:
        head_params.append(param)

optimizer = AdamW([
    {'params': head_params, 'lr': 0.001}
])
print(f"Num trainable params: {sum(p.numel() for p in head_params)}")

Set up learning utilities
Num trainable params: 50305




Set up scheduler and collator

In [8]:
from prokbert.tokenizer import LCATokenizer
from transformers import DataCollatorWithPadding
from transformers import get_scheduler

num_warmup = 0
max_steps = 10
tokenizer = LCATokenizer(kmer=6, shift=2, vocab_file = "/leonardo_work/EUHPC_R04_194/prokbert/src/prokbert/data/prokbert_vocabs/prokbert-base-dna6/vocab.txt")

scheduler = get_scheduler(
    "cosine",
    optimizer = optimizer,
    num_warmup_steps = num_warmup,
    num_training_steps = max_steps
)

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

Set up training process with Trainer

In [9]:
from transformers import TrainingArguments, Trainer

output_model_path = "./" #path to save trained models
training_args = TrainingArguments(
    output_dir=output_model_path,
    eval_strategy="steps",
    overwrite_output_dir = False,
    logging_strategy = "steps",
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    gradient_accumulation_steps=1,
    max_steps = max_steps, 
    weight_decay=0.001,
    logging_steps=1,
    report_to=None,
    eval_steps = 1,
    eval_accumulation_steps=1,
    dataloader_num_workers=1,
    dataloader_prefetch_factor=1,
    torch_compile=False,
    bf16=True,
    save_total_limit=1,              # limit the total amount of checkpoints
    save_steps = 10,
    load_best_model_at_end=True,
    max_grad_norm=1.0,  # <- this enables gradient clipping!
    ddp_find_unused_parameters=True,
)

In [10]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset_training,
    eval_dataset=dataset_test,
    processing_class=tokenizer,
    data_collator=data_collator,
    optimizers=(optimizer, scheduler),
)

Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [11]:
trainer.train()

  ctx_manager = torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)


Step,Training Loss,Validation Loss
1,32.0113,15.271637
2,15.6335,21.0742


: 