In [1]:
%load_ext autoreload
%autoreload 2
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '4'

In [None]:

import itertools
import random
from copy import deepcopy
from datasets import Dataset
from pprint import pprint
combination_length = 3
max_attempts = 12

vocab = "qawsedrftgyhujik" # 16 chars.
combinations = list(itertools.permutations(list(vocab), combination_length))
rand = random.Random(42)
rand.shuffle(combinations)
test_combinations = deepcopy(combinations[:10])
rand.shuffle(combinations)
def create_data(split, idx, solution, format="tool", is_instruct_model=True):
    if "interaction" in format:
        position_str = ", ".join(f'Position {i + 1}' for i in range(combination_length))
        example_str = ", ".join(f"'char {i + 1}'" for i in range(combination_length))
        system_prompt = (f"You will determine the combination for [{position_str}] in a {combination_length}-character combination lock through iterative reasoning and queries.\n"
                            f"The lock requires a {combination_length}-character combination. All {combination_length} characters are unique.\n"
                            f"The set of valid characters are as follows: {list(vocab)}\n"
                            "You get feedback from the user on the presence of a character. Either not in the lock, in the lock but in a different position, or in the lock and in the right position.\n"
                            f"You have {max_attempts} queries.\n"
                            "Your goal is to find the correct combination in the least number of queries.\n"
                            f"To query the lock, you will first think step by step, and then generate a query formatted as a list of {combination_length} characters inside <action> ... </action>, e.g., <action>[{example_str}]</action>.\n")
                            # "Ensure you don't put '?' in your queries. Make sure you submit vocabulary characters in your query.\n"
                            # "At each step, you can see the instructions, the current belief state in <belief> ... </belief>, the latest action in <action> ... </action>, and the returned feedback in <feedback> ... </feedback>.\n"
                            # "First think step by step then"
                            #f"Each query should be formatted as a python list: [{example_str}].\n")
        if format == "interaction_base":
            system_prompt = system_prompt.strip()
        else:
            raise Exception(f"invalid format {format}")
        
        prompt_list = [ {
                        "role": "system",
                        "content": system_prompt,
                    } ]
        if is_instruct_model:
            prompt_list.append({
                "role": "user",
                "content": "Give me your first guess.",
            })
        return {
                "data_source": "multi_turn_combo_lock",
                "prompt": prompt_list,
                "ability": "math",
                "reward_model": {"style": "rule", "ground_truth": solution},
                "extra_info": {
                    "split": split,
                    "index": idx,
                    "answer": solution,
                    "interaction_kwargs": {
                        "combination_length": combination_length,
                        "max_attempts": max_attempts,
                        "vocab": vocab,
                        "ground_truth": solution,
                        "format": format,
                    },
                },
            }
  
format = "interaction_base"
is_instruct_model = False
train_data = [create_data('train', idx, combo_sol, format, is_instruct_model=is_instruct_model) for idx, combo_sol in enumerate(combinations)]
test_data = [create_data('test', idx, combo_sol, format, is_instruct_model=is_instruct_model) for idx, combo_sol in enumerate(test_combinations)]

Dataset.from_list(train_data).to_parquet(f"../data/multi_turn_combo_lock_{format+('_base' if not is_instruct_model else '')}/train.parquet")
Dataset.from_list(test_data).to_parquet(f"../data/multi_turn_combo_lock_{format+('_base' if not is_instruct_model else '')}/test.parquet")
pprint(test_data[0])
# print(test_data[0]['prompt'][0]['content'])

Creating parquet from Arrow format:   0%|          | 0/4 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

