In [None]:
import os
import logging
import sys
from typing import Union

import pandas as pd
from datasets import load_dataset
from pprint import pprint
from transformers import AutoTokenizer, pipeline, PreTrainedTokenizer
from tqdm import tqdm


log = logging.getLogger()
log.setLevel(logging.INFO)

handler = logging.StreamHandler(sys.stdout)
handler.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
log.addHandler(handler)

tqdm.pandas()
# some tokenizers require this
os.environ["TOKENIZERS_PARALLELISM"] = "false"

## Data loading

In [None]:
data = load_dataset("medmcqa", split="train").to_pandas()
data.fillna("", inplace=True)
data["text"] = data["question"] + "\n" + data["exp"]
data = data[data["subject_name"] == "Medicine"]
data = data.sample(1000)

## Models comparison

Models to check:
- `jarvisx17/medicine-ner` - doesn't work
- `ukkendane/bert-medical-ner` - works quite well
- `samrawal/bert-large-uncased_med-ner` - too heavy and results are questionable
- `samrawal/bert-base-uncased_clinical-ner` - better than previous one, but still some broken ents (fixed with "first" strategy?)
- `reginaboateng/clinical_bert_adapter_ner_pico_for_classification_task` - adapter-transformers lib is needed, so let's skip it but it's probably ok model

In [None]:
def test_model(
    model_name: str, data: pd.DataFrame, aggregation_strategy=None
) -> pd.DataFrame:
    pipe = pipeline(
        task="ner", model=model_name, aggregation_strategy=aggregation_strategy
    )
    data["res"] = data["text"].progress_map(pipe)
    print(len(data))
    # at least a single entity is predicted
    data_with_res = data[data["res"].map(bool)]
    log.info("Data len: {0}".format(len(data)))
    log.info("Number or records with entities: {0}".format(len(data_with_res)))
    return data_with_res


def print_some_results(data_with_res: pd.DataFrame, sample: Union[int, float] = 5):
    for i, row in data_with_res.sample(min(sample, len(data_with_res))).iterrows():
        print("Id:", i, "\n")
        print("Input text:", row["text"], "\b")
        print("Entities:")
        pprint(row["res"])
        print("-" * 100)

class EntityDecoder:
    def __init__(
        self,
        tokenizer: PreTrainedTokenizer,
        tag_sep: str = "-",
        entity_key: str = "entity",
    ):
        """Replacement for hf ner pipeline if it doesn't work

        Use to assemble human-readable entities from classified tokens.

        Args:
            tokenizer (transformers.AutoTokenizer): proper transformers tokenizer
            tag_sep (str, optional): separator in class names for bio tags
                (in B-PERSON "-" is a separator). Defaults to "-".
        """
        self.tokenizer = tokenizer
        self.tag_sep = tag_sep
        self.entity_key = entity_key

    def __call__(self, ents: list) -> list:
        """Transform list of token entities from NER model to words

        Args:
            ents (list): list of dicts with entities

        Returns:
            list: list or assembled entities into words
        """
        grouped_ents = self._group(ents)
        return self._merge(grouped_ents)

    def _group(self, ents):
        if not ents:
            return []
        res_ents = []

        # TODO: test if it's better to just use another if in the main loop
        for i, ent in enumerate(ents):
            if ent[self.entity_key].startswith(f"B{self.tag_sep}"):
                current_entity = [ent]
                current_entity_class = current_entity[0][self.entity_key].split(
                    self.tag_sep, maxsplit=1
                )[1]
                break
            else:
                continue
        else:
            return []

        for ent in ents[i + 1 :]:
            if ent[self.entity_key].startswith(f"B{self.tag_sep}"):
                res_ents.append(current_entity)
                current_entity = [ent]
                current_entity_class = current_entity[0][self.entity_key].split(
                    self.tag_sep, maxsplit=1
                )[1]
            elif (
                ent[self.entity_key] == f"I{self.tag_sep}{current_entity_class}"
                and ent["start"] - 1 == current_entity[-1]["end"]
            ):
                current_entity.append(ent)
            else:
                pass
                # skip entities that start with I tag
        res_ents.append(current_entity)
        return res_ents

    def _merge(self, grouped_ents):
        res = []
        for ent in grouped_ents:
            temp_ent = {
                "tag": ent[0][self.entity_key].split(self.tag_sep, maxsplit=1)[1],
                "text": self.tokenizer.convert_tokens_to_string(
                    [x["word"] for x in ent]
                ),
                "start": ent[0]["start"],
                "end": ent[-1]["end"],
            }
            res.append(temp_ent)
        return res

In [None]:
model_name = "jarvisx17/medicine-ner"

test_jarvisx17 = test_model(model_name, data)

print_some_results(test_jarvisx17)

In [None]:
model_name = "ukkendane/bert-medical-ner"

tokenizer = AutoTokenizer.from_pretrained(model_name)
ent_decoder_underscore = EntityDecoder(tokenizer=tokenizer, entity_key="entity_group", tag_sep="_")

test_ukkendane = test_model(model_name, data[:100], aggregation_strategy="first")

test_ukkendane['res'] = test_ukkendane['res'].map(ent_decoder_underscore)

# print_some_results(test_ukkendane)

In [None]:
model_name = "samrawal/bert-base-uncased_clinical-ner"

tokenizer = AutoTokenizer.from_pretrained(model_name)
ent_decoder_dash = EntityDecoder(tokenizer=tokenizer, tag_sep="-")

test_samrawal_base = test_model(model_name, data, aggregation_strategy="first")

print_some_results(test_samrawal_base)

In [None]:
model_name = "samrawal/bert-large-uncased_med-ner"

test_samrawal = test_model(model_name, data, aggregation_strategy="first")

print_some_results(test_samrawal)