# Tutorial 4 - CNN para la clasificación de textos

## 1. Prepración del Dataset para extraer la categoría de las noticias

Tenemos a nuestra disposición el dataset de **CNN Chile** (16.472 noticias).

El dataset toma la forma de archivo CSV con la estructura siguiente:
- ID, country, media_outlet, url, title, body, date

En un inicio, a partir de la URL vamos a extraer la categoría de la noticia.

In [1]:
import pandas as pd

DATASET_CSV="../datasets/CNNCHILE_RAW.csv"

df = pd.read_csv(DATASET_CSV,sep=',',error_bad_lines=False)
df = df.drop(['Unnamed: 0'], axis = 1) # Para suprimir la columna ID
df['date'] = pd.to_datetime(df['date']) # Para convertir la columna date en formato datetime

df

Unnamed: 0,country,media_outlet,url,title,body,date
0,chile,cnnchile,https://www.cnnchile.com/pais/caso-ambar-confi...,Caso Ámbar: Fiscalía confirma que cadáver fue ...,La Fiscalía confirmó este jueves el hallazgo d...,2020-08-06
1,chile,cnnchile,https://www.cnnchile.com/pais/parlamentarios-b...,Parlamentarios latinoamericanos piden a Bachel...,Un grupo de parlamentarios de distintas nacion...,2020-08-06
2,chile,cnnchile,https://www.cnnchile.com/pais/caso-ambar-detie...,Caso Ámbar: Detienen a la madre y su pareja po...,La Policía de Investigaciones (PDI) de Villa A...,2020-08-06
3,chile,cnnchile,https://www.cnnchile.com/pais/diputados-rn-pro...,Diputados RN presentan proyecto para regular r...,(Agencia Uno) – Luego de jurar y hacer oficial...,2020-08-06
4,chile,cnnchile,https://www.cnnchile.com/pais/diputado-mellado...,Mellado (RN) por crisis en La Araucanía: “¿Qui...,(Agencia Uno) – Cómo “insólito” calificó el di...,2020-08-06
5,chile,cnnchile,https://www.cnnchile.com/pais/providencia-plan...,Multas de hasta $50 millones: Providencia prep...,"(Agencia UNO) – La alcaldesa de Providencia, E...",2020-08-06
6,chile,cnnchile,https://www.cnnchile.com/pais/subsecretario-zu...,Subsecretario Zúñiga: “En lugar de ir a un mal...,"Durante la mañana de este jueves, el ministro ...",2020-08-06
7,chile,cnnchile,https://www.cnnchile.com/pais/formalizan-prime...,Formalizan el primer embargo a ex presidente E...,(Agencia UNO) – El ministro de fuero de la Cor...,2020-08-06
8,chile,cnnchile,https://www.cnnchile.com/pais/revocan-prision-...,Corte de Apelaciones revoca prisión preventiva...,"Este jueves, la Corte de Apelaciones de Santia...",2020-08-06
9,chile,cnnchile,https://www.cnnchile.com/pais/bellolio-impuest...,Bellolio por impuesto a los súper ricos: “Esta...,"El vocero de gobierno, Jaime Bellolio, abordó ...",2020-08-06


In [2]:
import re

for index, row in df.iterrows():
    url=row['url']
    obj = re.findall('(\w+)://([\w\-\.]+)/([\w\-]+).([\w\-]+)', url) 
    
    category=obj[0][2]
    
    df.loc[index,'category'] = category

- ¿Cuáles son las categorias?

In [None]:
#!pip install --user pandasql

In [10]:
from pandasql import sqldf

In [11]:
q="""SELECT category, count(*) FROM df GROUP BY category ORDER BY count(*) DESC;"""
result=sqldf(q)
result

Unnamed: 0,category,count(*)
0,pais,3048
1,deportes,2202
2,tendencias,2200
3,tecnologias,2196
4,cultura,2142
5,economia,2133
6,mundo,2128
7,coronavirus,298
8,lodijeronencnn,66
9,futuro360,27


