In [None]:
!pip install -U transformers datasets accelerate evaluate seqeval

In [150]:
import os
import random
import numpy as np
from dataclasses import dataclass
from typing import List, Dict, Any

import torch
from datasets import Dataset, DatasetDict
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoConfig,
    AutoModelForTokenClassification,
    DataCollatorForTokenClassification,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback,
)
from evaluate import load as load_metric

In [133]:
DATA_DIR = "../data"  
MODEL_NAME = "bert-base-cased"  


#### load dataset

In [134]:
from datasets import load_from_disk

data = load_from_disk("../data/processed/bio_ner")

In [135]:

label_list: List[str] = [l.strip() for l in open(os.path.join(DATA_DIR, "labels.txt"), encoding="utf-8")]
id2label = {i: l for i, l in enumerate(label_list)}
label2id = {l: i for i, l in enumerate(label_list)}

In [168]:
from datasets import load_from_disk

tokenized_datasets = load_from_disk("../data/processed/bio_ner")


In [170]:

model = AutoModelForTokenClassification.from_pretrained(
                                                        MODEL_NAME, 
                                                         id2label=id2label,
                                                        label2id=label2id
                                                        )


Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


#### Data collator (dynamic padding for token classification)

In [169]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")  # or your checkpoint
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

In [138]:
data_collator

DataCollatorForTokenClassification(tokenizer=BertTokenizerFast(name_or_path='bert-base-cased', vocab_size=28996, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=False, added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
), padding=True, max_length=None, pad_to_multiple_of=None, label_pad_token_id=-100, r

#### Metrics

In [171]:
seqeval = load_metric("seqeval")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)

    # Remove ignored indices (-100) & convert to label strings
    true_labels, true_preds = [], []
    for pred, lab in zip(preds, labels):
        cur_true_labels, cur_true_preds = [], []
        for p, l in zip(pred, lab):
            if l == -100:
                continue
            cur_true_labels.append(id2label[l])
            cur_true_preds.append(id2label[p])
        true_labels.append(cur_true_labels)
        true_preds.append(cur_true_preds)

    results = seqeval.compute(predictions=true_preds, references=true_labels)
    
     # Aggregate main metrics
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }


In [183]:
%%time

args = TrainingArguments(
                        output_dir = "../outputs/distilbert-finetuned-ner",
                        eval_strategy = "epoch",
                        save_strategy="epoch",
                        load_best_model_at_end=True, 
                        metric_for_best_model="f1",  
                        greater_is_better=True,  #indicate for higher f1        
                        learning_rate = 2e-5,
                        num_train_epochs=1,
                        weight_decay=0.01,
                         report_to="none",   # disable MLflow/W&B logging
                           fp16=torch.cuda.is_available(), 
                         )

CPU times: total: 15.6 ms
Wall time: 2.22 s


In [184]:


trainer = Trainer(
    model=model,
    args= args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    # tokenizer=tokenizer,
    processing_class=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],  # stop if val doesn't improve
)

In [175]:
tokenized_datasets


DatasetDict({
    train: Dataset({
        features: ['tokens', 'ner_tags', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 4648
    })
    validation: Dataset({
        features: ['tokens', 'ner_tags', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 4657
    })
    test: Dataset({
        features: ['tokens', 'ner_tags', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 4861
    })
})

In [176]:
%%time
# Train, evaluate, and test
trainer.train()
val_results = trainer.evaluate()
print("Validation:", val_results)




Epoch,Training Loss,Validation Loss,Precision,Recall,F1,Accuracy
1,0.1769,0.114732,0.761783,0.795365,0.778212,0.95933




Validation: {'eval_loss': 0.11473238468170166, 'eval_precision': 0.7617833191650137, 'eval_recall': 0.7953648915187377, 'eval_f1': 0.7782119940174651, 'eval_accuracy': 0.9593295992546288, 'eval_runtime': 248.3535, 'eval_samples_per_second': 18.752, 'eval_steps_per_second': 2.347, 'epoch': 1.0}
CPU times: total: 2h 38min 44s
Wall time: 30min
