In [1]:
import os 
import re 
import json 
import shutil
import numpy as np 
import pandas as pd 
from tqdm import tqdm
from collections import defaultdict
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
from sklearn.metrics import PrecisionRecallDisplay, precision_recall_curve, roc_curve

In [2]:
import warnings
warnings.filterwarnings("ignore")

In [3]:
default_threshold = 0.5
row2index = lambda r: f"{r['subreddit_id']}_{r['post_id']}_{r['sentence_id']}"
id2disease = [
    "adhd",
    "anxiety",
    "bipolar_disorder",
    "depression",
    "eating_disorder",
    "ocd",
    "ptsd"
]

In [4]:
exp_name = "mbert_label_enhance_bal_sample_050_666"

In [5]:
pseudo_thrs = []
ref_df = pd.read_csv("../data/symp_data_w_control/test.csv", index_col=None)
all_symps = ref_df.columns[5:-1].tolist()
infer_df = ref_df.copy()
infer_probs = np.load(f"./infer_output/{exp_name}/test.npy")
infer_df.iloc[:, 5:-1] = infer_probs

In [6]:
results_df = []
for col in all_symps:
    labels = ref_df[col].values
    labels_sel = labels[labels != -1]
    probs_sel = infer_df[col].values[labels != -1]
    auc = roc_auc_score(labels_sel, probs_sel)
    preds_default = (probs_sel > default_threshold).astype(float)
    f1_default = f1_score(labels_sel, preds_default)
    p_default = precision_score(labels_sel, preds_default)
    r_default = recall_score(labels_sel, preds_default)
    precisions, recalls, thresholds = precision_recall_curve(labels_sel, probs_sel)
    best_f, best_thr, p_at_best, r_at_best = 0, 0, 0, 0
    for i, (p, r, thr) in enumerate(zip(precisions, recalls, thresholds)):
        f = 2*p*r / (p+r)
        if f > best_f:
            best_f, best_thr, p_at_best, r_at_best = f, thr, p, r
    results_df.append([col, best_thr, auc, f1_default, p_default, r_default, best_f, p_at_best, r_at_best])
    pseudo_thrs.append(best_thr)
results_df = pd.DataFrame(results_df, columns=['disease', 'thr', 'auc', 'f1_default', 'p_default', 'r_default', 'f1_best', 'p', 'r'])
results_df

Unnamed: 0,disease,thr,auc,f1_default,p_default,r_default,f1_best,p,r
0,Anxious_Mood,0.628077,0.973165,0.776204,0.789625,0.763231,0.778978,0.806259,0.753482
1,Autonomic_symptoms,0.595551,0.988879,0.719486,0.677419,0.767123,0.721739,0.688797,0.757991
2,Cardiovascular_symptoms,0.686787,0.999017,0.900232,0.866071,0.937198,0.91253,0.893519,0.932367
3,Catatonic_behavior,0.988553,0.961521,0.456621,0.387597,0.555556,0.540881,0.623188,0.477778
4,Decreased_energy_tiredness_fatigue,0.162364,0.975905,0.525714,0.666667,0.433962,0.563107,0.58,0.54717
5,Depressed_Mood,0.830102,0.973365,0.514768,0.481579,0.55287,0.534622,0.572414,0.501511
6,Gastrointestinal_symptoms,0.314559,0.998609,0.769231,0.833333,0.714286,0.811189,0.75817,0.87218
7,Genitourinary_symptoms,0.996896,0.99158,0.847162,0.788618,0.915094,0.903846,0.921569,0.886792
8,Hyperactivity_agitation,0.463877,0.971679,0.472973,0.853659,0.327103,0.483221,0.857143,0.336449
9,Impulsivity,0.448567,0.978776,0.613333,0.575,0.657143,0.622517,0.580247,0.671429


In [7]:
results_df.describe()

Unnamed: 0,thr,auc,f1_default,p_default,r_default,f1_best,p,r
count,38.0,38.0,38.0,38.0,38.0,38.0,38.0,38.0
mean,0.690854,0.985412,0.670268,0.668753,0.693503,0.696202,0.728895,0.67852
std,0.224843,0.012554,0.152085,0.156895,0.172491,0.14896,0.135528,0.170267
min,0.162364,0.941538,0.290909,0.331081,0.235294,0.313725,0.439024,0.235294
25%,0.568616,0.977516,0.554338,0.559178,0.588734,0.594915,0.637139,0.559026
50%,0.660098,0.988998,0.713379,0.691508,0.730658,0.718349,0.75759,0.714133
75%,0.89337,0.994914,0.76791,0.789373,0.821945,0.801799,0.836061,0.790871
max,0.996896,0.999562,0.909091,0.892045,0.951087,0.925134,0.921569,0.940217


In [8]:
results_df.mean()

thr           0.690854
auc           0.985412
f1_default    0.670268
p_default     0.668753
r_default     0.693503
f1_best       0.696202
p             0.728895
r             0.678520
dtype: float64