In [3]:
import torch
from torch import nn
from torch.nn import Identity
import torch.nn.functional as F
import transformers
from transformers import top_k_top_p_filtering
from transformers import PegasusForConditionalGeneration, PegasusTokenizer, PegasusPreTrainedModel, PegasusModel, PegasusConfig

In [4]:
class PegasusWithValueHeadModel(PegasusPreTrainedModel):
    """The PegasusWithValueHeadModel class implements a Pegasus language model with a secondary, scalar head."""
    
    _keys_to_ignore_on_load_missing = [
        r"final_logits_bias",
        r"encoder\.version",
        r"decoder\.version",
        r"lm_head\.weight",
        r"embed_positions\.weight",
    ]

    def __init__(self, config: PegasusConfig):
        super().__init__(config)
        self.model = PegasusModel(config)
        self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
        self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)

        config.num_labels = 1
        self.v_head = nn.Linear(config.hidden_size, config.num_labels)
        self.detach_head = False
        
        self.init_weights()
        


    def get_output_embeddings(self):
        return self.lm_head

    def detach_value_head(self):
        self.v_head.detach_head = True

    def forward(
        self,
        input_ids=None,
        past_key_values=None,
        attention_mask=None,
#         token_type_ids=None,
#         position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        mc_token_ids=None,
        lm_labels=None,
        mc_labels=None,
        decoder_input_ids=None,
    ):
       
        model_outputs = self.model(
            input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
#             token_type_ids=token_type_ids,
#             position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            decoder_input_ids=decoder_input_ids,
            output_hidden_states=True,
        )

        
        hidden_states = model_outputs.last_hidden_state
        
        lm_logits = self.lm_head(hidden_states)
        
        if self.detach_head:
            value = self.v_head(hidden_states.detach()).squeeze(-1)

        else:
            value = self.v_head(hidden_states).squeeze(-1)

        outputs = lm_logits, model_outputs[1:], value
        return outputs
    
def respond_to_batch(model, queries, txt_len=20, top_k=0, top_p=1.0):
    """Sample text from language model."""
    for i in range(txt_len):
        # Get Logits
        output = model(**queries)
        next_token_logits = output[0][:, -1, :]
        next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
        # Sample
        probs = F.softmax(next_token_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
        
        # No exploration (No sampling)
        # next_token = next_token_logits.argmax(-1)
        
        queries['decoder_input_ids'] = torch.cat([queries['decoder_input_ids'],
                                                      next_token.unsqueeze(-1)], 
                                                     dim=-1)
    return queries['decoder_input_ids'][:, -txt_len:]

In [5]:
model_name = 'tuner007/pegasus_paraphrase'
torch_device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
model = PegasusWithValueHeadModel.from_pretrained(model_name).to(torch_device)
tokenizer = PegasusTokenizer.from_pretrained(model_name)

Some weights of PegasusWithValueHeadModel were not initialized from the model checkpoint at tuner007/pegasus_paraphrase and are newly initialized: ['v_head.weight', 'v_head.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
# def get_response(input_text,num_return_sequences,num_beams):
#     batch = tokenizer([input_text],truncation=True,padding='longest',max_length=60, return_tensors="pt").to(torch_device)
#     translated = model.generate(**batch,max_length=60,num_beams=num_beams, num_returnb_sequences=num_return_sequences, temperature=1.5)
#     tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
#     return tgt_text

In [6]:
input_text = ['A massive glacier had crashed down the mountain.', 
'below are some useful links to facilitate your involvement .']

batch_input = tokenizer(input_text ,truncation=True,padding='longest',max_length=60, return_tensors="pt").to(torch_device)
batch_input['decoder_input_ids'] = torch.zeros(batch_input["input_ids"].shape[0], 1, dtype=int).to(torch_device)
batch_input


{'input_ids': tensor([[  202,  2926, 23682,   196, 14726,   308,   109,  2924,   107,     1,
             0,     0],
        [  487,   127,   181,  1498,  1784,   112,  5186,   128,  5597,   110,
           107,     1]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0'), 'decoder_input_ids': tensor([[0],
        [0]], device='cuda:0')}

In [7]:
# tokenizer.encode('big')

In [8]:
# output = model(**batch_input)
logits, transformer_outputs, values = model(**batch_input)
# out = model.model(output_hidden_states=True, **batch_input)
# out.last_hidden_state.shape
# model.lm_head(out.last_hidden_state).shape

In [9]:
logits.shape, values.shape

(torch.Size([2, 1, 96103]), torch.Size([2, 1]))

In [10]:
# output.decoder_hidden_states[-1].shape, output.logits.shape

In [11]:
logits.argmax(-1)
# output.logits.shape

tensor([[ 202],
        [5870]], device='cuda:0')

In [97]:
# tokenizer.decode(202)
# [tokenizer.decode(i) for i in output.logits.argmax(-1)]
tokenizer.batch_decode(logits.argmax(-1))

['A', 'A', 'A']

In [101]:
tokenizer.batch_decode(batch_input['decoder_input_ids'], skip_special_tokens=False)

# for j in range(batch_input['input_ids'].shape[0]):
#     print([tokenizer.decode(i) for i in batch_input['decoder_input_ids'][j]])

['', '', '']

In [17]:
resp = respond_to_batch(model, dict(batch_input))
print(resp.shape)

# tokenizer.batch_decode(resp, skip_special_tokens=False)
for j in range(resp.shape[0]):
    print([tokenizer.decode(i) for i in resp[j]])

torch.Size([2, 20])
['A', 'glacier', 'crashed', 'down', 'the', 'mountain', '.', '</s>', '</s>', 'Merrick', 'persevered', 'disastrous', 'ly', '.', '</s>', '</s>', '</s>', '</s>', '</s>', '</s>']
['links', 'that', 'are', 'helpful', 'can', 'be', 'found', 'below', '.', '</s>', '</s>', '</s>', '</s>', '</s>', '</s>', '</s>', 'chette', '</s>', '</s>', '</s>']


In [404]:
tokenizer.decode(resp[0])


'A glacier crashed down the mountain.'

In [14]:
def respond_to_batch(model, queries, txt_len=20, top_k=0, top_p=1.0):
    """Sample text from language model."""
    for i in range(txt_len):
        # Get Logits
        output = model(**queries)
        next_token_logits = output[0][:, -1, :]
        next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
        # Sample
        probs = F.softmax(next_token_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
        
        # No exploration (No sampling)
#         next_token = next_token_logits.argmax(-1)
        
        
        queries['decoder_input_ids'] = torch.cat([queries['decoder_input_ids'],
                                                      next_token.unsqueeze(-1)], 
                                                     dim=-1)
    return queries['decoder_input_ids'][:, -txt_len:]

In [78]:
[tokenizer.decode(i) for i in model.generate(**batch_input)[0]]

['<pad>', 'The', 'glacier', 'crashed', 'down', 'the', 'mountain', '.', '</s>']

In [69]:
model.generate(**batch_input)

tensor([[    0,   139, 23682, 14726,   308,   109,  2924,   107,     1]],
       device='cuda:2')

In [412]:
zz = torch.zeros(4,5)
zz[:, -1] += torch.ones(4)

tensor([[0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 1.]])