## Pre steps to setup the CPU node pool and get k8s credential
1. Create a CPU node pool in GKE (update the env var based on your setup)

```
export PROJECT_ID=cloud-tpu-multipod-dev
export CLUSTER_NAME=mlperf-v5p
export ZONE=europe-west4
export CPU_POOL_NAME="tsbao-cpu-pool"
export MACHINE_TYPE="n2-standard-8"
export NUM_NODES=1

gcloud container node-pools create ${CPU_POOL_NAME}   --cluster=${CLUSTER_NAME}   --zone=${ZONE}   --project=${PROJECT_ID}    --machine-type=${MACHINE_TYPE}   --num-nodes=${NUM_NODES}   --enable-autoscaling --min-nodes=1 --max-nodes=5  --node-labels="cloud.google.com/gke-nodepool=${CPU_POOL_NAME}"
```

2. Create k8s credential (this will add credential to your local ~/.kube/config)

```
 gcloud container clusters get-credentials ${CLUSTER_NAME} --zone ${ZONE} --project ${PROJECT_ID}
```

3. Checkout R2E-Gym and patch this change (I ended up creating a fork due to no writer permission on the original repo): https://github.com/R2E-Gym/R2E-Gym/commit/046275291d34773657dbe170c96266b9736c938f

In [2]:
%load_ext autoreload
%autoreload 2

In [1]:
import sys
import os
import logging
import numpy as np
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from flax import nnx
import optax
from orbax import checkpoint as ocp
from kubernetes import client, config as k8s_config
from transformers import AutoTokenizer
from tunix.cli.utils import data as data_lib
import datasets as datasets_lib 
import qwix
from tunix.utils import compat
Dataset = datasets_lib.Dataset

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
import sys

pathways_root = os.path.expanduser('~/pathways-utils')
r2egym_root = os.path.expanduser('~/r2egym')

for root in [pathways_root, r2egym_root]:
    if root not in sys.path:
        sys.path.insert(0, root)

try:
    import pathwaysutils 
    import r2egym
    print("✅ pathways-utils, r2egym are successfully mapped.")
except ImportError as e:
    print(f"❌ Still missing a module: {e}")

✅ pathways-utils, r2egym are successfully mapped.


In [5]:
# ==========================================
# 3. Environment Configuration
# ==========================================

os.environ["KUBECONFIG"] = "~/.kube/config"
os.environ["NODE_SELECTOR_KEY"] = "cloud.google.com/gke-nodepool"
os.environ["NODE_SELECTOR_VAL"] = "deepswe-worker-pool" # NB: change based on your node pool name

# Kubernetes Setup
try:
    k8s_config.load_kube_config()
    k8s_client = client.CoreV1Api()
    # k8s_client.list_namespace(timeout_seconds=5)
except Exception as e:
    print(f"Warning: Kubernetes config loading failed: {e}")

In [6]:
import logging

# Remove existing handlers to prevent duplicate logs or conflicts
for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)

logging.basicConfig(
    stream=sys.stdout,  # Direct logs to standard output (notebook cell)
    level=logging.INFO, # Set the minimum level to INFO
    format="%(asctime)s - %(levelname)s - %(message)s", # Optional: customize the format
    datefmt="%Y-%m-%d %H:%M:%S" # Optional: customize the date format
)

In [7]:
import os

# os.environ["TOKENIZERS_PARALLELISM"] = "false"

import jax
from datasets import load_dataset
from tunix.cli.utils import data as data_lib

devices = jax.devices()
print(f"Available JAX devices: {devices}")


Available JAX devices: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=4, process_index=0, coords=(0,2,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(1,2,0), core_on_chip=0), TpuDevice(id=6, process_index=0, coords=(0,3,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,3,0), core_on_chip=0)]


In [8]:
# ==========================================
# 5. Model & Training Hyperparameters
# ==========================================
# MODEL_PATH = "/scratch/models/DeepSeek-R1-Distill-Qwen-1.5B/"
# MODEL_PATH = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
MODEL_PATH = os.path.expanduser("~/models/Qwen3-1.7B/")

# ====== Data ======
TRAIN_FRACTION = 1.0

# ====== Reproducibility ======
SEED = 42

# ====== LoRA ======
RANK = 64
ALPHA = 64.0
TRAIN_WITH_LORA = False

# ====== Sharding ======
# MESH = [(4, 2), ("fsdp", "tp")]


# ====== GRPO ======
# === Generation during GRPO training ===
# MAX_PROMPT_LENGTH = 32768
MAX_PROMPT_LENGTH = 4096
MAX_RESPONSE_LENGTH = 512
TEMPERATURE = 0.6
TOP_P = 0.95
TOP_K = 50
NUM_GENERATIONS = 2 # This corresponds to `G` in Algorithm 1

# === other GRPO configs ===
NUM_ITERATIONS = 1
BETA = 0.001
EPSILON = 0.2

# ====== Training ======
BATCH_SIZE = 16
MINI_BATCH_SIZE = 16
# ROLLOUT_MICRO_BATCH_SIZE = 8
# LOGPS_MICRO_BATCH_SIZE = 8
NUM_BATCHES = 1
NUM_TEST_BATCHES = 50

EVAL_EVERY_N_STEPS = 10
NUM_EPOCHS = 100 

# Number of training steps.
MAX_STEPS = 10

# Max turns in mult-agent interaction (set to 1 for single-turn)
MAX_TURNS = 3

# === AdamW, warmup, cosine scheduler ===
LEARNING_RATE = 1e-6
B1 = 0.9  
B2 = 0.99 
WEIGHT_DECAY = 0.1
WARMUP_STEPS = int(0.1 * MAX_STEPS)
MAX_GRAD_NORM = 0.1

# ====== Checkpoint saving ======
SAVE_INTERVAL_STEPS = 500
MAX_TO_KEEP = 4
DO_MEM_PROFILING = False

# ====== Inference ======
GENERATION_CONFIGS = {
    "greedy": {"temperature": 1e-4, "top_k": 1, "top_p": 1.0},
    "standard": {"temperature": 0.7, "top_k": 50, "top_p": 0.95},
    "liberal": {"temperature": 0.85, "top_k": 2000, "top_p": 1.0},
}

# ====== Rollout ======
ROLLOUT_ENGINE = "vanilla" # one of "vanilla", "vllm" or "sglang_jax"
CKPT_DIR = os.path.join("/tmp/cp", "deepswe_ckpt/00")


In [9]:
# ==========================================
# 6. JAX Device & Mesh Setup
# ==========================================
import jax
import jax.numpy as jnp
devices = jax.devices()
split = int(len(devices) / 2)
rollout_devices = np.array(devices[:split]).reshape(2,2)
train_devices = np.array(devices[split:]).reshape(2,2)

