In [3]:
import datasets
import os
import pickle
import random
import re
import scipy
import torch
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from datasets import Dataset

from sklearn.pipeline import make_pipeline

from tqdm.auto import tqdm, trange

from transformers import AutoTokenizer, BigBirdForSequenceClassification, pipeline, \
                            TrainingArguments, Trainer

from torch.utils.data import DataLoader

#### Getting predictions to compare with labels

In [22]:
tokenizer = AutoTokenizer.from_pretrained("bigbird/output/ml/epoch-5")
model = BigBirdForSequenceClassification.from_pretrained("bigbird/output/ml/epoch-5")

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
device

device(type='cuda')

In [23]:
def tokenization(batched_text):
    return tokenizer(batched_text['text'], padding = "longest", truncation=True, max_length = 2048)

def m3_predict_probs(sample):
    dataset = pd.DataFrame(sample, columns=['text'])
    dataset = Dataset.from_pandas(dataset)
    dataset = dataset.map(tokenization, batched=True, batch_size=2, remove_columns=['text'])
    dataset.set_format('torch', columns=['input_ids', 'attention_mask'])

    loader = DataLoader(dataset, batch_size=2)
    probs = None

    for batch in loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        outputs = model(input_ids, attention_mask=attention_mask)
        outputs = outputs.logits.sigmoid().cpu().detach().numpy()
        probs = outputs if probs is None else np.concatenate((probs, outputs))
        
    return probs 

In [24]:
bb_probs = m3_predict_probs(df.text.tolist())

HBox(children=(FloatProgress(value=0.0, max=27421.0), HTML(value='')))




To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at  /opt/conda/conda-bld/pytorch_1623448238472/work/aten/src/ATen/native/BinaryOps.cpp:467.)
  return torch.floor_divide(self, other)
Attention type 'block_sparse' is not possible if sequence_length: 693 <= num global tokens: 2 * config.block_size + min. num sliding tokens: 3 * config.block_size + config.num_random_blocks * config.block_size + additional buffer: config.num_random_blocks * config.block_size = 704 with config.block_size = 64, config.num_random_blocks = 3.Changing attention type to 'original_full'...


In [27]:
bb_probs_df = pd.DataFrame(bb_probs, 
                           columns = ['desire_bb', 'intent_bb', 'capability_bb', 'timeframe_bb',
                                      'substance_bb', 'depressed_bb', 'self_harm_bb', 'anxiety_bb', 
                                      'helpful_bb'], 
                           index=df.index)

In [32]:
df = pd.concat([df, bb_probs_df], axis=1, join='inner')
pickle.dump(df, open('saved/multilabel_text_with_preds.pickle', 'wb'))

#### Selecting conversations 

In [4]:
df = pickle.load(open('saved/multilabel_text_with_preds.pickle', 'rb'))

In [5]:
no_text = df.loc[:, df.columns[1:]] > 0.5
no_text

Unnamed: 0_level_0,desire,intent,capability,timeframe,substance,depressed,self_harm,anxiety,helpful,desire_rf,...,helpful_rf,desire_bb,intent_bb,capability_bb,timeframe_bb,substance_bb,depressed_bb,self_harm_bb,anxiety_bb,helpful_bb
conversation_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
000087ec0f14337a6124ed7aa770cee1e29fcc9be90ff4c9444544c3b1d2ca48,False,False,False,False,False,False,False,True,True,False,...,True,False,False,False,False,False,False,False,True,True
000111a50f8f14341da6bd35c5574133cf991fd021774f7dd0b519cb69f00a01,False,False,False,False,False,False,False,False,True,False,...,True,True,False,False,False,False,False,False,False,True
00019c048486c231e7ea91371968c7e4b0349f04d1b1fac72f2c73f427ef99d5,False,False,False,False,False,False,False,False,True,False,...,True,False,False,False,False,False,False,False,False,True
0001bbd8070263101e9a7845ce1f9b38895a35ba29cab1721d0508c7ab077bd7,False,False,False,False,False,False,False,False,False,False,...,True,False,False,False,False,False,False,False,False,False
0002be67f4de3c255e297f52dbafc9649ce3ae5fcb325abf243bdc112f550afb,False,False,False,False,False,False,False,False,True,False,...,True,False,False,False,False,False,True,False,False,True
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
fff7fc9e8ce4b166533fc67b04fe079cce4327bda6a6ee8760153db0993c3f97,False,False,False,False,False,False,False,True,True,False,...,True,False,False,False,False,False,False,False,True,True
fff856ea88444fb982e9072ab257985996f668ab959b497c103ffba64e15cc2c,False,False,False,False,False,True,False,True,True,False,...,True,False,False,False,False,False,True,False,True,True
fffad4b870833f3cae3410c726f5e921fefae49fe6602f9415e7f2056cd3daae,False,False,False,False,False,False,False,False,True,False,...,True,False,False,False,False,False,False,False,False,True
fffc63d0b0ac2bb66ee4c67df1367215b90a54004eb73268804416a095bc0444,True,True,False,False,False,False,False,True,True,False,...,True,True,False,False,False,False,False,False,True,True


