In [1]:
import torch
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from torchtext.vocab import vocab as torch_vocab

import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

from datasets import load_dataset
from rouge import Rouge

In [2]:
rouge_metric = Rouge()
dataset = load_dataset('IlyaGusev/gazeta', revision="v2.0")

No config specified, defaulting to: gazeta/default
Found cached dataset gazeta (/home/goncharovglebig/.cache/huggingface/datasets/IlyaGusev___gazeta/default/2.0.0/c329f0fc1c22ab6e43e0045ee659d0d43c647492baa2a6ab3a5ea7dac98cd552)


  0%|          | 0/3 [00:00<?, ?it/s]

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

PG_MODEL_PATH = './pointer_gazeta.pth'
PG_VOCAB_PATH = './gazeta_voc.pth'
EXTR_MODEL_PATH = './extractor.pth'

## Top3 baseline

In [8]:
from nltk.tokenize import sent_tokenize


def top3(article):
    return '.'.join(sent_tokenize(article)[:3])

preds = []
for art in tqdm(dataset['test']['text']):
    preds.append(top3(art))

rouge_metric.get_scores(preds, dataset['test']['summary'], avg=True)

  0%|          | 0/6793 [00:00<?, ?it/s]

{'rouge-1': {'r': 0.23926067161957473,
  'p': 0.20639995255196122,
  'f': 0.21514407167065555},
 'rouge-2': {'r': 0.08792294667867649,
  'p': 0.0733570704747428,
  'f': 0.07724965555132388},
 'rouge-l': {'r': 0.21655468997676802,
  'p': 0.18726447349087771,
  'f': 0.19495473475845665}}

## Point Gen model

In [5]:
from predictors import PGenPredictor


ponter_model = PGenPredictor(
    model_path=PG_MODEL_PATH,
    vocab_path=PG_VOCAB_PATH,
    device=device,
)

preds = []
abst_lower = []
for i in tqdm(range(len(dataset['test']['text']))):
    preds.append(ponter_model.predict_one_sample(test_df['text'][i]))
    abst_lower.append(dataset['test']['summary'][i].lower())

rouge_metric.get_scores(preds, abst_lower, avg=True)

  0%|          | 0/6793 [00:00<?, ?it/s]

{'rouge-1': {'r': 0.20913465575729584,
  'p': 0.23785843721588698,
  'f': 0.21628455229856844},
 'rouge-2': {'r': 0.070436885435969,
  'p': 0.07686597047928244,
  'f': 0.07100721675287897},
 'rouge-l': {'r': 0.1882746409242634,
  'p': 0.21415780838697196,
  'f': 0.1947106309076359}}

## Extractor + PGen 

In [5]:
from predictors import ExtractionPGenPredictor


extr_model = ExtractionPGenPredictor(
    ext_model_path=EXTR_MODEL_PATH,
    pg_model_path=PG_MODEL_PATH,
    pg_vocab_path=PG_VOCAB_PATH,
    device=device,
    threshold=0.01
)

preds = []
abst_lower = []
for i in tqdm(range(len(dataset['test']['text']))):
    preds.append(extr_model.predict_one_sample(dataset['test']['text'][i]))
    abst_lower.append(dataset['test']['summary'][i].lower())

rouge_metric.get_scores(preds, abst_lower, avg=True)

  0%|          | 0/6793 [00:00<?, ?it/s]

{'rouge-1': {'r': 0.20907098421096082,
  'p': 0.2378527022928294,
  'f': 0.2162402917341788},
 'rouge-2': {'r': 0.07042001648760098,
  'p': 0.07685955000554483,
  'f': 0.07099325174761899},
 'rouge-l': {'r': 0.18822458148921,
  'p': 0.21416762503221332,
  'f': 0.1946788392502319}}

## MBart Model

In [None]:
from transformers import MBartTokenizer, MBartForConditionalGeneration
from torch.utils.data import DataLoader


def collate_batch(batch):
    texts = batch
    input_ids = tokenizer(
        texts,
        max_length=600,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
        )["input_ids"].to(device)
    return input_ids


model_name = "IlyaGusev/mbart_ru_sum_gazeta"
tokenizer = MBartTokenizer.from_pretrained(model_name)
model = MBartForConditionalGeneration.from_pretrained(model_name).to(device)

test_loader = DataLoader(dataset['test']['text'],
                         batch_size=1,
                         collate_fn=collate_batch)
preds = []
with torch.no_grad():
    for batch in tqdm(test_loader):
        output_ids = model.generate(
            input_ids=batch,
            no_repeat_ngram_size=4
        )
        preds = preds + [tokenizer.decode(tok, skip_special_tokens=True)
                         for tok in output_ids]

rouge_metric.get_scores(preds, dataset['test']['summary'], avg=True)

  0%|          | 0/6793 [00:00<?, ?it/s]



In [5]:
rouge_metric.get_scores(preds, dataset['test']['summary'], avg=True)

{'rouge-1': {'r': 0.241287149048779,
  'p': 0.23241927564091414,
  'f': 0.2308946759881896},
 'rouge-2': {'r': 0.08984618344552867,
  'p': 0.08505692164663911,
  'f': 0.08500935192604948},
 'rouge-l': {'r': 0.2193604865789741,
  'p': 0.2114669931119583,
  'f': 0.2100055745953222}}