{'ability': 'math',
 'data_source': 'multi_turn_combo_lock',
 'extra_info': {'answer': ('d', 'h', 'y'),
                'index': 0,
                'interaction_kwargs': {'combination_length': 3,
                                       'format': 'interaction_base',
                                       'ground_truth': ('d', 'h', 'y'),
                                       'max_attempts': 12,
                                       'vocab': 'qawsedrftgyhujik'},
                'split': 'test'},
 'prompt': [{'content': 'You will determine the combination for [Position 1, '
                        'Position 2, Position 3] in a 3-character combination '
                        'lock through iterative reasoning and queries.\n'
                        'The lock requires a 3-character combination. All 3 '
                        'characters are unique.\n'
                        "The set of valid characters are as follows: ['q', "
                        "'a', 'w', 's', 'e', 'd', 'r', 'f', 't

In [2]:
from hydra import initialize, compose
from omegaconf import OmegaConf

# Use this when you want to use the config outside @hydra.main
def get_config():
    overrides = [
        "algorithm.adv_estimator=grpo",
        "data.train_batch_size=512",
        "data.max_prompt_length=1024",
        "data.max_response_length=1024",  # MAX_RESP
        "data.filter_overlong_prompts=True",
        "data.truncation=error",
        "data.return_raw_chat=True",
        "actor_rollout_ref.model.path=Qwen/Qwen2.5-3B",
        "actor_rollout_ref.model.use_remove_padding=True",
        "+actor_rollout_ref.model.enable_activation_offloading=True",
        "actor_rollout_ref.actor.optim.lr=1e-6",
        "actor_rollout_ref.actor.ppo_mini_batch_size=512",
        "actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8",
        "actor_rollout_ref.actor.use_kl_loss=True",
        "actor_rollout_ref.actor.kl_loss_coef=0.001",
        "actor_rollout_ref.actor.kl_loss_type=low_var_kl",
        "actor_rollout_ref.actor.entropy_coeff=0",
        "actor_rollout_ref.actor.fsdp_config.param_offload=False",
        "actor_rollout_ref.actor.fsdp_config.optimizer_offload=False",
        "+actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16",
        "actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8",
        "actor_rollout_ref.rollout.tensor_model_parallel_size=2",
        "actor_rollout_ref.rollout.name=sglang",
        "actor_rollout_ref.rollout.gpu_memory_utilization=0.5",
        "actor_rollout_ref.rollout.n=1",
        "actor_rollout_ref.rollout.temperature=0.5",
        "actor_rollout_ref.rollout.is_instruct_model=False",
        "actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8",
        "actor_rollout_ref.ref.fsdp_config.param_offload=False",
        "algorithm.use_kl_in_reward=False",
        "trainer.critic_warmup=0",
        "trainer.logger=['console','wandb']",
        "trainer.project_name=verl-tests",
        "trainer.experiment_name=qwen2.5-3b_function_rm-combolock-sgl-multi-w-interaction-n8-2-debug",
        "trainer.n_gpus_per_node=1",
        "trainer.nnodes=1",
        "trainer.save_freq=20",
        "data.train_files=../data/multi_turn_combo_lock_interaction_base_base/train.parquet",
        "data.val_files=../data/multi_turn_combo_lock_interaction_base_base/test.parquet",
        "actor_rollout_ref.rollout.multi_turn.interaction_config_path=./examples/sglang_multiturn/config/interaction_config/combolock_interaction_config.yaml",
        "actor_rollout_ref.rollout.multi_turn.multi_context=True",
        "reward_model.reward_manager=multiturn",
        "trainer.log_val_generations=10",
        "+custom_reward_function.path=verl/utils/reward_score/multiturn.py",
        "trainer.total_epochs=15",
        # optionally:
        # "trainer.test_freq=5",
    ]
    with initialize(config_path="examples/sglang_multiturn/config", version_base=None): 
        cfg = compose(config_name="combolock_multiturn_grpo_w_interaction", overrides=overrides)
        return cfg

cfg = get_config()
#print(OmegaConf.to_yaml(cfg))

In [3]:
import nest_asyncio

nest_asyncio.apply()
import sglang as sgl

from sglang.test.test_utils import is_in_ci

if is_in_ci():
    import patch
else:
    import nest_asyncio

    nest_asyncio.apply()


llm = sgl.Engine(model_path=cfg.actor_rollout_ref.model.path)

# prompts = [
#     "Hello, my name is",
#     "The president of the United States is",
# ]
# sampling_params = {"temperature": 0.8, "top_p": 0.95}
# outputs = llm.async_generate(prompts, sampling_params)
# for prompt, output in zip(prompts, outputs):
#     print("===============================")
#     print(f"Prompt: {prompt}\nGenerated text: {output['text']}")

Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  50% Completed | 1/2 [00:00<00:00,  2.01it/s]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:01<00:00,  1.86it/s]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:01<00:00,  1.88it/s]

2025-07-28 18:43:43,042 - INFO - flashinfer.jit: Prebuilt kernels not found, using JIT backend
Capturing batches (avail_mem=4.45 GB):   0%|          | 0/23 [00:00<?, ?it/s]2025-07-28 18:43:43,652 - INFO - flashinfer.jit: Loading JIT ops: batch_prefill_with_kv_cache_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_qk_128_head_dim_vo_128_posenc_0_use_swa_False_use_logits_cap_False_f16qk_False
2025-07-28 18:43:43,716 - INFO - flashinfer.jit: Finished loading JIT ops: batch_prefill_with_kv_cache_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_qk_128_head_dim_vo_128_posenc_0_use_swa_False_use_logits_cap_False_f16qk_False
Capturing batc

In [18]:
import asyncio
import logging
import os
from copy import deepcopy
from json import JSONDecodeError
from uuid import uuid4

import numpy as np
import torch
from omegaconf import DictConfig

from sglang.srt.openai_api.protocol import Tool
from sglang.srt.sampling.sampling_params import SamplingParams

from tensordict import TensorDict
from torch.nn.utils.rnn import pad_sequence
from transformers import PreTrainedTokenizer

from verl import DataProto
from verl.interactions.base import BaseInteraction
from verl.tools.base_tool import BaseTool
from verl.tools.schemas import OpenAIFunctionCallSchema, OpenAIFunctionParsedSchema, OpenAIFunctionToolCall 
from verl.tools.utils.tool_registry import initialize_tools_from_config
from verl.utils.debug import GPUMemoryLogger
from verl.utils.torch_functional import get_response_mask, pad_sequence_to_length
from verl.workers.rollout.base import BaseRollout
from verl.workers.rollout.schemas import AsyncRolloutRequest, AsyncRolloutRequestStateEnum, FinishReasonTypeEnum, AsyncRolloutRequestInterface, AsyncRolloutRequestMultiContext


from types import SimpleNamespace

from verl.utils import hf_tokenizer, hf_processor
from verl.utils.fs import copy_to_local
from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler
from verl import DataProto
from torchdata.stateful_dataloader import StatefulDataLoader

try:
    from sglang.srt.function_call.function_call_parser import FunctionCallParser
except ImportError:
    from sglang.srt.function_call_parser import FunctionCallParser


logger = logging.getLogger()
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))


# NOTE(sgm): add for verl. We can optimize it by making
#  the dataloader yield List[int] without padding.
def _pre_process_inputs(
    pad_token_id,
    prompt_token_ids: torch.Tensor,
) -> list[int]:
    # remove the left padding in the prompt token_id
    non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0]
    token_ids = prompt_token_ids[non_pad_index:].tolist()
    return token_ids


# NOTE(linjunrong): adhoc
def _post_process_outputs(tokenizer, output):
    def _map_each_response(resp):
        output_token_logprobs = resp["meta_info"]["output_token_logprobs"]
        log_probs, output_token_ids = zip(*[(log_prob, token_ids) for log_prob, token_ids, _ in output_token_logprobs])
        return torch.tensor(output_token_ids), torch.tensor(log_probs)

    out_map = map(lambda x: _map_each_response(x), output)
    batched_output_token_ids = []
    batched_logprobs = []
    for output_token_ids, log_probs in out_map:
        batched_output_token_ids.append(output_token_ids)
        batched_logprobs.append(log_probs)
    pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
    batched_output_token_ids = pad_sequence(batched_output_token_ids, batch_first=True, padding_value=pad_token_id)
    if len(batched_logprobs) > 0:
        batched_logprobs = pad_sequence(batched_logprobs, batch_first=True, padding_value=pad_token_id)
    return batched_output_token_ids, batched_logprobs


def get_tool_call_parser_type(tokenizer: PreTrainedTokenizer) -> str:
    items = FunctionCallParser.ToolCallParserEnum.items()
    for parser_type, parser_cls in items:
        parser = parser_cls()
        if parser.bot_token.strip() in tokenizer.get_vocab() and (parser.eot_token == "" or parser.eot_token.strip() in tokenizer.get_vocab()):
            return parser_type
    else:
        raise ValueError(f"No tool call parser found for tokenizer {tokenizer}")


