In [1]:
import datasets
import os
import pickle
import plotly
import random
import re
import scipy
import seaborn as sns
import torch
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from lime import lime_text
from lime.lime_text import LimeTextExplainer

import plotly.graph_objs as go
from plotly.offline import init_notebook_mode, iplot

from termcolor import colored

from tqdm.auto import tqdm, trange

In [2]:
def aggregate_exps(exp_list):
    keys = pd.Series([i[0] for i in exp_list])
    attributions = pd.DataFrame(index=keys.unique())
    attributions['value'] = 0
    attributions['count'] = 0
    
    for i in tqdm(exp_list): 
        index, value = i
        attributions.loc[index, 'value'] += value
        attributions.loc[index, 'count'] += 1
        
    attributions['average'] = attributions['value'] / attributions['count']
    attributions['abs_average'] = abs(attributions['average'])
    return attributions

## Suicidal Capability and Timeframe

In [3]:
cap = aggregate_exps(pickle.load(open("saved/capability/exp_raw.pickle", "rb")))
tf = aggregate_exps(pickle.load(open("saved/timeframe/exp_raw.pickle", "rb")))

HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))




In [120]:
print(f"Capability features: {cap.shape[0]}")
print(f"Timeframe features: {tf.shape[0]}")

Capability features: 1925
Timeframe features: 1858


In [123]:
print(f"Overlap: {len(set(cap.index) & set(tf.index))}")

Overlap: 1146


In [164]:
cap[cap['abs_average'] > 0.04].sort_values("count", ascending=False)[:25]

Unnamed: 0,value,count,average,abs_average
you,10.093004,100,0.10093,0.10093
have,6.111871,91,0.067163,0.067163
to,4.186618,84,0.049841,0.049841
when,7.177573,78,0.09202,0.09202
this,3.362069,78,0.043103,0.043103
life,4.91682,74,0.066444,0.066444
plan,4.011106,73,0.054947,0.054947
the,2.881202,71,0.04058,0.04058
safe,2.924846,67,0.043654,0.043654
suicidal,4.427445,64,0.069179,0.069179


In [4]:
cap[cap['count'] > 5].sort_values("abs_average", ascending=False)[:25]

Unnamed: 0,value,count,average,abs_average
railings,0.247715,1,0.247715,0.247715
suffocate,0.199962,1,0.199962,0.199962
hang,2.330877,15,0.155392,0.155392
paracetamol,0.997332,7,0.142476,0.142476
railway,0.138242,1,0.138242,0.138242
tracks,0.270167,2,0.135083,0.135083
pills,4.951578,38,0.130305,0.130305
vehicle,0.128947,1,0.128947,0.128947
7am,0.126959,1,0.126959,0.126959
bedroom,0.112313,1,0.112313,0.112313


In [157]:
tf[tf['count'] > 5].sort_values("abs_average", ascending=False)[:25]

Unnamed: 0,value,count,average,abs_average
paracetamol,1.756677,17,0.103334,0.103334
you,9.466064,100,0.094661,0.094661
pills,2.872949,31,0.092676,0.092676
tablets,2.147469,27,0.079536,0.079536
ambulance,0.582114,8,0.072764,0.072764
when,5.445176,76,0.071647,0.071647
train,0.41979,6,0.069965,0.069965
jumping,0.747309,11,0.067937,0.067937
overdose,2.626548,40,0.065664,0.065664
tomorrow,1.413535,22,0.064252,0.064252


## Helpfulness

In [39]:
exp_list = pickle.load(open("saved/helpful/exp_raw.pickle", "rb"))
og = pickle.load(open("saved/helpful/100_text_samples.pickle", "rb"))
attributions = aggregate_exps(exp_list)

HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))




## Anxiety, Depressed, Suicidal

In [40]:
anxiety = aggregate_exps(pickle.load(open("saved/anxiety/exp_raw.pickle", "rb")))
depressed = aggregate_exps(pickle.load(open("saved/anxiety/depressed/exp_raw.pickle", "rb")))
suicide = aggregate_exps(pickle.load(open("saved/anxiety/suicide/exp_raw.pickle", "rb")))

In [216]:
common = pd.concat([anxiety, depressed, suicide], axis=1, join='inner')
common = common.drop(columns=['value'])
common.columns = ['a_count', 'a_average', 'a_abs_average', 
                  'd_count', 'd_average', 'd_abs_average', 
                  's_count', 's_average', 's_abs_average']

