Skip to content

Add rl example#49

Merged
Yunnglin merged 24 commits intodevfrom
add_rl_example
Feb 10, 2026
Merged

Add rl example#49
Yunnglin merged 24 commits intodevfrom
add_rl_example

Conversation

@Yunnglin
Copy link
Copy Markdown
Collaborator

No description provided.

Copilot AI review requested due to automatic review settings February 10, 2026 07:32
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @Yunnglin, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request primarily focuses on enhancing the system's capabilities by upgrading the default language model used in several examples and configurations to a more powerful version. It introduces a new, comprehensive reinforcement learning example for solving math problems, showcasing advanced training techniques. Additionally, significant backend improvements have been made to streamline model interaction, optimize resource management for LoRA adapters, and refine data processing for training, leading to a more robust and flexible framework.

Highlights

  • Model Upgrade: Upgraded the default model from 'Qwen2.5-0.5B-Instruct' to 'Qwen2.5-3B-Instruct' across various configurations and examples, enhancing model capabilities.
  • New RL Example: Introduced a new reinforcement learning example for the GSM8K math problem dataset, utilizing the GRPO algorithm for training and demonstrating client-side reward computation and weight synchronization for sampling.
  • Backend Unification & Improvements: Refactored the forward_backward logic in the Tinker server to unify handling for both Megatron and Transformers backends, ensuring consistent API usage. Also, enabled gradient checkpointing for MultiLoraTransformersModel and implemented automatic deletion of old sampler weights upon saving new ones.
  • Configuration Enhancements: Added mixed_precision: bf16 and adapter_config for LoRA management to Megatron model service configurations, providing more control over training precision and adapter lifecycle.
  • Datum Processing Refinement: Updated datum_to_input_feature and input_feature_to_datum to improve how weights are handled for labels, ensuring weights != 0 is used for masking and that weights are consistently present for loss calculations.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • cookbook/client/tinker/megatron/server_config.yaml
    • Updated model name and route prefix to 'Qwen2.5-3B-Instruct'.
    • Added 'mixed_precision: bf16' for training.
    • Included 'adapter_config' for LoRA management.
    • Added a new sampler service configuration for 'Qwen2.5-3B-Instruct'.
  • cookbook/client/tinker/transformer/grpo.py
    • Updated 'BASE_MODEL' to 'Qwen/Qwen2.5-3B-Instruct'.
    • Adjusted 'MAX_STEPS', 'BATCH_SIZE', and 'SYNC_INTERVAL' parameters.
  • cookbook/client/tinker/transformer/gsm8k.py
    • Added a new file implementing the GSM8K GRPO training example.
  • cookbook/client/tinker/transformer/lora.py
    • Updated 'base_model' references to 'Qwen/Qwen2.5-3B-Instruct'.
  • cookbook/client/tinker/transformer/sample.py
    • Updated 'base_model' to 'Qwen/Qwen2.5-3B-Instruct'.
  • cookbook/client/tinker/transformer/self_congnition.py
    • Updated 'base_model' to 'Qwen/Qwen2.5-3B-Instruct'.
  • cookbook/client/tinker/transformer/server_config.yaml
    • Updated model name, route prefix, and 'model_id' for both model and sampler services to 'Qwen2.5-3B-Instruct'.
  • cookbook/client/twinkle/megatron/lora.py
    • Removed the file.
  • cookbook/client/twinkle/megatron/server_config.yaml
    • Updated model name, route prefix, and 'model_id' to 'Qwen2.5-3B-Instruct'.
    • Added 'mixed_precision: bf16' and 'adapter_config'.
  • cookbook/client/twinkle/transformer/grpo.py
    • Updated 'MODEL_ID' to 'ms://Qwen/Qwen2.5-3B-Instruct'.
  • cookbook/client/twinkle/transformer/lora.py
    • Added 'use_megatron' flag for conditional logic.
    • Updated 'model_id' references to 'ms://Qwen/Qwen2.5-3B-Instruct'.
    • Changed optimizer from 'AdamW' to 'Adam' and conditionally applied LR scheduler based on 'use_megatron'.
  • cookbook/client/twinkle/transformer/sampler.py
    • Updated 'MODEL_ID' to 'Qwen/Qwen2.5-3B-Instruct'.
  • cookbook/client/twinkle/transformer/server_config.yaml
    • Updated model name, route prefix, and 'model_id' for both model and sampler services to 'Qwen2.5-3B-Instruct'.
  • src/twinkle/model/megatron/multi_lora_megatron.py
    • Changed 'MegatronPeft().patch()' to 'MegatronPeft()()' for consistency.
  • src/twinkle/model/transformers/multi_lora_transformers.py
    • Enabled 'gradient_checkpointing' for the model to optimize memory usage.
  • src/twinkle/server/tinker/common/datum.py
    • Modified label handling to use 'weights != 0' for masking.
    • Ensured 'weights' are always added to 'loss_fn_inputs' if not present.
    • Removed redundant 'TensorData' assignment.
  • src/twinkle/server/tinker/common/megatron_model.py
    • Imported 'extract_rl_feature' for RL-specific data processing.
    • Modified 'forward_backward' to set 'GRPOLoss' for 'importance_sampling' and extract RL features.
    • Unified 'forward_backward' call for Megatron backend.
  • src/twinkle/server/tinker/common/transformers_model.py
    • Modified 'forward_backward' to default to 'CrossEntropyLoss' if 'loss_fn' is not 'importance_sampling'.
  • src/twinkle/server/tinker/model.py
    • Changed default 'adapter_config' and 'queue_config' to empty dictionaries.
    • Stored 'base_model' in 'self.base_model' for easier access.
    • Passed 'model_id' to 'set_template' for proper template initialization.
    • Unified 'forward_backward' call logic for different backends.
    • Reordered save operations in 'save_weights_for_sampler' to save metadata first.
  • src/twinkle/server/tinker/server.py
    • Added 'Qwen/Qwen2.5-3B-Instruct' to the list of supported models.
  • src/twinkle/server/twinkle/model.py
    • Changed default 'adapter_config' to an empty dictionary.
  • src/twinkle/server/utils/io_utils.py
    • Added a logger for improved debugging and monitoring.
    • Changed 'base_path' to be absolute for consistent path resolution.
    • Implemented '_delete_existing_sampler_weights' to manage sampler checkpoints.