- Guardaremos las categorias que contienen más de 1000 noticias y las noticias que tienen más de 5 caracteres

In [31]:
q="""SELECT * FROM df WHERE category IN ('pais','deportes','tendencias','tecnologias','cultura','economia','mundo') ORDER BY date;"""
df_CNN=sqldf(q)
df_CNN

Unnamed: 0,country,media_outlet,url,title,body,date,category
0,chile,cnnchile,https://www.cnnchile.com/tecnologias/jenyne-bu...,Jenyne Butterfly arrasa con su baile en el caño,"Se trata de una verdader acróbata, que deja co...",2011-12-09 00:00:00.000000,tecnologias
1,chile,cnnchile,https://www.cnnchile.com/tecnologias/osama-bin...,Osama Bin Laden fue lo más compartido en Faceb...,Estos son los diez temas más populares de Face...,2011-12-09 00:00:00.000000,tecnologias
2,chile,cnnchile,https://www.cnnchile.com/tecnologias/aprenda-a...,Aprenda a preparar sushi navideño,"""ochikeron"" se llama la usuaria que publicó un...",2011-12-09 00:00:00.000000,tecnologias
3,chile,cnnchile,https://www.cnnchile.com/tecnologias/vea-un-re...,Vea un resumen semanal de los videos más visto...,"Una vez más, el editor de CNN Chile.com, Ed...",2011-12-10 00:00:00.000000,tecnologias
4,chile,cnnchile,https://www.cnnchile.com/tecnologias/esta-apli...,“Esta aplicación es una mezcla de las matemáti...,,2011-12-11 00:00:00.000000,tecnologias
5,chile,cnnchile,https://www.cnnchile.com/tecnologias/la-silla-...,La silla para comer que mantiene con vida a un...,Megaesófago significa que el esófago se amplía...,2011-12-12 00:00:00.000000,tecnologias
6,chile,cnnchile,https://www.cnnchile.com/tecnologias/se-imagin...,¿Se imagina ir en el metro y que Keanu Reeves ...,Al parecer la fama y el dinero no le han quita...,2011-12-12 00:00:00.000000,tecnologias
7,chile,cnnchile,https://www.cnnchile.com/tecnologias/la-extran...,La extraña y peligrosa broma de la ex número u...,En pleno mes de vacaciones en el circuito prof...,2011-12-12 00:00:00.000000,tecnologias
8,chile,cnnchile,https://www.cnnchile.com/tecnologias/cientific...,Científicos dicen que tienen más indicios de l...,La Organización Europea para la Investigación...,2011-12-13 00:00:00.000000,tecnologias
9,chile,cnnchile,https://www.cnnchile.com/tecnologias/el-escena...,El escenario de la concentración en el mercado...,Actualmente en Chile hay más celulares que ha...,2011-12-13 00:00:00.000000,tecnologias


In [32]:
q="""SELECT * FROM df_CNN WHERE length(body)>5"""
df_CNN=sqldf(q)
df_CNN

