# Searching nearest neighbour in UMLS knowledge base


The goal of this notebook is to search for nearest neighbour of each predicted mentions (from mention detection) in the UMLS knowledge base. And given that UMLS concepts are mapped to STY labels, we take the label of the concept/alias as the label of the predicted mention.

In [1]:
%%capture

import json
import numpy as np
from numpy import save
import pandas as pd

!pip install sentence_transformers
!pip install faiss
!apt-get install libomp-dev --yes
!pip install seqeval

import faiss
from sentence_transformers import SentenceTransformer

from transformers import BertModel, BertTokenizer, BertConfig

from tqdm.auto import tqdm

import tensorflow as tf

from seqeval.metrics import classification_report, f1_score, accuracy_score
from seqeval.scheme import IOB2

import gc
import csv
from IPython.display import FileLink

In [2]:
def build_concepts_arr(kb):
    umls_concepts = []
    umls_cui = []
    umls_sty = []
    
    for k, v in kb.items():
        umls_concepts.append(kb[k]['Name'].lower())
        umls_cui.append(k)
        umls_sty.append(kb[k]['STY'][0])
        for STR in kb[k]['STR']:
            umls_concepts.append(STR.lower())
            umls_cui.append(k)
            umls_sty.append(kb[k]['STY'][0])
            
    return umls_concepts, umls_cui, umls_sty

We mapped all the concepts that was just stored in the dict database to a cui code for each concept, so it becomes fast to extract cui codes later

In [3]:
def create_maps():
    alias2cui = {}
    alias2sty = {}
    cui2sty = {}
    cui2alias = {}
    sty2interpretation = {}
    
    with open('../input/umls-kb/umls.2017AA.active.st21pv.json') as f:
        kb = json.load(f)
        
    umls_concepts, umls_cui, umls_sty = build_concepts_arr(kb)

    for concept, cui, sty in zip(umls_concepts, umls_cui, umls_sty):
        alias2cui[concept] = cui
        alias2sty[concept] = sty
        cui2sty[cui] = sty

    for alias, cui in alias2cui.items():
        if cui in cui2alias.keys():
            if alias not in cui2alias[cui]:
                cui2alias[cui].append(alias)
        else:
            cui2alias[cui] = [alias]
            
    with open('../input/thesis/SemanticTypes_2018AB.txt', 'r') as file:
        for line in file.readlines():
            line = line.split("|")
            interpretation = (line[2][:-1]).lower().replace(" ", "_")
            sty2interpretation[line[1]] = interpretation

    interpretation2sty = {v: k for k, v in sty2interpretation.items()}

    del kb
    del umls_concepts
    del umls_cui
    del umls_sty
    gc.collect()
    
    return alias2cui, alias2sty, cui2sty, cui2alias, sty2interpretation, interpretation2sty

alias2cui, alias2sty, cui2sty, cui2alias, sty2interpretation, interpretation2sty = create_maps()

In [4]:
print("Semantic Type of the CUI C0086930 are: ", cui2sty['C0086930'], " or ", sty2interpretation[cui2sty['C0086930']])
print("Aliases of the CUI C0086930 are: ", cui2alias['C0086930'])

Semantic Type of the CUI C0086930 are:  T058  or  health_care_activity
Aliases of the CUI C0086930 are:  ['risk assessment', 'assessments, risk', 'risk assessments', 'assessment, risk']


## Predicted Mentions

This csv file is the predictions of mentions from the mentions detections

In [5]:
test_predictions = pd.read_csv('../input/predictions/test_prediction.csv')

def build_sentences():
    single_sent = []
    sentences = []
    rep_sentences = []
    
    for row in test_predictions.iterrows():
        if row[1]['Tokens'] == '[SEP]':
            single_sent.append(row[1]['Tokens'])
            sentences.append(single_sent[1:-1])
            single_sent = []
        else:
            single_sent.append(row[1]['Tokens'])

    for sent in sentences:
        for i in range(len(sent)+2):
            rep_sentences.append(sent)
    
    return rep_sentences

df_sentences = build_sentences()

test_predictions['Sentences'] = df_sentences

In [6]:
train_predictions = pd.read_csv('../input/predictions/train_prediction.csv')

