## Context-only category classifier

En esta notebook veamos de hacer un clasificador que sólo use el contexto

In [1]:
"""
Script to train hatespeech classifier
"""
import fire
import torch
import transformers
from transformers import (
    Trainer, TrainingArguments, AutoModelForSequenceClassification, AutoTokenizer
)
from hatedetection import BertForSequenceMultiClassification, load_datasets, extended_hate_categories
from hatedetection.metrics import compute_category_metrics


def load_model_and_tokenizer(model_name, max_length):
    """
    Load model and tokenizer
    """

    model = BertForSequenceMultiClassification.from_pretrained(
        model_name, return_dict=True, num_labels=len(extended_hate_categories)
    )

    model.train()

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.model_max_length = max_length

    return model, tokenizer

def tokenize(tokenizer, batch, padding='max_length', truncation=True):
    """
    Apply tokenization

    Arguments:
    ---------

    use_context: boolean (default True)
        Whether to add the context to the
    """

    return tokenizer(batch['context'], padding='max_length', truncation=True)





print("Loading datasets... ", end="")
train_dataset, dev_dataset, test_dataset = load_datasets()

train_dataset = train_dataset.filter(lambda x: x["HATEFUL"] > 0)
dev_dataset = dev_dataset.filter(lambda x: x["HATEFUL"] > 0)
test_dataset = test_dataset.filter(lambda x: x["HATEFUL"] > 0)



Loading datasets... 

HBox(children=(FloatProgress(value=0.0, max=37.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))




Acá cambiamos el train dataset

In [2]:
import pandas as pd
from datasets import Dataset, Value, ClassLabel, Features


df = pd.DataFrame(columns=["context"] + extended_hate_categories)
df.set_index("context", inplace=True)

for example in train_dataset:
    for cat in extended_hate_categories:
        if example[cat] > 0:
            df.loc[example["context"], cat] = 1
    
df.fillna(0, inplace=True)

df.reset_index(inplace=True)

features = Features({
    'context': Value('string'),
})


for cat in extended_hate_categories:
    """
    Set for WOMEN, LGBTI...and also for CALLS
    """
    features[cat] = ClassLabel(num_classes=2, names=["NO", "YES"])
    
train_dataset = Dataset.from_pandas(df, features=features)
train_dataset

Dataset({
    features: ['context', 'CALLS', 'WOMEN', 'LGBTI', 'RACISM', 'CLASS', 'POLITICS', 'DISABLED', 'APPEARANCE', 'CRIMINAL'],
    num_rows: 938
})

In [3]:
print("Done")
max_length = 128
model_name = 'dccuchile/bert-base-spanish-wwm-cased'
device = "cuda" if torch.cuda.is_available() else "cpu"


print("")
print("Loading model and tokenizer... ", end="")
model, tokenizer = load_model_and_tokenizer(model_name, max_length)
print("Done")


Done

Loading model and tokenizer... 

Some weights of the model checkpoint at dccuchile/bert-base-spanish-wwm-cased were not used when initializing BertForSequenceMultiClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias']
- This IS expected if you are initializing BertForSequenceMultiClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceMultiClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceMultiClassification were not initialized from the model c

Done


In [4]:
batch_size = 32
eval_batch_size = 16

my_tokenize = lambda batch: tokenize(tokenizer, batch)

train_dataset = train_dataset.map(my_tokenize, batched=True, batch_size=batch_size)
dev_dataset = dev_dataset.map(my_tokenize, batched=True, batch_size=eval_batch_size)
test_dataset = test_dataset.map(my_tokenize, batched=True, batch_size=eval_batch_size)


def format_dataset(dataset):
    def get_category_labels(examples):
        return {'labels': torch.Tensor([examples[cat] for cat in extended_hate_categories])}
    dataset = dataset.map(get_category_labels)
    dataset.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])
    return dataset

train_dataset = format_dataset(train_dataset)
dev_dataset = format_dataset(dev_dataset)
test_dataset = format_dataset(test_dataset)

HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=87.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=113.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=938.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1387.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1797.0), HTML(value='')))




In [5]:
"""
Finally, train!
"""

epochs = 5
warmup_proportion = 0.1

print("\n"*3, "Training...")

total_steps = (epochs * len(train_dataset)) // batch_size
warmup_steps = int(warmup_proportion * total_steps)
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=epochs,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=eval_batch_size,
    warmup_steps=warmup_steps,
    evaluation_strategy="epoch",
    do_eval=False,
    weight_decay=0.01,
    logging_dir='./logs',
    load_best_model_at_end=True,
    metric_for_best_model="mean_f1",
)

results = []

trainer = Trainer(
    model=model,
    args=training_args,
    compute_metrics=compute_category_metrics,
    train_dataset=train_dataset,
    eval_dataset=dev_dataset,
)

trainer.train()








 Training...


Epoch,Training Loss,Validation Loss,Calls F1,Women F1,Lgbti F1,Racism F1,Class F1,Politics F1,Disabled F1,Appearance F1,Criminal F1,Mean F1,Mean Precision,Mean Recall,Runtime,Samples Per Second
1,No log,0.45482,0.576618,0.457567,0.475019,0.728045,0.475019,0.350546,0.546817,0.635844,0.450475,0.521772,0.554399,0.568378,5.9753,232.121
2,No log,0.39081,0.709518,0.771063,0.475019,0.794578,0.617639,0.572413,0.580333,0.663977,0.922859,0.6786,0.698546,0.719235,6.0663,228.64
3,No log,0.368213,0.703577,0.777788,0.698703,0.830264,0.78483,0.627151,0.630644,0.715689,0.922682,0.743481,0.751727,0.784538,5.9617,232.653
4,No log,0.361994,0.691619,0.795058,0.722104,0.777589,0.765629,0.606947,0.603272,0.723078,0.924132,0.734381,0.743162,0.774535,5.9317,233.829
5,No log,0.370474,0.695512,0.781806,0.80251,0.802866,0.730062,0.604266,0.625106,0.729309,0.930079,0.744613,0.73732,0.794893,5.9388,233.547


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


{'eval_loss': 0.3704743567172922,
 'eval_calls_f1': 0.6955116397903768,
 'eval_women_f1': 0.7818064754688584,
 'eval_lgbti_f1': 0.8025095232151991,
 'eval_racism_f1': 0.8028663563117158,
 'eval_class_f1': 0.73006166390101,
 'eval_politics_f1': 0.6042661347740503,
 'eval_disabled_f1': 0.6251062735378481,
 'eval_appearance_f1': 0.7293086035593406,
 'eval_criminal_f1': 0.9300793573356751,
 'eval_mean_f1': 0.7446129322052002,
 'eval_mean_precision': 0.7373201847076416,
 'eval_mean_recall': 0.7948930263519287,
 'eval_runtime': 5.7201,
 'eval_samples_per_second': 242.478,
 'epoch': 5.0,
 'eval_mem_cpu_alloc_delta': 204441,
 'eval_mem_gpu_alloc_delta': 0,
 'eval_mem_cpu_peaked_delta': 271473,
 'eval_mem_gpu_peaked_delta': 69427200}

In [6]:
import pandas as pd
pd.options.display.max_columns = 40
pd.set_option('display.float_format', lambda x: '%.5f' % x)

df_results = pd.DataFrame([trainer.evaluate(dev_dataset)])

df_results.T

Unnamed: 0,0
eval_loss,0.37047
eval_calls_f1,0.69551
eval_women_f1,0.78181
eval_lgbti_f1,0.80251
eval_racism_f1,0.80287
eval_class_f1,0.73006
eval_politics_f1,0.60427
eval_disabled_f1,0.62511
eval_appearance_f1,0.72931
eval_criminal_f1,0.93008
