<a href="https://colab.research.google.com/github/conglapgit45/Medical_NER/blob/main/Medical_NER.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 1. Processing

In [1]:
!unzip /content/MACCROBAT2018.zip -d /content/MACCROBAT2018

Archive:  /content/MACCROBAT2018.zip
  inflating: /content/MACCROBAT2018/15939911.ann  
  inflating: /content/MACCROBAT2018/15939911.txt  
  inflating: /content/MACCROBAT2018/16778410.ann  
  inflating: /content/MACCROBAT2018/16778410.txt  
  inflating: /content/MACCROBAT2018/17803823.ann  
  inflating: /content/MACCROBAT2018/17803823.txt  
  inflating: /content/MACCROBAT2018/18236639.ann  
  inflating: /content/MACCROBAT2018/18236639.txt  
  inflating: /content/MACCROBAT2018/18258107.ann  
  inflating: /content/MACCROBAT2018/18258107.txt  
  inflating: /content/MACCROBAT2018/18416479.ann  
  inflating: /content/MACCROBAT2018/18416479.txt  
  inflating: /content/MACCROBAT2018/18561524.ann  
  inflating: /content/MACCROBAT2018/18561524.txt  
  inflating: /content/MACCROBAT2018/18666334.ann  
  inflating: /content/MACCROBAT2018/18666334.txt  
  inflating: /content/MACCROBAT2018/18787726.ann  
  inflating: /content/MACCROBAT2018/18787726.txt  
  inflating: /content/MACCROBAT2018/18815636.

In [2]:
import os
from typing import List, Dict, Tuple

class Preprocessing_Maccrobat:
    def __init__(self, dataset_folder, tokenizer):
        self.file_ids = [f.split(".")[0] for f in os.listdir(dataset_folder) if f.endswith ('.txt')]

        self.text_files = [f + ".txt" for f in self.file_ids]
        self.anno_files = [f + ".ann" for f in self.file_ids]

        self.num_samples = len(self.file_ids)

        self.texts:List[str] = []

        for i in range(self.num_samples):
            file_path = os.path.join(dataset_folder, self.text_files[i])
            with open(file_path, 'r') as f:
                self.texts.append(f.read())

        self.tags:List[Dict[str, str]] = []
        for i in range(self.num_samples):
            file_path = os.path.join(dataset_folder, self.anno_files[i])
            with open(file_path, 'r') as f:
                text_bound_ann = [t.split("\t") for t in f.read().split ("\n") if t.startswith ("T")]
                text_bound_lst = []
                for text_b in text_bound_ann:
                    label = text_b[1].split(" ")
                    try:
                        _ = int(label[1])
                        _ = int(label[2])
                        tag = {
                            "text": text_b[-1],
                            "label": label[0],
                            "start": label[1],
                            "end": label[2]
                        }
                        text_bound_lst.append(tag)
                    except:
                        pass

                self.tags.append(text_bound_lst)
        self.tokenizer = tokenizer

    def process(self) -> Tuple[List[List[str]], List[List[str]]]:
        input_texts = []
        input_labels = []

        for idx in range(self.num_samples):
            full_text = self.texts[idx]
            tags = self.tags[idx]

            label_offset = []
            continuous_label_offset = []
            for tag in tags:
                offset = list(range(int(tag["start"]), int(tag["end"]) + 1))
                label_offset.append(offset)
                continuous_label_offset.extend(offset)

            all_offset = list(range(len(full_text)))
            zero_offset = [offset for offset in all_offset if offset not in continuous_label_offset]
            zero_offset = Preprocessing_Maccrobat.find_continuous_ranges(zero_offset)

            self.tokens = []
            self.labels = []
            self._merge_offset(full_text, tags, zero_offset, label_offset)
            assert len(self.tokens ) == len(self.labels), f"Length of tokens and labelsare not equal"

            input_texts.append(self.tokens)
            input_labels.append(self.labels)

        return input_texts , input_labels

    def _merge_offset(self, full_text, tags, zero_offset, label_offset):
        i = j = 0
        while i < len(zero_offset) and j < len(label_offset):
            if zero_offset[i][0] < label_offset[j][0]:
                self._add_zero(full_text, zero_offset, i)
                i += 1
            else:
                self._add_label(full_text, label_offset, j, tags)
                j += 1

        while i < len(zero_offset):
            self._add_zero(full_text, zero_offset, i)
            i += 1

        while j < len(label_offset):
            self._add_label(full_text, label_offset, j, tags)
            j += 1

    def _add_zero(self, full_text, offset, index):
        start, *_, end = offset[index] if len(offset[index]) > 1 else (offset[index][0], offset[index][0] + 1)
        text = full_text[start:end]
        text_tokens = self.tokenizer.tokenize(text)

        self.tokens.extend(text_tokens)
        self.labels.extend(
            ["O"] * len(text_tokens)
        )

    def _add_label(self, full_text, offset, index, tags):
        start, *_, end = offset[index] if len(offset[index]) > 1 else (offset[index][0], offset[index][0] + 1)
        text = full_text[start:end]
        text_tokens = self.tokenizer.tokenize(text)

        self.tokens.extend(text_tokens)
        self.labels.extend(
            [f"B-{tags[index]['label']}"] + [f"I-{tags[index]['label']}"]*(len(text_tokens)-1)
        )

    @staticmethod
    def build_label2id(tokens: List[List[str]]):
        label2id = {}
        id_counter = 0
        for token in[token for sublist in tokens for token in sublist]:
            if token not in label2id:
                label2id[token] = id_counter
                id_counter += 1
        return label2id

    @staticmethod
    def find_continuous_ranges(data: List[int]):
        if not data:
            return []
        ranges = []
        start = data[0]
        prev = data[0]
        for number in data[1:]:
            if number != prev + 1:
                ranges.append(list(range(start, prev + 1)))
                start = number
            prev = number
        ranges.append(list(range(start, prev + 1)))
        return ranges


