In [1]:
from transformers import BartTokenizerFast, BartForConditionalGeneration, BartTokenizer
from transformers import AdamW, get_linear_schedule_with_warmup
import datasets

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler
from sklearn.model_selection import train_test_split

from tqdm.notebook import tqdm
import numpy as np
import pandas as pd
import random
import time
import datetime
from sumeval.metrics.rouge import RougeCalculator

In [2]:
batch_size = 8
adam_eps = 1e-8
learning_rate = 2e-5
epochs = 2

In [3]:
def set_seed(seed_val):
    random.seed(seed_val)
    np.random.seed(seed_val)
    torch.cuda.manual_seed(seed_val)
    torch.manual_seed(seed_val)

In [4]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [5]:
def format_time(elapsed):
    '''
    Takes a time in seconds and returns a string hh:mm:ss
    '''
    # Round to the nearest second
    elapsed_rounded = int(round((elapsed)))
    
    # Format as hh:mm:ss
    return str(datetime.timedelta(seconds=elapsed_rounded))

In [6]:
set_seed(42)

# Load Data

In [7]:
df = pd.read_csv('data/summarize/news_data.csv')

In [8]:
train_df, eval_df = train_test_split(df, test_size=0.2, random_state=2021)
eval_df, test_df = train_test_split(eval_df, test_size=0.5, random_state=2021)

In [9]:
class SummarizerDataset(Dataset):
    def __init__(self, tokenizer, df, max_length=90):
        self.tokenizer = tokenizer
        self.df = df
        self.max_length = max_length
        
    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self, index):
        data = self.df.iloc[index]
        input_text = data['text']
        target_text = data['headlines']
        
        input_ids = self.tokenizer.encode_plus(
            input_text, max_length=self.max_length, padding='max_length',
            return_tensors='pt', truncation=True
        )
        
        target_ids = self.tokenizer.encode_plus(
            target_text, max_length=self.max_length, padding='max_length',
            return_tensors='pt', truncation=True
        )
        
        targets = target_ids["input_ids"].squeeze()
        pad_token_id = tokenizer.pad_token_id
        labels = targets.clone()
        labels[targets == pad_token_id] = -100
        
        return {
            'source_ids': input_ids['input_ids'].squeeze(),
            'source_mask': input_ids['attention_mask'].squeeze(),
            'target_ids': targets,
            'labels': labels,
            'target_text': target_text
        }

In [10]:
tokenizer = BartTokenizer.from_pretrained('sshleifer/distilbart-xsum-6-6')

In [11]:
train_dataset = SummarizerDataset(tokenizer, train_df)
eval_dataset = SummarizerDataset(tokenizer, eval_df)

train_sampler = RandomSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=batch_size)

eval_sampler = SequentialSampler(eval_dataset)
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=batch_size)

# Prepare model

In [12]:
model = BartForConditionalGeneration.from_pretrained('sshleifer/distilbart-xsum-6-6')
model.to(device)

