In [38]:
import numpy as np
from pathlib import Path

from hroov.utils.retrievers import (
    TFIDFRetriever,
    BM25Retriever,
    SBERTRetriever,
    HiTRetriever,
    OnTRetriever
)

from hroov.utils.math_functools import (
    batch_cosine_similarity,
    batch_poincare_dist_with_adaptive_curv_k,
    entity_subsumption,
    concept_subsumption
)

embeddings_dir = '../embeddings'

In [39]:
sbert_plm_embs = np.load(f"{embeddings_dir}/sbert-plm-embeddings.npy", mmap_mode="r")
hit_snomed_25_embs = np.load(f"{embeddings_dir}/hit-snomed-25-embeddings.npy", mmap_mode="r") # HiT FULL
ont_snomed_96_embs = np.load(f"{embeddings_dir}/ont-snomed-96-embeddings.npy", mmap_mode="r") # SNOMED FULL
ont_minified_128_embs = np.load(f"{embeddings_dir}/ont-snomed-minified-128-embeddings.npy", mmap_mode="r")

In [40]:
embeddings_dir = "../embeddings"
common_map = Path(f"{embeddings_dir}/entity_mappings.json")
common_verbalisations = Path(f"{embeddings_dir}/verbalisations.json")

Retrieval: Lexical Methods

In [41]:
tfidf_ret = TFIDFRetriever(common_verbalisations, common_map)
bm25_ret = BM25Retriever(common_verbalisations, common_map, k1=1.3, b=0.7)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Retrieval: SBERT

In [42]:
sbert_plm_hf_string = "all-MiniLM-L12-v2"

sbert_ret_plm_w_cosine_sim = SBERTRetriever(
  embeddings_fp=Path(f"{embeddings_dir}/sbert-plm-embeddings.npy"),
  meta_map_fp=common_map,
  verbalisations_fp=common_verbalisations,
  model_str="all-MiniLM-L12-v2",
  score_fn=batch_cosine_similarity
)

Retrieval: HiT

In [43]:
# Hierarchy Transformer-based Retriever (HiT Full)

hit_snomed25_model_path = Path('../models/snomed_models/HiT-mixed-SNOMED-25/final')

hit_ret_snomed_25_w_hyp_dist = HiTRetriever(
  embeddings_fp=Path(f"{embeddings_dir}/hit-snomed-25-embeddings.npy"),
  meta_map_fp=common_map,
  verbalisations_fp=common_verbalisations,
  model_fp=hit_snomed25_model_path,
  score_fn=batch_poincare_dist_with_adaptive_curv_k
)

hit_ret_snomed_25_w_ent_sub = HiTRetriever(
  embeddings_fp=Path(f"{embeddings_dir}/hit-snomed-25-embeddings.npy"),
  meta_map_fp=common_map,
  verbalisations_fp=common_verbalisations,
  model_fp=hit_snomed25_model_path,
  score_fn=entity_subsumption
)



In [44]:
ont_snomed_96_model_path = Path('../models/snomed_models/OnT-96')

ont_snomed_96_w_hyp_dist = OnTRetriever(
    embeddings_fp=Path(f"{embeddings_dir}/ont-snomed-96-embeddings.npy"),
    meta_map_fp=common_map,
    verbalisations_fp=common_verbalisations,
    model_fp=ont_snomed_96_model_path,
    score_fn=batch_poincare_dist_with_adaptive_curv_k
)

ont_snomed_96_w_con_sub = OnTRetriever(
    embeddings_fp=Path(f"{embeddings_dir}/ont-snomed-96-embeddings.npy"),
    meta_map_fp=common_map,
    verbalisations_fp=common_verbalisations,
    model_fp=ont_snomed_96_model_path,
    score_fn=concept_subsumption
)

In [45]:
ont_snomed_minified_128_model_fp = Path('../models/snomed_models/OnTr-m-128')

ont_ret_snomed_minified_128_w_hyp_dist = OnTRetriever(
    embeddings_fp=Path(f"{embeddings_dir}/ont-snomed-minified-128-embeddings.npy"),
    meta_map_fp=common_map,
    verbalisations_fp=common_verbalisations,
    model_fp=ont_snomed_minified_128_model_fp,
    score_fn=batch_poincare_dist_with_adaptive_curv_k
)

ont_ret_snomed_minified_128_w_con_sub = OnTRetriever(
    embeddings_fp=Path(f"{embeddings_dir}/ont-snomed-minified-128-embeddings.npy"),
    meta_map_fp=common_map,
    verbalisations_fp=common_verbalisations,
    model_fp=ont_snomed_minified_128_model_fp,
    score_fn=concept_subsumption
)