#### Selecting individual conversations

In [7]:
no_text[['wrong_bb', 'wrong_rf']] = 0
for i in tqdm(no_text.index):
    bb_count = 0
    rf_count = 0
    for j in range(9):
        rf_count += 1 if no_text.loc[i, no_text.columns[j]] != no_text.loc[i, no_text.columns[j+9]] else 0
        bb_count += 1 if no_text.loc[i, no_text.columns[j]] != no_text.loc[i, no_text.columns[j+18]] else 0
    no_text.loc[i, ['wrong_bb', 'wrong_rf']] = [bb_count, rf_count]

HBox(children=(FloatProgress(value=0.0, max=54842.0), HTML(value='')))




#### Many wrongly predicted labels

In [5]:
text = df.loc['1a199f93fcd83ccc410c12957ca213c9c95cfb1f6ed93341f5c6ead288fc01e9', 'text']

In [8]:
no_text[(no_text['wrong_bb'] == 5)]

Unnamed: 0_level_0,desire,intent,capability,timeframe,substance,depressed,self_harm,anxiety,helpful,desire_rf,...,intent_bb,capability_bb,timeframe_bb,substance_bb,depressed_bb,self_harm_bb,anxiety_bb,helpful_bb,wrong_bb,wrong_rf
conversation_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0278a617c708b238fb878b5ee945d31b0970d655e93e801bda14e5dcc9997ba4,False,False,False,False,False,False,True,False,False,False,...,True,True,True,False,False,False,False,False,5,2
063d510135e062d5082a0937df1c79c38ef486d283c7d35157a7ea04e17e7d6d,False,False,False,False,False,False,False,True,True,False,...,True,True,False,False,False,False,False,False,5,1
0d14e3028605cb6e209505ea0670b7211706a36a3d647bbc7f944492f18d8087,True,False,False,False,False,True,True,False,False,False,...,True,True,False,False,False,False,False,True,5,4
0ea08e6982c8aff1580cba5b809c23210e6ef3f43d55cd68b7f91102f368a149,False,False,False,False,False,True,False,False,True,False,...,True,True,False,False,False,True,False,True,5,1
1568c3fcff956ba1b707f4b913f1578edc3bb88e76bd81a89a19d82498921f67,True,True,True,True,False,False,False,False,True,False,...,False,False,False,False,False,True,False,False,5,4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
fa688e653d9ec1c83f92f28e08ef3806696ab6a2b2abd8321a74707ff0c04eb7,True,True,True,True,False,False,False,False,False,False,...,False,False,False,False,True,False,True,False,5,5
faa9e3dda0c7e349885a1dd0fed4ba1db966fb8a9835e64057f6d767bc538ed7,False,False,False,False,False,False,False,False,True,False,...,True,True,True,False,False,False,False,False,5,0
fbf35f84829f73e54e2adaac4cfb01ad9ea66ac9de2dff1dc9d081f39fc77037,True,True,True,True,False,True,False,False,True,False,...,False,False,False,False,False,False,False,False,5,5
feda078f509317e658efd6035859ba01593740a358eb42a8eb9b292ba91af4f3,False,False,False,False,False,True,False,True,True,False,...,True,True,False,False,False,False,False,True,5,2


In [171]:
bb_fp = df.loc['af30982a452168b259508a70b7a89382afc7a3949f8003cd1c4095862b7e2271', 'text']
pickle.dump(bb_fp, open("saved/text_samples/sample7.pickle", "wb"))

In [172]:
bb_fp2 = df.loc['1d23d1780ab2e03aee8a38c93852bbbcb837fde18b917a25becd80f18698198d', 'text']
pickle.dump(bb_fp2, open("saved/text_samples/sample8.pickle", "wb"))

#### All correctly predicted labels

In [9]:
no_text[(no_text['wrong_bb'] == 0)]

