# Preparation

`xmen dict conf/xmen.yaml --code src/gazetteer.py`
(extend the SympTEMIST gazetteer with UMLS aliases and save as jsonl file)

`xmen index conf/xmen.yaml --all`
(compute TF-IDF and SapBERT indices)

In [1]:
from datasets import load_dataset, DatasetDict
from pathlib import Path

In [2]:
base_path = Path.home() / '.cache' / 'xmen' / 'symptemist'

In [3]:
symptemist_data = load_dataset(
    path="../../biomedical/bigbio/hub/hub_repos/symptemist/symptemist.py", 
    name="symptemist_linking_bigbio_kb"
)
train_data = symptemist_data['train']

Found cached dataset symptemist (/dhc/home/florian.borchert/.cache/huggingface/datasets/symptemist/symptemist_linking_bigbio_kb/2.0.0/2542aaab0d6c9963785fca5b4b0712501e06aa5a2e136b7b4d26d1fd7a2c382a)


  0%|          | 0/1 [00:00<?, ?it/s]

In [143]:
#import random
#random.seed(42)

#doc_ids = symptemist_data['train']['document_id']
#n_valid = int(0.2 * len(doc_ids))
#n_valid
#valid_doc_ids = random.sample(doc_ids, n_valid)
#with open('../data/subtrack2_valid_docids.txt', 'w') as fh:
#     for v in valid_doc_ids:
#        fh.write(v + '\n')
valid_doc_ids = [doc_id.strip() for doc_id in open('../data/subtrack2_valid_docids.txt', 'r').readlines()]
len(valid_doc_ids)

60

In [5]:
dataset = DatasetDict()
dataset['train'] = train_data.filter(lambda d: d['document_id'] not in valid_doc_ids)
dataset['validation'] = train_data.filter(lambda d: d['document_id'] in valid_doc_ids)
dataset

Loading cached processed dataset at /dhc/home/florian.borchert/.cache/huggingface/datasets/symptemist/symptemist_linking_bigbio_kb/2.0.0/2542aaab0d6c9963785fca5b4b0712501e06aa5a2e136b7b4d26d1fd7a2c382a/cache-2fdb3259090cb02d.arrow
Loading cached processed dataset at /dhc/home/florian.borchert/.cache/huggingface/datasets/symptemist/symptemist_linking_bigbio_kb/2.0.0/2542aaab0d6c9963785fca5b4b0712501e06aa5a2e136b7b4d26d1fd7a2c382a/cache-6587fb313b12f484.arrow


DatasetDict({
    train: Dataset({
        features: ['id', 'document_id', 'passages', 'entities', 'events', 'coreferences', 'relations'],
        num_rows: 244
    })
    validation: Dataset({
        features: ['id', 'document_id', 'passages', 'entities', 'events', 'coreferences', 'relations'],
        num_rows: 60
    })
})

In [6]:
from xmen.data import get_cuis
cuis = get_cuis(dataset['train']) + get_cuis(dataset['validation'])
len(cuis)

3484

In [7]:
from xmen import load_kb
kb = load_kb(base_path / 'symptemist.jsonl')

In [8]:
set([c for c in cuis if not c in kb.cui_to_entity])

{'NO_CODE'}

In [9]:
# aliases
sum([len(c) for c in kb.alias_to_cuis.values()])

1079623

# Candidate Generation

In [10]:
from xmen.linkers import default_ensemble
from xmen.evaluation import evaluate_at_k

In [11]:
linker = default_ensemble(base_path / 'index')

In [12]:
candidates_ngram = linker.linkers_fn['ngram']().predict_batch(dataset)

Map:   0%|          | 0/244 [00:00<?, ? examples/s]

Map:   0%|          | 0/60 [00:00<?, ? examples/s]

In [13]:
print('Training Set:')
_ = evaluate_at_k(dataset['train'], candidates_ngram['train'])
print('Validation Set:')
_ = evaluate_at_k(dataset['validation'], candidates_ngram['validation'])

