In [1]:
import torch
import numpy as np
import random
from types import SimpleNamespace as Namespace
from typing import Tuple
from config import config_dict as cfg
import pandas as pd
cfg = Namespace(**cfg)
from model_bert import TSCModel_PL
from train_apply import load_csvs, load_data
from inference import load_data_for_inference
from transformers import BertTokenizer
from captum.attr import IntegratedGradients, LayerIntegratedGradients
from captum.attr import visualization as viz
# device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device("cpu")
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)
torch.backends.cudnn.deterministic = True

In [2]:

checkpoint_path = "/home/david/TU-Kurse/WS22_23/PML/BERT-TSC/lightning_logs/cluster/difference_weighting_3ep/checkpoints/epoch=2-step=19947.ckpt"
checkpoint = torch.load(checkpoint_path, map_location=device)
pretrained_state_dict = checkpoint["state_dict"]
model = TSCModel_PL(cfg)
model.load_pretrained_weights_whole_model(pretrained_state_dict)
model = model.to(device)
tokenizer : BertTokenizer = BertTokenizer.from_pretrained("bert-base-cased")
lig = LayerIntegratedGradients(model, model.backbone.embeddings)
ref_token_id = tokenizer.pad_token_id 
sep_token_id = tokenizer.sep_token_id 
cls_token_id = tokenizer.cls_token_id 


In [7]:
def get_toxic_examples(df: pd.DataFrame):
    return df[df.iloc[:, 2:].sum(axis=1) > 0]

def confidence(p: torch.Tensor) -> torch.Tensor:
    return (p >= 0.5) * p + (p < 0.5) * (1 - p)

# adapted from Captum Tutorial: https://captum.ai/tutorials/Bert_SQUAD_Interpret
def construct_input_ref_pair(text: str):
    text_ids = tokenizer.encode(text, add_special_tokens=False, truncation=False, max_length=cfg.truncate_seq_len)

    # construct input ids
    input_ids = [cls_token_id] + text_ids + [sep_token_id]
    input_ids = torch.tensor([input_ids], device=device, dtype=torch.int)
    input_token_type_ids = torch.zeros_like(input_ids, device=device, dtype=torch.int)
    input_attention_mask = torch.ones_like(input_ids, device=device, dtype=torch.int)

    # construct ref ids
    ref_input_ids = [cls_token_id] + [ref_token_id] * len(text_ids) + [sep_token_id] 
    ref_input_ids = torch.tensor([ref_input_ids], device=device, dtype=torch.int)
    ref_token_type_ids = torch.zeros_like(ref_input_ids, device=device, dtype=torch.int)
    ref_attention_mask = torch.ones_like(input_ids, device=device, dtype=torch.int)

    return {"input_ids": input_ids, "token_type_ids": input_token_type_ids, "attention_mask": input_attention_mask}, \
        {"input_ids": ref_input_ids, "token_type_ids": ref_token_type_ids, "attention_mask": ref_attention_mask}


def predict(model:TSCModel_PL, sentence: str) -> Tuple[torch.Tensor, torch.Tensor]:
    input_dict, ref_dict = construct_input_ref_pair(sentence)
    with torch.no_grad():
        pred = model(input_dict["input_ids"], input_dict["token_type_ids"], input_dict["attention_mask"])[0]
        pred = torch.sigmoid(pred)
        conf = confidence(pred)
    return pred, conf

def attribute(lig:LayerIntegratedGradients, input_dict, ref_dict):
    res_attr, res_delta = [], []
    for target in range(6):
        attributions, delta = lig.attribute(inputs=input_dict["input_ids"], baselines=ref_dict["input_ids"], \
                                        additional_forward_args=(input_dict["token_type_ids"], input_dict["attention_mask"]), \
                                            target = target,
                                            return_convergence_delta=True, n_steps=50)
        attributions_sum = attributions.sum(dim=-1).squeeze(0)
        res_attr.append((attributions_sum.detach().cpu()))
        res_delta.append((delta.detach().cpu()))
    return torch.stack(res_attr), torch.stack(res_delta)


def visualize(attributions, deltas, predictions, raw_tokens, labels):
    print(raw_tokens)
    for target in range(6):
        print(f"Target: {cfg.label_tags[target]}")
        attribution = attributions[target]
        prediction = predictions[target]
        pred_vis = viz.VisualizationDataRecord(
                                word_attributions=attribution,
                                pred_prob = prediction,
                                pred_class=int(prediction>0.5),
                                true_class=labels[target],
                                attr_class=int(prediction>0.5),
                                attr_score=attribution.sum(),
                                raw_input_ids=raw_tokens,
                                convergence_score=deltas[target],
    )
        viz.visualize_text([pred_vis])
    del attributions, deltas, predictions, raw_tokens, labels

