In [1]:
from pathlib import Path
from typing import Union, List

import torch
import torch.nn.functional as F
from transformers import GPT2LMHeadModel, GPT2Tokenizer, modeling_utils, GPT2PreTrainedModel, T5ForConditionalGeneration, T5TokenizerFast
from generation.gpt2_generation import GPT2Generation

from utils import utils
from utils.generation_utils import top_k_top_p_filtering

In [2]:
MAX_LENGTH = int(10000)  # Hardcoded max length to avoid infinite loop

class DExpertsGeneration: 
    STOP_TOKEN = "</s>"

    def __init__(
        self, 
        base_model: Union[str, Path, T5ForConditionalGeneration],
        antiexpert_model: Union[str, Path, T5ForConditionalGeneration] = None,
        expert_model: Union[str, Path, T5ForConditionalGeneration] = None,
        tokenizer: str = 'gpt2', 
        seed: int = 42,
        expert_prefix: str = None,
        antiexpert_prefix: str = None
        ):
        # Set up device
        self.expert_prefix = expert_prefix
        self.antiexpert_prefix = antiexpert_prefix
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        n_gpu = torch.cuda.device_count()
        utils.set_seed(seed, n_gpu)

        self.base_model = T5ForConditionalGeneration.from_pretrained(base_model).to(self.device)
        
        if antiexpert_model:
            self.antiexpert = T5ForConditionalGeneration.from_pretrained(antiexpert_model).to(self.device)
        else:
            self.antiexpert = None
        
        if expert_model:
            self.expert = T5ForConditionalGeneration.from_pretrained(expert_model).to(self.device)
        else:
            self.expert = None
        
        self.tokenizer = T5TokenizerFast.from_pretrained(base_model)
        self.tokenizer.pad_token = self.STOP_TOKEN
        assert self.tokenizer.eos_token_id == self.tokenizer.pad_token_id

    def __repr__(self):
        return f'<DExpertsGenerator model_name_or_path="{self.model}">'

    def generate(self,
                 prompt: Union[str, List[str]],
                 max_len: int = 20,
                 sample: bool = True,
                 filter_p: float = 0.9,
                 k: int = 0,
                 p: float = 1.0,
                 temperature: float = 1.0,
                 alpha: float = 0.0
                ):
        if isinstance(prompt, str):
            source = [prompt]
        else:
            source = prompt
            
        source_base = ["paraphrase: " + x for x in source]
        source_expert = [self.expert_prefix + x for x in source]
        source_antiexpert = [self.antiexpert_prefix + x for x in source]
        
        target = []
        for x in source:
            target.append("<pad>")
            
        encodings_dict_base = self.tokenizer.batch_encode_plus(source_base, pad_to_max_length=True, return_tensors='pt')
        input_ids_base = encodings_dict_base['input_ids'].to(self.device)
        attention_mask_base = encodings_dict_base['attention_mask'].to(self.device)
        
        encodings_dict_expert = self.tokenizer.batch_encode_plus(source_expert, pad_to_max_length=True, return_tensors='pt')
        input_ids_exper = encodings_dict_expert['input_ids'].to(self.device)
        attention_mask_exper = encodings_dict_expert['attention_mask'].to(self.device)
        
        
        encodings_dict_anti = self.tokenizer.batch_encode_plus(source_antiexpert, pad_to_max_length=True, return_tensors='pt')
        input_ids_anti = encodings_dict_anti['input_ids'].to(self.device)
        attention_mask_anti = encodings_dict_anti['attention_mask'].to(self.device)
        
        decoder_dict = self.tokenizer.batch_encode_plus(target, return_tensors='pt')
        decoder_input_ids = decoder_dict['input_ids'].to(self.device)
        decoder_attention_mask = decoder_dict['attention_mask'].to(self.device)
        
        batch_size, input_seq_len = input_ids_exper.shape

#         position_ids = attention_mask.cumsum(dim=1) - 1
        unfinished_sents = torch.ones(batch_size, dtype=torch.long, device=self.device)

        self.base_model.eval()
        if self.expert:
            self.expert.eval()
        if self.antiexpert:
            self.antiexpert.eval()
        with torch.no_grad():
            for step in range(max_len):
                # base model prediction
                base_logits = self.base_model(input_ids_base, attention_mask = attention_mask_base, 
                                              decoder_input_ids = decoder_input_ids, decoder_attention_mask = decoder_attention_mask)["logits"]
                
                # expert prediction
                if self.expert:
                    expert_logits = self.expert(input_ids_exper, attention_mask = attention_mask_exper, 
                                               decoder_input_ids = decoder_input_ids, decoder_attention_mask = decoder_attention_mask)["logits"]
                else:
                    expert_logits = base_logits
                
                # antiexpert prediction
                if self.antiexpert:
                    antiexpert_logits = self.antiexpert(input_ids_anti, attention_mask = attention_mask_anti, 
                                                       decoder_input_ids = decoder_input_ids, decoder_attention_mask = decoder_attention_mask)["logits"]
                else:
                    antiexpert_logits = base_logits
                
                if filter_p < 1.0:
                    base_logits = top_k_top_p_filtering(base_logits, top_p=filter_p)
                
                # DExperts
                alpha = torch.tensor(alpha).to(self.device)
                ensemble_logits = base_logits + alpha * (expert_logits - antiexpert_logits)

                # in the first decoding step, we want to use the 'real' last position for each sentence
#                 if step == 0:
#                     last_non_masked_idx = torch.sum(attention_mask, dim=1) - 1
#                     next_token_logits = ensemble_logits[range(batch_size), last_non_masked_idx, :]
#                 else:
                next_token_logits = ensemble_logits[:, -1, :]

                if sample:
                    # Temperature (higher temperature => more likely to sample low probability tokens)
                    if temperature != 1.0:
                        next_token_logits = next_token_logits / temperature
                    if k > 0 or p < 1.0:
                        next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=k, top_p=p)
                    # Sample
                    probs = F.softmax(next_token_logits, dim=-1)
                    next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
                else:
                    # Greedy decoding
                    next_tokens = torch.argmax(next_token_logits, dim=-1)
                # either append a padding token here if <EOS> has been seen or append next token
                tokens_to_add = next_tokens * unfinished_sents + self.tokenizer.pad_token_id * (1 - unfinished_sents)
                # this updates which sentences have not seen an EOS token so far
                # if one EOS token was seen the sentence is finished
                eos_in_sents = tokens_to_add == self.tokenizer.eos_token_id
                unfinished_sents.mul_((~eos_in_sents).long())

                # stop when there is an EOS in each sentence
                if unfinished_sents.max() == 0:
                    break
                
#                 if step == 0:
#                     decoder_input_ids = tokens_to_add.unsqueeze(-1)
                    
#                 else:
                # Update input_ids, attention_mask and position_ids
                decoder_input_ids = torch.cat([decoder_input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
#                     print(decoder_input_ids)
                decoder_attention_mask = torch.cat([decoder_attention_mask, decoder_attention_mask.new_ones((batch_size, 1))], dim=1)
#                     position_ids = torch.cat([position_ids, (position_ids[:, -1] + 1).unsqueeze(-1)], dim=1)

        decoded_outputs = [self.tokenizer.decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=True)
                           for output in decoder_input_ids[:, :]]
        return decoded_outputs

In [7]:
generator = DExpertsGeneration(
        base_model="../st5-para/", 
        expert_model="../distributed/st5_mul_joy/checkpoint-93330/",
        antiexpert_model="../distributed/st5_mul_sad/checkpoint-101983/",
        expert_prefix = "joy_CLS: ",
        antiexpert_prefix = "sad_CLS: "
    )

In [12]:
generator.generate(
         ["And just like that , I don't care who wins the tournament .", 
          "Some people still need to learn to grow up and mature . . This isn't high school anymore",
          "It's kinda sad when you see an 9 year-old drop the f-bomb in public and the parents laugh along with them.",
          "My phone is dying yet it's my entertainment for the next 4 hours", "why did secret life have to end like that noo ) :"],
         max_len = 128,
         sample = True,
         filter_p = 0.9,
         k = 0,
         p = 0.9,
         temperature = 1.0,
         alpha = 3.0
                )



['Who won?',
 "it's no longer high school for some people to grow up, not a kind, ladaptive behaviour.",
 'When a 9-year-old becomes a gunman and their parents laugh together at the same time.',
 "I'm gonna die of my phone, but the energy will be safe for all four hours.",
 'Why do some secrets have to end so quickly?']

In [17]:
tokener = T5TokenizerFast.from_pretrained("../st5-para/")

In [18]:
tokener.pad_token = "</s>"

In [19]:
print(tokener.eos_token_id)

1


In [61]:
base_model = T5ForConditionalGeneration.from_pretrained("../distributed/st5_mul_joy/checkpoint-93330/", pad_token_id=tokener.eos_token_id)


In [111]:
encodings_dict = tokener.batch_encode_plus(["joy: I hate my day."], pad_to_max_length=True, return_tensors='pt')
decoder_dict = tokener.batch_encode_plus(["<pad>I absolutely"], add_special_tokens=False, return_tensors='pt')


In [112]:
input_ids = encodings_dict['input_ids']
attention_mask = encodings_dict['attention_mask']
batch_size, input_seq_len = input_ids.shape

In [113]:
decoder_input_ids = decoder_dict['input_ids']
decoder_attention_mask = decoder_dict['attention_mask']

In [114]:
print(tokener.decode(decoder_input_ids[0].tolist()))
print(decoder_input_ids)

<pad> I absolutely
tensor([[   0,    7, 2776]])


In [107]:
tokener.decode([7])

'I'

In [115]:
base_logits = base_model(input_ids, attention_mask= attention_mask,
                         decoder_input_ids = decoder_input_ids, decoder_attention_mask = decoder_attention_mask)["logits"]

In [109]:
base_logits.shape

torch.Size([1, 2, 110080])

In [116]:

next_token_logits = base_logits[:, -1, :]
temperature = 1
p = 0.9
k = 0
# Temperature (higher temperature => more likely to sample low probability tokens)
if temperature != 1.0:
    next_token_logits = next_token_logits / temperature
if k > 0 or p < 1.0:
    next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=k, top_p=p)
    # Sample
    probs = F.softmax(next_token_logits, dim=-1)
    next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)

In [117]:
tokener.decode(next_tokens.data)

'hate'

In [110]:

probabilities, predicted = torch.max(base_logits[0].cpu().data, 1)
tokener.decode(predicted.data.tolist())

'I absolutely'

In [50]:
predicted

tensor([[0, 1, 1,  ..., 0, 1, 1]])