# Preparation

Make sure to download the corresponding medical ontologies to build the term dictionaries. For each, look for the 2017 version.
 - Treatments: [OPS](https://www.bfarm.de/DE/Kodiersysteme/Services/Downloads/_node.html). The required file is `\p1smt2017\Klassifikationsdateien\ops2017syst_kodes.txt` 
 - Medications: [ATC](https://www.wido.de/publikationen-produkte/arzneimittel-klassifikation/). The required file is `ATC GKV AI 2017`
 - Diagnosis: [ICD10GM](https://www.bfarm.de/DE/Kodiersysteme/Services/Downloads/_node.html). The required file is `\x1gmt2017\Klassifikationsdateien\icd10gm2017syst_kodes.txt`
 
In the config file for BRONCO `../conf/bronco.yaml`, modify the paths so they point the extracted ontologies. We assume they are located in `xmen/temp`. Otherwise, change the path here and correct accordingly the terminal commands below:

In [1]:
base_path = "../temp"

We can already use xMEN to prepare the term dictionaries. <u>This step only has to be performed the first time. </u> 

In your terminal, navigate to the xMEN root folder and run:
 - `xmen dict conf/bronco.yaml --code dicts/atc2017_de.py --output temp/ --key atc`
 - `xmen dict conf/bronco.yaml --code dicts/ops2017.py --output temp/ --key ops`
 - `xmen dict conf/bronco.yaml --code dicts/icd10gm2017.py --output temp/ --key icd10gm`
 
Now use such dictionaries to build the indexes. For this example, we will use only SapBERT indexes and leave aside N-Gram:
 - `xmen index conf/bronco.yaml --dict temp/atc.jsonl --output temp/atc_index --sapbert`
 - `xmen index conf/bronco.yaml --dict temp/ops.jsonl --output temp/ops_index --sapbert`
 - `xmen index conf/bronco.yaml --dict temp/icd10gm.jsonl --output temp/icd10gm_index --sapbert`
 
Now we can load the BRONCO150 dataset using BigBIO:

In [2]:
import datasets

path_to_data = r"../../BRONCO150" # paste here the path to the local data

bronco = datasets.load_dataset(path = "bigbio/bronco", 
                               name = "bronco_bigbio_kb", 
                               data_dir=path_to_data)

bronco

Found cached dataset bronco (/dhc/home/ignacio.rodriguez/.cache/huggingface/datasets/bigbio___bronco/bronco_bigbio_kb-data_dir=..%2F..%2FBRONCO150/1.0.0/cab8fc4a62807688cb5b36df7a24eb7f364314862c4196f6ff2db3813f2fe68b)


  0%|          | 0/1 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'document_id', 'passages', 'entities', 'events', 'coreferences', 'relations'],
        num_rows: 5
    })
})

Finally, we have to choose the semantic class we want to work on and reestructure the dataset in 5 folds for cross-validation, as originally intended.

In [3]:
label = "DIAGNOSIS" # Choose here TREATMENT, MEDICATION or DIAGNOSIS

In [4]:
label2dict = {
    "TREATMENT": "ops",
    "MEDICATION": "atc",
    "DIAGNOSIS": "icd10gm"
}

def filter_entities(bigbio_entities, valid_entities):
    filtered_entities = []
    for ent in bigbio_entities:
        if ent['type'] in valid_entities:
            filtered_entities.append(ent)
    return filtered_entities

ds = bronco.map(lambda row: {'entities': filter_entities(row['entities'], [label])})

Loading cached processed dataset at /dhc/home/ignacio.rodriguez/.cache/huggingface/datasets/bigbio___bronco/bronco_bigbio_kb-data_dir=..%2F..%2FBRONCO150/1.0.0/cab8fc4a62807688cb5b36df7a24eb7f364314862c4196f6ff2db3813f2fe68b/cache-a26fddd6429a662a.arrow


In [5]:
from datasets import DatasetDict

ground_truth = DatasetDict()
for k in range(5):
    ground_truth[f"k{k+1}"] = ds["train"].select([k])
    
ds = ground_truth
ds

DatasetDict({
    k1: Dataset({
        features: ['id', 'document_id', 'passages', 'entities', 'events', 'coreferences', 'relations'],
        num_rows: 1
    })
    k2: Dataset({
        features: ['id', 'document_id', 'passages', 'entities', 'events', 'coreferences', 'relations'],
        num_rows: 1
    })
    k3: Dataset({
        features: ['id', 'document_id', 'passages', 'entities', 'events', 'coreferences', 'relations'],
        num_rows: 1
    })
    k4: Dataset({
        features: ['id', 'document_id', 'passages', 'entities', 'events', 'coreferences', 'relations'],
        num_rows: 1
    })
    k5: Dataset({
        features: ['id', 'document_id', 'passages', 'entities', 'events', 'coreferences', 'relations'],
        num_rows: 1
    })
})