Unnamed: 0_level_0,desire,intent,capability,timeframe,substance,depressed,self_harm,anxiety,helpful,desire_rf,...,intent_bb,capability_bb,timeframe_bb,substance_bb,depressed_bb,self_harm_bb,anxiety_bb,helpful_bb,wrong_bb,wrong_rf
conversation_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
000087ec0f14337a6124ed7aa770cee1e29fcc9be90ff4c9444544c3b1d2ca48,False,False,False,False,False,False,False,True,True,False,...,False,False,False,False,False,False,True,True,0,1
00019c048486c231e7ea91371968c7e4b0349f04d1b1fac72f2c73f427ef99d5,False,False,False,False,False,False,False,False,True,False,...,False,False,False,False,False,False,False,True,0,0
0001bbd8070263101e9a7845ce1f9b38895a35ba29cab1721d0508c7ab077bd7,False,False,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,0,1
0003913c5c6bacb2e1b5711feaf76cb6598f3074fae7956a80bc1352e0656ef4,False,False,False,False,False,True,False,True,True,False,...,False,False,False,False,True,False,True,True,0,2
0004daf966198dfa1885ee66496037f3afdb928bc2a654985702e91c9950bebc,False,False,False,False,False,False,False,False,True,False,...,False,False,False,False,False,False,False,True,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
fff63c45041982c732ba2f3635f977310edfc1fce00dca086a2e35b60ae96e40,False,False,False,False,False,True,False,False,True,False,...,False,False,False,False,True,False,False,True,0,1
fff7b6e089337e63c654a13b65a699012bf5968ee5dbd75d18036811affe19da,False,False,False,False,False,False,False,False,True,False,...,False,False,False,False,False,False,False,True,0,0
fff7fc9e8ce4b166533fc67b04fe079cce4327bda6a6ee8760153db0993c3f97,False,False,False,False,False,False,False,True,True,False,...,False,False,False,False,False,False,True,True,0,1
fff856ea88444fb982e9072ab257985996f668ab959b497c103ffba64e15cc2c,False,False,False,False,False,True,False,True,True,False,...,False,False,False,False,True,False,True,True,0,2


In [179]:
rf_fn = df.loc['28613325d87170b554f2055c3c6213c22afa2d204974bdc24342e9f094723c85', 'text']
pickle.dump(rf_fn, open('saved/text_samples/sample9.pickle', 'wb'))

## Texts for Aggregate Explanations

In [6]:
def get_samples(df, per_group, random_state=42): 
    labels = list(df.columns)
    samples = df.groupby(labels)\
                .apply(lambda x: x.sample(per_group, random_state=random_state))\
                .reset_index(level=labels, drop=True)
    return samples

#### Substance abuse

In [6]:
random.seed(42)
correct_pred = no_text.query("substance == substance_bb")
sample_index = random.sample(list(correct_pred[correct_pred.substance==1].index), 100)
labelled_text = pd.DataFrame(df.loc[sample_index, ['text']])
pickle.dump(labelled_text, open("saved/substance/100_text_samples.pickle", "wb"))

#### Depression and Anxiety and Suicidal Desire

In [60]:
correct_pred = no_text.query('depressed == depressed_bb and anxiety == anxiety_bb and desire == desire_bb')
samples = get_samples(correct_pred[['depressed', 'anxiety', 'desire']], 12)
labelled_text = pd.concat([samples, df.text], axis=1, join='inner')

In [63]:
pickle.dump(labelled_text, open("saved/anxiety/96_text_samples.pickle", "wb"))

#### Suicidal capability risk

In [14]:
random.seed(42)
cap_index = random.sample(list(capability.index), 100)
cap_samples = pd.DataFrame(df.loc[cap_index, ['text']])
pickle.dump(cap_samples, open("saved/capability/100_text_samples.pickle", "wb"))

In [17]:
tf_index = random.sample(list(timeframe.index), 102)
tf_samples = pd.DataFrame(df.loc[tf_index, ['text']])
tf_samples = tf_samples.drop(index=pd.concat([cap_samples, tf_samples], axis=1, join='inner').index)
pickle.dump(tf_samples, open("saved/timeframe/100_text_samples.pickle", "wb"))

In [14]:
suicide = no_text.query("capability==capability_bb and timeframe==timeframe_bb and desire==desire_bb and intent==intent_bb")
samples = get_samples(suicide[['desire', 'intent', 'capability', 'timeframe']], 20)
labelled_text = pd.concat([samples, df.text], axis=1, join='inner')

In [21]:
labelled_text[labelled_text.columns[:-1]].value_counts()

desire  intent  capability  timeframe
False   False   False       False        20
True    False   False       False        20
        True    False       False        20
                True        False        20
                            True         20
dtype: int64

In [82]:
pickle.dump(labelled_text, open("saved/suicide/100_text_samples.pickle", "wb"))

#### Helpfulness

In [6]:
helpful = no_text.loc[no_text.helpful & no_text.helpful_bb]
unhelpful = no_text.loc[~no_text.helpful & ~no_text.helpful_bb]

In [19]:
random.seed(42)
indexes = random.sample(list(helpful.index), 50)
indexes.extend(random.sample(list(unhelpful.index), 50))
samples = pd.DataFrame(df.loc[indexes, ['text', 'helpful']])
pickle.dump(samples, open("saved/helpful/100_text_samples.pickle", "wb"))