In [1]:
%load_ext autoreload
%autoreload 2
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'

In [2]:
from optimal_explorer.mdps.combination_lock import CombinationLock
mdp = CombinationLock()
mdp.reset()
mdp._is_valid_guess("224")
# mdp.step("224")

False

In [3]:
import itertools
import random
from copy import deepcopy
from datasets import Dataset
from pprint import pprint
combination_length = 3
max_attempts = 12

vocab = "0123456789" # "qawsedrftgyhujik" # 16 chars.
combinations = list(itertools.permutations(list(vocab), combination_length))
rand = random.Random(42)
rand.shuffle(combinations)
test_combinations = deepcopy(combinations[:1])
rand.shuffle(combinations)
# combinations = combinations[:1]
# test_combinations = deepcopy(combinations)
def create_data(split, idx, solution, format="tool"):
    if "interaction" in format:
        position_str = ", ".join(f'Position {i + 1}' for i in range(combination_length))
        example_str = ", ".join(f"'char {i + 1}'" for i in range(combination_length))
        system_prompt = (f"You will determine the correct combination of characters at [{position_str}] in a {combination_length}-character combination lock through iterative reasoning and queries.\n"
                            f"All {combination_length} characters are unique.\n"
                            f"The set of valid characters are as follows: {list(vocab)}\n"
                            f"Each action is a query of the form [{example_str}].\n"
                            "Each time you query a combination, you will get feedback from the user about each character: either not in the combination, in the combination but in a different position, or in the combination and in the right position.\n"
                            f"You can make up to {max_attempts} queries.\n"
                            "Your goal is to find the correct combination in the least number of queries.\n"
                            )
                            # "Ensure you don't put '?' in your queries. Make sure you submit vocabulary characters in your query.\n"
                            # "At each step, you can see the instructions, the current belief state in <belief> ... </belief>, the latest action in <action> ... </action>, and the returned feedback in <feedback> ... </feedback>.\n"
                            # "First think step by step then"
                            #f"Each query should be formatted as a python list: [{example_str}].\n")
        if format == "interaction_base":
            system_prompt = system_prompt.strip()
        else:
            raise Exception(f"invalid format {format}")
        
        prompt_list = [ {
                        "role": "system",
                        "content": system_prompt,
                    } ]
       
        prompt_list.append({
            "role": "user",
            # "content": f"Think extensively inside <think> tags, then give me your query formatted as a list of {combination_length} characters inside <action>[{example_str}]</action>.\n",
            "content": f"Give me your first query formatted as a list of {combination_length} characters inside <action> ... </action> after thinking inside <think> ... </think>, e.g., <think> Let's think step by step before giving the query [your extensive thinking] </think> <action>[{example_str}]</action>.\n",
        })
        return {
                "data_source": "multi_turn_combo_lock",
                "prompt": prompt_list,
                "ability": "math",
                "reward_model": {"style": "rule", "ground_truth": solution},
                "extra_info": {
                    "split": split,
                    "index": idx,
                    "answer": solution,
                    "interaction_kwargs": {
                        "combination_length": combination_length,
                        "max_attempts": max_attempts,
                        "vocab": vocab,
                        "ground_truth": solution,
                        "format": format,
                    },
                },
            }
  
format = "interaction_base"
is_instruct_model = False
train_data = [create_data('train', idx, combo_sol, format) for idx, combo_sol in enumerate(combinations)]
test_data = [create_data('test', idx, combo_sol, format) for idx, combo_sol in enumerate(test_combinations)]

Dataset.from_list(train_data).to_parquet(f"../data/multi_turn_combo_lock_{format+('_base' if not is_instruct_model else '')}/train.parquet")
Dataset.from_list(test_data).to_parquet(f"../data/multi_turn_combo_lock_{format+('_base' if not is_instruct_model else '')}/test.parquet")
pprint(test_data[0])
# print(test_data[0]['prompt'][0]['content'])

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