def build_sentences():
    single_sent = []
    sentences = []
    rep_sentences = []
    
    for row in train_predictions.iterrows():
        if row[1]['Tokens'] == '[SEP]':
            single_sent.append(row[1]['Tokens'])
            sentences.append(single_sent[1:-1])
            single_sent = []
        else:
            single_sent.append(row[1]['Tokens'])

    for sent in sentences:
        for i in range(len(sent)+2):
            rep_sentences.append(sent)
    
    return rep_sentences

df_sentences = build_sentences()

train_predictions['Sentences'] = df_sentences

In [7]:
test_mentions_candidates = np.unique(test_predictions[test_predictions['Detection Prediction'] != "O"]['Span'])
train_mentions_candidates = np.unique(train_predictions[train_predictions['Detection Prediction'] != "O"]['Span'])

## Extract Embedding of candidates and mentions using the UMLSBert model

The next function is used to weight combination the context embeddings and mentions embeddigs

In [8]:
def combine_mention_context_test(query_emb_dict, sent_emb_dict):
    idx = -1
    combined_emb = {}
    
    for mention, sent, detect_label in zip(test_predictions['Span'], 
                                           test_predictions['Sentences'], 
                                           test_predictions['Detection Prediction']):
        idx += 1
        
        if type(mention) != str:
            mention = str(mention)
        
        s = []
        
        for tok in sent:
            if type(tok) != str:
                s.append(str(tok))
            else:
                s.append(tok)
        s = " ".join(s)
        
        if detect_label != 'O':
            if query_emb_dict.get(mention) is not None:
                combined = query_emb_dict[mention] * 0.7 + sent_emb_dict[s] * 0.3
                combined_norm = np.linalg.norm(combined)
                combined = combined/combined_norm
                combined_emb[idx] = combined
    
    return combined_emb

In [9]:
def combine_mention_context_train(query_emb_dict, sent_emb_dict):
    idx = -1
    combined_emb = {}
    
    for mention, sent, detect_label in zip(train_predictions['Span'], 
                                           train_predictions['Sentences'], 
                                           train_predictions['Detection Prediction']):
        idx += 1
        
        if type(mention) != str:
            mention = str(mention)
        
        s = []
        
        for tok in sent:
            if type(tok) != str:
                s.append(str(tok))
            else:
                s.append(tok)
        s = " ".join(s)
        
        if detect_label != 'O':
            if query_emb_dict.get(mention) is not None:
                combined = query_emb_dict[mention] * 0.7 + sent_emb_dict[s] * 0.3
                combined_norm = np.linalg.norm(combined)
                combined = combined/combined_norm
                combined_emb[idx] = combined
    
    return combined_emb

We use this function to save the embeddings of the mentions, context and UMLS aliases in a binary files, The embeddings are normalized in order to calculate the cosine similarity later 

In [10]:
def save_emb(calculate_kb=False):
    query_emb = {}
    sent_emb = {}
    unique_queries = {}
    unique_mentions = {}
    kb_concepts_emb = []
    sent_emb_dict = {}
    query_emb_dict = {}
    kb_concepts_count = {}
    
    coder_embedder = SentenceTransformer("GanjinZero/UMLSBert_ENG")

    for q in test_mentions_candidates:
        q = str(q)
        unique_queries[q] = None
    
    unique_queries = list(unique_queries.keys())

    query_emb = coder_embedder.encode(unique_queries, batch_size=500)
    
    for q, emb in zip(unique_queries, query_emb):
        query_emb_dict[q] = emb
    
    del query_emb
    del sent_emb
    gc.collect()
    
    sentences = test_predictions['Sentences']
    
    merged_sents = []
    for sent in np.unique(sentences):
        s = []
        for tok in sent:
            if type(tok) != str:
                s.append(str(tok))
            else:
                s.append(tok)
        s = " ".join(s)
        merged_sents.append(s)
    
    sent_emb = coder_embedder.encode(merged_sents, batch_size=500)
    
    for sent, emb in zip(merged_sents, sent_emb):
        sent_emb_dict[sent] = emb
    
    combined_mention_context = combine_mention_context_test(query_emb_dict, sent_emb_dict)

    np.save('test_queries.npy', list(combined_mention_context.values()))

    del query_emb_dict
    del sent_emb_dict
    gc.collect()
    
    if calculate_kb:
        kb_unique_concepts = list(alias2cui.keys())

        chunks = (len(kb_unique_concepts) - 1) // 100000 + 1
    
        for i in tqdm(range(chunks)):
            batch = kb_unique_concepts[i*100000:(i+1)*100000]
            kb_concepts_emb = coder_embedder.encode(batch)
            kb_concepts_emb /= np.linalg.norm(kb_concepts_emb, axis=1, keepdims=True)
            np.save('./kb'+str(i)+'.npy', kb_concepts_emb)    
            del kb_concepts_emb
            gc.collect()

    del sentences
    del merged_sents
    gc.collect()

    np.save('test_query_indices.npy', list(combined_mention_context.keys())) 

