In [1]:
import os
# os.chdir('/home/s3/hyeryung/mucoco')
os.chdir('/data/hyeryung/mucoco')

In [2]:
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM

In [3]:
import torch
import torch.nn.functional as F
from new_module.losses import BaseLoss, register_loss

In [4]:
device='cuda:1'

# GPT2

In [5]:
# model=AutoModel.from_pretrained('gpt2-large',cache_dir='/shared/s3/lab07/hyeryung/hf_cache')
model=AutoModelForCausalLM.from_pretrained('gpt2-large',cache_dir='/data/hyeryung/hf_cache')

In [6]:
model=model.to(device)
model=model.eval()

In [7]:
tokenizer=AutoTokenizer.from_pretrained('gpt2-large')
tokenizer.pad_token_id =tokenizer.eos_token_id

In [8]:
prompt='abc'
gens=['dsxe','sdvbfe']

In [9]:
num_samples=len(gens); print(num_samples)

2


In [10]:
prompt_enc=tokenizer.encode_plus(prompt,add_special_tokens=False, return_tensors="pt", padding=True, truncation=True).to(device)

In [11]:
prompt_enc['input_ids']=prompt_enc['input_ids'].expand(num_samples,-1)

In [12]:
prompt_enc['attention_mask']=prompt_enc['attention_mask'].expand(num_samples,-1)

In [13]:
prompt_enc

{'input_ids': tensor([[39305],
        [39305]], device='cuda:1'), 'attention_mask': tensor([[1],
        [1]], device='cuda:1')}

In [14]:
gens_enc=tokenizer.batch_encode_plus(gens, add_special_tokens=False, return_tensors="pt", padding=True, truncation=True).to(device)

In [15]:
gens_enc

{'input_ids': tensor([[ 9310, 27705, 50256, 50256],
        [21282,    85,    65,  5036]], device='cuda:1'), 'attention_mask': tensor([[1, 1, 0, 0],
        [1, 1, 1, 1]], device='cuda:1')}

In [16]:
input_tokens = torch.cat([prompt_enc.input_ids, gens_enc.input_ids], dim=1)

In [17]:
attention_masks = torch.cat([prompt_enc.attention_mask, gens_enc.attention_mask], dim=1)

In [18]:
model_output = model(input_ids=input_tokens,
                    attention_mask=attention_masks)

In [19]:
model_output.keys()

odict_keys(['logits', 'past_key_values'])

In [20]:
lm_logits = model_output[0][:, prompt_enc.input_ids.size(1)-1:-1, :]
lm_logprobs = F.log_softmax(lm_logits, dim=-1)

In [21]:
lm_logprobs.shape

torch.Size([2, 4, 50257])

In [22]:
gens_enc.input_ids.shape

torch.Size([2, 4])

In [23]:
loss = F.nll_loss(lm_logprobs.permute(0,2,1), gens_enc.input_ids, reduction="none")

In [24]:
loss = loss * gens_enc.attention_mask

In [25]:
loss.shape

torch.Size([2, 4])

In [26]:
loss = loss.sum(dim=-1)

In [27]:
loss

tensor([19.280, 26.730], device='cuda:1', grad_fn=<SumBackward1>)

In [28]:
loss /= gens_enc.attention_mask.sum(dim=-1) ## 이렇게 하는게 맞을지 조금 고민이다.

In [29]:
loss

tensor([9.640, 6.683], device='cuda:1', grad_fn=<DivBackward0>)

In [45]:
from typing import List