Activity
  • The pull request was created by Yunnglin, introducing new features and updates.
  • The changes involve updating model configurations and adding a new example for reinforcement learning.
  • Backend logic for model handling and checkpoint management has been refined.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a reinforcement learning example using GRPO for the GSM8K dataset and updates various configurations to support it. The changes include upgrading the model size in examples, adding a new sampler service, and refactoring data handling and loss calculation logic. My review focuses on the new gsm8k.py example, with suggestions to improve data representation and code clarity. The other changes appear correct and well-integrated.

Comment on lines +84 to +87
return Trajectory(
messages=messages,
user_data=[('ground_truth', ground_truth)],
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve data representation and make trajectories more self-contained for debugging and logging, it's a good practice to include the original question in the user_data. This avoids using placeholders later when reconstructing the conversation history.

Suggested change
return Trajectory(
messages=messages,
user_data=[('ground_truth', ground_truth)],
)
return Trajectory(
messages=messages,
user_data=[('ground_truth', ground_truth), ('question', question)],
)

Comment on lines +121 to +129
# Get ground truth from user_data
gt = ''
user_data = trajectory.get('user_data', [])
if isinstance(user_data, list):
for item in user_data:
if isinstance(item, (list, tuple)) and len(item) == 2:
if item[0] == 'ground_truth':
gt = str(item[1])
break
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for extracting ground_truth from user_data can be simplified. Converting the list of tuples to a dictionary first makes the code more readable and robust, especially if more items are added to user_data in the future.

Suggested change
# Get ground truth from user_data
gt = ''
user_data = trajectory.get('user_data', [])
if isinstance(user_data, list):
for item in user_data:
if isinstance(item, (list, tuple)) and len(item) == 2:
if item[0] == 'ground_truth':
gt = str(item[1])
break
# Get ground truth from user_data
user_data_dict = dict(trajectory.get('user_data', []))
gt = str(user_data_dict.get('ground_truth', ''))

Comment on lines +289 to +297
# Use the corresponding user data for this sequence
trajectories.append({
'messages': [
{'role': 'system', 'content': SYSTEM_PROMPT},
{'role': 'user', 'content': 'Math problem'}, # Placeholder
{'role': 'assistant', 'content': decoded_text}
],
'user_data': all_user_data[idx]
})
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Instead of using a placeholder for the user's question, you can now retrieve it from the user_data to construct a complete and accurate trajectory. This makes the data representation more robust and easier to debug. This change depends on also adding the 'question' to the user_data in the GSM8KProcessor.

Suggested change
# Use the corresponding user data for this sequence
trajectories.append({
'messages': [
{'role': 'system', 'content': SYSTEM_PROMPT},
{'role': 'user', 'content': 'Math problem'}, # Placeholder
{'role': 'assistant', 'content': decoded_text}
],
'user_data': all_user_data[idx]
})
# Use the corresponding user data for this sequence
user_data_dict = dict(all_user_data[idx])
question = user_data_dict.get('question', 'Math problem') # Fallback to placeholder
trajectories.append({
'messages': [
{'role': 'system', 'content': SYSTEM_PROMPT},
{'role': 'user', 'content': question},
{'role': 'assistant', 'content': decoded_text}
],
'user_data': all_user_data[idx]
})

Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds an RL (GRPO) client example and related server-side plumbing/config updates to support sampler-weight syncing and unified forward/backward behavior across Transformers and Megatron backends.

Changes:

  • Add a Tinker-compatible GSM8K GRPO training example that periodically saves weights for the sampler and then samples generations.
  • Unify Tinker model server forward_backward to pass loss_fn for both Megatron and Transformers compatibility wrappers; adjust datum conversion to ensure weights are present.
  • Update cookbook/server configs and defaults to use Qwen2.5-3B-Instruct and extend supported model list.

Reviewed changes

Copilot reviewed 20 out of 20 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
src/twinkle/server/utils/io_utils.py Makes save dir absolute; adds sampler checkpoint cleanup before writing sampler metadata.
src/twinkle/server/twinkle/model.py Changes default adapter_config handling in build_model_app.
src/twinkle/server/tinker/server.py Adds Qwen2.5-3B-Instruct to the supported model list.
src/twinkle/server/tinker/model.py Unifies forward_backward call signature; updates sampler-weight save ordering; passes base model into template.
src/twinkle/server/tinker/common/transformers_model.py Adjusts loss selection logic in the Transformers compatibility wrapper.
src/twinkle/server/tinker/common/megatron_model.py Updates Megatron compatibility wrapper to accept loss_fn and RL extras.
src/twinkle/server/tinker/common/datum.py Adjusts masking/weights behavior and ensures weights exist for downstream computations.
src/twinkle/model/transformers/multi_lora_transformers.py Enables gradient checkpointing in MultiLoRA Transformers model initialization.
src/twinkle/model/megatron/multi_lora_megatron.py Changes how the Megatron PEFT patch is applied.
cookbook/client/twinkle/transformer/server_config.yaml Switches example config to Qwen2.5-3B-Instruct.
cookbook/client/twinkle/transformer/sampler.py Updates example MODEL_ID to Qwen2.5-3B-Instruct.
cookbook/client/twinkle/transformer/lora.py Updates base model, introduces use_megatron toggle, and adjusts optimizer/scheduler usage accordingly.
cookbook/client/twinkle/transformer/grpo.py Updates example MODEL_ID to Qwen2.5-3B-Instruct.
cookbook/client/twinkle/megatron/server_config.yaml Updates base model and adds adapter config + mixed precision.
cookbook/client/twinkle/megatron/lora.py Removes the prior Megatron LoRA training example script.
cookbook/client/tinker/transformer/server_config.yaml Switches example config to Qwen2.5-3B-Instruct.
cookbook/client/tinker/transformer/self_congnition.py Updates example base model to Qwen2.5-3B-Instruct.
cookbook/client/tinker/transformer/sample.py Updates example base model to Qwen2.5-3B-Instruct.
cookbook/client/tinker/transformer/lora.py Updates example base model and resume-path example for Qwen2.5-3B-Instruct.
cookbook/client/tinker/transformer/gsm8k.py Adds a new GSM8K GRPO training example using save-weights-for-sampler flow.
cookbook/client/tinker/transformer/grpo.py Updates base model and scales up example training loop parameters.
cookbook/client/tinker/megatron/server_config.yaml Updates base model, adds adapter config, and adds sampler service configuration.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 534 to 543
model_id=body.model_id,
is_sampler=True
)

