diff --git a/docs/source/data_utils.md b/docs/source/data_utils.md
index c862aa2cfa1..e4acfbb41b8 100644
--- a/docs/source/data_utils.md
+++ b/docs/source/data_utils.md
@@ -32,10 +32,6 @@
[[autodoc]] maybe_unpair_preference_dataset
-## pack_examples
-
-[[autodoc]] pack_examples
-
## pack_dataset
[[autodoc]] pack_dataset
diff --git a/docs/source/reducing_memory_usage.md b/docs/source/reducing_memory_usage.md
index acf5a798571..4dcc97d7a51 100644
--- a/docs/source/reducing_memory_usage.md
+++ b/docs/source/reducing_memory_usage.md
@@ -77,10 +77,16 @@ This technique applies only to SFT.
Packing, introduced in [Raffel et al., 2020](https://huggingface.co/papers/1910.10683), addresses these issues by grouping sequences instead of truncating. It concatenates and splits dataset sequences into the desired lengths.
-

+
-Packing eliminates padding, preserves all sequence information, and allows for flexible sequence lengths, making it a more efficient alternative to truncation. To enable packing, use `packing=True` in the [`SFTConfig`]:
+Packing reduces padding by merging several sequences in one row when possible. We use an advanced method to be near-optimal in the way we pack the dataset. To enable packing, use `packing=True` and in the [`SFTConfig`].
+
+
+
+In TRL 0.18 and earlier, packing used a more aggressive method that reduced padding to almost nothing, but had the downside of breaking sequence continuity for a large fraction of the dataset. To revert to this strategy, use `packing_strategy="wrapped"` in `SFTConfig`.
+
+
```python
from trl import SFTConfig
diff --git a/tests/test_data_utils.py b/tests/test_data_utils.py
index f182951a522..fa6d1731fe3 100644
--- a/tests/test_data_utils.py
+++ b/tests/test_data_utils.py
@@ -439,7 +439,7 @@ def test_with_dataset(self):
self.assertEqual(dataset.to_dict(), expected_output)
-class TestPackDataset(unittest.TestCase):
+class TestPackDatasetWrapped(unittest.TestCase):
def test_with_dataset(self):
examples = {
"input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
@@ -451,7 +451,7 @@ def test_with_dataset(self):
"input_ids": [[1, 2, 3], [4, 5, 6], [7, 8]],
"attention_mask": [[0, 1, 1], [0, 0, 1], [1, 1]],
}
- dataset = pack_dataset(dataset, seq_length)
+ dataset = pack_dataset(dataset, seq_length, strategy="wrapped")
self.assertEqual(dataset.to_dict(), expected_output)
def test_with_iterable_dataset(self):
@@ -465,11 +465,56 @@ def test_with_iterable_dataset(self):
"input_ids": [[1, 2, 3], [4, 5, 6], [7, 8]],
"attention_mask": [[0, 1, 1], [0, 0, 1], [1, 1]],
}
- dataset = pack_dataset(dataset, seq_length)
+ dataset = pack_dataset(dataset, seq_length, strategy="wrapped")
num_examples = len(examples[next(iter(examples))])
self.assertEqual(next(iter(dataset.batch(batch_size=num_examples))), expected_output)
+class TestPackDatasetFfd(unittest.TestCase):
+ def test_simple(self):
+ examples = {
+ "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
+ "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
+ }
+ dataset = Dataset.from_dict(examples)
+ seq_length = 4
+ expected_output = {
+ "input_ids": [[4, 5, 6, 7], [1, 2, 3, 8]],
+ "attention_mask": [[0, 0, 1, 1], [0, 1, 1, 1]],
+ }
+ dataset = pack_dataset(dataset, seq_length, strategy="ffd")
+ self.assertEqual(dataset.to_dict(), expected_output)
+
+ def test_with_iterable_dataset(self):
+ examples = {
+ "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
+ "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
+ }
+ dataset = Dataset.from_dict(examples).to_iterable_dataset()
+ seq_length = 4
+ expected_output = {
+ "input_ids": [[4, 5, 6, 7], [1, 2, 3, 8]],
+ "attention_mask": [[0, 0, 1, 1], [0, 1, 1, 1]],
+ }
+ dataset = pack_dataset(dataset, seq_length, strategy="ffd")
+ num_examples = len(examples[next(iter(examples))])
+ self.assertEqual(next(iter(dataset.batch(batch_size=num_examples))), expected_output)
+
+ def test_with_truncation(self):
+ examples = {
+ "input_ids": [[1, 2, 3, 4, 5], [6, 7], [8, 9, 10, 11], [12]],
+ "attention_mask": [[1, 1, 1, 1, 1], [1, 1], [1, 1, 1, 1], [1]],
+ }
+ dataset = Dataset.from_dict(examples)
+ seq_length = 4
+ expected_output = {
+ "input_ids": [[1, 2, 3, 4], [8, 9, 10, 11], [6, 7, 12]],
+ "attention_mask": [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1]],
+ }
+ dataset = pack_dataset(dataset, seq_length, strategy="ffd")
+ self.assertEqual(dataset.to_dict(), expected_output)
+
+
class TestTruncateExamples(unittest.TestCase):
def test_with_dataset(self):
examples = {
diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py
index f7f03afb9ce..7def17bb7da 100644
--- a/tests/test_sft_trainer.py
+++ b/tests/test_sft_trainer.py
@@ -19,6 +19,7 @@
import numpy as np
import torch
from datasets import Dataset, Image, Sequence, load_dataset
+from parameterized import parameterized
from transformers import (
AutoModelForCausalLM,
AutoProcessor,
@@ -812,7 +813,7 @@ def test_only_train_packing(self):
per_device_train_batch_size=2,
gradient_checkpointing=True,
packing=True,
- max_length=16, # make sure there is at least 1 packed sequence
+ max_length=128, # make sure there is at least 1 packed sequence
eval_packing=False,
report_to="none",
)
@@ -824,7 +825,7 @@ def test_only_train_packing(self):
eval_dataset=self.conversational_lm_dataset["test"],
)
- self.assertEqual(len(trainer.train_dataset["input_ids"]), 46) # w/ this dataset, we end up with 46 seqs
+ self.assertEqual(len(trainer.train_dataset["input_ids"]), 7) # w/ this dataset, we end up with 46 seqs
self.assertEqual(len(trainer.eval_dataset["input_ids"]), len(self.conversational_lm_dataset["test"]))
def test_eval_packing(self):
@@ -832,7 +833,7 @@ def test_eval_packing(self):
training_args = SFTConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
- max_length=16, # make sure there is at least 1 packed sequence
+ max_length=128, # make sure there is at least 1 packed sequence
packing=True,
report_to="none",
)
@@ -843,15 +844,15 @@ def test_eval_packing(self):
eval_dataset=self.conversational_lm_dataset["test"],
)
- self.assertEqual(len(trainer.train_dataset["input_ids"]), 46) # w/ this dataset, we end up with 46 seqs
- self.assertEqual(len(trainer.eval_dataset["input_ids"]), 6) # w/ this dataset, we end up with 6 seqs
+ self.assertEqual(len(trainer.train_dataset["input_ids"]), 7) # w/ this dataset, we end up with 46 seqs
+ self.assertEqual(len(trainer.eval_dataset["input_ids"]), 1) # w/ this dataset, we end up with 6 seqs
def test_no_packing(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = SFTConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
- max_length=16, # make sure there is at least 1 packed sequence
+ max_length=128, # make sure there is at least 1 packed sequence
packing=False,
report_to="none",
)
@@ -1229,3 +1230,31 @@ def test_train_padding_free(self):
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
+
+ @parameterized.expand([("ffd",), ("wrapped",)])
+ def test_train_packing(self, packing_strategy):
+ # Get the dataset
+ dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
+
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ # Initialize the trainer
+ training_args = SFTConfig(
+ output_dir=tmp_dir, packing=True, packing_strategy=packing_strategy, max_length=10, report_to="none"
+ )
+ trainer = SFTTrainer(
+ model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset
+ )
+
+ # Save the initial parameters to compare them later
+ previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
+
+ # Train the model
+ trainer.train()
+
+ # Check that the training loss is not None
+ self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+
+ # Check the params have changed
+ for n, param in previous_trainable_params.items():
+ new_param = trainer.model.get_parameter(n)
+ self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
diff --git a/trl/data_utils.py b/trl/data_utils.py
index 2612ab2780d..1a3fe523112 100644
--- a/trl/data_utils.py
+++ b/trl/data_utils.py
@@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import functools
+import warnings
+from collections import defaultdict
from collections.abc import Sequence
from typing import Any, Callable, Optional, TypeVar, Union
@@ -465,6 +466,11 @@ def pack_examples(examples: dict[str, list[list]], seq_length: int) -> dict[str,
{'input_ids': [[1, 2], [3, 4], [5, 6], [7, 8]], 'attention_mask': [[0, 1], [1, 0], [0, 1], [1, 1]]}
```
"""
+ warnings.warn(
+ "`pack_examples` is deprecated and will be removed in version 0.20.0. Use `pack_dataset` with a dataset "
+ "instead.",
+ DeprecationWarning,
+ )
# Join all the values into a single list
examples = {k: sum(v, []) for k, v in examples.items()}
# Split the values into chunks of size seq_length
@@ -472,7 +478,105 @@ def pack_examples(examples: dict[str, list[list]], seq_length: int) -> dict[str,
return examples
-def pack_dataset(dataset: DatasetType, seq_length: int, map_kwargs: Optional[dict[str, Any]] = None) -> DatasetType:
+def _pack_ffd(examples: pa.Table, seq_length: int) -> pa.Table:
+ """Pack sequences in a pyarrow Table using First Fit Decreasing strategy."""
+ packed_columns = []
+ for column in examples.columns:
+ if pyarrow.types.is_list(column.type) or pyarrow.types.is_large_list(column.type):
+ if isinstance(column, pa.ChunkedArray):
+ column = column.combine_chunks()
+ offsets, values = column.offsets, column.values
+ values = values[offsets[0].as_py() : offsets[-1].as_py()]
+
+ # Extract sequences using numpy for vectorized operations
+ offset_array = offsets.to_numpy()
+ starts = offset_array[:-1]
+ ends = offset_array[1:]
+ seq_lens = ends - starts
+
+ # Vectorized truncation
+ truncated_lens = np.minimum(seq_lens, seq_length)
+ truncated_ends = starts + truncated_lens
+
+ # Create sequences list with truncated values
+ sequences = list(zip(truncated_lens, starts, truncated_ends))
+
+ # Sort by length (decreasing) for First Fit Decreasing
+ sequences.sort(key=lambda x: x[0], reverse=True)
+
+ # Optimized bin packing using a priority queue approach
+ bins_by_remaining = defaultdict(list) # remaining_space -> [bin_indices]
+ bins = [] # [(current_length, seq_indices)]
+
+ for i, (seq_len, _start, _end) in enumerate(sequences):
+ # Find bins with enough space using the dictionary
+ placed = False
+ for remaining in range(seq_len, seq_length + 1):
+ if bins_by_remaining[remaining]:
+ # Use the first available bin with this remaining space
+ bin_idx = bins_by_remaining[remaining].pop()
+ current_len, seq_indices = bins[bin_idx]
+
+ # Update bin
+ new_len = current_len + seq_len
+ new_remaining = seq_length - new_len
+ bins[bin_idx] = (new_len, seq_indices + [i])
+
+ # Update the remaining space mapping
+ if new_remaining > 0:
+ bins_by_remaining[new_remaining].append(bin_idx)
+
+ placed = True
+ break
+
+ # If no bin fits, create new bin
+ if not placed:
+ bin_idx = len(bins)
+ bins.append((seq_len, [i]))
+ remaining = seq_length - seq_len
+ if remaining > 0:
+ bins_by_remaining[remaining].append(bin_idx)
+
+ # Reconstruct packed values more efficiently
+ values_numpy = values.to_numpy()
+ packed_values = []
+ new_offsets = [0]
+
+ for _, seq_indices in bins:
+ for seq_idx in seq_indices:
+ _, start, end = sequences[seq_idx]
+ packed_values.extend(values_numpy[start:end])
+ new_offsets.append(len(packed_values))
+
+ dtype = offsets.type.to_pandas_dtype()
+ new_offsets = np.array(new_offsets, dtype=dtype)
+ packed_values = pa.array(packed_values, type=values.type)
+ column = type(column).from_arrays(new_offsets, packed_values)
+ packed_columns.append(column)
+ return pa.Table.from_arrays(packed_columns, names=examples.column_names)
+
+
+def _pack_wrapped(examples: pa.Table, seq_length: int) -> pa.Table:
+ """Pack sequences in a pyarrow Table using a wrapped strategy."""
+ packed_columns = []
+ for column in examples.columns:
+ if pyarrow.types.is_list(column.type) or pyarrow.types.is_large_list(column.type):
+ if isinstance(column, pa.ChunkedArray):
+ column = column.combine_chunks()
+ offsets, values = column.offsets, column.values
+ values = values[offsets[0].as_py() : offsets[-1].as_py()]
+ num_elements = len(values)
+ dtype = offsets.type.to_pandas_dtype() # np.int32 or np.int64
+ offsets = np.arange(0, num_elements, seq_length, dtype=dtype)
+ offsets = np.concatenate((offsets, [num_elements]))
+ column = type(column).from_arrays(offsets, values)
+ packed_columns.append(column)
+ return pa.Table.from_arrays(packed_columns, names=examples.column_names)
+
+
+def pack_dataset(
+ dataset: DatasetType, seq_length: int, strategy: str = "ffd", map_kwargs: Optional[dict[str, Any]] = None
+) -> DatasetType:
r"""
Pack sequences in a dataset into chunks of size `seq_length`.
@@ -481,6 +585,13 @@ def pack_dataset(dataset: DatasetType, seq_length: int, map_kwargs: Optional[dic
Dataset to pack
seq_length (`int`):
Target sequence length to pack to.
+ strategy (`str`, *optional*, defaults to `"ffd"`):
+ Packing strategy to use. Can be either:
+
+ - `"ffd"` (First Fit Decreasing): Slower but preserves sequence boundaries. Sequences are never cut in the
+ middle.
+ - `"wrapped"`: Faster but more aggressive. Ignores sequence boundaries and will cut sequences in the middle
+ to completely fill each packed sequence with data.
map_kwargs (`dict` or `None`, *optional*, defaults to `None`):
Additional keyword arguments to pass to the dataset's map method when packing examples.
@@ -491,46 +602,29 @@ def pack_dataset(dataset: DatasetType, seq_length: int, map_kwargs: Optional[dic
Example:
```python
>>> from datasets import Dataset
+ >>> from trl import pack_dataset
>>> examples = {
- ... "input_ids": [[1, 2], [3, 4], [5, 6], [7]],
- ... "attention_mask": [[1, 1], [0, 1], [1, 1], [1]],
+ ... "input_ids": [[1, 2, 3], [4, 5], [6, 7, 8], [9]],
+ ... "attention_mask": [[1, 1, 0], [1, 0], [1, 0, 0], [1]]
... }
>>> dataset = Dataset.from_dict(examples)
- >>> packed_dataset = pack_dataset(dataset, seq_length=4)
+ >>> packed_dataset = pack_dataset(dataset, seq_length=4, strategy="ffd")
>>> packed_dataset[:]
- {'input_ids': [[1, 2, 3, 4], [5, 6, 7]],
- 'attention_mask': [[1, 1, 0, 1], [1, 1, 1]]}
+ {'input_ids': [[1, 2, 3, 9], [6, 7, 8, 4, 5]],
+ 'attention_mask': [[1, 1, 0, 1], [1, 0, 0, 1, 0]]}
```
"""
if map_kwargs is None:
map_kwargs = {}
- if isinstance(dataset, Dataset):
- # Fast packing with pyarrow
- def pack(examples):
- packed_columns = []
- for column in examples.columns:
- if pyarrow.types.is_list(column.type) or pyarrow.types.is_large_list(column.type):
- if isinstance(column, pa.ChunkedArray):
- column = column.combine_chunks()
- offsets, values = column.offsets, column.values
- values = values[offsets[0].as_py() : offsets[-1].as_py()]
- num_elements = len(values)
- dtype = offsets.type.to_pandas_dtype() # np.int32 or np.int64
- offsets = np.arange(0, num_elements, seq_length, dtype=dtype)
- offsets = np.concatenate((offsets, [num_elements]))
- column = type(column).from_arrays(offsets, values)
- packed_columns.append(column)
- return pa.Table.from_arrays(packed_columns, names=examples.column_names)
-
- dataset = dataset.with_format("arrow")
- dataset = dataset.map(pack, batched=True, **map_kwargs)
- dataset = dataset.with_format(None)
+ # Fast packing with pyarrow
+ dataset = dataset.with_format("arrow")
+ if strategy == "ffd":
+ dataset = dataset.map(_pack_ffd, batched=True, fn_kwargs={"seq_length": seq_length}, **map_kwargs)
+ elif strategy == "wrapped":
+ dataset = dataset.map(_pack_wrapped, batched=True, fn_kwargs={"seq_length": seq_length}, **map_kwargs)
else:
- dataset = dataset.map(
- functools.partial(pack_examples, seq_length=seq_length),
- batched=True,
- **map_kwargs,
- )
+ raise ValueError(f"Invalid packing strategy: {strategy}. Use 'ffd' or 'wrapped'.")
+ dataset = dataset.with_format(None)
return dataset
diff --git a/trl/trainer/sft_config.py b/trl/trainer/sft_config.py
index 5615e7733a0..0b1783bd640 100644
--- a/trl/trainer/sft_config.py
+++ b/trl/trainer/sft_config.py
@@ -56,7 +56,10 @@ class SFTConfig(TrainingArguments):
Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from the right.
If `None`, no truncation is applied. When packing is enabled, this value sets the sequence length.
packing (`bool`, *optional*, defaults to `False`):
- Whether to pack multiple sequences into a fixed-length format. Uses `max_length` to define sequence length.
+ Whether to group multiple sequences into fixed-length blocks to improve computational efficiency and reduce
+ padding. Uses `max_length` to define sequence length.
+ packing_strategy (`str`, *optional*, defaults to `"ffd"`):
+ Strategy for packing sequences. Can be either `"ffd"` (first-fit decreasing, default), or `"wrapped"`.
padding_free (`bool`, *optional*, defaults to `False`):
Whether to perform forward passes without padding by flattening all sequences in the batch into a single
continuous sequence. This reduces memory usage by eliminating padding overhead. Currently, this is only
@@ -132,8 +135,15 @@ class SFTConfig(TrainingArguments):
packing: bool = field(
default=False,
metadata={
- "help": "Whether to pack multiple sequences into a fixed-length format. Uses `max_length` to define "
- "sequence length."
+ "help": "Whether to group multiple sequences into fixed-length blocks to improve computational efficiency "
+ "and reduce padding. Uses `max_length` to define sequence length."
+ },
+ )
+ packing_strategy: str = field(
+ default="ffd",
+ metadata={
+ "help": "Strategy for packing sequences. Can be either `'ffd'` (first-fit decreasing, default), or "
+ "`'wrapped'`."
},
)
padding_free: bool = field(
diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py
index ef0cf79fbf7..f4c21d09923 100644
--- a/trl/trainer/sft_trainer.py
+++ b/trl/trainer/sft_trainer.py
@@ -660,7 +660,7 @@ def tokenize(example, processing_class, dataset_text_field, add_special_tokens):
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Packing {dataset_name} dataset"
dataset = dataset.select_columns("input_ids")
- dataset = pack_dataset(dataset, args.max_length, map_kwargs)
+ dataset = pack_dataset(dataset, args.max_length, args.packing_strategy, map_kwargs)
elif args.max_length is not None:
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Truncating {dataset_name} dataset"