### load snomed

In [2]:
import networkx as nx
from tqdm import tqdm
from Snomed import Snomed

In [3]:
SNOMED_PATH = '../data/SnomedCT_201907' # you need to download your own SNOMED distribution
snomed = Snomed(SNOMED_PATH)
snomed.load_snomed()

In [4]:
snomed_sf_id_pairs = []

for snomed_id in tqdm(snomed.graph.nodes):
    
    node_descs = snomed.index_definition[snomed_id]
    for d in node_descs:
        snomed_sf_id_pairs.append((d, snomed_id))

print(len(snomed_sf_id_pairs))

100%|██████████| 350830/350830 [00:00<00:00, 909733.56it/s]

910823





In [5]:
snomed_sf_id_pairs[:10]

[('Neoplasm of anterior aspect of epiglottis', '126813005'),
 ('Neoplasm of anterior aspect of epiglottis (disorder)', '126813005'),
 ('Neoplasm of junctional region of epiglottis', '126814004'),
 ('Neoplasm of junctional region of epiglottis (disorder)', '126814004'),
 ('Neoplasm of lateral wall of oropharynx', '126815003'),
 ('Neoplasm of lateral wall of oropharynx (disorder)', '126815003'),
 ('Neoplasm of posterior wall of oropharynx', '126816002'),
 ('Neoplasm of posterior wall of oropharynx (disorder)', '126816002'),
 ('Tumour of posterior wall of oropharynx', '126816002'),
 ('Tumor of posterior wall of oropharynx', '126816002')]

In [30]:
snomed_sf_id_pairs_100k = snomed_sf_id_pairs[:100000] # for simplicity

In [31]:
all_names = [p[0] for p in snomed_sf_id_pairs_100k]
all_ids = [p[1] for p in snomed_sf_id_pairs_100k]

In [32]:
all_names[:10]

['Neoplasm of anterior aspect of epiglottis',
 'Neoplasm of anterior aspect of epiglottis (disorder)',
 'Neoplasm of junctional region of epiglottis',
 'Neoplasm of junctional region of epiglottis (disorder)',
 'Neoplasm of lateral wall of oropharynx',
 'Neoplasm of lateral wall of oropharynx (disorder)',
 'Neoplasm of posterior wall of oropharynx',
 'Neoplasm of posterior wall of oropharynx (disorder)',
 'Tumour of posterior wall of oropharynx',
 'Tumor of posterior wall of oropharynx']

In [33]:
all_ids[:10]

['126813005',
 '126813005',
 '126814004',
 '126814004',
 '126815003',
 '126815003',
 '126816002',
 '126816002',
 '126816002',
 '126816002']

### load sapbert

In [34]:
from transformers import AutoTokenizer, AutoModel  
tokenizer = AutoTokenizer.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext")  
model = AutoModel.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext") #.cuda(1)

In [35]:
import numpy as np
import torch

#### encode snomed labels

In [36]:
bs = 128
all_reps = []
for i in tqdm(np.arange(0, len(all_names), bs)):
    toks = tokenizer.batch_encode_plus(all_names[i:i+bs], 
                                       padding="max_length", 
                                       max_length=25, 
                                       truncation=True,
                                       return_tensors="pt")
    #toks_cuda = {}
    #for k,v in toks.items():
    #    toks_cuda[k] = v.cuda(1)
    #output = model(**toks_cuda)
    
    output = model(**toks)
    cls_rep = output[0][:,0,:]
    
    all_reps.append(cls_rep.cpu().detach().numpy())
all_reps_emb = np.concatenate(all_reps, axis=0)

100%|██████████| 782/782 [10:22<00:00,  1.26it/s]


In [47]:
print (all_reps_emb.shape)

(100000, 768)


#### encode query

In [42]:
query = "cardiopathy"
query_toks = tokenizer.batch_encode_plus([query], 
                                       padding="max_length", 
                                       max_length=25, 
                                       truncation=True,
                                       return_tensors="pt")

In [43]:
query_output = model(**query_toks)
query_cls_rep = query_output[0][:,0,:]

In [44]:
query_cls_rep.shape

torch.Size([1, 768])

#### find query's nearest neighbour

In [45]:
# for large-scale search, should switch to faiss
from scipy.spatial.distance import cdist

In [52]:
dist = cdist(query_cls_rep.cpu().detach().numpy(), all_reps_emb)
nn_index = np.argmin(dist)
print ("predicted label:", snomed_sf_id_pairs_100k[nn_index])

predicted label: ('Cardiac complication', '40172005')