def explain(model: TSCModel_PL, sentence: str, labels):
    with torch.no_grad():
        input_dict, ref_dict = construct_input_ref_pair(sentence)
        raw_tokens = tokenizer.convert_ids_to_tokens(input_dict["input_ids"].tolist()[0])
        predictions, confidences = predict(model, sentence)
        attributions, deltas = attribute(lig, input_dict, ref_dict)
        visualize(attributions, deltas, predictions, raw_tokens, labels)

In [4]:
# example sentences taken from 
# https://github.com/conversationai/unintended-ml-bias-analysis/blob/main/sentence_templates/en_sentence_templates.csv
bert_worst_fp = 'You are a kind gay'
bert_worst_fn =  'Anastasiya is a horrible younger'
bert_aug_worst_fp =  'Mohammed is a neat muslim'
bert_aug_worst_fn = 'Hiro is a vile asian'

explain(model, bert_worst_fp, labels=[0, 0, 0, 0, 0, 0])
explain(model, bert_worst_fn, labels=[1, 0, 0, 0, 0, 0])
explain(model, bert_aug_worst_fp, labels=[0, 0, 0, 0, 0, 0])
explain(model, bert_aug_worst_fn, labels=[1, 0, 0, 0, 0, 1])


['[CLS]', 'You', 'are', 'a', 'kind', 'gay', '[SEP]']
Target: toxic


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,1 (0.97),1.0,11.64,[CLS] You are a kind gay [SEP]
,,,,


Target: severe_toxic


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (0.04),0.0,9.81,[CLS] You are a kind gay [SEP]
,,,,


Target: obscene


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (0.16),0.0,7.75,[CLS] You are a kind gay [SEP]
,,,,


Target: threat


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (0.00),0.0,5.22,[CLS] You are a kind gay [SEP]
,,,,


Target: insult


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,1 (0.67),1.0,11.86,[CLS] You are a kind gay [SEP]
,,,,


Target: identity_hate


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,1 (0.85),1.0,12.79,[CLS] You are a kind gay [SEP]
,,,,


['[CLS]', 'Ana', '##sta', '##si', '##ya', 'is', 'a', 'horrible', 'younger', '[SEP]']
Target: toxic


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,0 (0.00),0.0,1.89,[CLS] Ana ##sta ##si ##ya is a horrible younger [SEP]
,,,,


Target: severe_toxic


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (0.00),0.0,-0.91,[CLS] Ana ##sta ##si ##ya is a horrible younger [SEP]
,,,,


Target: obscene


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (0.00),0.0,2.23,[CLS] Ana ##sta ##si ##ya is a horrible younger [SEP]
,,,,


Target: threat


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (0.00),0.0,0.13,[CLS] Ana ##sta ##si ##ya is a horrible younger [SEP]
,,,,


Target: insult


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (0.00),0.0,4.08,[CLS] Ana ##sta ##si ##ya is a horrible younger [SEP]
,,,,


Target: identity_hate


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (0.00),0.0,0.52,[CLS] Ana ##sta ##si ##ya is a horrible younger [SEP]
,,,,


['[CLS]', 'Mohammed', 'is', 'a', 'neat', 'm', '##us', '##lim', '[SEP]']
Target: toxic


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,1 (0.71),1.0,9.1,[CLS] Mohammed is a neat m ##us ##lim [SEP]
,,,,


Target: severe_toxic


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (0.01),0.0,9.02,[CLS] Mohammed is a neat m ##us ##lim [SEP]
,,,,


Target: obscene


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (0.25),0.0,10.77,[CLS] Mohammed is a neat m ##us ##lim [SEP]
,,,,


Target: threat


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (0.00),0.0,4.79,[CLS] Mohammed is a neat m ##us ##lim [SEP]
,,,,


Target: insult


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (0.23),0.0,10.72,[CLS] Mohammed is a neat m ##us ##lim [SEP]
,,,,


Target: identity_hate


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (0.18),0.0,8.7,[CLS] Mohammed is a neat m ##us ##lim [SEP]
,,,,


['[CLS]', 'Hi', '##ro', 'is', 'a', 'v', '##ile', 'as', '##ian', '[SEP]']
Target: toxic


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,0 (0.01),0.0,4.92,[CLS] Hi ##ro is a v ##ile as ##ian [SEP]
,,,,


Target: severe_toxic


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (0.00),0.0,5.1,[CLS] Hi ##ro is a v ##ile as ##ian [SEP]
,,,,


Target: obscene


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (0.00),0.0,6.11,[CLS] Hi ##ro is a v ##ile as ##ian [SEP]
,,,,