print(f"Anxiety and Depression overlap: {len(set(anxiety.index) & set(depressed.index))}")
print(f"Anxiety and Suicide overlap: {len(set(anxiety.index) & set(suicide.index))}")
print(f"Depression and Suicide overlap: {len(set(depressed.index) & set(suicide.index))}")
print(f"Common features: {common.shape[0]}")

Anxiety and Depression overlap: 1520
Anxiety and Suicide overlap: 1450
Depression and Suicide overlap: 1455
Common features: 1224


In [227]:
common.sort_values('a_abs_average', ascending=False)[:25]

Unnamed: 0,a_count,a_average,a_abs_average,d_count,d_average,d_abs_average,s_count,s_average,s_abs_average
anxiety,30,0.206459,0.206459,28,-0.071723,0.071723,22,-0.027992,0.027992
disabled,1,-0.192211,0.192211,1,-0.063499,0.063499,1,5.6e-05,5.6e-05
trans,1,-0.17688,0.17688,1,-0.266226,0.266226,1,-0.006672,0.006672
sexuality,1,-0.15435,0.15435,1,-0.006777,0.006777,1,0.000113,0.000113
bullying,1,-0.149778,0.149778,1,-0.137102,0.137102,1,-0.002775,0.002775
anxieties,1,0.118987,0.118987,1,-0.087299,0.087299,1,0.000258,0.000258
stressed,5,0.098432,0.098432,3,-0.038408,0.038408,5,-0.009607,0.009607
visions,1,0.089294,0.089294,1,-0.009533,0.009533,1,-0.16324,0.16324
abusing,3,-0.08833,0.08833,3,-0.015711,0.015711,3,0.005822,0.005822
dread,1,0.082166,0.082166,1,-0.008815,0.008815,1,0.001644,0.001644


In [218]:
common.sort_values('d_abs_average', ascending=False)[:25]

Unnamed: 0,a_count,a_average,a_abs_average,d_count,d_average,d_abs_average,s_count,s_average,s_abs_average
depressed,7,-0.043971,0.043971,8,0.267781,0.267781,6,-0.017329,0.017329
trans,1,-0.17688,0.17688,1,-0.266226,0.266226,1,-0.006672,0.006672
depressive,3,-0.052485,0.052485,3,0.264903,0.264903,1,0.006888,0.006888
depression,16,-0.04462,0.04462,17,0.230524,0.230524,12,0.01037,0.01037
bullying,1,-0.149778,0.149778,1,-0.137102,0.137102,1,-0.002775,0.002775
agressive,1,-0.050654,0.050654,1,-0.096177,0.096177,1,-0.017467,0.017467
despair,1,-0.011209,0.011209,1,0.091552,0.091552,1,0.010101,0.010101
anxieties,1,0.118987,0.118987,1,-0.087299,0.087299,1,0.000258,0.000258
sadness,4,-0.048359,0.048359,5,0.078535,0.078535,2,-0.00792,0.00792
anxiety,30,0.206459,0.206459,28,-0.071723,0.071723,22,-0.027992,0.027992


In [219]:
common.sort_values('s_abs_average', ascending=False)[:25]

Unnamed: 0,a_count,a_average,a_abs_average,d_count,d_average,d_abs_average,s_count,s_average,s_abs_average
visions,1,0.089294,0.089294,1,-0.009533,0.009533,1,-0.16324,0.16324
suicidal,44,-0.049901,0.049901,43,-0.022637,0.022637,50,0.152345,0.152345
noose,2,-0.057761,0.057761,2,-0.042272,0.042272,2,0.11651,0.11651
suicide,43,-0.021415,0.021415,52,0.011016,0.011016,56,0.116248,0.116248
drown,1,-0.018196,0.018196,1,0.010227,0.010227,1,0.113709,0.113709
hang,8,-0.037742,0.037742,6,-0.027148,0.027148,8,0.107106,0.107106
window,1,-0.01219,0.01219,1,-0.005309,0.005309,1,0.104682,0.104682
wrists,1,-0.064656,0.064656,1,-0.033427,0.033427,1,0.097497,0.097497
you,93,-0.030458,0.030458,80,0.003247,0.003247,99,0.09513,0.09513
kill,28,-0.027611,0.027611,31,-0.026655,0.026655,36,0.088276,0.088276


## Suicide Risk

