# Xlsum Summarization

1. Import dependencies 

In [None]:
import random
import torch
import numpy as np
import wandb

from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
)

from src.util.torch_device import resolve_torch_device
from src.data.xlsum import load_xlsum
from src.metrics.summarization import compute_metrics
from src.definitions import MODELS_FOLDER

2. Init WANB

In [None]:
wandb.login()

2. Config

In [None]:
random_seed = 42

random.seed(random_seed)
torch.manual_seed(random_seed)
np.random.seed(random_seed)

device = resolve_torch_device()

In [None]:
model_checkpoint = "facebook/bart-large"
dataset_name = "csebuetnlp/xlsum"
language = "english"
batch_size = 4
num_train_epochs = 3
max_input_length = 512
max_target_length = 64
train_size = 0.8

In [None]:
run_name = f"xlsum-{str(model_checkpoint).split("/")[-1]}"

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

3. Load dataset

In [None]:
ds = load_xlsum(tokenizer, max_input_length, max_target_length, language)

In [None]:
ds

4. Train

In [None]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

training_args = Seq2SeqTrainingArguments(
    output_dir=MODELS_FOLDER / f"{run_name}-checkpoint",
    evaluation_strategy="epoch",
    eval_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=2,
    num_train_epochs=num_train_epochs,
    predict_with_generate=True,
    bf16=True,
    logging_dir="./logs",
    logging_strategy="steps",
    logging_steps=100,
    report_to="wandb",
    run_name=run_name,
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=ds["train"],
    eval_dataset=ds["validation"],
    processing_class=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [None]:
torch.cuda.empty_cache()

trainer.train()

5. Save weights

In [None]:
trainer.save_model(MODELS_FOLDER / run_name)
tokenizer.save_pretrained(MODELS_FOLDER / run_name)