In [1]:
from transformers import BartForSequenceClassification, BartForConditionalGeneration

PRETRAINED_MODEL_NAME_OR_PATH="ainize/bart-base-cnn"

In [2]:
def setup_models():
    # initialize models
    classification_model = BartForSequenceClassification.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH)
    summarization_model = BartForConditionalGeneration.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH)

    # share parameters
    summarization_model.model.shared = classification_model.model.shared
    summarization_model.model.encoder = classification_model.model.encoder
    summarization_model.model.decoder = classification_model.model.decoder

    return {
        "summarization": summarization_model,
        "classification": classification_model
    }


In [3]:
models = setup_models()
assert id(models["summarization"].model.shared) == id(models["classification"].model.shared)
assert id(models["summarization"].model.encoder) == id(models["classification"].model.encoder)
assert id(models["summarization"].model.decoder) == id(models["classification"].model.decoder)


Some weights of the model checkpoint at ainize/bart-base-cnn were not used when initializing BartForSequenceClassification: ['final_logits_bias', 'lm_head.weight']
- This IS expected if you are initializing BartForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BartForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BartForSequenceClassification were not initialized from the model checkpoint at ainize/bart-base-cnn and are newly initialized: ['classification_head.dense.bias', 'classification_head.dense.weight', 'classification_head.out_proj.weight', 'classification_head.out_proj.bias']
You should probably TRAIN this model on a down-strea

In [4]:
from transformers import AutoTokenizer

# we need a:
#   -> dataframe loaded with docee examples
#   -> tokenizer (bart tokenizer)

tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH)

In [6]:
from datasets import load_dataset

# bruhus
summ_dataset = load_dataset("cnn_dailymail", name="3.0.0", split="validation")
len(summ_dataset)

Found cached dataset cnn_dailymail (/home/jvidakovic/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)


13368

In [7]:
summ_dataset[0]

{'article': '(CNN)Share, and your gift will be multiplied. That may sound like an esoteric adage, but when Zully Broussard selflessly decided to give one of her kidneys to a stranger, her generosity paired up with big data. It resulted in six patients receiving transplants. That surprised and wowed her. "I thought I was going to help this one person who I don\'t know, but the fact that so many people can have a life extension, that\'s pretty big," Broussard told CNN affiliate KGO. She may feel guided in her generosity by a higher power. "Thanks for all the support and prayers," a comment on a Facebook page in her name read. "I know this entire journey is much bigger than all of us. I also know I\'m just the messenger." CNN cannot verify the authenticity of the page. But the power that multiplied Broussard\'s gift was data processing of genetic profiles from donor-recipient pairs. It works on a simple swapping principle but takes it to a much higher level, according to California Pacifi

In [23]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# obtain a random batch
batch = summ_dataset[:8]
batch.keys()

dict_keys(['article', 'highlights', 'id', 'input_ids', 'attention_mask'])

In [8]:
def compose2(f, g):
    def composition(*args, **kwargs):
        g_output = g(*args, **kwargs)
        return f(g_output)
    return composition

def c(*fs):
    def composition(*args, **kwargs):
        output = fs[-1](*args, **kwargs)
        for f in reversed(fs[:-1]):
            output = f(output)
        return output
    return composition


In [9]:
# okay, we got this
# cls_dataset = load_dataset("csv", data_files="../data/docee/train_all.csv")
# data_files can be a dictionary, where key is the name of the split, and value is path to the split
cls_dataset = load_dataset("csv", data_files={
    "train": "../data/docee/18091999/train.csv",
    "validation": "../data/docee/18091999/early_stopping.csv"
})
cls_dataset

Using custom data configuration default-0720af0f377253e9
Found cached dataset csv (/home/jvidakovic/.cache/huggingface/datasets/csv/default-0720af0f377253e9/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317)


  0%|          | 0/2 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['index', 'title', 'text', 'event_type', 'arguments', 'date', 'metadata'],
        num_rows: 17559
    })
    validation: Dataset({
        features: ['index', 'title', 'text', 'event_type', 'arguments', 'date', 'metadata'],
        num_rows: 2195
    })
})

In [10]:
cls_dataset["train"].shuffle(42).select(range(100))[:3]

Loading cached shuffled indices for dataset at /home/jvidakovic/.cache/huggingface/datasets/csv/default-0720af0f377253e9/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317/cache-1b432e07d6b975f8.arrow