class GPT2Loss(BaseLoss):

    def __init__(self, model, tokenizer, args):
        super().__init__() 

        self.model = model
        self.tokenizer = tokenizer 
        self.args = args
        self.device = model.device
        
        self.eos_token_id = self.tokenizer.eos_token_id    
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.model.config.pad_token_id = self.model.config.eos_token_id # to remove the warning
    
    def compute_gold_loss(self, prompt:str, predictions:List[str], **kwargs):
        '''
        given a discrete target output, this will compute the loss wrt to it. Useful in debugging
        '''
        # prompt = self.tokenizer.batch_encode_plus(prompt, add_special_tokens=False, return_tensors="pt", padding=True, truncation=True).to(self.device).long()
        # # assuming batch size of 1 (prediction is a string instance.)
        # prediction = self.tokenizer.batch_encode_plus(prediction, add_special_tokens=False, return_tensors="pt", padding=True, truncation=True).to(self.device).long()
        # input_tokens = torch.cat([prompt.input_ids, prediction.input_ids], dim=1)
        # model_output = self.model(input_tokens)

        # lm_logits = model_output[0][:, prompt.size(1)-1:-1, :]
        # lm_logprobs = F.log_softmax(lm_logits, dim=-1)

        # loss = F.nll_loss(lm_logprobs.squeeze(0), prediction.squeeze(0), reduction="none").sum(dim=-1)
        
        # if self.args.length_normalize:
        #     loss /= lm_logprobs.size(1)

        # return loss
        num_samples = len(predictions)
        prompt_enc=self.tokenizer.encode_plus(prompt,add_special_tokens=False, return_tensors="pt", padding=True, truncation=True).to(self.device)
        prompt_enc['input_ids']=prompt_enc['input_ids'].expand(num_samples,-1)
        prompt_enc['attention_mask']=prompt_enc['attention_mask'].expand(num_samples,-1)
    
        predictions_enc=self.tokenizer.batch_encode_plus(predictions, add_special_tokens=False, return_tensors="pt", padding=True, truncation=True).to(self.device)

        input_tokens = torch.cat([prompt_enc.input_ids, predictions_enc.input_ids], dim=1)
        attention_masks = torch.cat([prompt_enc.attention_mask, predictions_enc.attention_mask], dim=1)
        with torch.no_grad():
            model_output = self.model(input_ids=input_tokens,
                                attention_mask=attention_masks)
        lm_logits = model_output[0][:, prompt_enc.input_ids.size(1)-1:-1, :]
        lm_logprobs = F.log_softmax(lm_logits, dim=-1)

        # input dimensions : (N, C, d1), (N, d1)
        loss = F.nll_loss(lm_logprobs.permute(0,2,1), predictions_enc.input_ids, reduction="none")
        loss = loss * predictions_enc.attention_mask # make losses for pad tokens 0.
        
        loss = loss.sum(dim=-1)
        if self.args.length_normalize:
            loss /= predictions_enc.attention_mask.sum(dim=-1) 
        return loss # dimensions: (N)
    
    def generate(self, input_ids, **kwargs):
        prepared_input = self._prepare_input_for_generation(input_ids, **kwargs)
        output = self.model.generate(**prepared_input)
        
        return self._postprocess_output(prepared_input, output)

    def _prepare_input_for_generation(self, input_ids, **kwargs):
        max_output_length = getattr(self.args, "max_output_length", 10)
        batch_size = input_ids.size(0)
        #batch size is 1, padding and stuff needs to be modified for this to work for larger batches

        return_object = {'input_ids': input_ids,
                'max_length': input_ids.size(1) + max_output_length,
                'do_sample': True,
                'temperature': self.args.AR_temperature,
                'top_k': self.args.AR_top_k,
                'top_p': self.args.AR_top_p,
                'num_return_sequences': kwargs.get('num_return_sequences', 1)}
   
        return return_object
    
    def _postprocess_output(self, prepared_input, output_ids):
        return output_ids[:, prepared_input['input_ids'].size(1):, ]

In [46]:
class Args:
    length_normalize = False

gpt2_loss = GPT2Loss(model,tokenizer,Args())

In [47]:
gpt2_loss.compute_gold_loss(prompt, gens)

tensor([19.280, 26.730], device='cuda:1')

# Classification no prefix

In [49]:
from transformers import AutoModelForSequenceClassification

In [58]:
ckpt_path = '/data/hyeryung/loc_edit/models/roberta-base-jigsaw-toxicity-classifier-energy-training/step_1000_best_checkpoint/'
model = AutoModelForSequenceClassification.from_pretrained(ckpt_path)
model = model.eval()
model = model.to(device)

In [51]:
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)

In [60]:
prompt='abc'
prediction=['dsxe','sdvbfe']

In [54]:
eos_token_id = tokenizer.eos_token_id

In [55]:
label_id = 0

In [61]:
prediction = tokenizer.batch_encode_plus(prediction, add_special_tokens=True, return_tensors="pt", padding=True, truncation=True).to(device)
        
# eos = torch.empty((prediction.size(0), 1)).long().to(device).fill_(eos_token_id)
# prediction = torch.cat([prediction, eos, eos], dim=1)

model_output = model(**prediction)
lm_logits = model_output[0]
lm_logprobs = F.log_softmax(lm_logits, dim=-1)
loss = -lm_logprobs[:, label_id]

In [66]:
loss

tensor([0.142, 0.013], device='cuda:1', grad_fn=<NegBackward0>)

In [67]:
class ClassificationLogProbLoss(BaseLoss):

    def __init__(self, model, tokenizer, args):
        super().__init__() 
        
        self.model = model 
        self.tokenizer = tokenizer 
        self.args = args
        self.device = model.device

        self.bos_token_id = self.tokenizer.bos_token_id
        self.eos_token_id = self.tokenizer.eos_token_id    

    def compute_gold_loss(self, prompt:str, prediction:List[str], label_id, **kwargs):
        '''
        given a discrete target output, this will compute the loss wrt to it. Useful in debugging
        '''

        # prediction = self.tokenizer.encode(prediction, add_special_tokens=True, return_tensors="pt", padding=True, truncation=True).to(self.device).long()
        
        # eos = torch.empty((prediction.size(0), 1)).long().to(self.device).fill_(self.eos_token_id)
        # prediction = torch.cat([prediction, eos, eos], dim=1)
    
        # model_output = self.model(prediction)
        # lm_logits = model_output[0]
        # lm_logprobs = F.log_softmax(lm_logits, dim=-1)
        # loss = -lm_logprobs[:, label_id]
        # return loss
        
        prediction = self.tokenizer.batch_encode_plus(prediction, add_special_tokens=True, return_tensors="pt", padding=True, truncation=True).to(self.device)
        model_output = self.model(**prediction)
        lm_logits = model_output[0]
        lm_logprobs = F.log_softmax(lm_logits, dim=-1)
        loss = -lm_logprobs[:, label_id]
        return loss


In [68]:
clsf_loss = ClassificationLogProbLoss(model, tokenizer, {})

In [69]:
clsf_loss.compute_gold_loss(prompt, gens, 0)

tensor([0.142, 0.013], device='cuda:1', grad_fn=<NegBackward0>)

: 