In [1]:
from hroov.utils.retrievers import (
  HiTRetriever,
  OnTRetriever
)
from hroov.utils.gpu_retrievers import (
  GPUHiTRetriever,
  GPUOnTRetriever
)
from hroov.utils.math_functools import (
  batch_poincare_dist_with_adaptive_curv_k,
  batch_poincare_dist_with_adaptive_curv_k_torch,
  entity_subsumption,
  concept_subsumption
)
from pathlib import Path

In [2]:
embeddings_dir = '../embeddings'
common_verbalisations = Path('../embeddings/verbalisations.json')
common_map = Path('../embeddings/entity_mappings.json')

In [3]:
hit_model_path = Path('../models/snomed_models/HiT_mixed_hard_negatives')

hard_hit_retriever_w_sub_score = HiTRetriever(
  embeddings_fp=Path(f"{embeddings_dir}/hit-snomed-hard-embeddings.npy"),
  meta_map_fp=common_map,
  verbalisations_fp=common_verbalisations,
  model_fp=hit_model_path,
  score_fn=entity_subsumption
)

hit_model = hard_hit_retriever_w_sub_score._model

In [4]:
ont_model_path = Path('../models/snomed_models/OnT-96')

ont_retriever_w_sub_score = OnTRetriever(
  embeddings_fp=Path(f"{embeddings_dir}/ont-snomed-96-embeddings.npy"),
  meta_map_fp=common_map,
  verbalisations_fp=common_verbalisations,
  model_fp=ont_model_path,
  score_fn=concept_subsumption
)

ont_model = ont_retriever_w_sub_score._model

In [5]:
ont_mini_model_path = Path('../models/snomed_models/OnTr-m-128')

ont_mini_retriever_w_sub_score = OnTRetriever(
    embeddings_fp=Path(f"{embeddings_dir}/ont-snomed-minified-128-embeddings.npy"),
    meta_map_fp=common_map,
    verbalisations_fp=common_verbalisations,
    model_fp=ont_mini_model_path,
    score_fn=concept_subsumption
)

ont_mini_model = ont_mini_retriever_w_sub_score._model



In [6]:
# an example HiT retriever (full, hard negatives) function, the centri_weight for this model is well-tuned at 0.1

def fetch_results_for_hit(query_string: str, top_k: int = 5, centri_weight=0.35):
  return hard_hit_retriever_w_sub_score.retrieve(
    query_string=query_string,
    top_k = top_k,
    reverse_candidate_scores=True,
    model=hit_model,
    weight=centri_weight
  )

In [7]:
# an example OnT (full) retriever function, the centri_weight for this model is well-tuned at 0.37

def fetch_results_for_ont(query_string: str, top_k: int = 5, centri_weight=0.37):
  return ont_retriever_w_sub_score.retrieve(
    query_string=query_string,
    top_k=top_k,
    reverse_candidate_scores=True,
    model=ont_model,
    weight=centri_weight
  )

In [8]:
# an example OnT (mini) retriever function, the centri_weight for this model is well-tuned at between 0.05:0.15

def fetch_results_for_ont_mini(query_string: str, top_k: int = 5, centri_weight=0.1):
  return ont_mini_retriever_w_sub_score.retrieve(
    query_string=query_string,
    top_k=top_k,
    reverse_candidate_scores=True,
    model=ont_mini_model,
    weight=centri_weight
  )

## Example Results (HiT)

In [9]:
fetch_results_for_hit("metastatic breast cancer", centri_weight=0.1)

  parent_emb_t = torch.Tensor(parent_emd)


[(0,
  'http://snomed.info/id/145501000119108',
  np.float32(-5.618085),
  'metastatic malignant neoplasm to breast'),
 (1,
  'http://snomed.info/id/1264495000',
  np.float32(-8.248222),
  'metastatic carcinoma to breast'),
 (2,
  'http://snomed.info/id/763479005',
  np.float32(-10.172695),
  'metaplastic carcinoma of breast'),
 (3,
  'http://snomed.info/id/1264115008',
  np.float32(-11.317488),
  'metastatic metaplastic carcinoma to breast'),
 (4,
  'http://snomed.info/id/12246641000119104',
  np.float32(-13.701378),
  'metastatic malignant neoplasm to bilateral breasts')]