{'ability': 'math',
 'data_source': 'multi_turn_combo_lock',
 'extra_info': {'answer': ('8', '0', '6'),
                'index': 0,
                'interaction_kwargs': {'combination_length': 3,
                                       'format': 'interaction_base',
                                       'ground_truth': ('8', '0', '6'),
                                       'max_attempts': 12,
                                       'vocab': '0123456789'},
                'split': 'test'},
 'prompt': [{'content': 'You will determine the correct combination of '
                        'characters at [Position 1, Position 2, Position 3] in '
                        'a 3-character combination lock through iterative '
                        'reasoning and queries.\n'
                        'All 3 characters are unique.\n'
                        "The set of valid characters are as follows: ['0', "
                        "'1', '2', '3', '4', '5', '6', '7', '8', '9']\n"
                   

In [3]:
Dataset.from_parquet(f"../data/multi_turn_combo_lock_{format+('_base' if not is_instruct_model else '')}_single/train.parquet")

Dataset({
    features: ['data_source', 'prompt', 'ability', 'reward_model', 'extra_info'],
    num_rows: 8
})

In [4]:
train_data

[{'data_source': 'multi_turn_combo_lock',
  'prompt': [{'role': 'system',
    'content': "You will determine the correct combination of characters at [Position 1, Position 2, Position 3] in a 3-character combination lock through iterative reasoning and queries.\nAll 3 characters are unique.\nThe set of valid characters are as follows: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']\nEach action is a query of the form ['char 1', 'char 2', 'char 3'].\nEach time you query a combination, you will get feedback from the user about each character: either not in the combination, in the combination but in a different position, or in the combination and in the right position.\nYou can make up to 12 queries.\nYour goal is to find the correct combination in the least number of queries."},
   {'role': 'user',
    'content': "Give me your first query formatted as a list of 3 characters inside <action> ... </action> after thinking inside <think> ... </think>, e.g., <think> Let's think step by step

In [4]:
from hydra import initialize, compose
from omegaconf import OmegaConf

# Use this when you want to use the config outside @hydra.main
def get_config():
    overrides = [
        "algorithm.adv_estimator=grpo",
        "data.train_batch_size=512",
        "data.max_prompt_length=1024",
        "data.max_response_length=2048", 
        "data.filter_overlong_prompts=True",
        "data.truncation=error",
        "data.return_raw_chat=True",
        "actor_rollout_ref.model.path=Qwen/Qwen2.5-7B-instruct", # Jakob: important
        "actor_rollout_ref.rollout.is_instruct_model=True", # Jakob: important
        "actor_rollout_ref.rollout.max_new_tokens=512",
        "actor_rollout_ref.model.use_remove_padding=True",
        "+actor_rollout_ref.model.enable_activation_offloading=True",
        "actor_rollout_ref.actor.optim.lr=1e-6",
        "actor_rollout_ref.actor.ppo_mini_batch_size=512",
        "actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8",
        "actor_rollout_ref.actor.use_kl_loss=True",
        "actor_rollout_ref.actor.kl_loss_coef=0.001",
        "actor_rollout_ref.actor.kl_loss_type=low_var_kl",
        "actor_rollout_ref.actor.entropy_coeff=0",
        "actor_rollout_ref.actor.fsdp_config.param_offload=False",
        "actor_rollout_ref.actor.fsdp_config.optimizer_offload=False",
        "+actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16",
        "actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8",
        "actor_rollout_ref.rollout.tensor_model_parallel_size=1",
        "actor_rollout_ref.rollout.name=sglang",
        "actor_rollout_ref.rollout.gpu_memory_utilization=0.5",
        "actor_rollout_ref.rollout.n=1",
        "actor_rollout_ref.rollout.temperature=1", # Jakob: important sampling param we are changing
        "actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8",
        "actor_rollout_ref.ref.fsdp_config.param_offload=False",
        "algorithm.use_kl_in_reward=False",
        "trainer.critic_warmup=0",
        "trainer.logger=['console','wandb']",
        "trainer.project_name=verl-tests",
        "trainer.experiment_name=qwen2.5-3b_function_rm-combolock-sgl-multi-w-interaction-n8-2-debug",
        "trainer.n_gpus_per_node=1",
        "trainer.nnodes=1",
        "trainer.save_freq=20",
        "data.train_files=../data/multi_turn_combo_lock_interaction_base_base/train.parquet",
        "data.val_files=../data/multi_turn_combo_lock_interaction_base_base/test.parquet",
        "actor_rollout_ref.rollout.multi_turn.interaction_config_path=./examples/sglang_multiturn/config/interaction_config/combolock_interaction_config.yaml",
        "actor_rollout_ref.rollout.multi_turn.multi_context.enable=True",
        "actor_rollout_ref.rollout.multi_turn.lax_format=False",
        "reward_model.reward_manager=multiturn",
        "trainer.log_val_generations=10",
        "+custom_reward_function.path=verl/utils/reward_score/multiturn.py",
        "trainer.total_epochs=15",
        # optionally:
        # "trainer.test_freq=5",
    ]
    with initialize(config_path="examples/sglang_multiturn/config", version_base=None): 
        cfg = compose(config_name="combolock_multiturn_grpo_w_interaction", overrides=overrides)
        return cfg

cfg = get_config()
#print(OmegaConf.to_yaml(cfg))

In [18]:
import nest_asyncio

nest_asyncio.apply()
# import sglang as sgl

# from sglang.test.test_utils import is_in_ci

# if is_in_ci():
#     import patch
# else:
#     import nest_asyncio

#     nest_asyncio.apply()

# import torch
# llm = sgl.Engine(model_path=cfg.actor_rollout_ref.model.path, dp_size=torch.cuda.device_count())

# from vllm import LLM, AsyncLLMEngine, SamplingParams
# # llm = LLM(model=cfg.actor_rollout_ref.model.path)
# from vllm.engine.async_llm_engine import AsyncLLMEngine
# from vllm.engine.arg_utils import AsyncEngineArgs
# from vllm.sampling_params import SamplingParams
# engine_args = AsyncEngineArgs(cfg.actor_rollout_ref.model.path)
# engine = AsyncLLMEngine.from_engine_args(engine_args)
from openai import AsyncOpenAI

client = AsyncOpenAI(base_url="http://0.0.0.0:8000/v1", api_key="NONE")

out = await client.completions.create(model="qwen/qwen2.5-7b-instruct", prompt=[14990, 1052], logprobs=1)
out.choices[0]

CompletionChoice(finish_reason='length', index=0, logprobs=Logprobs(text_offset=[0, 11, 23, 36, 48, 61, 75, 87, 99, 111, 123, 135, 147, 157, 169, 181], token_logprobs=[-1.098605751991272, -1.3268921375274658, -1.4334216117858887, -2.7917659282684326, -3.2859456539154053, -1.036252737045288, -0.940233051776886, -0.9396407604217529, -3.058459520339966, -0.017765210941433907, -1.8452544212341309, -4.111019134521484, -0.30675244331359863, -0.6755289435386658, -0.01313144899904728, -1.4898881912231445], tokens=['token_id:11', 'token_id:358', 'token_id:1079', 'token_id:264', 'token_id:2699', 'token_id:21815', 'token_id:911', 'token_id:279', 'token_id:990', 'token_id:315', 'token_id:330', 'token_id:333', 'token_id:1', 'token_id:323', 'token_id:330', 'token_id:1503'], top_logprobs=[{'token_id:11': -1.098605751991272}, {'token_id:358': -1.3268921375274658, 'token_id:600': -1.2643921375274658}, {'token_id:1079': -1.4334216117858887, 'token_id:614': -1.4334216117858887}, {'token_id:264': -2.79176

In [23]:
[int(t.replace('token_id:',"")) for t in out.choices[0].logprobs.tokens]

[11,
 358,
 1079,
 264,
 2699,
 21815,
 911,
 279,
 990,
 315,
 330,
 333,
 1,
 323,
 330,
 1503]

In [3]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('qwen/qwen2.5-7b-instruct')


In [8]:
tokenizer.apply_chat_template([
    {'role': 'user', 'content': "hello there"}
],)

[151644,
 8948,
 198,
 2610,
 525,
 1207,
 16948,
 11,
 3465,
 553,
 54364,
 14817,
 13,
 1446,
 525,
 264,
 10950,
 17847,
 13,
 151645,
 198,
 151644,
 872,
 198,
 14990,
 1052,
 151645,
 198]

In [8]:
tokenizer.encode("hello there")

[14990, 1052]

In [9]:
# async def generate_text(prompt_token_ids: list[int], sampling_params=None):
#     if sampling_params is None:
#         sampling_params = SamplingParams()
#     request_id = uuid.uuid1()
#     # print(prompt_token_ids)
#     async for output in engine.generate(dict(prompt_token_ids=prompt_token_ids), sampling_params=sampling_params, request_id=request_id):
#         # Process the streaming output
#         last_output = output
#         # print(output.outputs)
    
#     #, request_id=str(request_id)
#     # return llm.generate(dict(prompt_token_ids=prompt_token_ids), sampling_params=sampling_params)[0].outputs[0]
#     return last_output.outputs[0]
# # import asyncio
# out = await generate_text([14990, 1052])
# # , generate_text("hello there"), generate_text("hello there"), )


In [63]:
import asyncio
import logging
import os
from copy import deepcopy
from json import JSONDecodeError
from uuid import uuid4

import numpy as np
import torch
from omegaconf import DictConfig

from sglang.srt.openai_api.protocol import Tool
from sglang.srt.sampling.sampling_params import SamplingParams

from tensordict import TensorDict
from torch.nn.utils.rnn import pad_sequence
from transformers import PreTrainedTokenizer

from verl import DataProto
from verl.interactions.base import BaseInteraction
from verl.tools.base_tool import BaseTool
from verl.tools.schemas import OpenAIFunctionCallSchema, OpenAIFunctionParsedSchema, OpenAIFunctionToolCall 
from verl.tools.utils.tool_registry import initialize_tools_from_config
from verl.utils.debug import GPUMemoryLogger
from verl.utils.torch_functional import get_response_mask, pad_sequence_to_length
from verl.workers.rollout.base import BaseRollout
from verl.workers.rollout.schemas import AsyncRolloutRequest, AsyncRolloutRequestStateEnum, FinishReasonTypeEnum, AsyncRolloutRequestInterface, AsyncRolloutRequestMultiContext


from types import SimpleNamespace

from verl.utils import hf_tokenizer, hf_processor
from verl.utils.fs import copy_to_local
from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler
from verl import DataProto
from torchdata.stateful_dataloader import StatefulDataLoader

try:
    from sglang.srt.function_call.function_call_parser import FunctionCallParser
except ImportError:
    from sglang.srt.function_call_parser import FunctionCallParser


logger = logging.getLogger()
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))


# NOTE(sgm): add for verl. We can optimize it by making
#  the dataloader yield List[int] without padding.
def _pre_process_inputs(
    pad_token_id,
    prompt_token_ids: torch.Tensor,
) -> list[int]:
    # remove the left padding in the prompt token_id
    non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0]
    token_ids = prompt_token_ids[non_pad_index:].tolist()
    return token_ids


# NOTE(linjunrong): adhoc
def _post_process_outputs(tokenizer, output):
    def _map_each_response(resp):
        output_token_logprobs = resp["meta_info"]["output_token_logprobs"]
        log_probs, output_token_ids = zip(*[(log_prob, token_ids) for log_prob, token_ids, _ in output_token_logprobs])
        return torch.tensor(output_token_ids), torch.tensor(log_probs)

    out_map = map(lambda x: _map_each_response(x), output)
    batched_output_token_ids = []
    batched_logprobs = []
    for output_token_ids, log_probs in out_map:
        batched_output_token_ids.append(output_token_ids)
        batched_logprobs.append(log_probs)
    pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
    batched_output_token_ids = pad_sequence(batched_output_token_ids, batch_first=True, padding_value=pad_token_id)
    if len(batched_logprobs) > 0:
        batched_logprobs = pad_sequence(batched_logprobs, batch_first=True, padding_value=pad_token_id)
    return batched_output_token_ids, batched_logprobs


def get_tool_call_parser_type(tokenizer: PreTrainedTokenizer) -> str:
    items = FunctionCallParser.ToolCallParserEnum.items()
    for parser_type, parser_cls in items:
        parser = parser_cls()
        if parser.bot_token.strip() in tokenizer.get_vocab() and (parser.eot_token == "" or parser.eot_token.strip() in tokenizer.get_vocab()):
            return parser_type
    else:
        raise ValueError(f"No tool call parser found for tokenizer {tokenizer}")


class SGLangRollout(BaseRollout):
    def __init__(
        self,
        config: DictConfig,
        tokenizer,
        model_hf_config,
        **kwargs,
    ):
        """Synchronized SGLang rollout engine.

        Args:
            actor_module: Huggingface model name or path to the model. The
                model should be supported by SGLang.
            config: A DictConfig object containing SGLang-specific operational
                parameters and rollout settings.
                Refer to https://docs.sglang.ai/backend/server_arguments.html
            tokenizer: The tokenizer instance compatible with the actor_module.
            model_hf_config: The Hugging Face model's configuration (e.g.,
                `transformers.PretrainedConfig`). It provides architectural
                details and hyperparameters like `max_position_embeddings`,
                used by SGLang for correct model initialization. This is
                the model's inherent design, not SGLang's runtime behavior.
            port: Optional port for multi-node initialization when nnodes > 1.
            trust_remote_code: Whether or not to allow for custom models
                defined on the Hub in their own modeling files.
            device_mesh: Optional `DeviceMesh` object for distributed setup.
            **kwargs: Additional keyword arguments, primarily `train_tp` for
                Megatron Backend integration to initialize hybrid engine
                process groups.
        """
        super().__init__()

        self.config = config
        
        os.environ.setdefault("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK", "true")

        (
            self._tool_schemas,
            self._tool_map,
            self._tool_call_parser_type,
            self._sgl_tools,
            self._function_call_parser,
        ) = self._initialize_tools(config, tokenizer)
        self.interaction: dict[str, BaseInteraction] = self._intitalize_interaction(config)
        # If turn on `free_cache_engine`, SGLang engine's KV cache
        # will be freed after each `generate_sequences` call.
        assert not (not config.enforce_eager and config.free_cache_engine), "disable CUDA graph (enforce_eager = False) if free cache engine"

        logger.info(f"tool_schemas: {self._tool_schemas}, tool_map: {self._tool_map}, tool_call_parser_type: {self._tool_call_parser_type}, sgl_tools: {self._sgl_tools}, function_call_parser: {self._function_call_parser}")

        self._verify_config(model_hf_config)
        self._init_sampling_params(**kwargs)

        self.tokenizer = tokenizer
        self.pad_token_id = tokenizer.pad_token_id


    def _verify_config(self, model_hf_config):
        if not self.config.get("max_model_len", None):
            self.config.max_model_len = self.config.prompt_length + self.config.response_length
        # jakob this doesn't seem necessary, and we actually want a setting where (max_model_len == prompt_length == response_length). So we can restrict the length of the context with max model len.
        # assert self.config.max_model_len >= self.config.prompt_length + self.config.response_length, f"""max_model_len should be greater than total sequence length (prompt_length + response_length): 
        #     {self.config.max_model_len} >= {self.config.prompt_length} + {self.config.response_length}"""
        # assert model_hf_config.max_position_embeddings >= self.config.max_model_len, "model context length should be greater than total sequence length"
        # currently max_assistant_turns stand for max number of tool calls
        if self.config.multi_turn.max_assistant_turns is None:
            self.config.multi_turn.max_assistant_turns = self.config.max_model_len // 3
        if self.config.multi_turn.max_user_turns is None:
            self.config.multi_turn.max_user_turns = self.config.max_model_len // 3


    def _init_sampling_params(self, **kwargs):
        kwargs = dict(
            n=1,
            max_new_tokens=self.config.response_length,
            presence_penalty=0.0,
            frequency_penalty=0.0,
            repetition_penalty=1.0,
            skip_special_tokens=False, # keep tokenization consistent btwn txt and ids
        )
        # supporting adding any sampling params from the config file
        for k in self.config.keys():
            if hasattr(SamplingParams(), str(k)):
                kwargs[k] = self.config.get(k)
        # Jakob special sampling param codes for ease
        if not self.config.get("is_instruct_model", True):
            kwargs["no_stop_trim"] = True
            kwargs["stop"] = ["</action>","</belief>"] # "</answer>" looks potentially unnecessary?
        self.sampling_params = kwargs

    def _initialize_tools(self, config, tokenizer):
        """Initialize tools from configuration.
        Args:
            config: Configuration object containing tool-related settings,
                    specifically `config.multi_turn.tool_config_path`.
            tokenizer: The tokenizer instance used for parsing tool calls from
                       the model's generated text.

        Returns:
            tuple: A tuple containing:
                - tool_schemas (list[dict]): OpenAI-formatted JSON schemas
                  defining each tool's capabilities.
                - tool_map (dict[str, BaseTool]): A dictionary mapping tool
                  names to their executable `BaseTool` objects.
                - tool_call_parser_type (str): The identifier for the specific
                  parser type (e.g., 'json_mode', 'tool_code') used to extract
                  tool calls.
                - sgl_tools (list[sglang.srt.openai_api.protocol.Tool]): Tool
                  definitions optimized for SGLang's internal engine.
                - function_call_parser (sglang.srt.function_call_parser.FunctionCallParser):
                  The active parser instance responsible for extracting
                  structured tool calls from model outputs.
        """
        if config.multi_turn.tool_config_path is None:
            return [], {}, None, [], None

        tools_config_file = config.multi_turn.tool_config_path
        tool_list = initialize_tools_from_config(tools_config_file)

        tool_schemas = [tool.get_openai_tool_schema().model_dump() for tool in tool_list]
        tool_map = {tool.name: tool for tool in tool_list}
        tool_call_parser_type = get_tool_call_parser_type(tokenizer)
        sgl_tools = [Tool.model_validate(tool_schema) for tool_schema in tool_schemas]
        function_call_parser = FunctionCallParser(
            sgl_tools,
            tool_call_parser_type,
        )

        return (
            tool_schemas,
            tool_map,
            tool_call_parser_type,
            sgl_tools,
            function_call_parser,
        )

    def _intitalize_interaction(self, config):
        import importlib.util
        import sys

        from omegaconf import OmegaConf

        if config.multi_turn.interaction_config_path is None:
            return None
        interaction_config_file = config.multi_turn.interaction_config_path
        interaction_config = OmegaConf.load(interaction_config_file).interaction[0]
        cls_name = interaction_config.class_name
        module_name, class_name = cls_name.rsplit(".", 1)
        if module_name not in sys.modules:
            spec = importlib.util.find_spec(module_name)
            module = importlib.util.module_from_spec(spec)
            sys.modules[module_name] = module
            spec.loader.exec_module(module)
        else:
            module = sys.modules[module_name]

        interaction_cls = getattr(module, class_name)

        interaction = interaction_cls(config=OmegaConf.to_container(interaction_config.config, resolve=True))
        return interaction

    async def _async_rollout_a_request(
        self,
        req: AsyncRolloutRequestInterface,
        do_sample: bool = True,
        is_validate: bool = False,
        **kwargs,
    ) -> AsyncRolloutRequestInterface:
        _req = deepcopy(req)

        finish_reason_type = None
        output = None

        current_turns = 0
        user_turns = 0
        user_turn_rewards = []

        # Create request-level sampling parameters
        request_sampling_params = self.sampling_params.copy()
        if not do_sample:
            request_sampling_params.update(
                {
                    "n": 1,
                    "presence_penalty": 0.0,
                    "frequency_penalty": 0.0,
                    "repetition_penalty": 1.0,
                    "temperature": 0,
                    "top_p": 1,
                    "top_k": -1,
                    "ignore_eos": False,
                    "min_new_tokens": 0,
                    "max_new_tokens": self.config.max_new_tokens,
                    "skip_special_tokens": True,
                    "spaces_between_special_tokens": True,
                }
            )
        elif is_validate:
            request_sampling_params.update(
                {
                    "top_k": self.config.val_kwargs.top_k,
                    "top_p": self.config.val_kwargs.top_p,
                    "temperature": self.config.val_kwargs.temperature,
                    "n": 1,  # if validate, already repeat in ray_trainer
                }
            )

        # Update with any additional kwargs
        request_sampling_params.update(kwargs)
        # run success, completion, and # attempts
        # number of tokens in the assistant messages
        tokens_per_action_message = []
        prompt_tokens_per_action_message = []
        prompt_tokens_per_belief_message = []
        tokens_per_belief_generation_message = []
        tokens_per_belief_state_message = []

        run_completion = False
        belief_gen_failures = 0
        
        
        while current_turns < self.config.multi_turn.max_assistant_turns:
            if _req.state == AsyncRolloutRequestStateEnum.PENDING:
                await self._handle_pending_state(_req)
                _req.state = AsyncRolloutRequestStateEnum.RUNNING
            elif _req.state == AsyncRolloutRequestStateEnum.TOOL_CALLING:
                if _req.get_last_msg().tool_calls is not None:
                    parsed_tool_calls = _req.get_last_msg().tool_calls
                    tool_call_results = await asyncio.gather(
                        *[
                            self._tool_map[tool_call.function.name].execute(
                                _req.request_id,
                                tool_call.function.arguments,
                                **_req.tools_kwargs[tool_call.function.name].get("execute_kwargs", {}),
                            )
                            for tool_call in parsed_tool_calls
                        ]
                    )
                    _req.add_tool_response_messages(self.tokenizer, [resp for resp, _, _ in tool_call_results])
                    for tool_call, (resp, reward, metrics) in zip(parsed_tool_calls, tool_call_results):
                        _req.update_metrics(metrics, tool_call.function.name)
                    if _req.get_next_input_ids_len() >= self.config.max_model_len:
                        finish_reason_type = FinishReasonTypeEnum.STOP
                        break
                    _req.state = AsyncRolloutRequestStateEnum.RUNNING
                else:
                    raise ValueError(f"Unexpected tool calling last message state: {_req.get_last_msg()}")
            elif _req.state == AsyncRolloutRequestStateEnum.RUNNING:
                # Only continue the conversation if the prompt length is not greater than max_model_len - 1,
                # since SGLang raises an error when max_new_tokens + 1 is greater to max_model_len (the extra token accounts for the EOS token).
                # every time before generating an action we check if belief state should be calculated.
                # this will have to call handle engine call for belief generation,
                if _req.should_generate_belief():
                    successful_context_gen = _req.pre_generate_belief_call(self.tokenizer)
                    if not successful_context_gen: # doesn't create a new context
                        finish_reason_type = FinishReasonTypeEnum.LENGTH
                        break
                    belief_generated = False # you get as many as will fit in context
                    while not belief_generated:
                        if len(_req.get_generation_prompt_ids(self.tokenizer)) + 1 >= self.config.max_model_len:
                            finish_reason_type = FinishReasonTypeEnum.LENGTH
                            break
                        if _req.should_engine_call_belief_generation():
                            output = await self._handle_engine_call(_req, request_sampling_params) # belief generation.
                            content = output["text"]
                            output_ids = [t[1] for t in output['meta_info']['output_token_logprobs']]
                            _req.add_assistant_message(self.tokenizer, content, output_ids)
                        else:
                            content = "<belief>" + self.interaction.get_mdp(_req.request_id).generate_posterior_str() + "</belief>"
                            output = {"text": content,
                                      # this will refer to the tokens used for a forward pass when generating the belief state, which here shouldn't be anything
                                      "meta_info": {"prompt_tokens": 0,
                                                    "completion_tokens": 0},
                                      }
                            output_ids = self.tokenizer.encode()
                            # changed this so that no policy gradient is executed on these tokens, because the model didn't generate these.
                            _req.add_assistant_message(self.tokenizer, content, output_ids, loss_mask=False) 

                        if 151644 in output_ids: # big issue with generation should terminate and -1 this trajectory.
                            finish_reason_type = FinishReasonTypeEnum.STOP # don't know if I want to count these failures yet, but they def need to be stopped.
                            break
                        if _req.get_next_input_ids_len() >= self.config.max_model_len:
                            finish_reason_type = FinishReasonTypeEnum.LENGTH
                            break
                        if _req.is_belief_valid(content):
                            belief_generated = True
                            prompt_tokens_per_belief_message += [output['meta_info']['prompt_tokens']]
                            tokens_per_belief_generation_message += [output['meta_info']['completion_tokens']] # only record the last one. we dont care redo cost, it seems to ollow directions eventually anyway.
                            tokens_per_belief_state_message += [len(self.tokenizer.encode(_req.extract_belief(content), add_special_tokens=False))]
                            break
                        else:
                            belief_gen_failures += 1
                            if belief_gen_failures > 4:
                                finish_reason_type = FinishReasonTypeEnum.STOP
                                break
                            _req.add_user_message(self.tokenizer, _req.get_belief_generation_failure_msg())
                    if not belief_generated:
                        break
                    # it is possible to have a belief state call which requires tools, we will not support this for now. 
                    # and store the belief in one context with gradient info, (belief generation context)
                    # and then separately in another context without gradient info (action context)
                    successful_context_gen = _req.post_generate_belief_call(self.tokenizer)
                    if not successful_context_gen: # doesn't create a new context
                        finish_reason_type = FinishReasonTypeEnum.LENGTH
                        break
                if len(_req.get_generation_prompt_ids(self.tokenizer)) + 1 >= self.config.max_model_len:
                    finish_reason_type = FinishReasonTypeEnum.LENGTH
                    break
                output = await self._handle_engine_call(_req, request_sampling_params) # action
                # print('output', output) # this for seeing if I can compute the number of generated tokens easily.
                prompt_tokens_per_action_message += [output['meta_info']['prompt_tokens']]
                tokens_per_action_message += [output['meta_info']['completion_tokens']] # this is characters right now, but want to make it tokens eventually.
                content = output["text"]
                output_ids = [t[1] for t in output['meta_info']['output_token_logprobs']]
                finish_reason_type = FinishReasonTypeEnum.from_str(output["meta_info"]["finish_reason"]["type"])
                _req.increment_turn()
                current_turns += 1
                if 151644 in output_ids:
                    _req.add_assistant_message(self.tokenizer, content, output_ids)
                    finish_reason_type = FinishReasonTypeEnum.STOP
                    break
                if finish_reason_type == FinishReasonTypeEnum.LENGTH:
                    _req.add_assistant_message(self.tokenizer, content, output_ids)
                    break
                else:
                    if self._function_call_parser and self._function_call_parser.has_tool_call(content):
                        assert False, "Tools no longer supported fix the add assistant message function to reintroduce tool support. -jakob"
                        finish_reason_type = FinishReasonTypeEnum.TOOL_CALL
                        _req.state = AsyncRolloutRequestStateEnum.TOOL_CALLING
                        try:
                            normed_content, tool_calls = self._function_call_parser.parse_non_stream(content)
                        except JSONDecodeError:
                            normed_content = content
                            tool_calls = []
                        except AttributeError:
                            normed_content = content
                            tool_calls = []
                        parsed_tool_calls = []
                        for tool_call in tool_calls:
                            function, has_decode_error = OpenAIFunctionCallSchema.from_openai_function_parsed_schema(
                                OpenAIFunctionParsedSchema(
                                    name=tool_call.name,
                                    arguments=tool_call.parameters,
                                )
                            )
                            # Drop the tool call if its arguments has decode error
                            if has_decode_error:
                                continue
                            parsed_tool_calls.append(
                                OpenAIFunctionToolCall(
                                    id=str(tool_call.tool_index),
                                    function=function,
                                )
                            )
                        if len(parsed_tool_calls) > 0:
                            _req.add_assistant_message(self.tokenizer, normed_content, tool_calls=parsed_tool_calls)
                        else:
                            _req.add_assistant_message(self.tokenizer, content)
                            finish_reason_type = FinishReasonTypeEnum.STOP
                            _req.state = AsyncRolloutRequestStateEnum.COMPLETED
                            break
                    else:
                        _req.add_assistant_message(
                            self.tokenizer,
                            content,
                            output_ids,
                        )
                        if _req.interaction_kwargs and user_turns < self.config.multi_turn.max_user_turns and current_turns < self.config.multi_turn.max_assistant_turns:
                            _req.state = AsyncRolloutRequestStateEnum.INTERACTING
                        else:
                            break
            elif _req.state == AsyncRolloutRequestStateEnum.INTERACTING:
                user_turns += 1
                messages = [{"role": x.role, "content": x.content} for x in _req.get_last_msgs()]
                should_terminate_sequence, content, reward, metrics = await self.interaction.generate_response(_req.request_id, messages, **_req.interaction_kwargs)
                user_turn_rewards.append(reward)
                if should_terminate_sequence:
                    run_completion = True
                    finish_reason_type = FinishReasonTypeEnum.STOP
                    _req.state = AsyncRolloutRequestStateEnum.COMPLETED
                    break
                else:
                    _req.add_user_message(self.tokenizer, content)
                    if _req.get_next_input_ids_len() >= self.config.max_model_len:
                        finish_reason_type = FinishReasonTypeEnum.STOP
                        break
                    else:
                        _req.state = AsyncRolloutRequestStateEnum.RUNNING

        if current_turns >= self.config.multi_turn.max_assistant_turns:
            finish_reason_type = FinishReasonTypeEnum.STOP

        # Calculate the reward for each tool
        async def calc_reward_and_release_fn(name: str, tool: BaseTool):
            reward = await tool.calc_reward(_req.request_id, **_req.tools_kwargs[name].get("calc_reward_kwargs", {}))
            await tool.release(_req.request_id, **_req.tools_kwargs[name].get("release_kwargs", {}))
            return name, reward

        tool_reward_tasks = []
        for name in _req.tools_kwargs.keys():
            tool = self._tool_map[name]
            tool_reward_tasks.append(calc_reward_and_release_fn(name, tool))
        tool_reward_scores = await asyncio.gather(*tool_reward_tasks)
        tool_reward_scores = dict(tool_reward_scores)
        if self.interaction is not None:
            mdp_reward = await self.interaction.calculate_score(_req.request_id)

            all_rewards = {"interaction_reward": [(mdp_reward - self.interaction.get_format_penalty_coef(_req.request_id) * (belief_gen_failures + self.interaction.get_trajectory_info(_req.request_id)['invalid_format_errors']))]} 
            run_success_flag = mdp_reward > 0
        else:
            run_success_flag = 0.42 # this isn't properly defined, so make a weird number in case I see it.
            all_rewards = {**tool_reward_scores, **{"user_turn_rewards": user_turn_rewards}}
        # all rewards other than this don't matter anyway for combo lock environment, with GRPO especially. Outcome rewards are it.
        # this is very specific to combo lock with the defined rewards of anything above zero meaning that the correct thing was eventually guessed.
        _req.to_log_stats = {"trajectory_info": self.interaction.get_trajectory_info(_req.request_id), 
                             "prompt_tokens_per_belief_message": prompt_tokens_per_belief_message,
                             "prompt_tokens_per_action_message": prompt_tokens_per_action_message,
                             "tokens_per_action_message": tokens_per_action_message, 
                             "tokens_per_belief_generation_message": tokens_per_belief_generation_message,
                             "tokens_per_belief_state_message": tokens_per_belief_state_message,
                             "run_success": run_success_flag, 
                             "run_attempts": self.interaction.get_attempts(_req.request_id), 
                             "run_completion": run_completion,
                             "belief_gen_failures": belief_gen_failures}
        _req.finalize(self.tokenizer, all_rewards, finish_reason_type)

        return _req

    async def _handle_engine_call(self, _req: AsyncRolloutRequestInterface, sampling_params: dict) -> dict:
        generation_prompt_ids = _req.get_generation_prompt_ids(self.tokenizer)
        max_new_tokens = min(self.config.max_new_tokens, self.config.response_length, self.config.max_model_len - len(generation_prompt_ids) - 1)
        kwargs = sampling_params.copy()
        kwargs["max_new_tokens"] = max_new_tokens
        kwargs["n"] = 1  # group size is supported in preprocess
        # for sglang
        # output = await llm.async_generate(input_ids=generation_prompt_ids,
        #     sampling_params=kwargs,
        #     return_logprob=True) 
        
        # for not vllm
        # output = await generate_text(prompt_token_ids=generation_prompt_ids)
        # kwargs['max_tokens'] = len(generation_prompt_ids) + kwargs.pop('max_new_tokens')
        output = await client.completions.create(model="qwen/qwen2.5-7b-instruct", prompt=generation_prompt_ids, max_tokens=len(generation_prompt_ids) + kwargs.pop('max_new_tokens'),logprobs=1, extra_body=kwargs)
        output = output.choices[0]
        try:
            output = {"text": output.text,
                    "meta_info": {"prompt_tokens": 0,
                                    "completion_tokens":0,
                                    "output_token_logprobs":[(0, int(t.replace('token_id:',"")), 0) for t in output.logprobs.tokens],
                                    "finish_reason": {'type':output.finish_reason}
                                    }
                    }
        except:
            print(output.logprobs.tokens)
        return output
    @GPUMemoryLogger(role="sglang rollout", logger=logger)    
    @torch.no_grad()
    def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
        if self.config.multi_turn.enable:
            return self._req_level_generate_sequences(prompts, **kwargs)
        return None
    
    @GPUMemoryLogger(role="sglang rollout", logger=logger)
    @torch.no_grad()
    def _req_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
        # Async rollout with tools support
        do_sample = prompts.meta_info.get("do_sample", True)
        is_validate = prompts.meta_info.get("validate", False)
        tgt_device = prompts.batch["input_ids"].device
        
        req_list = self._preprocess_prompt_to_async_rollout_requests(
            prompts,
            n=1 if is_validate else self.config.n,
        )
        loop = asyncio.get_event_loop()
        output_req_list = loop.run_until_complete(
            asyncio.gather(
                *[self._async_rollout_a_request(req, do_sample, is_validate, **kwargs) for req in req_list],
            )
        )
        sorted_output_req_list = sorted(output_req_list, key=lambda x: (x.batch_data_id, x.rollout_offset))
        
        # Construct the batch data
        prompt_ids, response_ids = [], []
        prompt_attention_mask, response_attention_mask = [], []
        prompt_position_ids, response_position_ids = [], []
        prompt_loss_mask, response_loss_mask = [], []
        messages = []
        reward_scores = []
        request_ids = []
        context_indices = []
        to_log_stats = []



        # the code is messy, but eventually we plan to remove the AsyncRolloutRequest and replace with AsyncRolloutRequestMultiContext with single context.

        new_sorted_output_req_list = []
        if self.config.multi_turn.multi_context.enable:
            for req in sorted_output_req_list:
                for idx in range(len(req.input_ids)): # number of contexts
                    new_sorted_output_req_list.append(SimpleNamespace(**dict(
                                                                    request_id = req.request_id,
                                                                    context_index = idx,
                                                                    state = req.state,
                                                                    reward_scores = req.reward_scores,
                                                                    to_log_stats = req.to_log_stats,
                                                                    max_model_len = req.max_model_len,
                                                                    input_ids = req.input_ids[idx],
                                                                    attention_mask = req.attention_mask[idx],
                                                                    position_ids = req.position_ids[idx],
                                                                    prompt_ids = req.prompt_ids[idx],
                                                                    loss_mask = req.loss_mask[idx],
                                                                    response_ids = req.response_ids[idx],
                                                                    prompt_attention_mask = req.prompt_attention_mask[idx],
                                                                    response_attention_mask = req.response_attention_mask[idx],
                                                                    prompt_position_ids = req.prompt_position_ids[idx],
                                                                    response_position_ids = req.response_position_ids[idx],
                                                                    prompt_loss_mask = req.prompt_loss_mask[idx],
                                                                    response_loss_mask = req.response_loss_mask[idx],
                                                                    messages = req.messages[idx],
                                                                    )))
            sorted_output_req_list = new_sorted_output_req_list

        for req in sorted_output_req_list:
            assert req.state == AsyncRolloutRequestStateEnum.COMPLETED, f"Request {req.request_id} is not completed"
            assert len(req.input_ids) == len(req.attention_mask) == len(req.position_ids) == len(req.loss_mask), f"""Request {req.request_id} has different length of 
                {len(req.input_ids)=}, {len(req.attention_mask)=}, {len(req.position_ids)=}, {len(req.loss_mask)=}"""
            error_message_lines = [
                f"""Request {req.request_id} has input_ids length {len(req.input_ids)}
                    greater than max_model_len {self.config.max_model_len}""",
                f"Decoded input_ids: {self.tokenizer.decode(req.input_ids)}",
                f"Decoded prompt_ids: {self.tokenizer.decode(req.prompt_ids)}",
                f"Decoded response_ids: {self.tokenizer.decode(req.response_ids)}",
                f"Messages: {req.messages}",
                f"Max model length: {req.max_model_len}",
            ]
            error_message = "\n".join(error_message_lines)
            assert len(req.input_ids) <= self.config.max_model_len, error_message

            prompt_ids.append(torch.tensor(req.prompt_ids, dtype=torch.int, device=tgt_device))
            response_ids.append(torch.tensor(req.response_ids, dtype=torch.int, device=tgt_device))
            if len(req.response_ids) > self.config.response_length:
                logger.warning(
                    f"""{req.request_id=} has response_ids length {len(req.response_ids)} 
                    greater than max_response_len {self.config.response_length},\n{req=}"""
                )
            prompt_attention_mask.append(torch.tensor(req.prompt_attention_mask, dtype=torch.int, device=tgt_device))
            response_attention_mask.append(torch.tensor(req.response_attention_mask, dtype=torch.int, device=tgt_device))
            prompt_position_ids.append(torch.tensor(req.prompt_position_ids, dtype=torch.int, device=tgt_device))
            response_position_ids.append(torch.tensor(req.response_position_ids, dtype=torch.int, device=tgt_device))
            prompt_loss_mask.append(torch.tensor(req.prompt_loss_mask, dtype=torch.int, device=tgt_device))
            response_loss_mask.append(torch.tensor(req.response_loss_mask, dtype=torch.int, device=tgt_device))
            messages.append({"messages": req.messages})
            reward_scores.append(req.reward_scores)
            request_ids.append(req.request_id)
            context_indices.append(req.context_index)
            to_log_stats.append(req.to_log_stats)

        prompt_ids = pad_sequence(
            prompt_ids,
            batch_first=True,
            padding_value=self.pad_token_id,
            padding_side="left",
        )
        if prompt_ids.shape[1] < self.config.prompt_length:
            prompt_ids = pad_sequence_to_length(prompt_ids, self.config.prompt_length, self.pad_token_id, left_pad=True)
        response_ids = pad_sequence(response_ids, batch_first=True, padding_value=self.pad_token_id)
        if response_ids.shape[1] < self.config.response_length:
            response_ids = pad_sequence_to_length(response_ids, self.config.response_length, self.pad_token_id)
        prompt_attention_mask = pad_sequence(
            prompt_attention_mask,
            batch_first=True,
            padding_value=0,
            padding_side="left",
        )
        if prompt_attention_mask.shape[1] < self.config.prompt_length:
            prompt_attention_mask = pad_sequence_to_length(prompt_attention_mask, self.config.prompt_length, 0, left_pad=True)
        response_attention_mask = pad_sequence(response_attention_mask, batch_first=True, padding_value=0)
        if response_attention_mask.shape[1] < self.config.response_length:
            response_attention_mask = pad_sequence_to_length(response_attention_mask, self.config.response_length, 0)
        prompt_position_ids = pad_sequence(prompt_position_ids, batch_first=True, padding_value=0, padding_side="left")
        if prompt_position_ids.shape[1] < self.config.prompt_length:
            prompt_position_ids = pad_sequence_to_length(prompt_position_ids, self.config.prompt_length, 0, left_pad=True)
        response_length = response_ids.size(1)
        delta_position_id = torch.arange(1, response_length + 1, device=response_ids.device)
        delta_position_id = delta_position_id.unsqueeze(0).repeat(len(sorted_output_req_list), 1)
        response_position_ids = prompt_position_ids[:, -1:] + delta_position_id
        prompt_loss_mask = pad_sequence(prompt_loss_mask, batch_first=True, padding_value=0, padding_side="left")
        if prompt_loss_mask.shape[1] < self.config.prompt_length:
            prompt_loss_mask = pad_sequence_to_length(prompt_loss_mask, self.config.prompt_length, 0, left_pad=True)
        response_loss_mask = pad_sequence(response_loss_mask, batch_first=True, padding_value=0)
        if response_loss_mask.shape[1] < self.config.response_length:
            response_loss_mask = pad_sequence_to_length(response_loss_mask, self.config.response_length, 0)

        input_ids = torch.cat((prompt_ids, response_ids), dim=-1)
        attention_mask = torch.cat((prompt_attention_mask, response_attention_mask), dim=-1)
        position_ids = torch.cat((prompt_position_ids, response_position_ids), dim=-1)
        loss_mask = torch.cat((prompt_loss_mask, response_loss_mask), dim=-1)

        # Construct the batch data
        batch = TensorDict(
            {
                "prompts": prompt_ids,
                "responses": response_ids,
                "input_ids": input_ids,  # here input_ids become the whole sentences
                "attention_mask": attention_mask,
                "position_ids": position_ids,
                "loss_mask": loss_mask,
            },
            batch_size=len(sorted_output_req_list),
        )

        return DataProto(
            batch=batch,
            non_tensor_batch={
                "messages": np.array(messages),
                "reward_scores": np.array(reward_scores),
                "request_ids": np.array(request_ids),
                "context_indices": np.array(context_indices),
                "to_log_stats": np.array(to_log_stats),
            },
            # add item here for trajectory index?
        )

    def _preprocess_prompt_to_async_rollout_requests(self, prompts: DataProto, n: int) -> list[AsyncRolloutRequest]:
        assert "raw_prompt" in prompts.non_tensor_batch, "need data.return_raw_chat=True, due to no official way do parse_messages"
        req_list = []
        for data_idx, raw_prompt in enumerate(prompts.non_tensor_batch["raw_prompt"]):
            for rollout_offset in range(n):
                if self._tool_schemas:
                    _tools_kwargs = prompts.non_tensor_batch["tools_kwargs"][data_idx]
                    _tool_schemas = [self._tool_map[k].get_openai_tool_schema() for k in _tools_kwargs.keys()]
                    _input_ids = None
                    _attention_mask = None
                else:
                    _input_ids = _pre_process_inputs(self.pad_token_id, prompts.batch["input_ids"][data_idx])
                    _attention_mask = _pre_process_inputs(0, prompts.batch["attention_mask"][data_idx])
                    _tools_kwargs = {}
                    _tool_schemas = None

                if self.interaction is not None:
                    _interaction_kwargs = prompts.non_tensor_batch["interaction_kwargs"][data_idx]
                else:
                    _interaction_kwargs = {}

                if self.config.multi_turn.multi_context.enable:
                    req = AsyncRolloutRequestMultiContext(
                        batch_data_id=data_idx,
                        rollout_offset=rollout_offset,
                        request_id=str(uuid4()),
                        state=AsyncRolloutRequestStateEnum.PENDING,
                        messages=[raw_prompt.tolist()],
                        tool_schemas=_tool_schemas,
                        tools_kwargs=_tools_kwargs,
                        interaction_kwargs=_interaction_kwargs,
                        input_ids=[_input_ids],
                        response_ids=[],
                        attention_mask=[_attention_mask],
                        response_attention_mask=[],
                        response_position_ids=[],
                        response_loss_mask=[],
                        reward_scores={},
                        to_log_stats={},
                        max_prompt_len=self.config.prompt_length,
                        max_response_len=self.config.response_length,
                        max_model_len=min(self.config.max_model_len, self.config.prompt_length + self.config.response_length),
                        use_inference_chat_template=self.config.multi_turn.use_inference_chat_template,
                        force_thinking=self.config.multi_turn.force_thinking,
                        enable_tokenization_sanity_check=self.config.multi_turn.enable_tokenization_sanity_check,
                        belief_state_construction_style=self.config.multi_turn.multi_context.belief_state_construction_style,
                        single_context_belief_generation=self.config.multi_turn.multi_context.single_context_belief_generation,
                        lax_format_belief=self.config.multi_turn.lax_format,
                        tokenizer=self.tokenizer,
                    )

                else:
                    req = AsyncRolloutRequest(
                        batch_data_id=data_idx,
                        rollout_offset=rollout_offset,
                        request_id=str(uuid4()),
                        state=AsyncRolloutRequestStateEnum.PENDING,
                        messages=raw_prompt.tolist(),
                        tool_schemas=_tool_schemas,
                        tools_kwargs=_tools_kwargs,
                        interaction_kwargs=_interaction_kwargs,
                        input_ids=_input_ids,
                        response_ids=[],
                        attention_mask=_attention_mask,
                        response_attention_mask=[],
                        response_position_ids=[],
                        response_loss_mask=[],
                        reward_scores={},
                        to_log_stats={},
                        max_prompt_len=self.config.prompt_length,
                        max_response_len=self.config.response_length,
                        max_model_len=min(self.config.max_model_len, self.config.prompt_length + self.config.response_length),
                        use_inference_chat_template=self.config.multi_turn.use_inference_chat_template,
                        force_thinking=self.config.multi_turn.force_thinking,
                        enable_tokenization_sanity_check=self.config.multi_turn.enable_tokenization_sanity_check,
                        belief_state_construction_style=self.config.multi_turn.multi_context.belief_state_construction_style,
                        lax_format_belief=self.config.multi_turn.lax_format,
                        single_context_belief_generation=self.config.multi_turn.multi_context.single_context_belief_generation,
                        tokenizer=self.tokenizer,
                    )
                error_message = f"Request {req.request_id} has mismatched lengths: input_ids={len(req.input_ids)}, attention_mask={len(req.attention_mask)}, position_ids={len(req.position_ids)}, loss_mask={len(req.loss_mask)}"
                assert len(req.input_ids) == len(req.attention_mask) == len(req.position_ids) == len(req.loss_mask), error_message

                req_list.append(req)

        return req_list

    async def _handle_pending_state(self, _req: AsyncRolloutRequestInterface) -> None:
        if _req.tool_schemas is not None:
            tool_creation_coroutines = []
            for tool_schema in _req.tool_schemas:
                tool = self._tool_map[tool_schema.function.name]
                create_kwargs = _req.tools_kwargs[tool.name].get("create_kwargs", {})
                tool_creation_coroutines.append(tool.create(_req.request_id, **create_kwargs))
            await asyncio.gather(*tool_creation_coroutines)
        if _req.interaction_kwargs:
            interaction_kwargs = _req.interaction_kwargs | {"lax_format": self.config.multi_turn.lax_format, "format_penalty_coef": self.config.multi_turn.format_penalty_coef}
            await self.interaction.start_interaction(_req.request_id, **interaction_kwargs)

    async def wake_up(self):
        if not self.is_sleep:
            return
        await self.sharding_manager.wake_up()  # pylint: disable=C2801
        self.is_sleep = False

    # this function is left for uniform train-inference resharding
    async def sleep(self):
        if self.is_sleep:
            return
        await self.sharding_manager.sleep()
        self.is_sleep = True

local_path = copy_to_local(cfg.actor_rollout_ref.model.path, use_shm=cfg.actor_rollout_ref.model.get("use_shm", False))

tokenizer = hf_tokenizer(local_path, is_instruct_model=cfg.actor_rollout_ref.rollout.is_instruct_model, trust_remote_code=True)

rollout = SGLangRollout(cfg.actor_rollout_ref.rollout, tokenizer, None)

In [64]:
cfg.actor_rollout_ref.rollout.do_sample, rollout.config.temperature

(True, 1)

In [65]:
processor = hf_processor(local_path, trust_remote_code=True, use_fast=True)

dataset = create_rl_dataset(cfg.data.train_files, cfg.data, tokenizer, processor)

from verl.utils.dataset.rl_dataset import collate_fn as default_collate_fn

collate_fn = default_collate_fn


val_dataloader = StatefulDataLoader(
    dataset=dataset,
    batch_size=100,
    num_workers=0,
    shuffle=False,
    drop_last=False,
    collate_fn=collate_fn,
)

for batch in val_dataloader:
    print(batch)
    test_batch = DataProto.from_single_dict(batch)
    break


Using dataset class: RLHFDataset
dataset len: 720
filter dataset len: 720
{'input_ids': tensor([[151643, 151643, 151643,  ..., 151644,  77091,    198],
        [151643, 151643, 151643,  ..., 151644,  77091,    198],
        [151643, 151643, 151643,  ..., 151644,  77091,    198],
        ...,
        [151643, 151643, 151643,  ..., 151644,  77091,    198],
        [151643, 151643, 151643,  ..., 151644,  77091,    198],
        [151643, 151643, 151643,  ..., 151644,  77091,    198]]), 'attention_mask': tensor([[0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        ...,
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1]]), 'position_ids': tensor([[  0,   0,   0,  ..., 259, 260, 261],
        [  0,   0,   0,  ..., 259, 260, 261],
        [  0,   0,   0,  ..., 259, 260, 261],
        ...,
        [  0,   0,   0,  ..., 259, 260, 261],
        [  0,   0,   0,  ..., 259, 260, 261],
        [  0,   0

In [68]:
from verl.protocol import pad_dataproto_to_divisor
for batch in val_dataloader:
    test_batch = DataProto.from_single_dict(batch)
    break
batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"]
non_tensor_batch_keys_to_pop = ["raw_prompt_ids"]
if "multi_modal_data" in test_batch.non_tensor_batch:
    non_tensor_batch_keys_to_pop.append("multi_modal_data")
if "raw_prompt" in test_batch.non_tensor_batch:
    non_tensor_batch_keys_to_pop.append("raw_prompt")
if "tools_kwargs" in test_batch.non_tensor_batch:
    non_tensor_batch_keys_to_pop.append("tools_kwargs")
if "interaction_kwargs" in test_batch.non_tensor_batch:
    non_tensor_batch_keys_to_pop.append("interaction_kwargs")
print(f"non_tensor_batch_keys_to_pop: {non_tensor_batch_keys_to_pop}")
print(f"batch_keys_to_pop: {batch_keys_to_pop}")
test_gen_batch = test_batch.pop(
    batch_keys=batch_keys_to_pop,
    non_tensor_batch_keys=non_tensor_batch_keys_to_pop,
)

test_gen_batch.meta_info = {
    "eos_token_id": rollout.tokenizer.eos_token_id,
    "pad_token_id": rollout.tokenizer.pad_token_id,
    "recompute_log_prob": False,
    "do_sample": True,
    "validate": False,
}

print(f"test_gen_batch meta info: {test_gen_batch.meta_info}")

# pad to be divisible by dp_size
test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, 1)

test_output_gen_batch_padded = rollout.generate_sequences(test_gen_batch_padded)
# prompts = test_gen_batch_padded
test_output_gen_batch_padded


# do_sample = prompts.meta_info.get("do_sample", True)
# is_validate = prompts.meta_info.get("validate", False)
# tgt_device = prompts.batch["input_ids"].device
# req_list = rollout._preprocess_prompt_to_async_rollout_requests(
#     prompts,
#     n=1 if is_validate else rollout.config.n,
# )
# # loop = asyncio.get_event_loop()
# output_req_list = await asyncio.gather(
#         *[rollout._async_rollout_a_request(req, do_sample, is_validate) for req in req_list],
#     )
# sorted_output_req_list = sorted(output_req_list, key=lambda x: (x.batch_data_id, x.rollout_offset))

non_tensor_batch_keys_to_pop: ['raw_prompt_ids', 'raw_prompt', 'tools_kwargs', 'interaction_kwargs']
batch_keys_to_pop: ['input_ids', 'attention_mask', 'position_ids']
test_gen_batch meta info: {'eos_token_id': 151645, 'pad_token_id': 151643, 'recompute_log_prob': False, 'do_sample': True, 'validate': False}


DataProto(batch=TensorDict(
    fields={
        attention_mask: Tensor(shape=torch.Size([2074, 3072]), device=cpu, dtype=torch.int32, is_shared=False),
        input_ids: Tensor(shape=torch.Size([2074, 3072]), device=cpu, dtype=torch.int32, is_shared=False),
        loss_mask: Tensor(shape=torch.Size([2074, 3072]), device=cpu, dtype=torch.int32, is_shared=False),
        position_ids: Tensor(shape=torch.Size([2074, 3072]), device=cpu, dtype=torch.int64, is_shared=False),
        prompts: Tensor(shape=torch.Size([2074, 1024]), device=cpu, dtype=torch.int32, is_shared=False),
        responses: Tensor(shape=torch.Size([2074, 2048]), device=cpu, dtype=torch.int32, is_shared=False)},
    batch_size=torch.Size([2074]),
    device=None,
    is_shared=False), non_tensor_batch={'messages': array([{'messages': [Message(role='system', content="You will determine the correct combination of characters at [Position 1, Position 2, Position 3] in a 3-character combination lock through iterative reas

In [69]:
full_traj_responses = [""] * len(test_output_gen_batch_padded)
full_traj_response_text_map = dict()
for i, response_id in enumerate(test_output_gen_batch_padded.non_tensor_batch['request_ids']):
    if response_id in full_traj_response_text_map:
        full_traj_responses[i] = full_traj_response_text_map[response_id]
        continue
    indicies_of_traj = [j for j, res_id in enumerate(test_output_gen_batch_padded.non_tensor_batch['request_ids']) if res_id == response_id]
    indicies_of_traj.sort(key=lambda el: test_output_gen_batch_padded.non_tensor_batch['context_indices'][el])
    # print(indicies_of_traj)
    full_traj_responses[i] = "\n==========================\n".join((tokenizer.decode(test_output_gen_batch_padded.batch['input_ids'][index_of_traj][test_output_gen_batch_padded.batch['attention_mask'][index_of_traj] == 1], skip_special_tokens=False) for j, index_of_traj in enumerate(indicies_of_traj)))
    full_traj_response_text_map[response_id] = full_traj_responses[i]
test_output_gen_batch_padded.non_tensor_batch['full_traj_response'] = np.array(full_traj_responses)

starting_index_mapping = dict()
for i, response_id in enumerate(test_output_gen_batch_padded.non_tensor_batch['request_ids']):
    if response_id in starting_index_mapping:
        continue
    starting_index_mapping[response_id] = sorted([j for j, res_id in enumerate(test_output_gen_batch_padded.non_tensor_batch['request_ids']) if res_id == response_id], key=lambda el: test_output_gen_batch_padded.non_tensor_batch['context_indices'][el])[0]

starting_index_of_each_traj = list(starting_index_mapping.values())

In [70]:
test_output_gen_batch_padded.non_tensor_batch['to_log_stats'][0]

{'trajectory_info': {'repeated_guesses': 7,
  'feedback_hist': [('012', [2, 0, 0]),
   ('034', [2, 0, 0]),
   ('055', [2, 0, 0]),
   ('067', [2, 0, 1]),
   ('078', [2, 2, 2])],
  'invalid_format_errors': 0},
 'prompt_tokens_per_belief_message': [0, 0, 0, 0],
 'prompt_tokens_per_action_message': [0, 0, 0, 0, 0],
 'tokens_per_action_message': [0, 0, 0, 0, 0],
 'tokens_per_belief_generation_message': [0, 0, 0, 0],
 'tokens_per_belief_state_message': [63, 51, 39, 33],
 'run_success': True,
 'run_attempts': 5,
 'run_completion': True,
 'belief_gen_failures': 0}

In [71]:
print(np.mean([stats['run_success'] for stats in test_output_gen_batch_padded.non_tensor_batch['to_log_stats'][starting_index_of_each_traj]]))
print(np.mean([stats['trajectory_info']['invalid_format_errors'] for stats in test_output_gen_batch_padded.non_tensor_batch['to_log_stats'][starting_index_of_each_traj]]))
print(np.mean([stats['belief_gen_failures'] for stats in test_output_gen_batch_padded.non_tensor_batch['to_log_stats'][starting_index_of_each_traj]]))

0.58
2.36
0.14


In [None]:
# with defaults enabled, (but I over wrote a lot) I don't think this happens with verl-agent.
# 0.69
# 2.12
# 0.11
# with no defaults, These are slightly higher than with sglang, but not as high as we are getting with the verl-agent implementation. I wonder why???
# could the difference cause the performance gap we are seeing? 
# Like potentially some overlooked environment difference?
# 0.58
# 2.36
# 0.14

In [26]:
np.array([res.count('<action>invalid action</action>\nEnvironment feedback:\nCould not parse') for res in set(full_traj_responses)]).mean()

np.float64(0.0)

In [25]:
print(list(set(full_traj_responses))[0])

<|im_start|>system
You will determine the correct combination of characters at [Position 1, Position 2, Position 3] in a 3-character combination lock through iterative reasoning and queries.
All 3 characters are unique.
The set of valid characters are as follows: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
Each action is a query of the form ['char 1', 'char 2', 'char 3'].
Each time you query a combination, you will get feedback from the user about each character: either not in the combination, in the combination but in a different position, or in the combination and in the right position.
You can make up to 12 queries.
Your goal is to find the correct combination in the least number of queries.<|im_end|>
<|im_start|>user
Give me your first query formatted as a list of 3 characters inside <action> ... </action> after thinking inside <think> ... </think>, e.g., <think> Let's think step by step before giving the query [your extensive thinking] </think> <action>['char 1', 'char 2', 

In [None]:
import pandas as pd

# --- 1)  Build a tidy DataFrame --------------------------------------------
ntb = test_output_gen_batch_padded.non_tensor_batch     # shorthand
df = pd.DataFrame({
    "request_id":  ntb["request_ids"],
    # keep reward ready to use; pull the element you need right now
    "reward_score": [r["interaction_reward"][0] for r in ntb["reward_scores"]],
})
df = df.join(pd.DataFrame(ntb['to_log_stats'].tolist()))
df = df.join(df.apply(lambda x: (x['trajectory_info']), axis=1, result_type="expand"))
del df['trajectory_info']
# --- 2)  How many rows belong to each request_id? ---------------------------
request_counts = (
    df.groupby("request_id", sort=False)      # keep original order
      .size()                                 # count rows per id
      .rename("num_rows")
)

# --- 3)  Indices of the *last* row for every trajectory ---------------------
last_row_idx = (
    df.groupby("request_id", sort=False)
      .tail(1)                                # keep the last row of each group
      .index
)
traj_indices = last_row_idx.to_numpy()

display(df.loc[traj_indices])#.feedback_hist.values
df.loc[traj_indices].feedback_hist.values.tolist()

Unnamed: 0,request_id,reward_score,prompt_tokens_per_belief_message,prompt_tokens_per_action_message,tokens_per_action_message,tokens_per_belief_generation_message,tokens_per_belief_state_message,run_success,run_attempts,run_completion,belief_gen_failures,repeated_guesses,feedback_hist,invalid_format_errors
26,b0a62d6d-b101-436a-988f-227b0ae55a38,-1.0,"[319, 367, 444, 458, 571, 456, 375, 381, 359, ...","[262, 302, 411, 411, 506, 416, 335, 341, 319, ...","[119, 68, 83, 127, 116, 107, 109, 95, 81, 113,...","[118, 544, 134, 310, 140, 59, 171, 292, 37, 73...","[18, 127, 128, 223, 133, 52, 58, 36, 30, 66, 6...",False,12,True,0,9,"[(013, [2, 0, 0]), (018, [2, 0, 2]), (804, [1,...",2
57,2a3f84cf-281c-4a18-b0f6-2c9d528bc4c2,-1.0,"[326, 405, 405, 371, 737, 353, 426, 612, 371, ...","[262, 365, 372, 331, 356, 320, 379, 349, 331, ...","[96, 74, 91, 99, 111, 111, 101, 73, 73, 63, 11...","[226, 265, 308, 80, 44, 398, 439, 243, 308, 27...","[81, 88, 47, 73, 37, 97, 66, 48, 75, 60, 115, ...",False,11,False,2,1,"[(601, [0, 0, 1]), (251, [0, 0, 1]), (429, [0,...",4
80,d94264b8-426f-4571-96c3-738ba2a96815,-1.0,"[319, 372, 1116, 1114, 623, 542, 538, 513, 560...","[262, 339, 1076, 1067, 590, 509, 491, 473, 527...","[111, 126, 93, 154, 93, 68, 86, 74, 81, 114, 6...","[151, 1476, 791, 909, 854, 360, 330, 340, 479,...","[57, 792, 783, 307, 225, 207, 189, 243, 397, 5...",False,12,True,0,2,"[(012, [0, 0, 2]), (190, [0, 0, 0]), (023, [0,...",0
103,6de5a79f-62df-4794-8ca0-11f1c347743e,-1.0,"[319, 408, 379, 370, 407, 420, 364, 349, 349, ...","[262, 368, 339, 330, 360, 380, 324, 316, 316, ...","[111, 154, 69, 64, 72, 153, 104, 104, 149, 119...","[160, 274, 53, 83, 103, 228, 40, 250, 260, 69,...","[84, 55, 46, 77, 97, 41, 33, 33, 33, 62, 212]",False,12,True,1,4,"[(208, [0, 0, 0]), (134, [0, 1, 0]), (156, [0,...",0
130,f0f2c244-8968-421b-9cfa-71c5e761777b,-1.0,"[319, 504, 488, 362, 342, 344, 376, 1072, 497,...","[262, 464, 455, 329, 302, 311, 311, 1040, 464,...","[94, 96, 96, 71, 132, 135, 93, 76, 95, 56, 48,...","[296, 178, 88, 196, 167, 176, 957, 251, 163, 2...","[180, 171, 45, 18, 27, 27, 756, 180, 95, 95, 9...",False,12,True,0,0,"[(048, [0, 0, 0]), (135, [0, 0, 1]), (124, [0,...",2
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2373,3718f83e-0886-42d6-9790-d78a2fddbd7f,-1.0,"[319, 2234, 711, 578, 612, 668, 394, 344, 369,...","[262, 2201, 671, 545, 572, 635, 347, 311, 304,...","[112, 57, 92, 102, 68, 70, 93, 151, 104, 76, 1...","[2033, 394, 268, 436, 358, 199, 349, 341, 132,...","[1917, 387, 261, 288, 351, 65, 29, 22, 22, 23,...",False,11,False,0,7,"[(159, [0, 0, 0]), (075, [2, 0, 0]), (074, [2,...",4
2400,98ad4e3f-475e-405a-90d2-2115b22c0db6,-1.0,"[319, 408, 351, 418, 407, 368, 353, 356, 362, ...","[262, 343, 318, 353, 374, 335, 320, 323, 329, ...","[99, 103, 103, 87, 134, 90, 31, 55, 108, 87, 9...","[168, 155, 254, 198, 149, 43, 46, 144, 46, 34,...","[61, 35, 70, 91, 51, 36, 39, 45, 39, 27, 24, 2...",False,12,True,0,8,"[(012, [0, 0, 0]), (456, [2, 0, 0]), (419, [2,...",2
2429,68b3b0a2-44f6-4201-aa4c-c11e0a096143,-1.0,"[319, 528, 454, 434, 448, 401, 413, 409, 401, ...","[262, 463, 390, 390, 383, 366, 366, 374, 368, ...","[60, 94, 73, 121, 66, 116, 113, 109, 175, 164,...","[314, 194, 205, 106, 259, 170, 359, 92, 217, 4...","[181, 106, 106, 99, 82, 82, 90, 84, 44, 36, 21...",False,12,True,2,7,"[(147, [0, 0, 2]), (173, [0, 1, 0]), (107, [0,...",3
2452,c6ac4828-c57b-488d-93d5-d7ae2c2d89a9,-1.0,"[319, 338, 338, 360, 397, 434, 427, 427, 388, ...","[262, 305, 305, 327, 357, 387, 387, 387, 355, ...","[114, 64, 107, 137, 222, 119, 122, 130, 83, 13...","[73, 117, 50, 80, 110, 468, 399, 296, 89, 223,...","[21, 21, 43, 73, 103, 103, 103, 71, 82, 41, 19]",False,12,True,0,2,"[(012, [0, 0, 0]), (012, [0, 0, 0]), (652, [2,...",0


[[('013', [2, 0, 0]),
  ('018', [2, 0, 2]),
  ('804', [1, 1, 0]),
  ('208', [0, 1, 2]),
  ('208', [0, 1, 2]),
  ('108', [0, 1, 2]),
  ('409', [0, 1, 0]),
  ('108', [0, 1, 2]),
  ('518', [0, 0, 2]),
  ('698', [0, 0, 2]),
  ('205', [0, 1, 0]),
  ('125', [0, 0, 0])],
 [('601', [0, 0, 1]),
  ('251', [0, 0, 1]),
  ('429', [0, 0, 0]),
  ('537', [0, 1, 0]),
  ('725', [0, 0, 0]),
  ('358', [1, 0, 1]),
  ('091', [0, 0, 1]),
  ('146', [1, 0, 0]),
  ('278', [0, 0, 1]),
  ('259', [0, 0, 0]),
  ('017', [0, 2, 0])],
 [('012', [0, 0, 2]),
  ('190', [0, 0, 0]),
  ('023', [0, 1, 0]),
  ('254', [1, 1, 0]),
  ('342', [0, 0, 2]),
  ('034', [0, 0, 0]),
  ('253', [1, 1, 0]),
  ('123', [0, 1, 0]),
  ('460', [0, 0, 0]),
  ('174', [0, 0, 0]),
  ('074', [0, 0, 0]),
  ('158', [0, 1, 1])],
 [('208', [0, 0, 0]),
  ('134', [0, 1, 0]),
  ('156', [0, 1, 0]),
  ('635', [0, 1, 2]),
  ('935', [1, 1, 2]),
  ('035', [0, 1, 2]),
  ('135', [0, 1, 2]),
  ('027', [0, 0, 0]),
  ('024', [0, 0, 0]),
  ('024', [0, 0, 0]),
  ('024

In [None]:
(df.loc[traj_indices].reward_score !=-1).sum() # 7b instruct 52/100, 3b instruct 5/100

np.int64(5)

In [17]:
req_id_to_print = '7ed63cb1-938e-4ea0-8c08-3750bc261a8a'
display(df.loc[traj_indices][df.loc[traj_indices].request_id == req_id_to_print].feedback_hist.values.tolist())
to_print_indices = [i for i, req_id in enumerate(test_output_gen_batch_padded.non_tensor_batch["request_ids"]) if req_id == req_id_to_print]
for p in tokenizer.batch_decode(test_output_gen_batch_padded.batch['input_ids'][np.array(to_print_indices)], skip_special_tokens=True):
    print("-"*50)
    print(p)


[[('012', [0, 0, 0]),
  ('345', [2, 0, 2]),
  ('365', [2, 0, 2]),
  ('785', [0, 0, 2]),
  ('095', [0, 2, 2]),
  ('012', [0, 0, 0]),
  ('345', [2, 0, 2]),
  ('346', [2, 0, 0]),
  ('301', [2, 0, 0])]]

--------------------------------------------------
system
You will determine the correct combination of characters at [Position 1, Position 2, Position 3] in a 3-character combination lock through iterative reasoning and queries.
All 3 characters are unique.
The set of valid characters are as follows: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
Each action is a query of the form ['char 1', 'char 2', 'char 3'].
Each time you query a combination, you will get feedback from the user about each character: either not in the combination, in the combination but in a different position, or in the combination and in the right position.
You can make up to 12 queries.
Your goal is to find the correct combination in the least number of queries.
user
Think extensively inside <think> tags, then give me your query formatted as a list of 3 characters inside <action>['char 1', 'char 2', 'char 3']</action>.

assistant
<think>
To start, we should make a query that tests a variety of characters in 

In [11]:
sorted_stats = sorted(zip(test_output_gen_batch_padded.non_tensor_batch["to_log_stats"], test_output_gen_batch_padded.non_tensor_batch["request_ids"]), key=lambda x: x[1])

[i[0]["trajectory_info"]["invalid_format_errors"] for i in sorted_stats]

[0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0]

In [6]:
[i for i, stat in enumerate(test_output_gen_batch_padded.non_tensor_batch['to_log_stats']) if stat['run_success']]

[0,
 1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 11,
 12,
 13,
 14,
 15,
 16,
 17,
 18,
 19,
 20,
 21,
 22,
 23,
 24,
 25,
 26,
 27,
 28,
 29,
 30,
 31,
 32,
 33,
 34,
 35,
 36,
 60,
 61,
 62,
 63,
 64,
 88,
 89,
 90,
 91,
 92,
 93,
 94,
 95,
 96,
 97,
 98,
 99,
 100,
 101,
 102,
 103,
 104,
 105,
 106,
 107,
 108,
 109,
 110,
 111,
 112,
 113,
 114,
 115,
 116,
 117,
 118,
 119,
 120,
 121,
 122,
 123,
 124,
 125,
 126,
 127,
 128,
 129,
 130,
 131,
 132,
 133,
 134,
 135,
 136,
 137,
 138,
 139,
 140,
 141,
 142,
 143,
 144,
 145,
 146,
 147,
 148,
 149,
 150,
 151,
 152,
 153,
 154,
 155,
 156,
 157]

In [7]:
for i in range(10):
    print(tokenizer.decode(test_output_gen_batch_padded.batch['input_ids'][i], skip_special_tokens=True))
    print('-'*100)

system prompt:

You will determine the correct combination of characters at [Position 1, Position 2, Position 3] in a 3-character combination lock through iterative reasoning and queries.
All 3 characters are unique.
The set of valid characters are as follows: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
Each time you try a combination, you will get feedback from the user about each character: either not in the combination, in the combination but in a different position, or in the combination and in the right position.
You can make up to 12 queries.
Your goal is to find the correct combination in the least number of queries.

user:

Give me your first query formatted as a list of 3 characters inside <action> ... </action>, e.g., <action>['char 1', 'char 2', 'char 3']</action>.


assistant:

<action>['0', '1', '2']</action>




user:

0 is in Position 1!
1 is not in the lock
2 is not in the lock


------------------------------------------------------------------------------------

In [11]:
for j in range(2004,2006):
    print("request id: ", test_output_gen_batch_padded.non_tensor_batch["request_ids"][j], " context index: ", test_output_gen_batch_padded.non_tensor_batch["context_indices"][j])
    for i in range(len(test_output_gen_batch_padded.non_tensor_batch["messages"][j]['messages'])):
        pprint(test_output_gen_batch_padded.non_tensor_batch["messages"][j]['messages'][i].content)#.keys()#['request_ids']
        print('-'*50)
    print('-'*100)

request id:  fb04a8aa-05b4-46ab-9b40-762057e0bcff  context index:  15
('You will determine the combination for [Position 1, Position 2, Position 3] '
 'in a 3-character combination lock through iterative reasoning and queries.\n'
 'The lock requires a 3-character combination. All 3 characters are unique.\n'
 "The set of valid characters are as follows: ['0', '1', '2', '3', '4', '5', "
 "'6', '7', '8', '9']\n"
 'You get feedback from the user on the presence of a character. Either not in '
 'the lock, in the lock but in a different position, or in the lock and in the '
 'right position.\n'
 'You have 12 queries.\n'
 'Your goal is to find the correct combination in the least number of '
 'queries.\n'
 'To query the lock, you will first think step by step, and then generate a '
 'query formatted as a list of 3 characters inside <action> ... </action>, '
 "e.g., <action>['char 1', 'char 2', 'char 3']</action>.\n"
 'Now update your beliefs based on the last action and environment feedback. 