rollout_mesh = Mesh(rollout_devices, axis_names=('fsdp', 'tp'))
train_mesh = Mesh(train_devices, axis_names=('fsdp', 'tp'))


In [10]:
# ==========================================
# 2. Imports from Custom Modules
# ==========================================
from tunix.models.qwen3 import params as params_lib
from tunix.models.qwen3 import model as model_lib
from tunix.sft import utils as sft_utils
from tunix.sft import metrics_logger
from tunix.rl import rl_cluster as rl_cluster_lib
from tunix.rl.rollout import base_rollout
from tunix.rl.experimental import agentic_grpo_learner
from tunix.rl.agentic.parser.chat_template_parser import parser
from tunix.rl.agentic.rewards.reward_types import RewardOutput
from system_prompts import (
    SWE_SYSTEM_PROMPT, 
    SWE_SYSTEM_PROMPT_FN_CALL, 
    SWE_USER_PROMPT, 
    SWE_USER_PROMPT_FN_CALL, 
    SWEAGENT_SYSTEM_PROMPT, 
    SWEAGENT_USER_PROMPT
)

# Assumed custom imports based on usage
from swe_agent import SWEAgent
from swe_env import SWEEnv

print("Initializing Model...")
config = model_lib.ModelConfig.qwen3_1p7b()


qwen_reference = params_lib.create_model_from_safe_tensors(MODEL_PATH, config, mesh=train_mesh, dtype=jnp.bfloat16)
def get_lora_model(base_model, model_mesh):
  lora_provider = qwix.LoraProvider(
      module_path=(
          ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|"
          ".*attn_vec_einsum"
      ),
      rank=RANK,
      alpha=ALPHA,
  )

  model_input = base_model.get_model_input()
  lora_model = qwix.apply_lora_to_model(
      base_model, lora_provider, **model_input
  )

  with compat.set_mesh(model_mesh):
    state = nnx.state(lora_model)
    pspecs = nnx.get_partition_spec(state)
    sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
    nnx.update(lora_model, sharded_state)

  return lora_model
qwen_actor = get_lora_model(qwen_reference, train_mesh)
sft_utils.show_hbm_usage()


Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality.
Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade.
Users of this version of Gym should be able to simply replace 'import gym' with 'import gymnasium as gym' in the vast majority of cases.
See the migration guide at https://gymnasium.farama.org/introduction/migration_guide/ for additional information.


Initializing Model...


2026-02-23 01:36:24,798 - absl - INFO - [QWIX] module='layers/0/attn/q_proj' op=einsum0 rule=None
2026-02-23 01:36:25,504 - absl - INFO - [QWIX] module='layers/0/attn/k_proj' op=einsum0 rule=None
2026-02-23 01:36:26,093 - absl - INFO - [QWIX] module='layers/0/attn/v_proj' op=einsum0 rule=None
2026-02-23 01:36:27,019 - absl - INFO - [QWIX] module='layers/0/attn' op=einsum0 rule=None
2026-02-23 01:36:27,956 - absl - INFO - [QWIX] module='layers/0/attn' op=einsum1 rule=None
2026-02-23 01:36:28,098 - absl - INFO - [QWIX] module='layers/0/attn/o_proj' op=einsum0 rule=None
2026-02-23 01:36:28,304 - absl - INFO - [QWIX] module='layers/0/mlp/gate_proj' op=dot_general0 rule=0
2026-02-23 01:36:30,323 - absl - INFO - [QWIX] module='layers/0/mlp/up_proj' op=dot_general0 rule=0
2026-02-23 01:36:30,337 - absl - INFO - [QWIX] module='layers/0/mlp/down_proj' op=dot_general0 rule=0
2026-02-23 01:36:32,184 - absl - INFO - [QWIX] module='layers/1/attn/q_proj' op=einsum0 rule=None
2026-02-23 01:36:32,188 

In [11]:
# ==========================================
# 8. Tokenizer & Parser
# ==========================================
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_PATH, 
    local_files_only=True,
    trust_remote_code=True
)

chat_parser = parser.QwenChatTemplateParser(tokenizer)

In [11]:
from datasets import load_dataset
import json
print("Loading Dataset...")

DATASET_CACHE = os.getenv('DATASET_CACHE', '/home/sizhi_google_com/dataset_cache')
os.makedirs(DATASET_CACHE, exist_ok=True)
dataset = load_dataset("R2E-Gym/R2E-Gym-V1", split="train", cache_dir=DATASET_CACHE)

Loading Dataset...


In [4]:
dataset[0].keys()

dict_keys(['repo_name', 'docker_image', 'commit_hash', 'parsed_commit_content', 'execution_result_content', 'modified_files', 'modified_entity_summaries', 'relevant_files', 'num_non_test_files', 'num_non_test_func_methods', 'num_non_test_lines', 'prompt', 'problem_statement', 'expected_output_json'])

In [5]:
[type(v) for v in dataset[0].values()]

[str, str, str, str, str, list, list, list, int, int, int, str, str, str]

In [12]:
dataset[0]['modified_entity_summaries']

[{'ast_type_str': 'ClassDef',
  'end_lineno': 132,
  'file_name': 'Orange/widgets/tests/test_context_handler.py',
  'name': 'TestContextHandler.test_migrates_settings_removes_incompatible',
  'start_lineno': 114,
  'type': 'method'},
 {'ast_type_str': 'ClassDef',
  'end_lineno': 628,
  'file_name': 'Orange/widgets/settings.py',
  'name': 'ContextHandler._migrate_contexts',
  'start_lineno': 626,
  'type': 'method'},
 {'ast_type_str': 'ClassDef',
  'end_lineno': 214,
  'file_name': 'Orange/widgets/tests/test_context_handler.py',
  'name': 'TestContextHandler',
  'start_lineno': 54,
  'type': 'class'},
 {'ast_type_str': 'ClassDef',
  'end_lineno': 862,
  'file_name': 'Orange/widgets/settings.py',
  'name': 'ContextHandler',
  'start_lineno': 585,
  'type': 'class'}]

In [6]:
dataset[0]['parsed_commit_content']



In [7]:
dataset[0]['execution_result_content']



In [9]:
dataset[0]['problem_statement']

