[dataset] support "loss_scale" in dataset#9214
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces several enhancements to dataset handling and loss scaling. Key changes include the addition of a loss_scale field in datasets, a new disable_auto_column_mapping parameter, and support for is_binary_loss_scale to optimize memory usage with liger_kernel. The LossScale and Template classes were updated to handle aggregated message fields when merging consecutive assistant responses. Feedback highlights a potential risk where merging message content into lists might break downstream string-based processing, and identifies an unused import in swift/template/base.py.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces support for a per-message loss_scale field in datasets and adds a disable_auto_column_mapping parameter to control automatic dataset column renaming. It also implements is_binary_loss_scale to optimize memory usage via liger_kernel when loss scales are binary. Key changes include updating the dataset preprocessor schema, refining message merging logic to handle multiple loss scales, and implementing lazy loading for templates. Feedback highlights a potential TypeError when processing merged assistant messages and concerns regarding backward compatibility for the loss field schema change.
| if self.template_backend == 'swift' and pre_role == 'assistant': | ||
| new_message = {} | ||
| for key in ['content', 'loss', 'loss_scale']: | ||
| values = [] | ||
| for msg in [pre_message, message]: | ||
| value = msg.get(key) | ||
| if not isinstance(value, list): | ||
| value = [value] | ||
| values += value | ||
| new_message[key] = values | ||
| pre_message.update(new_message) |
There was a problem hiding this comment.
When merging consecutive assistant messages, the content field is now converted into a list of strings. However, the downstream logic in _swift_encode (specifically around line 1240) expects response to be either a string or a list of token IDs (integers). If it receives a list of strings, tokenizer.decode will be called on strings, leading to a TypeError. You should ensure that if response is a list of strings, it is handled correctly (e.g., by joining them or updating the decoding logic).
| 'role': Value(dtype='string'), | ||
| 'content': Value(dtype='string'), | ||
| 'loss': Value(dtype='float64'), | ||
| 'loss': Value(dtype='bool'), |
There was a problem hiding this comment.
Changing the loss field type from float64 to bool in the Arrow schema might cause issues with existing cached datasets or datasets that use 0.0/1.0 for this field. While bool is more semantic for a toggle, ensure that this change doesn't break compatibility with existing data that might be loaded as floats.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces several enhancements to dataset handling and loss calculation, including the addition of a loss_scale field for individual messages and a disable_auto_column_mapping parameter to control automatic dataset column mapping. It also adds the is_binary_loss_scale parameter to optimize memory usage when compatible with liger_kernel. The internal logic for merging consecutive assistant messages and calculating loss scales was refactored to support these new fields. Review feedback identified critical issues in the handling of token ID lists during context processing and message merging, where broad type checks could lead to data corruption or crashes.
No description provided.