Unnamed: 0,country,media_outlet,url,title,body,date,category
0,chile,cnnchile,https://www.cnnchile.com/tecnologias/jenyne-bu...,Jenyne Butterfly arrasa con su baile en el caño,"Se trata de una verdader acróbata, que deja co...",2011-12-09 00:00:00.000000,tecnologias
1,chile,cnnchile,https://www.cnnchile.com/tecnologias/osama-bin...,Osama Bin Laden fue lo más compartido en Faceb...,Estos son los diez temas más populares de Face...,2011-12-09 00:00:00.000000,tecnologias
2,chile,cnnchile,https://www.cnnchile.com/tecnologias/aprenda-a...,Aprenda a preparar sushi navideño,"""ochikeron"" se llama la usuaria que publicó un...",2011-12-09 00:00:00.000000,tecnologias
3,chile,cnnchile,https://www.cnnchile.com/tecnologias/vea-un-re...,Vea un resumen semanal de los videos más visto...,"Una vez más, el editor de CNN Chile.com, Ed...",2011-12-10 00:00:00.000000,tecnologias
4,chile,cnnchile,https://www.cnnchile.com/tecnologias/la-silla-...,La silla para comer que mantiene con vida a un...,Megaesófago significa que el esófago se amplía...,2011-12-12 00:00:00.000000,tecnologias
5,chile,cnnchile,https://www.cnnchile.com/tecnologias/se-imagin...,¿Se imagina ir en el metro y que Keanu Reeves ...,Al parecer la fama y el dinero no le han quita...,2011-12-12 00:00:00.000000,tecnologias
6,chile,cnnchile,https://www.cnnchile.com/tecnologias/la-extran...,La extraña y peligrosa broma de la ex número u...,En pleno mes de vacaciones en el circuito prof...,2011-12-12 00:00:00.000000,tecnologias
7,chile,cnnchile,https://www.cnnchile.com/tecnologias/cientific...,Científicos dicen que tienen más indicios de l...,La Organización Europea para la Investigación...,2011-12-13 00:00:00.000000,tecnologias
8,chile,cnnchile,https://www.cnnchile.com/tecnologias/el-escena...,El escenario de la concentración en el mercado...,Actualmente en Chile hay más celulares que ha...,2011-12-13 00:00:00.000000,tecnologias
9,chile,cnnchile,https://www.cnnchile.com/tecnologias/estamos-y...,“Estamos yendo a la génesis de lo que es la masa”,La expectación científica esperaba para este...,2011-12-13 00:00:00.000000,tecnologias


Guardaremos los datos en tres archivos CSV: CNN_train, CNN_valid, CNN_test

In [61]:
import numpy as np

valid, test, train = np.split(df_CNN, [int(.15*len(df_CNN)), int(.3*len(df_CNN))])

In [63]:
print(df_CNN.shape)
print(train.shape)
print(valid.shape)
print(test.shape)

(15809, 7)
(11067, 7)
(2371, 7)
(2371, 7)


In [68]:
train.to_csv("CNN_train.csv", encoding="UTF-8",index=False)
valid.to_csv("CNN_valid.csv", encoding="UTF-8",index=False)
test.to_csv("CNN_test.csv", encoding="UTF-8",index=False)

## 2. Clasificar textos según su categoria temática con una red convolucional

**Tarea**: Queremos aprender un modelo capaz de distinguir las noticias según su categoria.

### 2.1 Leer el dataset

In [21]:
#!pip install --user torch
#!pip install --user torchtext
#!pip install --user spacy

/usr/bin/python: No module named spacy


In [22]:
import torch
import spacy
import random
import torchtext
from torchtext import data
from torchtext import datasets

In [26]:
spacy_es = spacy.load('es_core_news_sm')

In [27]:
def tokenize_es(sentence):
    return [tok.text for tok in spacy_es.tokenizer(sentence)]

In [28]:
print(torch.__version__,spacy.__version__,torchtext.__version__)

1.5.1 2.1.2 0.6.0


In [29]:
TEXT = data.Field(tokenize=tokenize_es, batch_first = True)
CATEGORY = data.LabelField(dtype = torch.float)

In [30]:
fields = [(None, None),(None, None),(None, None),(None, None),('body', TEXT),(None, None),('category', CATEGORY)]

Se leen los CSV para tokenizarlos con Torchtext.data

In [69]:
import numpy as np

SEED = 1234

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

train_data, valid_data, test_data = data.TabularDataset.splits(
                                        path = '.',
                                        train = 'CNN_train.csv',
                                        validation= 'CNN_valid.csv',
                                        test = 'CNN_test.csv',
                                        format = 'csv',
                                        fields = fields,
                                        skip_header = True
)

In [None]:
#for i in range(50):
#    print(vars(valid_data[i])['body'].__len__())

In [None]:
#vars(test_data[12])

