In [1]:
import torch 
import torch.nn as nn
import numpy as np
import torch.nn.functional as F

In [2]:
from datasets import load_dataset
dataset = load_dataset("gigaword")

Using custom data configuration default
Reusing dataset gigaword (/tmp/xdg-cache/huggingface/datasets/gigaword/default/1.2.0/c518c578e42a6afe842b09e979ee2907ea42a12b57ba992fae9e9d7347825245)


In [3]:
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
import torch
import torch.nn as nn
import math

In [4]:
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
import torch
src_text = dataset['train']['document'][0:1000]
target_text = dataset['train']['summary'][0:1000]
model_name = 'google/pegasus-gigaword'
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer = PegasusTokenizer.from_pretrained(model_name)

train_data = tokenizer.prepare_seq2seq_batch(src_text, target_text, return_tensors="pt", truncation="only_first", padding="longest", max_length=64)

#tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)

In [5]:
input_ids_train = train_data['input_ids']
attention_masks_train = train_data['attention_mask']
labels_train = train_data['labels']

In [6]:
print(train_data)

{'input_ids': tensor([[14937,  1034,   116,  ...,     0,     0,     0],
        [  134,   583,   228,  ...,     0,     0,     0],
        [73013,  2853,  2127,  ...,     0,     0,     0],
        ...,
        [  260,   117, 32078,  ...,     0,     0,     0],
        [  154,   197,   110,  ...,     0,     0,     0],
        [  126,  1034,   116,  ...,     0,     0,     0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]]), 'labels': tensor([[73013,   582,   728,  ...,     0,     0,     0],
        [  134,   583,   228,  ...,     0,     0,     0],
        [73013,  5446,   686,  ...,     0,     0,     0],
        ...,
        [ 3834,  6010, 12752,  ...,     0,     0,     0],
        [32577,   493, 23765,  ...,     0,     0,     0],
        [94830,  1034,   116,  ...,     0,     0,     0]])}


In [7]:
from transformers import PegasusForConditionalGeneration, PegasusTokenizer, PegasusConfig, modeling_outputs

In [8]:
class PegasusWithCopyBack(PegasusForConditionalGeneration):
    def __init__(self, config):
        super(PegasusWithCopyBack, self).__init__(config)
        num_features = config.d_model
        self.p_gen_w = nn.Linear(num_features*3,1)
        self.softmax = nn.Softmax(dim=2)
        self.p_gen_w.bias = nn.Parameter(torch.ones(1))
    
    def forward(
        self,
        input_ids,
        attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        encoder_outputs=None,
        past_key_values=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):  
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        out = super(PegasusWithCopyBack, self).forward(
            input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            encoder_outputs=encoder_outputs,
            decoder_attention_mask=decoder_attention_mask,
            past_key_values=past_key_values,
            use_cache=use_cache,
            labels=labels,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        
        hi = out['encoder_last_hidden_state'] # (batch_size x s_seq_len x model_size)
        at = torch.mean(out['cross_attentions'][-1], dim=1) # (batch_size x t_seq_len x s_seq_len)        
        st = out['decoder_hidden_states'][-1] # (batch_size x t_seq_len x model_size)
        
        if labels is not None:
            dec = self.model.get_input_embeddings()(labels) # (batch_size x t_seq_len x model_size)
        else:
#             print(decoder_input_ids.shape)
            dec = self.model.get_input_embeddings()(decoder_input_ids[:,[-1]])

#         print((at@hi).shape)
#         print(st.shape)
#         print(dec.shape)
        
        p_gen = torch.sigmoid(self.p_gen_w(torch.cat((at @ hi,st,dec),dim=-1))) # (batch_size x t_seq_len x 1)
        attn_dists = (1-p_gen)*at # (batch_size x t_seq_len x s_seq_len)
        v_dist = p_gen * out['logits'] # (batch_size x t_seq_len x vocab_size)
        

        if input_ids is not None:            
            src_ids = input_ids.unsqueeze(1).repeat(1, attn_dists.size(1), 1)   # (batch_size x 1 x s_seq_len)  
        else:
            src_ids = self.input_ids.unsqueeze(1).repeat(1, attn_dists.size(1), 1)   # (batch_size x 1 x s_seq_len)
                
        pred = v_dist.scatter_add(2, src_ids, attn_dists) #(batch_size x t_seq_len x vocab_size)

        masked_lm_loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            # TODO(SS): do we need to ignore pad tokens in labels?
            masked_lm_loss = loss_fct(pred.view(-1, self.config.vocab_size), labels.view(-1))
        
        if not return_dict:
            output = (pred,) + out[1:]
            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
        
        return modeling_outputs.Seq2SeqLMOutput(
            loss=masked_lm_loss,
            logits=pred,
            past_key_values=out.past_key_values,
            decoder_hidden_states=out.decoder_hidden_states,
            decoder_attentions=out.decoder_attentions,
            cross_attentions=out.cross_attentions,
            encoder_last_hidden_state=out.encoder_last_hidden_state,
            encoder_hidden_states=out.encoder_hidden_states,
            encoder_attentions=out.encoder_attentions,
        )
    
    @torch.no_grad()
    def generate(
        self,
        input_ids=None,
        max_length=None,
        min_length=None,
        do_sample=None,
        early_stopping=None,
        num_beams=None,
        temperature=None,
        top_k=None,
        top_p=None,
        repetition_penalty=None,
        bad_words_ids=None,
        bos_token_id=None,
        pad_token_id=None,
        eos_token_id=None,
        length_penalty=None,
        no_repeat_ngram_size=None,
        num_return_sequences=None,
        attention_mask=None,
        decoder_start_token_id=None,
        use_cache=None,
        **model_specific_kwargs
    ):
        self.input_ids = input_ids
        return super(PegasusWithCopyBack, self).generate(
            input_ids=input_ids,
            max_length=max_length,
            min_length=min_length,
            do_sample=do_sample,
            early_stopping=early_stopping,
            num_beams=num_beams,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            bad_words_ids=bad_words_ids,
            bos_token_id=bos_token_id,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            length_penalty=length_penalty,
            no_repeat_ngram_size=no_repeat_ngram_size,
            num_return_sequences=num_return_sequences,
            attention_mask=attention_mask,
            decoder_start_token_id=decoder_start_token_id,
            use_cache=use_cache,
        )

In [9]:
config = PegasusConfig.from_pretrained(model_name, output_hidden_states=True, output_attentions=True)

In [10]:
# pega_copyback_model = PegasusWithCopyBack(config, 1024).to(torch_device) 
pega_copyback_model = PegasusWithCopyBack.from_pretrained(model_name, config=config).to(torch_device) 

Some weights of PegasusWithCopyBack were not initialized from the model checkpoint at google/pegasus-gigaword and are newly initialized: ['p_gen_w.weight', 'p_gen_w.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [11]:
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers import AdamW, get_linear_schedule_with_warmup

dataset_train = TensorDataset(input_ids_train, attention_masks_train, labels_train)

batch_size = 2
dataloader_train = DataLoader(dataset_train,sampler=RandomSampler(dataset_train), batch_size=batch_size)

#freezing the parameters
# for param in pega_copyback_model.model.parameters():
#     param.requires_grad = False
optimizer = AdamW(filter(lambda p: p.requires_grad, pega_copyback_model.parameters()),lr=5e-5)
                  
epochs = 10
# #scheduler = get_linear_schedule_with_warmup(optimizer, 
#                                             num_warmup_steps=0,
#                                             num_training_steps=len(dataloader_train)*epochs)


In [12]:
from tqdm.notebook import tqdm
import random

seed_val = 42
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

    
pega_copyback_model = pega_copyback_model.to(torch_device)   

for epoch in tqdm(range(1, epochs+1)):
    
    pega_copyback_model.train()
    
    loss_train_total = 0

    progress_bar = tqdm(dataloader_train, desc='Epoch {:1d}'.format(epoch), leave=False, disable=False)
    for b in progress_bar:

        optimizer.zero_grad()
        
        b = tuple(x.to(torch_device) for x in b)
        
        inputs = {'input_ids':      b[0],
                  'attention_mask': b[1],
                  'labels':         b[2],
                 }       

        outputs = pega_copyback_model(**inputs)
#         vocab_size =  outputs['logits'].shape[2]
    
        loss = outputs[0]#criterion(outputs.view(-1,vocab_size),b[2].view(-1))

        loss_train_total += loss.item()
        loss.backward()

#         torch.nn.utils.clip_grad_norm_(pega_copyback_model.parameters(), 1.0)

        optimizer.step()
        #scheduler.step()
        
        progress_bar.set_postfix({'training_loss': '{:.3f}'.format(loss.item()/len(b))})
         
        
#     torch.save(pega_copyback_model.state_dict(), f'data/finetuned_pega_copyback_epoch_{epoch}.model')
        
    tqdm.write(f'\nEpoch {epoch}')
    
    loss_train_avg = loss_train_total/len(dataloader_train)            
    tqdm.write(f'Training loss: {loss_train_avg}')

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, description='Epoch 1', max=500.0, style=ProgressStyle(description_widt…


Epoch 1
Training loss: 4.616461138725281


HBox(children=(FloatProgress(value=0.0, description='Epoch 2', max=500.0, style=ProgressStyle(description_widt…


Epoch 2
Training loss: 0.7418627934753895


HBox(children=(FloatProgress(value=0.0, description='Epoch 3', max=500.0, style=ProgressStyle(description_widt…


Epoch 3
Training loss: 0.5312441710308194


HBox(children=(FloatProgress(value=0.0, description='Epoch 4', max=500.0, style=ProgressStyle(description_widt…


Epoch 4
Training loss: 0.3824864894151688


HBox(children=(FloatProgress(value=0.0, description='Epoch 5', max=500.0, style=ProgressStyle(description_widt…


Epoch 5
Training loss: 0.27958077045530083


HBox(children=(FloatProgress(value=0.0, description='Epoch 6', max=500.0, style=ProgressStyle(description_widt…


Epoch 6
Training loss: 0.20749787778034806


HBox(children=(FloatProgress(value=0.0, description='Epoch 7', max=500.0, style=ProgressStyle(description_widt…


Epoch 7
Training loss: 0.1514088394846767


HBox(children=(FloatProgress(value=0.0, description='Epoch 8', max=500.0, style=ProgressStyle(description_widt…


Epoch 8
Training loss: 0.10168601999618113


HBox(children=(FloatProgress(value=0.0, description='Epoch 9', max=500.0, style=ProgressStyle(description_widt…


Epoch 9
Training loss: 0.08478897425206378


HBox(children=(FloatProgress(value=0.0, description='Epoch 10', max=500.0, style=ProgressStyle(description_wid…


Epoch 10
Training loss: 0.0695862281427253



In [13]:
pega_copyback_model.eval()

tgt_text = tokenizer.batch_decode(pega_copyback_model.generate(train_data['input_ids'][[0],:].to(torch_device)))#, skip_special_tokens=True)

In [14]:
print(tgt_text)

["australian current account deficit narrows sharply in june quarter on commodity price rise torptn'tn'tn'tn'tnunk_3"]


In [15]:
from datasets import load_metric

pega_copyback_model.eval()


metric = load_metric("rouge")

src_text = dataset['validation']['document'][0:100]
target_text = dataset['validation']['summary'][0:100]

val_data = tokenizer.prepare_seq2seq_batch(src_text, target_text, return_tensors="pt", truncation="only_first", padding="longest", max_length=64)

input_ids_val = val_data['input_ids']
attention_masks_val = val_data['attention_mask']
labels_val = val_data['labels']

dataset_val = TensorDataset(input_ids_val, attention_masks_val, labels_val)

batch_size = 1
dataloader_val = DataLoader(dataset_val, batch_size=batch_size)


shouldPrint = 0
toPrint = 4
with torch.no_grad():
    for b in dataloader_val:        
        b = tuple(x.to(torch_device) for x in b)
        
        inputs = {'input_ids':      b[0],
                  'attention_mask': b[1],
                  'labels':         b[2],
                 }       
        gen = tokenizer.batch_decode(pega_copyback_model.generate(inputs['input_ids']), skip_special_tokens=True)
        if shouldPrint < toPrint:
            print("Source Sentence:", src_text[shouldPrint])
            print("Sample Sentence:", gen[0])
            shouldPrint += 1 
        ref = tokenizer.batch_decode(inputs['labels'])
        metric.add_batch(predictions=gen, references=ref)


print("@@@@@@@@@@@@@@@@@@@ Rogue @@@@@@@@@@@@@@@@@@@@@@@")
print()
        
print(metric.compute())

Source Sentence: five-time world champion michelle kwan withdrew from the #### us figure skating championships on wednesday , but will petition us skating officials for the chance to compete at the #### turin olympics .
Sample Sentence: five-time world champion kwan withdraws from us championships but wants to compete in #### olympics in #### in bid to keep spot forunk_3
Source Sentence: us business leaders lashed out wednesday at legislation that would penalize companies for employing illegal immigrants .
Sample Sentence: us us business leaders blast anti-immigrant bill as unfair to us firms with bc-as-as-me-as-gen us-as-unk_3
Source Sentence: general motors corp. said wednesday its us sales fell ##.# percent in december and four percent in #### with the biggest losses coming from passenger car sales .
Sample Sentence: gm us sales in #### down ##.# pct on year on worst dec # pct on pct on pct slide in passenger car salesunk_3
Source Sentence: several thousand people gathered on wednes