refactor: refactor train engine high level APIs#658
refactor: refactor train engine high level APIs#658rchardx merged 12 commits intoinclusionAI:mainfrom
Conversation
Signed-off-by: chenzhenyang <andy271828@163.com>
Summary of ChangesHello @aaaandychen, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request refactors the core API of the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request is a great step towards unifying the training engine's API by introducing a forward_backward_batch interface. The refactoring in areal/api/engine_api.py and areal/engine/megatron_engine.py is well-executed and significantly improves code structure and reduces duplication. However, the implementation for FSDPEngine in areal/engine/fsdp_engine.py doesn't fully adhere to the new API contract, which undermines the goal of a unified interface. I've left detailed comments on this, along with a few other potential bugs and areas for improvement.
| forward_step_counts = [0] * len(self.model) | ||
|
|
||
| def forward_step(batch_iter, model): | ||
| nonlocal forward_step_counts | ||
| batch = next(batch_iter) | ||
| model_vp_stage = getattr(model, "vp_stage", 0) | ||
| forward_step_count = forward_step_counts[model_vp_stage] | ||
| padding_length = mb_list.padding_lengths[forward_step_count] | ||
| orig_input = mb_list.mbs[forward_step_count] | ||
| cu_seqlens = batch["cu_seqlens"] | ||
| old_cu_seqlens = mb_list.old_cu_seqlens_list[forward_step_count] | ||
|
|
||
| forward_step_counts[model_vp_stage] += 1 | ||
| output = packed_context_parallel_forward(model, batch) | ||
|
|
||
| if mpu.is_pipeline_last_stage( | ||
| ignore_virtual=False, vp_stage=model_vp_stage | ||
| ): | ||
| output = unpad_logits( | ||
| output, | ||
| padding_length=padding_length, | ||
| cu_seqlens=cu_seqlens, | ||
| old_cu_seqlens=old_cu_seqlens, | ||
| ) | ||
|
|
||
| def _post_process_fn(input_, output): | ||
| loss = torch.tensor(1.0, device=output.device) | ||
| if post_hook is not None: | ||
| output = post_hook(output, input_) | ||
| return loss, {"output": output} | ||
|
|
||
| return output, functools.partial(_post_process_fn, orig_input) | ||
| batch_ctx = next(batch_iter) | ||
|
|
||
| return self._forward_compute_mb( | ||
| mb_input=batch_ctx, | ||
| loss_fn=loss_fn, | ||
| loss_weight_fn=loss_weight_fn, | ||
| model=model, | ||
| forward_step_counts=forward_step_counts | ||
| ) |
There was a problem hiding this comment.
There's an issue in the forward_step function. The forward_step_counts variable is a remnant from a previous implementation and is no longer used, so it and its nonlocal declaration can be removed.
More importantly, output_post_hook and return_outputs are not being passed to _forward_compute_mb. This will cause incorrect behavior for forward_only=True cases, as _forward_compute_mb will not be able to apply the post-processing hook or know that it should return outputs instead of computing a loss.
def forward_step(batch_iter, model):
batch_ctx = next(batch_iter)
return self._forward_compute_mb(
mb_input=batch_ctx,
loss_fn=loss_fn,
loss_weight_fn=loss_weight_fn,
model=model,
post_hook=output_post_hook,
return_output=return_outputs,
)
areal/engine/fsdp_engine.py
Outdated
| assert total_loss_weight != 0 | ||
| dist.all_reduce(total_loss_weight, group=self.dp_group) |
|
Inviting @ChangyiYang and @zhaochenyang20 for review since you are familiar with the TrainEngine refactoring. |
|
great job! |
nuzant
left a comment
There was a problem hiding this comment.
This PR still needs some refactoring. There are still some problems to be discussed about the API designs, and the code quality has room for improvement.
Additionally, please format the code according to the instructions in CONTRIBUTING.md and double-check the gemini reviews.
We should also make sure all related tests (areal/tests/test_fsdp_*.py and `areal/tests/test_megatron_*.py) can pass.
areal/api/engine_api.py
Outdated
| """ | ||
| raise NotImplementedError() | ||
|
|
||
| def split_micro_batch( |
There was a problem hiding this comment.
Since the way we split the micro batches is not related to engines, we should not expose this API in this class. We can add an API in areal/utils/data.py to assist user-side data handling.
There was a problem hiding this comment.
In areal/utils/data, I added create_mb_iterator to convert mb_list to an iterator, returning MB tuples on each iteration. This allows the engine to select specific MB elements for the iterator and wrap metadata for downstream use.
areal/engine/megatron_engine.py
Outdated
| max_seqlen = data_iterator.max_seqlen | ||
| num_microbatches = data_iterator.num_microbatches | ||
| else: | ||
| max_seqlen = self.config.mb_spec.max_tokens_per_mb |
There was a problem hiding this comment.
Since forward_backward_batch is an API exposed to users, please provide a clear definition of what is the expected input data_iterator.
areal/engine/fsdp_engine.py
Outdated
| loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor], | ||
| **kwargs, | ||
| ) -> tuple[torch.Tensor, Callable[[torch.Tensor], tuple[torch.Tensor, dict]]]: | ||
| batch_type = kwargs.get("batch_type") |
There was a problem hiding this comment.
Same problem as the implementation in MegatronEngine. _forward_compute_mb should be an atomic operation that takes a micro batch as input and output logits or logprobs. There should not be a bunch of if-else to decide what to output. Just process the output in forward_backward_batch.
Signed-off-by: chenzhenyang <andy271828@163.com>
a197ede to
b5eb5b7
Compare
|
Hi @nuzant , Thank you so much for your patient review and thoughtful comments! I have modified the code and will complete the conflict handling and testing within today. |
Signed-off-by: chenzhenyang <chenzhenyang@moonshot.cn> # Conflicts: # areal/engine/fsdp_engine.py # areal/engine/megatron_engine.py
5f8224f to
8a8b641
Compare
Signed-off-by: chenzhenyang <andy271828@163.com>
8a8b641 to
53dcd95
Compare
|
@nuzant I have fixed the code based on your feedback and conflicts and ensured all related tests pass. Please review again at your convenience.To address branch conflicts, the post process logic has been primarily updated, and the post_hook in forward_batch has been removed. |
# Conflicts: # areal/engine/fsdp_engine.py
Signed-off-by: chenzhenyang <andy271828@163.com>
Signed-off-by: chenzhenyang <andy271828@163.com>
b7d8a73 to
42c4f4e
Compare
Signed-off-by: chenzhenyang <andy271828@163.com>
5751e6e to
77a1aba
Compare
|
@rchardx Hi, I have incorporated your feedback. I’ve optimized the return value retrieval by designing hook methods within the API. Additionally, I introduced a BaseTrainEngine to implement the Template Method pattern, which enhances the overall usability of the design.And I have adapted the existing implementation accordingly. |
rchardx
left a comment
There was a problem hiding this comment.
Merge now, improve later. Delaying the merge would create more conflicts and complications. Follow-up PRs will address these quality issues.
Description
This PR refactors the top-level API implementation of TrainEngine and its subclasses (FSDPEngine, MegatronEngine).
Previously, the execution logic was fragmented across train_batch, forward_batch, and eval_batch. This PR introduces a unified forward_backward_batch interface to handle the execution flow. This change significantly reduces code duplication across different engines and provides greater flexibility for custom training loops.
Key Changes
Related Issue
related issue: #601
Fixes #(issue)
Type of Change
work as expected)
Checklist
jb build docs/gemini review)Breaking Change Details (if applicable):
Additional Context
Need help? Check the Contributing Guide or ask in
GitHub Discussions!