# NOTE: Need to save meta first to ensure only one sample weight exists
tinker_path = checkpoint_manager.save(
body.model_id, name=checkpoint_name, is_sampler=True)

logger.info(f"Saving weights to {save_dir}")
# Save weights with save_optimizer=False for sampler use
self.model.save(name=checkpoint_name,
Copy link

Copilot AI Feb 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

save_weights_for_sampler writes checkpoint metadata (and deletes previous sampler checkpoints) before self.model.save(...) runs. This can leave a listed checkpoint with metadata but missing/partial weights if model.save fails, and size_bytes in metadata will be 0 because the directory is still empty at metadata time. Safer flow is: delete old sampler checkpoints (if needed) → save weights → write metadata (or write metadata last / mark it as incomplete and finalize after save succeeds).

Copilot uses AI. Check for mistakes.
Comment on lines +142 to 143
adapter_config: Dict[str, Any] = {},
**kwargs):
Copy link

Copilot AI Feb 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adapter_config uses a mutable default ({}), which can be shared across requests/processes and lead to cross-request state leakage if mutated. Use adapter_config: Optional[Dict[str, Any]] = None (or Mapping) and normalize to {} inside the function.

Suggested change
adapter_config: Dict[str, Any] = {},
**kwargs):
adapter_config: Optional[Dict[str, Any]] = None,
**kwargs):
if adapter_config is None:
adapter_config = {}

