<a href="https://colab.research.google.com/github/kevin-rn/Grounding-LM/blob/main/abstractive_summary.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Google Colab setup

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
%cp -R ./drive/MyDrive/data/ ./

# Packages & Imports

In [3]:
%pip install -q pytorch-lightning
%pip install -q transformers
%pip install -q datasets

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m720.6/720.6 kB[0m [31m21.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.2/519.2 kB[0m [31m32.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m44.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m114.5/114.5 kB[0m [31m13.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.8/268.8 kB[0m [31m23.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m149.6/149.6 kB[0m [31m17.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.2/7.2 MB[0m [31m86.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m236.8/236.8 kB[0m [31m22.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━

In [4]:
from datasets import load_dataset, load_from_disk
import pandas as pd
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
from torch.optim import AdamW
from tqdm.auto import tqdm
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

# Dataset

In [None]:
# xsum_data = load_dataset("xsum")
# xsum_data.save_to_disk('data/xsum')

# cnn_data = load_dataset("cnn_dailymail", "3.0.0")
# cnn_data.save_to_disk('data/cnn_dailymail')

## tldr_data = load_dataset("webis/tldr-17")

In [5]:
data = load_from_disk('data/xsum')
# data = load_from_disk('data/cnn_dailymail')

df_train = pd.DataFrame(data=data['train'])
df_val = pd.DataFrame(data=data['validation'])
df_test = pd.DataFrame(data=data['test'])

df_train.columns = ['text', 'summary', 'id']
df_val.columns = ['text', 'summary', 'id']
df_test.columns = ['text', 'summary', 'id']

df_train.head()

Unnamed: 0,text,summary,id
0,"The full cost of damage in Newton Stewart, one...",Clean-up operations are continuing across the ...,35232142
1,A fire alarm went off at the Holiday Inn in Ho...,Two tourist buses have been destroyed by fire ...,40143035
2,Ferrari appeared in a position to challenge un...,Lewis Hamilton stormed to pole position at the...,35951548
3,"John Edward Bates, formerly of Spalding, Linco...",A former Lincolnshire Police officer carried o...,36266422
4,Patients and staff were evacuated from Cerahpa...,An armed man who locked himself into a room at...,38826984


In [6]:
class CustomDataset(Dataset):
    def __init__(self, data, tokenizer, text_max_len = 512, summary_max_len = 128):
        self.data = data
        self.tokenizer = tokenizer
        self.text_max_len = text_max_len
        self.summary_max_len = summary_max_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx: int):
        text = self.data.iloc[idx]['text']
        summary = self.data.iloc[idx]['summary']

        text_encoding = self.tokenizer(
            text,
            max_length=self.text_max_len,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            add_special_tokens=True,
            return_tensors='pt'
        )

        summary_encoding = self.tokenizer(
            summary,
            max_length=self.summary_max_len,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            add_special_tokens=True,
            return_tensors='pt'
        )

        labels = summary_encoding['input_ids']
        labels[labels == 0] = -100

        return {
            'text': text,
            'summary': summary,
            'input_ids': text_encoding['input_ids'],
            'attention_mask': text_encoding['attention_mask'],
            'labels': labels.flatten(),
            'labels_attention_mask': summary_encoding['attention_mask']
        }

In [7]:
class CustomDataModule(pl.LightningDataModule):
    def __init__(self, df_train, df_val, df_test, tokenizer, batch = 8, text_max_len = 512, summary_max_len = 128):
        super().__init__()
        self.df_train = df_train
        self.df_val = df_val
        self.df_test = df_test
        self.tokenizer = tokenizer
        self.batch = batch
        self.text_max_len = text_max_len
        self.summary_max_len = summary_max_len

    def setup(self, stage=None):
      self.train_dataset = CustomDataset(self.df_train, self.tokenizer, self.text_max_len, self.summary_max_len)
      self.val_dataset = CustomDataset(self.df_val, self.tokenizer, self.text_max_len, self.summary_max_len)
      self.test_dataset = CustomDataset(self.df_test, self.tokenizer, self.text_max_len, self.summary_max_len)

    def collate_fn(self, batch):
      texts = [item['text'] for item in batch]
      summaries = [item['summary'] for item in batch]
      text_input_ids = pad_sequence([item['input_ids'].flatten() for item in batch], batch_first=True)
      text_attention_masks = pad_sequence([item['attention_mask'].flatten() for item in batch], batch_first=True)
      labels = pad_sequence([item['labels'] for item in batch], batch_first=True)
      labels_attention_masks = pad_sequence([item['labels_attention_mask'].flatten() for item in batch], batch_first=True)

      return {
          'text': texts,
          'summary': summaries,
          'input_ids': text_input_ids,
          'attention_mask': text_attention_masks,
          'labels': labels,
          'labels_attention_mask': labels_attention_masks
      }

    def train_dataloader(self):
      return DataLoader(self.train_dataset, batch_size=self.batch, shuffle=True, num_workers=2, collate_fn=self.collate_fn)

    def val_dataloader(self):
      return DataLoader(self.val_dataset, batch_size=self.batch, shuffle=True, num_workers=2, collate_fn=self.collate_fn)

    def test_dataloader(self):
      return DataLoader(self.test_dataset, batch_size=self.batch, shuffle=True, num_workers=2, collate_fn=self.collate_fn)

# Pretrained T5 model

In [8]:
tokenizer = AutoTokenizer.from_pretrained("sysresearch101/t5-large-finetuned-xsum-cnn")
model = AutoModelForSeq2SeqLM.from_pretrained("sysresearch101/t5-large-finetuned-xsum-cnn")

Downloading (…)okenizer_config.json:   0%|          | 0.00/1.92k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/1.79k [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.43k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/2.95G [00:00<?, ?B/s]

# BART

In [None]:
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-xsum")
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-xsum")

# Generate Summaries

In [9]:
data_module = CustomDataModule(df_train, df_val, df_test, tokenizer, batch=8)
data_module.setup()

In [12]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

summaries = []

with torch.no_grad():
    for batch in tqdm(data_module.test_dataloader()):
        input_ids = batch['input_ids'].to(device)
        summary_ids = model.generate(
          input_ids = batch['input_ids'],
          attention_mask=batch['attention_mask'],
          max_length=150, 
          num_beams=2,
          repetition_penalty=2.5,
          length_penalty=1.0,
          early_stopping=True
      )

        summary_text = [tokenizer.decode(sum_id, skip_special_tokens=True, clean_up_tokenization_spaces=True) for sum_id in summary_ids]
        summary_text = "".join(summary_text)
        summaries.append(summary_text)

  0%|          | 0/1417 [00:00<?, ?it/s]

Process Process-10:
Process Process-9:
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fba7f91c4c0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1478, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1442, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/usr/lib/python3.10/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
  File "/usr/lib/python3.10/multiprocessing/popen_fork.py", line 40, in wait
    if not wait([self.sentinel], timeout):
  File "/usr/lib/python3.10/multiprocessing/connection.py", line 931, in wait
    ready = selector.select(timeout)
  File "/usr/lib/python3.10/selectors.py", line 416, in select
    fd_event_list = self._selector.poll(timeout)
KeyboardInterrupt: 
Traceback (most recent call last):
Traceback (most recent cal

KeyboardInterrupt: ignored

In [None]:
df_test['generated_summaries'] = summaries
df_test.to_csv("t5_xsum_gen.csv", index=False)

In [None]:
%cp t5_xsum_gen.csv ./drive/MyDrive/data/