In [1]:
# !wget -q https://www.dropbox.com/s/43l702z5a5i2w8j/gazeta_train.txt
# !wget -q https://www.dropbox.com/s/k2egt3sug0hb185/gazeta_val.txt
# !wget -q https://www.dropbox.com/s/3gki5n5djs9w0v6/gazeta_test.txt

In [20]:
import re
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

from nltk.translate.bleu_score import corpus_bleu
# from rouge import Rouge

from ast import literal_eval
from tqdm import tqdm

In [24]:
def handle_whitespaces(k):
    return re.sub('\s+', ' ', re.sub('\n+', ' ', k.strip()))


def summarize(article_text, model, dev='cpu'):
    input_ids = tokenizer(
        [handle_whitespaces(article_text)],
        return_tensors="pt",
        padding="max_length",
        truncation=True,
        max_length=512
    )["input_ids"].to(dev)

    output_ids = model.to(dev).generate(
        input_ids=input_ids,
        max_length=84,
        no_repeat_ngram_size=2,
        num_beams=4
    )[0].to('cpu')

    summary = tokenizer.decode(
        output_ids,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=False
    )
    return summary

In [3]:
model_name = "csebuetnlp/mT5_multilingual_XLSum"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

Downloading tokenizer_config.json:   0%|          | 0.00/375 [00:00<?, ?B/s]

Downloading config.json:   0%|          | 0.00/730 [00:00<?, ?B/s]

Downloading spiece.model:   0%|          | 0.00/4.11M [00:00<?, ?B/s]

Downloading special_tokens_map.json:   0%|          | 0.00/65.0 [00:00<?, ?B/s]



Downloading pytorch_model.bin:   0%|          | 0.00/2.17G [00:00<?, ?B/s]

In [30]:
def calc_scores(references, predictions, metric="all"):
    print("Count:", len(predictions))
    # print("Ref:", references[-1])
    # print("Hyp:", predictions[-1])

    if metric in ("bleu", "all"):
        print("BLEU: ", corpus_bleu([[r] for r in references], predictions))
    # if metric in ("rouge", "all"):
    #     rouge = Rouge()
    #     scores = rouge.get_scores(predictions, references, avg=True)
    #     print("ROUGE: ", scores)

In [42]:
# read lines
ITEMS_COUNT = 1000

summaries = []
predicts = []
with open('data/gazeta_val.txt', 'r') as file:
    lines = file.readlines()
    for num, line in tqdm(enumerate(lines), total=ITEMS_COUNT if ITEMS_COUNT else len(lines), desc='processing'):
        if num == ITEMS_COUNT:
            break
        item = literal_eval(line)
        summaries.append(item['summary'])
        predicts.append(summarize(item['text'], model, 'cuda'))

processing: 100%|██████████| 1000/1000 [19:09<00:00,  1.15s/it]


In [43]:
calc_scores(summaries, predicts)

Count: 1000
BLEU:  0.1097574634085071


In [None]:
#