class SGLangRollout(BaseRollout):
    def __init__(
        self,
        config: DictConfig,
        tokenizer,
        model_hf_config,
        **kwargs,
    ):
        """Synchronized SGLang rollout engine.

        Args:
            actor_module: Huggingface model name or path to the model. The
                model should be supported by SGLang.
            config: A DictConfig object containing SGLang-specific operational
                parameters and rollout settings.
                Refer to https://docs.sglang.ai/backend/server_arguments.html
            tokenizer: The tokenizer instance compatible with the actor_module.
            model_hf_config: The Hugging Face model's configuration (e.g.,
                `transformers.PretrainedConfig`). It provides architectural
                details and hyperparameters like `max_position_embeddings`,
                used by SGLang for correct model initialization. This is
                the model's inherent design, not SGLang's runtime behavior.
            port: Optional port for multi-node initialization when nnodes > 1.
            trust_remote_code: Whether or not to allow for custom models
                defined on the Hub in their own modeling files.
            device_mesh: Optional `DeviceMesh` object for distributed setup.
            **kwargs: Additional keyword arguments, primarily `train_tp` for
                Megatron Backend integration to initialize hybrid engine
                process groups.
        """
        super().__init__()

        self.config = config
        
        os.environ.setdefault("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK", "true")

        (
            self._tool_schemas,
            self._tool_map,
            self._tool_call_parser_type,
            self._sgl_tools,
            self._function_call_parser,
        ) = self._initialize_tools(config, tokenizer)
        self.interaction: dict[str, BaseInteraction] = self._intitalize_interaction(config)
        # If turn on `free_cache_engine`, SGLang engine's KV cache
        # will be freed after each `generate_sequences` call.
        assert not (not config.enforce_eager and config.free_cache_engine), "disable CUDA graph (enforce_eager = False) if free cache engine"

        logger.info(f"tool_schemas: {self._tool_schemas}, tool_map: {self._tool_map}, tool_call_parser_type: {self._tool_call_parser_type}, sgl_tools: {self._sgl_tools}, function_call_parser: {self._function_call_parser}")

        self._verify_config(model_hf_config)
        self._init_sampling_params(**kwargs)

        self.tokenizer = tokenizer
        self.pad_token_id = tokenizer.pad_token_id


    def _verify_config(self, model_hf_config):
        if not self.config.get("max_model_len", None):
            self.config.max_model_len = self.config.prompt_length + self.config.response_length
        assert self.config.max_model_len >= self.config.prompt_length + self.config.response_length, f"""max_model_len should be greater than total sequence length (prompt_length + response_length): 
            {self.config.max_model_len} >= {self.config.prompt_length} + {self.config.response_length}"""
        # assert model_hf_config.max_position_embeddings >= self.config.max_model_len, "model context length should be greater than total sequence length"
        # currently max_assistant_turns stand for max number of tool calls
        if self.config.multi_turn.max_assistant_turns is None:
            self.config.multi_turn.max_assistant_turns = self.config.max_model_len // 3
        if self.config.multi_turn.max_user_turns is None:
            self.config.multi_turn.max_user_turns = self.config.max_model_len // 3


    def _init_sampling_params(self, **kwargs):
        kwargs = dict(
            n=1,
            max_new_tokens=self.config.response_length,
            presence_penalty=0.0,
            frequency_penalty=0.0,
            repetition_penalty=1.0,
        )
        # supporting adding any sampling params from the config file
        for k in self.config.keys():
            if hasattr(SamplingParams(), str(k)):
                kwargs[k] = self.config.get(k)
        # Jakob special sampling param codes for ease
        if not self.config.get("is_instruct_model", True):
            kwargs["no_stop_trim"] = True
            kwargs["stop"] = ["</action>","</belief>"] # "</answer>" looks potentially unnecessary?
        self.sampling_params = kwargs

    def _initialize_tools(self, config, tokenizer):
        """Initialize tools from configuration.
        Args:
            config: Configuration object containing tool-related settings,
                    specifically `config.multi_turn.tool_config_path`.
            tokenizer: The tokenizer instance used for parsing tool calls from
                       the model's generated text.

        Returns:
            tuple: A tuple containing:
                - tool_schemas (list[dict]): OpenAI-formatted JSON schemas
                  defining each tool's capabilities.
                - tool_map (dict[str, BaseTool]): A dictionary mapping tool
                  names to their executable `BaseTool` objects.
                - tool_call_parser_type (str): The identifier for the specific
                  parser type (e.g., 'json_mode', 'tool_code') used to extract
                  tool calls.
                - sgl_tools (list[sglang.srt.openai_api.protocol.Tool]): Tool
                  definitions optimized for SGLang's internal engine.
                - function_call_parser (sglang.srt.function_call_parser.FunctionCallParser):
                  The active parser instance responsible for extracting
                  structured tool calls from model outputs.
        """
        if config.multi_turn.tool_config_path is None:
            return [], {}, None, [], None

        tools_config_file = config.multi_turn.tool_config_path
        tool_list = initialize_tools_from_config(tools_config_file)

        tool_schemas = [tool.get_openai_tool_schema().model_dump() for tool in tool_list]
        tool_map = {tool.name: tool for tool in tool_list}
        tool_call_parser_type = get_tool_call_parser_type(tokenizer)
        sgl_tools = [Tool.model_validate(tool_schema) for tool_schema in tool_schemas]
        function_call_parser = FunctionCallParser(
            sgl_tools,
            tool_call_parser_type,
        )

        return (
            tool_schemas,
            tool_map,
            tool_call_parser_type,
            sgl_tools,
            function_call_parser,
        )

    def _intitalize_interaction(self, config):
        import importlib.util
        import sys

        from omegaconf import OmegaConf

        if config.multi_turn.interaction_config_path is None:
            return None
        interaction_config_file = config.multi_turn.interaction_config_path
        interaction_config = OmegaConf.load(interaction_config_file).interaction[0]
        cls_name = interaction_config.class_name
        module_name, class_name = cls_name.rsplit(".", 1)
        if module_name not in sys.modules:
            spec = importlib.util.find_spec(module_name)
            module = importlib.util.module_from_spec(spec)
            sys.modules[module_name] = module
            spec.loader.exec_module(module)
        else:
            module = sys.modules[module_name]

        interaction_cls = getattr(module, class_name)

        interaction = interaction_cls(config=OmegaConf.to_container(interaction_config.config, resolve=True))
        return interaction

    async def _async_rollout_a_request(
        self,
        req: AsyncRolloutRequestInterface,
        do_sample: bool = True,
        is_validate: bool = False,
        **kwargs,
    ) -> AsyncRolloutRequestInterface:
        _req = deepcopy(req)

        finish_reason_type = None
        output = None

        current_turns = 0
        user_turns = 0
        user_turn_rewards = []

        # Create request-level sampling parameters
        request_sampling_params = self.sampling_params.copy()
        if not do_sample:
            request_sampling_params.update(
                {
                    "n": 1,
                    "presence_penalty": 0.0,
                    "frequency_penalty": 0.0,
                    "repetition_penalty": 1.0,
                    "temperature": 0,
                    "top_p": 1,
                    "top_k": -1,
                    "ignore_eos": False,
                    "min_new_tokens": 0,
                    "max_new_tokens": self.config.response_length,
                    "skip_special_tokens": True,
                    "spaces_between_special_tokens": True,
                }
            )
        elif is_validate:
            request_sampling_params.update(
                {
                    "top_k": self.config.val_kwargs.top_k,
                    "top_p": self.config.val_kwargs.top_p,
                    "temperature": self.config.val_kwargs.temperature,
                    "n": 1,  # if validate, already repeat in ray_trainer
                }
            )

        # Update with any additional kwargs
        request_sampling_params.update(kwargs)
        # run success, completion, and # attempts
        # number of tokens in the assistant messages
        tokens_per_assistant_message = []
        run_completion = False
        
        while current_turns < self.config.multi_turn.max_assistant_turns:
            if _req.state == AsyncRolloutRequestStateEnum.PENDING:
                await self._handle_pending_state(_req)
                _req.state = AsyncRolloutRequestStateEnum.RUNNING
            elif _req.state == AsyncRolloutRequestStateEnum.TOOL_CALLING:
                if _req.get_last_msg().tool_calls is not None:
                    parsed_tool_calls = _req.get_last_msg().tool_calls
                    tool_call_results = await asyncio.gather(
                        *[
                            self._tool_map[tool_call.function.name].execute(
                                _req.request_id,
                                tool_call.function.arguments,
                                **_req.tools_kwargs[tool_call.function.name].get("execute_kwargs", {}),
                            )
                            for tool_call in parsed_tool_calls
                        ]
                    )
                    _req.add_tool_response_messages(self.tokenizer, [resp for resp, _, _ in tool_call_results])
                    for tool_call, (resp, reward, metrics) in zip(parsed_tool_calls, tool_call_results):
                        _req.update_metrics(metrics, tool_call.function.name)
                    if _req.get_next_input_ids_len() >= self.config.max_model_len:
                        finish_reason_type = FinishReasonTypeEnum.STOP
                        break
                    _req.state = AsyncRolloutRequestStateEnum.RUNNING
                else:
                    raise ValueError(f"Unexpected tool calling last message state: {_req.get_last_msg()}")
            elif _req.state == AsyncRolloutRequestStateEnum.RUNNING:
                # Only continue the conversation if the prompt length is not greater than max_model_len - 1,
                # since SGLang raises an error when max_new_tokens + 1 is greater to max_model_len (the extra token accounts for the EOS token).
                if len(_req.get_generation_prompt_ids(self.tokenizer)) + 1 >= self.config.max_model_len:
                    finish_reason_type = FinishReasonTypeEnum.LENGTH
                    break
                 # every time before generating an action we check if belief state should be calculated.
                # this will have to call handle engine call for belief generation,
                if _req.should_generate_belief(): 
                    _req.is_gen_belief = True
                    output = await self._handle_engine_call(_req, request_sampling_params) # belief generation.
                    content = output["text"]
                    _req.add_assistant_message(self.tokenizer, content) 
                    # it is possible to have a belief state call which requires tools, we will not support this for now. 
                    _req.is_gen_belief = False
                    # and store the belief in one context with gradient info, (belief generation context)
                    # and then separately in another context without gradient info (action context)
                output = await self._handle_engine_call(_req, request_sampling_params) # action
                _req.increment_turn()
                # print('output', output) # this for seeing if I can compute the number of generated tokens easily.
                tokens_per_assistant_message.append(output['meta_info']['completion_tokens']) # this is characters right now, but want to make it tokens eventually.
                content = output["text"]
                finish_reason_type = FinishReasonTypeEnum.from_str(output["meta_info"]["finish_reason"]["type"])
                current_turns += 1
                if finish_reason_type == FinishReasonTypeEnum.LENGTH:
                    _req.add_assistant_message(self.tokenizer, content)
                    break
                else:
                    if self._function_call_parser and self._function_call_parser.has_tool_call(content):
                        finish_reason_type = FinishReasonTypeEnum.TOOL_CALL
                        _req.state = AsyncRolloutRequestStateEnum.TOOL_CALLING
                        try:
                            normed_content, tool_calls = self._function_call_parser.parse_non_stream(content)
                        except JSONDecodeError:
                            normed_content = content
                            tool_calls = []
                        except AttributeError:
                            normed_content = content
                            tool_calls = []
                        parsed_tool_calls = []
                        for tool_call in tool_calls:
                            function, has_decode_error = OpenAIFunctionCallSchema.from_openai_function_parsed_schema(
                                OpenAIFunctionParsedSchema(
                                    name=tool_call.name,
                                    arguments=tool_call.parameters,
                                )
                            )
                            # Drop the tool call if its arguments has decode error
                            if has_decode_error:
                                continue
                            parsed_tool_calls.append(
                                OpenAIFunctionToolCall(
                                    id=str(tool_call.tool_index),
                                    function=function,
                                )
                            )
                        if len(parsed_tool_calls) > 0:
                            _req.add_assistant_message(self.tokenizer, normed_content, tool_calls=parsed_tool_calls)
                        else:
                            _req.add_assistant_message(self.tokenizer, content)
                            finish_reason_type = FinishReasonTypeEnum.STOP
                            _req.state = AsyncRolloutRequestStateEnum.COMPLETED
                            break
                    else:
                        _req.add_assistant_message(
                            self.tokenizer,
                            content,
                        )
                        if _req.interaction_kwargs and user_turns < self.config.multi_turn.max_user_turns and current_turns < self.config.multi_turn.max_assistant_turns:
                            _req.state = AsyncRolloutRequestStateEnum.INTERACTING
                        else:
                            break
            elif _req.state == AsyncRolloutRequestStateEnum.INTERACTING:
                user_turns += 1
                messages = [{"role": x.role, "content": x.content} for x in _req.get_last_msgs()]
                should_terminate_sequence, content, reward, metrics = await self.interaction.generate_response(_req.request_id, messages, **_req.interaction_kwargs)
                user_turn_rewards.append(reward)
                if should_terminate_sequence:
                    run_completion = True
                    finish_reason_type = FinishReasonTypeEnum.STOP
                    _req.state = AsyncRolloutRequestStateEnum.COMPLETED
                    break
                else:
                    _req.add_user_message(self.tokenizer, content)
                    if _req.get_next_input_ids_len() >= self.config.max_model_len:
                        finish_reason_type = FinishReasonTypeEnum.STOP
                        break
                    else:
                        _req.state = AsyncRolloutRequestStateEnum.RUNNING

        if current_turns >= self.config.multi_turn.max_assistant_turns:
            finish_reason_type = FinishReasonTypeEnum.STOP

        # Calculate the reward for each tool
        async def calc_reward_and_release_fn(name: str, tool: BaseTool):
            reward = await tool.calc_reward(_req.request_id, **_req.tools_kwargs[name].get("calc_reward_kwargs", {}))
            await tool.release(_req.request_id, **_req.tools_kwargs[name].get("release_kwargs", {}))
            return name, reward

        tool_reward_tasks = []
        for name in _req.tools_kwargs.keys():
            tool = self._tool_map[name]
            tool_reward_tasks.append(calc_reward_and_release_fn(name, tool))
        tool_reward_scores = await asyncio.gather(*tool_reward_tasks)
        tool_reward_scores = dict(tool_reward_scores)
        if self.interaction is not None:
            all_rewards = {"interaction_reward": [await self.interaction.calculate_score(_req.request_id)]} 
        else:
            all_rewards = {**tool_reward_scores, **{"user_turn_rewards": user_turn_rewards}}
        # all rewards other than this don't matter anyway for combo lock environment, with GRPO especially. Outcome rewards are it.
        # this is very specific to combo lock with the defined rewards of anything above zero meaning that the correct thing was eventually guessed.
        _req.to_log_stats = {"trajectory_efficiency_info": self.interaction.get_trajectory_info(_req.request_id), "tokens_per_assistant_message": tokens_per_assistant_message, "run_success": all_rewards['interaction_reward'][0] > 0, "run_attempts": self.interaction.get_attempts(_req.request_id), "run_completion": run_completion}

        _req.finalize(self.tokenizer, all_rewards, finish_reason_type)

        return _req

    async def _handle_engine_call(self, _req: AsyncRolloutRequestInterface, sampling_params: dict) -> dict:
        generation_prompt_ids = _req.get_generation_prompt_ids(self.tokenizer)
        max_new_tokens = min(self.config.response_length, self.config.max_model_len - len(generation_prompt_ids) - 1)
        kwargs = sampling_params.copy()
        kwargs["max_new_tokens"] = max_new_tokens
        kwargs["n"] = 1  # group size is supported in preprocess

        output = await llm.async_generate(input_ids=generation_prompt_ids,
            sampling_params=kwargs,
            return_logprob=False)
        return output
    @GPUMemoryLogger(role="sglang rollout", logger=logger)    
    @torch.no_grad()
    def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
        if self.config.multi_turn.enable:
            return self._req_level_generate_sequences(prompts, **kwargs)
        return None
    
    @GPUMemoryLogger(role="sglang rollout", logger=logger)
    @torch.no_grad()
    def _req_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
        # Async rollout with tools support
        do_sample = prompts.meta_info.get("do_sample", True)
        is_validate = prompts.meta_info.get("validate", False)
        tgt_device = prompts.batch["input_ids"].device
        
        req_list = self._preprocess_prompt_to_async_rollout_requests(
            prompts,
            n=1 if is_validate else self.config.n,
        )
        loop = asyncio.get_event_loop()
        output_req_list = loop.run_until_complete(
            asyncio.gather(
                *[self._async_rollout_a_request(req, do_sample, is_validate, **kwargs) for req in req_list],
            )
        )
        sorted_output_req_list = sorted(output_req_list, key=lambda x: (x.batch_data_id, x.rollout_offset))
        
        # Construct the batch data
        prompt_ids, response_ids = [], []
        prompt_attention_mask, response_attention_mask = [], []
        prompt_position_ids, response_position_ids = [], []
        prompt_loss_mask, response_loss_mask = [], []
        messages = []
        reward_scores = []
        request_ids = []
        context_indices = []
        to_log_stats = []

         # breakpoint() # checking tensor dims.

        # the code is messy, but eventually we plan to remove the AsyncRolloutRequest and replace with AsyncRolloutRequestMultiContext with single context.

        new_sorted_output_req_list = []
        if self.config.multi_turn.multi_context:
            for req in sorted_output_req_list:
                for idx in range(len(req.input_ids)): # number of contexts
                    new_sorted_output_req_list.append(SimpleNamespace(**dict(
                                                                    request_id = req.request_id,
                                                                    context_index = idx,
                                                                    state = req.state,
                                                                    reward_scores = req.reward_scores,
                                                                    to_log_stats = req.to_log_stats,
                                                                    max_model_len = req.max_model_len,
                                                                    input_ids = req.input_ids[idx],
                                                                    attention_mask = req.attention_mask[idx],
                                                                    position_ids = req.position_ids[idx],
                                                                    prompt_ids = req.prompt_ids[idx],
                                                                    loss_mask = req.loss_mask[idx],
                                                                    response_ids = req.response_ids[idx],
                                                                    prompt_attention_mask = req.prompt_attention_mask[idx],
                                                                    response_attention_mask = req.response_attention_mask[idx],
                                                                    prompt_position_ids = req.prompt_position_ids[idx],
                                                                    response_position_ids = req.response_position_ids[idx],
                                                                    prompt_loss_mask = req.prompt_loss_mask[idx],
                                                                    response_loss_mask = req.response_loss_mask[idx],
                                                                    messages = req.messages[idx],
                                                                    )))
            sorted_output_req_list = new_sorted_output_req_list

        for req in sorted_output_req_list:
            assert req.state == AsyncRolloutRequestStateEnum.COMPLETED, f"Request {req.request_id} is not completed"
            assert len(req.input_ids) == len(req.attention_mask) == len(req.position_ids) == len(req.loss_mask), f"""Request {req.request_id} has different length of 
                {len(req.input_ids)=}, {len(req.attention_mask)=}, {len(req.position_ids)=}, {len(req.loss_mask)=}"""
            error_message_lines = [
                f"""Request {req.request_id} has input_ids length {len(req.input_ids)}
                    greater than max_model_len {self.config.max_model_len}""",
                f"Decoded input_ids: {self.tokenizer.decode(req.input_ids)}",
                f"Decoded prompt_ids: {self.tokenizer.decode(req.prompt_ids)}",
                f"Decoded response_ids: {self.tokenizer.decode(req.response_ids)}",
                f"Messages: {req.messages}",
                f"Max model length: {req.max_model_len}",
            ]
            error_message = "\n".join(error_message_lines)
            assert len(req.input_ids) <= self.config.max_model_len, error_message

            prompt_ids.append(torch.tensor(req.prompt_ids, dtype=torch.int, device=tgt_device))
            response_ids.append(torch.tensor(req.response_ids, dtype=torch.int, device=tgt_device))
            if len(req.response_ids) > self.config.response_length:
                logger.warning(
                    f"""{req.request_id=} has response_ids length {len(req.response_ids)} 
                    greater than max_response_len {self.config.response_length},\n{req=}"""
                )
            prompt_attention_mask.append(torch.tensor(req.prompt_attention_mask, dtype=torch.int, device=tgt_device))
            response_attention_mask.append(torch.tensor(req.response_attention_mask, dtype=torch.int, device=tgt_device))
            prompt_position_ids.append(torch.tensor(req.prompt_position_ids, dtype=torch.int, device=tgt_device))
            response_position_ids.append(torch.tensor(req.response_position_ids, dtype=torch.int, device=tgt_device))
            prompt_loss_mask.append(torch.tensor(req.prompt_loss_mask, dtype=torch.int, device=tgt_device))
            response_loss_mask.append(torch.tensor(req.response_loss_mask, dtype=torch.int, device=tgt_device))
            messages.append({"messages": req.messages})
            reward_scores.append(req.reward_scores)
            request_ids.append(req.request_id)
            context_indices.append(req.context_index)
            to_log_stats.append(req.to_log_stats)

        prompt_ids = pad_sequence(
            prompt_ids,
            batch_first=True,
            padding_value=self.pad_token_id,
            padding_side="left",
        )
        if prompt_ids.shape[1] < self.config.prompt_length:
            prompt_ids = pad_sequence_to_length(prompt_ids, self.config.prompt_length, self.pad_token_id, left_pad=True)
        response_ids = pad_sequence(response_ids, batch_first=True, padding_value=self.pad_token_id)
        if response_ids.shape[1] < self.config.response_length:
            response_ids = pad_sequence_to_length(response_ids, self.config.response_length, self.pad_token_id)
        prompt_attention_mask = pad_sequence(
            prompt_attention_mask,
            batch_first=True,
            padding_value=0,
            padding_side="left",
        )
        if prompt_attention_mask.shape[1] < self.config.prompt_length:
            prompt_attention_mask = pad_sequence_to_length(prompt_attention_mask, self.config.prompt_length, 0, left_pad=True)
        response_attention_mask = pad_sequence(response_attention_mask, batch_first=True, padding_value=0)
        if response_attention_mask.shape[1] < self.config.response_length:
            response_attention_mask = pad_sequence_to_length(response_attention_mask, self.config.response_length, 0)
        prompt_position_ids = pad_sequence(prompt_position_ids, batch_first=True, padding_value=0, padding_side="left")
        if prompt_position_ids.shape[1] < self.config.prompt_length:
            prompt_position_ids = pad_sequence_to_length(prompt_position_ids, self.config.prompt_length, 0, left_pad=True)
        response_length = response_ids.size(1)
        delta_position_id = torch.arange(1, response_length + 1, device=response_ids.device)
        delta_position_id = delta_position_id.unsqueeze(0).repeat(len(sorted_output_req_list), 1)
        response_position_ids = prompt_position_ids[:, -1:] + delta_position_id
        prompt_loss_mask = pad_sequence(prompt_loss_mask, batch_first=True, padding_value=0, padding_side="left")
        if prompt_loss_mask.shape[1] < self.config.prompt_length:
            prompt_loss_mask = pad_sequence_to_length(prompt_loss_mask, self.config.prompt_length, 0, left_pad=True)
        response_loss_mask = pad_sequence(response_loss_mask, batch_first=True, padding_value=0)
        if response_loss_mask.shape[1] < self.config.response_length:
            response_loss_mask = pad_sequence_to_length(response_loss_mask, self.config.response_length, 0)

        input_ids = torch.cat((prompt_ids, response_ids), dim=-1)
        attention_mask = torch.cat((prompt_attention_mask, response_attention_mask), dim=-1)
        position_ids = torch.cat((prompt_position_ids, response_position_ids), dim=-1)
        loss_mask = torch.cat((prompt_loss_mask, response_loss_mask), dim=-1)

        # Construct the batch data
        batch = TensorDict(
            {
                "prompts": prompt_ids,
                "responses": response_ids,
                "input_ids": input_ids,  # here input_ids become the whole sentences
                "attention_mask": attention_mask,
                "position_ids": position_ids,
                "loss_mask": loss_mask,
            },
            batch_size=len(sorted_output_req_list),
        )

        return DataProto(
            batch=batch,
            non_tensor_batch={
                "messages": np.array(messages),
                "reward_scores": np.array(reward_scores),
                "request_ids": np.array(request_ids),
                "context_indices": np.array(context_indices),
                "to_log_stats": np.array(to_log_stats),
            },
            # add item here for trajectory index?
        )

    def _preprocess_prompt_to_async_rollout_requests(self, prompts: DataProto, n: int) -> list[AsyncRolloutRequest]:
        assert "raw_prompt" in prompts.non_tensor_batch, "need data.return_raw_chat=True, due to no official way do parse_messages"
        req_list = []
        for data_idx, raw_prompt in enumerate(prompts.non_tensor_batch["raw_prompt"]):
            for rollout_offset in range(n):
                if self._tool_schemas:
                    _tools_kwargs = prompts.non_tensor_batch["tools_kwargs"][data_idx]
                    _tool_schemas = [self._tool_map[k].get_openai_tool_schema() for k in _tools_kwargs.keys()]
                    _input_ids = None
                    _attention_mask = None
                else:
                    _input_ids = _pre_process_inputs(self.pad_token_id, prompts.batch["input_ids"][data_idx])
                    _attention_mask = _pre_process_inputs(0, prompts.batch["attention_mask"][data_idx])
                    _tools_kwargs = {}
                    _tool_schemas = None

                if self.interaction is not None:
                    _interaction_kwargs = prompts.non_tensor_batch["interaction_kwargs"][data_idx]
                else:
                    _interaction_kwargs = {}
                # breakpoint()
                if self.config.multi_turn.multi_context:
                    req = AsyncRolloutRequestMultiContext(
                        batch_data_id=data_idx,
                        rollout_offset=rollout_offset,
                        request_id=str(uuid4()),
                        state=AsyncRolloutRequestStateEnum.PENDING,
                        messages=[raw_prompt.tolist()],
                        tool_schemas=_tool_schemas,
                        tools_kwargs=_tools_kwargs,
                        interaction_kwargs=_interaction_kwargs,
                        input_ids=[_input_ids],
                        response_ids=[],
                        attention_mask=[_attention_mask],
                        response_attention_mask=[],
                        response_position_ids=[],
                        response_loss_mask=[],
                        reward_scores={},
                        to_log_stats={},
                        max_prompt_len=self.config.prompt_length,
                        max_response_len=self.config.response_length,
                        max_model_len=min(self.config.max_model_len, self.config.prompt_length + self.config.response_length),
                        use_inference_chat_template=self.config.multi_turn.use_inference_chat_template,
                        force_thinking=self.config.multi_turn.force_thinking,
                        enable_tokenization_sanity_check=self.config.multi_turn.enable_tokenization_sanity_check,
                        tokenizer=self.tokenizer,
                    )
                else:
                    req = AsyncRolloutRequest(
                        batch_data_id=data_idx,
                        rollout_offset=rollout_offset,
                        request_id=str(uuid4()),
                        state=AsyncRolloutRequestStateEnum.PENDING,
                        messages=raw_prompt.tolist(),
                        tool_schemas=_tool_schemas,
                        tools_kwargs=_tools_kwargs,
                        interaction_kwargs=_interaction_kwargs,
                        input_ids=_input_ids,
                        response_ids=[],
                        attention_mask=_attention_mask,
                        response_attention_mask=[],
                        response_position_ids=[],
                        response_loss_mask=[],
                        reward_scores={},
                        to_log_stats={},
                        max_prompt_len=self.config.prompt_length,
                        max_response_len=self.config.response_length,
                        max_model_len=min(self.config.max_model_len, self.config.prompt_length + self.config.response_length),
                        use_inference_chat_template=self.config.multi_turn.use_inference_chat_template,
                        force_thinking=self.config.multi_turn.force_thinking,
                        enable_tokenization_sanity_check=self.config.multi_turn.enable_tokenization_sanity_check,
                        tokenizer=self.tokenizer,
                    )
                error_message = f"Request {req.request_id} has mismatched lengths: input_ids={len(req.input_ids)}, attention_mask={len(req.attention_mask)}, position_ids={len(req.position_ids)}, loss_mask={len(req.loss_mask)}"
                assert len(req.input_ids) == len(req.attention_mask) == len(req.position_ids) == len(req.loss_mask), error_message

                req_list.append(req)

        return req_list

    async def _handle_pending_state(self, _req: AsyncRolloutRequestInterface) -> None:
        if _req.tool_schemas is not None:
            tool_creation_coroutines = []
            for tool_schema in _req.tool_schemas:
                tool = self._tool_map[tool_schema.function.name]
                create_kwargs = _req.tools_kwargs[tool.name].get("create_kwargs", {})
                tool_creation_coroutines.append(tool.create(_req.request_id, **create_kwargs))
            await asyncio.gather(*tool_creation_coroutines)
        if _req.interaction_kwargs:
            interaction_kwargs = _req.interaction_kwargs
            await self.interaction.start_interaction(_req.request_id, **interaction_kwargs)

    async def wake_up(self):
        if not self.is_sleep:
            return
        await self.sharding_manager.wake_up()  # pylint: disable=C2801
        self.is_sleep = False

    # this function is left for uniform train-inference resharding
    async def sleep(self):
        if self.is_sleep:
            return
        await self.sharding_manager.sleep()
        self.is_sleep = True