## HiT Results for 'primary cartilaginous joint' *(with an untuned centripetal weight)*

In [10]:
fetch_results_for_hit("primary cartilaginous joint", top_k=5, centri_weight=0.4)

[(0,
  'http://snomed.info/id/118954006',
  np.float32(-9.718967),
  'cartilage finding'),
 (1,
  'http://snomed.info/id/373350008',
  np.float32(-9.742492),
  'procedure on cartilage'),
 (2,
  'http://snomed.info/id/771314001',
  np.float32(-9.876467),
  'cartilage structure'),
 (3,
  'http://snomed.info/id/371053008',
  np.float32(-9.9684105),
  'operative procedure on cartilage'),
 (4,
  'http://snomed.info/id/336003007',
  np.float32(-10.148065),
  'cartilaginous joint structure')]

## HiT Results for 'primary cartilaginous joint' *(with minimal depth-offset)*

In [11]:
fetch_results_for_hit("primary cartilaginous joint", top_k=5, centri_weight=0.1)

[(0,
  'http://snomed.info/id/336003007',
  np.float32(-12.172249),
  'cartilaginous joint structure'),
 (1,
  'http://snomed.info/id/771314001',
  np.float32(-12.464813),
  'cartilage structure'),
 (2,
  'http://snomed.info/id/373350008',
  np.float32(-12.638756),
  'procedure on cartilage'),
 (3,
  'http://snomed.info/id/371053008',
  np.float32(-12.790748),
  'operative procedure on cartilage'),
 (4,
  'http://snomed.info/id/118954006',
  np.float32(-13.072865),
  'cartilage finding')]

## Example Results for OnT

In [12]:
fetch_results_for_ont("tingling pins", top_k=6, centri_weight=0.37)

[(0,
  'http://snomed.info/id/62507009',
  np.float32(-8.398063),
  'pins and needles'),
 (1, 'http://snomed.info/id/786837007', np.float32(-8.6513), 'tingling pain'),
 (2, 'http://snomed.info/id/44077006', np.float32(-8.885715), 'numbness'),
 (3,
  'http://snomed.info/id/246605000',
  np.float32(-9.122135),
  'peripheral nerve finding'),
 (4,
  'http://snomed.info/id/85972008',
  np.float32(-9.233605),
  'sensory disorder'),
 (5,
  'http://snomed.info/id/404684003',
  np.float32(-9.500606),
  'clinical finding')]

In [13]:
fetch_results_for_ont_mini("primary cartilaginous joint", top_k=3, centri_weight=0.1)

[(0,
  'http://snomed.info/id/336003007',
  np.float32(-8.474541),
  'cartilaginous joint structure'),
 (1,
  'http://snomed.info/id/118954006',
  np.float32(-9.453409),
  'cartilage finding'),
 (2,
  'http://snomed.info/id/50927007',
  np.float32(-9.745896),
  'cartilage disorder')]

In [14]:
fetch_results_for_ont("primary cartilaginous joint", top_k=3, centri_weight=0.37)

[(0,
  'http://snomed.info/id/336003007',
  np.float32(-7.7745466),
  'cartilaginous joint structure'),
 (1,
  'http://snomed.info/id/118954006',
  np.float32(-8.738792),
  'cartilage finding'),
 (2,
  'http://snomed.info/id/50927007',
  np.float32(-9.1194105),
  'cartilage disorder')]

In [15]:
fetch_results_for_ont_mini("metastatic breast cancer", top_k=3, centri_weight=0.1)

[(0,
  'http://snomed.info/id/145501000119108',
  np.float32(-6.512161),
  'metastatic malignant neoplasm to breast'),
 (1,
  'http://snomed.info/id/1264495000',
  np.float32(-6.709763),
  'metastatic carcinoma to breast'),
 (2,
  'http://snomed.info/id/1264115008',
  np.float32(-6.762319),
  'metastatic metaplastic carcinoma to breast')]

In [16]:
fetch_results_for_ont("metastatic breast cancer", top_k=3, centri_weight=0.37)