# Run Candidate Generator
We will use `SapBERTLinker`, which uses a Transformer model to retrieve candidates with dense embeddings.

In [6]:
from xmen.linkers import SapBERTLinker
from xmen.evaluation import evaluate_at_k

Your CPU supports instructions that this binary was not compiled to use: SSE3 SSE4.1 SSE4.2 AVX AVX2
For maximum performance, you can install NMSLIB from sources 
pip install --no-binary :all: nmslib


<u>This cell can be run just once. For subsequent runs you can load the result with the cell below this one.</u>

In [13]:
embedding_model_name = 'cambridgeltl/SapBERT-UMLS-2020AB-all-lang-from-XLMR'

SapBERTLinker.clear()
sapbert_linker = SapBERTLinker(
    embedding_model_name = embedding_model_name,
    index_base_path = f"{base_path}/{label2dict[label]}_index/index/sapbert",
    k = 1000
)

pred_sapbert = sapbert_linker.predict_batch(ds, batch_size=128)

# Save locally to avoid running it every time
pred_sapbert.save_to_disk(f"{base_path}/{label2dict[label]}_index/pred_sapbert")

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1 [00:00<?, ? examples/s]

In [7]:
candidates = datasets.load_from_disk(f"{base_path}/{label2dict[label]}_index/pred_sapbert")

# Recall for different numbers of candidates (k)
_ = evaluate_at_k(ds['k5'], candidates['k5'])

Perf@1 0.22677925211097708
Perf@2 0.30639324487334135
Perf@4 0.353437876960193
Perf@8 0.4209891435464415
Perf@16 0.4969843184559711
Perf@32 0.5379975874547648
Perf@64 0.6079613992762364


# Train Cross-encoder
We use a cross-encoder to embed the mention with their context together with all potential candidates. This way, we can learn the best ranking of candidates from the training data.

In [8]:
from xmen.reranking.cross_encoder import CrossEncoderReranker, CrossEncoderTrainingArgs
from xmen.linkers.util import filter_and_apply_threshold
from xmen.kb import load_kb
from xmen.data.indexed_dataset import IndexedDatasetDict, IndexedDataset

<u>This cell can be run just once. For subsequent runs you can load the result with the cell below this one.</u>

In [26]:
K_RERANKING = 64
candidates = filter_and_apply_threshold(candidates, K_RERANKING, 0.0)
kb = load_kb(f"{base_path}/{label2dict[label]}.jsonl")

cross_enc_ds = CrossEncoderReranker.prepare_data(candidates, ds, kb)

# Save locally to avoid running it every time
cross_enc_ds.save_to_disk(f"{base_path}/{label2dict[label]}_index/cross_encoded_dataset")
cross_enc_ds

Context length: 128


  0%|          | 0/847 [00:00<?, ?it/s]

  0%|          | 0/847 [00:00<?, ?it/s]

  0%|          | 0/847 [00:00<?, ?it/s]

  0%|          | 0/771 [00:00<?, ?it/s]

  0%|          | 0/771 [00:00<?, ?it/s]

  0%|          | 0/771 [00:00<?, ?it/s]

  0%|          | 0/839 [00:00<?, ?it/s]

  0%|          | 0/839 [00:00<?, ?it/s]

  0%|          | 0/839 [00:00<?, ?it/s]

  0%|          | 0/793 [00:00<?, ?it/s]

  0%|          | 0/793 [00:00<?, ?it/s]

  0%|          | 0/793 [00:00<?, ?it/s]

  0%|          | 0/830 [00:00<?, ?it/s]

  0%|          | 0/830 [00:00<?, ?it/s]

  0%|          | 0/830 [00:00<?, ?it/s]

{'k1': [847 items],
 'k2': [771 items],
 'k3': [839 items],
 'k4': [793 items],
 'k5': [830 items]}

In [9]:
cross_enc_ds = IndexedDatasetDict.load_from_disk(f"{base_path}/{label2dict[label]}_index/cross_encoded_dataset")
cross_enc_ds

{'k4': [793 items],
 'k3': [839 items],
 'k1': [847 items],
 'k2': [771 items],
 'k5': [830 items]}