local_path = copy_to_local(cfg.actor_rollout_ref.model.path, use_shm=cfg.actor_rollout_ref.model.get("use_shm", False))

tokenizer = hf_tokenizer(local_path, is_instruct_model=cfg.actor_rollout_ref.rollout.is_instruct_model, trust_remote_code=True)

rollout = SGLangRollout(cfg.actor_rollout_ref.rollout, tokenizer, None)

using base chat template


In [19]:
cfg.actor_rollout_ref.rollout.do_sample
rollout.config.temperature

0.7

In [20]:
processor = hf_processor(local_path, trust_remote_code=True, use_fast=True)

dataset = create_rl_dataset(cfg.data.train_files, cfg.data, tokenizer, processor)

from verl.utils.dataset.rl_dataset import collate_fn as default_collate_fn

collate_fn = default_collate_fn


val_dataloader = StatefulDataLoader(
    dataset=dataset,
    batch_size=100,
    num_workers=0,
    shuffle=False,
    drop_last=False,
    collate_fn=collate_fn,
)

for batch in val_dataloader:
    print(batch)
    test_batch = DataProto.from_single_dict(batch)
    break

Using dataset class: RLHFDataset
dataset len: 3360
filter dataset len: 3360
{'input_ids': tensor([[151643, 151643, 151643,  ...,  94367,  77091,   1447],
        [151643, 151643, 151643,  ...,  94367,  77091,   1447],
        [151643, 151643, 151643,  ...,  94367,  77091,   1447],
        ...,
        [151643, 151643, 151643,  ...,  94367,  77091,   1447],
        [151643, 151643, 151643,  ...,  94367,  77091,   1447],
        [151643, 151643, 151643,  ...,  94367,  77091,   1447]]), 'attention_mask': tensor([[0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        ...,
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1]]), 'position_ids': tensor([[  0,   0,   0,  ..., 225, 226, 227],
        [  0,   0,   0,  ..., 225, 226, 227],
        [  0,   0,   0,  ..., 225, 226, 227],
        ...,
        [  0,   0,   0,  ..., 225, 226, 227],
        [  0,   0,   0,  ..., 225, 226, 227],
        [  0,  

In [21]:
from verl.protocol import pad_dataproto_to_divisor
for batch in val_dataloader:
    test_batch = DataProto.from_single_dict(batch)
    break
batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"]
non_tensor_batch_keys_to_pop = ["raw_prompt_ids"]
if "multi_modal_data" in test_batch.non_tensor_batch:
    non_tensor_batch_keys_to_pop.append("multi_modal_data")
if "raw_prompt" in test_batch.non_tensor_batch:
    non_tensor_batch_keys_to_pop.append("raw_prompt")
if "tools_kwargs" in test_batch.non_tensor_batch:
    non_tensor_batch_keys_to_pop.append("tools_kwargs")
if "interaction_kwargs" in test_batch.non_tensor_batch:
    non_tensor_batch_keys_to_pop.append("interaction_kwargs")
print(f"non_tensor_batch_keys_to_pop: {non_tensor_batch_keys_to_pop}")
print(f"batch_keys_to_pop: {batch_keys_to_pop}")
test_gen_batch = test_batch.pop(
    batch_keys=batch_keys_to_pop,
    non_tensor_batch_keys=non_tensor_batch_keys_to_pop,
)

test_gen_batch.meta_info = {
    "eos_token_id": rollout.tokenizer.eos_token_id,
    "pad_token_id": rollout.tokenizer.pad_token_id,
    "recompute_log_prob": False,
    "do_sample": True,
    "validate": False,
}

print(f"test_gen_batch meta info: {test_gen_batch.meta_info}")

# pad to be divisible by dp_size
test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, 1)

test_output_gen_batch_padded = rollout.generate_sequences(test_gen_batch_padded)
# prompts = test_gen_batch_padded
test_output_gen_batch_padded
# do_sample = prompts.meta_info.get("do_sample", True)
# is_validate = prompts.meta_info.get("validate", False)
# tgt_device = prompts.batch["input_ids"].device
# req_list = rollout._preprocess_prompt_to_async_rollout_requests(
#     prompts,
#     n=1 if is_validate else rollout.config.n,
# )
# # loop = asyncio.get_event_loop()
# output_req_list = await asyncio.gather(
#         *[rollout._async_rollout_a_request(req, do_sample, is_validate) for req in req_list],
#     )
# sorted_output_req_list = sorted(output_req_list, key=lambda x: (x.batch_data_id, x.rollout_offset))



non_tensor_batch_keys_to_pop: ['raw_prompt_ids', 'raw_prompt', 'tools_kwargs', 'interaction_kwargs']
batch_keys_to_pop: ['input_ids', 'attention_mask', 'position_ids']
test_gen_batch meta info: {'eos_token_id': 151643, 'pad_token_id': 151643, 'recompute_log_prob': False, 'do_sample': True, 'validate': False}


[2025-07-28 19:05:34] Inconsistent training and inference tokenization detected. This may lead to unexpected behavior during training. Please review your chat template to determine if this is intentional. For more information, refer to the multiturn README.md.
[2025-07-28 19:05:40] Inconsistent training and inference tokenization detected. This may lead to unexpected behavior during training. Please review your chat template to determine if this is intentional. For more information, refer to the multiturn README.md.
[2025-07-28 19:05:43] Inconsistent training and inference tokenization detected. This may lead to unexpected behavior during training. Please review your chat template to determine if this is intentional. For more information, refer to the multiturn README.md.
[2025-07-28 19:05:43] Inconsistent training and inference tokenization detected. This may lead to unexpected behavior during training. Please review your chat template to determine if this is intentional. For more inf

DataProto(batch=TensorDict(
    fields={
        attention_mask: Tensor(shape=torch.Size([2258, 3072]), device=cpu, dtype=torch.int32, is_shared=False),
        input_ids: Tensor(shape=torch.Size([2258, 3072]), device=cpu, dtype=torch.int32, is_shared=False),
        loss_mask: Tensor(shape=torch.Size([2258, 3072]), device=cpu, dtype=torch.int32, is_shared=False),
        position_ids: Tensor(shape=torch.Size([2258, 3072]), device=cpu, dtype=torch.int64, is_shared=False),
        prompts: Tensor(shape=torch.Size([2258, 2048]), device=cpu, dtype=torch.int32, is_shared=False),
        responses: Tensor(shape=torch.Size([2258, 1024]), device=cpu, dtype=torch.int32, is_shared=False)},
    batch_size=torch.Size([2258]),
    device=None,
    is_shared=False), non_tensor_batch={'messages': array([{'messages': [Message(role='system', content="You will determine the combination for [Position 1, Position 2, Position 3] in a 3-character combination lock through iterative reasoning and queries.\nT

In [9]:
test_output_gen_batch_padded.non_tensor_batch['to_log_stats'][0]

{'trajectory_efficiency_info': {'repeated_guesses': 2},
 'tokens_per_assistant_message': [15, 20, 67, 56, 30, 350, 15, 1024],
 'run_success': False,
 'run_attempts': 6,
 'run_completion': False}

In [22]:
[i for i, stat in enumerate(test_output_gen_batch_padded.non_tensor_batch['to_log_stats']) if stat['run_success']]

[2004,
 2005,
 2006,
 2007,
 2008,
 2009,
 2010,
 2011,
 2012,
 2013,
 2014,
 2015,
 2016,
 2017,
 2018,
 2019,
 2020]

In [23]:
for i in range(2004,2021):
    print(tokenizer.decode(test_output_gen_batch_padded.batch['input_ids'][i], skip_special_tokens=True))
    print('-'*100)

system prompt:

You will determine the combination for [Position 1, Position 2, Position 3] in a 3-character combination lock through iterative reasoning and queries.
The lock requires a 3-character combination. All 3 characters are unique.
The set of valid characters are as follows: ['q', 'a', 'w', 's', 'e', 'd', 'r', 'f', 't', 'g', 'y', 'h', 'u', 'j', 'i', 'k']
You get feedback from the user on the presence of a character. Either not in the lock, in the lock but in a different position, or in the lock and in the right position.
You have 12 queries.
Your goal is to find the correct combination in the least number of queries.
To query the lock, you will first think step by step, and then generate a query formatted as a list of 3 characters inside <action> ... </action>, e.g., <action>['char 1', 'char 2', 'char 3']</action>.

assistant:

<action>['q', 'q', 'q']</action>


user:

Could not parse valid guess from: '['q', 'q', 'q']'. Please ensure the guess is contained in the final charac

In [118]:
for j in range(3):
    print("request id: ", test_output_gen_batch_padded.non_tensor_batch["request_ids"][j], " context index: ", test_output_gen_batch_padded.non_tensor_batch["context_indices"][j])
    for i in range(len(test_output_gen_batch_padded.non_tensor_batch["messages"][j]['messages'])):
        pprint(test_output_gen_batch_padded.non_tensor_batch["messages"][j]['messages'][i].content)#.keys()#['request_ids']
        print('-'*50)
    print('-'*100)

request id:  0e2d2149-b73f-485f-9e25-552ec5972896  context index:  0
('You will determine the combination for [Position 1, Position 2, Position 3] '
 'in a 3-character combination lock through iterative reasoning and queries.\n'
 'The lock requires a 3-character combination. All 3 characters are unique.\n'
 "The set of valid characters are as follows: ['q', 'a', 'w', 's', 'e', 'd', "
 "'r', 'f', 't', 'g', 'y', 'h', 'u', 'j', 'i', 'k']\n"
 'You get feedback from the user on the presence of a character. Either not in '
 'the lock, in the lock but in a different position, or in the lock and in the '
 'right position.\n'
 'You have 12 queries.\n'
 'Your goal is to find the correct combination in the least number of '
 'queries.\n'
 'To query the lock, you will first think step by step, and then generate a '
 'query formatted as a list of 3 characters inside <action> ... </action>, '
 "e.g., <action>['char 1', 'char 2', 'char 3']</action>.")
-------------------------------------------------

In [92]:
from verl.interactions.utils import process_msg_content
contents, valid, error_msg = process_msg_content("Next query: <action>[a, e, b]</action>\n### Step-by-Step Reasoning:", tag_list=['action'])
print(contents[0])

[a, e, b]
