Conversation
Summary of ChangesHello @Yunnglin, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces comprehensive reinforcement learning examples, significantly refactoring the Twinkle sampler server to enhance its API, improve LoRA adapter management, and support diverse sampler backends. It also extends client-side sampling capabilities and refines dataset serialization for greater flexibility. Additionally, the server setup documentation has been updated to provide clearer guidance on Ray cluster configurations, ensuring a more robust and user-friendly experience for deploying and utilizing the framework's advanced features. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces significant enhancements for reinforcement learning workflows, including new GRPO training examples for both Tinker and Twinkle clients, a major refactoring of the sampler server for improved robustness and functionality, and updated client-side APIs. The changes are extensive and well-structured. I've identified a few critical issues in the new example code and server implementation that need to be addressed, along with some suggestions for improving the documentation and code quality. Overall, this is a great step forward for the project.
| old_logps_list = [] | ||
| completion_lengths = [] | ||
|
|
||
| sequences = sample_response.get('sequences', []) | ||
| for seq in sequences: | ||
| input_features.append(seq.get('new_input_feature', seq)) | ||
| old_logps_list.append(seq.get('logprobs', [])) | ||
| completion_lengths.append(len(seq.get('tokens', []))) | ||
|
|
||
| if not input_features: | ||
| logger.warning(f"Step {step}: No valid samples, skipping") | ||
| step += 1 | ||
| continue | ||
|
|
||
| # ========== 3. Compute rewards ========== | ||
| total_rewards, format_rewards, accuracy_rewards = compute_rewards( | ||
| input_features) | ||
| metrics.accumulate( |
There was a problem hiding this comment.
There are a couple of issues in this block that will prevent the example from running correctly:
-
Incorrect
input_featuresforcompute_rewards: Thecompute_rewardsfunction expects a list of trajectories (dictionaries with amessageskey), butinput_featuresis populated with dictionaries from the sampler's response ({'tokens': ..., 'logprobs': ...}). You need to decode the generated tokens into text and construct trajectory dictionaries, similar to thetinkerexample. -
Incorrect
inputsformodel.forward_backward: Themodel.forward_backwardmethod expects a list ofInputFeatureobjects representing the full prompt + completion sequence. The currentinput_featureslist does not have the correct structure. -
Missing Tokenizer: To decode the tokens for reward calculation, a tokenizer is needed, but it's not initialized in this script.
I suggest restructuring this part of the training loop to correctly process the sampler's output. You'll need to initialize a tokenizer, use it to decode completions for reward calculation, and then construct new InputFeature objects for the training step by combining the prompt features with the generated sequences.
|
|
||
| if body.adapter_uri: | ||
| from .common.io_utils import create_checkpoint_manager | ||
| token = get_token_from_request(request) |
There was a problem hiding this comment.
The import path for create_checkpoint_manager appears to be incorrect. Based on the file structure, io_utils.py is in the utils directory, not a common subdirectory. This relative import will likely cause a ModuleNotFoundError.
| token = get_token_from_request(request) | |
| from twinkle.server.utils.io_utils import create_checkpoint_manager |
| from twinkle.dataloader import DataLoader | ||
| from twinkle.dataset import Dataset, DatasetMeta | ||
| from twinkle.metric import CompletionRewardMetric | ||
| from twinkle.server.tinker.common import input_feature_to_datum |
| # or None to use the base model | ||
| # ADAPTER_URI = None | ||
| # Example: | ||
| ADAPTER_URI = "twinkle://20260208_224851-fa3cdd11-default/weights/twinkle-epoch-2" |
There was a problem hiding this comment.
| max_replicas: 1 | ||
| target_ongoing_requests: 16 | ||
| ray_actor_options: | ||
| num_cpus: 0.1 No newline at end of file |
|
|
||
| ```bash | ||
| # 第二个 GPU 节点,使用 GPU 4-7,共 4 个 GPU | ||
| CUDA_VISIBLE_DEVICES=4,5,6,7 ray start --address=10.28.252.9:6379 --num-gpus=4 |
|
|
||
| # Sampler 服务占用 Node 1(Worker 节点,GPU 4-7) | ||
| - name: sampler-Qwen2.5-7B-Instruct | ||
| route_prefix: /sampler/Qwen/Qwen2.5-7B-Instruct |
There was a problem hiding this comment.
The route_prefix here is /sampler/... (singular), but in the server_config.yaml and client implementations, it's /samplers/... (plural). This should be corrected to /samplers/... for consistency.
| route_prefix: /sampler/Qwen/Qwen2.5-7B-Instruct | |
| route_prefix: /samplers/Qwen/Qwen2.5-7B-Instruct |
| def __del__(self): | ||
| try: | ||
| heartbeat_manager.unregister_processor(self.processor_id) | ||
| except: |
There was a problem hiding this comment.
Using a bare except: clause is generally discouraged as it can catch unexpected system-exiting exceptions (like SystemExit or KeyboardInterrupt), making it harder to debug or interrupt the program. It's better to catch a more specific exception, like except Exception:.
| except: | |
| except Exception: |
There was a problem hiding this comment.
Pull request overview
This PR adds a client/server RL (GRPO) cookbook example and extends the sampler HTTP API + client wrappers to support LoRA adapter loading via adapter_uri and multi-sample generation (num_samples). It also updates checkpoint/serialization utilities to better support HTTP-mode workflows used by the new examples.
Changes:
- Update Twinkle sampler server and client to use
/samplers/...routes and acceptadapter_uri+num_samples. - Add GRPO training and sampler cookbook examples; add a GRPO processor client wrapper.
- Add
data_sliceserialization support forDatasetMetain HTTP mode and centralize adapter-URI parsing in the checkpoint manager.
Reviewed changes
Copilot reviewed 27 out of 27 changed files in this pull request and generated 24 comments.
Show a summary per file
| File | Description |
|---|---|
| src/twinkle_client/sampler/vllm_sampler.py | Update sampler client URL routing and request payload to include adapter_uri/num_samples. |
| client_tools/client_generator.py | Update sampler client code generation to match the new sampler route and request fields. |
| src/twinkle_client/processor/grpo.py | Add a new GRPO processor client wrapper for server-side preprocessing. |
| src/twinkle/server/twinkle/sampler.py | Major rewrite of the sampler service API using Pydantic request/response models + adapter lifecycle handling. |
| src/twinkle/server/utils/io_utils.py | Add parse_adapter_uri() helper to checkpoint manager to resolve LoRA adapter paths. |
| src/twinkle/server/twinkle/common/serialize.py | Support serializing/deserializing DatasetMeta.data_slice (e.g., range(...)) for HTTP mode. |
| src/twinkle/server/tinker/sampler.py | Switch to centralized parse_adapter_uri() implementation. |
| src/twinkle/model/transformers/multi_lora_transformers.py | Add active_group = None compatibility field. |
| src/twinkle/model/megatron/multi_lora_megatron.py | Add active_group = None compatibility field. |
| src/twinkle/infra/_ray/ray_helper.py | Adjust CPU worker env-var handling for Ray worker creation. |
| cookbook/client/twinkle/transformer/server_config.yaml | Update example server config and add a sampler service definition. |
| cookbook/client/twinkle/transformer/sampler.py | Add a Twinkle HTTP sampler inference example. |
| cookbook/client/twinkle/transformer/grpo.py | Add a Twinkle HTTP GRPO training example using model.save() + adapter_uri. |
| cookbook/client/tinker/transformer/grpo.py | Add a Tinker-compatible GRPO training example. |
Comments suppressed due to low confidence (3)
client_tools/client_generator.py:760
- The generated
VLLMSampler.__init__returnsresponse.json(). Returning a non-None value from__init__raisesTypeErroron instantiation, so any generated client will fail at runtime. Drop thereturn response.json()and instead just keep the response for validation / store needed values onself.
model_id = model_id.split('://')[1]
self.server_url = f'{self.server_url}/samplers/{model_id}'
response = http_post(
url=f'{self.server_url}/create',
json_data=kwargs
)
response.raise_for_status()
def _send_adapter_heartbeat(self):
"""Internal method to send adapter heartbeat."""
if not self.adapter_name:
return
response = http_post(
url=f'{self.server_url}/heartbeat',
json_data={'adapter_name': self.adapter_name}
)
src/twinkle/server/twinkle/sampler.py:325
deploy_optionsdefaults toNone, but is expanded with**deploy_optionswhen callingSamplerManagement.options(...). Ifdeploy_optionsis omitted by the caller, this will raise aTypeError. Consider defaulting to an empty dict (e.g.,deploy_options = deploy_options or {}) before using it.
nproc_per_node, device_group, device_mesh, sampler_type, engine_args, adapter_config, **kwargs)
src/twinkle_client/sampler/vllm_sampler.py:45
__init__returnsresponse.json(). In Python,__init__must returnNone; returning a dict will raiseTypeError: __init__() should return None, not 'dict'when constructingVLLMSampler. This will break the new cookbook examples that instantiateVLLMSampler. Remove the return value and store any needed fields onselfinstead.
self.adapter_name = None
if '://' in model_id:
model_id = model_id.split('://')[1]
self.server_url = f'{self.server_url}/samplers/{model_id}'
response = http_post(
url=f'{self.server_url}/create',
json_data=kwargs
)
response.raise_for_status()
def _send_adapter_heartbeat(self):
"""Internal method to send adapter heartbeat."""
if not self.adapter_name:
return
response = http_post(
url=f'{self.server_url}/heartbeat',
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| from twinkle import DeviceGroup, DeviceMesh | ||
| from twinkle.data_format import Trajectory, InputFeature | ||
| from twinkle.sampler import vLLMSampler | ||
| from twinkle.server.utils.validation import verify_request_token | ||
| from twinkle.data_format import Trajectory, InputFeature, SamplingParams | ||
| from twinkle.server.utils.adapter_manager import AdapterManagerMixin | ||
| from twinkle.server.utils.validation import verify_request_token, get_token_from_request |
There was a problem hiding this comment.
Import path looks incorrect: twinkle.sampler.types does not exist in this repo (sampling dataclasses live under twinkle.data_format.sampling). As-is, the sampler service will fail to import at startup. Update these imports to the correct module (and drop unused ones if needed).
| sequences.append({ | ||
| 'stop_reason': seq.stop_reason, | ||
| 'tokens': list(seq.tokens), | ||
| 'logprobs': list(seq.logprobs) if seq.logprobs is not None else None, |
There was a problem hiding this comment.
The HTTP response currently strips decoded and new_input_feature from SampledSequence. The new GRPO example (and existing sampler code) expects new_input_feature to be present to feed sampled continuations back into training. Either include these fields in the response (and in SampleResponseModel), or update the cookbook/examples to not rely on them.
| 'logprobs': list(seq.logprobs) if seq.logprobs is not None else None, | |
| 'logprobs': list(seq.logprobs) if seq.logprobs is not None else None, | |
| 'decoded': getattr(seq, 'decoded', None), | |
| 'new_input_feature': getattr(seq, 'new_input_feature', None), |
| num_samples: int = 1, | ||
| ) -> Dict[str, Any]: | ||
| """Sample from the model. | ||
|
|
||
| Args: | ||
| inputs: List of Trajectory or InputFeature to sample from. | ||
| sampling_params: Sampling parameters dict. | ||
| adapter_name: Adapter name. | ||
| adapter_name: Adapter name for LoRA inference. |
There was a problem hiding this comment.
This client class subclasses twinkle.sampler.base.Sampler, whose abstract sample() contract returns SampleResponse. Changing the override to return Dict[str, Any] breaks the base-class type contract and can confuse users/type-checkers. Consider either (a) not inheriting from Sampler for HTTP clients, or (b) returning/constructing a SampleResponse object on the client side to preserve the API.
| MODEL_ID = 'ms://Qwen/Qwen2.5-0.5B-Instruct' | ||
| NUM_GENERATIONS = 8 |
There was a problem hiding this comment.
This GRPO example uses MODEL_ID = 'ms://Qwen/Qwen2.5-3B-Instruct', but the provided server_config.yaml in the same directory deploys Qwen2.5-0.5B-Instruct. Unless the server config is updated accordingly, the model/sampler routes won’t exist for 3B. Consider aligning the example with the shipped config, or add a clear note that the config must be changed to match MODEL_ID.
| self.sampler.add_adapter_to_sampler(full_adapter_name, config) | ||
|
|
||
| self.register_adapter(full_adapter_name, token) | ||
| allowed, reason = self.check_adapter_limit(token, True) | ||
| if not allowed: | ||
| raise RuntimeError(reason) | ||
|
|
||
| return AddAdapterResponse(adapter_name=full_adapter_name) |
There was a problem hiding this comment.
Per-token adapter limit enforcement happens after the adapter is added to the sampler and registered. If the limit is exceeded, this raises but leaves the adapter loaded/registered (resource leak and limit bypass). Check check_adapter_limit(token, True) before adding/registering, or rollback (remove adapter + unregister) on failure.
| from twinkle.sampler import vLLMSampler | ||
| from twinkle.server.utils.validation import verify_request_token | ||
| from twinkle.data_format import Trajectory, InputFeature, SamplingParams | ||
| from twinkle.server.utils.adapter_manager import AdapterManagerMixin |
There was a problem hiding this comment.
Import of 'SampleResponse' is not used.
Import of 'SampledSequence' is not used.
| from twinkle.server.utils.adapter_manager import AdapterManagerMixin |
| from twinkle.data_format import InputFeature | ||
| from .base import InputProcessor | ||
|
|
||
| class GRPOLossProcessor(InputProcessor): |
There was a problem hiding this comment.
This class does not call InputProcessor.init during initialization. (GRPOLossProcessor.init may be missing a call to a base class init)
| from twinkle.data_format import InputFeature | ||
| from .base import InputProcessor | ||
|
|
||
| class GRPOLossProcessor(InputProcessor): |
There was a problem hiding this comment.
This class does not call InputProcessor.del during finalization. (GRPOLossProcessor.del may be missing a call to a base class del)
| def __del__(self): | ||
| try: | ||
| heartbeat_manager.unregister_processor(self.processor_id) | ||
| except: |
There was a problem hiding this comment.
Except block directly handles BaseException.
| except: | |
| except Exception: |
| def __del__(self): | ||
| try: | ||
| heartbeat_manager.unregister_processor(self.processor_id) | ||
| except: |
There was a problem hiding this comment.
'except' clause does nothing but pass and there is no explanatory comment.
No description provided.