-
Notifications
You must be signed in to change notification settings - Fork 903
[megatron] update megatron kto #6036
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[megatron] update megatron kto #6036
Conversation
Summary of ChangesHello @Jintao-Huang, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request undertakes a significant refactoring of the data batch generation and iteration mechanisms within the Megatron training framework. The primary goal is to enhance modularity and provide more context-aware batch processing by relocating common and specialized batch preparation logic from global utility functions into their respective trainer classes. This change streamlines how data batches are prepared and supplied to models during both training and evaluation, with a particular focus on improving the handling of KTO (Kahn-Tversky Optimization) training workflows. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors the batch retrieval logic by moving the get_batch
and get_kto_batch
functions into methods within the trainer classes. This is a good design improvement that enhances polymorphism. However, I've found a critical issue in MegatronKTOTrainer
where the get_batch
method signature is incorrect, which will lead to a runtime error. My review includes a specific comment and a code suggestion to fix this.
return super().evaluate(forward_step_func, data_iterator, model, process_non_loss_data_func, config, verbose, | ||
non_loss_data_func) | ||
|
||
def get_batch(data_iterator): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The get_batch
method is missing the self
parameter. As an instance method, it should be defined as def get_batch(self, data_iterator):
. Without self
, the call self.get_batch(data_iterator)
at line 185 will fail at runtime with a TypeError
.
def get_batch(data_iterator): | |
def get_batch(self, data_iterator): |
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for KTO (Kahneman-Tversky Optimization) to Megatron-SWIFT. The changes include updating documentation, refactoring argument parsing, and implementing the KTO trainer. The refactoring of the trainer logic is a good improvement. However, I've identified two critical bugs and one high-severity issue in the new MegatronKTOTrainer
implementation that could lead to runtime errors. These issues are related to incorrect function arguments, unconditional logic that should be conditional, and a potential division-by-zero error. Please see the detailed comments for suggestions on how to fix them.
policy_logps = self._kto_get_logps(output_tensor, data, False, False, length) | ||
ref_logps = self._kto_get_logps(output_tensor, data, False, True, length) | ||
policy_KL_logps = self._kto_get_logps(output_tensor, kl_data, True, False, length) | ||
ref_KL_logps = self._kto_get_logps(output_tensor, kl_data, True, True, length) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The computation of policy_KL_logps
and ref_KL_logps
is unconditional. This will cause a runtime error when self.args.calculate_KL
is False
, because the output_tensor
passed to loss_func
will not contain the KL-related parts, leading to an out-of-bounds slice in _kto_get_logps
.
The computation of these log probabilities should be guarded by if self.args.calculate_KL:
, and they should be set to None
otherwise, as the trl.KTOTrainer.kto_loss
function can handle None
values for these arguments.
policy_logps = self._kto_get_logps(output_tensor, data, False, False, length) | |
ref_logps = self._kto_get_logps(output_tensor, data, False, True, length) | |
policy_KL_logps = self._kto_get_logps(output_tensor, kl_data, True, False, length) | |
ref_KL_logps = self._kto_get_logps(output_tensor, kl_data, True, True, length) | |
policy_logps = self._kto_get_logps(output_tensor, data, False, False, length) | |
ref_logps = self._kto_get_logps(output_tensor, data, False, True, length) | |
if self.args.calculate_KL: | |
policy_KL_logps = self._kto_get_logps(output_tensor, kl_data, True, False, length) | |
ref_KL_logps = self._kto_get_logps(output_tensor, kl_data, True, True, length) | |
else: | |
policy_KL_logps = None | |
ref_KL_logps = None |
with torch.no_grad(), self.null_ref_context() as ref_models: | ||
ref_model = ref_models[vp_stage or 0] | ||
if self.args.calculate_KL: | ||
if input_tensor is not None: | ||
ref_model.set_input_tensor(self._get_input_tensor(True, True, length)) | ||
ref_KL_output_tensor = ref_model(**kl_data) | ||
|
||
if input_tensor is not None: | ||
ref_model.set_input_tensor(self._get_input_tensor(True, False, length)) | ||
ref_output_tensor = ref_model(**data) | ||
|
||
if self.args.calculate_KL: | ||
with torch.no_grad(): | ||
kl_output_tensor = self._forward_step_helper(ref_model, kl_inputs) | ||
data['reference_KL_logps'] = self.get_logps(kl_output_tensor, data['KL_completion_labels'], | ||
data.get('KL_completion_packed_seq_params')) | ||
if input_tensor is not None: | ||
unwrapped_model.set_input_tensor(self._get_input_tensor(input_tensor, False, True)) | ||
KL_output_tensor = model(**kl_data) | ||
|
||
if input_tensor is not None: | ||
unwrapped_model.set_input_tensor(self._get_input_tensor(input_tensor, False, False)) | ||
with self.stimer: | ||
output_tensor = model(**data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are several issues with the calls to _get_input_tensor
within the forward_step
method, which will lead to runtime errors:
- In lines 120 and 124, the first argument passed is the boolean
True
instead of theinput_tensor
. - In lines 120, 124, 130, and 134, the
dim
argument is missing. The_get_input_tensor
method expectsdim
as its last argument, which should be1
as the tensors are concatenated alongdim=1
. - In lines 130 and 134, the
length
argument is missing.
These appear to be typos and omissions that need to be corrected for the logic to work as intended.
with torch.no_grad(), self.null_ref_context() as ref_models:
ref_model = ref_models[vp_stage or 0]
if self.args.calculate_KL:
if input_tensor is not None:
ref_model.set_input_tensor(self._get_input_tensor(input_tensor, True, True, length, dim=1))
ref_KL_output_tensor = ref_model(**kl_data)
if input_tensor is not None:
ref_model.set_input_tensor(self._get_input_tensor(input_tensor, True, False, length, dim=1))
ref_output_tensor = ref_model(**data)
if self.args.calculate_KL:
with torch.no_grad():
if input_tensor is not None:
unwrapped_model.set_input_tensor(self._get_input_tensor(input_tensor, False, True, length, dim=1))
KL_output_tensor = model(**kl_data)
if input_tensor is not None:
unwrapped_model.set_input_tensor(self._get_input_tensor(input_tensor, False, False, length, dim=1))
with self.stimer:
output_tensor = model(**data)
for key in ['chosen', 'rejected']: | ||
count = total_loss_dict[f'{prefix}count/{key}'] | ||
res[f'{prefix}logps/{key}'] = total_loss_dict[f'{prefix}logps/{key}_sum'] / count | ||
res[f'{prefix}rewards/{key}'] = total_loss_dict[f'{prefix}rewards/{key}_sum'] / count | ||
res[f'{prefix}rewards/margins'] = res[f'{prefix}rewards/chosen'] - res[f'{prefix}rewards/rejected'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the custom_log
method, there is a potential ZeroDivisionError
. If a batch (or accumulated batches) contains only desirable or only undesirable samples, the count
for the other type will be zero, leading to division by zero when calculating the mean for logps
and rewards
.
You should add a check to prevent division by zero.
for key in ['chosen', 'rejected']: | |
count = total_loss_dict[f'{prefix}count/{key}'] | |
res[f'{prefix}logps/{key}'] = total_loss_dict[f'{prefix}logps/{key}_sum'] / count | |
res[f'{prefix}rewards/{key}'] = total_loss_dict[f'{prefix}rewards/{key}_sum'] / count | |
res[f'{prefix}rewards/margins'] = res[f'{prefix}rewards/chosen'] - res[f'{prefix}rewards/rejected'] | |
for key in ['chosen', 'rejected']: | |
count = total_loss_dict[f'{prefix}count/{key}'] | |
if count.item() > 0: | |
res[f'{prefix}logps/{key}'] = total_loss_dict[f'{prefix}logps/{key}_sum'] / count | |
res[f'{prefix}rewards/{key}'] = total_loss_dict[f'{prefix}rewards/{key}_sum'] / count | |
else: | |
res[f'{prefix}logps/{key}'] = torch.tensor(0.0, device=count.device) | |
res[f'{prefix}rewards/{key}'] = torch.tensor(0.0, device=count.device) | |
res[f'{prefix}rewards/margins'] = res[f'{prefix}rewards/chosen'] - res[f'{prefix}rewards/rejected'] |
/gemini review |
Warning Gemini encountered an error creating the review. You can try again by commenting |
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for KTO (Kahneman-Tversky Optimization) within the Megatron framework and refactors the RLHF trainers for better structure and code reuse. The changes are well-organized, and the refactoring of the KTO trainer is a significant improvement.
I've identified a couple of areas in the documentation that could be improved for clarity and consistency with the code changes. Please see my detailed comments below.
docs/source/Megatron-SWIFT/命令行参数.md
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The calculate_KL
parameter seems to have been removed from the KTO parameters documentation. However, it's still a configurable parameter in the code (swift/megatron/argument/megatron_args.py
). It's now optional and can be inferred, but users can still set it. It would be beneficial to document this parameter for clarity.
For example, you could add:
- calculate_KL: 是否计算KL散度。默认为`None`,会根据`loss_type`自动推断。例如,当`loss_type`为`'apo_zero_unpaired'`时,`calculate_KL`会设置为`False`,否则为`True`。
**KTO Parameters**: | ||
- beta: Coefficient for the KL regularization term. Default is `0.1`. | ||
- desirable_weight: Loss weight $\lambda_D$ for desirable response in the KTO algorithm, default is `1.`. | ||
- undesirable_weight: Loss weight $\lambda_U$ for undesirable response in the KTO algorithm, default is `1.`. | ||
- calculate_KL: Whether to calculate KL divergence. Default is `True`. | ||
- ref_load: same meaning as in DPO. | ||
- ref_adapter_load: same meaning as in DPO. | ||
- beta: parameter controlling the deviation from the ref_model. Higher `beta` means less deviation from the ref_model. Default is `0.1`. | ||
- loss_type: default is `'kto'`. See possible values in the TRL docs: https://huggingface.co/docs/trl/main/en/kto_trainer#trl.KTOConfig.loss_type. | ||
- desirable_weight: factor to weight desirable losses to counter imbalance between desirable and undesirable pairs. Default is `1.`. | ||
- undesirable_weight: factor to weight undesirable losses to counter imbalance between desirable and undesirable pairs. Default is `1.`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The calculate_KL
parameter seems to have been removed from the KTO parameters documentation. However, it's still a configurable parameter in the code (swift/megatron/argument/megatron_args.py
). It's now optional and can be inferred, but users can still set it. It would be beneficial to document this parameter for clarity.
For example, you could add:
- calculate_KL: Whether to calculate KL divergence. Defaults to `None`, and will be inferred based on `loss_type`. For example, when `loss_type` is `'apo_zero_unpaired'`, `calculate_KL` will be set to `False`, otherwise `True`.
No description provided.