# Hate Speech Detection - Part III: Testing multiple Datasets

In this notebook we use the general hatespeech model we trained in the previous notebook [(here)](https://www.kaggle.com/jessedingley/hatespeech-detection-model#Results) that predicts if a tweet conveys hate speech or not. With this general model we calculate various metrics (accuracy, F1, precision, recall, Matthews' Correlation Coefficient (MCC)) on three test sets:
   - general hatespeech test set
   - racism test set
   - sexism test set
   
Although not done in this notebook, we then compare these metrics against the metrics calculated with the racism and sexism models.

# 0. Setup

### Imports

In [1]:
# for gpu use, tensors etc...
import torch

# import tokenizer, model for sequence classification, trainer and training arguments
from transformers import AutoModelForSequenceClassification, AutoTokenizer

# for opening csv
import csv

# for computing metrics
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, matthews_corrcoef

### Model and Tokenizer setup

We're using the `cardiffnlp/twitter-roberta-base` model with a hate speech classifier head.


In [2]:
MODEL_NAME = "../input/hatespeech-detection-model/hatespeech_model-distilbert-base-uncased"

# define model and tokenizer
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side = "right")

### GPU

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 1. Data

### Retrieve Test Data
We also need to specify what data we want (hate / racism / sexism)

In [4]:
DATA = "hate" # ("hate" / "racism" / "sexism")

In [5]:
test_sets = {"hate": "../input/hatespeech-detection-data-bertweet/hatespeech_test.csv",
             "train": "../input/hatespeech-detection-data-bertweet/hatespeech_train.csv",
             "racism": "../input/racismdata/racism_test.csv",
             "sexism" : "../input/sexismdata/sexism_test.csv"}

In [6]:
# open test data
unformatted_test_data = {}
for data_type, filename in test_sets.items(): 
    with open(filename, "r", encoding="utf8") as f:
        unformatted_test_data[data_type] = [{k: v for k, v in row.items()} for row in csv.DictReader(f, skipinitialspace=True)] 

In [7]:
test_data = {}
for data_type, data in unformatted_test_data.items():
    instances = []
    for row in unformatted_test_data[data_type]:
        if data_type == "racism":
            instances.append({"tweet": row["Text"], "label": row["oh_label"]})
        elif data_type == "sexism":
            instances.append({"tweet": row["text"], "label": row["label"]})
        elif data_type == "hate" or data_type == "train":
            instances.append(row)
        else:
            raise Exception(f"Data {data_type} not valid")
    test_data[data_type] = instances 

### Split into labels and tweets

In [8]:
test_data_tweets = {}
test_data_labels = {}
for data_type, data in test_data.items():
    test_data_tweets[data_type] = [row["tweet"] for row in data]
    test_data_labels[data_type] = [int(row["label"]) for row in data]

# 2. Testing

In [9]:
model.to(device)
model.eval()

DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0): TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
       

### Function to predict output of a tweet

In [10]:
def get_sent_pred(input_str,device=device):
    tok = tokenizer(input_str, return_tensors="pt", truncation=True, padding=True)
    tok.to(device)
    with torch.no_grad():
        pred = model(**tok)
    return pred['logits'].argmax(-1).item()

### Compute metrics

In [11]:
def compute_metrics_test(y_true,y_pred):
    precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='binary')
    acc = accuracy_score(y_true, y_pred)
    mcc = matthews_corrcoef(y_true, y_pred) 
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall,
        'mcc': mcc
}

In [12]:
for data_type in test_data_tweets.keys():
    print(f"For {data_type.upper()}: the following results: ")
    print("")
    metrics = compute_metrics_test(test_data_labels[data_type], [get_sent_pred(sent) for sent in test_data_tweets[data_type]])
    for metric, val in metrics.items():
        print(f"{metric}: {val}")
    print("")
    print("------------------------------------------")
    print("")
    print("")

For HATE: the following results: 

accuracy: 0.8631232361241769
f1: 0.7144259077526988
precision: 0.8053097345132744
recall: 0.6419753086419753
mcc: 0.6328721067580543

------------------------------------------


For TRAIN: the following results: 

accuracy: 0.9056948364602357
f1: 0.8009113214160533
precision: 0.8785082660515187
recall: 0.7359098228663447
mcc: 0.7447360569604938

------------------------------------------


For RACISM: the following results: 

accuracy: 0.8164717844433147
f1: 0.44716692189892804
precision: 0.40331491712707185
recall: 0.5017182130584192
mcc: 0.3415970546665962

------------------------------------------


For SEXISM: the following results: 

accuracy: 0.7674581005586593
f1: 0.4126984126984127
precision: 0.5763546798029556
recall: 0.32142857142857145
mcc: 0.3007161675288282

------------------------------------------


