# Preparation

Make sure to download the corresponding medical ontologies to build the term dictionaries. For each, look for the 2017 version.
 - Treatments: [OPS](https://www.bfarm.de/DE/Kodiersysteme/Services/Downloads/_node.html).
 - Medications: [ATC](https://www.wido.de/publikationen-produkte/arzneimittel-klassifikation/).
 - Diagnosis: [ICD10GM](https://www.bfarm.de/DE/Kodiersysteme/Services/Downloads/_node.html). 
 
In the config file for BRONCO `../conf/bronco.yaml`, modify the paths so they point the extracted folders. We assume they are located in `xmen/local_files`. Otherwise, change the path here and correct accordingly the terminal commands below:

We can already use xMEN to prepare the term dictionaries. <u>This step only has to be performed the first time. </u> 

In your terminal, navigate to the xMEN root folder and run:
 - `xmen dict conf/bronco.yaml --code dicts/atc2017_de.py --key atc`
 - `xmen dict conf/bronco.yaml --code dicts/ops2017.py --key ops`
 - `xmen dict conf/bronco.yaml --code dicts/icd10gm2017.py --key icd10gm`
 
Now use such dictionaries to build the indexes:
 - `xmen index conf/bronco.yaml --dict ~/.cache/xmen/atc/atc.jsonl --output ~/.cache/xmen/atc/ --all`
 - `xmen index conf/bronco.yaml --dict ~/.cache/xmen/ops/ops.jsonl ~/.cache/xmen/ops/ --all`
 - `xmen index conf/bronco.yaml --dict ~/.cache/xmen/icd10gm/icd10gm.jsonl ~/.cache/xmen/icd10gm/ --all`
 
Now we can load the BRONCO150 dataset using BigBIO:

In [1]:
import datasets

path_to_data = r"../local_files/BRONCO150" # paste here the path to the local data

bronco = datasets.load_dataset(path = "bigbio/bronco", 
                               name = "bronco_bigbio_kb", 
                               data_dir=path_to_data)

bronco

Found cached dataset bronco (/home/Florian.Borchert/.cache/huggingface/datasets/bigbio___bronco/bronco_bigbio_kb-fd41eed48d3255b6/1.0.0/cab8fc4a62807688cb5b36df7a24eb7f364314862c4196f6ff2db3813f2fe68b)


  0%|          | 0/1 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'document_id', 'passages', 'entities', 'events', 'coreferences', 'relations'],
        num_rows: 5
    })
})

Finally, we have to choose the semantic class we want to work on and reestructure the dataset in 5 folds for cross-validation, as originally intended.

In [2]:
label = "MEDICATION" # Choose here TREATMENT, MEDICATION or DIAGNOSIS

In [3]:
label2dict = {
    "TREATMENT": "ops",
    "MEDICATION": "atc",
    "DIAGNOSIS": "icd10gm"
}

def filter_entities(bigbio_entities, valid_entities):
    filtered_entities = []
    for ent in bigbio_entities:
        if ent['type'] in valid_entities:
            filtered_entities.append(ent)
    return filtered_entities

ds = bronco.map(lambda row: {'entities': filter_entities(row['entities'], [label])})

Loading cached processed dataset at /home/Florian.Borchert/.cache/huggingface/datasets/bigbio___bronco/bronco_bigbio_kb-fd41eed48d3255b6/1.0.0/cab8fc4a62807688cb5b36df7a24eb7f364314862c4196f6ff2db3813f2fe68b/cache-44459d7791d9d192.arrow


In [4]:
from datasets import DatasetDict

ground_truth = DatasetDict()
for k in range(5):
    ground_truth[f"k{k+1}"] = ds["train"].select([k])
    
ds = ground_truth
ds

DatasetDict({
    k1: Dataset({
        features: ['id', 'document_id', 'passages', 'entities', 'events', 'coreferences', 'relations'],
        num_rows: 1
    })
    k2: Dataset({
        features: ['id', 'document_id', 'passages', 'entities', 'events', 'coreferences', 'relations'],
        num_rows: 1
    })
    k3: Dataset({
        features: ['id', 'document_id', 'passages', 'entities', 'events', 'coreferences', 'relations'],
        num_rows: 1
    })
    k4: Dataset({
        features: ['id', 'document_id', 'passages', 'entities', 'events', 'coreferences', 'relations'],
        num_rows: 1
    })
    k5: Dataset({
        features: ['id', 'document_id', 'passages', 'entities', 'events', 'coreferences', 'relations'],
        num_rows: 1
    })
})

# Run Candidate Generator
We will use the default `EnsembleLinker`:

In [5]:
from xmen.linkers import default_ensemble
from xmen.evaluation import evaluate_at_k
import os
from pathlib import Path

index_base_path = Path(f"{os.path.expanduser('~')}/.cache/xmen/{label2dict[label]}/index")
linker = default_ensemble(index_base_path)

In [6]:
candidates = linker.predict_batch(ds)

# Recall for different numbers of candidates (k)
_ = evaluate_at_k(ds['k5'], candidates['k5'])

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Perf@1 0.44375
Perf@2 0.525
Perf@4 0.64375
Perf@8 0.7375
Perf@16 0.76875
Perf@32 0.809375
Perf@64 0.825


# Train Cross-encoder
We use a cross-encoder to embed the mention with their context together with all potential candidates. This way, we can learn the best ranking of candidates from the training data.

