In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pickle
import numpy as np
lang_1 = pickle.load(open('data_pico_da.p', 'rb'))

In [3]:
lang_1['vocab_size'] = len(lang_1['word2idx'])
lang_1['word_dim'] = 200

In [4]:
for i, doc in enumerate(lang_1['da']) :
    new_sent = [sent for sent in doc if len(sent) > 0]
    lang_1['da'][i] = new_sent

In [5]:
import torch
import model
mod = model.Model(lang_1=lang_1, num_aspect=30)

In [7]:
mod.load_values('results/model_MonFeb1912:25:042018_pico_tanh')

In [8]:
embeds = []
for doc in lang_1['da'] :
    doc = [[x for x in sent] for sent in doc]
    mod.train_batch(doc, [doc], update=False, do_tanh=True)
    embeds.append(mod.z_s)

In [9]:
for i, e in enumerate(embeds) :
    embeds[i] = e.sum(dim=0).cpu().numpy()
e = np.array(embeds)

In [10]:
import pandas as pd
import numpy as np
da = pd.read_csv('../data/files/decision_aids_filter.csv')
im_map = {'breast cancer': 'BCT', 
          'healthy women at risk of breast cancer': 'BCS', 
          'type II diabetes': 'D',
          'menopausal women': 'MW',
          'pregnant women, previous C section': 'PWC',
          'pregnant women': 'PW',
          'healthy people, at risk colon cancer': 'CCS',
          'prostate cancer': 'PCT',
          'healthy men, contemplating risk of prostate cancer': 'PCS',
          'AF': 'AF',
          'healthy women at genetic risk of breast cancer': 'BCG'}

nb_studies = len(da)
H = np.zeros((nb_studies, nb_studies))
for i in range(nb_studies) :
    H[i, da[da['IM_population'] == da['IM_population'][i]].index] = 1
np.fill_diagonal(H, 0)

In [11]:
from sklearn.preprocessing import normalize
e = normalize(e, 'l2')
scores = np.dot(e, e.T)
scores[np.arange(nb_studies), np.arange(nb_studies)] = -1000

In [12]:
from sklearn.metrics import roc_auc_score
aucs = [0] * nb_studies
for i in range(nb_studies) :
    aucs[i] = roc_auc_score(H[i], scores[i])
rocs = {}
for key in im_map :
    idxs = da[da['IM_population'] == key].index
    rocs[key] = np.mean(np.array(aucs)[idxs])
rocs['mean'] = np.mean(aucs)
print(pd.Series(rocs))


AF                                                    0.381944
breast cancer                                         0.613463
healthy men, contemplating risk of prostate cancer    0.574461
healthy people, at risk colon cancer                  0.556496
healthy women at genetic risk of breast cancer        0.720482
healthy women at risk of breast cancer                0.583825
mean                                                  0.603923
menopausal women                                      0.584330
pregnant women                                        0.704879
pregnant women, previous C section                    0.980392
prostate cancer                                       0.431928
type II diabetes                                      0.763253
dtype: float64
