## Contextualized model

Let's train a model but this time **taking** the context into account

In [1]:
%load_ext autoreload
%autoreload 2

from hatedetection import load_datasets

train_dataset, dev_dataset, test_dataset = load_datasets()


In [2]:
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer

model_name = "../models/bert-contextualized-hate-speech-es/"


id2label = {0: 'Not hateful', 1: 'Hateful'}
label2id = {v:k for k,v in id2label.items()}

model = AutoModelForSequenceClassification.from_pretrained(model_name, return_dict=True, num_labels=2)

model.config.id2label = id2label
model.config.label2id = label2id

model.eval()

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.model_max_length = 256

In [3]:
def tokenize(batch, context=True, padding='max_length', truncation=True):
    """
    Apply tokenization
    
    Arguments:
    ---------
    
    use_context: boolean (default True)
        Whether to add the context to the 
    """
    
    if context:
        args = [batch['context'], batch['text']]
    else:
        args = [batch['text']]
        
    return tokenizer(*args, padding='max_length', truncation=True)

batch_size = 32
eval_batch_size = 16

my_tokenize = lambda x: tokenize(x, context=True)

train_dataset = train_dataset.map(my_tokenize, batched=True, batch_size=batch_size)
dev_dataset = dev_dataset.map(my_tokenize, batched=True, batch_size=eval_batch_size)
test_dataset = test_dataset.map(my_tokenize, batched=True, batch_size=eval_batch_size)



HBox(children=(FloatProgress(value=0.0, max=1087.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=544.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=690.0), HTML(value='')))




In [4]:
tokenizer.decode(train_dataset["input_ids"][4])

'[CLS] Potenciar Acompañamiento : lanzan un incentivo de [UNK] 8. 500 por mes para jóvenes en recuperación de adicciones [SEP] usuario " Me recupere " - " toma [UNK] " - una semana despues cuando ya se morfaron el chamuyo pintan un par de bolsas del transa que ellos mismos ya conocen. Negocio redondo boloh. Querian que la guita circule? Ahi tene [UNK] [UNK] y encima ni roban para comprar xq ya la tienen en el bolsillo [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [P

In [5]:

def format_dataset(dataset):
    dataset = dataset.map(lambda examples: {'labels': examples['HATEFUL']})
    dataset.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])
    return dataset

train_dataset = format_dataset(train_dataset)
dev_dataset = format_dataset(dev_dataset)
test_dataset = format_dataset(test_dataset)

HBox(children=(FloatProgress(value=0.0, max=34756.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=8689.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=11040.0), HTML(value='')))




Lo cargamos sólo para evaluar 🤗

In [6]:
from hatedetection.metrics import compute_hate_metrics
from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir='./results',
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=eval_batch_size,
)


trainer = Trainer(
    model=model,
    args=training_args,
    compute_metrics=compute_hate_metrics,
    train_dataset=train_dataset,
    eval_dataset=dev_dataset,
)


In [7]:
import pandas as pd
pd.options.display.max_columns = 40
pd.set_option('display.float_format', lambda x: '%.5f' % x)

df_results = pd.DataFrame([trainer.evaluate(dev_dataset)])

df_results.T

Unnamed: 0,0
eval_loss,0.67211
eval_accuracy,0.91472
eval_f1,0.83312
eval_precision,0.85251
eval_recall,0.81685
eval_runtime,69.4009
eval_samples_per_second,125.2
init_mem_cpu_alloc_delta,59564.0
init_mem_gpu_alloc_delta,439938560.0
init_mem_cpu_peaked_delta,18258.0
