## Paper
* Dataset for Automatic Summarization of Russian News: https://arxiv.org/abs/2006.11063
* Download dataset: https://github.com/IlyaGusev/gazeta
* Summarization models: https://github.com/IlyaGusev/summarus

## Requirements

In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load
!pip install jsonlines
!pip install transformers
!pip install razdel
!pip install rouge
!pip install sentence-transformers

import json
import random
import nltk
import sys
import torch
import numpy as np
import itertools
import torch.nn as nn
nltk.download('punkt')

from transformers.optimization import AdamW
from sentence_transformers import SentenceTransformer
from transformers import AutoModel, BertTokenizer
from torch.utils.data.dataset import Dataset

Collecting jsonlines
  Downloading https://files.pythonhosted.org/packages/d4/58/06f430ff7607a2929f80f07bfd820acbc508a4e977542fefcc522cde9dff/jsonlines-2.0.0-py3-none-any.whl
Installing collected packages: jsonlines
Successfully installed jsonlines-2.0.0
Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/fd/1a/41c644c963249fd7f3836d926afa1e3f1cc234a1c40d80c5f03ad8f6f1b2/transformers-4.8.2-py3-none-any.whl (2.5MB)
[K     |████████████████████████████████| 2.5MB 7.7MB/s 
Collecting tokenizers<0.11,>=0.10.1
[?25l  Downloading https://files.pythonhosted.org/packages/d4/e2/df3543e8ffdab68f5acc73f613de9c2b155ac47f162e725dcac87c521c11/tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3MB)
[K     |████████████████████████████████| 3.3MB 51.6MB/s 
Collecting huggingface-hub==0.0.12
  Downloading https://files.pythonhosted.org/packages/2f/ee/97e253668fda9b17e968b3f97b2f8e53aa0127e8807d24a54768

## Data

In [None]:
# !rm -f gazeta_jsonl.tar.gz
# !wget -nc https://www.dropbox.com/s/cmpfvzxdknkeal4/gazeta_jsonl.tar.gz
# !tar -xzvf gazeta_jsonl.tar.gz
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
def read_gazeta_records(file_name, shuffle=True, sort_by_date=False):
    assert shuffle != sort_by_date
    records = []
    with open(file_name, "r") as r:
        data = json.load(r)
        records = [i for i in data]
    if sort_by_date:
        records.sort(key=lambda x: x["date"])
    if shuffle:
        random.shuffle(records)
    return records

train_records = read_gazeta_records("drive/MyDrive/ds/train_edited.json")
val_records = read_gazeta_records("drive/MyDrive/ds/val_edited.json")
test_records = read_gazeta_records("drive/MyDrive/ds/test_edited.json")

In [None]:
model = SentenceTransformer('distiluse-base-multilingual-cased')

def get_embeddings(sentence):
    # Compute the sentence embeddings
    embeddings = model.encode(sentence)
    return embeddings

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=690.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=2821.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=607.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=122.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=341.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=538971577.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=53.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=112.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1961847.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=528.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=995526.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=190.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=114.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1575975.0, style=ProgressStyle(descript…




In [None]:
# Split text into sentences
def split_text(record, key='text'):
    for sentence in nltk.sent_tokenize(record[key]):
        yield sentence

In [None]:
tokenizer = BertTokenizer.from_pretrained("bert-base-multilingual-cased")
for key, value in val_records[10].items():
    print(key, value)
tokenized_data = tokenizer(val_records[10]["text"])
token_indexes = tokenized_data["input_ids"]
tokens = tokenizer.convert_ids_to_tokens(token_indexes)
print("")
for index, token in zip(token_indexes, tokens):
    print(index, token)

def cut_long_texts(dataset):
    counter = 0
    with tqdm.tqdm(total=len(dataset), file=sys.stdout) as pbar:
        for i, article in enumerate(dataset):
            tokenized_data = tokenizer(dataset[i]["text"])
            token_indexes = tokenized_data["input_ids"]
            tokens = tokenizer.convert_ids_to_tokens(token_indexes)
            if len(tokens) > 512:
                counter += 1
                tokens = tokens[1:509]
                tokens.append('.')
                tokens.append('[SEP]')
                dataset[i]['text'] = tokenizer.convert_tokens_to_string(tokens)
            pbar.set_description('processed: %d' % (1 + i))
            pbar.update(1)
    print(f"total mistakes: {counter}")

need_cut_dataset = True

if need_cut_dataset:
    cut_long_texts(train_records)
    cut_long_texts(test_records)
    cut_long_texts(val_records)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=995526.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=29.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1961828.0, style=ProgressStyle(descript…


text После оглашения вердикта их пустили в конвойное помещение , где они ожидали отправки в СИЗО . [SEP] Признанный виновным в убийстве 93 - летней работницы Мариинского театра Сеник Нароян найден мертвым в конвойном помещении Санкт - Петербургского городского суда , передает ТАСС . [SEP] « Нароян обнаружен мертвым в конвойном помещении Санкт - Петербургского горсуда » , [UNK] передает агентство со ссылкой на источник в правоохранительных органах . [SEP] Накануне коллегия присяжных в Петербурге вынесла Нарояну и еще двум фигурантам дела Галине Гусейновой и Размику Акобяну обвинительный приговор . [SEP] Всех троих признали виновными и не заслуживающими снисхождения . [SEP] После оглашения вердикта их пустили в конвойное помещение , где они ожидали отправки в СИЗО . [SEP] По данным агенства , Гусейнову и Акобяна увезли в изолятор , Нарояна оставили ожидать дальше . [SEP] Ближе к полуночи конвой обнаружил его мертвым . [SEP] По информации « Фонтанки » , вызванные на место происшествия вр

In [None]:
class RTEDataset(Dataset):

    def __init__(self, data, tokenizer, pos_label="entailment"):
        self.data = data
        self.tokenizer = tokenizer
        self.pos_label = pos_label

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        item = self.data[index]
        answer = {"input_ids": self.tokenizer(item["text"])["input_ids"]}
        if "label" in item:
            answer["y"] = item["label"]
            # answer["y"] = self.tokenizer(item["summary"])['input_ids']
        return answer 

In [None]:
train_dataset = RTEDataset(train_records, tokenizer)
# train_dataset = RTEDataset(val_records, tokenizer)
for key, value in train_dataset[10].items():
    print(key, value, sep="\n")

input_ids
[101, 18596, 11448, 16133, 10636, 10122, 43428, 13594, 42207, 117, 30956, 34020, 17257, 14689, 12189, 11602, 10122, 40087, 119, 102, 511, 75780, 10205, 10234, 22605, 48873, 117, 541, 34680, 40585, 46610, 16882, 208, 10439, 11079, 543, 28015, 10205, 220, 119, 102, 14699, 13462, 29627, 38494, 63649, 20298, 117, 12709, 23089, 10385, 23527, 10297, 38355, 103943, 17131, 119, 102, 122, 528, 14736, 12068, 208, 11712, 78790, 102127, 89561, 543, 38494, 63649, 11075, 55015, 11079, 47495, 11905, 220, 119, 102, 18890, 14551, 10956, 33276, 107276, 10191, 543, 528, 108877, 11905, 101674, 12152, 10297, 75780, 11292, 11431, 119, 102, 33933, 10332, 78504, 10191, 117, 10791, 10234, 76674, 84689, 10122, 105765, 10353, 57972, 45709, 118, 557, 80966, 14958, 11075, 28419, 117, 14028, 23807, 33580, 50511, 11602, 96357, 103026, 41779, 549, 13488, 82562, 10384, 32297, 80062, 14360, 84833, 528, 21345, 25956, 22681, 12152, 10297, 75780, 11292, 11431, 119, 102, 208, 511, 26222, 557, 18291, 109558, 68855

In [None]:
def pad_tensor(vec, length, dim, pad_symbol):
    # vec.shape = [3, 4, 5]
    # length=7, dim=1 -> pad_size = (3, 7-4, 5)
    pad_size = list(vec.shape)
    pad_size[dim] = length - vec.shape[dim]
    answer = torch.cat([vec, torch.ones(*pad_size, dtype=torch.long) * pad_symbol], axis=dim)
    return answer

def pad_tensors(tensors, pad=0):
    # дополняет тензоры из tensors до общей максимальной длины символом pad
    if isinstance(tensors[0], (int, np.integer)):
        return torch.LongTensor(tensors)
    elif isinstance(tensors[0], (float, np.float)):
        return torch.Tensor(tensors)
    tensors = [torch.LongTensor(tensor) for tensor in tensors]
    L = max(tensor.shape[0] for tensor in tensors)
    tensors = [pad_tensor(tensor, L, dim=0, pad_symbol=pad) for tensor in tensors]
    return torch.stack(tensors, axis=0)

class FieldBatchDataLoader:

    def __init__(self, X, batch_size=32, sort_by_length=False, 
                 length_field=None, state=115, device="cpu"):
        self.X = X
        self.batch_size = batch_size
        self.sort_by_length = sort_by_length
        self.length_field = length_field  ## добавилось
        self.device = device
        np.random.seed(state)

    def __len__(self):
        return (len(self.X)-1) // self.batch_size + 1 

    def __iter__(self):
        if self.sort_by_length:
            # отсортировать индексы по длине объектов [1, ..., 32] -> [7, 4, 15, ...]
            # изменилось взятие длины из поля
            if self.length_field is not None:
                lengths = [len(x[self.length_field]) for x in self.X]
            else:
                lengths = [len(list(x.values())[0]) for x in self.X]
            order = np.argsort(lengths)
            # сгруппировать в батчи [7, 4, 15, 31, 3, ...] -> [[7, 4, 15, 31], [3, ...], ...]
            batched_order = np.array([order[start:start+self.batch_size] 
                                      for start in range(0, len(self.X), self.batch_size)])
            # переупорядочить батчи случайно: [[3, 11, 21, 19], [27, ...], ..., [7, ...], ...]
            np.random.shuffle(batched_order[:-1])
            # собрать посл-ть индексов: -> [3, 11, 21, 19, 27, ...]
            self.order = np.fromiter(itertools.chain.from_iterable(batched_order), dtype=int)
        else:
            self.order = np.arange(len(self.X))
            np.random.shuffle(self.order)
        self.idx = 0
        return self

    def __next__(self):
        if self.idx >= len(self.X):
            raise StopIteration()
        end = min(self.idx + self.batch_size, len(self.X))
        indexes = [self.order[i] for i in range(self.idx, end)]
        batch = dict()
        # перебираем все поля
        for field in self.X[indexes[0]]:
            batch[field] = pad_tensors([self.X[i][field] for i in indexes]).to(self.device)
        batch["indexes"] = indexes
        self.idx = end
        return batch   

In [None]:
train_dataloader = iter(FieldBatchDataLoader(train_dataset, batch_size=8, device="cuda"))
for i in range(10):
    batch = next(train_dataloader)
    for field, data in batch.items():
        print(f"{field}:{np.shape(data)}", end="\t")
    print("")

input_ids:torch.Size([8, 492])	y:torch.Size([8])	indexes:(8,)	
input_ids:torch.Size([8, 487])	y:torch.Size([8])	indexes:(8,)	
input_ids:torch.Size([8, 465])	y:torch.Size([8])	indexes:(8,)	
input_ids:torch.Size([8, 465])	y:torch.Size([8])	indexes:(8,)	
input_ids:torch.Size([8, 465])	y:torch.Size([8])	indexes:(8,)	
input_ids:torch.Size([8, 462])	y:torch.Size([8])	indexes:(8,)	
input_ids:torch.Size([8, 481])	y:torch.Size([8])	indexes:(8,)	
input_ids:torch.Size([8, 506])	y:torch.Size([8])	indexes:(8,)	
input_ids:torch.Size([8, 492])	y:torch.Size([8])	indexes:(8,)	
input_ids:torch.Size([8, 502])	y:torch.Size([8])	indexes:(8,)	


In [None]:
bert_model = AutoModel.from_pretrained("bert-base-multilingual-cased").to("cuda")

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=625.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=714314041.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
class BasicTransformersClassificationModel(nn.Module):

    def __init__(self, model, labels_number, lr=1e-5, device="cpu", **kwargs):
        super(BasicTransformersClassificationModel, self).__init__()
        self.model = model
        self.labels_number = labels_number
        self.build_network(labels_number)
        # определяем функцию потерь
        if self.labels_number > 1:
            self.log_softmax = nn.LogSoftmax(dim=-1)
            self.criterion = nn.NLLLoss(reduction="mean")
        else:
            self.log_softmax = nn.Sigmoid()
            self.criterion = nn.BCELoss(reduction="mean")
        self.device = device
        if self.device is not None:
            self.to(self.device)
        self.optimizer = AdamW(self.parameters(), lr=lr, weight_decay=0.01)

    @property
    def hidden_size(self):
        return self.model.config.hidden_size

    def forward(self, input_ids, **kwargs):
        raise NotImplementedError("You should implement forward pass in your derived class.") 

    def train_on_batch(self, x, y):
        self.train()
        self.optimizer.zero_grad()
        loss = self._validate(x, y)
        loss["loss"].backward()
        self.optimizer.step()
        return loss

    def validate_on_batch(self, x, y):
        self.eval()
        with torch.no_grad():
            return self._validate(x, y)

    def _validate(self, x, y):
        if self.device is not None:
            y = y.to(self.device)
        log_probs = self(**x) #   self.forward(x) = self.__call__(x)
        loss = self.criterion(log_probs, y)
        if self.labels_number > 1:
            _, labels = torch.max(log_probs, dim=-1)
        else:
            labels = (log_probs >= 0.5).int()
        return {"loss": loss, "labels": labels}

class TransformersClassificationModel(BasicTransformersClassificationModel):

    def build_network(self, labels_number):
        self.proj_layer = torch.nn.Linear(self.hidden_size, self.labels_number)
        return self

    def forward(self, input_ids, **kwargs):
        input_ids = input_ids.to(self.device)
        cls_output = self.model(input_ids)["pooler_output"]
        logits = self.proj_layer(cls_output)
        log_probs = self.log_softmax(logits)
        if self.labels_number == 1:
            log_probs = log_probs[...,0]
        return log_probs


# Train model

In [None]:
model = TransformersClassificationModel(
    bert_model, labels_number=1, device="cuda"
)
for name, elem in model.named_parameters():
    print(name, elem.device, elem.shape)

model.embeddings.word_embeddings.weight cuda:0 torch.Size([119547, 768])
model.embeddings.position_embeddings.weight cuda:0 torch.Size([512, 768])
model.embeddings.token_type_embeddings.weight cuda:0 torch.Size([2, 768])
model.embeddings.LayerNorm.weight cuda:0 torch.Size([768])
model.embeddings.LayerNorm.bias cuda:0 torch.Size([768])
model.encoder.layer.0.attention.self.query.weight cuda:0 torch.Size([768, 768])
model.encoder.layer.0.attention.self.query.bias cuda:0 torch.Size([768])
model.encoder.layer.0.attention.self.key.weight cuda:0 torch.Size([768, 768])
model.encoder.layer.0.attention.self.key.bias cuda:0 torch.Size([768])
model.encoder.layer.0.attention.self.value.weight cuda:0 torch.Size([768, 768])
model.encoder.layer.0.attention.self.value.bias cuda:0 torch.Size([768])
model.encoder.layer.0.attention.output.dense.weight cuda:0 torch.Size([768, 768])
model.encoder.layer.0.attention.output.dense.bias cuda:0 torch.Size([768])
model.encoder.layer.0.attention.output.LayerNorm.we

In [None]:
train_dataloader = iter(FieldBatchDataLoader(train_dataset, batch_size=8, device="cuda"))
batch = next(train_dataloader)
labels = batch["y"]
print(labels)
for i in range(100):
    loss = model.train_on_batch(batch, labels)["loss"].item()
    if i < 5 or (i+1) % 10 == 0:
        print(i, loss)
print(model.validate_on_batch(batch, labels)["loss"].item())

tensor([1., 0., 1., 0., 1., 1., 0., 0.], device='cuda:0')
0 0.7225401401519775
1 0.7008438110351562
2 0.7027850151062012
3 0.7049773335456848
4 0.6824501156806946
9 0.6523040533065796
19 0.5197445750236511
29 0.2642676830291748
39 0.12907591462135315
49 0.07144463062286377
59 0.034492358565330505
69 0.028062934055924416
79 0.025174222886562347
89 0.011758781969547272
99 0.009894424118101597
0.007287419401109219


In [None]:
def update_metrics(metrics, batch_output, batch_labels):
    n_batches = metrics["n_batches"]
    metrics["loss"] = (metrics["loss"] * n_batches + batch_output["loss"].item()) / (n_batches + 1)
    metrics["n_batches"] += 1
    are_equal = (batch_output["labels"] == batch_labels).float()
    are_TP = are_equal * batch_labels
    are_FP = (1.0 - are_equal) * (1-batch_labels)
    are_FN = (1.0 - are_equal) * batch_labels
    metrics["correct"] += are_equal.sum().item()
    metrics["total"] += are_equal.shape[0]
    metrics["TP"] += are_TP.sum().item()
    metrics["FP"] += are_FP.sum().item()
    metrics["FN"] += are_FN.sum().item()
    metrics["accuracy"] = metrics["correct"] / max(metrics["total"], 1)
    metrics["P"] = metrics["TP"] / max(metrics["TP"]+metrics["FP"], 1)
    metrics["R"] = metrics["TP"] / max(metrics["TP"]+metrics["FN"], 1)
    metrics["F1"] = metrics["TP"] / max(metrics["TP"]+0.5*metrics["FN"]+0.5*metrics["FP"], 1)

In [None]:
import tqdm

def do_epoch(model, dataloader, mode="validate", epoch=1):
    metrics = {"correct": 0, "total": 0, "loss": 0.0, "n_batches": 0,
               "TP": 0, "FP": 0, "FN": 0}
    func = model.train_on_batch if mode == "train" else model.validate_on_batch
    progress_bar = tqdm.notebook.tqdm(dataloader, leave=True)
    progress_bar.set_description(f"{mode}, epoch={epoch}")
    for batch in progress_bar:
        batch_answers = batch["y"]
        batch_output = func(batch, batch_answers)
        update_metrics(metrics, batch_output, batch_answers)
        postfix = {"loss": round(metrics["loss"], 4), "acc": round(100 * metrics["accuracy"], 2)}
        for key in ["P", "R", "F1"]:
            postfix[key] =  round(100 * metrics[key], 2)
        progress_bar.set_postfix(postfix)
    return metrics

In [None]:
train_dataset = RTEDataset(train_records, tokenizer)
dev_dataset = RTEDataset(val_records, tokenizer)
train_dataloader = iter(FieldBatchDataLoader(train_dataset, batch_size=8, device="cuda"))
dev_dataloader = iter(FieldBatchDataLoader(dev_dataset, batch_size=8, device="cuda"))

bert_model = AutoModel.from_pretrained("bert-base-multilingual-cased").to("cuda")
model = TransformersClassificationModel(bert_model, labels_number=1, device="cuda")
best_val_acc = 0.0
checkpoint = "checkpoint_best.pt"
for epoch in range(5):
    do_epoch(model, train_dataloader, mode="train", epoch=epoch+1)
    epoch_metrics = do_epoch(model, dev_dataloader, mode="validate", epoch=epoch+1)
    if epoch_metrics["accuracy"] > best_val_acc:
        best_val_acc = epoch_metrics["accuracy"]
        torch.save(model.state_dict(), checkpoint)
        # print("Saving ")
model.load_state_dict(torch.load(checkpoint))
do_epoch(model, dev_dataloader, mode="validate", epoch="evaluate")

Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


HBox(children=(FloatProgress(value=0.0, max=6550.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=659.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=6550.0), HTML(value='')))

I hope results of training will be displayed in github. But you can also check them on colab: https://colab.research.google.com/drive/17eeB2DqkTuvvarNNdMC965U_ZZD_E1lc?usp=sharing