In [72]:
BATCH_SIZE = 32

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(
    (train_data, valid_data, test_data), 
    batch_size = BATCH_SIZE, 
    device = device,
    sort_key=lambda x:len(x.toxicity),
    sort_within_batch=False)

### 2.2 Crear la arquitectura CNN

Empezamos por cargar vectores de palabras para el inglés.

(Para cargar sus propios vectores, por ejemplo para procesor otros idiomas, se puede inspirarse de: https://www.innoq.com/en/blog/handling-german-text-with-torchtext/)

In [74]:
MAX_VOCAB_SIZE = 50000

vec = torchtext.vocab.Vectors('glove-sbwc.i25.vec.gz', cache='.')
TEXT.build_vocab(train_data, vectors=vec, max_size = MAX_VOCAB_SIZE, unk_init = torch.Tensor.normal_)

#TEXT.build_vocab(train_data, 
#                max_size = MAX_VOCAB_SIZE, 
#                 vectors = "glove.6B.100d", 
#                 unk_init = torch.Tensor.normal_)

CATEGORY.build_vocab(train_data)


  0%|          | 0/855380 [00:00<?, ?it/s][ASkipping token b'855380' with 1-dimensional vector [b'300']; likely a header

  0%|          | 1/855380 [00:00<104:09:42,  2.28it/s][A
  0%|          | 2/855380 [00:05<432:45:24,  1.82s/it][A
  0%|          | 39/855380 [00:05<303:11:14,  1.28s/it][A
  0%|          | 211/855380 [00:05<212:16:55,  1.12it/s][A
  0%|          | 440/855380 [00:06<148:38:42,  1.60it/s][A
  0%|          | 670/855380 [00:06<104:06:39,  2.28it/s][A
  0%|          | 899/855380 [00:06<72:53:36,  3.26it/s] [A
  0%|          | 1128/855380 [00:06<51:03:03,  4.65it/s][A
  0%|          | 1357/855380 [00:06<35:46:18,  6.63it/s][A
  0%|          | 1622/855380 [00:06<25:03:33,  9.46it/s][A
  0%|          | 1930/855380 [00:07<17:35:01, 13.48it/s][A
  0%|          | 2281/855380 [00:07<12:19:25, 19.23it/s][A
  0%|          | 2643/855380 [00:07<8:38:33, 27.41it/s] [A
  0%|          | 2907/855380 [00:07<6:11:36, 38.23it/s][A
  0%|          | 3154/855380 [00:07<4:21:4

 12%|█▏        | 103280/855380 [00:37<03:17, 3808.11it/s][A
 12%|█▏        | 103712/855380 [00:37<03:10, 3947.80it/s][A
 12%|█▏        | 104149/855380 [00:37<03:04, 4061.73it/s][A
 12%|█▏        | 104558/855380 [00:37<03:05, 4041.45it/s][A
 12%|█▏        | 104964/855380 [00:37<03:11, 3924.12it/s][A
 12%|█▏        | 105390/855380 [00:37<03:06, 4015.75it/s][A
 12%|█▏        | 105794/855380 [00:38<03:06, 4017.05it/s][A
 12%|█▏        | 106197/855380 [00:38<03:09, 3962.15it/s][A
 12%|█▏        | 106616/855380 [00:38<03:05, 4025.86it/s][A
 13%|█▎        | 107020/855380 [00:38<03:33, 3509.75it/s][A
 13%|█▎        | 107384/855380 [00:38<03:50, 3238.47it/s][A
 13%|█▎        | 107721/855380 [00:38<04:40, 2668.86it/s][A
 13%|█▎        | 108014/855380 [00:38<05:02, 2469.46it/s][A
 13%|█▎        | 108282/855380 [00:38<05:37, 2213.50it/s][A
 13%|█▎        | 108606/855380 [00:39<05:15, 2370.51it/s][A
 13%|█▎        | 108950/855380 [00:39<05:54, 2105.30it/s][A
 13%|█▎        | 109299/

 24%|██▍       | 205619/855380 [01:09<02:36, 4163.47it/s][A
 24%|██▍       | 206037/855380 [01:09<02:42, 4005.59it/s][A
 24%|██▍       | 206446/855380 [01:09<02:41, 4028.54it/s][A
 24%|██▍       | 206876/855380 [01:09<03:05, 3505.31it/s][A
 24%|██▍       | 207248/855380 [01:09<03:01, 3566.96it/s][A
 24%|██▍       | 207660/855380 [01:09<02:54, 3715.14it/s][A
 24%|██▍       | 208061/855380 [01:10<02:50, 3797.21it/s][A
 24%|██▍       | 208475/855380 [01:10<02:46, 3891.70it/s][A
 24%|██▍       | 208888/855380 [01:10<02:43, 3958.30it/s][A
 24%|██▍       | 209304/855380 [01:10<02:40, 4016.36it/s][A
 25%|██▍       | 209727/855380 [01:10<02:38, 4075.17it/s][A
 25%|██▍       | 210155/855380 [01:10<02:36, 4134.48it/s][A
 25%|██▍       | 210583/855380 [01:10<02:34, 4176.69it/s][A
 25%|██▍       | 211009/855380 [01:10<02:33, 4200.80it/s][A
 25%|██▍       | 211432/855380 [01:10<02:33, 4208.52it/s][A
 25%|██▍       | 211854/855380 [01:10<02:43, 3925.33it/s][A
 25%|██▍       | 212251/

 37%|███▋      | 317058/855380 [01:37<02:06, 4247.87it/s][A
 37%|███▋      | 317491/855380 [01:37<02:05, 4271.43it/s][A
 37%|███▋      | 317932/855380 [01:38<02:04, 4310.51it/s][A
 37%|███▋      | 318371/855380 [01:38<02:03, 4332.15it/s][A
 37%|███▋      | 318805/855380 [01:38<02:05, 4286.31it/s][A
 37%|███▋      | 319241/855380 [01:38<02:04, 4306.40it/s][A
 37%|███▋      | 319682/855380 [01:38<02:03, 4336.93it/s][A
 37%|███▋      | 320123/855380 [01:38<02:02, 4357.87it/s][A
 37%|███▋      | 320564/855380 [01:38<02:02, 4372.19it/s][A
 38%|███▊      | 321006/855380 [01:38<02:01, 4384.65it/s][A
 38%|███▊      | 321445/855380 [01:38<02:01, 4383.33it/s][A
 38%|███▊      | 321884/855380 [01:38<02:01, 4382.65it/s][A
 38%|███▊      | 322323/855380 [01:39<02:01, 4382.39it/s][A
 38%|███▊      | 322762/855380 [01:39<02:01, 4383.34it/s][A
 38%|███▊      | 323201/855380 [01:39<02:01, 4380.75it/s][A
 38%|███▊      | 323640/855380 [01:39<02:01, 4377.51it/s][A
 38%|███▊      | 324079/

 50%|█████     | 431601/855380 [02:05<01:40, 4221.37it/s][A
 51%|█████     | 432025/855380 [02:05<01:40, 4222.51it/s][A
 51%|█████     | 432449/855380 [02:06<01:46, 3977.50it/s][A
 51%|█████     | 432879/855380 [02:06<01:43, 4067.97it/s][A
 51%|█████     | 433318/855380 [02:06<01:41, 4158.44it/s][A
 51%|█████     | 433755/855380 [02:06<01:39, 4219.43it/s][A
 51%|█████     | 434192/855380 [02:06<01:38, 4262.04it/s][A
 51%|█████     | 434628/855380 [02:06<01:38, 4289.73it/s][A
 51%|█████     | 435059/855380 [02:06<01:37, 4294.92it/s][A
 51%|█████     | 435490/855380 [02:06<01:43, 4043.83it/s][A
 51%|█████     | 435898/855380 [02:06<01:51, 3766.15it/s][A
 51%|█████     | 436320/855380 [02:07<01:47, 3889.78it/s][A
 51%|█████     | 436743/855380 [02:07<01:45, 3984.45it/s][A
 51%|█████     | 437169/855380 [02:07<01:42, 4061.26it/s][A
 51%|█████     | 437596/855380 [02:07<01:41, 4120.79it/s][A
 51%|█████     | 438024/855380 [02:07<01:40, 4165.69it/s][A
 51%|█████▏    | 438452/

 63%|██████▎   | 535586/855380 [02:37<01:24, 3790.98it/s][A
 63%|██████▎   | 535998/855380 [02:37<01:37, 3288.16it/s][A
 63%|██████▎   | 536404/855380 [02:37<01:31, 3486.16it/s][A
 63%|██████▎   | 536818/855380 [02:37<01:27, 3659.13it/s][A
 63%|██████▎   | 537206/855380 [02:37<01:32, 3432.15it/s][A
 63%|██████▎   | 537608/855380 [02:37<01:28, 3588.30it/s][A
 63%|██████▎   | 538013/855380 [02:37<01:25, 3712.71it/s][A
 63%|██████▎   | 538396/855380 [02:38<02:05, 2521.81it/s][A
 63%|██████▎   | 538745/855380 [02:38<01:55, 2749.22it/s][A
 63%|██████▎   | 539068/855380 [02:38<02:06, 2491.63it/s][A
 63%|██████▎   | 539446/855380 [02:38<01:53, 2773.42it/s][A
 63%|██████▎   | 539760/855380 [02:38<03:09, 1666.49it/s][A
 63%|██████▎   | 540032/855380 [02:39<02:47, 1885.23it/s][A
 63%|██████▎   | 540286/855380 [02:39<02:48, 1871.94it/s][A
 63%|██████▎   | 540519/855380 [02:39<02:53, 1816.85it/s][A
 63%|██████▎   | 540733/855380 [02:39<02:56, 1783.18it/s][A
 63%|██████▎   | 540934/

 75%|███████▌  | 643973/855380 [03:07<01:04, 3268.41it/s][A
 75%|███████▌  | 644322/855380 [03:07<01:05, 3240.08it/s][A
 75%|███████▌  | 644735/855380 [03:07<01:00, 3463.13it/s][A
 75%|███████▌  | 645097/855380 [03:07<01:03, 3306.82it/s][A
 75%|███████▌  | 645497/855380 [03:07<01:00, 3486.08it/s][A
 76%|███████▌  | 645922/855380 [03:07<00:56, 3683.47it/s][A
 76%|███████▌  | 646344/855380 [03:07<00:54, 3828.43it/s][A
 76%|███████▌  | 646768/855380 [03:07<00:52, 3941.39it/s][A
 76%|███████▌  | 647170/855380 [03:08<00:57, 3616.03it/s][A
 76%|███████▌  | 647605/855380 [03:08<00:54, 3807.07it/s][A
 76%|███████▌  | 648026/855380 [03:08<00:52, 3918.21it/s][A
 76%|███████▌  | 648426/855380 [03:08<01:05, 3158.96it/s][A
 76%|███████▌  | 648846/855380 [03:08<01:00, 3410.51it/s][A
 76%|███████▌  | 649273/855380 [03:08<00:56, 3629.68it/s][A
 76%|███████▌  | 649699/855380 [03:08<00:54, 3798.02it/s][A
 76%|███████▌  | 650097/855380 [03:08<01:09, 2960.87it/s][A
 76%|███████▌  | 650505/

 88%|████████▊ | 752739/855380 [03:36<00:24, 4208.36it/s][A
 88%|████████▊ | 753161/855380 [03:36<00:24, 4192.00it/s][A
 88%|████████▊ | 753597/855380 [03:36<00:24, 4239.16it/s][A
 88%|████████▊ | 754022/855380 [03:36<00:25, 3988.94it/s][A
 88%|████████▊ | 754425/855380 [03:36<00:26, 3797.76it/s][A
 88%|████████▊ | 754810/855380 [03:36<00:26, 3780.42it/s][A
 88%|████████▊ | 755240/855380 [03:36<00:25, 3922.38it/s][A
 88%|████████▊ | 755668/855380 [03:36<00:24, 4021.46it/s][A
 88%|████████▊ | 756100/855380 [03:36<00:24, 4105.89it/s][A
 88%|████████▊ | 756536/855380 [03:36<00:23, 4175.69it/s][A
 88%|████████▊ | 756970/855380 [03:37<00:23, 4220.10it/s][A
 89%|████████▊ | 757407/855380 [03:37<00:22, 4262.54it/s][A
 89%|████████▊ | 757835/855380 [03:37<00:22, 4256.18it/s][A
 89%|████████▊ | 758262/855380 [03:37<00:22, 4245.53it/s][A
 89%|████████▊ | 758688/855380 [03:37<00:22, 4248.98it/s][A
 89%|████████▊ | 759127/855380 [03:37<00:22, 4289.00it/s][A
 89%|████████▉ | 759566/

In [75]:
import torch.nn as nn
import torch.nn.functional as F

class CNN1d(nn.Module):
    def __init__(self, vocab_size, embedding_dim, n_filters, filter_sizes, output_dim, 
                 dropout, pad_idx):
        
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx = pad_idx)
        
        self.convs = nn.ModuleList([
                                    nn.Conv1d(in_channels = embedding_dim, 
                                              out_channels = n_filters, 
                                              kernel_size = fs)
                                    for fs in filter_sizes
                                    ])
        
        self.fc = nn.Linear(len(filter_sizes) * n_filters, output_dim)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, text):
        
        #text = [batch size, sent len]
        
        embedded = self.embedding(text)
                
        #embedded = [batch size, sent len, emb dim]
        
        embedded = embedded.permute(0, 2, 1)
        
        #embedded = [batch size, emb dim, sent len]
        
        conved = [F.relu(conv(embedded)) for conv in self.convs]
            
        #conved_n = [batch size, n_filters, sent len - filter_sizes[n] + 1]
        
        pooled = [F.max_pool1d(conv, conv.shape[2]).squeeze(2) for conv in conved]
        
        #pooled_n = [batch size, n_filters]
        
        cat = self.dropout(torch.cat(pooled, dim = 1))
        
        #cat = [batch size, n_filters * len(filter_sizes)]
            
        return self.fc(cat)

In [79]:
INPUT_DIM = len(TEXT.vocab)
EMBEDDING_DIM = 300
N_FILTERS = 100
FILTER_SIZES = [3,4,5]
OUTPUT_DIM = 1
DROPOUT = 0.5
PAD_IDX = TEXT.vocab.stoi[TEXT.pad_token]

model = CNN1d(INPUT_DIM, EMBEDDING_DIM, N_FILTERS, FILTER_SIZES, OUTPUT_DIM, DROPOUT, PAD_IDX)

In [80]:
INPUT_DIM

50002

In [81]:
pretrained_embeddings = TEXT.vocab.vectors
UNK_IDX = TEXT.vocab.stoi[TEXT.unk_token]

model.embedding.weight.data.copy_(pretrained_embeddings)
model.embedding.weight.data[UNK_IDX] = torch.zeros(EMBEDDING_DIM)
model.embedding.weight.data[PAD_IDX] = torch.zeros(EMBEDDING_DIM)

### 2.3 Funciones para optimizar el modelo

In [82]:
import torch.optim as optim

optimizer = optim.Adam(model.parameters())

criterion = nn.BCEWithLogitsLoss()

model = model.to(device)
criterion = criterion.to(device)

In [97]:
def train(model, iterator, optimizer, criterion):
    
    epoch_loss = 0
    epoch_acc = 0
    
    model.train()
    
    for batch in iterator:
        
        optimizer.zero_grad()
        
        predictions = model(batch.body).squeeze(1)
        
        loss = criterion(predictions, batch.category)
        
        acc = binary_accuracy(predictions, batch.category)
        
        loss.backward()
        
        optimizer.step()
        
        epoch_loss += loss.item()
        epoch_acc += acc.item()
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

In [98]:
import time

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

### 2.4 Funciones para evaluar el modelo

In [99]:
def binary_accuracy(preds, y):
    """
    Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8
    """

    #round predictions to the closest integer
    rounded_preds = torch.round(torch.sigmoid(preds))
    correct = (rounded_preds == y).float() #convert into float for division 
    acc = correct.sum() / len(correct)
    return acc

In [100]:
def evaluate(model, iterator, criterion):
    
    epoch_loss = 0
    epoch_acc = 0
    
    model.eval()
    
    with torch.no_grad():
    
        for batch in iterator:
           
            predictions = model(batch.comment_text).squeeze(1)
            
            loss = criterion(predictions, batch.toxicity)
            
            acc = binary_accuracy(predictions, batch.toxicity)

            epoch_loss += loss.item()
            epoch_acc += acc.item()
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

### 2.5 Optimización del modelo

In [None]:
N_EPOCHS = 2

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):

    start_time = time.time()
    
    train_loss, train_acc = train(model, train_iterator, optimizer, criterion)
    valid_loss, valid_acc = evaluate(model, valid_iterator, criterion)
    
    end_time = time.time()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        nombre = './tematic-model-CNN'+'_ep'+str(epoch+1)+'.pt'
        torch.save({'epoca': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'Valid_loss': best_valid_loss}, nombre)
    
    print("Epoch:"+str(epoch+1:02)+" | Epoch Time: "+str(epoch_mins)+"m "+str(epoch_secs)+"s")
    print("\tTrain Loss: "+str(train_loss:.3f)+" | Train Acc: "+str(train_acc*100:.2f)+"%")
    print("\t Val. Loss: "+str(valid_loss:.3f)+" |  Val. Acc: "+str(valid_acc*100:.2f)+"%"')

### 2.6 Evaluación del modelo

In [57]:
best_model = CNN1d(INPUT_DIM, EMBEDDING_DIM, N_FILTERS, FILTER_SIZES, OUTPUT_DIM, DROPOUT, PAD_IDX)


In [58]:
pretrained_embeddings = TEXT.vocab.vectors
UNK_IDX = TEXT.vocab.stoi[TEXT.unk_token]

best_model.embedding.weight.data.copy_(pretrained_embeddings)
best_model.embedding.weight.data[UNK_IDX] = torch.zeros(EMBEDDING_DIM)
best_model.embedding.weight.data[PAD_IDX] = torch.zeros(EMBEDDING_DIM)

In [59]:
name = './tematic-model-CNN'+'_ep'+str(2)+'.pt'
best_model.load_state_dict(torch.load(name, map_location=torch.device('cpu'))['model_state_dict'])

<All keys matched successfully>

In [60]:
from sklearn.metrics import f1_score,confusion_matrix, classification_report

In [61]:
prediction_test = []
labels_test=[]
for batch in test_iterator:
    labels_test.append(batch.category.cpu().detach().numpy())
    predictions = best_model(batch.body.cpu()).squeeze(1)
    rounded_preds = torch.round(torch.sigmoid(predictions))
    prediction_test.append(rounded_preds.detach().numpy())
    

y_true = np.concatenate(labels_test)
y_pred = np.concatenate(prediction_test)

In [62]:
display(y_pred,y_true)

array([0., 0., 0., ..., 0., 0., 0.], dtype=float32)

array([0., 0., 0., ..., 0., 0., 0.], dtype=float32)

In [63]:
cm = confusion_matrix(y_true, y_pred)
display(cm)

print(classification_report(y_true, y_pred))

array([[1760,   15],
       [ 119,  105]])

              precision    recall  f1-score   support

         0.0       0.94      0.99      0.96      1775
         1.0       0.88      0.47      0.61       224

    accuracy                           0.93      1999
   macro avg       0.91      0.73      0.79      1999
weighted avg       0.93      0.93      0.92      1999