Training Set:
Perf@1 0.35521947620804134
Perf@2 0.47399483585392843
Perf@4 0.5407598672076724
Perf@8 0.5820730357801549
Perf@16 0.6233862043526374
Perf@32 0.66654371080782
Perf@64 0.6938399114717816
Validation Set:
Perf@1 0.34799482535575677
Perf@2 0.445019404915912
Perf@4 0.5045278137128072
Perf@8 0.5433376455368694
Perf@16 0.5899094437257438
Perf@32 0.6403622250970246
Perf@64 0.6752910737386805


In [14]:
candidates_sap = linker.linkers_fn['sapbert']().predict_batch(dataset, batch_size=128)

Map:   0%|          | 0/244 [00:00<?, ? examples/s]

Map:   0%|          | 0/60 [00:00<?, ? examples/s]

In [15]:
print('Training Set:')
_ = evaluate_at_k(dataset['train'], candidates_sap['train'])
print('Validation Set:')
_ = evaluate_at_k(dataset['validation'], candidates_sap['validation'])

Training Set:
Perf@1 0.4293618590925858
Perf@2 0.5551457026927333
Perf@4 0.6292880855772778
Perf@8 0.6842493544817411
Perf@16 0.7215049797122833
Perf@32 0.7558096643305053
Perf@64 0.782736997417927
Validation Set:
Perf@1 0.4023285899094437
Perf@2 0.5433376455368694
Perf@4 0.6170763260025873
Perf@8 0.6752910737386805
Perf@16 0.7257438551099612
Perf@32 0.7619663648124192
Perf@64 0.795601552393273


In [16]:
candidates = linker.predict_batch(dataset, batch_size=128, top_k=64, reuse_preds={'sapbert' : candidates_sap, 'ngram' : candidates_ngram})

Map:   0%|          | 0/244 [00:00<?, ? examples/s]

Map:   0%|          | 0/60 [00:00<?, ? examples/s]

In [17]:
_ = evaluate_at_k(dataset['train'], candidates['train'])

Perf@1 0.4577646624861675
Perf@2 0.5706381409074143
Perf@4 0.633345628919218
Perf@8 0.6883068978236813
Perf@16 0.7303578015492438
Perf@32 0.7624492807082257
Perf@64 0.7893766137956474


In [18]:
_ = evaluate_at_k(dataset['validation'], candidates['validation'])

Perf@1 0.4230271668822768
Perf@2 0.5433376455368694
Perf@4 0.6261319534282018
Perf@8 0.6817593790426908
Perf@16 0.7309184993531694
Perf@32 0.7684346701164295
Perf@64 0.8007761966364813


# Reranker Training

In [19]:
from xmen.reranking import CrossEncoderReranker
from xmen.reranking.cross_encoder import CrossEncoderTrainingArgs

In [20]:
ce_dataset = CrossEncoderReranker.prepare_data(candidates, dataset, kb)

Context length: 128
Use NIL values: True


  0%|          | 0/2711 [00:00<?, ?it/s]

  0%|          | 0/2711 [00:00<?, ?it/s]

  0%|          | 0/2711 [00:00<?, ?it/s]

  0%|          | 0/773 [00:00<?, ?it/s]

  0%|          | 0/773 [00:00<?, ?it/s]

  0%|          | 0/773 [00:00<?, ?it/s]

In [25]:
args = CrossEncoderTrainingArgs(num_train_epochs=2, model_name='PlanTL-GOB-ES/roberta-base-biomedical-clinical-es')

In [26]:
rr = CrossEncoderReranker()

In [27]:
rr.fit(args, ce_dataset['train'].dataset, ce_dataset['validation'].dataset)

model_name := PlanTL-GOB-ES/roberta-base-biomedical-clinical-es
num_train_epochs := 2
fp16 := True
label_smoothing := False
rank_regularization := 1.0
train_layers := None
softmax_loss := True
random_seed := 42
learning_rate := 2e-05


