# Preparation

`xmen dict conf/xmen.yaml --code src/gazetteer.py`
(extend the SympTEMIST gazetteer with UMLS aliases and save as jsonl file)

`xmen index conf/xmen.yaml --all`
(compute TF-IDF and SapBERT indices)

In [1]:
from datasets import load_dataset, DatasetDict, load_from_disk
from pathlib import Path
from xmen.evaluation import *

In [2]:
base_path = Path.home() / '.cache' / 'xmen' / 'symptemist'

In [3]:
symptemist_data = load_dataset(
    path="../../biomedical/bigbio/hub/hub_repos/symptemist/symptemist.py", 
    name="symptemist_linking_bigbio_kb"
)
train_data = symptemist_data['train']

Found cached dataset symptemist (/dhc/home/florian.borchert/.cache/huggingface/datasets/symptemist/symptemist_linking_bigbio_kb/2.0.0/2542aaab0d6c9963785fca5b4b0712501e06aa5a2e136b7b4d26d1fd7a2c382a)


  0%|          | 0/1 [00:00<?, ?it/s]

In [4]:
#import random
#random.seed(42)

#doc_ids = symptemist_data['train']['document_id']
#n_valid = int(0.2 * len(doc_ids))
#n_valid
#valid_doc_ids = random.sample(doc_ids, n_valid)
#with open('../data/subtrack2_valid_docids.txt', 'w') as fh:
#     for v in valid_doc_ids:
#        fh.write(v + '\n')
valid_doc_ids = [doc_id.strip() for doc_id in open('../data/subtrack2_valid_docids.txt', 'r').readlines()]
len(valid_doc_ids)

60

In [5]:
dataset = DatasetDict()
dataset['train'] = train_data.filter(lambda d: d['document_id'] not in valid_doc_ids)
dataset['validation'] = train_data.filter(lambda d: d['document_id'] in valid_doc_ids)
dataset

Loading cached processed dataset at /dhc/home/florian.borchert/.cache/huggingface/datasets/symptemist/symptemist_linking_bigbio_kb/2.0.0/2542aaab0d6c9963785fca5b4b0712501e06aa5a2e136b7b4d26d1fd7a2c382a/cache-2fdb3259090cb02d.arrow
Loading cached processed dataset at /dhc/home/florian.borchert/.cache/huggingface/datasets/symptemist/symptemist_linking_bigbio_kb/2.0.0/2542aaab0d6c9963785fca5b4b0712501e06aa5a2e136b7b4d26d1fd7a2c382a/cache-6587fb313b12f484.arrow


DatasetDict({
    train: Dataset({
        features: ['id', 'document_id', 'passages', 'entities', 'events', 'coreferences', 'relations'],
        num_rows: 244
    })
    validation: Dataset({
        features: ['id', 'document_id', 'passages', 'entities', 'events', 'coreferences', 'relations'],
        num_rows: 60
    })
})

In [6]:
from xmen.data import get_cuis
cuis = get_cuis(dataset['train']) + get_cuis(dataset['validation'])
len(cuis)

3484

In [7]:
from xmen import load_kb
kb = load_kb(base_path / 'symptemist.jsonl')

In [8]:
set([c for c in cuis if not c in kb.cui_to_entity])

{'NO_CODE'}

In [9]:
# aliases
sum([len(c) for c in kb.alias_to_cuis.values()])

1079623

# Candidate Generation

In [None]:
from xmen.linkers import default_ensemble

In [None]:
linker = default_ensemble(base_path / 'index')

In [None]:
candidates_ngram = linker.linkers_fn['ngram']().predict_batch(dataset)

In [None]:
print('Training Set:')
_ = evaluate_at_k(dataset['train'], candidates_ngram['train'])
print('Validation Set:')
_ = evaluate_at_k(dataset['validation'], candidates_ngram['validation'])

In [None]:
candidates_sap = linker.linkers_fn['sapbert']().predict_batch(dataset, batch_size=128)

In [None]:
print('Training Set:')
_ = evaluate_at_k(dataset['train'], candidates_sap['train'])
print('Validation Set:')
_ = evaluate_at_k(dataset['validation'], candidates_sap['validation'])

In [None]:
candidates = linker.predict_batch(dataset, batch_size=128, top_k=64, reuse_preds={'sapbert' : candidates_sap, 'ngram' : candidates_ngram})

