In [1]:
from chromadb_manager import ChromaDBManager
db_manager = ChromaDBManager("/Users/carlo/Downloads/curate-gpt/db")


In [2]:
omimToHPdict = db_manager.extract_and_use_omim_hpo_mappings("/Users/carlo/PycharmProjects/chroma_db_playground/phenotype.hpoa")


In [3]:
len(omimToHPdict)

12468

In [4]:
ont_hp = db_manager.get_collection("ont_hp")
cachedDict = db_manager.create_hpo_id_to_embedding(ont_hp)


In [5]:
hpoToEmbedding = db_manager.get_collection("HPOtoEmbeddings")


In [6]:
for hp, data in cachedDict.items():
    embedding_list = data['embeddings']  # Extract the embedding list
    hpoToEmbedding.upsert(ids=[hp], embeddings=[embedding_list], metadatas=[{"type": "HP"}])

In [10]:
#Test hpoToEmbedding collection -> HP:0010851, HP:0000478
hpoToEmbedding.get("HP:0010851", include=['embeddings'])


{'ids': ['HP:0010851'],
 'embeddings': [[-0.032349344342947006,
   0.019669702276587486,
   0.01702811010181904,
   -0.02214873395860195,
   -0.01125047355890274,
   0.016093391925096512,
   0.016432058066129684,
   0.0002599259023554623,
   -0.02855628915131092,
   -0.025874055922031403,
   -0.0024807259906083345,
   0.046166904270648956,
   -0.0021420603152364492,
   0.0042908936738967896,
   0.010756021365523338,
   0.011480765417218208,
   0.03635914623737335,
   -0.004236707463860512,
   0.0022521265782415867,
   0.011724605225026608,
   0.0015392354689538479,
   0.043674323707818985,
   -0.024844512343406677,
   -0.013654999434947968,
   -0.0025586190167814493,
   -0.006990059278905392,
   0.0037828953936696053,
   -0.036873918026685715,
   -0.026700399816036224,
   0.025386378169059753,
   -0.00021283020032569766,
   -0.01004482340067625,
   0.0026178855914622545,
   0.000680717988871038,
   -0.012760922312736511,
   -0.021308843046426773,
   0.006478674244135618,
   -0.00307169

In [7]:
diseaseAvgEmbedings = db_manager.get_collection("DiseaseAvgEmbeddings")


In [8]:
import numpy as np
def calculate_average_embedding_from_cachedDict(hps, embeddings_dict):
    embeddings = [embeddings_dict[hp_id]['embeddings'] for hp_id in hps if hp_id in embeddings_dict]
    return np.mean(embeddings, axis=0) if embeddings else []

In [9]:
for disease, hps in omimToHPdict.items():
    average_embedding = calculate_average_embedding_from_cachedDict(hps, cachedDict)
    print(f"{disease}: {average_embedding}")
    diseaseAvgEmbedings.upsert(ids=[disease], embeddings=[average_embedding.tolist()], metadatas=[{"type": "disease"}])

OMIM:619340: [-0.01219648  0.01346831  0.00939877 ... -0.00323286 -0.01156505
 -0.02035956]
OMIM:609153: [-0.01793644 -0.00505215  0.01455887 ... -0.01294534  0.00785382
 -0.01782576]
OMIM:614102: [-0.00249872  0.00447614  0.01253047 ... -0.01868893  0.00047552
 -0.00872546]
OMIM:619426: [-0.00762127  0.01078799  0.01818741 ... -0.00396307 -0.00135435
 -0.02842093]
OMIM:610370: [-0.01016556  0.00656168  0.01020029 ... -0.01063883 -0.00041203
 -0.01758707]
OMIM:609621: [-0.02332948  0.01000722  0.00431693 ...  0.0027278  -0.00280167
 -0.03001673]
OMIM:212790: [-0.00959551  0.007133    0.01445276 ... -0.0207679  -0.00290697
 -0.02809429]
OMIM:612567: [-0.00239858  0.00880057  0.01582952 ... -0.01190322  0.00397301
 -0.02388836]
OMIM:613679: [-0.01436651 -0.00576205  0.02446365 ... -0.02108799  0.00489418
 -0.01430917]
OMIM:614116: [-0.01415545  0.00226648  0.02428199 ... -0.01228464 -0.00380219
 -0.02675991]
OMIM:612201: [-0.00999363  0.01174399  0.00825426 ... -0.01000711  0.00357576
 -

In [11]:
# Test
diseaseAvgEmbedings.get("OMIM:619340", include=['embeddings'])

{'ids': ['OMIM:619340'],
 'embeddings': [[-0.012196474708616734,
   0.013468306511640549,
   0.009398766793310642,
   -0.024764113128185272,
   -0.020572684705257416,
   0.015203890390694141,
   -0.0061103226616978645,
   0.0046564023941755295,
   -0.01347385998815298,
   -0.01658846065402031,
   0.005214905831962824,
   0.03387384116649628,
   0.0068690781481564045,
   0.004080270882695913,
   0.0035723100882023573,
   0.01564258523285389,
   0.03677360713481903,
   0.0012868329649791121,
   0.009110218845307827,
   -0.0016628250014036894,
   0.0012369840405881405,
   0.02088215947151184,
   -0.015826230868697166,
   -0.010065481998026371,
   -0.013123854994773865,
   0.004171126987785101,
   0.004536786582320929,
   -0.02864968590438366,
   -0.012764071114361286,
   -0.011612300761044025,
   0.015040704980492592,
   -0.006631876807659864,
   -0.003404767019674182,
   -0.013999487273395061,
   -0.005562743172049522,
   -0.007567103952169418,
   0.003941751550883055,
   -0.011159550398

In [14]:
x = hpoToEmbedding.get(include=['embeddings'])
len(x['embeddings'])

29912

In [15]:
x = diseaseAvgEmbedings.get(include=['embeddings'])
len(x['embeddings'])

12468