<img src="https://s8.hostingkartinok.com/uploads/images/2018/08/308b49fcfbc619d629fe4604bceb67ac.jpg" width=500, height=450>
<h3 style="text-align: center;"><b>Физтех-Школа Прикладной математики и информатики (ФПМИ) МФТИ</b></h3>

---

# Задание 3

## Классификация текстов

В этом задании вам предстоит попробовать несколько методов, используемых в задаче классификации, а также понять насколько хорошо модель понимает смысл слов и какие слова в примере влияют на результат.

In [1]:
import pandas as pd
import numpy as np
import torch

from torchtext.legacy import datasets

from torchtext.legacy.data import Field, LabelField
from torchtext.legacy.data import BucketIterator

from torchtext.vocab import Vectors, GloVe

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random
from tqdm.autonotebook import tqdm

В этом задании мы будем использовать библиотеку torchtext. Она довольна проста в использовании и поможет нам сконцентрироваться на задаче, а не на написании Dataloader-а.

In [2]:
TEXT = Field(sequential=True, lower=True, include_lengths=True)  # Поле текста
LABEL = LabelField(dtype=torch.float)  # Поле метки

In [3]:
SEED = 1234

np.random.seed(SEED)
torch.manual_seed(SEED)
torch.random.manual_seed(SEED)
torch.cuda.random.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True

In [4]:
def calc_f1_score(model, iter, output_calc_func):
    model.eval()

    tp = 0.0
    tn = 0.0
    fp = 0.0
    fn = 0.0
    for batch in iter:
        with torch.no_grad():
            outputs = output_calc_func(model, batch)
            outputs = (torch.sigmoid(outputs).cpu() > 0.5).int().squeeze(1)
            labels = batch.label.cpu().int()

            tp += (labels * outputs).sum()
            tn += ((1 - labels) * (1 - outputs)).sum()
            fp += ((1 - labels) * outputs).sum()
            fn += (labels * (1 - outputs)).sum()

    epsilon = 1e-7
    precision = tp / (tp + fp + epsilon)
    recall = tp / (tp + fn + epsilon)

    f1_score = 2 * (precision * recall) / (precision + recall + epsilon)
    f1_score = f1_score.item()

    return f1_score

Напишем функции для обучения моделей

In [5]:
def train_rnn(model, loss_func, train_iter, val_iter, max_epochs, patience, max_grad_norm=2):
    min_loss = np.inf

    cur_patience = 0

    for epoch in range(1, max_epochs + 1):
        train_loss = 0.0
        model.train()
        pbar = tqdm(enumerate(train_iter), total=len(train_iter), leave=False)
        pbar.set_description(f"Epoch {epoch}")
        for it, batch in pbar: 
            opt.zero_grad()

            outputs = model(batch.text[0], batch.text[1].cpu())
            loss = loss_func(outputs, batch.label.unsqueeze(1))
            loss.backward()
            if max_grad_norm is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            opt.step()

            train_loss += loss.cpu().detach()

        train_loss /= len(train_iter)
        val_loss = 0.0
        model.eval()
        pbar = tqdm(enumerate(val_iter), total=len(val_iter), leave=False)
        pbar.set_description(f"Epoch {epoch}")
        for it, batch in pbar:
            with torch.no_grad():
                outputs = model(batch.text[0], batch.text[1].cpu())
                loss = loss_func(outputs, batch.label.unsqueeze(1)).cpu()

                val_loss += loss

        val_loss /= len(val_iter)
        if val_loss < min_loss:
            min_loss = val_loss
            best_model = model.state_dict()
        else:
            cur_patience += 1
            if cur_patience == patience:
                cur_patience = 0
                break
        
        print('Epoch: {}, Training Loss: {}, Validation Loss: {}'.format(epoch, train_loss, val_loss))
    model.load_state_dict(best_model)

def freeze_embeddings(model, req_grad=False):
    for c_p in model.embedding.parameters():
        c_p.requires_grad = req_grad

def train_cnn(model, loss_func, train_iter, val_iter, max_epochs, patience, max_grad_norm=2, num_freeze_epochs=0):
    min_loss = np.inf

    cur_patience = 0

    freeze_embeddings(model)

    for epoch in range(1, max_epochs + 1):
        train_loss = 0.0
        model.train()
        pbar = tqdm(enumerate(train_iter), total=len(train_iter), leave=False)
        pbar.set_description(f"Epoch {epoch}")

        if epoch > num_freeze_epochs:
            freeze_embeddings(model, True)

        for it, batch in pbar:
            opt.zero_grad()

            outputs = model(batch.text)
            loss = loss_func(outputs, batch.label.unsqueeze(1))
            loss.backward()
            if max_grad_norm is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            opt.step()

            train_loss += loss.cpu().detach()

        train_loss /= len(train_iter)
        val_loss = 0.0
        model.eval()
        pbar = tqdm(enumerate(val_iter), total=len(val_iter), leave=False)
        pbar.set_description(f"Epoch {epoch}")
        for it, batch in pbar:
            with torch.no_grad():
                outputs = model(batch.text)
                loss = loss_func(outputs, batch.label.unsqueeze(1)).cpu()

                val_loss += loss

        val_loss /= len(val_iter)
        if val_loss < min_loss:
            min_loss = val_loss
            best_model = model.state_dict()
        else:
            cur_patience += 1
            if cur_patience == patience:
                cur_patience = 0
                break
        
        print('Epoch: {}, Training Loss: {}, Validation Loss: {}'.format(epoch, train_loss, val_loss))
    model.load_state_dict(best_model)

