[bugfix] fix cache_dataset truncation_strategy#9210
Conversation
There was a problem hiding this comment.
Code Review
This pull request updates the dataset selection logic in swift/pipelines/utils.py to support configurable truncation strategies, including a 'delete' strategy and a mapping-based truncation. The review feedback identifies potential runtime errors, such as a TypeError when handling scalar length values in the truncation function and a ValueError when calling max() on empty lists. Additionally, it is recommended to use sum() instead of max() when filtering multi-turn conversation segments to accurately reflect the total sequence length against the maximum limit.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request updates the _select_dataset function in swift/pipelines/utils.py to support a configurable truncation strategy. When the strategy is set to 'delete', the dataset is filtered based on maximum length; otherwise, the lengths are truncated using a mapping function. The review feedback highlights a potential TypeError in the mapping logic when handling scalar values and suggests resolving variable shadowing. Additionally, it recommends using dataset.filter with multi-processing to improve performance and safely handle empty lists in the deletion logic.
| new_dataset = dataset.map( | ||
| lambda rows: {'lengths': [[min(length, max_length) for length in lengths] for lengths in rows['lengths']]}, | ||
| num_proc=args.dataset_num_proc, | ||
| load_from_cache_file=args.load_from_cache_file, | ||
| batched=True) |
There was a problem hiding this comment.
There are two issues in this map call:
- Potential Crash: The lambda assumes
row_lengthsis always a list (via the inner loopfor length in lengths). However, the 'delete' strategy logic (line 58) indicates that 'lengths' can also be scalars. If a scalar is encountered, this will raise aTypeError: 'int' object is not iterable. - Variable Shadowing: The variable name
lengthsis used for both the batch of lengths and the individual sample's lengths, which reduces readability and can lead to bugs.
The suggested fix handles both scalars and lists while avoiding variable shadowing.
| new_dataset = dataset.map( | |
| lambda rows: {'lengths': [[min(length, max_length) for length in lengths] for lengths in rows['lengths']]}, | |
| num_proc=args.dataset_num_proc, | |
| load_from_cache_file=args.load_from_cache_file, | |
| batched=True) | |
| new_dataset = dataset.map( | |
| lambda rows: {'lengths': [([min(l, max_length) for l in row_lengths] if isinstance(row_lengths, list) else min(row_lengths, max_length)) for row_lengths in rows['lengths']]}, | |
| num_proc=args.dataset_num_proc, | |
| load_from_cache_file=args.load_from_cache_file, | |
| batched=True) |
No description provided.