'[ISSUE]\n**Title:** Context migration fails to remove incompatible contexts, causing initialization errors\n\n**Description:**\nWhen initializing the `ContextHandler` with a mix of compatible and incompatible contexts, the migration process does not remove the incompatible contexts as expected. Instead, it raises an `IncompatibleContext` error, preventing successful initialization.\n\n**Example Code:**\n```python\nhandler = ContextHandler()\nhandler.bind(SimpleWidget)\n\nwidget = SimpleWidget()\ncontexts = [Context(foo=i) for i in (13, 13, 0, 1, 13, 2, 13)]\n\ndef migrate_context(context, _):\n    if context.foo == 13:\n        raise IncompatibleContext()\n\nhandler.initialize(widget, dict(context_settings=contexts))\n# Expected: Incompatible contexts with foo=13 should be removed\n# Actual: IncompatibleContext error is raised, and contexts are not removed\n```\n\n**Expected Behavior:**\nDuring initialization, contexts that are incompatible (e.g., those that cause `IncompatibleContext

In [10]:
dataset[0]['expected_output_json']

'{\n    "TestContextHandler.test_close_context": "PASSED",\n    "TestContextHandler.test_fast_save": "PASSED",\n    "TestContextHandler.test_find_or_create_context": "PASSED",\n    "TestContextHandler.test_initialize": "PASSED",\n    "TestContextHandler.test_initialize_migrates_contexts": "PASSED",\n    "TestContextHandler.test_migrates_settings_removes_incompatible": "PASSED",\n    "TestContextHandler.test_pack_settings_stores_version": "PASSED",\n    "TestContextHandler.test_read_defaults": "PASSED",\n    "TestContextHandler.test_write_defaults_stores_version": "PASSED",\n    "TestSettingsPrinter.test_formats_contexts": "PASSED"\n}'

In [12]:
from datasets import load_dataset
import json
print("Loading Dataset...")

DATASET_CACHE = os.getenv('DATASET_CACHE', '/home/sizhi_google_com/dataset_cache')
os.makedirs(DATASET_CACHE, exist_ok=True)
dataset = load_dataset("R2E-Gym/R2E-Gym-V1", split="train", cache_dir=DATASET_CACHE)


def transform_and_tokenize(entry):
    # Rename 'prompt' to 'prompts'
    entry['prompts'] = [] # agentic rl learner require this field to calculate size of batch 
    
    # JSON encode lists (excluding the new 'prompts')
    for k, v in entry.items():
        if isinstance(v, list) and k != 'prompts':
            entry[k] = json.dumps(v)
    
    # Pre-calculate token length for filtering later
    # This prevents redundant tokenization during the training loop
    tokens = tokenizer.encode(entry["problem_statement"], add_special_tokens=False)
    entry["prompt_length"] = len(tokens)
    
    return entry

dataset = dataset.map(
    transform_and_tokenize,
    num_proc=8,
    keep_in_memory=True,
    desc="Transforming and Tokenizing"
)

# entries = []
# unique_images = set()
# for i, entry in enumerate(dataset):
#   if "docker_image" in entry:
#     unique_images.add(entry["docker_image"])
#     entries.append(entry)
#   if i >= TASKS_TO_PROCESS - 1:
#     break
# unique_images = list(unique_images)
# print(f"Found {len(unique_images)} unique Docker images to download")
# IDS = [f"task-{i}" for i in range(len(entries))]

Loading Dataset...


  self.pid = os.fork()
  self.pid = os.fork()
Transforming and Tokenizing (num_proc=8): 100%|██████████| 8101/8101 [00:30<00:00, 261.83 examples/s]


In [13]:
# ==========================================
# 9. Optimizer & Checkpointing
# ==========================================
checkpointing_options = ocp.CheckpointManagerOptions(
    save_interval_steps=SAVE_INTERVAL_STEPS, max_to_keep=MAX_TO_KEEP
)
metrics_logging_options = metrics_logger.MetricsLoggerOptions(
    log_dir="/tmp/tensorboard/grpo", flush_every_n_steps=2
)

optimizer = optax.adamw(
    learning_rate=optax.schedules.warmup_cosine_decay_schedule(
        init_value=0.0,
        peak_value=LEARNING_RATE,
        warmup_steps=WARMUP_STEPS,
        decay_steps=MAX_STEPS,
        end_value=0.0,
    ),
    b1=B1,
    b2=B2,
    weight_decay=WEIGHT_DECAY,
)

# ==========================================
# 10. RL Cluster Setup
# ==========================================
cluster_config = rl_cluster_lib.ClusterConfig(
    role_to_mesh={
        rl_cluster_lib.Role.ACTOR: train_mesh,
        rl_cluster_lib.Role.REFERENCE: train_mesh,
        rl_cluster_lib.Role.ROLLOUT: rollout_mesh,
    },
    rollout_engine=ROLLOUT_ENGINE,
    offload_to_cpu=False,
    training_config=rl_cluster_lib.RLTrainingConfig(
        actor_optimizer=optimizer,
        eval_every_n_steps=EVAL_EVERY_N_STEPS,
        max_steps=MAX_STEPS,
        mini_batch_size=MINI_BATCH_SIZE,
        train_micro_batch_size=1,
        metrics_logging_options=metrics_logging_options,
        checkpoint_root_directory=CKPT_DIR,
        checkpointing_options=checkpointing_options,
    ),
    rollout_config=base_rollout.RolloutConfig(
        max_prompt_length=MAX_PROMPT_LENGTH,
        kv_cache_size=MAX_PROMPT_LENGTH + MAX_RESPONSE_LENGTH + 256,
        temperature=TEMPERATURE,
        top_p=TOP_P,
        top_k=TOP_K,
        eos_tokens=[tokenizer.encode("<|im_end|>")[0]],
    ),
)

rl_cluster = rl_cluster_lib.RLCluster(
    actor=qwen_actor,
    reference=qwen_reference,
    tokenizer=tokenizer,
    cluster_config=cluster_config,
)

# ==========================================
# 11. Learner & Agent Setup
# ==========================================
grpo_config = agentic_grpo_learner.GRPOConfig(
    num_generations=NUM_GENERATIONS,
    num_iterations=NUM_ITERATIONS,
    max_response_length=MAX_RESPONSE_LENGTH,
    beta=BETA,
    epsilon=EPSILON,
    system_prompt=SWE_SYSTEM_PROMPT,
    max_concurrency=1,
    epsilon_high=0.28,
    off_policy_steps=0,
)

# Helper for dummy reward function (placeholder)
def dummy_reward_fn(prompts, completions, **kwargs):
    return 0

# with jax.default_device(train_mesh.local_devices[0]):
agentic_grpo_learner = agentic_grpo_learner.GRPOLearner(
    rl_cluster=rl_cluster,
    reward_fns=dummy_reward_fn,
    agent_class=SWEAgent,
    agent_kwargs={},
    env_class=SWEEnv,
    env_kwargs={"max_steps": MAX_TURNS}, 
    algo_config=grpo_config,
)