Датасет на котором мы будем проводить эксперементы это комментарии к фильмам из сайта IMDB.

In [6]:
train, test = datasets.IMDB.splits(TEXT, LABEL)  # загрузим датасет
train, valid = train.split(random_state=random.seed(SEED))  # разобьем на части

In [7]:
TEXT.build_vocab(train)
LABEL.build_vocab(train)

In [8]:
device = "cuda" if torch.cuda.is_available() else "cpu"

train_iter, valid_iter, test_iter = BucketIterator.splits(
    (train, valid, test), 
    batch_size = 64,
    sort_within_batch = True,
    device = device)

## RNN

Для начала попробуем использовать рекурентные нейронные сети. На семинаре вы познакомились с GRU, вы можете также попробовать LSTM. Можно использовать для классификации как hidden_state, так и output последнего токена.

In [9]:
class RNNBaseline(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, n_layers, 
                 bidirectional, dropout, pad_idx):
        
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx = pad_idx)
        
        self.rnn = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim, num_layers=n_layers, bidirectional=bidirectional, dropout=dropout)
        
        self.dropout = nn.Dropout(dropout)

        self.fc = nn.Linear(2 * hidden_dim, output_dim)
        
        
    def forward(self, text, text_lengths):
        
        #text = [sent len, batch size]
        
        embedded = self.embedding(text)
        
        #embedded = [sent len, batch size, emb dim]
        
        #pack sequence
        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, text_lengths)
        
        # cell arg for LSTM, remove for GRU
        packed_output, (hidden, cell) = self.rnn(packed_embedded)
        #unpack sequence
        output, output_lengths = nn.utils.rnn.pad_packed_sequence(packed_output)  

        #output = [sent len, batch size, hid dim * num directions]
        #output over padding tokens are zero tensors
        
        #hidden = [num layers * num directions, batch size, hid dim]
        #cell = [num layers * num directions, batch size, hid dim]
        
        #concat the final forward (hidden[-2,:,:]) and backward (hidden[-1,:,:]) hidden layers
        #and apply dropout
        
        hidden = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), 1)
        hidden = self.dropout(hidden)

        #hidden = [batch size, hid dim * num directions] or [batch_size, hid dim * num directions]
            
        return self.fc(hidden)

Поиграйтесь с гиперпараметрами

In [10]:
vocab_size = len(TEXT.vocab)
emb_dim = 300
hidden_dim = 256
output_dim = 1
n_layers = 2
bidirectional = True
dropout = 0.5
PAD_IDX = TEXT.vocab.stoi[TEXT.pad_token]

rnn_model = RNNBaseline(
    vocab_size=vocab_size,
    embedding_dim=emb_dim,
    hidden_dim=hidden_dim,
    output_dim=output_dim,
    n_layers=n_layers,
    bidirectional=bidirectional,
    dropout=dropout,
    pad_idx=PAD_IDX
).to(device)

In [11]:
opt = torch.optim.Adam(rnn_model.parameters(), lr=1e-4)
loss_func = nn.BCEWithLogitsLoss()

Обучите сетку! Используйте любые вам удобные инструменты, Catalyst, PyTorch Lightning или свои велосипеды.

In [12]:
max_epochs = 50
patience=15

train_rnn(rnn_model, loss_func, train_iter, valid_iter, max_epochs, patience)

HBox(children=(FloatProgress(value=0.0, max=274.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=118.0), HTML(value='')))

Epoch: 1, Training Loss: 0.6742545366287231, Validation Loss: 0.5824978351593018


HBox(children=(FloatProgress(value=0.0, max=274.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=118.0), HTML(value='')))

Epoch: 2, Training Loss: 0.5109237432479858, Validation Loss: 0.5463767051696777


HBox(children=(FloatProgress(value=0.0, max=274.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=118.0), HTML(value='')))

Epoch: 3, Training Loss: 0.4130497872829437, Validation Loss: 0.4198527932167053


HBox(children=(FloatProgress(value=0.0, max=274.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=118.0), HTML(value='')))

Epoch: 4, Training Loss: 0.3480762541294098, Validation Loss: 0.42421337962150574


HBox(children=(FloatProgress(value=0.0, max=274.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=118.0), HTML(value='')))

Epoch: 5, Training Loss: 0.2857218384742737, Validation Loss: 0.418121874332428


HBox(children=(FloatProgress(value=0.0, max=274.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=118.0), HTML(value='')))

Epoch: 6, Training Loss: 0.24492843449115753, Validation Loss: 0.4250085651874542


HBox(children=(FloatProgress(value=0.0, max=274.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=118.0), HTML(value='')))

Epoch: 7, Training Loss: 0.19758240878582, Validation Loss: 0.4353417754173279