[(0,
  'http://snomed.info/id/763479005',
  np.float32(-7.8551884),
  'metaplastic carcinoma of breast'),
 (1,
  'http://snomed.info/id/145501000119108',
  np.float32(-8.114372),
  'metastatic malignant neoplasm to breast'),
 (2,
  'http://snomed.info/id/1264495000',
  np.float32(-8.398518),
  'metastatic carcinoma to breast')]

In [17]:
fetch_results_for_ont("primary cartilaginous joint", top_k=3, centri_weight=0.37)

[(0,
  'http://snomed.info/id/336003007',
  np.float32(-7.7745466),
  'cartilaginous joint structure'),
 (1,
  'http://snomed.info/id/118954006',
  np.float32(-8.738792),
  'cartilage finding'),
 (2,
  'http://snomed.info/id/50927007',
  np.float32(-9.1194105),
  'cartilage disorder')]

In [18]:
fetch_results_for_ont_mini("bile duct sludge")

[(0,
  'http://snomed.info/id/235568004',
  np.float32(-8.949597),
  'bile duct drainage'),
 (1,
  'http://snomed.info/id/27123005',
  np.float32(-9.202466),
  'biliary sludge'),
 (2, 'http://snomed.info/id/70150004', np.float32(-9.4263), 'bile'),
 (3,
  'http://snomed.info/id/20239009',
  np.float32(-9.515549),
  'bile duct proliferation'),
 (4,
  'http://snomed.info/id/118926004',
  np.float32(-9.743823),
  'disorder of bile duct')]

In [19]:
fetch_results_for_hit("bile duct sludge", centri_weight=0.1)

[(0,
  'http://snomed.info/id/27123005',
  np.float32(-17.908875),
  'biliary sludge'),
 (1, 'http://snomed.info/id/44901006', np.float32(-18.012564), 'sludge'),
 (2,
  'http://snomed.info/id/28273000',
  np.float32(-19.579191),
  'bile duct structure'),
 (3,
  'http://snomed.info/id/118926004',
  np.float32(-19.695902),
  'disorder of bile duct'),
 (4,
  'http://snomed.info/id/372166008',
  np.float32(-19.824635),
  'bile duct part')]

In [20]:
fetch_results_for_ont("bile duct sludge", centri_weight=0.37)

[(0,
  'http://snomed.info/id/27123005',
  np.float32(-9.168112),
  'biliary sludge'),
 (1, 'http://snomed.info/id/70150004', np.float32(-9.379118), 'bile'),
 (2,
  'http://snomed.info/id/118926004',
  np.float32(-10.467894),
  'disorder of bile duct'),
 (3, 'http://snomed.info/id/44901006', np.float32(-10.478819), 'sludge'),
 (4,
  'http://snomed.info/id/119341000',
  np.float32(-10.500213),
  'bile specimen')]

In [21]:
fetch_results_for_ont("intravenous methylprednisolone")

[(0,
  'http://snomed.info/id/116593003',
  np.float32(-9.438526),
  'methylprednisolone'),
 (1,
  'http://snomed.info/id/763158003',
  np.float32(-9.829003),
  'medicinal product'),
 (2,
  'http://snomed.info/id/469667004',
  np.float32(-9.857391),
  'intravenous solution compounder'),
 (3,
  'http://snomed.info/id/410942007',
  np.float32(-9.876816),
  'drug or medicament'),
 (4,
  'http://snomed.info/id/427324005',
  np.float32(-10.049192),
  'intravenous nutrition agent')]

In [22]:
fetch_results_for_hit("intravenous methylprednisolone", top_k=10, centri_weight=0.1)