{'index': [8677, 13606, 10423],
 'title': ['Simulation of glacial calving and tsunami waves predicts climate change consequences',
  'Four protesters and two police officers are killed during clashes in Baghdad.',
  'Italian firm Fiat Chrysler  proposes a merger with French carmaker Renault. The new company will be based in the Netherlands and will be listed on the Milan, Paris and New York stock exchanges.'],
 'text': ['As natural disasters intensify due to climate change, accurate predictions of weather patterns and mechanisms are greatly needed to mitigate damage. Coastal regions will be the most affected by changing weather, with events such as tsunamis and hurricanes becoming more frequent and life-threatening. While most tsunamis are caused by earthquakes and tectonic activity, the warming of the planet is now increasing the occurrence of tsunamis caused by glacier calving, when chunks of glacier break off and become icebergs. Additionally, glacier calving is predicted to be the 

In [13]:
from transformers.utils import PaddingStrategy

max_input_length = 512
max_target_length = 100

def process_summary_example(examples):
    # tokenize the article
    batch_encoding = tokenizer(
        examples["article"],
        max_length=max_input_length,
        truncation=True
    )

    # tokenize the labels
    tokenized_highlights = tokenizer(
        examples["highlights"],
        max_length=max_target_length,
        truncation=True
    )

    batch_encoding["labels"] = tokenized_highlights["input_ids"]
    return batch_encoding

In [20]:
tokenized_cnn = summ_dataset.map(process_summary_example, batched=True, remove_columns=["id", "article", "highlights"])


  0%|          | 0/14 [00:00<?, ?ba/s]

In [15]:
tokenized_cnn.features

{'article': Value(dtype='string', id=None),
 'highlights': Value(dtype='string', id=None),
 'id': Value(dtype='string', id=None),
 'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None),
 'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None),
 'labels': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None)}

In [79]:
import evaluate

#import evaluate
rouge_score = evaluate.load("rouge")

generated_summary = "I absolutely loved reading the Hunger Games"
reference_summary = "I loved reading the Hunger Games"

scores = rouge_score.compute(
    predictions=[generated_summary],
    references=[reference_summary]
)
scores

Downloading builder script:   0%|          | 0.00/6.27k [00:00<?, ?B/s]

{'rouge1': 0.923076923076923,
 'rouge2': 0.7272727272727272,
 'rougeL': 0.923076923076923,
 'rougeLsum': 0.923076923076923}

In [16]:
from transformers import Seq2SeqTrainingArguments

batch_size = 1
num_train_epochs = 1
logging_steps = len(summ_dataset) // batch_size

args = Seq2SeqTrainingArguments(
    output_dir=f"{PRETRAINED_MODEL_NAME_OR_PATH}-finetuned-cnn",
    evaluation_strategy="epoch",
    learning_rate=5.6e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=num_train_epochs,
    predict_with_generate=True,
    logging_steps=logging_steps,
    push_to_hub=True,
    sortish_sampler=True,
    generation_max_length=150,
    generation_num_beams=1
)


In [17]:
from nltk import sent_tokenize
import numpy as np


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    # Decode generated summaries into text
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    # Decode reference summaries into text
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    # ROUGE expects a newline after each sentence
    decoded_preds = ["\n".join(sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(sent_tokenize(label.strip())) for label in decoded_labels]
    # Compute ROUGE scores
    result = rouge_score.compute(
        predictions=decoded_preds, references=decoded_labels, use_stemmer=True
    )
    # Extract the median scores
    result = {key: value * 100 for key, value in result.items()}
    return {k: round(v, 4) for k, v in result.items()}



In [21]:
from transformers import DataCollatorForSeq2Seq

summ_data_collator = DataCollatorForSeq2Seq(tokenizer, model=models["summarization"])

In [22]:
features = [tokenized_cnn[i] for i in range(2)]
features

[{'input_ids': [0,
   1640,
   16256,
   43,
   11957,
   6,
   8,
   110,
   4085,
   40,
   28,
   39582,
   4,
   280,
   189,
   2369,
   101,
   41,
   43962,
   2329,
   1580,
   6,
   53,
   77,
   525,
   19678,
   163,
   8508,
   7485,
   1120,
   1403,
   12445,
   1276,
   7,
   492,
   65,
   9,
   69,
   33473,
   7,
   10,
   12443,
   6,
   69,
   19501,
   11153,
   62,
   19,
   380,
   414,
   4,
   85,
   4596,
   11,
   411,
   1484,
   2806,
   28748,
   3277,
   4,
   280,
   3911,
   8,
   885,
   9725,
   69,
   4,
   22,
   100,
   802,
   38,
   21,
   164,
   7,
   244,
   42,
   65,
   621,
   54,
   38,
   218,
   75,
   216,
   6,
   53,
   5,
   754,
   14,
   98,
   171,
   82,
   64,
   33,
   10,
   301,
   5064,
   6,
   14,
   18,
   1256,
   380,
   60,
   163,
   8508,
   7485,
   1120,
   174,
   3480,
   10515,
   229,
   14740,
   4,
   264,
   189,
   619,
   10346,
   11,
   69,
   19501,
   30,
   10,
   723,
   476,
   4,
   22,
   22086,
 

In [23]:
summ_data_collator(features)

You're using a BartTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


{'input_ids': tensor([[    0,  1640, 16256,  ...,    39,  2761,     2],
        [    0,  1640, 16256,  ...,    95,    15,     2]]), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]]), 'labels': tensor([[    0,  1301, 19678,   163,  8508,  7485,  1120,  1276,     7,   492,
            10, 12855,     7,    10, 12443,   479, 50118,   250,    92,  3034,
           586,  1147,    69,  7096, 15220, 28748,  3277,    13,   411, 12855,
          1484,   479,     2,  -100,  -100,  -100],
        [    0,   133,   291,   212, 13989,   191,  3772,    42,   983,   479,
         50118, 17608,    34,  1714,  8617,   187,    63, 17692,    11,  8008,
           479, 50118,  6323,   864,   549,  1492,  2624,  5391,  9686,     8,
         12291,   240,     7,   464,   479,     2]]), 'decoder_input_ids': tensor([[    2,     0,  1301, 19678,   163,  8508,  7485,  1120,  1276,     7,
           492,    10, 12855,     7,    10, 12443,   479, 50118,   250,    92,
          3

In [27]:
tokenized_cnn.set_format("torch")

In [29]:
from torch.utils.data import DataLoader

batch_size = 1
train_dataloader = DataLoader(
    tokenized_cnn,
    shuffle=True,
    collate_fn=summ_data_collator,
    batch_size=batch_size
)

eval_dataloader = DataLoader(
    tokenized_cnn,
    collate_fn=summ_data_collator,
    batch_size=batch_size
)


In [30]:
from torch.optim import AdamW

optimizer = AdamW(models["summarization"].parameters(), lr=2e-5)

In [32]:
from accelerate import Accelerator
accelerator = Accelerator()
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
    models["summarization"], optimizer, train_dataloader, eval_dataloader
)


In [33]:
from transformers import get_scheduler

num_train_epochs = 1
num_update_steps_per_epoch = len(train_dataloader)
num_training_steps = num_train_epochs * num_update_steps_per_epoch

lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)


In [34]:
import nltk

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]

    # ROUGE expects a newline after each sentence
    preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

    return preds, labels

In [None]:
from tqdm.auto import tqdm
import torch

# progress_bar = tqdm(range(num_training_steps))

for epoch in tqdm(range(num_train_epochs), total=num_train_epochs, desc="Epoch progress"):
    # Training
    model.train()
    for step, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc="Epoch step", leave=False):
        # pass through model
        outputs = model(**batch)
        loss = outputs.loss
        accelerator.backward(loss)

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

    # Evaluation
    model.eval()
    for step, batch in enumerate(eval_dataloader):
        with torch.no_grad():
            generated_tokens = accelerator.unwrap_model(model).generate(
                batch["input_ids"],
                attention_mask=batch["attention_mask"],
            )  # aha! we can plug the generation parameters here

            generated_tokens = accelerator.pad_across_processes(
                generated_tokens, dim=1, pad_index=tokenizer.pad_token_id
            )
            labels = batch["labels"]

            # If we did not pad to max length, we need to pad the labels too
            labels = accelerator.pad_across_processes(
                batch["labels"], dim=1, pad_index=tokenizer.pad_token_id
            )

            generated_tokens = accelerator.gather(generated_tokens).cpu().numpy()
            labels = accelerator.gather(labels).cpu().numpy()

            # Replace -100 in the labels as we can't decode them
            labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
            if isinstance(generated_tokens, tuple):
                generated_tokens = generated_tokens[0]
            decoded_preds = tokenizer.batch_decode(
                generated_tokens, skip_special_tokens=True
            )
            decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

            decoded_preds, decoded_labels = postprocess_text(
                decoded_preds, decoded_labels
            )

            rouge_score.add_batch(predictions=decoded_preds, references=decoded_labels)

    # Compute metrics
    result = rouge_score.compute()
    # Extract the median ROUGE scores
    result = {key: value * 100 for key, value in result.items()}
    result = {k: round(v, 4) for k, v in result.items()}
    print(f"Epoch {epoch}:", result)

    output_dir = "./test_summ_train"
    # Save and upload
    accelerator.wait_for_everyone()
    unwrapped_model = accelerator.unwrap_model(model)
    unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)
    if accelerator.is_main_process:
        tokenizer.save_pretrained(output_dir)

Epoch progress:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch step:   0%|          | 0/13368 [00:00<?, ?it/s]