Skip to content

Conversation

Jintao-Huang
Copy link
Collaborator

No description provided.

Copy link
Contributor

Summary of Changes

Hello @Jintao-Huang, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request undertakes a significant refactoring of the data batch generation and iteration mechanisms within the Megatron training framework. The primary goal is to enhance modularity and provide more context-aware batch processing by relocating common and specialized batch preparation logic from global utility functions into their respective trainer classes. This change streamlines how data batches are prepared and supplied to models during both training and evaluation, with a particular focus on improving the handling of KTO (Kahn-Tversky Optimization) training workflows.

Highlights

  • Batch Generation Refactoring: The logic for generating data batches has been refactored, moving the get_batch and get_kto_batch functions from swift/megatron/trainers/utils.py into the respective trainer classes as methods.
  • Base Trainer Enhancements: The BaseTrainer in swift/megatron/trainers/base.py now includes its own get_batch method and its _replace_data_iterator method has been updated to accept a model argument, providing more context for data iterator replacement.
  • KTO-Specific Batch Handling: The KtoTrainer in swift/megatron/trainers/kto_trainer.py now implements a specialized get_batch method tailored for KTO (Kahn-Tversky Optimization) batches, incorporating logic for specific tensor keys, pipeline parallelism, padding-free processing, and context parallelism.
  • Unified Batch Access: The DpoTrainer and Trainer classes have been updated to call self.get_batch instead of relying on global utility functions, ensuring that batch generation is handled polymorphically within the trainer instances.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the batch retrieval logic by moving the get_batch and get_kto_batch functions into methods within the trainer classes. This is a good design improvement that enhances polymorphism. However, I've found a critical issue in MegatronKTOTrainer where the get_batch method signature is incorrect, which will lead to a runtime error. My review includes a specific comment and a code suggestion to fix this.

return super().evaluate(forward_step_func, data_iterator, model, process_non_loss_data_func, config, verbose,
non_loss_data_func)

def get_batch(data_iterator):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The get_batch method is missing the self parameter. As an instance method, it should be defined as def get_batch(self, data_iterator):. Without self, the call self.get_batch(data_iterator) at line 185 will fail at runtime with a TypeError.

Suggested change
def get_batch(data_iterator):
def get_batch(self, data_iterator):

@Jintao-Huang
Copy link
Collaborator Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for KTO (Kahneman-Tversky Optimization) to Megatron-SWIFT. The changes include updating documentation, refactoring argument parsing, and implementing the KTO trainer. The refactoring of the trainer logic is a good improvement. However, I've identified two critical bugs and one high-severity issue in the new MegatronKTOTrainer implementation that could lead to runtime errors. These issues are related to incorrect function arguments, unconditional logic that should be conditional, and a potential division-by-zero error. Please see the detailed comments for suggestions on how to fix them.

Comment on lines 49 to 52
policy_logps = self._kto_get_logps(output_tensor, data, False, False, length)
ref_logps = self._kto_get_logps(output_tensor, data, False, True, length)
policy_KL_logps = self._kto_get_logps(output_tensor, kl_data, True, False, length)
ref_KL_logps = self._kto_get_logps(output_tensor, kl_data, True, True, length)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The computation of policy_KL_logps and ref_KL_logps is unconditional. This will cause a runtime error when self.args.calculate_KL is False, because the output_tensor passed to loss_func will not contain the KL-related parts, leading to an out-of-bounds slice in _kto_get_logps.

The computation of these log probabilities should be guarded by if self.args.calculate_KL:, and they should be set to None otherwise, as the trl.KTOTrainer.kto_loss function can handle None values for these arguments.

