In [1]:
import torch 
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
import math

from datasets import load_dataset
dataset = load_dataset("gigaword")

Using custom data configuration default
Reusing dataset gigaword (/tmp/xdg-cache/huggingface/datasets/gigaword/default/1.2.0/c518c578e42a6afe842b09e979ee2907ea42a12b57ba992fae9e9d7347825245)


In [2]:
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
import torch
src_text = dataset['train']['document'][0:1000]
target_text = dataset['train']['summary'][0:1000]
model_name = 'google/pegasus-gigaword'
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer = PegasusTokenizer.from_pretrained(model_name)

train_data = tokenizer.prepare_seq2seq_batch(src_text, target_text, return_tensors="pt", truncation="only_first", padding="longest", max_length=64)

input_ids_train = train_data['input_ids']
attention_masks_train = train_data['attention_mask']
labels_train = train_data['labels']

In [3]:
from transformers import PegasusForConditionalGeneration, PegasusTokenizer, PegasusConfig, modeling_outputs

In [4]:
config = PegasusConfig.from_pretrained(model_name, output_hidden_states=True, output_attentions=True)

In [5]:
pega_copyback_model = PegasusForConditionalGeneration.from_pretrained(model_name, config=config).to(torch_device)

In [6]:
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers import AdamW, get_linear_schedule_with_warmup

dataset_train = TensorDataset(input_ids_train, attention_masks_train, labels_train)

batch_size = 2
dataloader_train = DataLoader(dataset_train,sampler=RandomSampler(dataset_train), batch_size=batch_size)

#freezing the parameters
# for param in pega_copyback_model.model.parameters():
#     param.requires_grad = False
optimizer = AdamW(filter(lambda p: p.requires_grad, pega_copyback_model.parameters()),lr=5e-5)
                  
epochs = 10

In [7]:
from tqdm.notebook import tqdm
import random

seed_val = 42
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)  

for epoch in tqdm(range(1, epochs+1)):
    
    pega_copyback_model.train()
    
    loss_train_total = 0

    progress_bar = tqdm(dataloader_train, desc='Epoch {:1d}'.format(epoch), leave=False, disable=False)
    for b in progress_bar:

        optimizer.zero_grad()
        
        b = tuple(x.to(torch_device) for x in b)
        
        inputs = {'input_ids':      b[0],
                  'attention_mask': b[1],
                  'labels':         b[2],
                 }       

        outputs = pega_copyback_model(**inputs)
#         vocab_size =  outputs['logits'].shape[2]
    
        loss = outputs[0]#criterion(outputs.view(-1,vocab_size),b[2].view(-1))

        loss_train_total += loss.item()
        loss.backward()

#         torch.nn.utils.clip_grad_norm_(pega_copyback_model.parameters(), 1.0)

        optimizer.step()
        #scheduler.step()
        
        progress_bar.set_postfix({'training_loss': '{:.3f}'.format(loss.item()/len(b))})
         
        
