In [1]:
import pandas as pd
import numpy as np
import random
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm, trange

from torchtext.data import Field
from torchtext.vocab import GloVe
from gensim.models.doc2vec import Doc2Vec, TaggedDocument
from transformers import BertTokenizer, BertModel

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
import torch.optim as optim
import torch.nn.functional as F
from torchinfo import summary

from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import confusion_matrix, accuracy_score, classification_report, f1_score

import gdown
from utils import preprocessing
from utils.evaluation import DataSetText, SexismClassifier, infer

sns.set_style('darkgrid')
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Evaluación

## Modelos

Primero necesitamos descargar los modelos ya entrenados.

In [4]:
!mkdir models

In [3]:
url = 'https://drive.google.com/uc?id=1V0VbdwXDcFP6f0GrdCna1SQqqkZpBOLW'
output = 'models/sexism-classifier-task1.pt'

gdown.download(url, output)

In [2]:
model = SexismClassifier()
model.load_state_dict(torch.load('models/sexism-classifier-task1.pt'))
model.to(device)
model.eval()

summary(model)

Layer (type:depth-idx)                   Param #
├─BertModel: 1-1                         --
|    └─BertEmbeddings: 2-1               --
|    |    └─Embedding: 3-1               81,315,072
|    |    └─Embedding: 3-2               393,216
|    |    └─Embedding: 3-3               1,536
|    |    └─LayerNorm: 3-4               1,536
|    |    └─Dropout: 3-5                 --
|    └─BertEncoder: 2-2                  --
|    |    └─ModuleList: 3-6              85,054,464
|    └─BertPooler: 2-3                   --
|    |    └─Linear: 3-7                  590,592
|    |    └─Tanh: 3-8                    --
├─Dropout: 1-2                           --
├─Sequential: 1-3                        --
|    └─Linear: 2-4                       1,538
|    └─Softmax: 2-5                      --
Total params: 167,357,954
Trainable params: 167,357,954
Non-trainable params: 0

## Datos

In [3]:
test_df = pd.read_csv('../../Data/EXIST2021_test.tsv', sep='\t')

# Un simple pre-procesamiento
test_df['text'] = test_df['text'].apply(lambda text: preprocessing.preprocess(text))

# Codificamos las etiquetas
labels_dict = {'non-sexist': 0, 'sexist': 1}

test_df['label'] = test_df['task1'].apply(lambda x: labels_dict[x])

test_df_en = test_df[test_df['language'] == 'en']
test_df_es = test_df[test_df['language'] == 'es']

test_df.head()

Unnamed: 0,test_case,id,source,language,text,task1,task2,label
0,EXIST2021,6978,gab,en,pennsylvania state rep horrifies with opening ...,non-sexist,non-sexist,0
1,EXIST2021,6979,twitter,en,"he sounds like as ass , and very condescending .",non-sexist,non-sexist,0
2,EXIST2021,6980,twitter,en,"lol ! "" this behavior of not letting men tell ...",sexist,ideological-inequality,1
3,EXIST2021,6981,twitter,en,rights ? i mean yeah most women especially the...,sexist,ideological-inequality,1
4,EXIST2021,6982,twitter,en,the jack manifold appreciation i ’ m seeing is...,non-sexist,non-sexist,0


In [4]:
ds_text_test = DataSetText(test_df)
ds_text_test_en = DataSetText(test_df_en)
ds_text_test_es = DataSetText(test_df_es)

print(f'Test: {len(ds_text_test)}')
print(f'Test en: {len(ds_text_test_en)}')
print(f'Test es: {len(ds_text_test_es)}')

Test: 4368
Test en: 2208
Test es: 2160


In [5]:
BATCH_SIZE = 8

test_dl = DataLoader(
    ds_text_test,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4)

test_en_dl = DataLoader(
    ds_text_test_en,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4)

test_es_dl = DataLoader(
    ds_text_test_es,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4)

## Rendimiento

In [6]:
%time y_test, y_pred = infer(model, test_dl)
%time y_test_en, y_pred_en = infer(model, test_en_dl)
%time y_test_es, y_pred_es = infer(model, test_es_dl)

100%|██████████| 546/546 [00:21<00:00, 25.89it/s]
  0%|          | 0/276 [00:00<?, ?it/s]

CPU times: user 20.6 s, sys: 320 ms, total: 20.9 s
Wall time: 21.1 s


100%|██████████| 276/276 [00:10<00:00, 25.62it/s]
  0%|          | 0/270 [00:00<?, ?it/s]

CPU times: user 10.4 s, sys: 233 ms, total: 10.7 s
Wall time: 10.8 s


100%|██████████| 270/270 [00:10<00:00, 25.40it/s]

CPU times: user 10.3 s, sys: 237 ms, total: 10.5 s
Wall time: 10.6 s





En general se tiene que:

In [7]:
print(classification_report(y_test, y_pred, target_names=['non-sexist', 'sexist']))

print(f'Accuracy: {round(100*accuracy_score(y_test, y_pred), 4)}')
print(f'F1 score: {round(100*f1_score(y_test, y_pred, average="macro"), 4)}')

              precision    recall  f1-score   support

  non-sexist       0.75      0.73      0.74      2087
      sexist       0.76      0.77      0.77      2281

    accuracy                           0.75      4368
   macro avg       0.75      0.75      0.75      4368
weighted avg       0.75      0.75      0.75      4368

Accuracy: 75.2747
F1 score: 75.1999


En inglés se tiene que:

In [8]:
print(classification_report(y_test_en, y_pred_en, target_names=['non-sexist', 'sexist']))


print(f'Accuracy: {round(100*accuracy_score(y_test_en, y_pred_en), 4)}')
print(f'F1 score: {round(100*f1_score(y_test_en, y_pred_en, average="macro"), 4)}')

              precision    recall  f1-score   support

  non-sexist       0.77      0.68      0.72      1050
      sexist       0.74      0.81      0.77      1158

    accuracy                           0.75      2208
   macro avg       0.75      0.75      0.75      2208
weighted avg       0.75      0.75      0.75      2208

Accuracy: 74.8641
F1 score: 74.5908


En español se tiene que:

In [9]:
print(classification_report(y_test_es, y_pred_es, target_names=['non-sexist', 'sexist']))


print(f'Accuracy: {round(100*accuracy_score(y_test_es, y_pred_es), 4)}')
print(f'F1 score: {round(100*f1_score(y_test_es, y_pred_es, average="macro"), 4)}')

              precision    recall  f1-score   support

  non-sexist       0.73      0.78      0.76      1037
      sexist       0.79      0.73      0.76      1123

    accuracy                           0.76      2160
   macro avg       0.76      0.76      0.76      2160
weighted avg       0.76      0.76      0.76      2160

Accuracy: 75.6944
F1 score: 75.6938