BartForConditionalGeneration(
  (model): BartModel(
    (shared): Embedding(50264, 1024, padding_idx=1)
    (encoder): BartEncoder(
      (embed_tokens): Embedding(50264, 1024, padding_idx=1)
      (embed_positions): BartLearnedPositionalEmbedding(1026, 1024)
      (layers): ModuleList(
        (0): BartEncoderLayer(
          (self_attn): BartAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (fc2): Linear(in_features=4096, out_features=1024, bias=True)
          (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
   

In [13]:
optimizer = AdamW(
    model.parameters(),
    lr=learning_rate,
    eps=adam_eps
)

scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader) * epochs
)

rouge = RougeCalculator(stopwords=True, lang="en")

In [None]:
train_loss_values = []
eval_loss_values = []

for epoch in tqdm(range(epochs)):
    model.train()
    total_loss = 0
    t0 = time.time()
    
    for step, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader), leave=False):
        input_ids = batch['source_ids'].to(device)
        attention_mask = batch['source_mask'].to(device)
        target_ids = batch['target_ids'].to(device)
        labels = batch['labels'].to(device)
        
        model.zero_grad()
        
        # Forward pass
        outputs = model(
            input_ids, 
            attention_mask=attention_mask, 
            decoder_input_ids=target_ids, 
            labels=labels
        )
        
        loss, logits = outputs[:2]
        
        # Calculate loss
        total_loss += loss.item()
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        
        # Metrics calculation
        if step > 0 and step % 1000 == 0:
            # Calculate elapsed time in minutes.
            elapsed = format_time(time.time() - t0)
            # Report progress.
            print('  Batch {:>5,}  of  {:>5,}.    Elapsed: {:}.'.format(step, len(train_dataloader), elapsed))
            
    # Calculate the average loss over the training data.
    avg_train_loss = total_loss / len(train_dataloader)            
    
    # Store the loss value for plotting the learning curve.
    train_loss_values.append(avg_train_loss)

    print("")
    print("  Average training loss: {0:.3f}".format(avg_train_loss))
    
    # Evaluation mode
    model.eval()
    nb_eval_steps = 0
    eval_loss = 0
    
    rogue_scores = {
        'rogue_1': [],
        'rogue_2': [],
        'rogue_l': []
    }
    references = []
    
    for step, batch in tqdm(enumerate(eval_dataloader), total=len(eval_dataloader), leave=False):
        input_ids = batch['source_ids'].to(device)
        attention_mask = batch['source_mask'].to(device)
        target_ids = batch['target_ids'].to(device)
        labels = batch['labels'].to(device)
        
        with torch.no_grad():
            outputs = model(
                input_ids, 
                attention_mask=attention_mask, 
                decoder_input_ids=target_ids, 
                labels=labels
            )
            
            tmp_eval_loss, logits = outputs[:2]
            eval_loss += tmp_eval_loss.item()
        
        # Calculate metrics
        summarized_output = model.generate(input_ids)
        for sum_output, sum_gt in zip(summarized_output, batch['target_text']):
            sum_pred = tokenizer.decode(sum_output, skip_special_tokens=True, clean_up_tokenization_spaces=True)
            rogue_scores['rogue_1'].append(rouge.rouge_n(
                summary=sum_pred,
                references=sum_gt,
                n=1
            ))
            rogue_scores['rogue_2'].append(rouge.rouge_n(
                summary=sum_pred,
                references=sum_gt,
                n=2
            ))
            rogue_scores['rogue_l'].append(rouge.rouge_l(
                summary=sum_pred,
                references=sum_gt
            ))
            
    eval_loss = eval_loss / len(eval_dataloader)
    eval_loss_values.append(eval_loss)
    
    print("  Average evaluation loss: {0:.3f}".format(eval_loss))
    print('Metrics:\nRogue_1: {}\nRogue_2: {}\nRogue_L:{}'.format(
        np.mean(rogue_scores['rogue_1']), np.mean(rogue_scores['rogue_2']), np.mean(rogue_scores['rogue_l'])
    ))

HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=10292.0), HTML(value='')))

  Batch 1,000  of  10,292.    Elapsed: 0:05:03.
  Batch 2,000  of  10,292.    Elapsed: 0:10:08.
  Batch 3,000  of  10,292.    Elapsed: 0:15:13.
  Batch 4,000  of  10,292.    Elapsed: 0:20:18.
  Batch 5,000  of  10,292.    Elapsed: 0:25:23.
  Batch 6,000  of  10,292.    Elapsed: 0:30:28.
  Batch 7,000  of  10,292.    Elapsed: 0:35:34.
  Batch 8,000  of  10,292.    Elapsed: 0:40:39.
  Batch 9,000  of  10,292.    Elapsed: 0:45:44.
  Batch 10,000  of  10,292.    Elapsed: 0:50:48.

  Average training loss: 0.013


HBox(children=(FloatProgress(value=0.0, max=1287.0), HTML(value='')))

  Average evaluation loss: 0.000
Metrics:
Rogue_1: 0.00038409031864110294
Rogue_2: 0.0
Rogue_L:0.000379107124463467


HBox(children=(FloatProgress(value=0.0, max=10292.0), HTML(value='')))

  Batch 1,000  of  10,292.    Elapsed: 0:05:02.
  Batch 2,000  of  10,292.    Elapsed: 0:10:04.
  Batch 3,000  of  10,292.    Elapsed: 0:15:07.
  Batch 4,000  of  10,292.    Elapsed: 0:20:10.
  Batch 5,000  of  10,292.    Elapsed: 0:25:13.
  Batch 6,000  of  10,292.    Elapsed: 0:30:16.
  Batch 7,000  of  10,292.    Elapsed: 0:35:20.
  Batch 8,000  of  10,292.    Elapsed: 0:40:23.
  Batch 9,000  of  10,292.    Elapsed: 0:45:27.
  Batch 10,000  of  10,292.    Elapsed: 0:50:30.

  Average training loss: 0.000


HBox(children=(FloatProgress(value=0.0, max=1287.0), HTML(value='')))

In [None]:
test = model.generate(input_ids, max_length=10, top_k=1000, top_p=0.9, num_beams=5)
tokenizer.decode(test[2], skip_special_tokens=True, clean_up_tokenization_spaces=True)