In [7]:
from xmen.reranking.cross_encoder import CrossEncoderReranker, CrossEncoderTrainingArgs
from xmen.linkers.util import filter_and_apply_threshold
from xmen.kb import load_kb
from xmen.data.indexed_dataset import IndexedDatasetDict, IndexedDataset

In [8]:
K_RERANKING = 64
candidates = filter_and_apply_threshold(candidates, K_RERANKING, 0.0)
kb = load_kb(index_base_path.parent / f"{label2dict[label]}.jsonl")

cross_enc_ds = CrossEncoderReranker.prepare_data(candidates, ds, kb)
cross_enc_ds

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Context length: 128


  0%|          | 0/325 [00:00<?, ?it/s]

  0%|          | 0/325 [00:00<?, ?it/s]

  0%|          | 0/325 [00:00<?, ?it/s]

  0%|          | 0/333 [00:00<?, ?it/s]

  0%|          | 0/333 [00:00<?, ?it/s]

  0%|          | 0/333 [00:00<?, ?it/s]

  0%|          | 0/338 [00:00<?, ?it/s]

  0%|          | 0/338 [00:00<?, ?it/s]

  0%|          | 0/338 [00:00<?, ?it/s]

  0%|          | 0/314 [00:00<?, ?it/s]

  0%|          | 0/314 [00:00<?, ?it/s]

  0%|          | 0/314 [00:00<?, ?it/s]

  0%|          | 0/320 [00:00<?, ?it/s]

  0%|          | 0/320 [00:00<?, ?it/s]

  0%|          | 0/320 [00:00<?, ?it/s]

{'k1': [325 items],
 'k2': [333 items],
 'k3': [338 items],
 'k4': [314 items],
 'k5': [320 items]}

Now we set the training arguments, train-eval splits and fit the model. Depending on the number of epochs, training can take several hours.

In [9]:
# Choose train and evaluation folds
train_folds = ["k1", "k2", "k3"]
val_fold = "k4"
test_fold = "k5"

train = sum([cross_enc_ds[k].dataset for k in train_folds],[])
val = cross_enc_ds[val_fold].dataset

In [10]:
#microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext

In [11]:
train_args = CrossEncoderTrainingArgs(num_train_epochs = 5)

rr = CrossEncoderReranker()
output_dir = f'../outputs/{label2dict[label]}_index/cross_encoder_training/'

In [12]:
rr.fit(
    train_dataset = train,
    val_dataset = val,
    output_dir= output_dir,
    training_args = train_args,
    show_progress_bar = False
)

model_name := bert-base-multilingual-cased
num_train_epochs := 5
fp16 := True
label_smoothing := False
score_regularization := 1.0
train_layers := None
softmax_loss := True
random_seed := 42


Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model ch

2023-06-29 12:55:36 - EntityLinkingEvaluator: Evaluating the model on eval dataset after epoch 0:
2023-06-29 12:58:31 - Accuracy: 0.7006369426751592
2023-06-29 12:58:31 - Accuracy @ 5: 0.7802547770700637
2023-06-29 12:58:31 - Accuracy @ 64: 0.8152866242038217
2023-06-29 12:58:31 - Baseline Accuracy: 0.4394904458598726
2023-06-29 12:58:31 - Save model to ../outputs/atc_index/cross_encoder_training/
2023-06-29 13:14:06 - EntityLinkingEvaluator: Evaluating the model on eval dataset after epoch 1:
2023-06-29 13:17:01 - Accuracy: 0.7165605095541401
2023-06-29 13:17:01 - Accuracy @ 5: 0.7834394904458599
2023-06-29 13:17:01 - Accuracy @ 64: 0.8152866242038217
2023-06-29 13:17:01 - Baseline Accuracy: 0.4394904458598726
2023-06-29 13:17:01 - Save model to ../outputs/atc_index/cross_encoder_training/
2023-06-29 13:32:38 - EntityLinkingEvaluator: Evaluating the model on eval dataset after epoch 2:
2023-06-29 13:35:33 - Accuracy: 0.7292993630573248
2023-06-29 13:35:33 - Accuracy @ 5: 0.78025477707

# Evaluate Cross-encoder
Now we can take our trained model and test it on data outside of training.

In [13]:
rr = CrossEncoderReranker.load(output_dir, device=0)

2023-06-29 14:12:39 - Use pytorch device: cuda


In [14]:
cross_enc_pred_val = rr.rerank_batch(candidates[val_fold], cross_enc_ds[val_fold])
_ = evaluate_at_k(ds[val_fold], cross_enc_pred_val)

Batches:   0%|          | 0/314 [00:00<?, ?it/s]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Perf@1 0.7292993630573248
Perf@2 0.7707006369426752
Perf@4 0.7834394904458599
Perf@8 0.8121019108280255
Perf@16 0.8152866242038217
Perf@32 0.8152866242038217
Perf@64 0.8152866242038217


In [15]:
cross_enc_pred_test = rr.rerank_batch(candidates[test_fold], cross_enc_ds[test_fold])
_ = evaluate_at_k(ds[test_fold], cross_enc_pred_test)

Batches:   0%|          | 0/320 [00:00<?, ?it/s]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Perf@1 0.696875
Perf@2 0.753125
Perf@4 0.7625
Perf@8 0.80625
Perf@16 0.8125
Perf@32 0.821875
Perf@64 0.825
