# UMLS Entity Linking

In [2]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [3]:
import datasets
from pathlib import Path

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
from xmen import load_config, load_kb

conf = load_config('xmen_ggponc3.yaml')
kb = load_kb(Path(conf.cache_dir) / 'ggponc3' / 'ggponc3.jsonl')

## Preparation

`xmen dict xmen_ggponc3.yaml`

`xmen index xmen_ggponc3.yaml --all --overwrite`

In [None]:
!unzip -n -q data/v3.0_2024_01_03/plain_text/plain_text.zip -d data/v3_plain_text

## Run NER Model

In [None]:
from spacy import Language

@Language.component('prevent-sbd')
def prevent_sentence_boundary_detection(doc):
    doc[0].is_sent_start = True
    for token in doc[1:]:
        token.is_sent_start = False
    return doc

In [None]:
import spacy

spacy.require_gpu()

nlp = spacy.load('de_ggponc_medbertde')
# GGPONC is already split into sentences
nlp.add_pipe('prevent-sbd', before='parser')

In [None]:
from pathlib import Path
from tqdm.auto import tqdm

def get_sentences():
    for f in tqdm(list(sorted(Path('data/v3_plain_text/sentences/all_files_sentences/').glob('*.txt')))):
        with open(f, 'r', encoding='utf-8') as fh:
            for l in fh.readlines():
                l = l.rstrip()
                if l:
                    yield l, f.stem

In [None]:
sents = list(get_sentences())

In [None]:
len(sents)

In [None]:
%%time
ner_result = list(nlp.pipe(get_sentences(), as_tuples=True, batch_size=256))

In [None]:
from spacy import displacy

displacy.render(ner_result[10][0], style='span', options={'spans_key' : 'entities'})

In [None]:
from itertools import groupby
from spacy.tokens import Doc, DocBin

In [None]:
from typing import List, Tuple, Any

def merge_sentence_docs(sentence_docs : List[Tuple[Doc, Any]], key_name='file_name'):
    docs = []
    for key, grp in groupby(sentence_docs, key=lambda t: t[1]):
        sents = [g[0] for g in grp]
        doc = Doc.from_docs(sents)
        for k in doc.spans.keys():
            assert sum([len(d.spans[k]) for d in sents]) == len(doc.spans[k])
        doc.user_data[key_name] = key
        docs.append(doc)
    return docs

In [None]:
docs = merge_sentence_docs(ner_result)

In [None]:
from xmen.data import from_spacy
ds = from_spacy(docs, span_key='entities', doc_id_key='file_name')

In [None]:
ds.save_to_disk('data/ggponc_v3_spacy')

# Run Entity Linker

In [None]:
ds = datasets.load_from_disk('data/ggponc_v3_spacy')

In [None]:
from xmen.data import from_spacy
from xmen.linkers import default_ensemble

In [None]:
from xmen.data import AbbreviationExpander
ds = AbbreviationExpander().transform_batch(ds)

In [None]:
linker = default_ensemble(**conf.linker.candidate_generation)

In [None]:
candidates_raw = linker.predict_batch(ds, batch_size=128)

In [None]:
candidates_raw.save_to_disk('data/ggponc_v3_candidates_raw')

### Semantic Type Filter

In [None]:
candidates_raw = datasets.load_from_disk('data/ggponc_v3_candidates_raw')

In [None]:
from xmen.data import SemanticTypeFilter
import pandas as pd

tui_df = pd.read_csv('ggponc2tui.csv')
type2tui = {}
for c in ['Diagnosis_or_Pathology', 'Other_Finding', 'Clinical_Drug', 'Nutrient_or_Body_Substance',
       'External_Substance', 'Therapeutic', 'Diagnostic']:
    type2tui[c] = list(tui_df.TUI[tui_df[c] == 'x'].values)
    
type_filter = SemanticTypeFilter(type2tui, kb)

In [None]:
candidates_all = type_filter.transform_batch(candidates_raw)

In [None]:
candidates_all.save_to_disk('data/ggponc_v3_candidates')

## Re-Ranking

In [5]:
candidates_all = datasets.load_from_disk('data/ggponc_v3_candidates')

In [6]:
K_RERANKING = 16

In [7]:
from xmen.data import filter_and_apply_threshold
candidates = filter_and_apply_threshold(candidates_all, K_RERANKING, 0.0)

Map: 100%|█████████████████████████████████████████████████████████████████████████████████████| 11821/11821 [03:35<00:00, 54.81 examples/s]


In [8]:
from xmen.reranking import CrossEncoderReranker

In [9]:
ce_candidates = CrossEncoderReranker.prepare_data(candidates, None, kb, use_nil=False)

Context length: 128
Use NIL values: False


100%|████████████████████████████████████████████████████████████████████████████████████████████| 271629/271629 [00:06<00:00, 41458.89it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████| 271629/271629 [00:01<00:00, 219170.48it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████| 271629/271629 [01:02<00:00, 4325.17it/s]


In [10]:
rr = CrossEncoderReranker.load("phlobo/xmen-de-ce-medmentions", device=0)

In [None]:
reranked = rr.rerank_batch(candidates, ce_candidates)

Batches:   5%|████                                                                                 | 13066/271561 [16:23<4:33:35, 15.75it/s]

In [None]:
reranked.save_to_disk('data/ggponc_v3_rr_no_nil')

# Final Format

In [None]:
reranked = datasets.load_from_disk('data/ggponc_v3_rr_no_nil')

In [None]:
import random
show_indices = [random.randint(0, len(candidates))]

In [None]:
get_annotation_dataframe(reranked.select(show_indices), kb, 3, 0.0).iloc[0:10]

In [None]:
#df = get_annotation_dataframe(reranked, kb, 3, 0.07)
df = get_annotation_dataframe(reranked, kb, 1, 0.0)

In [None]:
df.to_csv('data/entities_with_cuis.tsv', sep='\t', index=False)