Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why are instructions not masked when performing VSFT for LLaVa? #1880

Open
shijian2001 opened this issue Jul 27, 2024 · 6 comments
Open

Why are instructions not masked when performing VSFT for LLaVa? #1880

shijian2001 opened this issue Jul 27, 2024 · 6 comments
Assignees
Labels
vlm Related to Visual Language Model

Comments

@shijian2001
Copy link

I have some questions about the LLavaDataCollator in the vsft_llava.py:

https://github.com/huggingface/trl/blob/main/examples/scripts/vsft_llava.py

class LLavaDataCollator:
     def __init__(self, processor):
         self.processor = processor

     def __call__(self, examples):
         texts = []
         images = []
         for example in examples:
             if len(example["images"]) > 1:
                 raise ValueError("This collator only supports one image per example")
             messages = example["messages"]
             text = self.processor.tokenizer.apply_chat_template(
                 messages, tokenize=False, add_generation_prompt=False
             )
             texts.append(text)
             images.append(example["images"][0])

         batch = self.processor(texts, images, return_tensors="pt", padding=True)

         labels = batch["input_ids"].clone()
         if self.processor.tokenizer.pad_token_id is not None:
             labels[labels == self.processor.tokenizer.pad_token_id] = -100
         batch["labels"] = labels

         return batch

I noticed that you copied the input_id (image, question concatenated with answer) to the label, and then only set the label of the pad token to -100 (no loss will be calculated). However, as far as I understand SFT, only the loss of the answer part should be calculated, which means that we should also set the labels of all question parts to -100?

Looking forward to your reply!

@shijian2001
Copy link
Author

@qgallouedec When performing SFT on a VLM, it may be a better choice only to calculate the loss of the response part. Does trl provide a direct implementation for this? Can you give an example? Thanks!

@qgallouedec qgallouedec self-assigned this Aug 1, 2024
@qgallouedec qgallouedec added the vlm Related to Visual Language Model label Aug 1, 2024
@shijian2001
Copy link
Author

@qgallouedec Sorry to bother you, I would like to ask if SFTTrainer can directly calculate the loss of only the response part, and whether you have plans to implement a vsft script that only calculates the response loss. Thank you!

@qgallouedec
Copy link
Member

qgallouedec commented Aug 5, 2024

Hi, sorry for the delay, I'm addressing the issues in order, and there are a lot these days.

only the loss of the answer part should be calculated

Can you justify this?
In general, loss is calculated over the entire text input, including the prompt and the answer.

When performing SFT on a VLM, it may be a better choice only to calculate the loss of the response part.

I'm not sure about this. Have you tried it? It would be good to have some results to confirm or refute this statement.

Does trl provide a direct implementation for this?

A small modification of the data collator should be enough. Just set labels to -100 for the prompt part.

@shijian2001
Copy link
Author

In the implementation of the llava repository, the padding token and instruction token are all set to -100. For reference, see the preprocess_v1 function in https://github.com/haotian-liu/LLaVA/blob/main/llava/train/train.py

if has_image:
    round_len = len(tokenizer_image_token(rou, tokenizer))
    instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
else:
    round_len = len(tokenizer(rou).input_ids)
    instruction_len = len(tokenizer(parts[0]).input_ids) - 2

if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14:
    round_len -= 1
    instruction_len -= 1

target[cur_len : cur_len + instruction_len] = IGNORE_INDEX

@shijian2001
Copy link
Author

I roughly implemented this idea as follows:

class DataCollator:
    def __init__(self, processor, enable_mask_instructions: bool=False):
        self.processor = processor
        self.processor.tokenizer.model_max_length = 2048
        self.enable_mask_instructions = enable_mask_instructions
        self.IGNORE_INDEX = -100

    def _mask_padding_tokens(self, labels: torch.Tensor):
        """Only mask padding tokens"""
        pad_token_id = self.processor.tokenizer.pad_token_id
        labels[labels == pad_token_id] = self.IGNORE_INDEX
        return labels

    def _prepare_vsft_labels(self, labels: torch.Tensor):
        """Mask instructions and padding tokens"""

        # [Note] EOS token and assistant_token may be different for different chat_templates
        eos_token_id = self.processor.tokenizer.convert_tokens_to_ids("</s>")
        assistant_token_id = self.processor.tokenizer.encode("ASSISTANT:", add_special_tokens=False)

        batch_size, _ = labels.shape
        
        for i in range(batch_size):

            # Get positions of all eos tokens
            eos_positions = (labels[i] == eos_token_id).nonzero(as_tuple=True)[0]
            # Add 0 to eos_positions; Helpful for following loop
            eos_positions = torch.cat([torch.tensor([0], device=labels.device), eos_positions])
            
            # Consider the first special token <s>
            cur_len = 1
            labels[i, :cur_len] = self.IGNORE_INDEX

            for j in range(len(eos_positions) - 1):
                start = eos_positions[j]
                end = eos_positions[j+1]
                
                assistant_pos = None
                for k in range(start, end - len(assistant_token_id) + 1):
                    if torch.equal(labels[i, k:k+len(assistant_token_id)], torch.tensor(assistant_token_id, device=labels.device)):
                        assistant_pos = k
                        break
                    
                if assistant_pos is not None:
                    labels[i, cur_len:assistant_pos + len(assistant_token_id)] = self.IGNORE_INDEX
                    cur_len = end + 1
        
        masked_labels = self._mask_padding_tokens(labels)
        
        return masked_labels

    def __call__(self, examples):
        texts = []
        images = []
        for example in examples:
            image = example["images"][0]
            messages = example["messages"]
            text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
            texts.append(text.strip())
            images.append(image)

        batch = self.processor(text=texts, images=images, return_tensors="pt", truncation=True, padding=True) # lauch truncated

        labels = batch["input_ids"].clone()
        if self.enable_mask_instructions:
            # Mask instructions and padding tokens
            mask_labels = self._prepare_vsft_labels(labels)
        else:
            # Only mask padding tokens
            mask_labels = self._mask_padding_tokens(labels)
            
        batch["labels"] = mask_labels

        return batch

@qgallouedec
Copy link
Member

Thanks for the reference and for the piece of code which can certainly be useful.
My position is to keep the sft example for vlm as it is (don't mask the instructions). If at some point we manage to prove that in the general case we get faster convergence or better results with instruction masking, then we'll modify the example along those lines.
Feel free to feed this conversation if you find interesting results.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
vlm Related to Visual Language Model
Projects
None yet
Development

No branches or pull requests

2 participants