Suggested change
policy_logps = self._kto_get_logps(output_tensor, data, False, False, length)
ref_logps = self._kto_get_logps(output_tensor, data, False, True, length)
policy_KL_logps = self._kto_get_logps(output_tensor, kl_data, True, False, length)
ref_KL_logps = self._kto_get_logps(output_tensor, kl_data, True, True, length)
policy_logps = self._kto_get_logps(output_tensor, data, False, False, length)
ref_logps = self._kto_get_logps(output_tensor, data, False, True, length)
if self.args.calculate_KL:
policy_KL_logps = self._kto_get_logps(output_tensor, kl_data, True, False, length)
ref_KL_logps = self._kto_get_logps(output_tensor, kl_data, True, True, length)
else:
policy_KL_logps = None
ref_KL_logps = None

Comment on lines 116 to 136
with torch.no_grad(), self.null_ref_context() as ref_models:
ref_model = ref_models[vp_stage or 0]
if self.args.calculate_KL:
if input_tensor is not None:
ref_model.set_input_tensor(self._get_input_tensor(True, True, length))
ref_KL_output_tensor = ref_model(**kl_data)

if input_tensor is not None:
ref_model.set_input_tensor(self._get_input_tensor(True, False, length))
ref_output_tensor = ref_model(**data)

if self.args.calculate_KL:
with torch.no_grad():
kl_output_tensor = self._forward_step_helper(ref_model, kl_inputs)
data['reference_KL_logps'] = self.get_logps(kl_output_tensor, data['KL_completion_labels'],
data.get('KL_completion_packed_seq_params'))
if input_tensor is not None:
unwrapped_model.set_input_tensor(self._get_input_tensor(input_tensor, False, True))
KL_output_tensor = model(**kl_data)

if input_tensor is not None:
unwrapped_model.set_input_tensor(self._get_input_tensor(input_tensor, False, False))
with self.stimer:
output_tensor = model(**data)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There are several issues with the calls to _get_input_tensor within the forward_step method, which will lead to runtime errors:

  1. In lines 120 and 124, the first argument passed is the boolean True instead of the input_tensor.
  2. In lines 120, 124, 130, and 134, the dim argument is missing. The _get_input_tensor method expects dim as its last argument, which should be 1 as the tensors are concatenated along dim=1.
  3. In lines 130 and 134, the length argument is missing.

These appear to be typos and omissions that need to be corrected for the logic to work as intended.

        with torch.no_grad(), self.null_ref_context() as ref_models:
            ref_model = ref_models[vp_stage or 0]
            if self.args.calculate_KL:
                if input_tensor is not None:
                    ref_model.set_input_tensor(self._get_input_tensor(input_tensor, True, True, length, dim=1))
                ref_KL_output_tensor = ref_model(**kl_data)

            if input_tensor is not None:
                ref_model.set_input_tensor(self._get_input_tensor(input_tensor, True, False, length, dim=1))
            ref_output_tensor = ref_model(**data)

        if self.args.calculate_KL:
            with torch.no_grad():
                if input_tensor is not None:
                    unwrapped_model.set_input_tensor(self._get_input_tensor(input_tensor, False, True, length, dim=1))
                KL_output_tensor = model(**kl_data)

        if input_tensor is not None:
            unwrapped_model.set_input_tensor(self._get_input_tensor(input_tensor, False, False, length, dim=1))
        with self.stimer:
            output_tensor = model(**data)

Comment on lines 160 to 164
for key in ['chosen', 'rejected']:
count = total_loss_dict[f'{prefix}count/{key}']
res[f'{prefix}logps/{key}'] = total_loss_dict[f'{prefix}logps/{key}_sum'] / count
res[f'{prefix}rewards/{key}'] = total_loss_dict[f'{prefix}rewards/{key}_sum'] / count
res[f'{prefix}rewards/margins'] = res[f'{prefix}rewards/chosen'] - res[f'{prefix}rewards/rejected']
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In the custom_log method, there is a potential ZeroDivisionError. If a batch (or accumulated batches) contains only desirable or only undesirable samples, the count for the other type will be zero, leading to division by zero when calculating the mean for logps and rewards.

You should add a check to prevent division by zero.