In [None]:
print('Training Set:')
_ = evaluate_at_k(dataset['train'], candidates['train'])
print('Validation Set:')
_ = evaluate_at_k(dataset['validation'], candidates['validation'])

In [None]:
candidates.save_to_disk('../data/candidates')

# Prepare Data for Reranking

In [10]:
from xmen.reranking import CrossEncoderReranker

In [11]:
candidates = load_from_disk('../data/candidates')

In [13]:
ea_df = error_analysis(dataset['validation'], candidates['validation'])

In [27]:
_ = evaluate_at_k(dataset['validation'], candidates['validation'])

Recall@1 0.4230271668822768
Recall@2 0.5433376455368694
Recall@4 0.6261319534282018
Recall@8 0.6817593790426908
Recall@16 0.7309184993531694
Recall@32 0.7684346701164295
Recall@64 0.8007761966364813


In [28]:
ce_dataset = CrossEncoderReranker.prepare_data(candidates, dataset, kb)

Context length: 128
Use NIL values: True


  0%|          | 0/2711 [00:00<?, ?it/s]

  0%|          | 0/2711 [00:00<?, ?it/s]

  0%|          | 0/2711 [00:00<?, ?it/s]

  0%|          | 0/773 [00:00<?, ?it/s]

  0%|          | 0/773 [00:00<?, ?it/s]

  0%|          | 0/773 [00:00<?, ?it/s]

# Reranker Training

In [29]:
from xmen.reranking.cross_encoder import CrossEncoderTrainingArgs
args = CrossEncoderTrainingArgs(num_train_epochs=20, model_name='PlanTL-GOB-ES/roberta-base-biomedical-clinical-es')

In [30]:
import datetime
output_dir = Path('output') / 'cross_encoder' / datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

In [None]:
import wandb

wandb.init(project="symptemist")

try:
    rr = CrossEncoderReranker()
    rr.fit(args, ce_dataset['train'].dataset, ce_dataset['validation'].dataset, show_progress_bar=False, eval_callback=wandb.log, output_dir=output_dir)
