In [None]:
!pip install youtokentome rouge razdel

In [None]:
import os
import youtokentome as yttm
import razdel
import random
import math
import copy
import torch
import numpy as np

from tqdm import tqdm
from collections import Counter
from typing import List, Tuple
from rouge import Rouge
from torch.utils import data

In [None]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(DEVICE)

In [None]:
train_records = cleaned_dataset[:100000]
val_records = cleaned_dataset[100000:115000]
test_records = cleaned_dataset[115000:130000]

Utils: score calculating

In [None]:
def calc_scores(references, predictions, metric="all"):
    print("Count:", len(predictions))
    print("Last true headline:", references[-1])
    print("Last predicted headline:", predictions[-1])

    if metric in ("bleu", "all"):
        print("\nBLEU: ", corpus_bleu([[r] for r in references], predictions))
    if metric in ("rouge", "all"):
        rouge = Rouge()
        scores = rouge.get_scores(predictions, references, avg=True)
        scores_string = ""
        for metric, value in scores.items():
            scores_string += "\n" + str(metric) + ":" + str(value)
        print("ROUGE: ", scores_string, "\n")

Byte Pair Encoding (BPE)

In [None]:
def train_bpe(records, model_path, model_type="bpe", vocab_size=10000, lower=True):
    temp_file_name = "temp.txt"
    with open(temp_file_name, "w") as temp:
        for text, title in tqdm(records[['Text', 'Head_title']].values):
            if lower:
                title = title.lower()
                text = text.lower()
            if not text or not title:
                continue
            temp.write(text + "\n")
            temp.write(title + "\n")
    yttm.BPE.train(data=temp_file_name, vocab_size=vocab_size, model=model_path)

In [None]:
train_bpe(train_records, "BPE_model.bin")

In [None]:
bpe_processor = yttm.BPE('BPE_model.bin')
bpe_processor.encode(["шустрая бурая лиса прыгает через ленивого пса"], output_type=yttm.OutputType.SUBWORD)

Составим словарь для индексации токенов

In [None]:
class Vocabulary:
    def __init__(self, bpe_processor):
        self.index2word = bpe_processor.vocab()
        self.word2index = {w: i for i, w in enumerate(self.index2word)}
        self.word2count = Counter()

    def get_pad(self):
        return self.word2index["<PAD>"]

    def get_sos(self):
        return self.word2index["<SOS>"]

    def get_eos(self):
        return self.word2index["<EOS>"]

    def get_unk(self):
        return self.word2index["<UNK>"]
    
    def has_word(self, word) -> bool:
        return word in self.word2index

    def get_index(self, word):
        if word in self.word2index:
            return self.word2index[word]
        return self.get_unk()

    def get_word(self, index):
        return self.index2word[index]

    def size(self):
        return len(self.index2word)

    def is_empty(self):
        empty_size = 4
        return self.size() <= empty_size

    def reset(self):
        self.word2count = Counter()
        self.index2word = ["<pad>", "<sos>", "<eos>", "<unk>"]
        self.word2index = {word: index for index, word in enumerate(self.index2word)}

In [None]:
vocabulary = Vocabulary(bpe_processor)
vocabulary.size()

Кэш oracle summary.
Закэшируем oracle summary, чтобы не пересчитывать их каждый раз.

In [None]:
def add_oracle_summary_to_records(records, max_sentences=30, lower=True, nrows=1000):
    rouge = Rouge()
    sentences_ = []
    oracle_sentences_ = []
    oracle_summary_ = []
    records = records.iloc[:nrows].copy()

    for text, title in tqdm(records[['Text', 'Head_title']].values):
        title = title.lower() if lower else title
        sentences = [sentence.text.lower() if lower else sentence.text for sentence in razdel.sentenize(text)][:max_sentences]
        oracle_summary, sentences_indicies = build_oracle_summary_greedy(text, title, calc_score=lambda x, y: calc_single_score(x, y, rouge),
                                                                         lower=lower, max_sentences=max_sentences)
        sentences_ += [sentences]
        oracle_sentences_ += [list(sentences_indicies)]
        oracle_summary_ += [oracle_summary]
    records['sentences'] = sentences_
    records['oracle_sentences'] = oracle_sentences_
    records['oracle_summary'] = oracle_summary_
    return records