In [46]:
_model = ont_snomed_96_w_con_sub._model
MAX_K = 20
centri_w = 0.37
query_string = "Fracture of Foot"

ont_snomed_96_w_con_sub.retrieve("fracture of foot", top_k=MAX_K, reverse_candidate_scores=True, model=_model, weight=centri_w)

[(0,
  'http://snomed.info/id/15574005',
  np.float32(-9.34402e-06),
  'fracture of foot'),
 (1,
  'http://snomed.info/id/424644003',
  np.float32(-4.260316),
  'pathological fracture of foot'),
 (2,
  'http://snomed.info/id/125604000',
  np.float32(-6.1606207),
  'injury of foot'),
 (3,
  'http://snomed.info/id/208733004',
  np.float32(-7.205575),
  'multiple fractures of foot'),
 (4,
  'http://snomed.info/id/704065008',
  np.float32(-8.02222),
  'stress fracture of foot'),
 (5,
  'http://snomed.info/id/287070000',
  np.float32(-8.655782),
  'pathological fracture - ankle and/or foot'),
 (6, 'http://snomed.info/id/72704001', np.float32(-8.855542), 'fracture'),
 (7, 'http://snomed.info/id/283439004', np.float32(-8.86903), 'cut of foot'),
 (8,
  'http://snomed.info/id/38961000087108',
  np.float32(-9.004515),
  'fracture of right foot'),
 (9,
  'http://snomed.info/id/367527001',
  np.float32(-9.015572),
  'open fracture of foot'),
 (10,
  'http://snomed.info/id/735901006',
  np.float32(

In [47]:
_model = hit_ret_snomed_25_w_ent_sub._model
MAX_K = 10 or len(hit_ret_snomed_25_w_ent_sub._verbalisations)
centri_w = 0.1
query_string = "fracture of foot"

hit_ret_snomed_25_w_ent_sub.retrieve("fracture of foot", top_k=MAX_K, reverse_candidate_scores=True, model=_model, weight=centri_w)

[(0,
  'http://snomed.info/id/15574005',
  np.float32(-2.803206e-05),
  'fracture of foot'),
 (1,
  'http://snomed.info/id/424644003',
  np.float32(-11.7800455),
  'pathological fracture of foot'),
 (2,
  'http://snomed.info/id/38961000087108',
  np.float32(-15.377461),
  'fracture of right foot'),
 (3,
  'http://snomed.info/id/38951000087105',
  np.float32(-16.080738),
  'fracture of left foot'),
 (4,
  'http://snomed.info/id/704065008',
  np.float32(-20.466463),
  'stress fracture of foot'),
 (5,
  'http://snomed.info/id/287070000',
  np.float32(-20.531363),
  'pathological fracture - ankle and/or foot'),
 (6, 'http://snomed.info/id/64572001', np.float32(-20.5541), 'disease'),
 (7,
  'http://snomed.info/id/342070009',
  np.float32(-20.958042),
  'closed fracture of foot'),
 (8,
  'http://snomed.info/id/208733004',
  np.float32(-21.201408),
  'multiple fractures of foot'),
 (9,
  'http://snomed.info/id/125604000',
  np.float32(-21.232725),
  'injury of foot')]

In [48]:
_model = hit_ret_snomed_25_w_ent_sub._model
MAX_K = 10 or len(hit_ret_snomed_25_w_ent_sub._verbalisations)
centri_w = 0.1

hit_ret_snomed_25_w_ent_sub.retrieve("Fracture of hand disorder", top_k=MAX_K, reverse_candidate_scores=True, model=_model, weight=centri_w)

[(0,
  'http://snomed.info/id/20511007',
  np.float32(-14.9000435),
  'fracture of hand'),
 (1,
  'http://snomed.info/id/287067004',
  np.float32(-15.449109),
  'pathological fracture - hand'),
 (2,
  'http://snomed.info/id/118933004',
  np.float32(-19.775383),
  'disorder of hand'),
 (3, 'http://snomed.info/id/64572001', np.float32(-19.857185), 'disease'),
 (4,
  'http://snomed.info/id/782964007',
  np.float32(-20.855976),
  'genetic disease'),
 (5,
  'http://snomed.info/id/208388003',
  np.float32(-20.943039),
  'fracture at wrist and/or hand level'),
 (6,
  'http://snomed.info/id/281466008',
  np.float32(-21.24665),
  'disorder of fracture healing'),
 (7,
  'http://snomed.info/id/287075005',
  np.float32(-21.266232),
  'fracture malunion - hand'),
 (8,
  'http://snomed.info/id/404684003',
  np.float32(-21.339437),
  'clinical finding'),
 (9,
  'http://snomed.info/id/263211006',
  np.float32(-21.694853),
  'fracture of sesamoid bone of hand')]