In [10]:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

In [11]:
with open("../training/cc_cols.txt", "r") as f:
    cc_cols = f.readlines()
cc_cols = list(map(lambda x:x[:-1], cc_cols))

with open("../training/pmh_cols.txt", "r") as f:
    pmh_cols = f.readlines()
pmh_cols = list(map(lambda x:x[:-1], pmh_cols))

In [12]:
import numpy as np

In [13]:
cc_embeddings = np.load("../OpenAIEmbeddings/cc_embeddings.npy")
pmh_embeddings = np.load("../OpenAIEmbeddings/pmh_embeddings.npy")

In [14]:
sev_words = ['heart attack', 'seizure', 'stroke']

sev_encodings = model.encode(sev_words)

In [15]:
cc_sev_scores = np.zeros((cc_embeddings.shape[0], len(sev_words)))
pmh_sev_scores = np.zeros((pmh_embeddings.shape[0], len(sev_words)))

for i in range(len(sev_words)):
    for j in range(cc_embeddings.shape[0]):
        cc_sev_scores[j][i] = np.dot(cc_embeddings[j], sev_encodings[i]) / (np.linalg.norm(cc_embeddings[j]) * np.linalg.norm(sev_encodings[i]))

    for j in range(pmh_embeddings.shape[0]):
        pmh_sev_scores[j][i] = np.dot(pmh_embeddings[j], sev_encodings[i]) / (np.linalg.norm(pmh_embeddings[j]) * np.linalg.norm(sev_encodings[i]))

In [16]:
cc_sev_scores = list(np.mean(cc_sev_scores, axis=1))
pmh_sev_scores = list(np.mean(pmh_sev_scores, axis=1))

In [17]:
cc_sev_scores = [(i, cc_sev_scores[i]) for i in range(len(cc_sev_scores))]
pmh_sev_scores = [(i, pmh_sev_scores[i]) for i in range(len(pmh_sev_scores))]

cc_sev_scores = sorted(cc_sev_scores, key=lambda x:x[1], reverse=True)
pmh_sev_scores = sorted(pmh_sev_scores, key=lambda x:x[1], reverse=True)

In [18]:
def get_sev_index(word):
    idx_orders = [i[0] for i in cc_sev_scores]
    return idx_orders.index(cc_cols.index(word))

In [20]:
cc_sev_scores

[(162, 0.5047700057427088),
 (139, 0.44079627593358356),
 (96, 0.42321041226387024),
 (160, 0.4203548381725947),
 (129, 0.4133596122264862),
 (183, 0.4034417967001597),
 (172, 0.4009161392847697),
 (145, 0.3892795940240224),
 (15, 0.37788426876068115),
 (53, 0.3673965136210124),
 (177, 0.3631592094898224),
 (149, 0.3625971774260203),
 (161, 0.36120906472206116),
 (173, 0.3487200637658437),
 (49, 0.33431461453437805),
 (46, 0.3336651523907979),
 (143, 0.32858828206857044),
 (29, 0.3247840801874797),
 (93, 0.31928780674934387),
 (144, 0.3143913100163142),
 (152, 0.31439074873924255),
 (198, 0.31378698348999023),
 (43, 0.3122790952523549),
 (107, 0.312131588657697),
 (73, 0.3074888934691747),
 (22, 0.30507340530554455),
 (176, 0.3045850495497386),
 (135, 0.30172261595726013),
 (57, 0.29690849781036377),
 (117, 0.2960127890110016),
 (28, 0.29274306694666546),
 (138, 0.29203397532304126),
 (199, 0.29138035078843433),
 (7, 0.28879011670748395),
 (13, 0.2870972951253255),
 (20, 0.286898126204

In [13]:
print(get_sev_index("cc_trauma"))

5
