In [20]:
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, Trainer, TrainingArguments, AutoModelForTokenClassification

In [6]:
model_checkpoint = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
data = {
    'tokens': [
        ['The', 'Patient', 'was', 'diagnosed', 'with', 'diabetes']
    ],
    'labels': [[0, 0, 0, 0, 0, 1]] # 0 for non-medical, 1 for medical terms
}
label_list = ['non-medical', 'medical']
id2label = {i: label for i, label in enumerate(label_list)}
label2id = {label: i for i, label in id2label.items()}
id2label

{0: 'non-medical', 1: 'medical'}

In [11]:
dataset = Dataset.from_dict(data)
dataset

Dataset({
    features: ['tokens', 'labels'],
    num_rows: 1
})

In [45]:
def align_labels_with_tokens(examples):
    tokenized_inputs = tokenizer(
        examples['tokens'],
        truncation=True,
        is_split_into_words=True,
        return_offsets_mapping=True
    )
    labels = []
    for i, label in enumerate(examples['labels']):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        label_ids = []
        for word_idx in word_ids:
            if word_idx is None:
                label_ids.append(-100)
            else:
                label_ids.append(label[word_idx])
        labels.append(label_ids)
    tokenized_inputs["labels"] = labels
    return tokenized_inputs

tokenized_datasets = dataset.map(align_labels_with_tokens, batched=True)
tokenized_datasets = dataset.map(align_labels_with_tokens, batched=True)
tokenized_datasets = tokenized_datasets.rename_column("labels", "label")
tokenized_datasets

Map: 100%|██████████| 1/1 [00:00<00:00, 277.92 examples/s]
Map: 100%|██████████| 1/1 [00:00<00:00, 306.58 examples/s]


Dataset({
    features: ['tokens', 'label', 'input_ids', 'token_type_ids', 'attention_mask', 'offset_mapping'],
    num_rows: 1
})

In [46]:
import torch
device = 'mps' if torch.backends.mps.is_available() else 'cpu'
model = AutoModelForTokenClassification.from_pretrained(
    model_checkpoint,
    num_labels=len(label_list),
    id2label=id2label,
    label2id=label2id
)
model.to(device)

Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertForTokenClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12

In [47]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = logits.argmax(axis=-1)
    accuracy = (predictions == labels).mean()
    return {'accuracy': accuracy}

In [48]:
args = TrainingArguments(
    learning_rate=2e-5,
    weight_decay=0.01,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    output_dir='output',
    num_train_epochs=2,
)
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized_datasets,
    eval_dataset=tokenized_datasets,
    compute_metrics=compute_metrics,
)

In [49]:
trainer.train()



Step,Training Loss


TrainOutput(global_step=2, training_loss=0.8022475838661194, metrics={'train_runtime': 2.761, 'train_samples_per_second': 0.724, 'train_steps_per_second': 0.724, 'total_flos': 8165523648.0, 'train_loss': 0.8022475838661194, 'epoch': 2.0})

In [50]:
trainer.evaluate()



{'eval_loss': 0.5017200112342834,
 'eval_accuracy': 0.625,
 'eval_runtime': 0.0256,
 'eval_samples_per_second': 39.07,
 'eval_steps_per_second': 39.07,
 'epoch': 2.0}

In [51]:
model.save_pretrained('ner-healthcare-sample')
tokenizer.save_pretrained('ner-healthcare-sample')

('ner-healthcare-sample/tokenizer_config.json',
 'ner-healthcare-sample/special_tokens_map.json',
 'ner-healthcare-sample/vocab.txt',
 'ner-healthcare-sample/added_tokens.json',
 'ner-healthcare-sample/tokenizer.json')

In [56]:
model_path = 'ner-healthcare-sample'
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForTokenClassification.from_pretrained(model_path)
model.eval()
sentence = 'The patient had diabetes'
inputs = tokenizer(sentence, return_tensors='pt', is_split_into_words=False)
outputs = model(**inputs)

In [57]:
logits = outputs.logits
predicted_label = torch.argmax(logits, dim=-1)
predicted_label

tensor([[1, 0, 0, 0, 1, 0]])

In [58]:
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
tokens

['[CLS]', 'the', 'patient', 'had', 'diabetes', '[SEP]']

In [59]:
predictions = predicted_label[0].tolist()
for token, prediction in zip(tokens, predictions):
    print(f"{token}: {id2label[prediction]}")

[CLS]: medical
the: non-medical
patient: non-medical
had: non-medical
diabetes: medical
[SEP]: non-medical
