# Xlsum Summarization

1. Import dependencies 

In [1]:
import random
import torch
import numpy as np
import wandb
import evaluate
import tqdm

from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    pipeline,
)

from src.util.torch_device import resolve_torch_device
from src.data.xlsum import load_xlsum
from src.metrics.summarization import compute_metrics
from src.definitions import MODELS_FOLDER

2. Init WANB

In [2]:
wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33malexander-melashchenko[0m ([33malexander-melashchenko-igor-sikorsky-kyiv-polytechnic-in[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

2. Config

In [3]:
random_seed = 42

random.seed(random_seed)
torch.manual_seed(random_seed)
np.random.seed(random_seed)

device = resolve_torch_device()

In [4]:
model_checkpoint = "facebook/bart-large"
dataset_name = "csebuetnlp/xlsum"
language = "english"
batch_size = 4
num_train_epochs = 3
max_input_length = 512
max_target_length = 64
train_size = 0.8

In [5]:
run_name = f"xlsum-{str(model_checkpoint).split("/")[-1]}"
output_dir = MODELS_FOLDER / f"{run_name}-checkpoint"
resume_from_checkpoint = output_dir / "checkpoint-229893"

In [6]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

3. Load dataset

In [7]:
ds = load_xlsum(tokenizer, max_input_length, max_target_length, language)

In [8]:
ds

DatasetDict({
    train: Dataset({
        features: ['id', 'url', 'title', 'summary', 'text', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 306522
    })
    test: Dataset({
        features: ['id', 'url', 'title', 'summary', 'text', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 11535
    })
    validation: Dataset({
        features: ['id', 'url', 'title', 'summary', 'text', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 11535
    })
})

4. Train

In [9]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

training_args = Seq2SeqTrainingArguments(
    output_dir=output_dir,
    evaluation_strategy="epoch",
    eval_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=2,
    num_train_epochs=num_train_epochs,
    predict_with_generate=True,
    bf16=True,
    logging_dir="./logs",
    logging_strategy="steps",
    logging_steps=100,
    report_to="wandb",
    run_name=run_name,
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=ds["train"],
    eval_dataset=ds["validation"],
    processing_class=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics(tokenizer),
)



In [10]:
torch.cuda.empty_cache()

In [None]:
trainer.train(resume_from_checkpoint=resume_from_checkpoint)

5. Save weights

In [None]:
trainer.save_model(MODELS_FOLDER / run_name)
tokenizer.save_pretrained(MODELS_FOLDER / run_name)

6. Test model

In [11]:
model = AutoModelForSeq2SeqLM.from_pretrained(MODELS_FOLDER / run_name)
tokenizer = AutoTokenizer.from_pretrained(MODELS_FOLDER / run_name)

In [12]:
summarizer = pipeline(
    "summarization", model=model, tokenizer=tokenizer, truncation=True
)

Device set to use cuda:0


In [13]:
rouge = evaluate.load("rouge")

predictions = []
references = []

for sample in tqdm.tqdm(ds["test"]):
    pred = summarizer(sample["text"], max_length=max_target_length)[0]["summary_text"]
    predictions.append(pred)
    references.append(sample["summary"])

results = rouge.compute(predictions=predictions, references=references)

f"ROGUE is {results}"

  0%|          | 10/11535 [00:02<37:07,  5.17it/s] You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset
100%|██████████| 11535/11535 [44:04<00:00,  4.36it/s] 


"ROGUE is {'rouge1': np.float64(0.42241299994687187), 'rouge2': np.float64(0.1997336748109491), 'rougeL': np.float64(0.34443485056882484), 'rougeLsum': np.float64(0.34429214309547196)}"