# Preprocessing
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")

dataset_folder = "/content/MACCROBAT2018"

Maccrobat_builder = Preprocessing_Maccrobat(dataset_folder, tokenizer)
input_texts, input_labels = Maccrobat_builder.process()

label2id = Preprocessing_Maccrobat.build_label2id(input_labels)
id2label = {v: k for k, v in label2id.items()}


# Split
from sklearn.model_selection import train_test_split

inputs_train, inputs_val, labels_train, labels_val = train_test_split(
    input_texts,
    input_labels,
    test_size = 0.2,
    random_state = 42
)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/373 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

# 2. Dataloader

In [3]:
import torch
from torch.utils.data import Dataset

MAX_LEN = 512

class NER_Dataset(Dataset):
    def __init__(self, input_texts, input_labels, tokenizer, label2id, max_len = MAX_LEN):
        super().__init__()
        self.tokens = input_texts
        self.labels = input_labels
        self.tokenizer = tokenizer
        self.label2id = label2id
        self.max_len = max_len

    def __len__(self):
        return len(self.tokens)

    def __getitem__(self, idx):
        input_token = self.tokens[idx]
        label_token = [self.label2id[label] for label in self.labels[idx]]

        input_token = self.tokenizer.convert_tokens_to_ids(input_token)
        attention_mask = [1] * len(input_token)

        input_ids = self.pad_and_truncate(input_token, pad_id = self.tokenizer.pad_token_id)

        labels = self.pad_and_truncate(label_token, pad_id=0)
        attention_mask = self.pad_and_truncate(attention_mask, pad_id=0)

        return {
            "input_ids": torch.as_tensor(input_ids),
            "labels": torch.as_tensor(labels),
            "attention_mask": torch.as_tensor(attention_mask)
        }

    def pad_and_truncate(self, inputs: List[int], pad_id: int):
        if len(inputs) < self.max_len:
            padded_inputs = inputs + [pad_id] * (self.max_len - len(inputs))
        else:
            padded_inputs = inputs[:self.max_len]
        return padded_inputs

    def label2id(self, labels: List[str]):
        return [self.label2id[label] for label in labels]

train_set = NER_Dataset(inputs_train, labels_train, tokenizer, label2id)
val_set = NER_Dataset(inputs_val, labels_val, tokenizer, label2id)

# 3. Modeling

In [4]:
from transformers import AutoModelForTokenClassification

model = AutoModelForTokenClassification.from_pretrained(
    "d4data/biomedical-ner-all",
    label2id = label2id,
    id2label = id2label,
    ignore_mismatched_sizes = True
)

config.json:   0%|          | 0.00/5.00k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/266M [00:00<?, ?B/s]