2026-02-23 01:37:09,506 - absl - INFO - Reshard finished in 0.47s
2026-02-23 01:37:09,635 - absl - INFO - WandbBackend skipped: 'wandb' library not installed.
2026-02-23 01:37:10,160 - absl - INFO - save_device_host_concurrent_bytes=None
2026-02-23 01:37:10,162 - absl - INFO - Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x7f8fbb3fa3c0>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
2026-02-23 01:37:10,162 - absl - INFO - save_device_host_concurrent_bytes=None
2026-02-23 01:37:10,162 - absl - INFO - Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store

In [14]:
import grain
grain_dataset = grain.MapDataset.source(dataset)

# def transform_entry(entry):
#     processed_entry = {}
#     for k, v in entry.items():
#         new_key = 'prompts' if k == 'prompt' else k
        
#         # If it's a list (and not the prompts), JSON encode it
#         if isinstance(v, list) and new_key != 'prompts':
#             # This turns [2 items] and [3 items] into simple strings
#             processed_entry[new_key] = json.dumps(v)
#         else:
#             processed_entry[new_key] = v
#     return processed_entry


# grain_dataset = grain.MapDataset.source(dataset).map(transform_entry)

train_dataset, _ = data_lib.post_init_dataset(
    grain_dataset, 
    tokenizer, 
    batch_size=BATCH_SIZE,
    num_batches=NUM_BATCHES,
    max_prompt_length=MAX_PROMPT_LENGTH,
    fraction=TRAIN_FRACTION,
    num_epochs=NUM_EPOCHS,
    # worker_count=8,
)


In [17]:
entry = next(iter(train_dataset))

In [None]:
entry['modified_files']

np.str_('["Orange/widgets/settings.py", "Orange/widgets/tests/test_context_handler.py", "doc/development/source/tutorial-settings.rst"]')

In [23]:
res=np.array(['["Orange/widgets/settings.py", "Orange/widgets/tests/test_context_handler.py", "doc/development/source/tutorial-settings.rst"]',
       '["Orange/data/tests/test_variable.py", "Orange/data/variable.py"]',
       '["Orange/data/pandas_compat.py", "Orange/data/tests/test_pandas.py"]',
       '["Orange/preprocess/discretize.py", "Orange/preprocess/tests/test_discretize.py"]',
       '["Orange/data/tests/test_util.py", "Orange/data/util.py"]',
       '["Orange/misc/distmatrix.py", "Orange/tests/test_distances.py"]',
       '["Orange/data/tests/test_io_util.py", "Orange/data/variable.py"]',
       '["Orange/classification/logistic_regression.py", "Orange/tests/test_logistic_regression.py"]',
       '["Orange/data/tests/test_variable.py", "Orange/data/variable.py"]',
       '["Orange/widgets/utils/state_summary.py", "Orange/widgets/utils/tests/test_state_summary.py"]',
       '["Orange/preprocess/discretize.py", "Orange/tests/test_discretize.py"]',
       '["Orange/preprocess/discretize.py", "Orange/preprocess/tests/test_discretize.py"]',
       '["Orange/data/pandas_compat.py", "Orange/data/tests/test_pandas.py"]',
       '["Orange/widgets/data/owconcatenate.py", "Orange/widgets/data/tests/test_owconcatenate.py"]',
       '["Orange/preprocess/fss.py", "Orange/preprocess/tests/test_fss.py"]',
       '["Orange/widgets/data/owpurgedomain.py", "Orange/widgets/data/tests/test_owpurgedomain.py"]'])

In [None]:
res[0]

np.str_('["Orange/widgets/settings.py", "Orange/widgets/tests/test_context_handler.py", "doc/development/source/tutorial-settings.rst"]')

: 

In [None]:
print("Starting training...")
agentic_grpo_learner.train(train_dataset=train_dataset)

Starting training...


2026-02-23 00:29:12,529 - absl - INFO - Training with full_batch_size=0, mini_batch_size=16, train_micro_batch_size=1, self._rollout_micro_batch_size=1, self._compute_logps_micro_batch_size=1, grad_acc_steps=16
2026-02-23 00:29:12,529 - absl - INFO - Starting AgenticRLLearner training loop.
2026-02-23 00:29:12,530 - absl - INFO - Prefilling prompt queue with 1 batches.
2026-02-23 00:29:12,531 - absl - INFO - Starting run_producers_from_stream with 1 concurrency


SWEEnv is initialized with: 3
SWEEnv is initialized with: 3
SWEEnv step impl called
action string: <function=view>
  <parameter=path>/testbed</parameter>
</function>
calling r2e env
SWEEnv step impl called
action string: <function=view>
</function>
calling r2e env


In [9]:
print(entries[0]['problem_statement'])

[ISSUE]
**Title:** Context migration fails to remove incompatible contexts, causing initialization errors

**Description:**
When initializing the `ContextHandler` with a mix of compatible and incompatible contexts, the migration process does not remove the incompatible contexts as expected. Instead, it raises an `IncompatibleContext` error, preventing successful initialization.

**Example Code:**
```python
handler = ContextHandler()
handler.bind(SimpleWidget)

widget = SimpleWidget()
contexts = [Context(foo=i) for i in (13, 13, 0, 1, 13, 2, 13)]

def migrate_context(context, _):
    if context.foo == 13:
        raise IncompatibleContext()

handler.initialize(widget, dict(context_settings=contexts))
# Expected: Incompatible contexts with foo=13 should be removed
# Actual: IncompatibleContext error is raised, and contexts are not removed
```

**Expected Behavior:**
During initialization, contexts that are incompatible (e.g., those that cause `IncompatibleContext` to be raised) should be

In [10]:
print(entries[0]['prompt'])

You are an expert software engineer tasked with creating informative GitHub issues based on commit details and test results. These issues will be used to help junior developers and machine learning systems understand the motivation behind commits. Your goal is to create concise, clear, and realistic issues that highlight bugs without revealing solutions.
    
The commit hash is 2d9617bd0cb1f0ba61771258410ab8fae8e7e24d. 
The commit message is: Settings migration: Allow rejecting a context.

