In [18]:
import os
from typing import List, Dict, Tuple
from transformers import AutoTokenizer
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset
from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer
import evaluate
import numpy as np

# Preprocessing

In [2]:
class Preprocessing_Maccrobat:
    def __init__(self, dataset_folder, tokenizer):
        self.file_ids = [
            f.split(".")[0] for f in os.listdir(dataset_folder) if f.endswith(".txt")
        ]

        self.text_files = [f + ".txt" for f in self.file_ids]
        self.anno_files = [f + ".ann" for f in self.file_ids]

        self.num_samples = len(self.file_ids)

        self.texts: List[str] = []
        for i in range(self.num_samples):
            file_path = os.path.join(dataset_folder, self.text_files[i])
            with open(file_path, "r") as f:
                self.texts.append(f.read())

        self.tags: List[Dict[str, str]] = []
        for i in range(self.num_samples):
            file_path = os.path.join(dataset_folder, self.anno_files[i])
            with open(file_path, "r") as f:
                text_bound_ann = [
                    t.split("\t") for t in f.read().split("\n") if t.startswith("T")
                ]
                text_bound_lst = []
                for text_b in text_bound_ann:
                    label = text_b[1].split(" ")
                    try:
                        _ = int(label[1])
                        _ = int(label[2])
                        tag = {
                            "text": text_b[-1],
                            "label": label[0],
                            "start": label[1],
                            "end": label[2],
                        }
                        text_bound_lst.append(tag)
                    except:
                        pass

                self.tags.append(text_bound_lst)
        self.tokenizer = tokenizer

    def process(self) -> Tuple[List[List[str]], List[List[str]]]:
        input_texts = []
        input_labels = []

        for idx in range(self.num_samples):
            full_text = self.texts[idx]
            tags = self.tags[idx]

            label_offset = []
            continuous_label_offset = []
            for tag in tags:
                offset = list(range(int(tag["start"]), int(tag["end"]) + 1))
                label_offset.append(offset)  # 345
                continuous_label_offset.extend(offset)  #  345

            all_offset = list(range(len(full_text)))
            zero_offset = [
                offset for offset in all_offset if offset not in continuous_label_offset
            ]
            zero_offset = Preprocessing_Maccrobat.find_continuous_ranges(
                zero_offset
            )  # 012 67

            self.tokens = []
            self.labels = []
            self._merge_offset(full_text, tags, zero_offset, label_offset)
            assert len(self.tokens) == len(
                self.labels
            ), f"Length of tokens and labels are not equal"

            input_texts.append(self.tokens)
            input_labels.append(self.labels)

        return input_texts, input_labels

    def _merge_offset(self, full_text, tags, zero_offset, label_offset):
        # zero: [[0,1,2], [6,7]] label: [[3,4,5]]
        i = j = 0
        while i < len(zero_offset) and j < len(label_offset):
            if zero_offset[i][0] < label_offset[j][0]:
                self._add_zero(full_text, zero_offset, i)
                i += 1
            else:
                self._add_label(full_text, label_offset, j, tags)
                j += 1

        while i < len(zero_offset):
            self._add_zero(full_text, zero_offset, i)
            i += 1

        while j < len(label_offset):
            self._add_label(full_text, label_offset, j, tags)
            j += 1

    def _add_zero(self, full_text, offset, index):
        start, *_, end = (
            offset[index]
            if len(offset[index]) > 1
            else (offset[index][0], offset[index][0] + 1)
        )
        text = full_text[start:end]
        text_tokens = self.tokenizer.tokenize(text)

        self.tokens.extend(text_tokens)
        self.labels.extend(["O"] * len(text_tokens))

    def _add_label(self, full_text, offset, index, tags):
        start, *_, end = (
            offset[index]
            if len(offset[index]) > 1
            else (offset[index][0], offset[index][0] + 1)
        )
        text = full_text[start:end]
        text_tokens = self.tokenizer.tokenize(text)

        self.tokens.extend(text_tokens)
        self.labels.extend(
            [f"B-{tags[index]['label']}"]
            + [f"I-{tags[index]['label']}"] * (len(text_tokens) - 1)
        )

    @staticmethod
    def build_label2id(tokens: List[List[str]]):
        label2id = {}
        id_counter = 0
        for token in [token for sublist in tokens for token in sublist]:
            if token not in label2id:
                label2id[token] = id_counter
                id_counter += 1
        return label2id

    @staticmethod
    def find_continuous_ranges(data: List[int]):  # [0, 1, 2, 6, 7]
        if not data:
            return []
        ranges = []
        start = data[0]  # 0
        prev = data[0]  # 0
        for number in data[1:]:  # [1, 2, 6, 7]
            if number != prev + 1:
                ranges.append(list(range(start, prev + 1)))
                start = number
            prev = number
        ranges.append(list(range(start, prev + 1)))
        return ranges