Some weights of DistilBertForTokenClassification were not initialized from the model checkpoint at d4data/biomedical-ner-all and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([84]) in the checkpoint and torch.Size([83]) in the model instantiated
- classifier.weight: found shape torch.Size([84, 768]) in the checkpoint and torch.Size([83, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


# 4. Metric

In [5]:
!pip install evaluate

Collecting evaluate
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Collecting datasets>=2.0.0 (from evaluate)
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting dill (from evaluate)
  Downloading dill-0.3.9-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from evaluate)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess (from evaluate)
  Downloading multiprocess-0.70.17-py311-none-any.whl.metadata (7.2 kB)
Collecting dill (from evaluate)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting multiprocess (from evaluate)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec>=2021.05.0 (from fsspec[http]>=2021.05.0->evaluate)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading evaluate-0.4.3-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.0/84.0 kB[0m [

In [6]:
import evaluate
import numpy as np

accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    mask = labels != 0
    predictions = np.argmax(predictions, axis=-1)
    return accuracy.compute(predictions = predictions[mask], references=labels[mask])

Downloading builder script:   0%|          | 0.00/4.20k [00:00<?, ?B/s]

# 5. Trainer

In [7]:
from transformers import TrainingArguments , Trainer

training_args = TrainingArguments(
    output_dir = "out_dir",
    learning_rate = 1e-4,
    per_device_train_batch_size = 16,
    per_device_eval_batch_size = 16,
    num_train_epochs = 20,
    eval_strategy = "epoch",
    save_strategy = "epoch",
    load_best_model_at_end = True,
    optim = "adamw_torch"
)

trainer = Trainer(
    model = model,
    args = training_args,
    train_dataset = train_set,
    eval_dataset = val_set,
    tokenizer = tokenizer,
    compute_metrics = compute_metrics,
)

trainer.train()

  trainer = Trainer(


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mconglap45[0m ([33mconglap45-cong-lap[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Epoch,Training Loss,Validation Loss,Accuracy
1,No log,1.711359,0.26625
2,No log,1.098727,0.552858
3,No log,0.881726,0.646433
4,No log,0.789557,0.692099
5,No log,0.7568,0.709936
6,No log,0.740555,0.716847
7,No log,0.729252,0.715446
8,No log,0.738914,0.723571
9,No log,0.749758,0.728708
10,No log,0.75586,0.735805


TrainOutput(global_step=200, training_loss=0.4138778305053711, metrics={'train_runtime': 523.9878, 'train_samples_per_second': 6.107, 'train_steps_per_second': 0.382, 'total_flos': 418702245888000.0, 'train_loss': 0.4138778305053711, 'epoch': 20.0})

# 6. Inference

In [9]:
test_sentence = """
A 48 year-old female presented with vaginal bleeding and abnormal
Pap smears. Upon diagnosis of invasive non-keratinizing SCC of the cervix, she
underwent a radical hysterectomy with salpingo-oophorectomy which demonstrated
positive spread to the pelvic lymph nodes and the parametrium. Pathological
examination revealed that the tumour also extensively involved the lower uterine
segment."""

# tokenization
input = torch.as_tensor([tokenizer.convert_tokens_to_ids(test_sentence.split())])
input = input.to("cuda")

# prediction
outputs = model(input)
_, preds = torch.max(outputs.logits, -1)
preds = preds[0].cpu().numpy()

# decode
for token, pred in zip(test_sentence.split(), preds):
    print(f"{token}\t{id2label[pred]}")

A	O
48	O
year-old	I-Lab_value
female	B-Sex
presented	O
with	O
vaginal	O
bleeding	B-Sign_symptom
and	O
abnormal	B-Sign_symptom
Pap	O
smears.	O
Upon	O
diagnosis	O
of	O
invasive	B-Detailed_description
non-keratinizing	O
SCC	O
of	O
the	O
cervix,	O
she	O
underwent	O
a	O
radical	O
hysterectomy	O
with	O
salpingo-oophorectomy	O
which	O
demonstrated	O
positive	B-Lab_value
spread	I-Lab_value
to	O
the	O
pelvic	O
lymph	O
nodes	O
and	O
the	O
parametrium.	O
Pathological	O
examination	O
revealed	O
that	O
the	O
tumour	O
also	O
extensively	O
involved	O
the	O
lower	O
uterine	O
segment.	O