Some weights of the model checkpoint at PlanTL-GOB-ES/roberta-base-biomedical-clinical-es were not used when initializing RobertaForSequenceClassification: ['lm_head.layer_norm.weight', 'lm_head.dense.bias', 'lm_head.decoder.bias', 'lm_head.decoder.weight', 'lm_head.layer_norm.bias', 'lm_head.dense.weight', 'lm_head.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at PlanTL-GOB-ES/roberta-base-biomedical-clinical-es and are newly initialized: 

Epoch:   0%|          | 0/2 [00:00<?, ?it/s]

Iteration:   0%|          | 0/2711 [00:00<?, ?it/s]

2023-09-19 17:20:37 - EntityLinkingEvaluator: Evaluating the model on eval dataset after epoch 0:
2023-09-19 17:25:51 - Accuracy: 0.5756791720569211
2023-09-19 17:25:51 - Accuracy @ 5: 0.7529107373868047
2023-09-19 17:25:51 - Accuracy @ 64: 1.0
2023-09-19 17:25:51 - Baseline Accuracy: 0.4230271668822768
2023-09-19 17:25:51 - Save model to ./output/cross_encoder


Iteration:   0%|          | 0/2711 [00:00<?, ?it/s]

2023-09-19 17:42:24 - EntityLinkingEvaluator: Evaluating the model on eval dataset after epoch 1:
2023-09-19 17:47:37 - Accuracy: 0.6261319534282018
2023-09-19 17:47:37 - Accuracy @ 5: 0.7852522639068564
2023-09-19 17:47:37 - Accuracy @ 64: 1.0
2023-09-19 17:47:37 - Baseline Accuracy: 0.4230271668822768
2023-09-19 17:47:37 - Save model to ./output/cross_encoder


# Prediction on Test Set

In [29]:
rr = CrossEncoderReranker.load('./output/cross_encoder', device=0)

2023-09-19 17:48:33 - Use pytorch device: cuda


In [30]:
pred_validation = rr.rerank_batch(candidates['validation'], ce_dataset['validation'])

Batches:   0%|          | 0/773 [00:00<?, ?it/s]

Map:   0%|          | 0/60 [00:00<?, ? examples/s]

Map:   0%|          | 0/60 [00:00<?, ? examples/s]

In [33]:
from xmen.evaluation import evaluate

In [56]:
evaluate(dataset['validation'], pred_validation)

{'strict': {'precision': 0.6226138032305433,
  'recall': 0.5485122897800776,
  'fscore': 0.5832187070151307,
  'ptp': 424,
  'fp': 257,
  'rtp': 424,
  'fn': 349,
  'n_docs_system': 60,
  'n_annos_system': 681,
  'n_docs_gold': 60,
  'n_annos_gold': 773}}

# Post-Processing - Training Set Lookup?

# Create Submission TSV

## Validation Set

In [144]:
import pandas as pd
train_tsv = pd.read_csv(Path('../data/symptemist-train_all_subtasks+gazetteer+multilingual_230919/symptemist_train/subtask2-linking/symptemist_tsv_train_subtask2.tsv'), sep='\t')

In [145]:
valid_tsv = train_tsv[train_tsv.filename.isin(valid_doc_ids)]

In [146]:
valid_tsv.rename(columns={'span_ini' : 'start_span', 'span_end' : 'end_span'}).to_csv('../data/valid_subtask2.tsv', sep='\t', index=False)

In [147]:
len(valid_tsv)

773

In [148]:
val_entities = [e for d in pred_validation['entities'] for e in d]

In [149]:
pred_tsv = valid_tsv.copy()[['filename', 'label', 'span_ini', 'span_end', 'text']]

In [150]:
ents = {(d['document_id'], e['offsets'][0][0], e['offsets'][0][1]) : e for d in pred_validation for e in d['entities']}

In [155]:
assert len(ents) == len(pred_tsv)
output_val = pred_tsv.copy()
for idx, pred in pred_tsv.iterrows():
    e_idx = (pred['filename'], pred['span_ini'], pred['span_end'])
    assert e_idx in ents
    norms = ents[e_idx]['normalized']
    if len(norms) > 0:
        output_val.loc[idx, 'code'] = norms[0]['db_id']

In [156]:
output_val[~output_val.code.isna()].rename(columns={'span_ini' : 'start_span', 'span_end' : 'end_span'}).to_csv('../data/valid_pred_xmen.tsv', sep='\t', index=False)

In [157]:
#python medprocner_evaluation.py -r ../symptemist_biocreative_2023/data/valid_subtask2.tsv -p ../symptemist_biocreative_2023/data/valid_pred_xmen.tsv -t norm -o ../symptemist_biocreative_2023/data/