HBox(children=(FloatProgress(value=0.0, max=274.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=118.0), HTML(value='')))

Epoch: 8, Training Loss: 0.15108461678028107, Validation Loss: 0.46530306339263916


HBox(children=(FloatProgress(value=0.0, max=274.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=118.0), HTML(value='')))

Epoch: 9, Training Loss: 0.12880533933639526, Validation Loss: 0.5643097758293152


HBox(children=(FloatProgress(value=0.0, max=274.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=118.0), HTML(value='')))

Epoch: 10, Training Loss: 0.09997963905334473, Validation Loss: 0.6360642313957214


HBox(children=(FloatProgress(value=0.0, max=274.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=118.0), HTML(value='')))

Epoch: 11, Training Loss: 0.0769193023443222, Validation Loss: 0.5590224266052246


HBox(children=(FloatProgress(value=0.0, max=274.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=118.0), HTML(value='')))

Epoch: 12, Training Loss: 0.06668514758348465, Validation Loss: 0.6920998692512512


HBox(children=(FloatProgress(value=0.0, max=274.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=118.0), HTML(value='')))

Epoch: 13, Training Loss: 0.048860158771276474, Validation Loss: 0.6697909832000732


HBox(children=(FloatProgress(value=0.0, max=274.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=118.0), HTML(value='')))

Epoch: 14, Training Loss: 0.041297007352113724, Validation Loss: 0.7189866900444031


HBox(children=(FloatProgress(value=0.0, max=274.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=118.0), HTML(value='')))

Epoch: 15, Training Loss: 0.03422073647379875, Validation Loss: 0.7576189041137695


HBox(children=(FloatProgress(value=0.0, max=274.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=118.0), HTML(value='')))

Epoch: 16, Training Loss: 0.03396064043045044, Validation Loss: 0.8378239274024963


HBox(children=(FloatProgress(value=0.0, max=274.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=118.0), HTML(value='')))

Epoch: 17, Training Loss: 0.022049326449632645, Validation Loss: 0.8449418544769287


HBox(children=(FloatProgress(value=0.0, max=274.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=118.0), HTML(value='')))

Epoch: 18, Training Loss: 0.015442287549376488, Validation Loss: 0.8026885390281677


HBox(children=(FloatProgress(value=0.0, max=274.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=118.0), HTML(value='')))



Посчитайте f1-score вашего классификатора на тестовом датасете.

**Ответ**:

In [13]:
calc_f1_score(rnn_model, test_iter, lambda model, batch: model(batch.text[0], batch.text[1].cpu()))

0.8374814987182617

## CNN

![](https://www.researchgate.net/publication/333752473/figure/fig1/AS:769346934673412@1560438011375/Standard-CNN-on-text-classification.png)

Для классификации текстов также часто используют сверточные нейронные сети. Идея в том, что как правило сентимент содержат словосочетания из двух-трех слов, например "очень хороший фильм" или "невероятная скука". Проходясь сверткой по этим словам мы получим какой-то большой скор и выхватим его с помощью MaxPool. Далее идет обычная полносвязная сетка. Важный момент: свертки применяются не последовательно, а параллельно. Давайте попробуем!

In [14]:
TEXT = Field(sequential=True, lower=True, batch_first=True)  # batch_first тк мы используем conv  
LABEL = LabelField(batch_first=True, dtype=torch.float)

train, tst = datasets.IMDB.splits(TEXT, LABEL)
trn, vld = train.split(random_state=random.seed(SEED))

TEXT.build_vocab(trn)
LABEL.build_vocab(trn)

device = "cuda" if torch.cuda.is_available() else "cpu"

In [15]:
train_iter, val_iter, test_iter = BucketIterator.splits(
        (trn, vld, tst),
        batch_sizes=(128, 256, 256),
        sort=False,
        sort_key= lambda x: len(x.src),
        sort_within_batch=False,
        device=device,
        repeat=False,
)

Вы можете использовать Conv2d с `in_channels=1, kernel_size=(kernel_sizes[0], emb_dim))` или Conv1d c `in_channels=emb_dim, kernel_size=kernel_size[0]`. Но хорошенько подумайте над shape в обоих случаях.

In [16]:
class CNN(nn.Module):
    def __init__(
        self,
        vocab_size,
        emb_dim,
        out_channels,
        kernel_sizes,
        dropout=0.5,
    ):
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, emb_dim)
        self.conv_0 = nn.Conv1d(in_channels=emb_dim, out_channels=out_channels, kernel_size=kernel_sizes[0])
        
        self.conv_1 = nn.Conv1d(in_channels=emb_dim, out_channels=out_channels, kernel_size=kernel_sizes[1])
        
        self.conv_2 = nn.Conv1d(in_channels=emb_dim, out_channels=out_channels, kernel_size=kernel_sizes[2])
        
        self.fc = nn.Linear(len(kernel_sizes) * out_channels, 1)
        
        self.dropout = nn.Dropout(dropout)
        
        
    def forward(self, text):        
        embedded = self.embedding(text)

        embedded = embedded.permute(0, 2, 1)
        
        conved_0 = F.relu(self.conv_0(embedded))  # may be reshape here
        conved_1 = F.relu(self.conv_1(embedded))  # may be reshape here
        conved_2 = F.relu(self.conv_2(embedded))  # may be reshape here

        pooled_0 = F.max_pool1d(conved_0, conved_0.shape[2]).squeeze(2)
        pooled_1 = F.max_pool1d(conved_1, conved_1.shape[2]).squeeze(2)
        pooled_2 = F.max_pool1d(conved_2, conved_2.shape[2]).squeeze(2)

        cat = self.dropout(torch.cat((pooled_0, pooled_1, pooled_2), dim=1))

        return self.fc(cat)

