In [1]:
from datasets import load_dataset

dataset = load_dataset("ccdv/arxiv-summarization")
dataset

No config specified, defaulting to: arxiv-summarization/section
Found cached dataset arxiv-summarization (C:/Users/JustinDu/.cache/huggingface/datasets/ccdv___arxiv-summarization/section/1.0.0/fa2c9abf4312afb8660ef8e041d576b8e3943ea96ae771bd3cd091b5798e7cc3)


  0%|          | 0/3 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['article', 'abstract'],
        num_rows: 203037
    })
    validation: Dataset({
        features: ['article', 'abstract'],
        num_rows: 6436
    })
    test: Dataset({
        features: ['article', 'abstract'],
        num_rows: 6440
    })
})

In [2]:
data_train = dataset['train']
data_test = dataset['test']
data_val = dataset['validation']
data_test

Dataset({
    features: ['article', 'abstract'],
    num_rows: 6440
})

In [3]:
data_test = data_test.shard(num_shards=644, index=0)
data_val = data_val.shard(num_shards=806, index=0)

In [4]:
from transformers import AutoTokenizer
from transformers import AutoModelForSeq2SeqLM
from transformers import DataCollatorForSeq2Seq
import torch

checkpoint = 'facebook/bart-large-cnn'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [5]:
max_length = 1024

def tokenize_function(data):

    model_inputs = tokenizer(
        data["article"],
        truncation=True,
        padding='max_length',
        max_length=max_length,
        return_tensors='pt'
    )

    labels = tokenizer(
        data['abstract'],
        truncation=True,
        padding='max_length',
        max_length=max_length,
        return_tensors='pt'
    )
    model_inputs["labels"] = labels["input_ids"]

    return model_inputs

In [6]:
tok = data_test.map(tokenize_function, batched=True)
tok_val = data_val.map(tokenize_function, batched=True)
tok

Loading cached processed dataset at C:/Users/JustinDu/.cache/huggingface/datasets/ccdv___arxiv-summarization/section/1.0.0/fa2c9abf4312afb8660ef8e041d576b8e3943ea96ae771bd3cd091b5798e7cc3\cache-6b52843ce66933f8.arrow


  0%|          | 0/1 [00:00<?, ?ba/s]

Dataset({
    features: ['article', 'abstract', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 10
})

In [7]:
tok = tok.remove_columns(
    data_test.column_names
)

tok_val = tok_val.remove_columns(
    data_val.column_names
)

In [8]:
import evaluate
import nltk

nltk.download("punkt")
rouge_score = evaluate.load('rouge')

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\JustinDu\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


Data is done preparing

In [9]:
tok.set_format('torch')
tok_val.set_format('torch')


In [10]:
default_args = {
    "output_dir": "accelxiv",
    "evaluation_strategy": "steps",
    "num_train_epochs": 8,
}

In [11]:
from transformers import Seq2SeqTrainingArguments, Trainer, logging
training_args = Seq2SeqTrainingArguments(
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    fp16=True,
    **default_args,
)

In [12]:
from nltk.tokenize import sent_tokenize
def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]

    # ROUGE expects a newline after each sentence
    preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

    return preds, labels

In [13]:
import torch_optimizer as optim

optimizer = optim.Adafactor(model.parameters(), lr=2e-5)

In [14]:
from accelerate import Accelerator
from torch.utils.data.dataloader import DataLoader
from tqdm.auto import tqdm
import torch
import numpy as np

device = torch.device('cuda')

dataloader = DataLoader(tok, batch_size=training_args.per_device_train_batch_size)
dataloader_val = DataLoader(tok_val, batch_size=training_args.per_device_train_batch_size)

num_update_steps_per_epoch = len(dataloader)
num_training_steps = num_update_steps_per_epoch * training_args.num_train_epochs
progress_bar = tqdm(range(num_training_steps))


if training_args.gradient_checkpointing:
    model.gradient_checkpointing_enable()

accelerator = Accelerator(fp16=training_args.fp16)
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)

for epoch in range(training_args.num_train_epochs):
    model.train()
    for step, batch in enumerate(dataloader, start=1):
        loss = model(**batch).loss
        loss = loss / training_args.gradient_accumulation_steps
        accelerator.backward(loss)
        if step % training_args.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
            progress_bar.update(1)
    model.eval()
    for step, batch in enumerate(dataloader_val, start=1):

        batch = tuple(b.to(device) for b in batch.values())
        
        with torch.no_grad():
            generated_tokens = accelerator.unwrap_model(model).generate(
                batch[0],
                attention_mask=batch[1],
            )

            generated_tokens = accelerator.pad_across_processes(
                generated_tokens, dim=1, pad_index=tokenizer.pad_token_id
            )
            labels = batch[2]

            # If we did not pad to max length, we need to pad the labels too
            labels = accelerator.pad_across_processes(
                batch[2], dim=1, pad_index=tokenizer.pad_token_id
            )

            generated_tokens = accelerator.gather(generated_tokens).cpu().numpy()
            labels = accelerator.gather(labels).cpu().numpy()

            # Replace -100 in the labels as we can't decode them
            labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
            if isinstance(generated_tokens, tuple):
                generated_tokens = generated_tokens[0]
            decoded_preds = tokenizer.batch_decode(
                generated_tokens, skip_special_tokens=True
            )
            decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

            decoded_preds, decoded_labels = postprocess_text(
                decoded_preds, decoded_labels
            )

            rouge_score.add_batch(predictions=decoded_preds, references=decoded_labels)

    # Compute metrics
    result = rouge_score.compute()
    # Extract the median ROUGE scores
    result = {key: value * 100 for key, value in result.items()}
    print(f"Epoch {epoch}:", result)
    

  0%|          | 0/80 [00:00<?, ?it/s]



Epoch 0: {'rouge1': 1.0671809816668667, 'rouge2': 0.0, 'rougeL': 1.0959146401141044, 'rougeLsum': 0.9789893269594723}
Epoch 1: {'rouge1': 4.242374667333113, 'rouge2': 0.0, 'rougeL': 3.7452039039241436, 'rougeLsum': 4.1722721464403865}
Epoch 2: {'rouge1': 17.887760476829907, 'rouge2': 1.8652542276891706, 'rougeL': 14.420131240544096, 'rougeLsum': 16.89484288988697}
Epoch 3: {'rouge1': 0.1445086705202312, 'rouge2': 0.0, 'rougeL': 0.1445086705202312, 'rougeLsum': 0.1445086705202312}
Epoch 4: {'rouge1': 2.9475207581914513, 'rouge2': 0.0, 'rougeL': 2.7135159667370337, 'rougeLsum': 2.8405140060462055}
Epoch 5: {'rouge1': 16.090860871594977, 'rouge2': 0.7908727481095902, 'rougeL': 9.860130014210226, 'rougeLsum': 13.229620534591785}
Epoch 6: {'rouge1': 12.33080267374235, 'rouge2': 0.11848341232227488, 'rougeL': 8.16751234054356, 'rougeLsum': 9.73266596311361}
Epoch 7: {'rouge1': 2.273529671579645, 'rouge2': 0.0, 'rougeL': 2.080913458972712, 'rougeLsum': 2.254768351981075}