In [165]:
desire = aggregate_exps(pickle.load(open("saved/suicide/desire/exp_raw.pickle", "rb")))

HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))




In [166]:
intent = aggregate_exps(pickle.load(open("saved/suicide/intent/exp_raw.pickle", "rb")))
capability = aggregate_exps(pickle.load(open("saved/suicide/capability/exp_raw.pickle", "rb")))
timeframe = aggregate_exps(pickle.load(open("saved/suicide/timeframe/exp_raw.pickle", "rb")))

HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))




In [167]:
print(f"desire: {desire.shape[0]}")
print(f"intent: {intent.shape[0]}")
print(f"capability: {capability.shape[0]}")
print(f"timeframe: {timeframe.shape[0]}")

desire: 2030
intent: 2004
capability: 1979
timeframe: 1965


In [174]:
common = pd.concat([desire, intent, capability, timeframe], axis=1, join='inner')
common = common.drop(columns=['value'])
common.columns = ['d_count', 'd_average', 'd_abs_average', 
                  'i_count', 'i_average', 'i_abs_average', 
                  'c_count', 'c_average', 'c_abs_average', 
                  't_count', 't_average', 't_abs_average']
print(f"overlap: {common.shape[0]}")

overlap: 1138


In [189]:
print(f"desire x intent: {len(set(desire.index) & set(intent.index))}")
print(f"desire x capability: {len(set(desire.index) & set(capability.index))}")
print(f"desire x timeframe: {len(set(desire.index) & set(timeframe.index))}")
print("\n")
print(f"intent x capability: {len(set(intent.index) & set(capability.index))}")
print(f"intent x timeframe: {len(set(intent.index) & set(timeframe.index))}")
print("\n")
print(f"capability x timeframe: {len(set(capability.index) & set(timeframe.index))}")
print("\n")
print(f"desire x intent x capability: {len(set(desire.index) & set(intent.index) & set(capability.index))}")
print(f"desire x intent x timeframe: {len(set(desire.index) & set(intent.index) & set(timeframe.index))}")
print(f"desire x capability x timeframe: {len(set(desire.index) & set(capability.index) & set(timeframe.index))}")
print("\n")
print(f"timeframe x intent x capability: {len(set(timeframe.index) & set(intent.index) & set(capability.index))}")

desire x intent: 1469
desire x capability: 1455
desire x timeframe: 1453


intent x capability: 1475
intent x timeframe: 1455


capability x timeframe: 1470


desire x intent x capability: 1251
desire x intent x timeframe: 1239
desire x capability x timeframe: 1245


timeframe x intent x capability: 1259


In [203]:
desire.sort_values('abs_average', ascending=False)[:25]

Unnamed: 0,value,count,average,abs_average
visions,-0.171154,1,-0.171154,0.171154
suicidal,7.803215,50,0.156064,0.156064
suicide,6.600445,56,0.117865,0.117865
wrists,0.115037,1,0.115037,0.115037
noose,0.223809,2,0.111905,0.111905
drown,0.109437,1,0.109437,0.109437
hang,0.844506,8,0.105563,0.105563
window,0.100955,1,0.100955,0.100955
you,9.434922,99,0.095302,0.095302
kill,3.137589,34,0.092282,0.092282


In [202]:
common.sort_values('d_abs_average', ascending=False)[:25]

Unnamed: 0,d_count,d_average,d_abs_average,i_count,i_average,i_abs_average,c_count,c_average,c_abs_average,t_count,t_average,t_abs_average
visions,1,-0.171154,0.171154,1,-0.016993,0.016993,1,-0.001477,0.001477,1,-0.000279,0.000279
suicidal,50,0.156064,0.156064,46,0.061982,0.061982,38,0.035058,0.035058,41,0.012924,0.012924
suicide,56,0.117865,0.117865,53,0.063759,0.063759,47,0.042325,0.042325,39,0.013229,0.013229
wrists,1,0.115037,0.115037,1,0.015409,0.015409,1,0.00177,0.00177,1,0.000266,0.000266
noose,2,0.111905,0.111905,2,0.171075,0.171075,2,0.148547,0.148547,2,0.057403,0.057403
drown,1,0.109437,0.109437,1,0.160206,0.160206,1,0.079693,0.079693,1,0.013002,0.013002
hang,8,0.105563,0.105563,8,0.125424,0.125424,8,0.091329,0.091329,8,0.040992,0.040992
window,1,0.100955,0.100955,1,0.158441,0.158441,1,0.234933,0.234933,1,0.192939,0.192939
you,99,0.095302,0.095302,97,0.07433,0.07433,98,0.056998,0.056998,97,0.03236,0.03236
kill,34,0.092282,0.092282,30,0.083511,0.083511,30,0.058304,0.058304,29,0.02569,0.02569