#     torch.save(pega_copyback_model.state_dict(), f'data/finetuned_pega_epoch_{epoch}.model')
        
    tqdm.write(f'\nEpoch {epoch}')
    
    loss_train_avg = loss_train_total/len(dataloader_train)            
    tqdm.write(f'Training loss: {loss_train_avg}')

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, description='Epoch 1', max=500.0, style=ProgressStyle(description_widt…


Epoch 1
Training loss: 4.518651569277048


HBox(children=(FloatProgress(value=0.0, description='Epoch 2', max=500.0, style=ProgressStyle(description_widt…


Epoch 2
Training loss: 0.7126671251058578


HBox(children=(FloatProgress(value=0.0, description='Epoch 3', max=500.0, style=ProgressStyle(description_widt…


Epoch 3
Training loss: 0.5251764520779252


HBox(children=(FloatProgress(value=0.0, description='Epoch 4', max=500.0, style=ProgressStyle(description_widt…


Epoch 4
Training loss: 0.38543529219180345


HBox(children=(FloatProgress(value=0.0, description='Epoch 5', max=500.0, style=ProgressStyle(description_widt…


Epoch 5
Training loss: 0.284935523763299


HBox(children=(FloatProgress(value=0.0, description='Epoch 6', max=500.0, style=ProgressStyle(description_widt…


Epoch 6
Training loss: 0.21589056959562003


HBox(children=(FloatProgress(value=0.0, description='Epoch 7', max=500.0, style=ProgressStyle(description_widt…


Epoch 7
Training loss: 0.16221176459919662


HBox(children=(FloatProgress(value=0.0, description='Epoch 8', max=500.0, style=ProgressStyle(description_widt…


Epoch 8
Training loss: 0.11790174465859309


HBox(children=(FloatProgress(value=0.0, description='Epoch 9', max=500.0, style=ProgressStyle(description_widt…


Epoch 9
Training loss: 0.09408257886418142


HBox(children=(FloatProgress(value=0.0, description='Epoch 10', max=500.0, style=ProgressStyle(description_wid…


Epoch 10
Training loss: 0.07352545802830718



In [8]:
pega_copyback_model.eval()

tgt_text = tokenizer.batch_decode(pega_copyback_model.generate(train_data['input_ids'][[0],:].to(torch_device)))#, skip_special_tokens=True)

In [9]:
print(tgt_text)

['australian current account deficit narrows sharply in june quarter on strong commodity price rise.## bln dlr month on month on month in june qtrunk_3']


In [10]:
from datasets import load_metric

pega_copyback_model.eval()


metric = load_metric("rouge")

src_text = dataset['validation']['document'][0:100]
target_text = dataset['validation']['summary'][0:100]

val_data = tokenizer.prepare_seq2seq_batch(src_text, target_text, return_tensors="pt", truncation="only_first", padding="longest", max_length=64)

input_ids_val = val_data['input_ids']
attention_masks_val = val_data['attention_mask']
labels_val = val_data['labels']

dataset_val = TensorDataset(input_ids_val, attention_masks_val, labels_val)

batch_size = 1
dataloader_val = DataLoader(dataset_val, batch_size=batch_size)


shouldPrint = 0
toPrint = 4
with torch.no_grad():
    for b in dataloader_val:        
        b = tuple(x.to(torch_device) for x in b)
        
        inputs = {'input_ids':      b[0],
                  'attention_mask': b[1],
                  'labels':         b[2],
                 }       
        gen = tokenizer.batch_decode(pega_copyback_model.generate(inputs['input_ids']), skip_special_tokens=True)
        if shouldPrint < toPrint:
            print("Source Sentence:", src_text[shouldPrint])
            print("Sample Sentence:", gen[0])
            shouldPrint += 1 
        ref = tokenizer.batch_decode(inputs['labels'])
        metric.add_batch(predictions=gen, references=ref)


print("@@@@@@@@@@@@@@@@@@@ Rogue @@@@@@@@@@@@@@@@@@@@@@@")
print()        

print(metric.compute())

Source Sentence: five-time world champion michelle kwan withdrew from the #### us figure skating championships on wednesday , but will petition us skating officials for the chance to compete at the #### turin olympics .
Sample Sentence: kwan withdraws from us figure skating championships but wants to compete at olympics in #### in beijing olympics despite latest injury call-unk_3
Source Sentence: us business leaders lashed out wednesday at legislation that would penalize companies for employing illegal immigrants .
Sample Sentence: us business leaders blast anti-immigrant bill as unfair to us firms with bc-as-eu-polus-poland-us-mexicounk_3
Source Sentence: general motors corp. said wednesday its us sales fell ##.# percent in december and four percent in #### with the biggest losses coming from passenger car sales .
Sample Sentence: gm's us sales drop ##.# pct in december # pct in #### on weak ; worst in ## pct in ## monthsunk_3
Source Sentence: several thousand people gathered on wedne