Target: threat


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (0.00),0.0,1.3,[CLS] Hi ##ro is a v ##ile as ##ian [SEP]
,,,,


Target: insult


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (0.01),0.0,5.78,[CLS] Hi ##ro is a v ##ile as ##ian [SEP]
,,,,


Target: identity_hate


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,0 (0.00),0.0,3.69,[CLS] Hi ##ro is a v ##ile as ##ian [SEP]
,,,,


In [5]:

train, test = load_csvs()
toxic_train = get_toxic_examples(train)
toxic_test = get_toxic_examples(test)
examples = toxic_train.iloc[0:2]
for index, example in examples.iterrows():
    sentence = example["comment_text"]
    labels = example.iloc[2:].to_list()
    explain(model, sentence, labels)




['[CLS]', 'CO', '##C', '##KS', '##UC', '##KE', '##R', 'B', '##EF', '##OR', '##E', 'YOU', 'P', '##IS', '##S', 'AR', '##O', '##UN', '##D', 'ON', 'M', '##Y', 'W', '##OR', '##K', '[SEP]']
Target: toxic


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,1 (0.99),1.0,10.19,[CLS] CO ##C ##KS ##UC ##KE ##R B ##EF ##OR ##E YOU P ##IS ##S AR ##O ##UN ##D ON M ##Y W ##OR ##K [SEP]
,,,,


Target: severe_toxic


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,0 (0.30),0.0,10.71,[CLS] CO ##C ##KS ##UC ##KE ##R B ##EF ##OR ##E YOU P ##IS ##S AR ##O ##UN ##D ON M ##Y W ##OR ##K [SEP]
,,,,


Target: obscene


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,1 (0.93),1.0,11.87,[CLS] CO ##C ##KS ##UC ##KE ##R B ##EF ##OR ##E YOU P ##IS ##S AR ##O ##UN ##D ON M ##Y W ##OR ##K [SEP]
,,,,


Target: threat


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (0.00),0.0,8.13,[CLS] CO ##C ##KS ##UC ##KE ##R B ##EF ##OR ##E YOU P ##IS ##S AR ##O ##UN ##D ON M ##Y W ##OR ##K [SEP]
,,,,


Target: insult


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,1 (0.89),1.0,8.98,[CLS] CO ##C ##KS ##UC ##KE ##R B ##EF ##OR ##E YOU P ##IS ##S AR ##O ##UN ##D ON M ##Y W ##OR ##K [SEP]
,,,,


Target: identity_hate


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (0.24),0.0,5.8,[CLS] CO ##C ##KS ##UC ##KE ##R B ##EF ##OR ##E YOU P ##IS ##S AR ##O ##UN ##D ON M ##Y W ##OR ##K [SEP]
,,,,


Target: toxic


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,1 (0.53),1.0,2.84,"[CLS] Hey . . . what is it . . @ | talk . What is it . . . an exclusive group of some W ##P T ##AL ##IB ##AN ##S . . . who are good at destroying , self - appointed pu ##rist who GA ##NG UP any one who asks them questions a ##bt their AN ##TI - S ##OC ##IA ##L and DE ##ST ##R ##UC ##TI ##VE ( non ) - contribution at W ##P ? Ask Sit ##yu ##sh to clean up his behavior than issue me non ##sen ##si ##cal warnings . . . [SEP]"
,,,,


Target: severe_toxic


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (0.00),0.0,4.46,"[CLS] Hey . . . what is it . . @ | talk . What is it . . . an exclusive group of some W ##P T ##AL ##IB ##AN ##S . . . who are good at destroying , self - appointed pu ##rist who GA ##NG UP any one who asks them questions a ##bt their AN ##TI - S ##OC ##IA ##L and DE ##ST ##R ##UC ##TI ##VE ( non ) - contribution at W ##P ? Ask Sit ##yu ##sh to clean up his behavior than issue me non ##sen ##si ##cal warnings . . . [SEP]"
,,,,


Target: obscene


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (0.43),0.0,5.24,"[CLS] Hey . . . what is it . . @ | talk . What is it . . . an exclusive group of some W ##P T ##AL ##IB ##AN ##S . . . who are good at destroying , self - appointed pu ##rist who GA ##NG UP any one who asks them questions a ##bt their AN ##TI - S ##OC ##IA ##L and DE ##ST ##R ##UC ##TI ##VE ( non ) - contribution at W ##P ? Ask Sit ##yu ##sh to clean up his behavior than issue me non ##sen ##si ##cal warnings . . . [SEP]"
,,,,


