In [1]:
import os
from pathlib import Path

# Preparation

First of all, make sure to download the corresponding medical ontologies to build the term dictionaries.
 - Treatments: [OPS](https://www.bfarm.de/DE/Kodiersysteme/Services/Downloads/_node.html) 
 - Medications: [ATC](https://www.wido.de/publikationen-produkte/arzneimittel-klassifikation/)
 - Diagnosis: [ICD10GM](https://www.bfarm.de/DE/Kodiersysteme/Services/Downloads/_node.html)
 
In the config file for BRONCO `../conf/bronco.yaml`, modify the paths so they point the extracted ontologies.

We can already use xMEN to prepare the term dictionaries. In your terminal, navigate to the xMEN root folder and run:
 - `xmen dict conf/bronco.yaml --code dicts/atc2017_de.py --output temp/ --key atc`
 - `xmen dict conf/bronco.yaml --code dicts/ops2017.py --output temp/ --key ops`
 - `xmen dict conf/bronco.yaml --code dicts/icd10gm2017.py --output temp/ --key icd10gm`
 
Now use such dictionaries to build the indexes. For this example, we will use only SapBERT indexes and leave aside N-Gram:
 - `xmen index conf/bronco.yaml --dict temp/atc.jsonl --output temp/atc_index --sapbert`
 - `xmen index conf/bronco.yaml --dict temp/ops.jsonl --output temp/ops_index --sapbert`
 - `xmen index conf/bronco.yaml --dict temp/icd10gm.jsonl --output temp/icd10gm_index --sapbert`
 
Finally, load the adapted config file and the BRONCO150 dataset using BigBIO:

In [2]:
from xmen.confhelper import load_config
config = load_config("../conf/bronco.yaml")

import datasets
path_to_data = r"../../BRONCO150" # paste here the path to the local data
bronco = datasets.load_dataset(path = "bigbio/bronco", 
                           name = "bronco_bigbio_kb", 
                           data_dir=path_to_data)

bronco

Found cached dataset bronco (/dhc/home/ignacio.rodriguez/.cache/huggingface/datasets/bigbio___bronco/bronco_bigbio_kb-data_dir=..%2F..%2FBRONCO150/1.0.0/cab8fc4a62807688cb5b36df7a24eb7f364314862c4196f6ff2db3813f2fe68b)


  0%|          | 0/1 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'document_id', 'passages', 'entities', 'events', 'coreferences', 'relations'],
        num_rows: 5
    })
})

We have to filter out the semantic classes that we are not aiming to predict.

In [3]:
label = "DIAGNOSIS" # Choose here TREATMENT, MEDICATION or DIAGNOSIS
label2dict = {
    "TREATMENT": "ops",
    "MEDICATION": "atc",
    "DIAGNOSIS": "icd10gm"
}

In [4]:
def filter_entities(bigbio_entities: str, valid_entities: list):
    filtered_entities = []
    for ent in bigbio_entities:
        if ent['type'] in valid_entities:
            filtered_entities.append(ent)
    return filtered_entities

ds = bronco.map(lambda row: {'entities': filter_entities(row['entities'], [label])})

Loading cached processed dataset at /dhc/home/ignacio.rodriguez/.cache/huggingface/datasets/bigbio___bronco/bronco_bigbio_kb-data_dir=..%2F..%2FBRONCO150/1.0.0/cab8fc4a62807688cb5b36df7a24eb7f364314862c4196f6ff2db3813f2fe68b/cache-cce0dc987032834e.arrow


# Run Candidate Generator
We will use `SapBERTLinker`, which uses a Transformer model to retrieve candidates with dense embeddings.

We could have also used `TFIDFNGramLinker` (see `xmen/notebooks/BioASQ_DisTEMIST.ipynb` for an example).

In [5]:
from notebook_util import analyze
from xmen.linkers import TFIDFNGramLinker, SapBERTLinker, EnsembleLinker
from xmen.linkers.util import filter_and_apply_threshold
from datasets import DatasetDict
from xmen.evaluation import entity_linking_error_analysis, evaluate_at_k
from xmen.linkers.util import filter_and_apply_threshold

Your CPU supports instructions that this binary was not compiled to use: SSE3 SSE4.1 SSE4.2 AVX AVX2
For maximum performance, you can install NMSLIB from sources 
pip install --no-binary :all: nmslib