In [4]:
tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")

dataset_folder = "./MACCROBAT2018"

Maccrobat_builder = Preprocessing_Maccrobat(dataset_folder, tokenizer)
input_texts, input_labels = Maccrobat_builder.process()

# print
print(input_texts[0])
print(input_labels[0])

['a', '70', '-', 'year', '-', 'old', 'man', 'was', 'referred', 'to', 'our', 'hospital', 'for', 'gas', '##tric', 'cancer', 'that', 'was', 'detected', 'during', 'screening', 'by', 'es', '##op', '##ha', '##go', '##gas', '##tro', '##du', '##ode', '##nos', '##co', '##py', '(', 'e', '##g', '##d', ')', 'no', 'significant', 'medical', 'history', 'was', 'identified', ',', 'except', 'd', '##ys', '##uria', 'caused', 'by', 'bladder', 'contraction', 'initial', 'laboratory', 'data', 'showed', 'a', 'serum', 'level', 'of', 'af', '##p', 'of', '32', '.', '3', 'ng', '/', 'ml', '(', 'normal', 'range', ':', '0', '-', '15', 'ng', '/', 'ml', ')', ',', 'but', 'which', 'included', 'other', 'tumor', 'markers', ',', 'such', 'as', ',', 'car', '##cino', '##em', '##bry', '##onic', 'antigen', '(', 'ce', '##a', ')', 'and', 'car', '##bo', '##hy', '##dra', '##te', 'antigen', '19', '-', '9', '(', 'ca', '##19', '-', '9', ')', 'no', 'other', 'abnormal', '##ity', 'e', '##g', '##d', 'revealed', 'a', 'mass', 'ul', '##cer', '

In [5]:
label2id = Preprocessing_Maccrobat.build_label2id(input_labels)
id2label = {v: k for k, v in label2id.items()}

# print
print(label2id)
print(id2label)

{'O': 0, 'B-Age': 1, 'I-Age': 2, 'B-Sex': 3, 'B-Clinical_event': 4, 'B-Nonbiological_location': 5, 'B-Biological_structure': 6, 'I-Biological_structure': 7, 'B-Disease_disorder': 8, 'B-Diagnostic_procedure': 9, 'I-Diagnostic_procedure': 10, 'B-History': 11, 'I-History': 12, 'B-Sign_symptom': 13, 'I-Sign_symptom': 14, 'B-Detailed_description': 15, 'I-Detailed_description': 16, 'B-Lab_value': 17, 'I-Lab_value': 18, 'B-Coreference': 19, 'I-Coreference': 20, 'B-Distance': 21, 'I-Distance': 22, 'I-Disease_disorder': 23, 'B-Therapeutic_procedure': 24, 'I-Therapeutic_procedure': 25, 'B-Date': 26, 'I-Date': 27, 'B-Medication': 28, 'I-Medication': 29, 'B-Frequency': 30, 'I-Frequency': 31, 'B-Dosage': 32, 'I-Dosage': 33, 'I-Clinical_event': 34, 'B-Color': 35, 'I-Color': 36, 'B-Shape': 37, 'I-Shape': 38, 'B-Severity': 39, 'B-Duration': 40, 'I-Duration': 41, 'B-Administration': 42, 'I-Administration': 43, 'B-Personal_background': 44, 'B-Activity': 45, 'I-Activity': 46, 'B-Outcome': 47, 'I-Nonbiolo

In [7]:
inputs_train, inputs_val, labels_train, labels_val = train_test_split(
    input_texts, input_labels, test_size=0.2, random_state=42
)

print(len(inputs_train), len(inputs_val), len(labels_train), len(labels_val))

160 40 160 40


# Dataloader

In [9]:
MAX_LEN = 512


class NER_Dataset(Dataset):
    def __init__(self, input_texts, input_labels, tokenizer, label2id, max_len=MAX_LEN):
        super().__init__()
        self.tokens = input_texts
        self.labels = input_labels
        self.tokenizer = tokenizer
        self.label2id = label2id
        self.max_len = max_len

    def __len__(self):
        return len(self.tokens)

    def __getitem__(self, idx):
        input_token = self.tokens[idx]
        label_token = [self.label2id[label] for label in self.labels[idx]]

        input_token = self.tokenizer.convert_tokens_to_ids(input_token)
        attention_mask = [1] * len(input_token)

        input_ids = self.pad_and_truncate(
            input_token, pad_id=self.tokenizer.pad_token_id
        )
        labels = self.pad_and_truncate(label_token, pad_id=0)
        attention_mask = self.pad_and_truncate(attention_mask, pad_id=0)

        return {
            "input_ids": torch.as_tensor(input_ids),
            "labels": torch.as_tensor(labels),
            "attention_mask": torch.as_tensor(attention_mask),
        }

    def pad_and_truncate(self, inputs: List[int], pad_id: int):
        if len(inputs) < self.max_len:
            padded_inputs = inputs + [pad_id] * (self.max_len - len(inputs))
        else:
            padded_inputs = inputs[: self.max_len]
        return padded_inputs

    def label2id(self, labels: List[str]):
        return [self.label2id[label] for label in labels]

In [10]:
train_set = NER_Dataset(inputs_train, labels_train, tokenizer, label2id)
val_set = NER_Dataset(inputs_val, labels_val, tokenizer, label2id)

In [13]:
# print
print(train_set[0]["input_ids"])
print(train_set[0]["labels"])

tensor([ 1999,  2254,  2268,  1037,  5401,  1011,  2095,  1011,  2214,  2450,
        12636, 20630,  3591,  2007,  2460,  2791,  1997,  3052,  2016,  2018,
         1037,  1015,  1011,  3204,  2381,  1997,  2460,  2791,  1997,  3052,
         6555,  1998,  1037,  1999,  6912,  3977, 16612,  9885,  3905,  2000,
         1040,  7274,  2361, 22084, 10256,  2016,  2988,  2053,  3176,  8030,
         2012,  2287,  2871,  2016,  2018,  2042, 11441,  2007,  1037,  7388,
         4456,  2187,  1011, 11536,  1056,  2487, 24700,  2487,  2754,  2462,
         2050,  2754,  2462,  2050,  1010,  1056,  2487, 24700,  2487,  1010,
         2187,  1011, 11536,  7388,  4456,  3988,  3949,  2018,  2443,  1037,
        15116, 22471, 16940,  1998, 13045,  4487, 11393,  7542, 22260,  9386,
         2854,  2016,  3525,  9601,  1997,  2079,  2595,  7242,  1006,  4293,
        11460,  1013, 25525,  1018, 12709,  1010,  2628,  2011,  1022, 12709,
         1997, 22330, 20464,  7361, 15006, 21890, 24284,  2777, 

# Modeling

In [15]:
model = AutoModelForTokenClassification.from_pretrained(
    "d4data/biomedical-ner-all",
    label2id=label2id,
    id2label=id2label,
    ignore_mismatched_sizes=True,
)

model

Some weights of DistilBertForTokenClassification were not initialized from the model checkpoint at d4data/biomedical-ner-all and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([84]) in the checkpoint and torch.Size([83]) in the model instantiated
- classifier.weight: found shape torch.Size([84, 768]) in the checkpoint and torch.Size([83, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


DistilBertForTokenClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): DistilBertSdpaAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
   

# Metric

In [17]:
accuracy = evaluate.load("accuracy")


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    mask = labels != 0
    predictions = np.argmax(predictions, axis=-1)
    return accuracy.compute(predictions=predictions[mask], references=labels[mask])

# Trainer

In [19]:
training_args = TrainingArguments(
    output_dir="ner-biomedical-maccrobat2018",
    learning_rate=1e-4,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=20,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="epoch",
    load_best_model_at_end=True,
    optim="adamw_torch",
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_set,
    eval_dataset=val_set,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

  trainer = Trainer(


In [20]:
%%time
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,2.774,1.69574,0.337673
2,1.4472,0.9843,0.591136
3,0.9143,0.718306,0.703509
4,0.6507,0.618407,0.752539
5,0.5017,0.57702,0.763989
6,0.4034,0.552041,0.784118
7,0.3268,0.536559,0.785596
8,0.271,0.539278,0.786057
9,0.2259,0.542826,0.793536
10,0.192,0.539999,0.796953


CPU times: user 2min 37s, sys: 8.4 s, total: 2min 46s
Wall time: 2min 46s


TrainOutput(global_step=200, training_loss=0.4418525362014771, metrics={'train_runtime': 166.2238, 'train_samples_per_second': 19.251, 'train_steps_per_second': 1.203, 'total_flos': 418702245888000.0, 'train_loss': 0.4418525362014771, 'epoch': 20.0})

In [21]:
# trainer.push_to_hub(
#     commit_message="Training complete",
#     token="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"
# )

# Inference

In [22]:
%%time
test_sentence = """A 48 year-old female presented with vaginal bleeding and abnormal 
Pap smears. Upon diagnosis of invasive non-keratinizing SCC of the cervix, she 
underwent a radical hysterectomy with salpingo-oophorectomy which demonstrated 
positive spread to the pelvic lymph nodes and the parametrium. Pathological 
examination revealed that the tumour also extensively involved the lower uterine 
segment.
"""

# tokenization
input = torch.as_tensor([tokenizer.convert_tokens_to_ids(test_sentence.split())])
input = input.to("cuda")

# prediction
outputs = model(input)
_, preds = torch.max(outputs.logits, -1)
preds = preds[0].cpu().numpy()

# decode
for token, pred in zip(test_sentence.split(), preds):
    print(f"{token}\t{id2label[pred]}")

A	O
48	B-Lab_value
year-old	I-Age
female	B-Sex
presented	O
with	O
vaginal	O
bleeding	B-Sign_symptom
and	O
abnormal	B-Lab_value
Pap	B-Lab_value
smears.	I-Detailed_description
Upon	O
diagnosis	O
of	O
invasive	B-Detailed_description
non-keratinizing	O
SCC	O
of	O
the	O
cervix,	O
she	O
underwent	O
a	O
radical	O
hysterectomy	B-Lab_value
with	O
salpingo-oophorectomy	O
which	O
demonstrated	O
positive	B-Lab_value
spread	I-Lab_value
to	O
the	O
pelvic	B-Biological_structure
lymph	I-Biological_structure
nodes	O
and	O
the	O
parametrium.	B-Diagnostic_procedure
Pathological	I-Diagnostic_procedure
examination	O
revealed	O
that	O
the	O
tumour	O
also	O
extensively	O
involved	O
the	O
lower	O
uterine	I-Detailed_description
segment.	I-Detailed_description
CPU times: user 10.8 ms, sys: 13.9 ms, total: 24.7 ms
Wall time: 31 ms