finally:
    if run := wandb.run:
        run.finish()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mphlobo[0m. Use [1m`wandb login --relogin`[0m to force relogin


model_name := PlanTL-GOB-ES/roberta-base-biomedical-clinical-es
num_train_epochs := 20
fp16 := True
label_smoothing := False
rank_regularization := 1.0
train_layers := None
softmax_loss := True
random_seed := 42
learning_rate := 2e-05


Some weights of the model checkpoint at PlanTL-GOB-ES/roberta-base-biomedical-clinical-es were not used when initializing RobertaForSequenceClassification: ['lm_head.decoder.weight', 'lm_head.bias', 'lm_head.dense.bias', 'lm_head.decoder.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.dense.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at PlanTL-GOB-ES/roberta-base-biomedical-clinical-es and are newly initialized: 

2023-09-20 14:05:06 - EntityLinkingEvaluator: Evaluating the model on eval dataset after epoch 0:
2023-09-20 14:10:20 - Accuracy: 0.5795601552393272
2023-09-20 14:10:20 - Accuracy @ 5: 0.7645536869340233
2023-09-20 14:10:20 - Accuracy @ 64: 1.0
2023-09-20 14:10:20 - Baseline Accuracy: 0.4230271668822768
2023-09-20 14:10:20 - Save model to output/cross_encoder/20230920-134756
2023-09-20 14:27:21 - EntityLinkingEvaluator: Evaluating the model on eval dataset after epoch 1:
2023-09-20 14:32:35 - Accuracy: 0.6338939197930142
2023-09-20 14:32:35 - Accuracy @ 5: 0.7930142302716688
2023-09-20 14:32:35 - Accuracy @ 64: 1.0
2023-09-20 14:32:35 - Baseline Accuracy: 0.4230271668822768
2023-09-20 14:32:35 - Save model to output/cross_encoder/20230920-134756
2023-09-20 14:49:37 - EntityLinkingEvaluator: Evaluating the model on eval dataset after epoch 2:
2023-09-20 14:54:51 - Accuracy: 0.6338939197930142
2023-09-20 14:54:51 - Accuracy @ 5: 0.7723156532988357
2023-09-20 14:54:51 - Accuracy @ 64: 1.0

# Prediction on Test Set

In [32]:
rr = CrossEncoderReranker.load(output_dir, device=0)

2023-09-20 19:02:50 - Use pytorch device: cuda


In [33]:
pred_validation = rr.rerank_batch(candidates['validation'], ce_dataset['validation'])

Batches:   0%|          | 0/773 [00:00<?, ?it/s]

Map:   0%|          | 0/60 [00:00<?, ? examples/s]

Map:   0%|          | 0/60 [00:00<?, ? examples/s]

In [34]:
evaluate(dataset['validation'], pred_validation)

{'strict': {'precision': 0.7042682926829268,
  'recall': 0.5976714100905562,
  'fscore': 0.6466060181945417,
  'ptp': 462,
  'fp': 194,
  'rtp': 462,
  'fn': 311,
  'n_docs_system': 60,
  'n_annos_system': 656,
  'n_docs_gold': 60,
  'n_annos_gold': 773}}

# Post-Processing

In [35]:
str2cui = []
for d in dataset['train']['entities']:
    for e in d:
        str2cui.append({'text' : ' '.join(e['text']), 'norm' : e['normalized'][0]['db_id']})
str2cui = pd.DataFrame(str2cui)

In [36]:
str2norm = str2cui.groupby('text').agg(set).norm

In [37]:
# Mapping is unique
str2norm[str2norm.map(len) != 1]

Series([], Name: norm, dtype: object)

In [38]:
str2cui_list = str2cui.groupby('text').agg(list).norm
lookup = str2cui_list[str2cui_list.map(len) > 1].map(lambda l: l[0])

In [39]:
lookup = str2norm.map(lambda s: list(s)[0])
lookup

text
5HIAA en orina de 24 horas estaba dentro de los parámetros normales    171250001
A nivel analítico no presentaba alteración                             166315009
ALT y AST mayores de 20 veces el valor normal                          707724006
ALT y AST menos de 3 veces el valor normal                             166642001
AMA (Anticuerpos antimitocondriales) negativos                         310293008
                                                                         ...    
éxitus                                                                 419099009
íleon por engrosamiento parietal                                       312895004
íleon terminal una mucosa extremadamente irregular                     312895004
óbito                                                                  419099009
β-HCG normal                                                            33809001
Name: norm, Length: 1985, dtype: object

In [40]:
def transform_lookup(sample):
    entities = sample['entities'].copy()
    for e in entities:
        t = ' '.join(e['text'])
        if t in lookup.index:
            train_cui = lookup.loc[t]
            norm = e['normalized']
            entry = {'db_name': 'SNOMED_CT', 'db_id': train_cui, 'score' : 1.0, 'predicted_by' : ['lookup']}
            norm.insert(0, entry)
    return { 'entities' : entities }

In [41]:
pred_validation_lookup = pred_validation.map(transform_lookup)

2023-09-20 19:08:15 - Loading cached processed dataset at /dhc/home/florian.borchert/workspace/symptemist_biocreative_2023/data/candidates/validation/cache-713f0c3d98a3022a.arrow


In [42]:
evaluate(dataset['validation'], pred_validation_lookup)

{'strict': {'precision': 0.7123493975903614,
  'recall': 0.6119016817593791,
  'fscore': 0.6583159359777315,
  'ptp': 473,
  'fp': 191,
  'rtp': 473,
  'fn': 300,
  'n_docs_system': 60,
  'n_annos_system': 664,
  'n_docs_gold': 60,
  'n_annos_gold': 773}}

In [None]:
from xmen.evaluation import error_analysis

In [43]:
ea_df = error_analysis(dataset['validation'], pred_validation_lookup)

In [44]:
ea_df[ea_df.pred_index == -1]

Unnamed: 0,_word_len,_abbrev,pred_start,pred_end,pred_text,gt_start,gt_end,gt_text,entity_match_type,gold_concept,gold_type,pred_index,pred_index_score,pred_top,pred_top_score,document_id
1,5,False,411,459,[alteraciones tróficas en extremidades inferio...,411,459,[alteraciones tróficas en extremidades inferio...,tp,"{'db_name': 'SNOMED_CT', 'db_id': '373408007'}",SINTOMA,-1,,449917004,0.016326,es-S0210-48062009000300017-1
2,9,False,865,936,[dos tercios anteriores del glande se encuentr...,865,936,[dos tercios anteriores del glande se encuentr...,tp,"{'db_name': 'SNOMED_CT', 'db_id': '44882003'}",SINTOMA,-1,,NIL,,es-S0210-48062009000300017-1
4,5,False,1020,1057,[normalidad de los estudios urológicos],1020,1057,[normalidad de los estudios urológicos],tp,"{'db_name': 'SNOMED_CT', 'db_id': '300561007'}",SINTOMA,-1,,302778005,0.044476,es-S0210-48062009000300017-1
7,8,False,2713,2759,[a nivel del cardias masa mamelonada y ulcerada],2713,2759,[a nivel del cardias masa mamelonada y ulcerada],tp,"{'db_name': 'SNOMED_CT', 'db_id': '126825008'}",SINTOMA,-1,,NIL,,es-S0210-48062009000300017-1
14,3,False,913,942,[Auscultación cardiaca rítmica],913,942,[Auscultación cardiaca rítmica],tp,"{'db_name': 'SNOMED_CT', 'db_id': '64730000'}",SINTOMA,-1,,106068003,0.016546,es-S0212-71992006000700009-1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
740,7,False,675,736,"[exploración sistémica, resultando todos los e...",675,736,"[exploración sistémica, resultando todos los e...",tp,"{'db_name': 'SNOMED_CT', 'db_id': 'NO_CODE'}",SINTOMA,-1,,168699003,0.016326,es-S0365-66912004001000009-1
750,7,False,1123,1169,[engrosamiento de la pared intestinal en yeyuno],1123,1169,[engrosamiento de la pared intestinal en yeyuno],tp,"{'db_name': 'SNOMED_CT', 'db_id': '304370001'}",SINTOMA,-1,,550376571000132101,0.018853,es-S1130-01082008000800011-1
751,4,False,1485,1514,[plastrón localizado en yeyuno],1485,1514,[plastrón localizado en yeyuno],tp,"{'db_name': 'SNOMED_CT', 'db_id': '282050000'}",SINTOMA,-1,,NIL,,es-S1130-01082008000800011-1
766,6,False,317,356,[tumoraciones en valva anterior de riñón],317,356,[tumoraciones en valva anterior de riñón],tp,"{'db_name': 'SNOMED_CT', 'db_id': '309088003'}",SINTOMA,-1,,237783006,0.016709,es-S0210-48062005000800014-1


# Create Submission TSV

## Validation Set

In [None]:
import pandas as pd
train_tsv = pd.read_csv(Path('../data/symptemist-train_all_subtasks+gazetteer+multilingual_230919/symptemist_train/subtask2-linking/symptemist_tsv_train_subtask2.tsv'), sep='\t')
valid_tsv = train_tsv[train_tsv.filename.isin(valid_doc_ids)]
valid_tsv.rename(columns={'span_ini' : 'start_span', 'span_end' : 'end_span'}).to_csv('../data/valid_subtask2.tsv', sep='\t', index=False)

In [None]:
len(valid_tsv)

In [None]:
val_entities = [e for d in pred_validation['entities'] for e in d]

In [None]:
pred_tsv = valid_tsv.copy()[['filename', 'label', 'span_ini', 'span_end', 'text']]

In [None]:
ents = {(d['document_id'], e['offsets'][0][0], e['offsets'][0][1]) : e for d in pred_validation for e in d['entities']}

In [None]:
assert len(ents) == len(pred_tsv)
output_val = pred_tsv.copy()
for idx, pred in pred_tsv.iterrows():
    e_idx = (pred['filename'], pred['span_ini'], pred['span_end'])
    assert e_idx in ents
    norms = ents[e_idx]['normalized']
    if len(norms) > 0:
        output_val.loc[idx, 'code'] = norms[0]['db_id']

In [None]:
output_val[~output_val.code.isna()].rename(columns={'span_ini' : 'start_span', 'span_end' : 'end_span'}).to_csv('../data/valid_pred_xmen.tsv', sep='\t', index=False)

In [None]:
#python medprocner_evaluation.py -r ../symptemist_biocreative_2023/data/valid_subtask2.tsv -p ../symptemist_biocreative_2023/data/valid_pred_xmen.tsv -t norm -o ../symptemist_biocreative_2023/data/

# Test Set