In [200]:
timeframe.loc[set(timeframe.index) ^ set(common.index)].sort_values('abs_average', ascending=False)[:25]

Unnamed: 0,value,count,average,abs_average
vodka,0.042763,1,0.042763,0.042763
counter,0.037476,1,0.037476,0.037476
roads,0.035021,1,0.035021,0.035021
7pm,0.033798,1,0.033798,0.033798
late,0.032901,1,0.032901,0.032901
gown,0.030876,1,0.030876,0.030876
miss,-0.027182,1,-0.027182,0.027182
Overdosing,-0.025557,1,-0.025557,0.025557
fill,0.023644,1,0.023644,0.023644
prefer,-0.023296,1,-0.023296,0.023296


In [204]:
timeframe.sort_values('abs_average', ascending=False)[:25]

Unnamed: 0,value,count,average,abs_average
cliffs,0.202429,1,0.202429,0.202429
window,0.192939,1,0.192939,0.192939
painkillers,0.14021,1,0.14021,0.14021
cupboard,0.138545,1,0.138545,0.138545
nearest,0.177173,2,0.088586,0.088586
road,0.08782,1,0.08782,0.08782
paracetamol,0.372181,5,0.074436,0.074436
drugs,0.145319,2,0.07266,0.07266
cyanide,0.064226,1,0.064226,0.064226
stash,0.127063,2,0.063531,0.063531


In [220]:
common.sort_values('d_abs_average', ascending=False)[:25]

Unnamed: 0,a_count,a_average,a_abs_average,d_count,d_average,d_abs_average,s_count,s_average,s_abs_average
depressed,7,-0.043971,0.043971,8,0.267781,0.267781,6,-0.017329,0.017329
trans,1,-0.17688,0.17688,1,-0.266226,0.266226,1,-0.006672,0.006672
depressive,3,-0.052485,0.052485,3,0.264903,0.264903,1,0.006888,0.006888
depression,16,-0.04462,0.04462,17,0.230524,0.230524,12,0.01037,0.01037
bullying,1,-0.149778,0.149778,1,-0.137102,0.137102,1,-0.002775,0.002775
agressive,1,-0.050654,0.050654,1,-0.096177,0.096177,1,-0.017467,0.017467
despair,1,-0.011209,0.011209,1,0.091552,0.091552,1,0.010101,0.010101
anxieties,1,0.118987,0.118987,1,-0.087299,0.087299,1,0.000258,0.000258
sadness,4,-0.048359,0.048359,5,0.078535,0.078535,2,-0.00792,0.00792
anxiety,30,0.206459,0.206459,28,-0.071723,0.071723,22,-0.027992,0.027992


## Substance Use

In [63]:
og_ml = aggregate_exps(pickle.load(open("saved/substance/og_ml_exp_raw.pickle", "rb")))

HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))




In [6]:
finetuned_ml = aggregate_exps(pickle.load(open("saved/substance/finetune_ml_exp_raw.pickle", "rb")))

HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))




In [32]:
everything = pd.concat([og_ml, finetuned_ml], axis=1, join='outer')
everything = everything.drop(columns=['value', 'abs_average'])
everything.columns = ['og_count', 'og_avg', 'ft_count', 'ft_avg']
everything.sort_values('ft_avg', ascending=False)[:25]

Unnamed: 0,og_count,og_avg,ft_count,ft_avg
cannabis,7.0,0.264703,7.0,0.205985
drinking,43.0,0.153517,45.0,0.15705
cocaine,11.0,0.143026,9.0,0.128862
alcohol,46.0,0.104702,48.0,0.126127
smoking,8.0,0.129075,9.0,0.122774
addiction,23.0,0.138053,23.0,0.120776
drugs,34.0,0.12424,35.0,0.1099
fentanyl,1.0,0.122019,1.0,0.103016
drug,28.0,0.08651,29.0,0.096544
nicotine,2.0,0.093729,2.0,0.080204


In [60]:
everything.loc['quitting']

og_count    5.000000
og_avg      0.069769
ft_count    5.000000
ft_avg      0.030224
Name: quitting, dtype: float64