Conversation
…into add_rl_example
Summary of ChangesHello @Yunnglin, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request primarily focuses on enhancing the system's capabilities by upgrading the default language model used in several examples and configurations to a more powerful version. It introduces a new, comprehensive reinforcement learning example for solving math problems, showcasing advanced training techniques. Additionally, significant backend improvements have been made to streamline model interaction, optimize resource management for LoRA adapters, and refine data processing for training, leading to a more robust and flexible framework. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a reinforcement learning example using GRPO for the GSM8K dataset and updates various configurations to support it. The changes include upgrading the model size in examples, adding a new sampler service, and refactoring data handling and loss calculation logic. My review focuses on the new gsm8k.py example, with suggestions to improve data representation and code clarity. The other changes appear correct and well-integrated.
| return Trajectory( | ||
| messages=messages, | ||
| user_data=[('ground_truth', ground_truth)], | ||
| ) |
There was a problem hiding this comment.
To improve data representation and make trajectories more self-contained for debugging and logging, it's a good practice to include the original question in the user_data. This avoids using placeholders later when reconstructing the conversation history.
| return Trajectory( | |
| messages=messages, | |
| user_data=[('ground_truth', ground_truth)], | |
| ) | |
| return Trajectory( | |
| messages=messages, | |
| user_data=[('ground_truth', ground_truth), ('question', question)], | |
| ) |
| # Get ground truth from user_data | ||
| gt = '' | ||
| user_data = trajectory.get('user_data', []) | ||
| if isinstance(user_data, list): | ||
| for item in user_data: | ||
| if isinstance(item, (list, tuple)) and len(item) == 2: | ||
| if item[0] == 'ground_truth': | ||
| gt = str(item[1]) | ||
| break |
There was a problem hiding this comment.
The logic for extracting ground_truth from user_data can be simplified. Converting the list of tuples to a dictionary first makes the code more readable and robust, especially if more items are added to user_data in the future.
| # Get ground truth from user_data | |
| gt = '' | |
| user_data = trajectory.get('user_data', []) | |
| if isinstance(user_data, list): | |
| for item in user_data: | |
| if isinstance(item, (list, tuple)) and len(item) == 2: | |
| if item[0] == 'ground_truth': | |
| gt = str(item[1]) | |
| break | |
| # Get ground truth from user_data | |
| user_data_dict = dict(trajectory.get('user_data', [])) | |
| gt = str(user_data_dict.get('ground_truth', '')) |
| # Use the corresponding user data for this sequence | ||
| trajectories.append({ | ||
| 'messages': [ | ||
| {'role': 'system', 'content': SYSTEM_PROMPT}, | ||
| {'role': 'user', 'content': 'Math problem'}, # Placeholder | ||
| {'role': 'assistant', 'content': decoded_text} | ||
| ], | ||
| 'user_data': all_user_data[idx] | ||
| }) |
There was a problem hiding this comment.
Instead of using a placeholder for the user's question, you can now retrieve it from the user_data to construct a complete and accurate trajectory. This makes the data representation more robust and easier to debug. This change depends on also adding the 'question' to the user_data in the GSM8KProcessor.
| # Use the corresponding user data for this sequence | |
| trajectories.append({ | |
| 'messages': [ | |
| {'role': 'system', 'content': SYSTEM_PROMPT}, | |
| {'role': 'user', 'content': 'Math problem'}, # Placeholder | |
| {'role': 'assistant', 'content': decoded_text} | |
| ], | |
| 'user_data': all_user_data[idx] | |
| }) | |
| # Use the corresponding user data for this sequence | |
| user_data_dict = dict(all_user_data[idx]) | |
| question = user_data_dict.get('question', 'Math problem') # Fallback to placeholder | |
| trajectories.append({ | |
| 'messages': [ | |
| {'role': 'system', 'content': SYSTEM_PROMPT}, | |
| {'role': 'user', 'content': question}, | |
| {'role': 'assistant', 'content': decoded_text} | |
| ], | |
| 'user_data': all_user_data[idx] | |
| }) |
There was a problem hiding this comment.
Pull request overview
Adds an RL (GRPO) client example and related server-side plumbing/config updates to support sampler-weight syncing and unified forward/backward behavior across Transformers and Megatron backends.
Changes:
- Add a Tinker-compatible GSM8K GRPO training example that periodically saves weights for the sampler and then samples generations.
- Unify Tinker model server
forward_backwardto passloss_fnfor both Megatron and Transformers compatibility wrappers; adjust datum conversion to ensure weights are present. - Update cookbook/server configs and defaults to use Qwen2.5-3B-Instruct and extend supported model list.
Reviewed changes
Copilot reviewed 20 out of 20 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
| src/twinkle/server/utils/io_utils.py | Makes save dir absolute; adds sampler checkpoint cleanup before writing sampler metadata. |
| src/twinkle/server/twinkle/model.py | Changes default adapter_config handling in build_model_app. |
| src/twinkle/server/tinker/server.py | Adds Qwen2.5-3B-Instruct to the supported model list. |
| src/twinkle/server/tinker/model.py | Unifies forward_backward call signature; updates sampler-weight save ordering; passes base model into template. |
| src/twinkle/server/tinker/common/transformers_model.py | Adjusts loss selection logic in the Transformers compatibility wrapper. |
| src/twinkle/server/tinker/common/megatron_model.py | Updates Megatron compatibility wrapper to accept loss_fn and RL extras. |
| src/twinkle/server/tinker/common/datum.py | Adjusts masking/weights behavior and ensures weights exist for downstream computations. |
| src/twinkle/model/transformers/multi_lora_transformers.py | Enables gradient checkpointing in MultiLoRA Transformers model initialization. |
| src/twinkle/model/megatron/multi_lora_megatron.py | Changes how the Megatron PEFT patch is applied. |
| cookbook/client/twinkle/transformer/server_config.yaml | Switches example config to Qwen2.5-3B-Instruct. |
| cookbook/client/twinkle/transformer/sampler.py | Updates example MODEL_ID to Qwen2.5-3B-Instruct. |
| cookbook/client/twinkle/transformer/lora.py | Updates base model, introduces use_megatron toggle, and adjusts optimizer/scheduler usage accordingly. |
| cookbook/client/twinkle/transformer/grpo.py | Updates example MODEL_ID to Qwen2.5-3B-Instruct. |
| cookbook/client/twinkle/megatron/server_config.yaml | Updates base model and adds adapter config + mixed precision. |
| cookbook/client/twinkle/megatron/lora.py | Removes the prior Megatron LoRA training example script. |
| cookbook/client/tinker/transformer/server_config.yaml | Switches example config to Qwen2.5-3B-Instruct. |
| cookbook/client/tinker/transformer/self_congnition.py | Updates example base model to Qwen2.5-3B-Instruct. |
| cookbook/client/tinker/transformer/sample.py | Updates example base model to Qwen2.5-3B-Instruct. |
| cookbook/client/tinker/transformer/lora.py | Updates example base model and resume-path example for Qwen2.5-3B-Instruct. |
| cookbook/client/tinker/transformer/gsm8k.py | Adds a new GSM8K GRPO training example using save-weights-for-sampler flow. |
| cookbook/client/tinker/transformer/grpo.py | Updates base model and scales up example training loop parameters. |
| cookbook/client/tinker/megatron/server_config.yaml | Updates base model, adds adapter config, and adds sampler service configuration. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| model_id=body.model_id, | ||
| is_sampler=True | ||
| ) | ||
|
|
||
| # NOTE: Need to save meta first to ensure only one sample weight exists | ||
| tinker_path = checkpoint_manager.save( | ||
| body.model_id, name=checkpoint_name, is_sampler=True) | ||
|
|
||
| logger.info(f"Saving weights to {save_dir}") | ||
| # Save weights with save_optimizer=False for sampler use | ||
| self.model.save(name=checkpoint_name, |
There was a problem hiding this comment.
save_weights_for_sampler writes checkpoint metadata (and deletes previous sampler checkpoints) before self.model.save(...) runs. This can leave a listed checkpoint with metadata but missing/partial weights if model.save fails, and size_bytes in metadata will be 0 because the directory is still empty at metadata time. Safer flow is: delete old sampler checkpoints (if needed) → save weights → write metadata (or write metadata last / mark it as incomplete and finalize after save succeeds).
| adapter_config: Dict[str, Any] = {}, | ||
| **kwargs): |
There was a problem hiding this comment.
adapter_config uses a mutable default ({}), which can be shared across requests/processes and lead to cross-request state leakage if mutated. Use adapter_config: Optional[Dict[str, Any]] = None (or Mapping) and normalize to {} inside the function.
| adapter_config: Dict[str, Any] = {}, | |
| **kwargs): | |
| adapter_config: Optional[Dict[str, Any]] = None, | |
| **kwargs): | |
| if adapter_config is None: | |
| adapter_config = {} |
| use_megatron: bool = False, | ||
| adapter_config: Dict[str, Any] = None, | ||
| queue_config: Optional[Dict[str, Any]] = None, | ||
| adapter_config: Dict[str, Any] = {}, | ||
| queue_config: Optional[Dict[str, Any]] = {}, |
There was a problem hiding this comment.
adapter_config / queue_config use mutable defaults ({}). If any code mutates these dicts (directly or indirectly), the mutated value will be reused for subsequent app builds. Use None defaults and normalize inside build_model_app / ModelManagement.__init__.
| super().set_loss('CrossEntropyLoss', | ||
| adapter_name=adapter_name) |
There was a problem hiding this comment.
For unknown loss_fn, this now silently falls back to CrossEntropyLoss. That can mask client/server mismatches and make debugging hard (and can produce incorrect training behavior). Prefer validating loss_fn and raising a clear error (or explicitly mapping supported values).
| super().set_loss('CrossEntropyLoss', | |
| adapter_name=adapter_name) | |
| raise ValueError( | |
| f"Unsupported loss_fn '{loss_fn}'. Expected one of: 'cross_entropy', 'importance_sampling'." | |
| ) |
| 'logprobs' and 'elementwise_loss', and loss is a scalar. | ||
| """ | ||
| if loss_fn == 'importance_sampling': | ||
| super().set_loss('GRPOLoss', | ||
| adapter_name=adapter_name, |
There was a problem hiding this comment.
forward_backward only sets GRPOLoss when loss_fn == 'importance_sampling', but never resets the loss for other values. If a prior call set GRPOLoss, subsequent cross-entropy calls will keep using it and likely miscompute/require missing RL fields. Set CrossEntropyLoss (or the Megatron default loss) explicitly when loss_fn is not RL, and validate/raise on unsupported values.
| types.AdamParams(learning_rate=LEARNING_RATE)) | ||
|
|
||
| fwdbwd_result = fwdbwd_future.result() | ||
| optim_result = optim_future.result() |
There was a problem hiding this comment.
Variable optim_result is not used.
| optim_result = optim_future.result() | |
| optim_future.result() |
| from twinkle.dataloader import DataLoader | ||
| from twinkle.preprocessor import Preprocessor | ||
| from twinkle.reward.base import Reward | ||
| from twinkle.data_format import Trajectory, InputFeature, Message |
There was a problem hiding this comment.
Import of 'InputFeature' is not used.
| from twinkle.data_format import Trajectory, InputFeature, Message | |
| from twinkle.data_format import Trajectory, Message |
No description provided.