In [1]:
import torch 
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
import math

from datasets import load_dataset
dataset = load_dataset("gigaword")

Using custom data configuration default
Reusing dataset gigaword (/tmp/xdg-cache/huggingface/datasets/gigaword/default/1.2.0/c518c578e42a6afe842b09e979ee2907ea42a12b57ba992fae9e9d7347825245)


In [2]:
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
import torch
src_text = dataset['train']['document'][0:1]
target_text = dataset['train']['summary'][0:1]
model_name = 'google/pegasus-gigaword'
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer = PegasusTokenizer.from_pretrained(model_name)

#model = PegasusForConditionalGeneration.from_pretrained(model_name,return_dict=True,output_attentions=True,output_hidden_states=True).to(torch_device)

train_data = tokenizer.prepare_seq2seq_batch(src_text, target_text, return_tensors="pt", truncation="only_first", padding="longest")

input_ids_train = train_data['input_ids']
attention_masks_train = train_data['attention_mask']
labels_train = train_data['labels']

In [3]:
from transformers import PegasusForConditionalGeneration, PegasusTokenizer, PegasusConfig, modeling_outputs

In [4]:
class PegasusGuidedCopyBack(PegasusForConditionalGeneration):
    def __init__(self, config):
        super(PegasusGuidedCopyBack, self).__init__(config)
        num_features = config.d_model
        self.p_gen_w = nn.Linear(num_features*3,1)
        self.softmax = nn.Softmax(dim=2)
        self.p_gen_w.bias = nn.Parameter(torch.ones(1))
        self.model_size = num_features
        self.outdegree_score_w = nn.Parameter(torch.ones(1) * 0.5)
        self.indegree_score_w = nn.Parameter(torch.ones(1) * 0.5)
    
    def forward(
        self,
        input_ids,
        attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        encoder_outputs=None,
        past_key_values=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):  
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        out = super(PegasusGuidedCopyBack, self).forward(
            input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            encoder_outputs=encoder_outputs,
            decoder_attention_mask=decoder_attention_mask,
            past_key_values=past_key_values,
            use_cache=use_cache,
            labels=labels,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        
        # score centrality addition
        queries_matrix = self.model.decoder.layers[-1].encoder_attn.q_proj(out['decoder_hidden_states'][-2])
        keys_matrix = self.model.encoder.layers[-1].self_attn.k_proj(out['encoder_hidden_states'][-2])
        self_attn_graph = torch.sum(out['encoder_attentions'][-1],dim=1) # (batch_size x s_seq_len x s_seq_len)
        
        
        outdegree_score = torch.sum(self_attn_graph,dim=1).unsqueeze(-1) #(batch_size x s_seq_len x 1)
        temp_matrix = self_attn_graph
        transition_matrix = F.normalize(temp_matrix,p=1,dim=2)
        indegree_score = torch.sum(transition_matrix,dim=1).unsqueeze(-1)
        
        temp = keys_matrix + self.outdegree_score_w * outdegree_score + self.indegree_score_w * indegree_score
        attn = self.softmax(queries_matrix@(temp.transpose(2,1))/math.sqrt(self.model_size))

        if labels is not None:
            dec = self.model.get_input_embeddings()(labels) # (batch_size x t_seq_len x model_size)
        else:
            dec = self.model.get_input_embeddings()(decoder_input_ids[:,[-1]])
        
        hi = out['encoder_last_hidden_state'] # (batch_size x s_seq_len x model_size)      
        st = out['decoder_hidden_states'][-1] # (batch_size x t_seq_len x model_size)
        
        p_gen = torch.sigmoid(self.p_gen_w(torch.cat((attn @ hi,st,dec),dim=-1))) # (batch_size x t_seq_len x 1)

        v_dist = p_gen*out['logits'] # (batch_size x t_seq_len x vocab_size)
        
        attn_dists = (1-p_gen)*attn # (batch_size x t_seq_len x s_seq_len)
        
        if input_ids is not None:            
            src_ids = input_ids.unsqueeze(1).repeat(1, attn_dists.size(1), 1)   # (batch_size x 1 x s_seq_len)  
        else:
            src_ids = self.input_ids.unsqueeze(1).repeat(1, attn_dists.size(1), 1)   # (batch_size x 1 x s_seq_len)      
#         print(src_ids.shape)
#         print(v_dist.shape)
#         print(attn_dists.shape)
#         print(attn_dists.size(1))
        total_score = (self.outdegree_score_w * outdegree_score + self.indegree_score_w * indegree_score).squeeze(-1)
        total_score = F.softmax(total_score,dim=-1)
        #total_score = F.softmax(outdegree_score.squeeze(-1),dim=-1)
        
        enc_dec_attn = torch.sum(attn_dists,dim=1)

        kl_divergence = F.kl_div(total_score,enc_dec_attn)
        pred = v_dist.scatter_add(2, src_ids, attn_dists) #(batch_size x t_seq_len x vocab_size)

        masked_lm_loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            # TODO(SS): do we need to ignore pad tokens in labels?
            masked_lm_loss = loss_fct(pred.view(-1, self.config.vocab_size), labels.view(-1)) + kl_divergence
        
        if not return_dict:
            output = (pred,) + out[1:]
            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
        
        return modeling_outputs.Seq2SeqLMOutput(
            loss=masked_lm_loss,
            logits=pred,
            past_key_values=out.past_key_values,
            decoder_hidden_states=out.decoder_hidden_states,
            decoder_attentions=out.decoder_attentions,
            cross_attentions=out.cross_attentions,
            encoder_last_hidden_state=out.encoder_last_hidden_state,
            encoder_hidden_states=out.encoder_hidden_states,
            encoder_attentions=out.encoder_attentions,
        )
    
    @torch.no_grad()
    def generate(
        self,
        input_ids=None,
        max_length=None,
        min_length=None,
        do_sample=None,
        early_stopping=None,
        num_beams=None,
        temperature=None,
        top_k=None,
        top_p=None,
        repetition_penalty=None,
        bad_words_ids=None,
        bos_token_id=None,
        pad_token_id=None,
        eos_token_id=None,
        length_penalty=None,
        no_repeat_ngram_size=None,
        num_return_sequences=None,
        attention_mask=None,
        decoder_start_token_id=None,
        use_cache=None,
        **model_specific_kwargs
    ):
        self.input_ids = input_ids
        return super(PegasusWithCopyBack, self).generate(
            input_ids=input_ids,
            max_length=max_length,
            min_length=min_length,
            do_sample=do_sample,
            early_stopping=early_stopping,
            num_beams=num_beams,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            bad_words_ids=bad_words_ids,
            bos_token_id=bos_token_id,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            length_penalty=length_penalty,
            no_repeat_ngram_size=no_repeat_ngram_size,
            num_return_sequences=num_return_sequences,
            attention_mask=attention_mask,
            decoder_start_token_id=decoder_start_token_id,
            use_cache=use_cache,
        )


In [5]:
config = PegasusConfig.from_pretrained(model_name, output_hidden_states=True, output_attentions=True)

In [6]:
pega_copyback_model = PegasusGuidedCopyBack.from_pretrained(model_name, config=config).to(torch_device)

Some weights of PegasusGuidedCopyBack were not initialized from the model checkpoint at google/pegasus-gigaword and are newly initialized: ['outdegree_score_w', 'indegree_score_w', 'p_gen_w.weight', 'p_gen_w.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers import AdamW, get_linear_schedule_with_warmup

dataset_train = TensorDataset(input_ids_train, attention_masks_train, labels_train)

batch_size = 4
dataloader_train = DataLoader(dataset_train,sampler=RandomSampler(dataset_train), batch_size=batch_size)

#freezing the parameters
# for param in pega_copyback_model.model.parameters():
#     param.requires_grad = False
optimizer = AdamW(filter(lambda p: p.requires_grad, pega_copyback_model.parameters()),lr=5e-3)
                  
epochs = 10

In [9]:
from tqdm.notebook import tqdm
import random

seed_val = 42
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)  

for epoch in tqdm(range(1, epochs+1)):
    
    pega_copyback_model.train()
    
    loss_train_total = 0

    progress_bar = tqdm(dataloader_train, desc='Epoch {:1d}'.format(epoch), leave=False, disable=False)
    for b in progress_bar:

        optimizer.zero_grad()
        
        b = tuple(x.to(torch_device) for x in b)
        
        inputs = {'input_ids':      b[0],
                  'attention_mask': b[1],
                  'labels':         b[2],
                 }       

        outputs = pega_copyback_model(**inputs)
#         vocab_size =  outputs['logits'].shape[2]
    
        loss = outputs[0]#criterion(outputs.view(-1,vocab_size),b[2].view(-1))

        loss_train_total += loss.item()
        loss.backward()

#         torch.nn.utils.clip_grad_norm_(pega_copyback_model.parameters(), 1.0)

        optimizer.step()
        #scheduler.step()
        
        progress_bar.set_postfix({'training_loss': '{:.3f}'.format(loss.item()/len(b))})
         
        
    #torch.save(model.state_dict(), f'data_volume/finetuned_BERT_epoch_{epoch}.model')
        
    tqdm.write(f'\nEpoch {epoch}')
    
    loss_train_avg = loss_train_total/len(dataloader_train)            
    tqdm.write(f'Training loss: {loss_train_avg}')

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, description='Epoch 1', max=1.0, style=ProgressStyle(description_width=…


Epoch 1
Training loss: 0.3993922472000122


HBox(children=(FloatProgress(value=0.0, description='Epoch 2', max=1.0, style=ProgressStyle(description_width=…


Epoch 2
Training loss: 0.34161853790283203


HBox(children=(FloatProgress(value=0.0, description='Epoch 3', max=1.0, style=ProgressStyle(description_width=…


Epoch 3
Training loss: 0.07979846000671387


HBox(children=(FloatProgress(value=0.0, description='Epoch 4', max=1.0, style=ProgressStyle(description_width=…


Epoch 4
Training loss: 0.04203534126281738


HBox(children=(FloatProgress(value=0.0, description='Epoch 5', max=1.0, style=ProgressStyle(description_width=…


Epoch 5
Training loss: 0.03306841850280762


HBox(children=(FloatProgress(value=0.0, description='Epoch 6', max=1.0, style=ProgressStyle(description_width=…


Epoch 6
Training loss: 0.04416036605834961


HBox(children=(FloatProgress(value=0.0, description='Epoch 7', max=1.0, style=ProgressStyle(description_width=…


Epoch 7
Training loss: 0.017304182052612305


HBox(children=(FloatProgress(value=0.0, description='Epoch 8', max=1.0, style=ProgressStyle(description_width=…


Epoch 8
Training loss: 0.008627176284790039


HBox(children=(FloatProgress(value=0.0, description='Epoch 9', max=1.0, style=ProgressStyle(description_width=…


Epoch 9
Training loss: 0.0020666122436523438


HBox(children=(FloatProgress(value=0.0, description='Epoch 10', max=1.0, style=ProgressStyle(description_width…


Epoch 10
Training loss: 0.0044405460357666016

