# Summarization (PyTorch)

In [None]:
!pip install datasets evaluate transformers[sentencepiece] -qq
!pip install accelerate -qq
!pip install rouge_score -qq
!pip install nltk -qq

## Libraries

In [None]:
from datasets import load_dataset, Dataset, concatenate_datasets, DatasetDict
from transformers import AutoTokenizer
from transformers import AutoModelForSeq2SeqLM
from transformers import DataCollatorForSeq2Seq
from accelerate import Accelerator
from torch.optim import AdamW
from transformers import get_scheduler

import torch
from torch.utils.data import DataLoader

import evaluate
import nltk
from nltk.tokenize import sent_tokenize

import pandas as pd
import numpy as np

## Args

In [None]:
max_input_length = 512
max_target_length = 30
batch_size = 8
num_train_epochs = 10

## Download + Preprocess data

In [None]:
train_ds_size = 5000
test_ds_size = 1000

vietnamese_train_ds, vietnamese_test_ds = load_dataset("wiki_lingua", "vietnamese", 
                                                       split=[f'train[:{train_ds_size}]', f'train[{train_ds_size}:{train_ds_size+test_ds_size}]'])

english_train_ds, english_test_ds = load_dataset("wiki_lingua", "english", 
                                                 split=[f'train[:{train_ds_size}]', f'train[{train_ds_size}:{train_ds_size+test_ds_size}]'])

def preporcess_columns_data_type(ds):
  ds = pd.DataFrame(data=ds['article'])
  ds['document'] = ds['document'].apply(lambda x: ' '.join(x))
  ds['summary'] = ds['summary'].apply(lambda x: ' '.join(x))

  return ds

def parse_dataset(ds_train, ds_test):
  ds_train = preporcess_columns_data_type(ds_train)
  ds_train = Dataset.from_pandas(ds_train)

  ds_test = preporcess_columns_data_type(ds_test)
  ds_test = Dataset.from_pandas(ds_test)

  return ds_train, ds_test

english_train_ds, english_test_ds = parse_dataset(english_train_ds, english_test_ds)
vietnamese_train_ds, vietnamese_test_ds = parse_dataset(vietnamese_train_ds, vietnamese_test_ds)
english_train_ds

In [None]:
def show_samples(dataset, num_samples=3, seed=42):
  sample = dataset.shuffle(seed=seed).select(range(num_samples))

  for example in sample:
      print(f"\n'>> Document: {example['document']}'")
      print(f"'>> Summary: {example['summary']}'")

show_samples(english_train_ds)

In [None]:
dataset = DatasetDict()

def concat_two_datasets(dataset1, dataset2):
  dataset = concatenate_datasets([dataset1, dataset2])
  dataset = dataset.shuffle(seed=42)
  return dataset

dataset['train'] = concat_two_datasets(english_train_ds, vietnamese_train_ds)
dataset['test'] = concat_two_datasets(english_test_ds, vietnamese_test_ds)

show_samples(dataset['train'])

## EDA + Filter data

In [None]:
def get_len(x):
  return len(x.split())

dataset.set_format("pandas")
df = dataset['train'][:]

df["length"] = df['document'].apply(get_len)

df["length"].sort_values(ascending=False).plot(kind='hist')

In [None]:
dataset.reset_format()
dataset = dataset.filter(lambda x: get_len(x['document']) < 550)
dataset = dataset.filter(lambda x: get_len(x['summary']) > 2 and get_len(x['summary']) < 45)
dataset

## Metric

In [None]:
rouge_score = evaluate.load("rouge")
nltk.download("punkt")

## Model

In [None]:
model_checkpoint = "google/mt5-small"
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

## Tokenizer

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

def preprocess_function(examples):
    model_inputs = tokenizer(examples["document"],
        max_length=max_input_length,
        truncation=True)
    
    labels = tokenizer(examples["summary"], max_length=max_target_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

## Dataset

In [None]:
tokenized_datasets = dataset.map(preprocess_function, batched=True, batch_size=64)
tokenized_datasets = tokenized_datasets.remove_columns(
    dataset["train"].column_names)
tokenized_datasets.set_format("torch")

## DataLoader

In [None]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

train_dataloader = DataLoader(
    tokenized_datasets["train"],
    shuffle=True,
    collate_fn=data_collator,
    batch_size=batch_size,)

eval_dataloader = DataLoader(
    tokenized_datasets["test"], collate_fn=data_collator, batch_size=batch_size)

## Optimizer

In [None]:
optimizer = AdamW(model.parameters(), lr=2e-5)


In [None]:
accelerator = Accelerator()
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader)

## Scheduler

In [55]:
num_update_steps_per_epoch = len(train_dataloader)
num_training_steps = num_train_epochs * num_update_steps_per_epoch

lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)

In [None]:
def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]

    # ROUGE expects a newline after each sentence
    preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

    return preds, labels

