First, we create an inference pipeline for abstractive summarization. We use BART-base, finetuned on xsum.

In [1]:
from transformers.utils import PaddingStrategy
from transformers import pipeline, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(
    "morenolq/bart-base-xsum",
    padding=PaddingStrategy.DO_NOT_PAD,
    use_fast=True
)

summarizer = pipeline(
    "summarization",
    model="morenolq/bart-base-xsum",
    tokenizer=tokenizer,
    device=0,
    framework="pt",
)

Next, we need to load our dataset.

In [2]:
import pandas as pd

df = pd.read_csv("../data/docee/low_resource_augmentation/train_news.csv")
print(f"Loaded {len(df)} examples.")
df.head()

Loaded 2697 examples.


Unnamed: 0,index,title,text,event_type,arguments,date,metadata
0,0,North Korea responds to nearby joint United St...,"North Korea says it will use its ""nuclear dete...",Military Exercise,"[{'start': 74, 'end': 75, 'type': 'Countries p...",July 2010,['(BBC)']
1,3,Severe drought conditions continue in Zimbabwe...,As severe drought conditions continue in Zimba...,Droughts,"[{'start': 41, 'end': 48, 'type': 'Areas Affec...",December 2019,"['(Catholic News Agency)', '(Bloomberg News)']"
2,4,"Saudi Arabian blogger Raif Badawi, who has bee...",Raif Badawi is the laureate of this year's Sak...,Awards ceremony,"[{'start': 0, 'end': 10, 'type': 'Winner', 'te...",October 2015,"['(Reuters)', '(Al Jazeera)', '(EU)']"
3,13,British Prime Minister Theresa May announces t...,EU sources say Theresa May missed chance of sw...,Withdraw from an Organization,"[{'start': 15, 'end': 25, 'type': 'Declarer', ...",March 2017,['(The Guardian)']
4,20,Ansar al-Sharia announces it is formally disso...,CAIRO (Reuters) - Libyan Islamist group Ansar ...,Organization Closed,"[{'start': 18, 'end': 54, 'type': 'Organizatio...",May 2017,['(Reuters)']


Check how many classes are present.

In [3]:
num_classes = df.loc[:, ["event_type"]].nunique()[0]
print(f"Total {num_classes} unique classes.")

Total 37 unique classes.


Build a dataset.

Perform an inference with the first example

In [4]:
first_example = df.iloc[0]
first_example

index                                                         0
title         North Korea responds to nearby joint United St...
text          North Korea says it will use its "nuclear dete...
event_type                                    Military Exercise
arguments     [{'start': 74, 'end': 75, 'type': 'Countries p...
date                                                  July 2010
metadata                                              ['(BBC)']
Name: 0, dtype: object

In [5]:
summary_output = summarizer(
    first_example.text,
    min_length=20,
    max_length=200,
    num_beams=20,
    do_sample=False,
)  # greedy decoding
summary_output[0]["summary_text"]

'North Korea has warned that it will launch a war game with the US if necessary to "stifle" the country, state media say.'

Now let's map the input dataframe into a new one, which consists of summaries instead of input texts.

In [3]:
summary_df = df.loc[:, ["text", "event_type"]]
summary_df.head()

Unnamed: 0,text,event_type
0,"North Korea says it will use its ""nuclear dete...",Military Exercise
1,As severe drought conditions continue in Zimba...,Droughts
2,Raif Badawi is the laureate of this year's Sak...,Awards ceremony
3,EU sources say Theresa May missed chance of sw...,Withdraw from an Organization
4,CAIRO (Reuters) - Libyan Islamist group Ansar ...,Organization Closed


In [7]:
from torch.utils.data import Dataset


class TextDoceeDataset(Dataset):
    def __init__(self, df):
        self.df = df.loc[:, "text"]

    def __len__(self):
        return len(self.df)

    def __getitem__(self, item):
        return self.df.iloc[item]

dataset = TextDoceeDataset(df[:10])
outs = [out[0]["summary_text"] for out in summarizer(
    dataset,
    min_length=20,
    max_length=200,
    truncation=True,
    batch_size=5
)]
# im actually not sure whether or not batch size preserves the ordering
outs

['North Korea has warned of a "fire and fury" if the US goes ahead with its planned war games with South Korea, state media say.',
 'The Catholic Church in Southern Africa is working with farmers in Zimbabwe to help feed tens of thousands of people.',
 'The European Parliament (EP) has announced that it will award a prize to a Saudi blogger who was sentenced to 50 lashes in public last year.',
 'Theresa May’s decision to spell out the starting date for formal EU-UK Brexit negotiations has been met with fury by the European Union.',
 'A rival faction in Libya has said it is split into two rival factions, in the latest twist in a long-running war.',
 "A selection of photos from around Malaysia's capital, Kuala Lumpur, which have been released by the BBC.",
 "A look back at some of the key events in Myanmar's recent turbulent political history, including:",
 'At least five people have been killed and 23 injured in a school bus crash in Chattanooga, Tennessee.',
 'A selection of photos fro

In [7]:
from tqdm import tqdm
tqdm.pandas()

def summarize_text(text: str) -> str:
    return summarizer(
        text,
        min_length=20,
        max_length=200,
        num_beams=10,
        do_sample=False,
        truncation=True,
    )[0]["summary_text"]

summary_df.loc[:, "text"] = summary_df["text"].progress_apply(summarize_text)

  0%|          | 7/2697 [00:02<14:53,  3.01it/s]Your max_length is set to 200, but you input_length is only 108. You might consider decreasing max_length manually, e.g. summarizer('...', max_length=54)
  1%|          | 21/2697 [00:07<18:14,  2.44it/s]Your max_length is set to 200, but you input_length is only 153. You might consider decreasing max_length manually, e.g. summarizer('...', max_length=76)
  1%|▏         | 36/2697 [00:12<15:51,  2.80it/s]Your max_length is set to 200, but you input_length is only 154. You might consider decreasing max_length manually, e.g. summarizer('...', max_length=77)
  2%|▏         | 46/2697 [00:16<17:02,  2.59it/s]Your max_length is set to 200, but you input_length is only 160. You might consider decreasing max_length manually, e.g. summarizer('...', max_length=80)
  2%|▏         | 49/2697 [00:17<15:00,  2.94it/s]Your max_length is set to 200, but you input_length is only 83. You might consider decreasing max_length manually, e.g. summarizer('...', ma