Target: threat


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (0.00),0.0,-0.15,"[CLS] Hey . . . what is it . . @ | talk . What is it . . . an exclusive group of some W ##P T ##AL ##IB ##AN ##S . . . who are good at destroying , self - appointed pu ##rist who GA ##NG UP any one who asks them questions a ##bt their AN ##TI - S ##OC ##IA ##L and DE ##ST ##R ##UC ##TI ##VE ( non ) - contribution at W ##P ? Ask Sit ##yu ##sh to clean up his behavior than issue me non ##sen ##si ##cal warnings . . . [SEP]"
,,,,


Target: insult


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (0.11),0.0,3.35,"[CLS] Hey . . . what is it . . @ | talk . What is it . . . an exclusive group of some W ##P T ##AL ##IB ##AN ##S . . . who are good at destroying , self - appointed pu ##rist who GA ##NG UP any one who asks them questions a ##bt their AN ##TI - S ##OC ##IA ##L and DE ##ST ##R ##UC ##TI ##VE ( non ) - contribution at W ##P ? Ask Sit ##yu ##sh to clean up his behavior than issue me non ##sen ##si ##cal warnings . . . [SEP]"
,,,,


Target: identity_hate


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (0.01),0.0,3.31,"[CLS] Hey . . . what is it . . @ | talk . What is it . . . an exclusive group of some W ##P T ##AL ##IB ##AN ##S . . . who are good at destroying , self - appointed pu ##rist who GA ##NG UP any one who asks them questions a ##bt their AN ##TI - S ##OC ##IA ##L and DE ##ST ##R ##UC ##TI ##VE ( non ) - contribution at W ##P ? Ask Sit ##yu ##sh to clean up his behavior than issue me non ##sen ##si ##cal warnings . . . [SEP]"
,,,,


In [12]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)
predictions, confidences = [], []
print(toxic_test)
for index, example in toxic_test.iterrows():
    sentence = example["comment_text"]
    labels = example.iloc[2:].to_list()
    prediction, conf = predict(model, sentence)
    predictions.append(prediction.cpu().detach().numpy())
    confidences.append(conf.cpu().detach().numpy())

confidences = np.stack(confidences)
mean_confidences = confidences.mean(axis=0)
print(mean_confidences)


                     id                                       comment_text  \
8      00091c35fa9d0465  == Arabs are committing genocide in Iraq, but ...   
24     0013fed3aeae76b7  DJ Robinson is gay as hell! he sucks his dick ...   
29     0017d4d47894af05               :Fuck off, you anti-semitic cunt.  |   
38     001d739c97bc2ae4  How dare you vandalize that page about the HMS...   
40     001eff4007dbb65b  ::No, he is an arrogant, self serving, immatur...   
...                 ...                                                ...   
63877  ff91c3d8a3e34398                          NIGEL IS A CRAZY IDIOT!!!   
63935  ffd49b8defd069d0    " \n ::Well, now don't I feel stupid....   ·  "   
63945  ffdf6854b41d9102  ==Fourth Baldrick possibly being cleverer than...   
63962  ffebe90c8d5acaba  " \n\n == IRAN == \n That’s right, Iran. It wa...   
63976  fffac2a094c8e0e2  MEL GIBSON IS A NAZI BITCH WHO MAKES SHITTY MO...   

       toxic  severe_toxic  obscene  threat  insult  identity_h

In [11]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)
all_predictions, all_confidences = [], []
print(test)
for index, example in test.iterrows():
    sentence = example["comment_text"]
    labels = example.iloc[2:].to_list()
    prediction, conf = predict(model, sentence)
    all_predictions.append(prediction.cpu().detach().numpy())
    all_confidences.append(conf.cpu().detach().numpy())

all_confidences = np.stack(all_confidences)
all_mean_confidences = all_confidences.mean(axis=0)
print(all_mean_confidences)


                     id                                       comment_text  \
0      0001ea8717f6de06  Thank you for understanding. I think very high...   
1      000247e83dcc1211                   :Dear god this site is horrible.   
2      0002f87b16116a7f  "::: Somebody will invariably try to add Relig...   
3      0003e1cccfd5a40a  " \n\n It says it right there that it IS a typ...   
4      00059ace3e3e9a53  " \n\n == Before adding a new product to the l...   
...                 ...                                                ...   
63973  fff8f64043129fa2  :Jerome, I see you never got around to this…! ...   
63974  fff9d70fe0722906  ==Lucky bastard== \n http://wikimediafoundatio...   
63975  fffa8a11c4378854  ==shame on you all!!!== \n\n You want to speak...   
63976  fffac2a094c8e0e2  MEL GIBSON IS A NAZI BITCH WHO MAKES SHITTY MO...   
63977  fffb5451268fb5ba  " \n\n == Unicorn lair discovery == \n\n Suppo...   

       toxic  severe_toxic  obscene  threat  insult  identity_h