In [17]:
kernel_sizes = [3, 4, 5]
vocab_size = len(TEXT.vocab)
out_channels=64
dropout = 0.5
dim = 300

cnn_model = CNN(vocab_size=vocab_size, emb_dim=dim, out_channels=out_channels, kernel_sizes=kernel_sizes, dropout=dropout).to(device)

In [18]:
opt = torch.optim.Adam(cnn_model.parameters(), lr=1e-4)
loss_func = nn.BCEWithLogitsLoss()

Обучите!

In [19]:
max_epochs = 50
patience = 15

train_cnn(cnn_model, loss_func, train_iter, val_iter, max_epochs, patience)

HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 1, Training Loss: 0.7576324343681335, Validation Loss: 0.6517045497894287


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 2, Training Loss: 0.6899340748786926, Validation Loss: 0.6062830686569214


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 3, Training Loss: 0.6387899518013, Validation Loss: 0.5613199472427368


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 4, Training Loss: 0.5910921692848206, Validation Loss: 0.5337907075881958


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 5, Training Loss: 0.555692732334137, Validation Loss: 0.5114783048629761


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 6, Training Loss: 0.5308852791786194, Validation Loss: 0.496740460395813


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 7, Training Loss: 0.5096702575683594, Validation Loss: 0.48413896560668945


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 8, Training Loss: 0.49173012375831604, Validation Loss: 0.4738166630268097


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 9, Training Loss: 0.4728925824165344, Validation Loss: 0.46500885486602783


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 10, Training Loss: 0.4634188115596771, Validation Loss: 0.4571587145328522


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 11, Training Loss: 0.45103561878204346, Validation Loss: 0.4493470788002014


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 12, Training Loss: 0.4329078495502472, Validation Loss: 0.4436053931713104


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 13, Training Loss: 0.4215008616447449, Validation Loss: 0.4369823634624481


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 14, Training Loss: 0.410084068775177, Validation Loss: 0.43029505014419556


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 15, Training Loss: 0.40883302688598633, Validation Loss: 0.42483919858932495


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 16, Training Loss: 0.39374157786369324, Validation Loss: 0.41871586441993713


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 17, Training Loss: 0.38184672594070435, Validation Loss: 0.41376128792762756


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 18, Training Loss: 0.3776858150959015, Validation Loss: 0.4085441529750824


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 19, Training Loss: 0.3630562424659729, Validation Loss: 0.40397489070892334


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 20, Training Loss: 0.34618082642555237, Validation Loss: 0.3981413245201111


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 21, Training Loss: 0.3446395993232727, Validation Loss: 0.39361831545829773


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 22, Training Loss: 0.3412877321243286, Validation Loss: 0.3911988139152527


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 23, Training Loss: 0.3309648633003235, Validation Loss: 0.3863234221935272


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 24, Training Loss: 0.3180355727672577, Validation Loss: 0.3843051493167877


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 25, Training Loss: 0.3123960793018341, Validation Loss: 0.3800780177116394


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 26, Training Loss: 0.3083990216255188, Validation Loss: 0.3769741356372833


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 27, Training Loss: 0.3016382157802582, Validation Loss: 0.37317579984664917


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 28, Training Loss: 0.29282230138778687, Validation Loss: 0.3705296814441681


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 29, Training Loss: 0.2849389314651489, Validation Loss: 0.36675870418548584


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 30, Training Loss: 0.27899599075317383, Validation Loss: 0.36684614419937134


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 31, Training Loss: 0.26917049288749695, Validation Loss: 0.3610278069972992


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 32, Training Loss: 0.26171180605888367, Validation Loss: 0.35848888754844666


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 33, Training Loss: 0.2526353895664215, Validation Loss: 0.3561694025993347


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 34, Training Loss: 0.24546083807945251, Validation Loss: 0.3540089428424835


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 35, Training Loss: 0.24242974817752838, Validation Loss: 0.3528509736061096


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 36, Training Loss: 0.23450540006160736, Validation Loss: 0.35080963373184204


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 37, Training Loss: 0.22836841642856598, Validation Loss: 0.3484395146369934


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 38, Training Loss: 0.21914775669574738, Validation Loss: 0.3465931713581085


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 39, Training Loss: 0.21657702326774597, Validation Loss: 0.34479638934135437


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 40, Training Loss: 0.20906610786914825, Validation Loss: 0.34354647994041443


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 41, Training Loss: 0.20689404010772705, Validation Loss: 0.3415157198905945


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 42, Training Loss: 0.19985994696617126, Validation Loss: 0.3426453173160553


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 43, Training Loss: 0.19117160141468048, Validation Loss: 0.3394249379634857


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 44, Training Loss: 0.18380019068717957, Validation Loss: 0.3377845585346222


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 45, Training Loss: 0.1847643256187439, Validation Loss: 0.336868554353714


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 46, Training Loss: 0.1741999089717865, Validation Loss: 0.3346464931964874


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 47, Training Loss: 0.17601940035820007, Validation Loss: 0.3336796462535858


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 48, Training Loss: 0.16273018717765808, Validation Loss: 0.33324959874153137


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 49, Training Loss: 0.16219404339790344, Validation Loss: 0.33389589190483093


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 50, Training Loss: 0.15842033922672272, Validation Loss: 0.3318353593349457


