diff --git a/docs/source/data_utils.md b/docs/source/data_utils.md index c862aa2cfa1..e4acfbb41b8 100644 --- a/docs/source/data_utils.md +++ b/docs/source/data_utils.md @@ -32,10 +32,6 @@ [[autodoc]] maybe_unpair_preference_dataset -## pack_examples - -[[autodoc]] pack_examples - ## pack_dataset [[autodoc]] pack_dataset diff --git a/docs/source/reducing_memory_usage.md b/docs/source/reducing_memory_usage.md index acf5a798571..4dcc97d7a51 100644 --- a/docs/source/reducing_memory_usage.md +++ b/docs/source/reducing_memory_usage.md @@ -77,10 +77,16 @@ This technique applies only to SFT. Packing, introduced in [Raffel et al., 2020](https://huggingface.co/papers/1910.10683), addresses these issues by grouping sequences instead of truncating. It concatenates and splits dataset sequences into the desired lengths.
- Packing + Packing
-Packing eliminates padding, preserves all sequence information, and allows for flexible sequence lengths, making it a more efficient alternative to truncation. To enable packing, use `packing=True` in the [`SFTConfig`]: +Packing reduces padding by merging several sequences in one row when possible. We use an advanced method to be near-optimal in the way we pack the dataset. To enable packing, use `packing=True` and in the [`SFTConfig`]. + + + +In TRL 0.18 and earlier, packing used a more aggressive method that reduced padding to almost nothing, but had the downside of breaking sequence continuity for a large fraction of the dataset. To revert to this strategy, use `packing_strategy="wrapped"` in `SFTConfig`. + + ```python from trl import SFTConfig diff --git a/tests/test_data_utils.py b/tests/test_data_utils.py index f182951a522..fa6d1731fe3 100644 --- a/tests/test_data_utils.py +++ b/tests/test_data_utils.py @@ -439,7 +439,7 @@ def test_with_dataset(self): self.assertEqual(dataset.to_dict(), expected_output) -class TestPackDataset(unittest.TestCase): +class TestPackDatasetWrapped(unittest.TestCase): def test_with_dataset(self): examples = { "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]], @@ -451,7 +451,7 @@ def test_with_dataset(self): "input_ids": [[1, 2, 3], [4, 5, 6], [7, 8]], "attention_mask": [[0, 1, 1], [0, 0, 1], [1, 1]], } - dataset = pack_dataset(dataset, seq_length) + dataset = pack_dataset(dataset, seq_length, strategy="wrapped") self.assertEqual(dataset.to_dict(), expected_output) def test_with_iterable_dataset(self): @@ -465,11 +465,56 @@ def test_with_iterable_dataset(self): "input_ids": [[1, 2, 3], [4, 5, 6], [7, 8]], "attention_mask": [[0, 1, 1], [0, 0, 1], [1, 1]], } - dataset = pack_dataset(dataset, seq_length) + dataset = pack_dataset(dataset, seq_length, strategy="wrapped") num_examples = len(examples[next(iter(examples))]) self.assertEqual(next(iter(dataset.batch(batch_size=num_examples))), expected_output) +class TestPackDatasetFfd(unittest.TestCase): + def test_simple(self): + examples = { + "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]], + "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]], + } + dataset = Dataset.from_dict(examples) + seq_length = 4 + expected_output = { + "input_ids": [[4, 5, 6, 7], [1, 2, 3, 8]], + "attention_mask": [[0, 0, 1, 1], [0, 1, 1, 1]], + } + dataset = pack_dataset(dataset, seq_length, strategy="ffd") + self.assertEqual(dataset.to_dict(), expected_output) + + def test_with_iterable_dataset(self): + examples = { + "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]], + "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]], + } + dataset = Dataset.from_dict(examples).to_iterable_dataset() + seq_length = 4 + expected_output = { + "input_ids": [[4, 5, 6, 7], [1, 2, 3, 8]], + "attention_mask": [[0, 0, 1, 1], [0, 1, 1, 1]], + } + dataset = pack_dataset(dataset, seq_length, strategy="ffd") + num_examples = len(examples[next(iter(examples))]) + self.assertEqual(next(iter(dataset.batch(batch_size=num_examples))), expected_output) + + def test_with_truncation(self): + examples = { + "input_ids": [[1, 2, 3, 4, 5], [6, 7], [8, 9, 10, 11], [12]], + "attention_mask": [[1, 1, 1, 1, 1], [1, 1], [1, 1, 1, 1], [1]], + } + dataset = Dataset.from_dict(examples) + seq_length = 4 + expected_output = { + "input_ids": [[1, 2, 3, 4], [8, 9, 10, 11], [6, 7, 12]], + "attention_mask": [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1]], + } + dataset = pack_dataset(dataset, seq_length, strategy="ffd") + self.assertEqual(dataset.to_dict(), expected_output) + + class TestTruncateExamples(unittest.TestCase): def test_with_dataset(self): examples = { diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index f7f03afb9ce..7def17bb7da 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -19,6 +19,7 @@ import numpy as np import torch from datasets import Dataset, Image, Sequence, load_dataset +from parameterized import parameterized from transformers import ( AutoModelForCausalLM, AutoProcessor, @@ -812,7 +813,7 @@ def test_only_train_packing(self): per_device_train_batch_size=2, gradient_checkpointing=True, packing=True, - max_length=16, # make sure there is at least 1 packed sequence + max_length=128, # make sure there is at least 1 packed sequence eval_packing=False, report_to="none", ) @@ -824,7 +825,7 @@ def test_only_train_packing(self): eval_dataset=self.conversational_lm_dataset["test"], ) - self.assertEqual(len(trainer.train_dataset["input_ids"]), 46) # w/ this dataset, we end up with 46 seqs + self.assertEqual(len(trainer.train_dataset["input_ids"]), 7) # w/ this dataset, we end up with 46 seqs self.assertEqual(len(trainer.eval_dataset["input_ids"]), len(self.conversational_lm_dataset["test"])) def test_eval_packing(self): @@ -832,7 +833,7 @@ def test_eval_packing(self): training_args = SFTConfig( output_dir=tmp_dir, per_device_train_batch_size=2, - max_length=16, # make sure there is at least 1 packed sequence + max_length=128, # make sure there is at least 1 packed sequence packing=True, report_to="none", ) @@ -843,15 +844,15 @@ def test_eval_packing(self): eval_dataset=self.conversational_lm_dataset["test"], ) - self.assertEqual(len(trainer.train_dataset["input_ids"]), 46) # w/ this dataset, we end up with 46 seqs - self.assertEqual(len(trainer.eval_dataset["input_ids"]), 6) # w/ this dataset, we end up with 6 seqs + self.assertEqual(len(trainer.train_dataset["input_ids"]), 7) # w/ this dataset, we end up with 46 seqs + self.assertEqual(len(trainer.eval_dataset["input_ids"]), 1) # w/ this dataset, we end up with 6 seqs def test_no_packing(self): with tempfile.TemporaryDirectory() as tmp_dir: training_args = SFTConfig( output_dir=tmp_dir, per_device_train_batch_size=2, - max_length=16, # make sure there is at least 1 packed sequence + max_length=128, # make sure there is at least 1 packed sequence packing=False, report_to="none", ) @@ -1229,3 +1230,31 @@ def test_train_padding_free(self): for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + + @parameterized.expand([("ffd",), ("wrapped",)]) + def test_train_packing(self, packing_strategy): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") + + with tempfile.TemporaryDirectory() as tmp_dir: + # Initialize the trainer + training_args = SFTConfig( + output_dir=tmp_dir, packing=True, packing_strategy=packing_strategy, max_length=10, report_to="none" + ) + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") diff --git a/trl/data_utils.py b/trl/data_utils.py index 2612ab2780d..1a3fe523112 100644 --- a/trl/data_utils.py +++ b/trl/data_utils.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import functools +import warnings +from collections import defaultdict from collections.abc import Sequence from typing import Any, Callable, Optional, TypeVar, Union @@ -465,6 +466,11 @@ def pack_examples(examples: dict[str, list[list]], seq_length: int) -> dict[str, {'input_ids': [[1, 2], [3, 4], [5, 6], [7, 8]], 'attention_mask': [[0, 1], [1, 0], [0, 1], [1, 1]]} ``` """ + warnings.warn( + "`pack_examples` is deprecated and will be removed in version 0.20.0. Use `pack_dataset` with a dataset " + "instead.", + DeprecationWarning, + ) # Join all the values into a single list examples = {k: sum(v, []) for k, v in examples.items()} # Split the values into chunks of size seq_length @@ -472,7 +478,105 @@ def pack_examples(examples: dict[str, list[list]], seq_length: int) -> dict[str, return examples -def pack_dataset(dataset: DatasetType, seq_length: int, map_kwargs: Optional[dict[str, Any]] = None) -> DatasetType: +def _pack_ffd(examples: pa.Table, seq_length: int) -> pa.Table: + """Pack sequences in a pyarrow Table using First Fit Decreasing strategy.""" + packed_columns = [] + for column in examples.columns: + if pyarrow.types.is_list(column.type) or pyarrow.types.is_large_list(column.type): + if isinstance(column, pa.ChunkedArray): + column = column.combine_chunks() + offsets, values = column.offsets, column.values + values = values[offsets[0].as_py() : offsets[-1].as_py()] + + # Extract sequences using numpy for vectorized operations + offset_array = offsets.to_numpy() + starts = offset_array[:-1] + ends = offset_array[1:] + seq_lens = ends - starts + + # Vectorized truncation + truncated_lens = np.minimum(seq_lens, seq_length) + truncated_ends = starts + truncated_lens + + # Create sequences list with truncated values + sequences = list(zip(truncated_lens, starts, truncated_ends)) + + # Sort by length (decreasing) for First Fit Decreasing + sequences.sort(key=lambda x: x[0], reverse=True) + + # Optimized bin packing using a priority queue approach + bins_by_remaining = defaultdict(list) # remaining_space -> [bin_indices] + bins = [] # [(current_length, seq_indices)] + + for i, (seq_len, _start, _end) in enumerate(sequences): + # Find bins with enough space using the dictionary + placed = False + for remaining in range(seq_len, seq_length + 1): + if bins_by_remaining[remaining]: + # Use the first available bin with this remaining space + bin_idx = bins_by_remaining[remaining].pop() + current_len, seq_indices = bins[bin_idx] + + # Update bin + new_len = current_len + seq_len + new_remaining = seq_length - new_len + bins[bin_idx] = (new_len, seq_indices + [i]) + + # Update the remaining space mapping + if new_remaining > 0: + bins_by_remaining[new_remaining].append(bin_idx) + + placed = True + break + + # If no bin fits, create new bin + if not placed: + bin_idx = len(bins) + bins.append((seq_len, [i])) + remaining = seq_length - seq_len + if remaining > 0: + bins_by_remaining[remaining].append(bin_idx) + + # Reconstruct packed values more efficiently + values_numpy = values.to_numpy() + packed_values = [] + new_offsets = [0] + + for _, seq_indices in bins: + for seq_idx in seq_indices: + _, start, end = sequences[seq_idx] + packed_values.extend(values_numpy[start:end]) + new_offsets.append(len(packed_values)) + + dtype = offsets.type.to_pandas_dtype() + new_offsets = np.array(new_offsets, dtype=dtype) + packed_values = pa.array(packed_values, type=values.type) + column = type(column).from_arrays(new_offsets, packed_values) + packed_columns.append(column) + return pa.Table.from_arrays(packed_columns, names=examples.column_names) + + +def _pack_wrapped(examples: pa.Table, seq_length: int) -> pa.Table: + """Pack sequences in a pyarrow Table using a wrapped strategy.""" + packed_columns = [] + for column in examples.columns: + if pyarrow.types.is_list(column.type) or pyarrow.types.is_large_list(column.type): + if isinstance(column, pa.ChunkedArray): + column = column.combine_chunks() + offsets, values = column.offsets, column.values + values = values[offsets[0].as_py() : offsets[-1].as_py()] + num_elements = len(values) + dtype = offsets.type.to_pandas_dtype() # np.int32 or np.int64 + offsets = np.arange(0, num_elements, seq_length, dtype=dtype) + offsets = np.concatenate((offsets, [num_elements])) + column = type(column).from_arrays(offsets, values) + packed_columns.append(column) + return pa.Table.from_arrays(packed_columns, names=examples.column_names) + + +def pack_dataset( + dataset: DatasetType, seq_length: int, strategy: str = "ffd", map_kwargs: Optional[dict[str, Any]] = None +) -> DatasetType: r""" Pack sequences in a dataset into chunks of size `seq_length`. @@ -481,6 +585,13 @@ def pack_dataset(dataset: DatasetType, seq_length: int, map_kwargs: Optional[dic Dataset to pack seq_length (`int`): Target sequence length to pack to. + strategy (`str`, *optional*, defaults to `"ffd"`): + Packing strategy to use. Can be either: + + - `"ffd"` (First Fit Decreasing): Slower but preserves sequence boundaries. Sequences are never cut in the + middle. + - `"wrapped"`: Faster but more aggressive. Ignores sequence boundaries and will cut sequences in the middle + to completely fill each packed sequence with data. map_kwargs (`dict` or `None`, *optional*, defaults to `None`): Additional keyword arguments to pass to the dataset's map method when packing examples. @@ -491,46 +602,29 @@ def pack_dataset(dataset: DatasetType, seq_length: int, map_kwargs: Optional[dic Example: ```python >>> from datasets import Dataset + >>> from trl import pack_dataset >>> examples = { - ... "input_ids": [[1, 2], [3, 4], [5, 6], [7]], - ... "attention_mask": [[1, 1], [0, 1], [1, 1], [1]], + ... "input_ids": [[1, 2, 3], [4, 5], [6, 7, 8], [9]], + ... "attention_mask": [[1, 1, 0], [1, 0], [1, 0, 0], [1]] ... } >>> dataset = Dataset.from_dict(examples) - >>> packed_dataset = pack_dataset(dataset, seq_length=4) + >>> packed_dataset = pack_dataset(dataset, seq_length=4, strategy="ffd") >>> packed_dataset[:] - {'input_ids': [[1, 2, 3, 4], [5, 6, 7]], - 'attention_mask': [[1, 1, 0, 1], [1, 1, 1]]} + {'input_ids': [[1, 2, 3, 9], [6, 7, 8, 4, 5]], + 'attention_mask': [[1, 1, 0, 1], [1, 0, 0, 1, 0]]} ``` """ if map_kwargs is None: map_kwargs = {} - if isinstance(dataset, Dataset): - # Fast packing with pyarrow - def pack(examples): - packed_columns = [] - for column in examples.columns: - if pyarrow.types.is_list(column.type) or pyarrow.types.is_large_list(column.type): - if isinstance(column, pa.ChunkedArray): - column = column.combine_chunks() - offsets, values = column.offsets, column.values - values = values[offsets[0].as_py() : offsets[-1].as_py()] - num_elements = len(values) - dtype = offsets.type.to_pandas_dtype() # np.int32 or np.int64 - offsets = np.arange(0, num_elements, seq_length, dtype=dtype) - offsets = np.concatenate((offsets, [num_elements])) - column = type(column).from_arrays(offsets, values) - packed_columns.append(column) - return pa.Table.from_arrays(packed_columns, names=examples.column_names) - - dataset = dataset.with_format("arrow") - dataset = dataset.map(pack, batched=True, **map_kwargs) - dataset = dataset.with_format(None) + # Fast packing with pyarrow + dataset = dataset.with_format("arrow") + if strategy == "ffd": + dataset = dataset.map(_pack_ffd, batched=True, fn_kwargs={"seq_length": seq_length}, **map_kwargs) + elif strategy == "wrapped": + dataset = dataset.map(_pack_wrapped, batched=True, fn_kwargs={"seq_length": seq_length}, **map_kwargs) else: - dataset = dataset.map( - functools.partial(pack_examples, seq_length=seq_length), - batched=True, - **map_kwargs, - ) + raise ValueError(f"Invalid packing strategy: {strategy}. Use 'ffd' or 'wrapped'.") + dataset = dataset.with_format(None) return dataset diff --git a/trl/trainer/sft_config.py b/trl/trainer/sft_config.py index 5615e7733a0..0b1783bd640 100644 --- a/trl/trainer/sft_config.py +++ b/trl/trainer/sft_config.py @@ -56,7 +56,10 @@ class SFTConfig(TrainingArguments): Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from the right. If `None`, no truncation is applied. When packing is enabled, this value sets the sequence length. packing (`bool`, *optional*, defaults to `False`): - Whether to pack multiple sequences into a fixed-length format. Uses `max_length` to define sequence length. + Whether to group multiple sequences into fixed-length blocks to improve computational efficiency and reduce + padding. Uses `max_length` to define sequence length. + packing_strategy (`str`, *optional*, defaults to `"ffd"`): + Strategy for packing sequences. Can be either `"ffd"` (first-fit decreasing, default), or `"wrapped"`. padding_free (`bool`, *optional*, defaults to `False`): Whether to perform forward passes without padding by flattening all sequences in the batch into a single continuous sequence. This reduces memory usage by eliminating padding overhead. Currently, this is only @@ -132,8 +135,15 @@ class SFTConfig(TrainingArguments): packing: bool = field( default=False, metadata={ - "help": "Whether to pack multiple sequences into a fixed-length format. Uses `max_length` to define " - "sequence length." + "help": "Whether to group multiple sequences into fixed-length blocks to improve computational efficiency " + "and reduce padding. Uses `max_length` to define sequence length." + }, + ) + packing_strategy: str = field( + default="ffd", + metadata={ + "help": "Strategy for packing sequences. Can be either `'ffd'` (first-fit decreasing, default), or " + "`'wrapped'`." }, ) padding_free: bool = field( diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index ef0cf79fbf7..f4c21d09923 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -660,7 +660,7 @@ def tokenize(example, processing_class, dataset_text_field, add_special_tokens): if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` map_kwargs["desc"] = f"Packing {dataset_name} dataset" dataset = dataset.select_columns("input_ids") - dataset = pack_dataset(dataset, args.max_length, map_kwargs) + dataset = pack_dataset(dataset, args.max_length, args.packing_strategy, map_kwargs) elif args.max_length is not None: if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` map_kwargs["desc"] = f"Truncating {dataset_name} dataset"