-
Notifications
You must be signed in to change notification settings - Fork 45
Description
Describe the bug
When i set Packing=False with SFT code, the following error will appear :
Traceback (most recent call last):
File "/home/user/supervised_finetuning_trainer_test.py", line 145, in <module>
main()
File "/home/user/supervised_finetuning_trainer_test.py", line 131, in main
trainer = ed.SFTTrainer(
File "/home/user/.local/lib/python3.10/site-packages/easydel/trainers/supervised_fine_tuning_trainer/sft_trainer.py", line 78, in __init__
train_dataset = self._prepare_dataset(
File "/home/user/.local/lib/python3.10/site-packages/easydel/trainers/supervised_fine_tuning_trainer/sft_trainer.py", line 175, in _prepare_dataset
return self._prepare_non_packed_dataloader(
File "/home/user/.local/lib/python3.10/site-packages/easydel/trainers/supervised_fine_tuning_trainer/sft_trainer.py", line 282, in _prepare_non_packed_dataloader
tokenized_dataset = dataset.map(tokenize, **map_kwargs)
File "/home/user/.local/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 560, in wrapper
out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
File "/home/user/.local/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 3073, in map
for rank, done, content in Dataset._map_single(**dataset_kwargs):
File "/home/user/.local/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 3476, in _map_single
batch = apply_function_on_filtered_inputs(
File "/home/user/.local/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 3338, in apply_function_on_filtered_inputs
processed_inputs = function(*fn_args, *additional_args, **fn_kwargs)
File "/home/user/.local/lib/python3.10/site-packages/easydel/trainers/supervised_fine_tuning_trainer/sft_trainer.py", line 234, in tokenize
else formatting_func(element)
File "/home/user/.local/lib/python3.10/site-packages/easydel/trainers/utils.py", line 80, in _pc
return conversations_formatting_function(processing_class, messages_field="conversation")(to_role_and_content(sample))
File "/home/user/.local/lib/python3.10/site-packages/easydel/trainers/utils.py", line 75, in to_role_and_content
{"role": "user", "content": field["conversation"][0]["input"]},
TypeError: list indices must be integers or slices, not str
Also there is a another problem with Packing. Packing generally tend to give a very bad results compared to Packing=False. This issue exists even in the 0.0.80 version. I think in the long-term to fix this issue we may need to adapt different strategy for packing. This paper discuss various method for packing and it find that greedy packing generally gives the best results :
https://arxiv.org/pdf/2410.08081v2
Also , is the packing method used in TRL same as it has been done here ?
https://github.com/huggingface/trl/blob/69ad852e5654a77f1695eb4c608906fe0c7e8624/trl/data_utils.py#L435
On the other hand, i noticed when Packing=True, sometime especially for custom models, the SFT code will ended up with NaN loss. The reason why i think custom model suffers from this issue, because we need to do significant gradient normalizations as the model is raw and have not seen any SFT dataset before so trying to learn using huge chunked (packed) examples will cause a gradient vanishing issue.
One suggested way to solve this issue is to follow an adaptive packing approach. In this adaptive approach we will kick the finetuning process with a small subset of the dataset (unpacked) for initial steps (e.g. first 10 steps) . Then, after 10 steps, we will use the packed examples. I have done this approach manually and it work very well to avoid NaN issue. I finetuned my custom model (0.0.80) for 10 steps using unpacked examples. Then, i finetuned the saved model from previous step using Packing=True and the NaN issue disappeared.
Thus, i believe that adding a new feature called adaptive packing strategy would solve this issue.
To Reproduce
Just change the value of Packing to False using this code :
https://github.com/erfanzar/EasyDeL/blob/main/tests/trainers/supervised_finetuning_trainer_test.py