In [1]:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

In [3]:
with open("../training/cc_cols.txt", "r") as f:
    cc_cols = f.readlines()
cc_cols = list(map(lambda x:x[:-1], cc_cols))

with open("../training/pmh_cols.txt", "r") as f:
    pmh_cols = f.readlines()
pmh_cols = list(map(lambda x:x[:-1], pmh_cols))

In [4]:
import numpy as np

In [6]:
cc_embeddings = np.load("../OpenAIEmbeddings/cc_embeddings.npy")
pmh_embeddings = np.load("../OpenAIEmbeddings/pmh_embeddings.npy")

In [7]:
sev_words = ['heart attack', 'seizure', 'stroke']

sev_encodings = model.encode(sev_words)

In [8]:
cc_sev_scores = np.zeros((cc_embeddings.shape[0], len(sev_words)))
pmh_sev_scores = np.zeros((pmh_embeddings.shape[0], len(sev_words)))

for i in range(len(sev_words)):
    for j in range(cc_embeddings.shape[0]):
        cc_sev_scores[j][i] = np.dot(cc_embeddings[j], sev_encodings[i]) / (np.linalg.norm(cc_embeddings[j]) * np.linalg.norm(sev_encodings[i]))

    for j in range(pmh_embeddings.shape[0]):
        pmh_sev_scores[j][i] = np.dot(pmh_embeddings[j], sev_encodings[i]) / (np.linalg.norm(pmh_embeddings[j]) * np.linalg.norm(sev_encodings[i]))

In [9]:
cc_sev_scores = list(np.mean(cc_sev_scores, axis=1))
pmh_sev_scores = list(np.mean(pmh_sev_scores, axis=1))

In [10]:
cc_sev_scores = [(i, cc_sev_scores[i]) for i in range(len(cc_sev_scores))]
pmh_sev_scores = [(i, pmh_sev_scores[i]) for i in range(len(pmh_sev_scores))]

cc_sev_scores = sorted(cc_sev_scores, key=lambda x:x[1], reverse=True)
pmh_sev_scores = sorted(pmh_sev_scores, key=lambda x:x[1], reverse=True)

In [11]:
def get_sev_index(word):
    idx_orders = [i[0] for i in cc_sev_scores]
    return idx_orders.index(cc_cols.index(word))

In [18]:
pmh_sev_scores

[(219, 0.4769665201505025),
 (257, 0.4378529091676076),
 (16, 0.4250001311302185),
 (40, 0.4020433823267619),
 (229, 0.39613211154937744),
 (65, 0.3897954821586609),
 (220, 0.38940643270810443),
 (28, 0.34811728199323017),
 (87, 0.33681730926036835),
 (63, 0.3336651623249054),
 (249, 0.32134700814882916),
 (37, 0.31941765546798706),
 (115, 0.3164737820625305),
 (39, 0.3155514895915985),
 (117, 0.3154146025578181),
 (38, 0.3142701139052709),
 (233, 0.3133492370446523),
 (89, 0.3114054302374522),
 (159, 0.31031256914138794),
 (32, 0.30975181857744855),
 (234, 0.3082517584164937),
 (94, 0.3074887643257777),
 (260, 0.30645618836085003),
 (107, 0.3059845467408498),
 (262, 0.3045850197474162),
 (17, 0.3023215631643931),
 (31, 0.3021031618118286),
 (86, 0.29611314336458844),
 (240, 0.29517558217048645),
 (109, 0.29496438304583233),
 (228, 0.2889227668444316),
 (19, 0.2875792731841405),
 (21, 0.286898136138916),
 (8, 0.2841538190841675),
 (164, 0.2790655493736267),
 (271, 0.27880488832791644),

In [19]:
print(get_sev_index("cc_trauma"))

5
