In [37]:
from transformers.utils import PaddingStrategy
from transformers import AutoTokenizer
import os

tokenizer_path = "/shared/lovorka/jvidakovic/models/checkpoint-28000"
tokenizer = AutoTokenizer.from_pretrained(
    tokenizer_path,
    padding=PaddingStrategy.LONGEST,
    use_fast=True,
    model_max_length=1024
)


In [38]:
tokenizer.model_max_length

1024

In [39]:
model_path = os.path.join(tokenizer_path, "summarization")

In [40]:
from transformers import BartForConditionalGeneration

model = BartForConditionalGeneration.from_pretrained(
    model_path,
)
model

BartForConditionalGeneration(
  (model): BartModel(
    (shared): Embedding(50265, 768, padding_idx=1)
    (encoder): BartEncoder(
      (embed_tokens): Embedding(50265, 768, padding_idx=1)
      (embed_positions): BartLearnedPositionalEmbedding(1026, 768)
      (layers): ModuleList(
        (0): BartEncoderLayer(
          (self_attn): BartAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (final_layer_norm): LayerNorm((768,), eps=1e-05,

In [28]:
type(model.config)

transformers.models.bart.configuration_bart.BartConfig

In [43]:
from transformers import pipeline

summarizer = pipeline(
    "summarization",
    model=model,
    tokenizer=tokenizer,
    device="cpu",
    framework="pt"
)

In [6]:
import pandas as pd
data_path = "/home/jvidakovic/cross_lingual_data_augmentation/data/docee/all/train_all.csv"
df = pd.read_csv(data_path)
df.head()

Unnamed: 0.1,Unnamed: 0,title,text,event_type,arguments,date,metadata
0,0,Vietnam reelects conservative Nguyễn Phú Trọng...,Vietnam's Communist Party Wednesday re-elected...,Government Job change - Election,"[{'start': 0, 'end': 24, 'type': 'Candidates a...",January 2016,"['(AP via ABC News)', '(Channel NewsAsia)']"
1,1,At least 42 people are killed in a bus crash i...,Another 43 people were injured when the bus ca...,Road Crash,"[{'start': 8, 'end': 29, 'type': 'Casualties a...",October 2006,['(BBC)']
2,2,At least 27 migrants die in a shipwreck in the...,At least 27 migrants have died off the Turkish...,Shipwreck,"[{'start': 0, 'end': 29, 'type': 'Casualties a...",February 2016,"['(ANSAmed)', '(Leadership)', '(news.com.au)',..."
3,3,Colten Treu faces charges of vehicular homicid...,"Colten Treu, 21, and his roommate both told au...",Road Crash,"[{'start': 183, 'end': 207, 'type': 'Number of...",November 2018,"['(KSTP)', '(Oxygen)']"
4,4,"Hours after the announcement, Morales resigns ...",Bolivian President Evo Morales has resigned af...,Government Job change - Resignation_Dismissal,"[{'start': 0, 'end': 17, 'type': 'Position', '...",November 2019,"['(BBC News)', '(The Guardian)']"


In [7]:
df = df.loc[:, ["text", "title", "event_type", "date"]]
summary_df = df.copy()

In [16]:
import numpy as np
from typing import Tuple

def get_class_counts(df: pd.DataFrame) -> dict[str, int]:
    """For each class label present in the dataset, returns the
    count of examples for that class.

    :param df:  dataframe
    :return:  dictionary where each key represents a class label
        and each value represents the count of examples belonging to
        that class.
    """

    class_names = set(df["event_type"].tolist())
    print(f"Total of {len(class_names)} class names.")

    class_counts = {
        class_name: np.sum(df.event_type.values == class_name)
        for class_name in class_names
    }
    print(f"Sum of all class counts equals {sum(class_counts.values())}.")
    return class_counts

def low_resource_slice(
        df: pd.DataFrame,
        cutoff: int,
        return_classes: bool = False
) -> pd.DataFrame | Tuple[list[str], pd.DataFrame]:
    """ For a given dataframe, returns all examples which belong to low resource classes.

    Low resource classes include all classes for which the class count (i.e. number of examples)
    is not greater than the given cutoff.

    :param return_classes: whether or not to return classes
    :param df:  dataframe
    :param cutoff:  low resource threshold
    :return:  (low_resource_classes, low_resource_df)
    """

    class_counts = get_class_counts(df)
    low_resource_classes = list(filter(lambda k: class_counts[k] <= cutoff, class_counts))
    low_resource_df = df.loc[df["event_type"].isin(low_resource_classes), :]
    if return_classes:
        return low_resource_classes, low_resource_df
    else:
        return low_resource_df


In [17]:
low_resource_classes, summary_df = low_resource_slice(
    summary_df,
    500,
    return_classes=True
)
print(f"{low_resource_classes = }")
print(f"{len(low_resource_classes) = }")

Total of 59 class names.
Sum of all class counts equals 21949.
low_resource_classes = ['Tsunamis', 'Mudslides', 'Organization Closed', 'Famous Person - Give a speech', 'Hurricanes_Tornado_Storm_Blizzard', 'Join in an Organization', 'Famous Person - Commit Crime - Investigate', 'Diplomatic Visit', 'Withdraw from an Organization', 'Famous Person - Commit Crime - Release', 'Famous Person - Commit Crime - Arrest', 'Floods', 'Famous Person - Recovered', 'Famine', 'Strike', 'Famous Person - Divorce', 'Sign Agreement', 'Famous Person - Marriage', 'Shipwreck', 'Mass Poisoning', 'Diplomatic Talks _ Diplomatic_Negotiation_ Summit Meeting', 'Organization Fine', 'Tear Up Agreement', 'Awards ceremony', 'Train collisions', 'Mine Collapses', 'Financial Aid', 'Financial Crisis', 'Road Crash', 'Environment Pollution', 'Famous Person - Death', 'Famous Person - Sick', 'Military Exercise', 'Volcano Eruption', 'New achievements in aerospace', 'Insect Disaster', 'New archeological discoveries', 'Disease Out

In [18]:
source_doc_ids = summary_df.index.values
source_doc_ids

array([    1,     2,     3, ..., 21945, 21946, 21948])

In [19]:
from torch.utils.data import Dataset
from typing import Optional, Callable, Iterable

def concat_dot_join(tokens: Iterable[str]) -> str:
    return ". ".join(tokens)

class DoceeForInference(Dataset):
    def __init__(
            self,
            df: pd.DataFrame,
            use_title: bool = False,
            concat: Optional[Callable[[Iterable[str]], str]] = concat_dot_join
    ):
        columns = ["title", "text"] if use_title else ["text"]
        self.concat = concat
        self.df = df.loc[:, columns]

    def __len__(self):
        return len(self.df)

    def __getitem__(self, item):
        return self.concat(self.df.iloc[item])

In [20]:
dataset = DoceeForInference(summary_df, use_title=False)
len(dataset)

11370

In [44]:
from tqdm import tqdm

summary_df.loc[:, ["text", "source_doc_id"]] = [
    (out[j]["summary_text"], source_doc_ids[i])
    for i, out in enumerate(tqdm(summarizer(
        dataset,
        truncation=True,
        batch_size=1,
        num_workers=1,
        min_length=20,
        max_length=156,
        num_beams=1,
        early_stopping=True,
        top_k=0,
        top_p=1.0,
        temperature=1.0,
        do_sample=False,
        num_return_sequences=1,
        penalty_alpha=0,
    ), desc=f"Inference loop", total=len(dataset)))
    for j in range(len(out))
]

Inference loop:   0%|▏                                                                                                                                                          | 13/11370 [00:17<4:37:40,  1.47s/it]Your max_length is set to 156, but you input_length is only 154. You might consider decreasing max_length manually, e.g. summarizer('...', max_length=77)
Inference loop:   0%|▏                                                                                                                                                          | 13/11370 [00:17<4:13:42,  1.34s/it]


In [30]:
len(dataset[8])

11931

In [34]:
[len(dataset[i]) for i in range(8)]

[855, 3373, 2692, 1782, 1675, 1027, 4195, 2783]

In [33]:
dataset[:8]

'text'