Посчитайте f1-score вашего классификатора.

**Ответ**:

In [20]:
calc_f1_score(cnn_model, test_iter, lambda model, batch: model(batch.text))

0.8496467471122742

## Интерпретируемость

Посмотрим, куда смотрит наша модель. Достаточно запустить код ниже.

In [21]:
!pip install -q captum

In [28]:
from captum.attr import LayerIntegratedGradients, TokenReferenceBase, visualization

PAD_IND = TEXT.vocab.stoi['pad']

token_reference = TokenReferenceBase(reference_token_idx=PAD_IND)
lig = LayerIntegratedGradients(cnn_model, cnn_model.embedding)

In [33]:
def forward_with_softmax(model, inp):
    logits = model(inp)
    return torch.softmax(logits, 0)[0][1]

def forward_with_sigmoid(model, input):
    return torch.sigmoid(model(input))


# accumalate couple samples in this array for visualization purposes
vis_data_records_ig = []

def interpret_sentence(model, sentence, min_len = 7, label = 0):
    model.eval()
    text = [tok for tok in TEXT.tokenize(sentence)]
    if len(text) < min_len:
        text += ['pad'] * (min_len - len(text))
    indexed = [TEXT.vocab.stoi[t] for t in text]

    model.zero_grad()

    input_indices = torch.tensor(indexed, device=device)
    input_indices = input_indices.unsqueeze(0)
    
    # input_indices dim: [sequence_length]
    seq_length = min_len

    # predict
    pred = forward_with_sigmoid(model, input_indices).item()
    pred_ind = round(pred)

    # generate reference indices for each sample
    reference_indices = token_reference.generate_reference(seq_length, device=device).unsqueeze(0)

    # compute attributions and approximation delta using layer integrated gradients
    attributions_ig, delta = lig.attribute(input_indices, reference_indices, \
                                           n_steps=5000, return_convergence_delta=True)

    print('pred: ', LABEL.vocab.itos[pred_ind], '(', '%.2f'%pred, ')', ', delta: ', abs(delta))

    add_attributions_to_visualizer(attributions_ig, text, pred, pred_ind, label, delta, vis_data_records_ig)
    
def add_attributions_to_visualizer(attributions, text, pred, pred_ind, label, delta, vis_data_records):
    attributions = attributions.sum(dim=2).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    attributions = attributions.cpu().detach().numpy()

    # storing couple samples in an array for visualization purposes
    vis_data_records.append(visualization.VisualizationDataRecord(
                            attributions,
                            pred,
                            LABEL.vocab.itos[pred_ind],
                            LABEL.vocab.itos[label],
                            LABEL.vocab.itos[1],
                            attributions.sum(),       
                            text,
                            delta))

In [34]:
interpret_sentence(cnn_model, 'It was a fantastic performance !', label=1)
interpret_sentence(cnn_model, 'Best film ever', label=1)
interpret_sentence(cnn_model, 'Such a great show!', label=1)
interpret_sentence(cnn_model, 'It was a horrible movie', label=0)
interpret_sentence(cnn_model, 'I\'ve never watched something as bad', label=0)
interpret_sentence(cnn_model, 'It is a disgusting movie!', label=0)

pred:  pos ( 0.99 ) , delta:  tensor([0.0001], device='cuda:0', dtype=torch.float64)
pred:  neg ( 0.30 ) , delta:  tensor([7.6005e-06], device='cuda:0', dtype=torch.float64)
pred:  pos ( 0.96 ) , delta:  tensor([1.3040e-05], device='cuda:0', dtype=torch.float64)
pred:  neg ( 0.09 ) , delta:  tensor([3.4205e-05], device='cuda:0', dtype=torch.float64)
pred:  neg ( 0.12 ) , delta:  tensor([7.6346e-05], device='cuda:0', dtype=torch.float64)
pred:  pos ( 0.69 ) , delta:  tensor([1.7903e-05], device='cuda:0', dtype=torch.float64)


Попробуйте добавить свои примеры!

In [35]:
print('Visualize attributions based on Integrated Gradients')
visualization.visualize_text(vis_data_records_ig)

