In [10]:
import os
os.chdir('/home/s3/hyeryung/mucoco')

In [1]:
from transformers import AutoModel, AutoTokenizer

In [64]:
import torch
import torch.nn.functional as F
from new_module.losses import BaseLoss, register_loss


In [37]:
device='cuda'

In [14]:
model=AutoModel.from_pretrained('gpt2-large',cache_dir='/shared/s3/lab07/hyeryung/hf_cache')

In [61]:
model=model.to(device)

In [2]:
tokenizer=AutoTokenizer.from_pretrained('gpt2-large')

In [21]:
prompt='abc'
gens=['dsxe','sdvbfe']

In [40]:
num_samples=2

In [46]:
prompt_enc=tokenizer.encode_plus(prompt,add_special_tokens=False, return_tensors="pt", padding=True, truncation=True).to(device)

In [47]:
prompt_enc['input_ids']=prompt_enc['input_ids'].expand(num_samples,-1)

In [48]:
prompt_enc['attention_mask']=prompt_enc['attention_mask'].expand(num_samples,-1)

In [59]:
prompt_enc

{'input_ids': tensor([[39305],
        [39305]], device='cuda:0'), 'attention_mask': tensor([[1],
        [1]], device='cuda:0')}

In [49]:
gens_enc=tokenizer.batch_encode_plus(gens, add_special_tokens=False, return_tensors="pt", padding=True, truncation=True).to(device)

In [55]:
gens_enc

{'input_ids': tensor([[ 9310, 27705, 50256, 50256],
        [21282,    85,    65,  5036]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 0, 0],
        [1, 1, 1, 1]], device='cuda:0')}

In [53]:
input_tokens = torch.cat([prompt_enc.input_ids, gens_enc.input_ids], dim=1)

In [56]:
attention_masks = torch.cat([prompt_enc.attention_mask, gens_enc.attention_mask], dim=1)

In [62]:
model_output = model(input_ids=input_tokens,
                    attention_mask=attention_masks)

In [65]:
lm_logits = model_output[0][:, prompt_enc.input_ids.size(1)-1:-1, :]
lm_logprobs = F.log_softmax(lm_logits, dim=-1)

In [67]:
lm_logprobs.shape

torch.Size([2, 4, 1280])

In [68]:
gens_enc.input_ids.shape

torch.Size([2, 4])

In [70]:
loss = F.nll_loss(lm_logprobs.permute(0,2,1), gens_enc.input_ids, reduction="none")

In [72]:
loss = loss * gens_enc.attention_mask

In [74]:
loss = loss.sum()

In [76]:
loss /= gens_enc.attention_mask.sum() ## 이렇게 하는게 맞을지 조금 고민이다.

In [77]:
loss

tensor(2.569, device='cuda:0', grad_fn=<DivBackward0>)

In [75]:
loss

tensor(15.414, device='cuda:0', grad_fn=<SumBackward0>)

In [18]:
tokenizer.encode_plus(prompt, add_special_tokens=False, return_tensors="pt", padding=True, truncation=True).long()

SyntaxError: invalid non-printable character U+0008 (4036987649.py, line 1)

In [6]:
tokenizer.pad_token_id=tokenizer.eos_token_id

In [7]:
tokenizer.batch_encode_plus(['a','b'],truncation=True, padding=True,return_tensors="pt")

{'input_ids': tensor([[64],
        [65]]), 'attention_mask': tensor([[1],
        [1]])}

In [16]:
tokenizer.batch_encode_plus(['a','bcd'],truncation=True, padding=False,return_tensors="pt")

ValueError: Unable to create tensor, you should probably activate truncation and/or padding with 'padding=True' 'truncation=True' to have batched tensors with the same length. Perhaps your features (`input_ids` in this case) have excessive nesting (inputs type `list` where type `int` is expected).

In [12]:
class GPT2Loss(BaseLoss):

    def __init__(self, model, tokenizer, args):
        super().__init__() 

        self.model = model
        self.tokenizer = tokenizer 
        self.args = args
        self.device = model.device
        
        self.eos_token_id = self.tokenizer.eos_token_id    
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.model.config.pad_token_id = self.model.config.eos_token_id # to remove the warning
    
    def compute_gold_loss(self, prompt, prediction, **kwargs):
        '''
        given a discrete target output, this will compute the loss wrt to it. Useful in debugging
        '''
        prompt = self.tokenizer.batch_encode_plus(prompt, add_special_tokens=False, return_tensors="pt", padding=True, truncation=True).to(self.device).long()
        # assuming batch size of 1 (prediction is a string instance.)
        prediction = self.tokenizer.batch_encode_plus(prediction, add_special_tokens=False, return_tensors="pt", padding=True, truncation=True).to(self.device).long()
        input_tokens = torch.cat([prompt.input_ids, prediction.input_ids], dim=1)
        model_output = self.model(input_tokens)

        lm_logits = model_output[0][:, prompt.size(1)-1:-1, :]
        lm_logprobs = F.log_softmax(lm_logits, dim=-1)

        loss = F.nll_loss(lm_logprobs.squeeze(0), prediction.squeeze(0), reduction="none").sum(dim=-1)
        
        if self.args.length_normalize:
            loss /= lm_logprobs.size(1)

        return loss
    
    def generate(self, input_ids, **kwargs):
        prepared_input = self._prepare_input_for_generation(input_ids, **kwargs)
        output = self.model.generate(**prepared_input)
        
        return self._postprocess_output(prepared_input, output)

    def _prepare_input_for_generation(self, input_ids, **kwargs):
        max_output_length = getattr(self.args, "max_output_length", 10)
        batch_size = input_ids.size(0)
        #batch size is 1, padding and stuff needs to be modified for this to work for larger batches

        return_object = {'input_ids': input_ids,
                'max_length': input_ids.size(1) + max_output_length,
                'do_sample': True,
                'temperature': self.args.AR_temperature,
                'top_k': self.args.AR_top_k,
                'top_p': self.args.AR_top_p,
                'num_return_sequences': kwargs.get('num_return_sequences', 1)}
   
        return return_object
    
    def _postprocess_output(self, prepared_input, output_ids):
        return output_ids[:, prepared_input['input_ids'].size(1):, ]

In [15]:
gpt2_loss = GPT2Loss(model,tokenizer,{})