In [7]:
embedding_model_name = 'cambridgeltl/SapBERT-UMLS-2020AB-all-lang-from-XLMR'

# Clear singleton to free up memory
SapBERTLinker.clear()
sapbert_linker = SapBERTLinker(
    embedding_model_name = embedding_model_name,
    index_base_path = (f"../temp/{label2dict[label]}_index/index/sapbert",
    k = 1000
)

pred_sapbert = sapbert_linker.predict_batch(ds, batch_size=128)

# Save locally to avoid running it every time
#pred_sapbert.save_to_disk(f"../temp/{label2dict[label]}_index/pred_sapbert")

Map:   0%|          | 0/5 [00:00<?, ? examples/s]

In [6]:
# Recall for different numbers of candidates (k)
candidates = datasets.load_from_disk(f"../temp/{label2dict[label]}_index/pred_sapbert")
_ = evaluate_at_k(ds['train'].select([3]), candidates['train'].select([3]))

Perf@1 0.20959595959595959
Perf@2 0.2916666666666667
Perf@4 0.3560606060606061
Perf@8 0.3939393939393939
Perf@16 0.4671717171717172
Perf@32 0.5202020202020202
Perf@64 0.5782828282828283


# Train Cross-encoder
We use a cross-encoder to embed the mention with their context together with all potential candidates. This way, we can learn the best ranking of candidates from the training data.

BRONCO150 is designed in 5 folds for cross-validation. We split the data before passing the cross-encoder pre-processing.

In [7]:
kandidates = DatasetDict()
ground_truth = DatasetDict()

for k in range(5):
    kandidates[f"k{k+1}"] = candidates["train"].select([k])
    ground_truth[f"k{k+1}"] = ds["train"].select([k])

kandidates

DatasetDict({
    k1: Dataset({
        features: ['id', 'document_id', 'passages', 'entities', 'events', 'coreferences', 'relations'],
        num_rows: 1
    })
    k2: Dataset({
        features: ['id', 'document_id', 'passages', 'entities', 'events', 'coreferences', 'relations'],
        num_rows: 1
    })
    k3: Dataset({
        features: ['id', 'document_id', 'passages', 'entities', 'events', 'coreferences', 'relations'],
        num_rows: 1
    })
    k4: Dataset({
        features: ['id', 'document_id', 'passages', 'entities', 'events', 'coreferences', 'relations'],
        num_rows: 1
    })
    k5: Dataset({
        features: ['id', 'document_id', 'passages', 'entities', 'events', 'coreferences', 'relations'],
        num_rows: 1
    })
})

In [9]:
from xmen.reranking.cross_encoder import CrossEncoderReranker, CrossEncoderTrainingArgs
from xmen.linkers.util import filter_and_apply_threshold

In [10]:
K_RERANKING = 64
candidates = filter_and_apply_threshold(kandidates, K_RERANKING, 0.0)

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

In [18]:
from xmen.kb import load_kb
kb = load_kb(f"../temp/{label2dict[label]}.jsonl")
cross_enc_ds = CrossEncoderReranker.prepare_data(candidates, ground_truth, kb, encode_sem_type=False)

Context length: 128


  0%|          | 0/847 [00:00<?, ?it/s]

  0%|          | 0/847 [00:00<?, ?it/s]

  0%|          | 0/847 [00:00<?, ?it/s]

  0%|          | 0/771 [00:00<?, ?it/s]

  0%|          | 0/771 [00:00<?, ?it/s]

  0%|          | 0/771 [00:00<?, ?it/s]

  0%|          | 0/839 [00:00<?, ?it/s]

  0%|          | 0/839 [00:00<?, ?it/s]

  0%|          | 0/839 [00:00<?, ?it/s]

  0%|          | 0/793 [00:00<?, ?it/s]

  0%|          | 0/793 [00:00<?, ?it/s]

  0%|          | 0/793 [00:00<?, ?it/s]

  0%|          | 0/830 [00:00<?, ?it/s]

  0%|          | 0/830 [00:00<?, ?it/s]

  0%|          | 0/830 [00:00<?, ?it/s]

In [19]:
# Save locally to avoid running it every time
#cross_enc_ds.save_to_disk(f"../temp/{label2dict[label]}_index/cross_encoded_dataset")

In [20]:
from xmen.data.indexed_dataset import IndexedDatasetDict
cross_enc_ds = IndexedDatasetDict.load_from_disk(f"../temp/{label2dict[label]}_index/cross_encoded_dataset")
cross_enc_ds

{'k4': [793 items],
 'k3': [839 items],
 'k1': [847 items],
 'k2': [771 items],
 'k5': [830 items]}

In [23]:
CROSS_ENC_MODEL = 'cross-encoder/msmarco-MiniLM-L6-en-de-v1'
NUM_EPOCHS = 10
train_args = CrossEncoderTrainingArgs(
    CROSS_ENC_MODEL, 
    NUM_EPOCHS,
    score_regularization=True,
)

rr = CrossEncoderReranker()
output_dir = f'../temp/{label2dict[label]}_index/cross_encoder_training/'

In [None]:
# Depending on the number of epochs, this can take a few hours
rr.fit(cross_enc_ds['k1'].dataset, cross_enc_ds['k2'].dataset, output_dir=output_dir, training_args=train_args)

model_name := cross-encoder/msmarco-MiniLM-L6-en-de-v1
num_train_epochs := 10
fp16 := True
label_smoothing := False
score_regularization := True
train_layers := None
softmax_loss := True


Downloading (…)lve/main/config.json:   0%|          | 0.00/840 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/428M [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/541 [00:00<?, ?B/s]

Downloading (…)tencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/9.10M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/150 [00:00<?, ?B/s]

Using score regularization: True
Using label smoothing factor: False


Epoch:   0%|          | 0/10 [00:00<?, ?it/s]

Iteration:   0%|          | 0/847 [00:00<?, ?it/s]

2023-05-29 20:13:05 - EntityLinkingEvaluator: Evaluating the model on eval dataset after epoch 0:
2023-05-29 20:13:13 - Accuracy: 0.20881971465629054
2023-05-29 20:13:13 - Accuracy @ 5: 0.36964980544747084
2023-05-29 20:13:13 - Accuracy @ 64: 0.5927367055771725
2023-05-29 20:13:13 - Baseline Accuracy: 0.185473411154345
2023-05-29 20:13:13 - Save model to ../temp/icd10gm_index/cross_encoder_training/


Iteration:   0%|          | 0/847 [00:00<?, ?it/s]

2023-05-29 20:13:52 - EntityLinkingEvaluator: Evaluating the model on eval dataset after epoch 1:
2023-05-29 20:13:59 - Accuracy: 0.2490272373540856
2023-05-29 20:13:59 - Accuracy @ 5: 0.42412451361867703
2023-05-29 20:13:59 - Accuracy @ 64: 0.5927367055771725
2023-05-29 20:13:59 - Baseline Accuracy: 0.185473411154345
2023-05-29 20:13:59 - Save model to ../temp/icd10gm_index/cross_encoder_training/


Iteration:   0%|          | 0/847 [00:00<?, ?it/s]

In [16]:
from xmen.reranking.cross_encoder import get_flat_candidate_ds
from xmen.reranking.ranking_util import get_candidates

candidates = candidates
ground_truth = ds
kb = kb
context_length = 128
expand_abbreviations = False
encode_sem_type = True
masking = False

# loop
split = "train"
cand = candidates[split]
gt = ground_truth[split]

# ds, doc_index = create_cross_enc_dataset(cand, gt, kb, context_length, expand_abbreviations, encode_sem_type, masking)
candidate_ds = cand
ground_truth = gt

flat_candidate_ds = candidate_ds.map(
            lambda e, i: get_candidates(e, i, expand_abbreviations),
            batched=True,
            remove_columns=candidate_ds.column_names,
            with_indices=True,
            load_from_cache_file=False,
        )
#flat_candidate_ds, doc_index = get_flat_candidate_ds(cand, gt, expand_abbreviations=expand_abbreviations, kb=kb)
doc_index = flat_candidate_ds["doc_index"]
flat_candidate_ds = flat_candidate_ds.remove_columns(["doc_index"])

flat_ground_truth = ground_truth.map(
            lambda e, i: get_candidates(e, i, False),
            batched=True,
            remove_columns=candidate_ds.column_names,
            with_indices=True,
            load_from_cache_file=False,
        )
flat_ground_truth = flat_ground_truth.rename_column("candidates", "label")

Map:   0%|          | 0/5 [00:00<?, ? examples/s]

Map:   0%|          | 0/5 [00:00<?, ? examples/s]