Visualize attributions based on Integrated Gradients


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
pos,pos (0.99),pos,1.48,It was a fantastic performance ! pad
,,,,
pos,neg (0.30),pos,0.73,Best film ever pad pad pad pad
,,,,
pos,pos (0.96),pos,1.14,Such a great show! pad pad pad
,,,,
neg,neg (0.09),pos,-0.45,It was a horrible movie pad pad
,,,,
neg,neg (0.12),pos,-0.23,I've never watched something as bad pad
,,,,


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
pos,pos (0.99),pos,1.48,It was a fantastic performance ! pad
,,,,
pos,neg (0.30),pos,0.73,Best film ever pad pad pad pad
,,,,
pos,pos (0.96),pos,1.14,Such a great show! pad pad pad
,,,,
neg,neg (0.09),pos,-0.45,It was a horrible movie pad pad
,,,,
neg,neg (0.12),pos,-0.23,I've never watched something as bad pad
,,,,


## Эмбэдинги слов

Вы ведь не забыли, как мы можем применить знания о word2vec и GloVe. Давайте попробуем!

In [36]:
TEXT.build_vocab(trn, vectors='glove.840B.300d')
LABEL.build_vocab(trn)

word_embeddings = TEXT.vocab.vectors

In [37]:
train, tst = datasets.IMDB.splits(TEXT, LABEL)
trn, vld = train.split(random_state=random.seed(SEED))

device = "cuda" if torch.cuda.is_available() else "cpu"

train_iter, val_iter, test_iter = BucketIterator.splits(
        (trn, vld, tst),
        batch_sizes=(128, 256, 256),
        sort=False,
        sort_key= lambda x: len(x.src),
        sort_within_batch=False,
        device=device,
        repeat=False,
)

In [38]:
kernel_sizes = [3, 4, 5]
vocab_size = len(TEXT.vocab)
out_channels=64
dropout = 0.5
dim = 300

cnn_and_embeddings_model = CNN(vocab_size=vocab_size, emb_dim=dim, out_channels=out_channels, kernel_sizes=kernel_sizes, dropout=dropout)

word_embeddings = TEXT.vocab.vectors

prev_shape = cnn_and_embeddings_model.embedding.weight.shape

cnn_and_embeddings_model.embedding.weight.data = torch.clone(word_embeddings)

assert prev_shape == cnn_and_embeddings_model.embedding.weight.shape
cnn_and_embeddings_model.to(device)

CNN(
  (embedding): Embedding(202268, 300)
  (conv_0): Conv1d(300, 64, kernel_size=(3,), stride=(1,))
  (conv_1): Conv1d(300, 64, kernel_size=(4,), stride=(1,))
  (conv_2): Conv1d(300, 64, kernel_size=(5,), stride=(1,))
  (fc): Linear(in_features=192, out_features=1, bias=True)
  (dropout): Dropout(p=0.5, inplace=False)
)

In [39]:
opt = torch.optim.Adam(cnn_and_embeddings_model.parameters(), lr=1e-4)
loss_func = nn.BCEWithLogitsLoss()

Вы знаете, что делать.

In [40]:
max_epochs = 50
patience = 15

train_cnn(cnn_and_embeddings_model, loss_func, train_iter, val_iter, max_epochs, patience, num_freeze_epochs=0)

HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 1, Training Loss: 0.6762555241584778, Validation Loss: 0.6392307877540588


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 2, Training Loss: 0.6033423542976379, Validation Loss: 0.5459020733833313


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 3, Training Loss: 0.49749210476875305, Validation Loss: 0.4511328339576721


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 4, Training Loss: 0.42663270235061646, Validation Loss: 0.40656739473342896


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 5, Training Loss: 0.38875171542167664, Validation Loss: 0.38155120611190796


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 6, Training Loss: 0.36204424500465393, Validation Loss: 0.3653334379196167


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 7, Training Loss: 0.33915168046951294, Validation Loss: 0.3537197709083557


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 8, Training Loss: 0.31842857599258423, Validation Loss: 0.3433770537376404


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 9, Training Loss: 0.30003029108047485, Validation Loss: 0.33546820282936096


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 10, Training Loss: 0.28027257323265076, Validation Loss: 0.3285302519798279


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 11, Training Loss: 0.26262593269348145, Validation Loss: 0.32217657566070557


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 12, Training Loss: 0.2462833672761917, Validation Loss: 0.3174632489681244


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 13, Training Loss: 0.23081234097480774, Validation Loss: 0.31211230158805847


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 14, Training Loss: 0.2127387821674347, Validation Loss: 0.3088242709636688


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 15, Training Loss: 0.19626827538013458, Validation Loss: 0.30575281381607056


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 16, Training Loss: 0.18307611346244812, Validation Loss: 0.3018968999385834


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 17, Training Loss: 0.16972137987613678, Validation Loss: 0.2993110418319702


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 18, Training Loss: 0.15460239350795746, Validation Loss: 0.29740700125694275


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 19, Training Loss: 0.14086130261421204, Validation Loss: 0.2963598966598511


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 20, Training Loss: 0.12892025709152222, Validation Loss: 0.29563918709754944


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 21, Training Loss: 0.11790939420461655, Validation Loss: 0.2952145040035248


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 22, Training Loss: 0.10530146211385727, Validation Loss: 0.2947268486022949


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 23, Training Loss: 0.09470847994089127, Validation Loss: 0.2952643036842346


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 24, Training Loss: 0.08577204495668411, Validation Loss: 0.29551824927330017


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 25, Training Loss: 0.0754757970571518, Validation Loss: 0.29775363206863403


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 26, Training Loss: 0.06945324689149857, Validation Loss: 0.29902732372283936


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 27, Training Loss: 0.06202911585569382, Validation Loss: 0.3005805015563965


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 28, Training Loss: 0.055815860629081726, Validation Loss: 0.30180710554122925


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 29, Training Loss: 0.04888837784528732, Validation Loss: 0.30486446619033813


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 30, Training Loss: 0.043837517499923706, Validation Loss: 0.3083876967430115


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 31, Training Loss: 0.038906726986169815, Validation Loss: 0.30834341049194336


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 32, Training Loss: 0.035062339156866074, Validation Loss: 0.31097838282585144


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 33, Training Loss: 0.030737191438674927, Validation Loss: 0.31603652238845825


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 34, Training Loss: 0.027629932388663292, Validation Loss: 0.31838181614875793


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 35, Training Loss: 0.025365378707647324, Validation Loss: 0.3211582899093628


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 36, Training Loss: 0.022064093500375748, Validation Loss: 0.32449501752853394


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))



