In [1]:
import os
import re
import json
import pickle
import blingfire   # fast and relatively accurate sentence bound disambiguation
import numpy as np
import pandas as pd
from tqdm import tqdm, trange
from heapq import heappop, heappush, heappushpop, heapreplace
from collections import defaultdict
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import roc_auc_score, f1_score, precision_score, recall_score
import matplotlib.pyplot as plt

In [2]:
sbert = SentenceTransformer('paraphrase-MiniLM-L6-v2', device='cuda:0')

In [3]:
parsed_kg = json.load(open("../data/parsed_kg_info.json"))
symp2id = parsed_kg['symp2id']
id2symp = parsed_kg['id2symp']
desc2id = parsed_kg['desc2id']
id2desc = parsed_kg['id2desc']
symp_id2desc_range = parsed_kg['symp_id2desc_range']
symp2descs = {id2symp[symp_id]: id2desc[l:r] for symp_id, (l, r) in enumerate(symp_id2desc_range)}

In [7]:
test_set = pd.read_csv("../data/symp_data_w_control/test.csv", index_col=None)
test_set

Unnamed: 0,subreddit_id,post_id,sentence_id,disease,sentence,Anxious_Mood,Autonomic_symptoms,Cardiovascular_symptoms,Catatonic_behavior,Decreased_energy_tiredness_fatigue,...,panic_fear,pessimism,poor_memory,sleep_disturbance,somatic_muscle,somatic_symptoms_others,somatic_symptoms_sensory,weight_and_appetite_change,Anger_Irritability,uncertain
0,0,225,9,adhd,This makes it a drag talking to people and con...,0,-1,-1,-1,-1,...,-1,-1,-1,-1,-1,-1,-1,-1,-1,0
1,0,370,4,adhd,Being unavoidably distracted is sometimes fun.,0,-1,-1,-1,-1,...,-1,-1,-1,-1,-1,-1,-1,-1,-1,0
2,1,108,5,adhd,I have to constantly stop myself from dominati...,0,-1,-1,-1,-1,...,-1,-1,-1,-1,-1,-1,-1,-1,-1,0
3,1,151,8,adhd,This also means people often think I'm ignorin...,0,-1,-1,-1,-1,...,-1,-1,-1,-1,-1,-1,-1,-1,-1,0
4,1,206,21,adhd,I no longer interrupt people when they're spea...,0,-1,-1,-1,-1,...,-1,-1,-1,-1,-1,-1,-1,-1,-1,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
36938,-1,-1,-1,control,"She slow begins to unlace the front, eager, li...",0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,-1
36939,-1,-1,-1,control,Realm leaders in the sidebar?,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,-1
36940,-1,-1,-1,control,"In the higher activity sub-reddits, these post...",0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,-1
36941,-1,-1,-1,control,Shipping to Serbia Does anyone from Serbia has...,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,-1


In [8]:
sents = test_set["sentence"].tolist()
len(sents)

36943

### use org descs

In [6]:
symp2desc_embs = {symp: sbert.encode(descs) for symp, descs in symp2descs.items()}
len(symp2desc_embs)

38

In [9]:
symp2probs = defaultdict(list)
BS = 32
for i in trange(0, len(sents), BS):
    batch_sents = sents[i:i+BS]
    batch_embs = sbert.encode(batch_sents)
    for symp, desc_embs in symp2desc_embs.items():
        desc_sims = cosine_similarity(batch_embs, desc_embs)
        symp2probs[symp].extend(desc_sims.max(1))

100%|██████████| 1155/1155 [01:06<00:00, 17.32it/s]


In [15]:
all_metrics = {
    "p": [],
    "r": [],
    "f1": [],
    "auc": []
}
THRESHOLD = 0.5
for symp, probs in symp2probs.items():
    labels = test_set[symp].values
    probs = np.array(probs)
    sel_indices = np.where(labels != -1)
    sel_probs = probs[sel_indices]
    sel_preds = (sel_probs > THRESHOLD).astype(int)
    sel_labels = labels[sel_indices]
    p = precision_score(sel_labels, sel_preds)
    r = recall_score(sel_labels, sel_preds)
    f1 = f1_score(sel_labels, sel_preds)
    auc = roc_auc_score(sel_labels, sel_probs)
    all_metrics["p"].append(p)
    all_metrics["r"].append(r)
    all_metrics["f1"].append(f1)
    all_metrics["auc"].append(auc)
for k, values in all_metrics.items():
    print(k, np.mean(values))

p 0.5036182906297512
r 0.6611689945781819
f1 0.537034538035627
auc 0.9894072405625843


### use descs from post

In [4]:
desc_from_post = {}
for fname in os.listdir("../data/desc_from_post/"):
    symp = fname[:-4]
    desc_from_post[symp] = open("../data/desc_from_post/"+fname).read().split('\n')

In [16]:
symp2desc_embs = {symp: sbert.encode(descs) if symp not in desc_from_post else sbert.encode(desc_from_post[symp]) for symp, descs in symp2descs.items()}
len(symp2desc_embs)

38

In [17]:
symp2probs = defaultdict(list)
BS = 32
for i in trange(0, len(sents), BS):
    batch_sents = sents[i:i+BS]
    batch_embs = sbert.encode(batch_sents)
    for symp, desc_embs in symp2desc_embs.items():
        desc_sims = cosine_similarity(batch_embs, desc_embs)
        symp2probs[symp].extend(desc_sims.max(1))

100%|██████████| 1155/1155 [01:06<00:00, 17.25it/s]


In [18]:
all_metrics = {
    "p": [],
    "r": [],
    "f1": [],
    "auc": []
}
THRESHOLD = 0.5
for symp, probs in symp2probs.items():
    labels = test_set[symp].values
    probs = np.array(probs)
    sel_indices = np.where(labels != -1)
    sel_probs = probs[sel_indices]
    sel_preds = (sel_probs > THRESHOLD).astype(int)
    sel_labels = labels[sel_indices]
    p = precision_score(sel_labels, sel_preds)
    r = recall_score(sel_labels, sel_preds)
    f1 = f1_score(sel_labels, sel_preds)
    auc = roc_auc_score(sel_labels, sel_probs)
    all_metrics["p"].append(p)
    all_metrics["r"].append(r)
    all_metrics["f1"].append(f1)
    all_metrics["auc"].append(auc)
for k, values in all_metrics.items():
    print(k, np.mean(values))

p 0.4889419978434326
r 0.7709345596486377
f1 0.5523990960760925
auc 0.992086459261943