Copilot uses AI. Check for mistakes.
Comment on lines 39 to +41
use_megatron: bool = False,
adapter_config: Dict[str, Any] = None,
queue_config: Optional[Dict[str, Any]] = None,
adapter_config: Dict[str, Any] = {},
queue_config: Optional[Dict[str, Any]] = {},
Copy link

Copilot AI Feb 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adapter_config / queue_config use mutable defaults ({}). If any code mutates these dicts (directly or indirectly), the mutated value will be reused for subsequent app builds. Use None defaults and normalize inside build_model_app / ModelManagement.__init__.

Copilot uses AI. Check for mistakes.
Comment on lines +98 to +99
super().set_loss('CrossEntropyLoss',
adapter_name=adapter_name)
Copy link

Copilot AI Feb 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For unknown loss_fn, this now silently falls back to CrossEntropyLoss. That can mask client/server mismatches and make debugging hard (and can produce incorrect training behavior). Prefer validating loss_fn and raising a clear error (or explicitly mapping supported values).

Suggested change
super().set_loss('CrossEntropyLoss',
adapter_name=adapter_name)
raise ValueError(
f"Unsupported loss_fn '{loss_fn}'. Expected one of: 'cross_entropy', 'importance_sampling'."
)

Copilot uses AI. Check for mistakes.
Comment on lines 85 to +89
'logprobs' and 'elementwise_loss', and loss is a scalar.
"""
if loss_fn == 'importance_sampling':
super().set_loss('GRPOLoss',
adapter_name=adapter_name,
Copy link

Copilot AI Feb 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

forward_backward only sets GRPOLoss when loss_fn == 'importance_sampling', but never resets the loss for other values. If a prior call set GRPOLoss, subsequent cross-entropy calls will keep using it and likely miscompute/require missing RL fields. Set CrossEntropyLoss (or the Megatron default loss) explicitly when loss_fn is not RL, and validate/raise on unsupported values.

Copilot uses AI. Check for mistakes.
types.AdamParams(learning_rate=LEARNING_RATE))

fwdbwd_result = fwdbwd_future.result()
optim_result = optim_future.result()
Copy link

Copilot AI Feb 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable optim_result is not used.

Suggested change
optim_result = optim_future.result()
optim_future.result()

Copilot uses AI. Check for mistakes.
from twinkle.dataloader import DataLoader
from twinkle.preprocessor import Preprocessor
from twinkle.reward.base import Reward
from twinkle.data_format import Trajectory, InputFeature, Message
Copy link

Copilot AI Feb 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'InputFeature' is not used.

Suggested change
from twinkle.data_format import Trajectory, InputFeature, Message
from twinkle.data_format import Trajectory, Message

Copilot uses AI. Check for mistakes.
@Yunnglin Yunnglin merged commit c98fee9 into dev Feb 10, 2026
0 of 4 checks passed
@tastelikefeet tastelikefeet deleted the add_rl_example branch February 13, 2026 09:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants