Introducing DataCollatorForChatCompletionOnlyLM#456
Introducing DataCollatorForChatCompletionOnlyLM#456younesbelkada merged 12 commits intohuggingface:mainfrom
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
younesbelkada
left a comment
There was a problem hiding this comment.
Thank you so much for your great work and adding this new data collator!
This looks great, can you just add a simple test to make sure we won't break that collator in the future ?
Check out this very simple test: https://github.com/lvwerra/trl/blob/main/tests/test_sft_trainer.py#L410-L424
Also can you run the styling checks ?
make style && make quality(cc for @lvwerra )
We don't need the check we added here: https://github.com/lvwerra/trl/blob/33f88ead0b31dbfe7d190756cdc8f5fc63b04363/trl/trainer/sft_trainer.py#L133 as in theory this new collator would work for packed dataset as well.
|
@younesbelkada thanks for the feedback. I added what you asked for. If you need something else don't hesitate. |
younesbelkada
left a comment
There was a problem hiding this comment.
Thanks you very much for adding the nice test !
I have one comment, also after discussing with @lvwerra we thought it might make sense to unify DataCollatorForCompletionOnlyLM and DataCollatorForChatCompletionOnlyLM in a single data collator. Do you want to give it a try? otherwise we can merge this PR and I can take care of that right after. Let me know what do you think
| labels = batch["labels"] | ||
| non_masked_tokens = batch["input_ids"][labels != -100] | ||
| result_text = self.tokenizer.decode(non_masked_tokens) | ||
| self.assertTrue(result_text == " I have not been masked correctly.### I have not been masked correctly too.") |
There was a problem hiding this comment.
It is a bit surprising that there is ### in the non-masked part. I think adding a space or a line break between dots and ### would solve the issue. Can you try that out and see if it works? 🙏
There was a problem hiding this comment.
The problem with this is that if we do not train the model to output the human response token (either an added special token or the first token of it), the model will just spit out text until max_length. We need a way to stop the model from generating text once it reaches a certain token (I have chosen the human response tokens but it is also possible to add an eos token at the end of each assistant response). What do you think? (sorry if I was unclear)
There was a problem hiding this comment.
The idea is to help the model know that this is the end of the assistant response.
There was a problem hiding this comment.
This seems like a fundamental problem and I see both of your opinions. Ideally, every turn ends with an endofturn token. I believe that the Chat UI does this. It adds a </s> token at the end of each turn. Llama 2 uses some special tokens to indicate end/start of user queries vs system queries (https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/generation.py#L44-L45). So the question here becomes how we can accomodate for that I think...
|
@younesbelkada Thanks for the suggestions, gonna give it a try this weekend! |
| if idx!=0: | ||
| labels[i, start+1:end] = -100 |
| for assistant_idx in np.where(batch["labels"][i] == assistant_token_ids[0])[0]: | ||
| # find the indexes of the start of an assistant answer. | ||
| if assistant_token_ids == examples[i]["input_ids"][assistant_idx : assistant_idx + len(assistant_token_ids)]: | ||
| assistant_token_ids_idxs.append(assistant_idx + len(assistant_token_ids)) |
There was a problem hiding this comment.
Shouldn't this just be append(assistant_idx) (so that the <|assistant|> part is also ignored)?
|
|
||
|
|
||
| class DataCollatorForChatCompletionOnlyLM(DataCollatorForLanguageModeling): | ||
| def __init__(self, human_template, assistant_template, *args, **kwargs): |
There was a problem hiding this comment.
Would it make sense to add a docstring explaining that this data collator specifically masks labels for the prompt?
|
Can I help with this one? |
|
Hello, sorry have been pretty busy the last 3 weeks, I will work on this tomorrow |
|
Thanks so much @gaetanlop , let us know if you need any help ! This collator would be really helpful for the community |
|
Hello @younesbelkada I merged the two collators and removed the ### symbole from the completion as requested. The new collator is both working for alpaca and oasst dataset. It should not be used with packing. What do you think? |
| assistant_token_ids_idxs.append(assistant_idx + len(assistant_token_ids)) | ||
|
|
||
| if len(assistant_token_ids) == 0: | ||
| raise RuntimeError( |
There was a problem hiding this comment.
ValueError's seem better suited for missing expected Values.
|
|
||
| response_token_ids_end_idx = response_token_ids_start_idx + len(response_token_ids) | ||
| if self.instruction_template is not None: | ||
| human_token_ids = self.tokenizer.encode(self.instruction_template, add_special_tokens=False) |
There was a problem hiding this comment.
Maybe add this in init so that it just has to happens once and not on every call
There was a problem hiding this comment.
Yes, you can add self.human_token_ids = self.tokenizer.encode(self.instruction_template, add_special_tokens=False) at the init, good point.
| response_token_ids = self.tokenizer.encode(self.response_template, add_special_tokens=False) | ||
| assistant_token_ids = self.tokenizer.encode(self.assistant_template, add_special_tokens=False) | ||
|
|
||
| labels = batch["labels"].clone() |
There was a problem hiding this comment.
To avoid needlessly increasing memory usage, all the operation on labels can be done directly on batch["labels"] so that cloning is not necessary.
|
|
||
| # Make pytorch loss function ignore all tokens up through the end of the response key | ||
| labels[i, :response_token_ids_end_idx] = self.ignore_index | ||
| if len(assistant_token_ids_idxs) < len(human_token_ids_idxs): |
There was a problem hiding this comment.
Thanks a lot for your great work on this !
I really like the fact that we merge both collators! My comment is about backward compatiblity - I think we should still keep the behaviour of the previous collator to avoid unexpected behaviour for users that already use that collator. For that what I suggest is:
1- rename assistant_template to response_template
2- make instruction_template optional (defaults to None)
If instruction_template is None --> behave as the old data collator / if both are not None --> behave as the new collator
Also, make sure to run make precommit before pushing!
|
Thanks @younesbelkada and @BramVanroy. I made the required changes. For some of @BramVanroy comments such as putting the response token ids, I didn't makee the changes to follow what was done in the initial DataCollatorForCompletionOnlyLM but I also believe this would be better. |
|
|
||
| text = """### Human: Hello all this should be masked.### Assistant: I should not be masked.### Human: All this should be masked too.### Assistant: I should not be masked too.""" | ||
| encoded_text = self.tokenizer(text) | ||
| encoded_text["input_ids"] = encoded_text["input_ids"] |
There was a problem hiding this comment.
This can be removed I think
|
Thank you for your work @gaetanlop! I left some comments where it might be useful if @younesbelkada has a look to see whether my comments are worth including or not. I tried the collator but it is failing for me. Full example below, copy-pasted from your tests with the falcon tokenizer. import os
import random
import warnings
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
import numpy as np
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import IterableDataset
from transformers import DataCollatorForLanguageModeling, PreTrainedTokenizerBase, TrainerCallback, AutoTokenizer
class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling):
"""
Data collator used for completion tasks. It ensures that all the tokens of the labels are set to an 'ignore_index'
when they do not come from the assistant. This ensure that the loss is only
calculated on the completion made by the assistant.
Args:
instruction_template (`Optional[str]`): the template form that indicates the start of the human instruction, typically something like
'### Human:\n'. Useful for assistant-style conversation datasets
response_template (`str`): the template form that indicates the start of the response, typically something like
'### Response:\n'
mlm (`bool`, *optional*, defaults to `False`): Whether or not to use masked language modeling in the underlying
`DataCollatorForLanguageModeling` class. Note that this option currently has no effect but is present
for flexibility and backwards-compatibility.
ignore_index (`int`, *optional*, defaults to `-100`):
The index to use to ignore the initial tokens with
"""
def __init__(
self,
response_template: str,
instruction_template: Optional[str] = None,
*args,
mlm: bool = False,
ignore_index: int = -100,
**kwargs,
):
super().__init__(*args, mlm=mlm, **kwargs)
self.instruction_template = instruction_template
self.response_template = response_template
self.ignore_index = ignore_index
def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
batch = super().torch_call(examples)
response_token_ids = self.tokenizer.encode(self.response_template, add_special_tokens=False)
labels = batch["labels"].clone()
if self.instruction_template is None:
for i in range(len(examples)):
response_token_ids_start_idx = None
for idx in np.where(batch["labels"][i] == response_token_ids[0])[0]:
# `response_token_ids` is `'### Response:\n'`, here we are just making sure that the token IDs match
if response_token_ids == examples[i]["input_ids"][idx : idx + len(response_token_ids)]:
response_token_ids_start_idx = idx
if response_token_ids_start_idx is None:
raise RuntimeError(
f'Could not find response key {response_token_ids} in token IDs {batch["labels"][i]}'
)
response_token_ids_end_idx = response_token_ids_start_idx + len(response_token_ids)
# Make pytorch loss function ignore all tokens up through the end of the response key
labels[i, :response_token_ids_end_idx] = self.ignore_index
else:
for i in range(len(examples)):
response_token_ids_idxs = []
human_token_ids_idxs = []
for assistant_idx in np.where(batch["labels"][i] == response_token_ids[0])[0]:
# find the indexes of the start of a response.
if (
response_token_ids
== examples[i]["input_ids"][assistant_idx : assistant_idx + len(response_token_ids)]
):
response_token_ids_idxs.append(assistant_idx + len(response_token_ids))
if len(response_token_ids) == 0:
raise RuntimeError(
f'Could not find response key {response_token_ids} in token IDs {batch["labels"][i]}'
)
human_token_ids = self.tokenizer.encode(self.instruction_template, add_special_tokens=False)
for human_idx in np.where(batch["labels"][i] == human_token_ids[0])[0]:
# find the indexes of the start of a human answer.
if human_token_ids == examples[i]["input_ids"][human_idx : human_idx + len(human_token_ids)]:
human_token_ids_idxs.append(human_idx)
if len(human_token_ids_idxs) == 0:
raise RuntimeError(
f'Could not find response key {human_token_ids} in token IDs {batch["labels"][i]}'
)
for idx, (start, end) in enumerate(zip(human_token_ids_idxs, response_token_ids_idxs)):
# Make pytorch loss function ignore all non response tokens
if idx != 0:
labels[i, start:end] = self.ignore_index
else:
labels[i, :end] = self.ignore_index
if len(response_token_ids_idxs) < len(human_token_ids_idxs):
labels[i, human_token_ids_idxs[-1] :] = self.ignore_index
batch["labels"] = labels
return batch
if __name__ == '__main__':
tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-40b", trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
instruction_template = "### Human:"
assistant_template = "### Assistant:"
data_collator = DataCollatorForCompletionOnlyLM(
instruction_template, assistant_template, tokenizer=tokenizer, mlm=False
)
text = """### Human: Hello all this should be masked.### Assistant: I should not be masked.### Human: All this should be masked too.### Assistant: I should not be masked too."""
encoded_text = tokenizer(text)
encoded_text["input_ids"] = encoded_text["input_ids"]
examples = [encoded_text]
batch = data_collator(examples)
labels = batch["labels"]
non_masked_tokens = batch["input_ids"][labels != -100]
result_text = tokenizer.decode(non_masked_tokens)
assert result_text == " I should not be masked. I should not be masked too." |
| @@ -423,6 +423,25 @@ def test_data_collator_completion_lm(self): | |||
| result_text = self.tokenizer.decode(batch["input_ids"][0, last_pad_idx + 1 :]) | |||
| self.assertTrue(result_text == "I have not been masked correctly.") | |||
There was a problem hiding this comment.
assertEqual is better here for useful test output. Otherwise you will get "True is not False" or something like that in case of error.
| labels = batch["labels"] | ||
| non_masked_tokens = batch["input_ids"][labels != -100] | ||
| result_text = self.tokenizer.decode(non_masked_tokens) | ||
| self.assertTrue(result_text == " I should not be masked. I should not be masked too.") |
|
Thanks for the comments @BramVanroy, your test failed because of a wrong ordering of the arguments in the initialization of the collator. It is fixed now, thanks for pointing this out. |
|
@younesbelkada are there other steps I should do before you can merge this? |
younesbelkada
left a comment
There was a problem hiding this comment.
Thank you very much @gaetanlop for your great work and making the tests pass. Thanks a mile also @BramVanroy for the investigation and review. I left some minor comments and I think @BramVanroy have raised very good points. Would you be happy addressing them? 🙏
Also let's add few lines in the docs here: https://github.com/lvwerra/trl/blob/main/docs/source/sft_trainer.mdx#L53
Happy also to help you, let me know!
| @@ -423,6 +423,28 @@ def test_data_collator_completion_lm(self): | |||
| result_text = self.tokenizer.decode(batch["input_ids"][0, last_pad_idx + 1 :]) | |||
| self.assertTrue(result_text == "I have not been masked correctly.") | |||
There was a problem hiding this comment.
| self.assertTrue(result_text == "I have not been masked correctly.") | |
| self.assertEqual(result_text, "I have not been masked correctly.") |
| labels = batch["labels"] | ||
| non_masked_tokens = batch["input_ids"][labels != -100] | ||
| result_text = self.tokenizer.decode(non_masked_tokens) | ||
| self.assertTrue(result_text == " I should not be masked. I should not be masked too.") |
There was a problem hiding this comment.
| self.assertTrue(result_text == " I should not be masked. I should not be masked too.") | |
| self.assertEqual(result_text, " I should not be masked. I should not be masked too.") |
|
|
||
| response_token_ids_end_idx = response_token_ids_start_idx + len(response_token_ids) | ||
| if self.instruction_template is not None: | ||
| human_token_ids = self.tokenizer.encode(self.instruction_template, add_special_tokens=False) |
There was a problem hiding this comment.
Yes, you can add self.human_token_ids = self.tokenizer.encode(self.instruction_template, add_special_tokens=False) at the init, good point.
| response_token_ids = self.tokenizer.encode(self.response_template, add_special_tokens=False) | ||
| assistant_token_ids = self.tokenizer.encode(self.assistant_template, add_special_tokens=False) | ||
|
|
||
| labels = batch["labels"].clone() |
|
Thanks @younesbelkada and @BramVanroy for all your comments. I made the changes requested. @younesbelkada for the doc, I have added an example for chat data with the collator. Feel free to change it if you think there is a better way to showcase how it works. |
There was a problem hiding this comment.
Thanks a lot for iterating @gaetanlop !
All this looks great now, one last thing, it seems you have removed this nice test by mistake:
def test_data_collator_chat_completion_lm(self):
instruction_template = "### Human:"
assistant_template = "### Assistant:"
data_collator = DataCollatorForCompletionOnlyLM(
response_template=assistant_template,
instruction_template=instruction_template,
tokenizer=self.tokenizer,
mlm=False,
)
text = """### Human: Hello all this should be masked.### Assistant: I should not be masked.### Human: All this should be masked too.### Assistant: I should not be masked too."""
encoded_text = self.tokenizer(text)
encoded_text["input_ids"] = encoded_text["input_ids"]
examples = [encoded_text]
batch = data_collator(examples)
labels = batch["labels"]
non_masked_tokens = batch["input_ids"][labels != -100]
result_text = self.tokenizer.decode(non_masked_tokens)
self.assertEqual(result_text, " I should not be masked. I should not be masked too.")Can you add it again please 🙏 After that we can merge I think
Thanks again !
|
@younesbelkada But change the |
|
@BramVanroy , nice catch ! just modified it |
|
Sorry, I thought it was in the requested changes. @younesbelkada I just added it back. |
younesbelkada
left a comment
There was a problem hiding this comment.
Thanks for this great contribution!
|
Thanks for working on this! I'm excited to see this get merged |
|
Q: Does this PR take the context of the conversation into account? Or does it only train on singular instruction + response pairs? |
|
Should take into account multi-turn conversation and mask the appropriate parts of the conversation. |
|
Hello - thanks for this great library and feature I have been trying to fine-tune some models using this collator, but the resulting model I created does not generate the stop token I have been using Mistral, which of course does not have a Looking into the implementation of this Collator, I see that it inherits from When I have been using the model I trained with this flow, I have been seeing that the model rarely (if ever) generates the I wanted to flag this as I believe the |
|
Hello @rsnm2! I also encountered roughly this issue myself, resulting in me subclassing the collator of my choice with one that didn't replace the EOS tokens with -100. Otherwise the model would never train to produce EOS tokens. Perhaps this should be resolved via 1) a documentation change or 2) a change in the collators (though I was also using ones from pure
|
|
Thanks @tomaarsen. I figured this was the case and Ill give it a try. What do you think of setting I think I am okay with never backpropogating loss for mispredicting Anything else you would mention Also --- I am going to make an issue with this issue to discuss making a change to either docs or this since I've seen this come up in quite a few places on the interwebs I really wish the mistral / llama people just included a pad token in their vocab :) |
|
@rsnm2 see https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/sft_llama2.py#L177 which set pad_token to be eos_token. This is also done in DPO example. |
* added DataCollatorForChatCompletionOnlyLM * added simple test * merged the two collators and fixed ### in completion * fix response template * fixing ordering in test * quality * fixed minor comments & make doc * chat test back * Update tests/test_sft_trainer.py --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
This PR is a follow-up to #439. It introduces another data collator to mask all the non-assistant responses in a chatbot setting.
In essence, this is a generalization to the DataCollatorForCompletionOnlyLM for chatbot datasets like the OASST one.
It finds the indexes of the start of human responses and the end of assistant responses and mask labels so that human responses are not taken into account in the loss. I made it so that the first token of the human template is not masked (so that people can use it later as a stop token).
I think that's ready for a review.
cc @younesbelkada @vwxyzjn