Suggested change
for key in ['chosen', 'rejected']:
count = total_loss_dict[f'{prefix}count/{key}']
res[f'{prefix}logps/{key}'] = total_loss_dict[f'{prefix}logps/{key}_sum'] / count
res[f'{prefix}rewards/{key}'] = total_loss_dict[f'{prefix}rewards/{key}_sum'] / count
res[f'{prefix}rewards/margins'] = res[f'{prefix}rewards/chosen'] - res[f'{prefix}rewards/rejected']
for key in ['chosen', 'rejected']:
count = total_loss_dict[f'{prefix}count/{key}']
if count.item() > 0:
res[f'{prefix}logps/{key}'] = total_loss_dict[f'{prefix}logps/{key}_sum'] / count
res[f'{prefix}rewards/{key}'] = total_loss_dict[f'{prefix}rewards/{key}_sum'] / count
else:
res[f'{prefix}logps/{key}'] = torch.tensor(0.0, device=count.device)
res[f'{prefix}rewards/{key}'] = torch.tensor(0.0, device=count.device)
res[f'{prefix}rewards/margins'] = res[f'{prefix}rewards/chosen'] - res[f'{prefix}rewards/rejected']

@Jintao-Huang Jintao-Huang changed the title [WIP] [megatron] update megatron kto [megatron] update megatron kto Oct 8, 2025
@Jintao-Huang
Copy link
Collaborator Author

/gemini review

Copy link
Contributor

Warning

Gemini encountered an error creating the review. You can try again by commenting /gemini review.

@Jintao-Huang
Copy link
Collaborator Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for KTO (Kahneman-Tversky Optimization) within the Megatron framework and refactors the RLHF trainers for better structure and code reuse. The changes are well-organized, and the refactoring of the KTO trainer is a significant improvement.

I've identified a couple of areas in the documentation that could be improved for clarity and consistency with the code changes. Please see my detailed comments below.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The calculate_KL parameter seems to have been removed from the KTO parameters documentation. However, it's still a configurable parameter in the code (swift/megatron/argument/megatron_args.py). It's now optional and can be inferred, but users can still set it. It would be beneficial to document this parameter for clarity.

For example, you could add:

- calculate_KL: 是否计算KL散度。默认为`None`,会根据`loss_type`自动推断。例如,当`loss_type``'apo_zero_unpaired'`时,`calculate_KL`会设置为`False`,否则为`True`

Comment on lines 253 to +259
**KTO Parameters**:
- beta: Coefficient for the KL regularization term. Default is `0.1`.
- desirable_weight: Loss weight $\lambda_D$ for desirable response in the KTO algorithm, default is `1.`.
- undesirable_weight: Loss weight $\lambda_U$ for undesirable response in the KTO algorithm, default is `1.`.
- calculate_KL: Whether to calculate KL divergence. Default is `True`.
- ref_load: same meaning as in DPO.
- ref_adapter_load: same meaning as in DPO.
- beta: parameter controlling the deviation from the ref_model. Higher `beta` means less deviation from the ref_model. Default is `0.1`.
- loss_type: default is `'kto'`. See possible values in the TRL docs: https://huggingface.co/docs/trl/main/en/kto_trainer#trl.KTOConfig.loss_type.
- desirable_weight: factor to weight desirable losses to counter imbalance between desirable and undesirable pairs. Default is `1.`.
- undesirable_weight: factor to weight undesirable losses to counter imbalance between desirable and undesirable pairs. Default is `1.`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The calculate_KL parameter seems to have been removed from the KTO parameters documentation. However, it's still a configurable parameter in the code (swift/megatron/argument/megatron_args.py). It's now optional and can be inferred, but users can still set it. It would be beneficial to document this parameter for clarity.

For example, you could add:

- calculate_KL: Whether to calculate KL divergence. Defaults to `None`, and will be inferred based on `loss_type`. For example, when `loss_type` is `'apo_zero_unpaired'`, `calculate_KL` will be set to `False`, otherwise `True`.

@Jintao-Huang Jintao-Huang merged commit e686ae8 into modelscope:main Oct 8, 2025
1 of 2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants