In [1]:
from pheval_exomiser.prepare.core.chromadb_manager import ChromaDBManager
from pheval_exomiser.prepare.utils.similarity_measures import SimilarityMeasures
manager = ChromaDBManager(similarity=SimilarityMeasures.COSINE)

In [2]:
manager.ont_hp

Collection(name=ont_hp)

In [2]:
manager.list_collections()

[Collection(name=avgDiseaseEmbeddings),
 Collection(name=test),
 Collection(name=newTEST),
 Collection(name=newEmbeddingsFromHpOntology_phase2),
 Collection(name=crazy),
 Collection(name=DiseaseOrganEmbeddings),
 Collection(name=avgDiseaseEmbeddings_phase2),
 Collection(name=hpo),
 Collection(name=DiseaseNewAvgEmbeddingsNew),
 Collection(name=HpEmbeddings),
 Collection(name=DiseaseNewOrganEmbeddings),
 Collection(name=ont_hp),
 Collection(name=average)]

In [4]:
ont = manager.get_collection("ont_hp")
ont.get(limit=1)

{'ids': ['10MinuteAPGARScoreOf0'],
 'embeddings': None,
 'metadatas': [{'_json': '{"id": "10MinuteAPGARScoreOf0", "label": "10-minute APGAR score of 0", "definition": null, "aliases": null, "relationships": [{"predicate": "subClassOf", "target": "Low10MinuteAPGARScore"}], "logical_definition": null, "original_id": "HP:0033468"}',
   'id': '10MinuteAPGARScoreOf0',
   'label': '10-minute APGAR score of 0',
   'original_id': 'HP:0033468'}],
 'documents': ["10-minute APGAR score of 0 None [{'predicate': 'subClassOf', 'target': 'Low10MinuteAPGARScore'}]"],
 'uris': None,
 'data': None}

In [5]:

import json
from typing import Dict
from chromadb.types import Collection


def create_hpo_id_to_embedding(collection: Collection) -> Dict:
    hpo_id_to_data = {}
    results = collection.get(include=["metadatas", "embeddings"])
    for metadata, embedding in zip(results.get("metadatas", []), results.get("embeddings", []), strict=False):
        metadata_json = json.loads(metadata["_json"])
        hpo_id = metadata_json.get("original_id")
        if hpo_id:
            hpo_id_to_data[hpo_id] = {"embeddings": embedding}  # #{'HP:0005872': [1,2,3, ...]}
    return hpo_id_to_data

cachedDict = create_hpo_id_to_embedding(ont)
manager.create_collection("HpEmbeddings")
hpToEmbedding = manager.get_collection("HpEmbeddings")
for hp, data in cachedDict.items():
    embedding_list = data['embeddings']
    hpToEmbedding.upsert(ids=[hp], embeddings=[embedding_list], metadatas=[{"type": "HP"}])

In [6]:
from pheval_exomiser.prepare.core.OMIMHPOExtractor import OMIMHPOExtractor
from pheval_exomiser.prepare.core.data_processor import DataProcessor

data_processor = DataProcessor(db_manager=manager)
file_path = "/Users/carlo/PycharmProjects/chroma_db_playground/phenotype.hpoa"
with open(file_path, 'r') as file:
    data = file.read()

extractor = OMIMHPOExtractor
omimToHPdict = extractor.extract_omim_hpo_mappings(data)

manager.create_collection("avgDiseaseEmbeddings")
diseaseAvgEmbedings = manager.get_collection("avgDiseaseEmbeddings")
for disease, hps in omimToHPdict.items():
    average_embedding = data_processor.calculate_average_embedding(hps, cachedDict)
    diseaseAvgEmbedings.upsert(ids=[disease], embeddings=[average_embedding.tolist()],
                               metadatas=[{"type": "disease"}])

12468


In [74]:
avg = manager.get_collection("avgDiseaseEmbeddings")
x = avg.get(ids="OMIM:619340", include=['embeddings'])
# foo = [x for x['embeddings']['data'] in avg.get(ids="OMIM:619340", include=['embeddings'])]
# print(type(x))
# print(type(foo))
x

