In [4]:
import numpy as np
import chromadb
from chromadb_manager import ChromaDBManager

In [5]:
db_manager = ChromaDBManager("/Users/carlo/Downloads/curate-gpt/db")


In [6]:
db_manager.list_collections()

[Collection(name=hpo_embedding_collection_test),
 Collection(name=ont_oba),
 Collection(name=disease_average_embedding_test),
 Collection(name=ont_mp),
 Collection(name=hpoa),
 Collection(name=ont_hp),
 Collection(name=ont_fbbt),
 Collection(name=ont_uberon),
 Collection(name=ont_obi),
 Collection(name=HPOtoEmbeddings),
 Collection(name=ont_agro),
 Collection(name=default),
 Collection(name=ont_cl),
 Collection(name=ont_chebi),
 Collection(name=ont_envo),
 Collection(name=ont_mondo),
 Collection(name=ont_po),
 Collection(name=DiseaseAvgEmbeddings),
 Collection(name=ont_go),
 Collection(name=ont_nbo)]

In [32]:


from typing import List


def query_diseases_by_hpo_terms(hpo_ids: List[str]) -> List:
    """
    Queries the 'DiseaseAvgEmbeddings' collection for diseases closest to the average embeddings of given HPO terms.

    :param hpo_ids: List of HPO term IDs.
    :return: List of diseases sorted by closeness to the average HPO embeddings.
    """
    diseaseAvgEmbedings = db_manager.get_collection("DiseaseAvgEmbeddings")
    ont_hp = db_manager.get_collection("ont_hp")

    cachedDict = db_manager.create_hpo_id_to_embedding(ont_hp)
    avg_embedding = db_manager.calculate_average_embedding_from_cachedDict(hpo_ids, cachedDict)
    
    if avg_embedding is None:
        return "No valid embeddings found for provided HPO terms."

    query_results = diseaseAvgEmbedings.query(
        query_embeddings=[avg_embedding.tolist()],
        n_results=10,
        include=["embeddings", "distances"]
    )

    disease_ids = query_results['ids'][0] if 'ids' in query_results and query_results['ids'] else []
    distances = query_results['distances'][0] if 'distances' in query_results and query_results['distances'] else []
    sorted_results = sorted(zip(disease_ids, distances), key=lambda x: x[1])

    return sorted_results    


In [31]:
query_diseases_by_hpo_terms(["HP:0010851", "HP:0000478"])

[('ORPHA:99802', 0.101658396422863),
 ('ORPHA:3006', 0.10190735012292862),
 ('ORPHA:1942', 0.10313200205564499),
 ('ORPHA:168491', 0.10354506224393845),
 ('OMIM:618792', 0.10524667799472809),
 ('ORPHA:101030', 0.10562039911746979),
 ('ORPHA:411986', 0.10577473044395447),
 ('ORPHA:231178', 0.10754810591383239),
 ('ORPHA:1934', 0.10766223073005676),
 ('OMIM:619428', 0.10783161222934723)]

In [38]:
hpListOfOMIM619340 = ['HP:0001518',
 'HP:0001522',
 'HP:0010851',
 'HP:0002643',
 'HP:0032792',
 'HP:0002187',
 'HP:0000006',
 'HP:0200134',
 'HP:0011451',
 'HP:0001789',
 'HP:0011097']

In [39]:
query_diseases_by_hpo_terms(hpListOfOMIM619340)

[('OMIM:619340', 0.0),
 ('OMIM:251280', 0.01930435560643673),
 ('OMIM:266100', 0.021473241969943047),
 ('OMIM:617065', 0.02249375358223915),
 ('OMIM:620033', 0.02262789011001587),
 ('OMIM:617105', 0.02341277524828911),
 ('OMIM:619881', 0.02391430363059044),
 ('OMIM:301058', 0.023992136120796204),
 ('OMIM:612164', 0.024027930572628975),
 ('OMIM:609304', 0.024042846634984016)]

In [21]:
diseaseAvgEmbedings = db_manager.get_collection("DiseaseAvgEmbeddings")


In [34]:
diseaseAvgEmbedings.get("OMIM:619340", include=['embeddings'])

{'ids': ['OMIM:619340'],
 'embeddings': [[-0.012196474708616734,
   0.013468306511640549,
   0.009398766793310642,
   -0.024764113128185272,
   -0.020572684705257416,
   0.015203890390694141,
   -0.0061103226616978645,
   0.0046564023941755295,
   -0.01347385998815298,
   -0.01658846065402031,
   0.005214905831962824,
   0.03387384116649628,
   0.0068690781481564045,
   0.004080270882695913,
   0.0035723100882023573,
   0.01564258523285389,
   0.03677360713481903,
   0.0012868329649791121,
   0.009110218845307827,
   -0.0016628250014036894,
   0.0012369840405881405,
   0.02088215947151184,
   -0.015826230868697166,
   -0.010065481998026371,
   -0.013123854994773865,
   0.004171126987785101,
   0.004536786582320929,
   -0.02864968590438366,
   -0.012764071114361286,
   -0.011612300761044025,
   0.015040704980492592,
   -0.006631876807659864,
   -0.003404767019674182,
   -0.013999487273395061,
   -0.005562743172049522,
   -0.007567103952169418,
   0.003941751550883055,
   -0.011159550398

In [23]:
ont_hp = db_manager.get_collection("ont_hp")
cachedDict = db_manager.create_hpo_id_to_embedding(ont_hp)


In [35]:
cachedDict['OMIM:619340']

KeyError: 'OMIM:619340'

In [25]:
cachedDict['HP:0010851']

{'embeddings': [-0.032349344342947006,
  0.019669702276587486,
  0.01702811010181904,
  -0.02214873395860195,
  -0.01125047355890274,
  0.016093391925096512,
  0.016432058066129684,
  0.0002599259023554623,
  -0.02855628915131092,
  -0.025874055922031403,
  -0.0024807259906083345,
  0.046166904270648956,
  -0.0021420603152364492,
  0.0042908936738967896,
  0.010756021365523338,
  0.011480765417218208,
  0.03635914623737335,
  -0.004236707463860512,
  0.0022521265782415867,
  0.011724605225026608,
  0.0015392354689538479,
  0.043674323707818985,
  -0.024844512343406677,
  -0.013654999434947968,
  -0.0025586190167814493,
  -0.006990059278905392,
  0.0037828953936696053,
  -0.036873918026685715,
  -0.026700399816036224,
  0.025386378169059753,
  -0.00021283020032569766,
  -0.01004482340067625,
  0.0026178855914622545,
  0.000680717988871038,
  -0.012760922312736511,
  -0.021308843046426773,
  0.006478674244135618,
  -0.0030716974288225174,
  0.03979998826980591,
  0.019466502591967583,
  

In [29]:

hpo_ids = ["HP:0010851", "HP:0000478"]
diseaseAvgEmbedings = db_manager.get_collection("DiseaseAvgEmbeddings")
ont_hp = db_manager.get_collection("ont_hp")

# Calculate the average embedding for the HPO terms
cachedDict = db_manager.create_hpo_id_to_embedding(ont_hp)
avg_embedding = db_manager.calculate_average_embedding_from_cachedDict(hpo_ids, cachedDict)

if avg_embedding is None:
    print("NONE")

# Query the DiseaseAvgEmbeddings collection
query_results = diseaseAvgEmbedings.query(
    query_embeddings=[avg_embedding.tolist()],
    n_results=10,
    include=["embeddings", "distances"]
    )




print(query_results)


{'ids': [['ORPHA:99802', 'ORPHA:3006', 'ORPHA:1942', 'ORPHA:168491', 'OMIM:618792', 'ORPHA:101030', 'ORPHA:411986', 'ORPHA:231178', 'ORPHA:1934', 'OMIM:619428']], 'distances': [[0.101658396422863, 0.10190735012292862, 0.10313200205564499, 0.10354506224393845, 0.10524667799472809, 0.10562039911746979, 0.10577473044395447, 0.10754810591383239, 0.10766223073005676, 0.10783161222934723]], 'metadatas': None, 'embeddings': [[[-0.0124784205108881, 0.01214809063822031, 0.021168019622564316, -0.017344815656542778, -0.025811072438955307, 0.014329575002193451, 0.0021200154442340136, -0.0014365189708769321, -0.008101856335997581, -0.018051862716674805, 0.006674862466752529, 0.028887003660202026, 0.005537147633731365, 0.007171448320150375, -0.0021199581678956747, 0.007899182848632336, 0.03775918483734131, 0.003976758569478989, 0.00475371303036809, -0.004277605097740889, 0.0052110617980360985, 0.02193157747387886, -0.012384295463562012, -0.008787031285464764, -0.003965799696743488, 0.007441060151904