The commit patch is:
```diff
diff --git a/Orange/widgets/settings.py b/Orange/widgets/settings.py
index 8be8bf0ae..75ebe4129 100644
--- a/Orange/widgets/settings.py
+++ b/Orange/widgets/settings.py
@@ -49,7 +49,8 @@ log = logging.getLogger(__name__)
 __all__ = ["Setting", "SettingsHandler", "SettingProvider",
            "ContextSetting", "ContextHandler",
            "DomainContextHandler", "PerfectDomainContextHandler",
-           "ClassValuesContextHandler", "widget_settings_dir"]
+           "Cl

In [None]:
import os

os.environ["KUBECONFIG"] = "~/.kube/config"
os.environ["NODE_SELECTOR_KEY"] = "cloud.google.com/gke-nodepool"
os.environ["NODE_SELECTOR_VAL"] = "deepswe-worker-pool" # NB: change based on your node pool name

from kubernetes import client, config
try:
    k8s_config.load_kube_config()
    k8s_client = client.CoreV1Api()
except Exception as e:
    print(f"Warning: Kubernetes config loading failed: {e}")


In [None]:
# ====== Data ======
TRAIN_FRACTION = 1.0

# ====== Reproducibility ======
SEED = 42

# ====== LoRA ======
RANK = 64
ALPHA = 64.0
TRAIN_WITH_LORA = True

# ====== Sharding ======
# MESH = [(4, 2), ("fsdp", "tp")]


# ====== GRPO ======
# === Generation during GRPO training ===
# MAX_PROMPT_LENGTH = 32768
MAX_PROMPT_LENGTH = 2048
MAX_RESPONSE_LENGTH = 8192
TEMPERATURE = 0.6
TOP_P = 0.95
TOP_K = 50
NUM_GENERATIONS = 2 # This corresponds to `G` in Algorithm 1

# === other GRPO configs ===
NUM_ITERATIONS = 1
BETA = 0.001
EPSILON = 0.2

# ====== Training ======
BATCH_SIZE = 16
MINI_BATCH_SIZE = 16
# ROLLOUT_MICRO_BATCH_SIZE = 8
# LOGPS_MICRO_BATCH_SIZE = 8
NUM_BATCHES = 1
NUM_TEST_BATCHES = 50

EVAL_EVERY_N_STEPS = 10
NUM_EPOCHS = 100 

# Number of training steps.
MAX_STEPS = 10

# === AdamW, warmup, cosine scheduler ===
LEARNING_RATE = 1e-6
B1 = 0.9  
B2 = 0.99 
WEIGHT_DECAY = 0.1
WARMUP_STEPS = int(0.1 * MAX_STEPS)
MAX_GRAD_NORM = 0.1

# ====== Checkpoint saving ======
SAVE_INTERVAL_STEPS = 500
MAX_TO_KEEP = 4
DO_MEM_PROFILING = False

# ====== Inference ======
GENERATION_CONFIGS = {
    "greedy": {"temperature": 1e-4, "top_k": 1, "top_p": 1.0},
    "standard": {"temperature": 0.7, "top_k": 50, "top_p": 0.95},
    "liberal": {"temperature": 0.85, "top_k": 2000, "top_p": 1.0},
}

# ====== Rollout ======
ROLLOUT_ENGINE = "vanilla" # one of "vanilla", "vllm" or "sglang_jax"
CKPT_DIR = os.path.join("/tmp/cp", "deepswe_ckpt/00")


In [1]:
# NOTE: download to local dir for faster future access
# >> hf download Qwen/Qwen3-4B-Instruct-2507 --local-dir ./models

# MODEL_PATH = "/scratch/models/DeepSeek-R1-Distill-Qwen-1.5B/"
# MODEL_PATH = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
MODEL_PATH = os.path.expanduser("~/models/Qwen3-1.7B/")

from transformers import AutoTokenizer
from tunix.rl.agentic.parser.chat_template_parser import parser

tokenizer = AutoTokenizer.from_pretrained(
    MODEL_PATH, 
    local_files_only=True,
    trust_remote_code=True
)

chat_parser = parser.QwenChatTemplateParser(tokenizer)

NameError: name 'os' is not defined

In [None]:
import numpy as np
import jax
import jax.numpy as jnp
from jax.sharding import Mesh
from tunix.models.qwen3 import params as params_lib
from tunix.models.qwen3 import model as model_lib
from tunix.sft import utils as sft_utils

devices = jax.devices()
split = int(len(devices) / 2)
rollout_devices = np.array(devices[:split]).reshape(2,2)
train_devices = np.array(devices[split:]).reshape(2,2)

rollout_mesh = Mesh(rollout_devices, axis_names=('fsdp', 'tp'))
train_mesh = Mesh(train_devices, axis_names=('fsdp', 'tp'))


config = model_lib.ModelConfig.qwen3_1p7b()


qwen_reference = params_lib.create_model_from_safe_tensors(MODEL_PATH, config, mesh=train_mesh, dtype=jnp.bfloat16)
def get_lora_model(base_model, model_mesh):
  lora_provider = qwix.LoraProvider(
      module_path=(
          ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|"
          ".*attn_vec_einsum"
      ),
      rank=RANK,
      alpha=ALPHA,
  )

  model_input = base_model.get_model_input()
  lora_model = qwix.apply_lora_to_model(
      base_model, lora_provider, **model_input
  )

  with compat.set_mesh(model_mesh):
    state = nnx.state(lora_model)
    pspecs = nnx.get_partition_spec(state)
    sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
    nnx.update(lora_model, sharded_state)

  return lora_model
qwen_actor = get_lora_model(qwen_reference, train_mesh)
sft_utils.show_hbm_usage()


2026-02-04 02:15:25 - INFO -  - Pathways not available. Using default HBM stats collector
2026-02-04 02:15:25 - INFO - Using 1.9 GiB / 31.2 GiB (0.05997285121734592) on TPU_0(process=0,(0,0,0,0))
2026-02-04 02:15:25 - INFO - Using 1.9 GiB / 31.2 GiB (0.05997285121734592) on TPU_1(process=0,(1,0,0,0))
2026-02-04 02:15:25 - INFO - Using 1.9 GiB / 31.2 GiB (0.05997285121734592) on TPU_2(process=0,(0,1,0,0))
2026-02-04 02:15:25 - INFO - Using 1.9 GiB / 31.2 GiB (0.05997285121734592) on TPU_3(process=0,(1,1,0,0))
2026-02-04 02:15:25 - INFO - Using 1.9 GiB / 31.2 GiB (0.05997285121734592) on TPU_4(process=0,(0,2,0,0))
2026-02-04 02:15:25 - INFO - Using 1.9 GiB / 31.2 GiB (0.05997285121734592) on TPU_5(process=0,(1,2,0,0))
2026-02-04 02:15:25 - INFO - Using 1.9 GiB / 31.2 GiB (0.05997285121734592) on TPU_6(process=0,(0,3,0,0))
2026-02-04 02:15:25 - INFO - Using 1.9 GiB / 31.2 GiB (0.05997285121734592) on TPU_7(process=0,(1,3,0,0))


In [None]:
checkpointing_options = ocp.CheckpointManagerOptions(
    save_interval_steps=SAVE_INTERVAL_STEPS, max_to_keep=MAX_TO_KEEP
)
metrics_logging_options = metrics_logger.MetricsLoggerOptions(
    log_dir="/tmp/tensorboard/grpo", flush_every_n_steps=2
)

optimizer = optax.adamw(
    learning_rate=optax.schedules.warmup_cosine_decay_schedule(
        init_value=0.0,
        peak_value=LEARNING_RATE,
        warmup_steps=WARMUP_STEPS,
        decay_steps=MAX_STEPS,
        end_value=0.0,
    ),
    b1=B1,
    b2=B2,
    weight_decay=WEIGHT_DECAY,
)

In [None]:
cluster_config = rl_cluster_lib.ClusterConfig(
    role_to_mesh={
        rl_cluster_lib.Role.ACTOR: train_mesh,
        rl_cluster_lib.Role.REFERENCE: train_mesh,
        rl_cluster_lib.Role.ROLLOUT: rollout_mesh,
    },
    rollout_engine=ROLLOUT_ENGINE,
    offload_to_cpu=False,
    training_config=rl_cluster_lib.RLTrainingConfig(
        actor_optimizer=optimizer,
        eval_every_n_steps=EVAL_EVERY_N_STEPS,
        max_steps=10, # Note: Overridden locally to 20 in config vs MAX_STEPS above
        mini_batch_size=MINI_BATCH_SIZE,
        train_micro_batch_size=1,
        metrics_logging_options=metrics_logging_options,
        checkpoint_root_directory=CKPT_DIR,
        checkpointing_options=checkpointing_options,
    ),
    rollout_config=base_rollout.RolloutConfig(
        max_prompt_length=MAX_PROMPT_LENGTH,
        kv_cache_size=MAX_PROMPT_LENGTH + MAX_RESPONSE_LENGTH + 256,
        temperature=TEMPERATURE,
        top_p=TOP_P,
        top_k=TOP_K,
        eos_tokens=[tokenizer.encode("<|im_end|>")[0]],
    ),
)

rl_cluster = rl_cluster_lib.RLCluster(
    actor=qwen_actor,
    reference=qwen_reference,
    tokenizer=tokenizer,
    cluster_config=cluster_config,
)


In [None]:
grpo_config = agentic_grpo_learner.GRPOConfig(
    num_generations=NUM_GENERATIONS,
    num_iterations=NUM_ITERATIONS,
    max_response_length=MAX_RESPONSE_LENGTH,
    beta=BETA,
    epsilon=EPSILON,
    system_prompt=SWE_SYSTEM_PROMPT,
    max_concurrency=1,
    epsilon_high=0.28,
    off_policy_steps=0,
)

# Helper for dummy reward function (placeholder)
def dummy_reward_fn(prompts, completions, **kwargs):
    return 0

# with jax.default_device(train_mesh.local_devices[0]):
agentic_grpo_learner = agentic_grpo_learner.GRPOLearner(
    rl_cluster=rl_cluster,
    reward_fns=dummy_reward_fn,
    agent_class=SWEAgent,
    agent_kwargs={},
    env_class=SWEEnv,
    env_kwargs={"max_steps": 3}, 
    algo_config=grpo_config,
)

In [1]:
import json
def transform_entry(entry):
    processed_entry = {}
    for k, v in entry.items():
        new_key = 'prompts' if k == 'prompt' else k
        
        # If it's a list (and not the prompts), JSON encode it
        if isinstance(v, list) and new_key != 'prompts':
            # This turns [2 items] and [3 items] into simple strings
            processed_entry[new_key] = json.dumps(v)
        else:
            processed_entry[new_key] = v
    return processed_entry


grain_dataset = grain.MapDataset.source(entries).map(transform_entry)

train_dataset, _ = data_lib.post_init_dataset(
    grain_dataset, 
    tokenizer, 
    batch_size=BATCH_SIZE,
    num_batches=NUM_BATCHES,
    max_prompt_length=None, #TODO(sizhi):  Max prompt length filtering is applied but also used to calculate kv cache size 
    fraction=TRAIN_FRACTION,
    num_epochs=NUM_EPOCHS,
)

print("Starting training...")
agentic_grpo_learner.train(train_dataset=train_dataset)

NameError: name 'grain' is not defined

In [8]:
from tunix.generate import sampler

sampler = sampler.Sampler(qwen_actor, tokenizer, sampler.CacheConfig(cache_size=16384, num_layers=36, num_kv_heads=8, head_dim=128))

In [9]:
# from tunix.generate.vllm_sampler import VllmSampler, VllmConfig
# from tunix.generate import mappings

# mapping_config = mappings.MappingConfig.build(
#     mapping_obj=None,
#     model=qwen_actor,
#     backend="vllm_jax",
# )

# vllm_config = VllmConfig(
#     model_path=MODEL_PATH,
#     max_model_len=8192,
#     mesh=train_mesh,
#     hbm_utilization_target=0.5,
#     init_with_random_weights=True,
#     tpu_backend_type="jax",
#     mapping_config=mapping_config
# )
# vllm_sampler = VllmSampler(tokenizer=tokenizer, config=vllm_config)

In [12]:
from swe_agent import SWEAgent
from swe_env import SWEEnv
from tunix.rl.agentic.trajectory import trajectory_collect_engine
from tunix.rl.agentic.parser.chat_template_parser.parser import QwenChatTemplateParser
from tunix.rl.agentic.rewards.reward_types import RewardOutput

chat_parser = QwenChatTemplateParser(tokenizer)

# def model_call(chat_lists, rl_cluster):
#     result = rl_cluster.generate(
#         prompts=chat_lists,
#         apply_chat_template=True,
#         mode=rl_cluster_lib.Mode.TRAIN,
#     )
#     return result.text[0]

def model_call(chat_completions, _):
    p = chat_parser.parse(chat_completions)
    out = sampler(p, max_generation_steps=512, echo=False)
    return out.text[0]

MAX_STEPS = 10
agent = SWEAgent()
env = SWEEnv(entry=entries[0], max_steps=MAX_STEPS)

print(chat_parser.parse(agent.chat_completions))

engine = trajectory_collect_engine.TrajectoryCollectEngine(
    agent=agent,
    env=env,
    model_call=model_call,
    final_reward_fn=lambda x, y: RewardOutput(reward=0, metadata={}),
    max_steps=MAX_STEPS,
    gamma=0.9,
    timeout=120,
)


# res = await engine.collect(mode="Trajectory")

Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality.
Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade.
Users of this version of Gym should be able to simply replace 'import gym' with 'import gymnasium as gym' in the vast majority of cases.
See the migration guide at https://gymnasium.farama.org/introduction/migration_guide/ for additional information.


<|im_start|>system
You are a programming agent who is provided a github issue and repository bash environment and is tasked to solve certain tasks (e.g., file localization, testcase generation, code repair and editing etc) to resolve the issue.

We have access to the following functions:

–– BEGIN FUNCTION #1: file_editor ––
Description:
Custom editing tool for viewing, creating and editing files
  •	State is persistent across command calls and discussions with the user
  •	If path is a file, view displays the result of applying cat -n. If path is a directory, view lists non-hidden files and directories up to 2 levels deep
  •	The create command cannot be used if the specified path already exists as a file
  •	If a command generates a long output, it will be truncated and marked with <response clipped>
  •	The undo_edit command will revert the last edit made to the file at path

Notes for using the str_replace command:
  •	The old_str parameter should match EXACTLY one or more consecut

In [13]:
res = await engine.collect(mode='Trajectory')

calling r2e env
didn't find any funciton to call
calling r2e env
calling r2e env
calling r2e env
didn't find any funciton to call
calling r2e env
calling r2e env
calling r2e env
calling r2e env


In [14]:
# print(env.total_steps)

STEP = 9
print(f"Step {STEP} ###################")
print(f"Observation ###################")
print(res.steps[STEP].observation)
print(f"Model Response ###################")
print(res.steps[STEP].model_response)

agent._messages



Step 9 ###################
Observation ###################
Execution output of [search]:
Directory '/testbed/Orange/widgets/settings' not found or not a directory.

Model Response ###################
<|im_start|><|im_start|>
Let me try searching in a different directory that might contain context handling code.

<function=search>
<parameter=search_term>ContextHandler</parameter>
<parameter=path>/testbed/Orange/widgets</parameter>
</function>


[{'role': 'system',
  'content': 'You are a programming agent who is provided a github issue and repository bash environment and is tasked to solve certain tasks (e.g., file localization, testcase generation, code repair and editing etc) to resolve the issue.\n\nWe have access to the following functions:\n\n–– BEGIN FUNCTION #1: file_editor ––\nDescription:\nCustom editing tool for viewing, creating and editing files\n  •\tState is persistent across command calls and discussions with the user\n  •\tIf path is a file, view displays the result of applying cat -n. If path is a directory, view lists non-hidden files and directories up to 2 levels deep\n  •\tThe create command cannot be used if the specified path already exists as a file\n  •\tIf a command generates a long output, it will be truncated and marked with <response clipped>\n  •\tThe undo_edit command will revert the last edit made to the file at path\n\nNotes for using the str_replace command:\n  •\tThe old_str parameter should

In [None]:
# # ====== Data ======
# TRAIN_FRACTION = 1.0

# # ====== Reproducibility ======
# SEED = 42

# # ====== LoRA ======
# RANK = 64
# ALPHA = 64.0
# TRAIN_WITH_LORA = False

# # ====== Sharding ======
# MESH = [(2, 4), ("fsdp", "tp")]

# # ====== GRPO ======
# # === Generation during GRPO training ===
# MAX_PROMPT_LENGTH = 2048
# TOTAL_GENERATION_STEPS = 512
# # Important to keep a high-ish temperature for varied, diverse responses during
# # training.
# TEMPERATURE = 0.6
# TOP_P = 0.95
# TOP_K = 50
# # The number of times the policy generates multiple responses for a given prompt
# # within a single training step. This corresponds to `G` in Algorithm 1 in the
# # paper. The "group" in GRPO comes from here.
# NUM_GENERATIONS = 2

# # === other GRPO configs ===
# # The number of iterations per batch (𝜇 in GRPO algo 1).
# NUM_ITERATIONS = 1
# # The coefficient for the KL divergence penalty (𝛽) in the GRPO loss function.
# # Important to keep a high enough value for this, otherwise, the KL divergence
# # can increase unchecked.
# BETA = 0.001
# # Epsilon value for clipping (𝜀 in GRPO loss in paper). Similar to PPO, for
# # stable updates.
# EPSILON = 0.2

# # ====== Training ======
# BATCH_SIZE = 16
# MINI_BATCH_SIZE = 16
# # ROLLOUT_MICRO_BATCH_SIZE = 8
# # LOGPS_MICRO_BATCH_SIZE = 8
# NUM_BATCHES = 100
# # Keep `NUM_TEST_BATCHES` low so that evaluation runs quickly. It can be
# # increased to a max. of 330 (if batch size is 4).
# NUM_TEST_BATCHES = 50

# EVAL_EVERY_N_STEPS = 1000  # this doesn't matter if `TRAIN_FRACTION = 1.0`.
# NUM_EPOCHS = 100 # can potentially train for more epochs

# # Number of training steps.
# MAX_STEPS = int(NUM_BATCHES * NUM_ITERATIONS * TRAIN_FRACTION * NUM_EPOCHS)

# # === AdamW, warmup, cosine scheduler ===
# LEARNING_RATE = 1e-6
# B1 = 0.9  # Adam beta1
# B2 = 0.99  # Adam beta2
# WEIGHT_DECAY = 0.1
# # == Cosine decay with warmup scheduler ==
# # Linearly increase learning rate from 0. to 5e-6 in the first 10% training
# # steps, and then gradually decrease the learning rate to 0 using cosine
# # scheduler.
# WARMUP_STEPS = int(0.1 * MAX_STEPS)
# # == Grad clipping ==
# # Grad clipping to prevent large gradients. Found this
# # important to keep KL divergence in check.
# MAX_GRAD_NORM = 0.1

# # ====== Checkpoint saving ======
# SAVE_INTERVAL_STEPS = 500
# MAX_TO_KEEP = 4
# DO_MEM_PROFILING = False

# # ====== Inference ======
# GENERATION_CONFIGS = {
#     # greedy search
#     "greedy": {"temperature": 1e-4, "top_k": 1, "top_p": 1.0},
#     # some randomness
#     "standard": {"temperature": 0.7, "top_k": 50, "top_p": 0.95},
#     # liberal
#     "liberal": {"temperature": 0.85, "top_k": 2000, "top_p": 1.0},
# }
# # ====== Rollout ======
# ROLLOUT_ENGINE = "sglang_jax" # one of "vanilla", "vllm" or "sglang_jax"

# CKPT_DIR = os.path.join("/tmp/cp", "deepscaler_ckpt/01")

In [None]:
# from tunix.rl import rl_cluster as rl_cluster_lib
# import optax
# from tunix.sft import metrics_logger
# from orbax import checkpoint as ocp
# from tunix.rl.rollout import base_rollout

# checkpointing_options = ocp.CheckpointManagerOptions(
#     save_interval_steps=SAVE_INTERVAL_STEPS, max_to_keep=MAX_TO_KEEP
# )
# metrics_logging_options = metrics_logger.MetricsLoggerOptions(
#     log_dir="/tmp/tensorboard/grpo", flush_every_n_steps=20
# )

# optimizer = optax.adamw(
#     learning_rate=optax.schedules.warmup_cosine_decay_schedule(
#         init_value=0.0,
#         peak_value=LEARNING_RATE,
#         warmup_steps=WARMUP_STEPS,
#         decay_steps=MAX_STEPS,
#         end_value=0.0,
#     ),
#     b1=B1,
#     b2=B2,
#     weight_decay=WEIGHT_DECAY,
# )

# cluster_config = rl_cluster_lib.ClusterConfig(
#     role_to_mesh={
#         rl_cluster_lib.Role.ACTOR: train_mesh,
#         rl_cluster_lib.Role.REFERENCE: train_mesh,
#         rl_cluster_lib.Role.ROLLOUT: rollout_mesh,
#     },
#     rollout_engine=ROLLOUT_ENGINE,
#     offload_to_cpu=False,
#     training_config=rl_cluster_lib.RLTrainingConfig(
#         actor_optimizer=optimizer,
#         eval_every_n_steps=EVAL_EVERY_N_STEPS,
#         max_steps=20,
#         mini_batch_size=MINI_BATCH_SIZE,
#         train_micro_batch_size = 1,  # larger than 1 will cause OOM on HBM
#         # metrics logging
#         metrics_logging_options=metrics_logging_options,
#         # checkpoint saving
#         checkpoint_root_directory=CKPT_DIR,
#         checkpointing_options=checkpointing_options,
#     ),
#     rollout_config=base_rollout.RolloutConfig(
#         max_tokens_to_generate=TOTAL_GENERATION_STEPS,
#         max_prompt_length=MAX_PROMPT_LENGTH,
#         kv_cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
#         temperature=TEMPERATURE,
#         top_p=TOP_P,
#         top_k=TOP_K,
#         eos_tokens=[tokenizer.encode("<|im_end|>")[0]],
#         # sglang-jax specific configs
#         rollout_sglang_jax_model_version="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
#         rollout_sglang_jax_mem_fraction_static=0.2,
#         rollout_sglang_jax_init_with_random_weights=True,
#         rollout_sglang_jax_disable_radix_cache=True,
#         rollout_sglang_jax_enable_deterministic_sampling=False,
#         rollout_sglang_jax_precompile_bs_paddings=[1, 2],
#         rollout_sglang_jax_precompile_token_paddings=[2048, 4096, 8192],
#         rollout_sglang_jax_chunked_prefill_size=2048,
#         rollout_sglang_jax_page_size=64,
#     ),
# )

# rl_cluster = rl_cluster_lib.RLCluster(
#     actor=qwen2_actor,
#     reference=qwen2_ref,
#     tokenizer=tokenizer,
#     cluster_config=cluster_config,
# )

# Random stuff for debugging

In [None]:
# from rllm.environments.swe.swe import R2EGYM_COMMAND_FILES
# import r2egym

# print(r2egym.__file__)
# from r2egym.agenthub.runtime.docker import DockerRuntime
# from r2egym.agenthub.utils.log import get_logger
# from r2egym.agenthub.environment.env import EnvArgs, RepoEnv

# env_args = EnvArgs(ds=entries[0])
# env = RepoEnv(env_args, backend="kubernetes")

# env.add_commands(cmd_files=R2EGYM_COMMAND_FILES)

In [None]:
# runtime = DockerRuntime(ds=entries[0], command=["/bin/bash", "-l"], logger=get_logger(), backend="kubernetes", id=IDS[0])
# runtime.get_task_instruction()

In [None]:
# runtime.run(code="ls -l")
# runtime.stop_container()

In [None]:
# DOCKER_PATH = "/root/.venv/bin:/root/.local/bin:/root/.cargo/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
# pod_name = "tsbao-test-cpu-pod"
# docker_image = entries[0]["docker_image"]
# command = "/bin/bash"

# env_vars = {"PATH": DOCKER_PATH}
# env_spec = [{"name": k, "value": str(v)} for k, v in env_vars.items()]
# pod_body = {
#     "apiVersion": "v1",
#     "kind": "Pod",
#     "metadata": {"name": pod_name},
#     "spec": {
#         "restartPolicy": "Never",
#         "containers": [
#             {
#                 "name": pod_name,
#                 "image": docker_image,
#                 "command": ["/bin/sh", "-c"],
#                 "args": [command] if isinstance(command, str) else command,
#                 "stdin": True,
#                 "tty": True,
#                 "env": env_spec,
#                 "resources": {
#                     "requests": {"cpu": "1", "memory": "1Gi"},
#                 },
#             }
#         ],
#         "imagePullSecrets": [{"name": "dockerhub-pro"}],
#         "nodeSelector": {"cloud.google.com/gke-nodepool": "tsbao-cpu-pool"},
#         "tolerations": [
#             {
#                 "key": "node.kubernetes.io/disk-pressure",
#                 "operator": "Exists",
#                 "effect": "NoExecute",
#                 "tolerationSeconds": 10800
#             }
#         ],
#     },
# }

pod = k8s_client.create_namespaced_pod(
    namespace="default", body=pod_body, _request_timeout=60,
)

In [None]:
# k8s_client.list_namespaced_pod(namespace="default")
pod_name = "tsbao-test-pod"
pod = k8s_client.read_namespaced_pod(name=pod_name, namespace="default")
pod.status.phase



'Running'

In [None]:
# from kubernetes.stream import stream

# full_command = ["/bin/sh", "-c", "ls -l"]
# resp = stream(
#     k8s_client.connect_get_namespaced_pod_exec,
#     name=pod_name,
#     namespace="default",
#     command=full_command,
#     stderr=True,
#     stdin=False,
#     stdout=True,
#     tty=False,  # Match docker exec_run settings
#     _preload_content=False,  # Important for streaming
# )
# resp

<kubernetes.stream.ws_client.WSClient at 0x78b19a7390f0>

In [None]:
# combined_chunks = []
# stdout_chunks = []
# stderr_chunks = []
# while resp.is_open():
#     resp.update(timeout=1)  # wait for data
#     if resp.peek_stdout():
#         chunk = resp.read_stdout()
#         stdout_chunks.append(chunk)
#         combined_chunks.append(chunk)
#     if resp.peek_stderr():
#         chunk = resp.read_stderr()
#         stderr_chunks.append(chunk)
#         combined_chunks.append(chunk)
# resp.close()
# exit_code = resp.returncode
# combined_output = "".join(combined_chunks)

In [None]:
# from r2egym.agenthub.agent.commands import ParseCommandBash

# cmd_parser = ParseCommandBash()
# cmds = cmd_parser.parse_command_file("/scratch/git/R2E-Gym/src/r2egym/agenthub/tools/r2egym/file_editor.py")
# cmds[0]