Now we set the training arguments, train-eval splits and fit the model. Depending on the number of epochs, the training can take several hours.

In [10]:
CROSS_ENC_MODEL = "SCAI-BIO/bio-gottbert-base"#'cross-encoder/msmarco-MiniLM-L6-en-de-v1'
NUM_EPOCHS = 5
train_args = CrossEncoderTrainingArgs(
    CROSS_ENC_MODEL, 
    NUM_EPOCHS,
    score_regularization=True,
)

rr = CrossEncoderReranker()
output_dir = f'{base_path}/{label2dict[label]}_index/cross_encoder_training/'

In [27]:
# Choose train and evaluation folds
train_folds = [
    "k1",
    "k2",
    "k3",
    "k4",
    #"k5",
]

val_folds = [
    #"k1",
    #"k2",
    #"k3",
    "k4",
    #"k5",
] 

train, val = [], []
train = sum([train + cross_enc_ds[k].dataset for k in train_folds],[])
val = sum([val + cross_enc_ds[k].dataset for k in val_folds],[])

model_name := SCAI-BIO/bio-gottbert-base
num_train_epochs := 5
fp16 := True
label_smoothing := False
score_regularization := True
train_layers := None
softmax_loss := True


Some weights of the model checkpoint at SCAI-BIO/bio-gottbert-base were not used when initializing RobertaForSequenceClassification: ['lm_head.layer_norm.bias', 'lm_head.dense.bias', 'lm_head.bias', 'lm_head.decoder.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.weight', 'lm_head.decoder.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at SCAI-BIO/bio-gottbert-base and are newly initialized: ['classifier.out_proj.bias', 'classifier.dense

Using score regularization: True
Using label smoothing factor: False


Epoch:   0%|          | 0/5 [00:00<?, ?it/s]

Iteration:   0%|          | 0/3250 [00:00<?, ?it/s]

2023-05-30 18:01:54 - EntityLinkingEvaluator: Evaluating the model on eval dataset after epoch 0:
2023-05-30 18:02:32 - Accuracy: 0.39722572509457754
2023-05-30 18:02:32 - Accuracy @ 5: 0.5145018915510718
2023-05-30 18:02:32 - Accuracy @ 64: 0.5775535939470365
2023-05-30 18:02:32 - Baseline Accuracy: 0.20933165195460277
2023-05-30 18:02:32 - Save model to ../temp/icd10gm_index/cross_encoder_training/


Iteration:   0%|          | 0/3250 [00:00<?, ?it/s]

2023-05-30 18:07:04 - EntityLinkingEvaluator: Evaluating the model on eval dataset after epoch 1:
2023-05-30 18:07:42 - Accuracy: 0.4489281210592686
2023-05-30 18:07:42 - Accuracy @ 5: 0.5397225725094578
2023-05-30 18:07:42 - Accuracy @ 64: 0.5775535939470365
2023-05-30 18:07:42 - Baseline Accuracy: 0.20933165195460277
2023-05-30 18:07:42 - Save model to ../temp/icd10gm_index/cross_encoder_training/


Iteration:   0%|          | 0/3250 [00:00<?, ?it/s]

2023-05-30 18:12:14 - EntityLinkingEvaluator: Evaluating the model on eval dataset after epoch 2:
2023-05-30 18:12:52 - Accuracy: 0.47414880201765447
2023-05-30 18:12:52 - Accuracy @ 5: 0.544766708701135
2023-05-30 18:12:52 - Accuracy @ 64: 0.5775535939470365
2023-05-30 18:12:52 - Baseline Accuracy: 0.20933165195460277
2023-05-30 18:12:52 - Save model to ../temp/icd10gm_index/cross_encoder_training/


Iteration:   0%|          | 0/3250 [00:00<?, ?it/s]

2023-05-30 18:17:23 - EntityLinkingEvaluator: Evaluating the model on eval dataset after epoch 3:
2023-05-30 18:18:02 - Accuracy: 0.5006305170239597
2023-05-30 18:18:02 - Accuracy @ 5: 0.5573770491803278
2023-05-30 18:18:02 - Accuracy @ 64: 0.5775535939470365
2023-05-30 18:18:02 - Baseline Accuracy: 0.20933165195460277
2023-05-30 18:18:02 - Save model to ../temp/icd10gm_index/cross_encoder_training/


Iteration:   0%|          | 0/3250 [00:00<?, ?it/s]

2023-05-30 18:22:33 - EntityLinkingEvaluator: Evaluating the model on eval dataset after epoch 4:
2023-05-30 18:23:11 - Accuracy: 0.5094577553593947
2023-05-30 18:23:11 - Accuracy @ 5: 0.5561160151324086
2023-05-30 18:23:11 - Accuracy @ 64: 0.5775535939470365
2023-05-30 18:23:11 - Baseline Accuracy: 0.20933165195460277
2023-05-30 18:23:11 - Save model to ../temp/icd10gm_index/cross_encoder_training/


In [None]:
rr.fit(train, val, output_dir=output_dir, training_args=train_args)

# Evaluate Cross-encoder
Now we can take our trained model and test it on data outside of training. Using the `xmen.evaluation` module, we gain insights into different error types.

In [29]:
rr = CrossEncoderReranker.load(output_dir, device=0)

2023-05-30 18:34:17 - Use pytorch device: cuda


In [30]:
test_fold = "k5"

cross_enc_pred_valid = rr.rerank_batch(candidates[test_fold], cross_enc_ds[test_fold])
_ = evaluate_at_k(candidates[test_fold], cross_enc_pred_valid)

Batches:   0%|          | 0/830 [00:00<?, ?it/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Map:   0%|          | 0/1 [00:00<?, ? examples/s]

AssertionError: (1000, 64)

In [None]:
from notebook_util import analyze
from xmen.evaluation import entity_linking_error_analysis

In [None]:
cg_kb = load_kb(f"{base_path}/{label2dict[label]}.jsonl")
error_cand = entity_linking_error_analysis(ds[test_fold], candidates[test_fold])
edf_cand, _ = analyze(error_cand, cg_kb, 'eval', None)

In [None]:
# False negative (not part of top 64)
fns_cand = edf_cand[edf_cand.pred_index == -1]
fns_cand_counts = fns_cand.error_type.value_counts() / len(edf_cand)
fns_cand_counts

In [None]:
# Ranking errors during candidate generation
misranked_cand = edf_cand[edf_cand.pred_index > 0]
misranked_cand_counts = misranked_cand.error_type.value_counts() / len(edf_cand)
misranked_cand_counts

In [None]:
# Ranking errors after reranking
error_rr = entity_linking_error_analysis(ds['validation'], cross_enc_pred_valid)
edf_rr, _ = analyze(error_rr, cg_kb, 'eval', None)

misranked_rr = edf_rr[edf_rr.pred_index > 0]
misranked_rr_counts = misranked_rr.error_type.value_counts() / len(edf_rr)
misranked_rr_counts

In [None]:
print('Recall@1', (edf_rr.pred_index == 0).sum() / len(edf_rr))
print('Recall@64', (edf_rr.pred_index >= 0).sum() / len(edf_rr))

In [None]:
import matplotlib
from matplotlib import pyplot as plt

error_types = ['Complex Entity', 'Abbreviation', 'Wrong Type', 'Homonyms', 'Other Error']
keys = ['COMPLEX_ENTITY', 'ABBREV', 'WRONG_SEMANTIC_TYPE', 'SAME_SYNONYMS', 'UNKNOWN_ERROR']
counts = [fns_cand_counts, misranked_cand_counts, misranked_rr_counts]
count_names = ['False Negatives', 'Candidate Generation Errors', 'Reranking Errors']

hatches = ['', '\\\\\\', '///']
colors = ['lightyellow', '#C0F5EC', '#118470']

matplotlib.rcParams['hatch.linewidth'] = 0.3 
matplotlib.rcParams.update({'font.size': 12})

fig, axs = plt.subplots(1, len(keys), figsize=(8, 2))
for i, (key, error) in enumerate(zip(keys, error_types)):
    for j, bar in enumerate(axs[i].bar(x=count_names, height=[c.get(key, 0.0) for c in counts], color=colors, edgecolor = 'black', linewidth=0.2)):
        bar.set_hatch(hatches[j])
    axs[i].grid(axis='y')
    axs[i].set_ylim([0.0, 0.2])
    axs[i].set_xlabel(error, labelpad=10)
    axs[i].get_xaxis().set_ticks([])
handles = [plt.Rectangle((0,0),1,1, facecolor=color, linewidth=0.2, edgecolor='black', hatch=h) for color, h in zip(colors, hatches)]
plt.legend(handles, count_names, loc='upper center', bbox_to_anchor=(-3.1, 1.4, 0, 0), ncol=len(counts))

plt.subplots_adjust(wspace=0.7)