{'ids': ['OMIM:619340'],
 'embeddings': [[-0.012174168601632118,
   0.01349125150591135,
   0.009389386512339115,
   -0.024743353947997093,
   -0.02052556350827217,
   0.015198228880763054,
   -0.006085923407226801,
   0.00467239273712039,
   -0.013423134572803974,
   -0.016600506380200386,
   0.0052439444698393345,
   0.033820990473032,
   0.006892082747071981,
   0.0040873391553759575,
   0.003584218444302678,
   0.015599209815263748,
   0.036792851984500885,
   0.0012920359149575233,
   0.009129415266215801,
   -0.0016922313952818513,
   0.0012177551398053765,
   0.02098451368510723,
   -0.015805931761860847,
   -0.010085664689540863,
   -0.013112599961459637,
   0.004157305229455233,
   0.004549398086965084,
   -0.02862929180264473,
   -0.012716548517346382,
   -0.011727680452167988,
   0.01503787375986576,
   -0.00663196062669158,
   -0.0033757921773940325,
   -0.01404667366296053,
   -0.0055846781469881535,
   -0.007513368036597967,
   0.004000385757535696,
   -0.0111718242987990

Shape of average_embedding_cachedDict: (1536,)
Shape of average_embedding_diseaseAvgEmbeddings: (1536,)


In [68]:
avg.get(ids="OMIM:619340", include=['embeddings'])

{'ids': ['OMIM:619340'],
 'embeddings': [[-0.012174168601632118,
   0.01349125150591135,
   0.009389386512339115,
   -0.024743353947997093,
   -0.02052556350827217,
   0.015198228880763054,
   -0.006085923407226801,
   0.00467239273712039,
   -0.013423134572803974,
   -0.016600506380200386,
   0.0052439444698393345,
   0.033820990473032,
   0.006892082747071981,
   0.0040873391553759575,
   0.003584218444302678,
   0.015599209815263748,
   0.036792851984500885,
   0.0012920359149575233,
   0.009129415266215801,
   -0.0016922313952818513,
   0.0012177551398053765,
   0.02098451368510723,
   -0.015805931761860847,
   -0.010085664689540863,
   -0.013112599961459637,
   0.004157305229455233,
   0.004549398086965084,
   -0.02862929180264473,
   -0.012716548517346382,
   -0.011727680452167988,
   0.01503787375986576,
   -0.00663196062669158,
   -0.0033757921773940325,
   -0.01404667366296053,
   -0.0055846781469881535,
   -0.007513368036597967,
   0.004000385757535696,
   -0.0111718242987990

In [67]:
diseaseAvgEmbedings.get(ids="OMIM:619340", include=['embeddings'])

{'ids': ['OMIM:619340'],
 'embeddings': [[-0.012174168601632118,
   0.01349125150591135,
   0.009389386512339115,
   -0.024743353947997093,
   -0.02052556350827217,
   0.015198228880763054,
   -0.006085923407226801,
   0.00467239273712039,
   -0.013423134572803974,
   -0.016600506380200386,
   0.0052439444698393345,
   0.033820990473032,
   0.006892082747071981,
   0.0040873391553759575,
   0.003584218444302678,
   0.015599209815263748,
   0.036792851984500885,
   0.0012920359149575233,
   0.009129415266215801,
   -0.0016922313952818513,
   0.0012177551398053765,
   0.02098451368510723,
   -0.015805931761860847,
   -0.010085664689540863,
   -0.013112599961459637,
   0.004157305229455233,
   0.004549398086965084,
   -0.02862929180264473,
   -0.012716548517346382,
   -0.011727680452167988,
   0.01503787375986576,
   -0.00663196062669158,
   -0.0033757921773940325,
   -0.01404667366296053,
   -0.0055846781469881535,
   -0.007513368036597967,
   0.004000385757535696,
   -0.0111718242987990

In [50]:
import numpy as np
omim_619340_hpo_terms = list(omimToHPdict["OMIM:619340"])

# check if hp temrs from omim619340 retrieved from phenotype.hpoa and compared to inside cachedDict have same average then disAvgEmbeddings collection.get("OMIM:619340)

embeddings_cachedDict = [cachedDict[hp]['embeddings'] for hp in omim_619340_hpo_terms if hp in omim_619340_hpo_terms]
array = np.mean(embeddings_cachedDict, axis=0) if embeddings_cachedDict else []
# embeddings_omim_dict = [omimToHPdict.]
array2 = np.mean from the average of HPs inside the diseaseAvgEmbedings collection
assert array == array2


[[-0.01141765434294939,
  0.02041873149573803,
  0.023188292980194092,
  -0.019943561404943466,
  -0.02606646530330181,
  0.013895326294004917,
  -0.011824943125247955,
  0.013535554520785809,
  -0.008444448933005333,
  -0.030112197622656822,
  0.021654171869158745,
  0.04406861588358879,
  0.011356561444699764,
  -0.015585573390126228,
  0.007996431551873684,
  0.023867106065154076,
  0.020337272435426712,
  -0.0038794230204075575,
  0.012429087422788143,
  0.0002827685384545475,
  0.010582713410258293,
  0.03299036994576454,
  -0.0064996457658708096,
  -0.010317975655198097,
  -0.025985006242990494,
  0.005749556235969067,
  0.011349773034453392,
  -0.01391569059342146,
  -0.0030631490517407656,
  -0.0250753965228796,
  0.03062809631228447,
  -0.002367364475503564,
  0.0005235354183241725,
  -0.013264029286801815,
  -0.010928908362984657,
  0.007914973422884941,
  0.005997323431074619,
  0.0004130160086788237,
  0.021884968504309654,
  0.0036655967123806477,
  0.0011183463502675295,


In [70]:
import numpy as np

# Retrieve HPO terms for OMIM:619340
omim_619340_hpo_terms = list(omimToHPdict["OMIM:619340"])

# Compute average embedding from cachedDict
embeddings_cachedDict = [cachedDict[hp]['embeddings'] for hp in omim_619340_hpo_terms if hp in cachedDict]
average_embedding_cachedDict = np.mean(embeddings_cachedDict, axis=0) if embeddings_cachedDict else np.array([])

# Retrieve precomputed average embedding from diseaseAvgEmbeddings collection
# Assuming `diseaseAvgEmbeddings` is a dict-like collection and has a method `get` to retrieve data
average_embedding_diseaseAvgEmbeddings = diseaseAvgEmbedings.get("OMIM:619340", np.array([]))

# Comparing the two average embeddings
# Using np.allclose for comparison due to potential floating-point precision differences
assert np.allclose(average_embedding_cachedDict, average_embedding_diseaseAvgEmbeddings), "Average embeddings do not match."


TypeError: ufunc 'isfinite' not supported for the input types, and the inputs could not be safely coerced to any supported types according to the casting rule ''safe''

In [76]:
import numpy as np

# Assuming avg is your collection manager
avg = manager.get_collection("avgDiseaseEmbeddings")
data = avg.get(ids="OMIM:619340", include=['embeddings'])

# Extract the embedding for OMIM:619340 and convert to numpy array
average_embedding_diseaseAvgEmbeddings = np.array(data['embeddings'][0])

# Ensure the embedding from cachedDict is also a numpy array
average_embedding_cachedDict = np.array(average_embedding_cachedDict)

# Print shapes of both arrays
print("Shape of average_embedding_cachedDict:", average_embedding_cachedDict.shape)
print("Shape of average_embedding_diseaseAvgEmbeddings:", average_embedding_diseaseAvgEmbeddings.shape)

# Check if both arrays have the same shape and are not empty
if average_embedding_cachedDict.shape == average_embedding_diseaseAvgEmbeddings.shape and average_embedding_cachedDict.size > 0:
    # Compare the arrays
    assert np.allclose(average_embedding_cachedDict, average_embedding_diseaseAvgEmbeddings), "Average embeddings do not match."
else:
    # Handle the case where shapes are different or arrays are empty
    raise ValueError("The shapes of the average embeddings do not match or one/both are empty.")

Shape of average_embedding_cachedDict: (1536,)
Shape of average_embedding_diseaseAvgEmbeddings: (1536,)


In [77]:
# Check if both arrays are equal within a tolerance
are_arrays_equal = np.allclose(average_embedding_cachedDict, average_embedding_diseaseAvgEmbeddings)

# Print the result
print("Are the arrays equal?:", are_arrays_equal)


Are the arrays equal?: True


In [72]:
# Ensure both embeddings are numpy arrays
average_embedding_cachedDict = np.array(average_embedding_cachedDict)
average_embedding_diseaseAvgEmbeddings = np.array(average_embedding_diseaseAvgEmbeddings)

# Print shapes of both arrays
print("Shape of average_embedding_cachedDict:", average_embedding_cachedDict.shape)
print("Shape of average_embedding_diseaseAvgEmbeddings:", average_embedding_diseaseAvgEmbeddings.shape)

# Proceed with the rest of your code...


Shape of average_embedding_cachedDict: (1536,)
Shape of average_embedding_diseaseAvgEmbeddings: ()


In [71]:
# Ensure both embeddings are numpy arrays
average_embedding_cachedDict = np.array(average_embedding_cachedDict)
average_embedding_diseaseAvgEmbeddings = np.array(average_embedding_diseaseAvgEmbeddings)

# Check if both arrays have the same shape before comparing
if average_embedding_cachedDict.shape == average_embedding_diseaseAvgEmbeddings.shape:
    # Using np.allclose for comparison
    assert np.allclose(average_embedding_cachedDict, average_embedding_diseaseAvgEmbeddings), "Average embeddings do not match."
else:
    # Handle the case where the arrays have different shapes
    raise ValueError("The shapes of the average embeddings do not match.")


ValueError: The shapes of the average embeddings do not match.

In [39]:
cachedDict.get("HP:0001518")

{'embeddings': [-0.00836104154586792,
  -0.0014480522368103266,
  0.016945745795965195,
  0.00492716021835804,
  -0.010018778033554554,
  0.005131087731570005,
  -0.016932589933276176,
  0.009091234765946865,
  -0.021274279803037643,
  -0.042680125683546066,
  0.026997415348887444,
  0.03878576308488846,
  0.005134377162903547,
  -0.005604726728051901,
  0.0027908512856811285,
  0.024195052683353424,
  0.02574753575026989,
  0.007255884353071451,
  0.005006100051105022,
  -0.009393838234245777,
  0.007420342415571213,
  0.03157592564821243,
  -0.04170653596520424,
  4.846370211453177e-05,
  -0.024076642468571663,
  0.005101485643535852,
  0.008466294966638088,
  -0.03133910521864891,
  0.0017909470479935408,
  -0.014590708538889885,
  0.01049241703003645,
  -0.012479068711400032,
  -0.022524159401655197,
  -0.0026346163358539343,
  -0.01741938479244709,
  0.007242728024721146,
  -0.0005163979367353022,
  -0.011005525477230549,
  0.0153537942096591,
  -0.027418429031968117,
  0.01076212

In [14]:
hpListOfOMIM619340 = ['HP:0001518',
 'HP:0001522',
 'HP:0010851',
 'HP:0002643',
 'HP:0032792',
 'HP:0002187',
 'HP:0000006',
 'HP:0200134',
 'HP:0011451',
 'HP:0001789',
 'HP:0011097']

OMIM619340 = ['HP:0010851', 'HP:0011097', 'HP:0002643', 'HP:0032792', 'HP:0011451', 'HP:0200134', 'HP:0001518', 'HP:0002187', 'HP:0000006', 'HP:0001522', 'HP:0001789']


diseaseAvgEmbedings = manager.get_collection("avgDiseaseEmbeddings")
# ont_hp = manager.get_collection("ont_hp")

# cachedDict = manager.create_hpo_id_to_embedding(ont_hp)
avg_embedding = data_processor.calculate_average_embedding(OMIM619340, cachedDict)

if avg_embedding is None:
    print("No valid embeddings found for provided HPO terms.")

query_results = diseaseAvgEmbedings.query(
    query_embeddings=[avg_embedding.tolist()],
    n_results=100,
    include=["embeddings", "distances"]
)

disease_ids = query_results['ids'][0] if 'ids' in query_results and query_results['ids'] else []
distances = query_results['distances'][0] if 'distances' in query_results and query_results['distances'] else []
sorted_results = sorted(zip(disease_ids, distances), key=lambda x: x[1])
# distances
sorted_results

[1.1920928955078125e-06,
 0.012330412864685059,
 0.013801753520965576,
 0.014431536197662354,
 0.014481723308563232,
 0.01506197452545166,
 0.01537412405014038,
 0.01538097858428955,
 0.015410244464874268,
 0.015469074249267578,
 0.015494465827941895,
 0.01588517427444458,
 0.015889763832092285,
 0.01601111888885498,
 0.016223132610321045,
 0.01626121997833252,
 0.01631838083267212,
 0.01635289192199707,
 0.01637125015258789,
 0.016381382942199707,
 0.01642155647277832,
 0.016445159912109375,
 0.016647756099700928,
 0.01665431261062622,
 0.0166814923286438,
 0.016730785369873047,
 0.016777634620666504,
 0.016793012619018555,
 0.016864657402038574,
 0.016886770725250244,
 0.016893982887268066,
 0.01705414056777954,
 0.017084479331970215,
 0.017090201377868652,
 0.017149269580841064,
 0.017186462879180908,
 0.01718956232070923,
 0.01723027229309082,
 0.017379939556121826,
 0.01747143268585205,
 0.017500877380371094,
 0.017506837844848633,
 0.01750844717025757,
 0.017543792724609375,
 0.0

In [78]:
def query_disease_avg_embeddings(omim_id, omim_to_hp_dict, cached_dict, data_processor, disease_avg_embeddings_manager):
    """
    Queries the diseaseAvgEmbeddings collection with the average embedding of HPO terms associated with an OMIM ID.

    :param omim_id: The OMIM ID to query.
    :param omim_to_hp_dict: Dictionary mapping OMIM IDs to lists of HPO terms.
    :param cached_dict: Dictionary mapping HPO terms to their embeddings.
    :param data_processor: DataProcessor instance for calculating average embeddings.
    :param disease_avg_embeddings_manager: Manager for the diseaseAvgEmbeddings collection.
    :return: Sorted list of (disease_id, distance) tuples.
    """
    # Retrieve the list of HPO terms for the given OMIM ID
    hpo_terms = omim_to_hp_dict.get(omim_id, [])

    # Calculate the average embedding for these HPO terms
    avg_embedding = data_processor.calculate_average_embedding(hpo_terms, cached_dict)

    # Handle cases where avg_embedding is None or empty
    if avg_embedding is None or not avg_embedding.size:
        print("No valid embeddings found for provided HPO terms.")
        return []

    # Query the diseaseAvgEmbeddings collection
    query_results = diseaseAvgEmbedings.query(
        query_embeddings=[avg_embedding.tolist()],
        n_results=100,
        include=["embeddings", "distances"]
    )

    # Extract and sort the results
    disease_ids = query_results.get('ids', [])
    distances = query_results.get('distances', [])
    sorted_results = sorted(zip(disease_ids, distances), key=lambda x: x[1])

    return sorted_results


In [81]:
omim_id = "OMIM:619340"
sorted_query_results = query_disease_avg_embeddings(omim_id, omimToHPdict, cachedDict, data_processor, diseaseAvgEmbedings)
print(sorted_query_results)


[(['OMIM:619340', 'OMIM:251280', 'OMIM:266100', 'OMIM:620033', 'OMIM:617065', 'OMIM:617105', 'OMIM:612164', 'OMIM:609304', 'OMIM:619881', 'OMIM:300607', 'OMIM:301058', 'OMIM:300672', 'ORPHA:289266', 'OMIM:614558', 'OMIM:617711', 'OMIM:619847', 'OMIM:617929', 'OMIM:617166', 'OMIM:620024', 'OMIM:600721', 'OMIM:308350', 'OMIM:612949', 'OMIM:618298', 'ORPHA:3006', 'OMIM:613720', 'OMIM:220120', 'OMIM:617290', 'OMIM:608097', 'OMIM:619605', 'OMIM:619913', 'ORPHA:95232', 'OMIM:617162', 'OMIM:614231', 'OMIM:619229', 'OMIM:617389', 'OMIM:620145', 'OMIM:619239', 'OMIM:613661', 'OMIM:615282', 'OMIM:300673', 'OMIM:617132', 'OMIM:617933', 'OMIM:616973', 'OMIM:245570', 'OMIM:617601', 'OMIM:619606', 'ORPHA:1934', 'OMIM:617391', 'OMIM:618890', 'OMIM:618663', 'OMIM:615476', 'OMIM:300884', 'OMIM:620359', 'OMIM:609056', 'OMIM:615851', 'OMIM:618012', 'OMIM:618174', 'OMIM:617771', 'OMIM:272300', 'OMIM:620167', 'OMIM:618235', 'OMIM:615338', 'OMIM:604317', 'ORPHA:209370', 'OMIM:617976', 'OMIM:617599', 'OMIM:6

In [15]:
cachedDict.get("HP:0010851")

{'embeddings': [-0.032377034425735474,
  0.019941002130508423,
  0.016879407688975334,
  -0.022094953805208206,
  -0.011067797429859638,
  0.01621561124920845,
  0.016500094905495644,
  0.0002233117847936228,
  -0.028583908453583717,
  -0.025671329349279404,
  -0.002524798968806863,
  0.04611356556415558,
  -0.0020709787495434284,
  0.004355320706963539,
  0.010695258155465126,
  0.011304867453873158,
  0.03617015853524208,
  -0.004277425818145275,
  0.002413037233054638,
  0.011772234924137592,
  0.0013521475484594703,
  0.043918970972299576,
  -0.024844970554113388,
  -0.013865227811038494,
  -0.0023486895952373743,
  -0.006969867739826441,
  0.0037220041267573833,
  -0.03684750199317932,
  -0.026633158326148987,
  0.025441031903028488,
  -0.00021198744070716202,
  -0.010058555752038956,
  0.0025789865758270025,
  0.0006739570526406169,
  -0.01261214166879654,
  -0.021119579672813416,
  0.006377191748470068,
  -0.00298708607442677,
  0.03977362811565399,
  0.019466860219836235,
  0.0

In [13]:
diseaseAvgEmbedings.get("OMIM:619340", include=['embeddings'])
# {'ids': ['OMIM:619340'],
#  'embeddings': [[-0.012196474708616734,
#    0.013468306511640549,
#    0.009398766793310642,
#    -0.024764113128185272,
#    -0.020572684705257416,
#    0.015203890390694141,

{'ids': ['OMIM:619340'],
 'embeddings': [[-0.012174168601632118,
   0.01349125150591135,
   0.009389386512339115,
   -0.024743353947997093,
   -0.02052556350827217,
   0.015198228880763054,
   -0.006085923407226801,
   0.00467239273712039,
   -0.013423134572803974,
   -0.016600506380200386,
   0.0052439444698393345,
   0.033820990473032,
   0.006892082747071981,
   0.0040873391553759575,
   0.003584218444302678,
   0.015599209815263748,
   0.036792851984500885,
   0.0012920359149575233,
   0.009129415266215801,
   -0.0016922313952818513,
   0.0012177551398053765,
   0.02098451368510723,
   -0.015805931761860847,
   -0.010085664689540863,
   -0.013112599961459637,
   0.004157305229455233,
   0.004549398086965084,
   -0.02862929180264473,
   -0.012716548517346382,
   -0.011727680452167988,
   0.01503787375986576,
   -0.00663196062669158,
   -0.0033757921773940325,
   -0.01404667366296053,
   -0.0055846781469881535,
   -0.007513368036597967,
   0.004000385757535696,
   -0.0111718242987990

In [83]:
x = ont.get(include=['metadatas', 'embeddings'])
len(x['embeddings'])

29912