def build_oracle_summary_greedy(text, gold_summary, calc_score, lower=True, max_sentences=30):
    '''
    Жадное построение oracle summary
    '''
    gold_summary = gold_summary.lower() if lower else gold_summary
    # Делим текст на предложения
    sentences = [sentence.text.lower() if lower else sentence.text for sentence in razdel.sentenize(text)][:max_sentences]
    n_sentences = len(sentences)
    oracle_summary_sentences = set()
    score = -1.0
    summaries = []
    for _ in range(min(n_sentences, 2)):
        for i in range(n_sentences):
            if i in oracle_summary_sentences:
                continue
            current_summary_sentences = copy.copy(oracle_summary_sentences)
            # Добавляем какое-то предложения к уже существующему summary
            current_summary_sentences.add(i)
            current_summary = " ".join([sentences[index] for index in sorted(list(current_summary_sentences))])
            # Считаем метрики
            current_score = calc_score(current_summary, gold_summary)
            summaries.append((current_score, current_summary_sentences))
        # Если получилось улучшить метрики с добавлением какого-либо предложения, то пробуем добавить ещё
        # Иначе на этом заканчиваем
        best_summary_score, best_summary_sentences = max(summaries)
        if best_summary_score <= score:
            break
        oracle_summary_sentences = best_summary_sentences
        score = best_summary_score
    oracle_summary = " ".join([sentences[index] for index in sorted(list(oracle_summary_sentences))])
    return oracle_summary, oracle_summary_sentences


def calc_single_score(pred_summary, gold_summary, rouge):
    return rouge.get_scores([pred_summary], [gold_summary], avg=True)['rouge-2']['f']

In [None]:
ext_train_records = add_oracle_summary_to_records(train_records, nrows=10000)
ext_val_records = add_oracle_summary_to_records(val_records, nrows=4000)
ext_test_records = add_oracle_summary_to_records(test_records, nrows=4000)

Составление батчей

In [None]:
class ExtDataset(data.Dataset):
    def __init__(self, records, vocabulary, bpe_processor, lower=True, max_sentences=30, max_sentence_length=50, device=torch.device('cpu')):
        self.records = records
        self.num_samples = records.shape[0]
        self.bpe_processor = bpe_processor
        self.lower = lower
        self.rouge = Rouge()
        self.vocabulary = vocabulary
        self.max_sentences = max_sentences
        self.max_sentence_length = max_sentence_length
        self.device = device
        
    def __len__(self):
        return self.records.shape[0]

    def __getitem__(self, idx):
        cur_record = self.records.iloc[idx]
        inputs = list(map(lambda x: x[:self.max_sentence_length], self.bpe_processor.encode(cur_record['sentences'], output_type=yttm.OutputType.ID)))
        outputs = [int(i in cur_record['oracle_sentences']) for i in range(len(cur_record['sentences']))]
        return {'inputs': inputs, 'outputs': outputs}

In [None]:
train_dataset = ExtDataset(ext_train_records, vocabulary, bpe_processor)

In [None]:
print(train_dataset[0])

In [None]:
def collate_fn(records):
    max_length = max(len(sentence) for record in records for sentence in record['inputs'])
    max_sentences = max(len(record['outputs']) for record in records)

    new_inputs = torch.zeros((len(records), max_sentences, max_length))
    new_outputs = torch.zeros((len(records), max_sentences))
    for i, record in enumerate(records):
        for j, sentence in enumerate(record['inputs']):
            new_inputs[i, j, :len(sentence)] += np.array(sentence)
        new_outputs[i, :len(record['outputs'])] += np.array(record['outputs'])
    return {'features': new_inputs.type(torch.LongTensor), 'targets': new_outputs}

Model RNN

In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.nn.utils.rnn import pack_padded_sequence as pack
from torch.nn.utils.rnn import pad_packed_sequence as unpack

In [None]:
class SentenceEncoderRNN(nn.Module):
    def __init__(self, input_size, embedding_dim, hidden_size, n_layers=3, dropout=0.3, bidirectional=True):
        super(SentenceEncoderRNN, self).__init__()

        num_directions = 2 if bidirectional else 1
        assert hidden_size % num_directions == 0
        hidden_size = hidden_size // num_directions

        self.embedding_dim = embedding_dim
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        self.dropout = dropout
        self.bidirectional = bidirectional

        self.embedding_layer = nn.Embedding(input_size, embedding_dim)
        self.rnn_layer = nn.LSTM(embedding_dim, hidden_size, n_layers, dropout=dropout, bidirectional=bidirectional, batch_first=True)
        self.dropout_layer = nn.Dropout(dropout)

    def forward(self, inputs, hidden=None):
        embedded = self.embedding_layer(inputs)
        outputs, _ = self.rnn_layer(embedded, hidden)
        sentences_embeddings = torch.mean(outputs, 1)
        return sentences_embeddings