Посчитайте f1-score вашего классификатора.

**Ответ**:

In [41]:
calc_f1_score(cnn_and_embeddings_model, test_iter, lambda model, batch: model(batch.text))

0.872462272644043

Проверим насколько все хорошо!

In [43]:
PAD_IND = TEXT.vocab.stoi['pad']

token_reference = TokenReferenceBase(reference_token_idx=PAD_IND)
lig = LayerIntegratedGradients(cnn_and_embeddings_model, cnn_and_embeddings_model.embedding)
vis_data_records_ig = []

interpret_sentence(cnn_and_embeddings_model, 'It was a fantastic performance !', label=1)
interpret_sentence(cnn_and_embeddings_model, 'Best film ever', label=1)
interpret_sentence(cnn_and_embeddings_model, 'Such a great show!', label=1)
interpret_sentence(cnn_and_embeddings_model, 'It was a horrible movie', label=0)
interpret_sentence(cnn_and_embeddings_model, 'I\'ve never watched something as bad', label=0)
interpret_sentence(cnn_and_embeddings_model, 'It is a disgusting movie!', label=0)

pred:  pos ( 0.98 ) , delta:  tensor([0.0001], device='cuda:0', dtype=torch.float64)
pred:  neg ( 0.01 ) , delta:  tensor([2.5854e-05], device='cuda:0', dtype=torch.float64)
pred:  neg ( 0.35 ) , delta:  tensor([0.0002], device='cuda:0', dtype=torch.float64)
pred:  neg ( 0.00 ) , delta:  tensor([5.4771e-05], device='cuda:0', dtype=torch.float64)
pred:  neg ( 0.43 ) , delta:  tensor([0.0002], device='cuda:0', dtype=torch.float64)
pred:  neg ( 0.00 ) , delta:  tensor([1.0463e-05], device='cuda:0', dtype=torch.float64)


In [44]:
print('Visualize attributions based on Integrated Gradients')
visualization.visualize_text(vis_data_records_ig)

Visualize attributions based on Integrated Gradients


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
pos,pos (0.98),pos,1.85,It was a fantastic performance ! pad
,,,,
pos,neg (0.01),pos,1.12,Best film ever pad pad pad pad
,,,,
pos,neg (0.35),pos,1.57,Such a great show! pad pad pad
,,,,
neg,neg (0.00),pos,0.06,It was a horrible movie pad pad
,,,,
neg,neg (0.43),pos,1.65,I've never watched something as bad pad
,,,,


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
pos,pos (0.98),pos,1.85,It was a fantastic performance ! pad
,,,,
pos,neg (0.01),pos,1.12,Best film ever pad pad pad pad
,,,,
pos,neg (0.35),pos,1.57,Such a great show! pad pad pad
,,,,
neg,neg (0.00),pos,0.06,It was a horrible movie pad pad
,,,,
neg,neg (0.43),pos,1.65,I've never watched something as bad pad
,,,,


## CNN + Embeddings + Frozen N first epochs



In [45]:
cnn_and_frozen_embeddings_model = CNN(vocab_size=vocab_size, emb_dim=dim, out_channels=out_channels, kernel_sizes=kernel_sizes, dropout=dropout)

word_embeddings = TEXT.vocab.vectors

prev_shape = cnn_and_frozen_embeddings_model.embedding.weight.shape

cnn_and_frozen_embeddings_model.embedding.weight.data = torch.clone(word_embeddings)

assert prev_shape == cnn_and_frozen_embeddings_model.embedding.weight.shape
cnn_and_frozen_embeddings_model.to(device)

opt = torch.optim.Adam(cnn_and_frozen_embeddings_model.parameters(), lr=1e-4)
loss_func = nn.BCEWithLogitsLoss()

max_epochs = 50
patience = 15

train_cnn(cnn_and_frozen_embeddings_model, loss_func, train_iter, val_iter, max_epochs, patience, num_freeze_epochs=20)

HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 1, Training Loss: 0.6767279505729675, Validation Loss: 0.6421321034431458


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 2, Training Loss: 0.6141664385795593, Validation Loss: 0.5650098919868469


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 3, Training Loss: 0.5261306762695312, Validation Loss: 0.47656527161598206


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 4, Training Loss: 0.4586566686630249, Validation Loss: 0.42748555541038513


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 5, Training Loss: 0.42121607065200806, Validation Loss: 0.4015980660915375


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 6, Training Loss: 0.39548027515411377, Validation Loss: 0.38501882553100586


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 7, Training Loss: 0.3787156343460083, Validation Loss: 0.37361207604408264


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 8, Training Loss: 0.36894601583480835, Validation Loss: 0.36509644985198975


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 9, Training Loss: 0.35644420981407166, Validation Loss: 0.3579086661338806


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 10, Training Loss: 0.3496182858943939, Validation Loss: 0.3519825041294098


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 11, Training Loss: 0.3371344208717346, Validation Loss: 0.3468288481235504


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 12, Training Loss: 0.3277920186519623, Validation Loss: 0.34187641739845276


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 13, Training Loss: 0.32094889879226685, Validation Loss: 0.3384256958961487


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 14, Training Loss: 0.3118315041065216, Validation Loss: 0.334312379360199


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 15, Training Loss: 0.3036172688007355, Validation Loss: 0.3307430148124695


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 16, Training Loss: 0.29895299673080444, Validation Loss: 0.32824885845184326


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 17, Training Loss: 0.292708158493042, Validation Loss: 0.3250955045223236


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 18, Training Loss: 0.28394633531570435, Validation Loss: 0.3225109875202179


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 19, Training Loss: 0.27842479944229126, Validation Loss: 0.3203069567680359


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 20, Training Loss: 0.2723568379878998, Validation Loss: 0.31801366806030273


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 21, Training Loss: 0.2643948495388031, Validation Loss: 0.31352129578590393


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 22, Training Loss: 0.2518346309661865, Validation Loss: 0.3098350167274475


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 23, Training Loss: 0.23830914497375488, Validation Loss: 0.3071346879005432


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 24, Training Loss: 0.2230561226606369, Validation Loss: 0.3037700057029724


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 25, Training Loss: 0.21227459609508514, Validation Loss: 0.3005315661430359


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 26, Training Loss: 0.19739606976509094, Validation Loss: 0.29796603322029114


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 27, Training Loss: 0.18485459685325623, Validation Loss: 0.2954501807689667


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 28, Training Loss: 0.1732412576675415, Validation Loss: 0.29324138164520264


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 29, Training Loss: 0.16174115240573883, Validation Loss: 0.29129117727279663


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 30, Training Loss: 0.14885404706001282, Validation Loss: 0.2894691228866577


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 31, Training Loss: 0.13797134160995483, Validation Loss: 0.28846704959869385


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 32, Training Loss: 0.12843380868434906, Validation Loss: 0.2867271304130554


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 33, Training Loss: 0.11915233731269836, Validation Loss: 0.2870180010795593


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 34, Training Loss: 0.10873080790042877, Validation Loss: 0.2850891351699829


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 35, Training Loss: 0.10037156194448471, Validation Loss: 0.2849443256855011


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 36, Training Loss: 0.09172287583351135, Validation Loss: 0.28484204411506653


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 37, Training Loss: 0.08283637464046478, Validation Loss: 0.28615668416023254


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 38, Training Loss: 0.07585844397544861, Validation Loss: 0.2845083773136139


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 39, Training Loss: 0.07033748179674149, Validation Loss: 0.287023663520813


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 40, Training Loss: 0.06372224539518356, Validation Loss: 0.2853008508682251


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 41, Training Loss: 0.05621509999036789, Validation Loss: 0.288138747215271


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 42, Training Loss: 0.05207853019237518, Validation Loss: 0.2900582551956177


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 43, Training Loss: 0.04726078733801842, Validation Loss: 0.2891470193862915


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 44, Training Loss: 0.0423966683447361, Validation Loss: 0.2912187874317169


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 45, Training Loss: 0.037737924605607986, Validation Loss: 0.29307907819747925


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 46, Training Loss: 0.034464605152606964, Validation Loss: 0.2945287525653839


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 47, Training Loss: 0.031634438782930374, Validation Loss: 0.2983073592185974


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 48, Training Loss: 0.028561849147081375, Validation Loss: 0.2984580993652344


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 49, Training Loss: 0.02543807215988636, Validation Loss: 0.30106598138809204


HBox(children=(FloatProgress(value=0.0, max=137.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 50, Training Loss: 0.02351159229874611, Validation Loss: 0.3031412959098816


In [46]:
calc_f1_score(cnn_and_frozen_embeddings_model, test_iter, lambda model, batch: model(batch.text))

0.8806599974632263

## Score для всех эксперментов:
- RNN (LSTM) - 0.8374814987182617
- CNN - 0.8496467471122742
- CNN + Glove Embeddings - 0.872462272644043
- CNN + Glove Embeddings (frozen first N epochs) - 0.8806599974632263

## Выводы
- CNN работает лучше чем RNN
- Предобученные эмбеддинги помогают улучшить результат
- Замораживание весов эмбеддингов на первые N эпох помогает улучшить результат еще лучше