In [1]:
DRAFT_MODEL="Qwen/Qwen2.5-Math-1.5B-Instruct"
TARGET_MODEL="Qwen/Qwen2.5-Math-7B-Instruct"
PRM="Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B"
DRAFT_IP_ADDRESS="http://localhost:12340/v1"
TARGET_IP_ADDRESS="http://localhost:12341/v1"
PRM_IP_ADDRESS="http://localhost:12342/v1"

In [None]:
import openai
from openai import OpenAI
from transformers import AutoTokenizer
import numpy as np

openai_api_key = "EMPTY"

args = {
    "draft_model_path_rsd": DRAFT_MODEL,
    "draft_model_ip_address": DRAFT_IP_ADDRESS,
    "target_model_path_rsd": TARGET_MODEL,
    "target_model_ip_address": TARGET_IP_ADDRESS,
    "prm_ip_address_rsd": PRM_IP_ADDRESS,
    "prm_path_rsd": PRM
}


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
draft_client = OpenAI(
        api_key=openai_api_key,
        base_url=args["draft_model_ip_address"],
    )
draft_tokenizer = AutoTokenizer.from_pretrained(args["draft_model_path_rsd"], trust_remote_code=True)

target_client = OpenAI(
    api_key=openai_api_key,
    base_url=args["target_model_ip_address"],
)
target_tokenizer = AutoTokenizer.from_pretrained(args["target_model_path_rsd"], trust_remote_code=True)

prm_client = OpenAI(
    api_key=openai_api_key,
    base_url=args["prm_ip_address_rsd"],
)
prm_tokenizer = AutoTokenizer.from_pretrained(args["prm_path_rsd"], trust_remote_code=True)

In [4]:
config = {
    "draft_model_name_or_path": DRAFT_MODEL,
    "temperature": 1.0,
    "top_p": 0.9,
    "max_tokens": 2048
}

In [5]:
gen_prompts = [
            "Say hello",
            "Say goodbye"
        ]
llm = draft_client

In [6]:
draft_responses = llm.completions.create(
                model=config["draft_model_name_or_path"].split("/")[-1],
                prompt=gen_prompts,
                temperature=config["temperature"],
                top_p=config["top_p"],
                max_tokens=config["max_tokens"],
                stop=["\n\n"],
                n=1,
                stream=False,
                extra_body={
                    "include_stop_str_in_output": True
                                   }
            ).choices
llm_outputs = sorted(draft_responses, key=lambda x: int(x.index))

In [7]:
print("llm_outputs: \n", llm_outputs)

llm_outputs: 
 [CompletionChoice(finish_reason='stop', index=0, logprobs=None, text='. The language model uses a simplified attention mechanism to weigh theimportance of each word in a sentence. The attention weights are calculated using a scoring function that takes into account the context and the context itself. For a given sentence, the attention weight for each word is given by the formula:\n\n', stop_reason='\n\n', prompt_logprobs=None), CompletionChoice(finish_reason='stop', index=1, logprobs=None, text=', Sumit, and submergedCities. The number of people in these three cities is given as follows:\n- Sumit has 12,345 people.\n- Subha has 10,987 people.\n- Submarines have 8,765 people.\n\n', stop_reason='\n\n', prompt_logprobs=None)]


In [8]:
for output in llm_outputs:
    print(f"output.text: {output.text}")
    print(f"output.index: {output.index}")
    print(f"output.finish_reason: {output.finish_reason}")
    print(f"output.logprobs: {output.logprobs}")

output.text: . The language model uses a simplified attention mechanism to weigh theimportance of each word in a sentence. The attention weights are calculated using a scoring function that takes into account the context and the context itself. For a given sentence, the attention weight for each word is given by the formula:


output.index: 0
output.finish_reason: stop
output.logprobs: None
output.text: , Sumit, and submergedCities. The number of people in these three cities is given as follows:
- Sumit has 12,345 people.
- Subha has 10,987 people.
- Submarines have 8,765 people.


output.index: 1
output.finish_reason: stop
output.logprobs: None


In [9]:
def prepare_input(problem, response, tokenizer, step_token):
    prompt_ids = tokenizer.encode(tokenizer.bos_token + problem + "\n")
    response_ids = []
    steps = []
    reward_flags = [0] * len(prompt_ids)
    step_token_id = tokenizer.encode(step_token)[-1]
    for idx, step in enumerate(response.split(step_token)):
        if step != "":
            step_ids = tokenizer.encode(step)
        else:
            step_ids = []
        step_ids += [step_token_id]
        step = step + step_token
        flag = [0] * len(step_ids)
        flag[-1] = 1
        response_ids.extend(step_ids)
        reward_flags.extend(flag)
        steps.append(step)
    input_ids = prompt_ids + response_ids
    return input_ids, steps, reward_flags

In [11]:
processed_data = [
    prepare_input(p, full_resp.text, tokenizer=prm_tokenizer, step_token="\n\n") 
    for p, full_resp in zip(gen_prompts, llm_outputs)
]
input_ids, steps, reward_flags = zip(*processed_data)

In [31]:
rewards = prm_client.embeddings.create(
                model=args["prm_path_rsd"].split("/")[-1],
                input=input_ids,
            )

In [33]:
for i in rewards:
    print(i)

('data', [Embedding(embedding=[-0.59375, -0.4609375, 1.6015625, -1.515625, -0.515625, 1.125, -1.1796875, 1.1015625, 1.9609375, -0.11376953125, 0.13671875, 0.21484375, 0.7265625, -0.36328125, 0.306640625, -0.00640869140625, -0.423828125, -0.90625, 1.265625, 3.8125, 0.10595703125, -0.193359375, 0.97265625, -0.291015625, 0.1494140625, -1.09375, -0.53515625, -0.71875, -0.52734375, -0.12353515625, 0.404296875, -0.392578125, -0.50390625, -1.1484375, 0.3984375, -0.380859375, 0.6328125, 0.75390625, 0.8203125, -0.265625, 1.34375, 0.10107421875, -2.03125, -0.88671875, -0.33203125, 1.3515625, 0.5859375, 1.453125, 1.125, 0.2373046875, -1.3125, -1.2421875, -1.234375, 1.15625, 2.1875, -0.97265625, -2.6875, -0.1630859375, -2.015625, -2.15625, -0.765625, 0.216796875, -0.087890625, -0.337890625], index=0, object='embedding'), Embedding(embedding=[-0.59375, -0.4609375, -0.6640625, -0.984375, -0.2578125, -0.921875, 1.1171875, 2.078125, 3.8125, -1.515625, 0.205078125, -0.390625, -0.3984375, -1.4765625, 0.

In [35]:
def sigmoid(x):
    return 1/(np.exp(-x) + 1)
    
def derive_step_rewards_vllm(raw_rewards, batch_reward_flags):
    batch_step_rewards = []
    for idx,data in enumerate(raw_rewards.data):
        rewards = data.embedding
        reward_flags = batch_reward_flags[idx]

        step_rewards = [sigmoid(reward) for reward,flag in zip(rewards,reward_flags) if flag == 1]   
        batch_step_rewards.append(step_rewards)
    return batch_step_rewards

In [37]:
step_rewards = derive_step_rewards_vllm(rewards, reward_flags)
step_rewards

[[0.4780414772938761, 0.4163219602930174],
 [0.4134771498315425, 0.3798383564508293]]