In [11]:
#save_emb()
#FileLink(r'./test_queries.npy')

In [13]:
queries_emb = np.load('../input/d/mhmdrdwn/saved-embeddings/test_queries.npy', mmap_mode='r')
queries_indices = np.load('../input/d/mhmdrdwn/saved-embeddings/test_query_indices.npy', mmap_mode='r') 
kb_concepts_emb = np.load('../input/d/mhmdrdwn/saved-embeddings/kb'+str(0)+'.npy', mmap_mode='r')

In [14]:
kb_concepts_emb.shape, queries_emb.shape

((100000, 768), (63567, 768))

In [15]:
dimension = kb_concepts_emb.shape[1]
k = 1
nlist = 21 #number of clusters
index = faiss.IndexFlatIP(dimension)

We build a search index using Faiss, we train the index on all the embeddigs of the UMLS aliases

In [17]:
# add all UMLS embeddings to the index
for i in tqdm(range(26)):
    kb_concepts_emb = np.load('../input/d/mhmdrdwn/saved-embeddings/kb'+str(i)+'.npy', mmap_mode='r')
    index.train(kb_concepts_emb)
    index.add(kb_concepts_emb)
    kb_concepts_emb = None

  0%|          | 0/26 [00:00<?, ?it/s]

Now we search all the detected embeddings in the UMLS, We set nprobe as 21 which equals the number of clusters meaning we are doing exhaustive search

In [18]:
def search():
    index.nprobe = 21  
    chunks = (len(queries_emb) - 1) // 1000 + 1
    nearest_neighbors = []
    distances = []
    kb_unique_concepts = list(alias2cui.keys())
    for i in tqdm(range(chunks)):
        batch = queries_emb[i*1000:(i+1)*1000]
        distance, nearest_neighbor = index.search(batch, k)
        distances.extend(distance)
        nearest_neighbors.extend(nearest_neighbor)
        
    idx2nn = {}
    idx2dis = {}
    for nn, dis, q_idx in zip(nearest_neighbors, distances ,queries_indices):
        idx2nn[q_idx] = sty2interpretation[alias2sty[kb_unique_concepts[nn[0]]]]
        idx2dis[q_idx] = dis[0]
        
    return idx2nn, idx2dis

idx2nn, idx2dis = search()

  0%|          | 0/64 [00:00<?, ?it/s]

In [19]:
def build_prediction_list():
    extracted_sty = []
    sty_sim = []
    idx = -1
    
    for mention, pred, sent in zip(tqdm(test_predictions['Span']), 
                                   test_predictions['Detection Prediction'], 
                                   test_predictions['Sentences']):
        
        idx += 1
        if pred != 'O':
            if idx2nn.get(idx) is not None:
                extracted_sty.append(pred[:2]+idx2nn.get(idx))
                sty_sim.append(idx2dis.get(idx))
            else:
                extracted_sty.append('O')
                sty_sim.append(None)
        else:
            extracted_sty.append('O')
            sty_sim.append(None)
        
    return extracted_sty, sty_sim

In [20]:
extracted_sty, sty_sim = build_prediction_list()
#train_predictions['Predicted Label'] = extracted_sty
#train_predictions['Cosine Similarity'] = sty_sim
test_predictions['Predicted Label'] = extracted_sty
test_predictions['Cosine Similarity'] = sty_sim

  0%|          | 0/270657 [00:00<?, ?it/s]

In [21]:
test_predictions.to_csv('test_nn_prediction.csv')
#train_predictions.to_csv('train_nn_prediction.csv')

In [22]:
FileLink(r'./test_nn_prediction.csv')