In [None]:
class SentenceTaggerRNN(nn.Module):
    def __init__(self,
                 vocabulary_size,
                 token_embedding_dim=256,
                 sentence_encoder_hidden_size=256,
                 hidden_size=256,
                 bidirectional=True,
                 sentence_encoder_n_layers=2,
                 sentence_encoder_dropout=0.3,
                 sentence_encoder_bidirectional=True,
                 n_layers=1,
                 dropout=0.3):
        super(SentenceTaggerRNN, self).__init__()

        num_directions = 2 if bidirectional else 1
        assert hidden_size % num_directions == 0
        hidden_size = hidden_size // num_directions

        self.hidden_size = hidden_size
        self.n_layers = n_layers
        self.dropout = dropout
        self.bidirectional = bidirectional

        self.sentence_encoder = SentenceEncoderRNN(vocabulary_size, token_embedding_dim,
                                                   sentence_encoder_hidden_size, sentence_encoder_n_layers, 
                                                   sentence_encoder_dropout, sentence_encoder_bidirectional)
        self.rnn_layer = nn.LSTM(sentence_encoder_hidden_size, hidden_size, n_layers, dropout=dropout,
                           bidirectional=bidirectional, batch_first=True)
        self.dropout_layer = nn.Dropout(dropout)
        self.content_linear_layer = nn.Linear(hidden_size * 2, 1)
        self.document_linear_layer = nn.Linear(hidden_size * 2, hidden_size * 2)
        self.salience_linear_layer = nn.Linear(hidden_size * 2, hidden_size * 2)
        self.tanh_layer = nn.Tanh()

    def forward(self, inputs, hidden=None):
        batch_size = inputs.size(0)
        sentences_count = inputs.size(1)
        tokens_count = inputs.size(2)
        inputs = inputs.reshape(-1, tokens_count)
        embedded_sentences = self.sentence_encoder(inputs)
        embedded_sentences = embedded_sentences.reshape(batch_size, sentences_count, -1)
        outputs, _ = self.rnn_layer(embedded_sentences, hidden)
        outputs = self.dropout_layer(outputs)
        document_embedding = self.tanh_layer(self.document_linear_layer(torch.mean(outputs, 1)))
        content = self.content_linear_layer(outputs).squeeze(2)
        salience = torch.bmm(outputs, self.salience_linear_layer(document_embedding).unsqueeze(2)).squeeze(2)
        return content + salience

Trainer

In [None]:
def fit_epoch(model, train_loader, criterion, optimizer):
    model.train()  # train mode
    avg_loss = 0
    for item in tqdm(train_loader):
        # data to device
        inputs = item['features'].to(DEVICE)
        targets = item['targets'].to(DEVICE)
        # reset gradients
        optimizer.zero_grad()
        # forward
        outputs = model(inputs)
        # calc batch loss
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        # add batch average loss
        avg_loss += loss.item() / len(train_loader)
    return avg_loss



def eval_epoch(model, val_loader, criterion):
    model.eval()  # testing mode
    avg_loss = 0
    for item in val_loader:
        inputs = item['features'].to(DEVICE)
        targets = item['targets'].to(DEVICE)

        with torch.set_grad_enabled(False):
            outputs = model(inputs)
            # calc batch loss
            loss = criterion(outputs, targets)
            # add batch average loss
            avg_loss += loss.item() / len(val_loader)
    return avg_loss



def train(model, opt, criterion, epochs, train_loader, val_loader, verbose=True):
    history = []
    for epoch in range(epochs):
        if verbose:
            print('* Epoch %d/%d' % (epoch+1, epochs))
        # Обучаем модель, собираем loss-метрику по текущей эпохе
        avg_train_loss = fit_epoch(model, train_loader, criterion, opt)
        # Собираем loss-метрику на валидационном датасете по текущей эпохе
        avg_val_loss = eval_epoch(model, val_loader, criterion)
        # Сохраняем все метрики для последующей отрисовки на графиках
        history.append((avg_train_loss, avg_val_loss))
                
        if verbose:
            print('Train loss: %f' % avg_train_loss)
    return history

Тренировка модели

In [None]:
train_loader = data.DataLoader(ExtDataset(ext_train_records, vocabulary, bpe_processor=bpe_processor), batch_size=200, collate_fn=collate_fn)
val_loader = data.DataLoader(ExtDataset(ext_val_records, vocabulary, bpe_processor=bpe_processor), batch_size=200, collate_fn=collate_fn)

In [None]:
torch.manual_seed(13)
torch.cuda.manual_seed(13)
torch.backends.cudnn.deterministic = True
model = SentenceTaggerRNN(vocabulary.size()).to(DEVICE)

history = train(model = model,
                opt = torch.optim.Adam(model.parameters(), lr=1e-3),
                criterion = nn.BCEWithLogitsLoss(),
                epochs = 5,
                train_loader = train_loader,
                val_loader = val_loader,
                verbose = True)

Score model

In [None]:
references = []
predictions = []
model.eval()
for i, item in tqdm(enumerate(data.DataLoader(ExtDataset(ext_test_records, vocabulary, bpe_processor=bpe_processor), batch_size=1, collate_fn=collate_fn)), total=ext_test_records.shape[0]):
    logits = model(item["features"].to(device))[0] # forward
    record = ext_test_records.iloc[i]
    predicted_summary = []
    for i, logit in enumerate(logits):
        if logit > 0.0:
            predicted_summary.append(record['sentences'][i])
    if not predicted_summary:
        predicted_summary.append(record['sentences'][torch.max(logits, dim=0)[1].item()])
    predicted_summary = " ".join(predicted_summary)
    references.append(record['summary'].lower())
    predictions.append(predicted_summary)

calc_scores(references, predictions)