[(0,
  'http://snomed.info/id/293169001',
  np.float32(-18.370144),
  'methylprednisolone adverse reaction'),
 (1,
  'http://snomed.info/id/121375002',
  np.float32(-18.432688),
  'methylprednisolone measurement'),
 (2,
  'http://snomed.info/id/116593003',
  np.float32(-19.45129),
  'methylprednisolone'),
 (3,
  'http://snomed.info/id/350449009',
  np.float32(-19.689938),
  'product containing methylprednisolone in oral dose form'),
 (4,
  'http://snomed.info/id/294706006',
  np.float32(-19.770927),
  'allergy to methylprednisolone'),
 (5,
  'http://snomed.info/id/350450009',
  np.float32(-19.873163),
  'product containing methylprednisolone in parenteral dose form'),
 (6,
  'http://snomed.info/id/27242001',
  np.float32(-20.301685),
  'product containing methylprednisolone'),
 (7,
  'http://snomed.info/id/429453000',
  np.float32(-22.261164),
  'product containing methyltransferase inhibitor'),
 (8,
  'http://snomed.info/id/427911001',
  np.float32(-22.450594),
  'substance with methy

In [23]:
fetch_results_for_hit("intravenous injection of methylprednisolone​​")

[(0,
  'http://snomed.info/id/18629005',
  np.float32(-17.335476),
  'administration of drug or medicament'),
 (1,
  'http://snomed.info/id/350449009',
  np.float32(-17.568481),
  'product containing methylprednisolone in oral dose form'),
 (2,
  'http://snomed.info/id/410942007',
  np.float32(-17.590736),
  'drug or medicament'),
 (3,
  'http://snomed.info/id/350450009',
  np.float32(-17.667301),
  'product containing methylprednisolone in parenteral dose form'),
 (4,
  'http://snomed.info/id/405679004',
  np.float32(-17.702347),
  'drug administration device')]

In [24]:
fetch_results_for_ont("bile duct stones​", top_k=10, centri_weight=0.37)

[(0,
  'http://snomed.info/id/118926004',
  np.float32(-8.865848),
  'disorder of bile duct'),
 (1,
  'http://snomed.info/id/235923000',
  np.float32(-9.071082),
  'retained bile duct stone'),
 (2, 'http://snomed.info/id/70150004', np.float32(-9.128838), 'bile'),
 (3,
  'http://snomed.info/id/20239009',
  np.float32(-9.295712),
  'bile duct proliferation'),
 (4,
  'http://snomed.info/id/372166008',
  np.float32(-9.376175),
  'bile duct part'),
 (5,
  'http://snomed.info/id/57259009',
  np.float32(-10.087494),
  'gallbladder bile'),
 (6,
  'http://snomed.info/id/28273000',
  np.float32(-10.127083),
  'bile duct structure'),
 (7,
  'http://snomed.info/id/110818007',
  np.float32(-10.190354),
  'bile duct and stomach'),
 (8,
  'http://snomed.info/id/110617004',
  np.float32(-10.254094),
  'gallbladder and bile ducts'),
 (9,
  'http://snomed.info/id/110817002',
  np.float32(-10.256865),
  'bile duct and liver')]

In [25]:
fetch_results_for_ont("bile duct sludge​", top_k=10, centri_weight=0.37)

[(0,
  'http://snomed.info/id/27123005',
  np.float32(-9.168112),
  'biliary sludge'),
 (1, 'http://snomed.info/id/70150004', np.float32(-9.379118), 'bile'),
 (2,
  'http://snomed.info/id/118926004',
  np.float32(-10.467894),
  'disorder of bile duct'),
 (3, 'http://snomed.info/id/44901006', np.float32(-10.478819), 'sludge'),
 (4,
  'http://snomed.info/id/119341000',
  np.float32(-10.500213),
  'bile specimen'),
 (5,
  'http://snomed.info/id/110928002',
  np.float32(-10.700028),
  'bile duct cytologic material'),
 (6, 'http://snomed.info/id/39477002', np.float32(-10.727328), 'feces'),
 (7,
  'http://snomed.info/id/372166008',
  np.float32(-10.812213),
  'bile duct part'),
 (8,
  'http://snomed.info/id/235569007',
  np.float32(-10.819471),
  'collection of bile'),
 (9,
  'http://snomed.info/id/110818007',
  np.float32(-10.85163),
  'bile duct and stomach')]

In [26]:
fetch_results_for_ont_mini("bile duct stones​", top_k=10, centri_weight=0.1)

[(0,
  'http://snomed.info/id/235923000',
  np.float32(-8.842586),
  'retained bile duct stone'),
 (1,
  'http://snomed.info/id/20239009',
  np.float32(-8.888088),
  'bile duct proliferation'),
 (2,
  'http://snomed.info/id/118926004',
  np.float32(-9.103369),
  'disorder of bile duct'),
 (3,
  'http://snomed.info/id/235548005',
  np.float32(-9.515656),
  'chemodissolution of bile duct stone'),
 (4,
  'http://snomed.info/id/235568004',
  np.float32(-10.090011),
  'bile duct drainage'),
 (5,
  'http://snomed.info/id/235932003',
  np.float32(-10.331136),
  'bile duct leakage'),
 (6,
  'http://snomed.info/id/384647004',
  np.float32(-10.433599),
  'bile duct stone removal'),
 (7, 'http://snomed.info/id/70150004', np.float32(-10.715321), 'bile'),
 (8,
  'http://snomed.info/id/43030007',
  np.float32(-10.80925),
  'stenosis of bile duct'),
 (9,
  'http://snomed.info/id/265447007',
  np.float32(-10.834778),
  'bile duct operation')]

In [27]:
fetch_results_for_ont("bile duct stones​", top_k=10, centri_weight=0.37)

[(0,
  'http://snomed.info/id/118926004',
  np.float32(-8.865848),
  'disorder of bile duct'),
 (1,
  'http://snomed.info/id/235923000',
  np.float32(-9.071082),
  'retained bile duct stone'),
 (2, 'http://snomed.info/id/70150004', np.float32(-9.128838), 'bile'),
 (3,
  'http://snomed.info/id/20239009',
  np.float32(-9.295712),
  'bile duct proliferation'),
 (4,
  'http://snomed.info/id/372166008',
  np.float32(-9.376175),
  'bile duct part'),
 (5,
  'http://snomed.info/id/57259009',
  np.float32(-10.087494),
  'gallbladder bile'),
 (6,
  'http://snomed.info/id/28273000',
  np.float32(-10.127083),
  'bile duct structure'),
 (7,
  'http://snomed.info/id/110818007',
  np.float32(-10.190354),
  'bile duct and stomach'),
 (8,
  'http://snomed.info/id/110617004',
  np.float32(-10.254094),
  'gallbladder and bile ducts'),
 (9,
  'http://snomed.info/id/110817002',
  np.float32(-10.256865),
  'bile duct and liver')]

In [28]:
fetch_results_for_hit("bile duct stones​", top_k=10, centri_weight=0.1)

[(0,
  'http://snomed.info/id/28273000',
  np.float32(-12.790095),
  'bile duct structure'),
 (1,
  'http://snomed.info/id/372166008',
  np.float32(-13.01217),
  'bile duct part'),
 (2,
  'http://snomed.info/id/384647004',
  np.float32(-13.058372),
  'bile duct stone removal'),
 (3,
  'http://snomed.info/id/265447007',
  np.float32(-13.084067),
  'bile duct operation'),
 (4,
  'http://snomed.info/id/118926004',
  np.float32(-13.195814),
  'disorder of bile duct'),
 (5, 'http://snomed.info/id/70150004', np.float32(-13.419483), 'bile'),
 (6,
  'http://snomed.info/id/366741003',
  np.float32(-13.938126),
  'repair of bile duct'),
 (7,
  'http://snomed.info/id/118824002',
  np.float32(-13.999963),
  'procedure on biliary tract'),
 (8,
  'http://snomed.info/id/364163003',
  np.float32(-14.066526),
  'biliary tract observable'),
 (9,
  'http://snomed.info/id/447881000',
  np.float32(-14.078251),
  'specimen from biliary system')]

In [29]:
fetch_results_for_ont("tingling pins", top_k=6, centri_weight=0.37)

[(0,
  'http://snomed.info/id/62507009',
  np.float32(-8.398063),
  'pins and needles'),
 (1, 'http://snomed.info/id/786837007', np.float32(-8.6513), 'tingling pain'),
 (2, 'http://snomed.info/id/44077006', np.float32(-8.885715), 'numbness'),
 (3,
  'http://snomed.info/id/246605000',
  np.float32(-9.122135),
  'peripheral nerve finding'),
 (4,
  'http://snomed.info/id/85972008',
  np.float32(-9.233605),
  'sensory disorder'),
 (5,
  'http://snomed.info/id/404684003',
  np.float32(-9.500606),
  'clinical finding')]