## Train + Eval

In [56]:
from tqdm.auto import tqdm
import torch
import numpy as np

progress_bar = tqdm(range(num_training_steps))

for epoch in range(num_train_epochs):
    # Training
    model.train()
    for step, batch in enumerate(train_dataloader):
        outputs = model(**batch)
        loss = outputs.loss
        accelerator.backward(loss)

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)

    # Evaluation
    model.eval()
    for step, batch in enumerate(eval_dataloader):
        with torch.no_grad():
            generated_tokens = accelerator.unwrap_model(model).generate(
                batch["input_ids"],
                attention_mask=batch["attention_mask"],
                max_new_tokens=30
            )

            generated_tokens = accelerator.pad_across_processes(
                generated_tokens, dim=1, pad_index=tokenizer.pad_token_id
            )
            labels = batch["labels"]

            # If we did not pad to max length, we need to pad the labels too
            labels = accelerator.pad_across_processes(
                batch["labels"], dim=1, pad_index=tokenizer.pad_token_id
            )

            generated_tokens = accelerator.gather(generated_tokens).cpu().numpy()
            labels = accelerator.gather(labels).cpu().numpy()

            # Replace -100 in the labels as we can't decode them
            labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
            if isinstance(generated_tokens, tuple):
                generated_tokens = generated_tokens[0]
            decoded_preds = tokenizer.batch_decode(
                generated_tokens, skip_special_tokens=True
            )
            decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

            decoded_preds, decoded_labels = postprocess_text(
                decoded_preds, decoded_labels
            )

            rouge_score.add_batch(predictions=decoded_preds, references=decoded_labels)

    # Compute metrics
    result = rouge_score.compute()
    result ={key: round(value * 100, 4) for key, value in result.items()}
    print(f"Epoch {epoch+1} / {num_train_epochs}:", result)

  0%|          | 0/710 [00:00<?, ?it/s]

Epoch 1 / 10: {'rouge1': 16.9757, 'rouge2': 5.1833, 'rougeL': 14.9742, 'rougeLsum': 15.7201}
Epoch 2 / 10: {'rouge1': 16.9899, 'rouge2': 5.224, 'rougeL': 15.0307, 'rougeLsum': 15.845}
Epoch 3 / 10: {'rouge1': 17.8864, 'rouge2': 5.6078, 'rougeL': 15.2309, 'rougeLsum': 16.4326}
Epoch 4 / 10: {'rouge1': 17.641, 'rouge2': 5.5965, 'rougeL': 15.0372, 'rougeLsum': 16.29}
Epoch 5 / 10: {'rouge1': 18.0917, 'rouge2': 5.645, 'rougeL': 15.3974, 'rougeLsum': 16.6957}
Epoch 6 / 10: {'rouge1': 17.8797, 'rouge2': 5.4575, 'rougeL': 15.2241, 'rougeLsum': 16.4437}
Epoch 7 / 10: {'rouge1': 17.9234, 'rouge2': 5.5306, 'rougeL': 15.1526, 'rougeLsum': 16.48}
Epoch 8 / 10: {'rouge1': 17.9172, 'rouge2': 5.5306, 'rougeL': 15.0623, 'rougeLsum': 16.3199}
Epoch 9 / 10: {'rouge1': 17.9391, 'rouge2': 5.6348, 'rougeL': 15.2106, 'rougeLsum': 16.0324}
Epoch 10 / 10: {'rouge1': 18.2028, 'rouge2': 5.72, 'rougeL': 15.3121, 'rougeLsum': 16.1298}


## Save model

In [None]:
output_dir = 'mt5-finetuned-summarization'

accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)
if accelerator.is_main_process:
  tokenizer.save_pretrained(output_dir)

## Inference

In [60]:
doc = """Tìm máy nghiền gốc cây. Đào xung quanh rễ cây. 
Tìm hiểu xem có được phép đốt gốc cây trong khu vực bạn ở không. Giữ cho trẻ em và vật nuôi tránh xa gốc cây. Thay thế tro bằng đất mùn."""

def preprocess_txt(examples):
    model_inputs=tokenizer(
    examples,
    max_length=512,
    padding="max_length",
    truncation=True,
    return_attention_mask=True,
    add_special_tokens=True,
    return_tensors="pt")

    return model_inputs
    
input_example = preprocess_txt(doc)
input_example = input_example.to("cuda")

model.eval()
with torch.no_grad():
  generated_tokens = accelerator.unwrap_model(model).generate(
      input_ids=input_example["input_ids"],
      attention_mask=input_example["attention_mask"],
      num_beams=2,
      max_length=30,
      repetition_penalty=2.5,
      length_penalty=2.0,
      early_stopping=True,
      use_cache=True
  )

preds=[tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True) 
         for gen_id in generated_tokens]

summary = "".join(preds)
print(summary)

Đốt gốc cây.
