In [None]:
####################################################################################################
############################ Try SGLang and Test Llama3 on MATH  ###################################
####################################################################################################

In [1]:
from sglang.utils import (
    execute_shell_command,
    wait_for_server,
    terminate_process,
    print_highlight,
)

import requests

In [2]:
url = "http://localhost:10086/v1/chat/completions"

data = {
    "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
    "messages": [
        {"role": "user", "content": "List 3 countries and their capitals."}
    ]
}

response = requests.post(url, json=data)
print(response.json())

{'id': 'c6a1022ade8740d4aca47c978ab03380', 'object': 'chat.completion', 'created': 1733650120, 'model': 'meta-llama/Meta-Llama-3.1-8B-Instruct', 'choices': [{'index': 0, 'message': {'role': 'assistant', 'content': 'Here are 3 countries and their capitals:\n\n1. Country: Japan\nCapital: Tokyo\n\n2. Country: Australia\nCapital: Canberra\n\n3. Country: Brazil\nCapital: Brasília'}, 'logprobs': None, 'finish_reason': 'stop', 'matched_stop': 128009}], 'usage': {'prompt_tokens': 43, 'total_tokens': 83, 'completion_tokens': 40, 'prompt_tokens_details': None}}


In [3]:
import openai

client = openai.Client(base_url="http://127.0.0.1:10086/v1", api_key="None")

response = client.chat.completions.create(
    model="meta-llama/Meta-Llama-3.1-8B-Instruct",
    messages=[
        {"role": "system", "content": "You are a helpful assistant"},
        {"role": "user", "content": "List 3 countries and their capitals."},
    ],
    temperature=0,
    max_tokens=128,
)

print(response)

ChatCompletion(id='56c754badbc749a4a99ed74e45ce987c', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content='Here are 3 countries and their capitals:\n\n1. Country: Japan\n   Capital: Tokyo\n\n2. Country: Australia\n   Capital: Canberra\n\n3. Country: Brazil\n   Capital: Brasília', refusal=None, role='assistant', audio=None, function_call=None, tool_calls=None), matched_stop=128009)], created=1733650125, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=43, prompt_tokens=48, total_tokens=91, completion_tokens_details=None, prompt_tokens_details=None))


In [4]:
# load math500-test.jsonl
import json

data = []
with open("math500-test.jsonl", "r") as f:
    raw_data = f.readlines()

for line in raw_data:
    data.append(json.loads(line))

In [5]:
# Need to import an evaluator from dart_math
# to compare math expressions with ground truth answers

from dart_math.eval import EvaluatorMath
math_evaluator = EvaluatorMath()

  from .autonotebook import tqdm as notebook_tqdm
2024-12-08 01:28:55,884	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


In [6]:
# from simple_eval by OpenAI

import random
import re
from typing import Literal

import blobfile as bf
import pandas

import common
from common import ANSWER_PATTERN, HTML_JINJA, check_equality
from eval_types import Eval, EvalResult, SamplerBase, SingleEvalResult

QUERY_TEMPLATE = """
Solve the following math problem step by step. The last line of your response should be of the form Answer: $ANSWER (without quotes) where $ANSWER is the answer to the problem.

{Question}

Remember to put your answer on its own line after "Answer:", and you do not need to use a \\boxed command.
""".strip()


class MathEval(Eval):
    def __init__(
        self,
        equality_checker: SamplerBase,
        num_examples: int | None = None,
        n_repeats: int = 1,
        split: Literal["math_test", "math_500_test"] = "my_math_500_test",  # see readme.md
    ):
        df = pandas.read_csv(
            # bf.BlobFile(f"https://openaipublic.blob.core.windows.net/simple-evals/{split}.csv")
            bf.BlobFile(f"{split}.csv")
        )
        examples = [row.to_dict() for _, row in df.iterrows()]
        if num_examples:
            assert n_repeats == 1, "n_repeats only supported for num_examples = None"
            rng = random.Random(0)
            examples = rng.sample(examples, num_examples)
        self.examples = examples * n_repeats
        self.equality_checker = equality_checker

    def __call__(self, sampler: SamplerBase) -> EvalResult:
        def fn(row: dict):
            prompt_messages = [
                sampler._pack_message(content=QUERY_TEMPLATE.format(**row), role="user")
            ]
            response_text = sampler(prompt_messages)
            match = re.search(ANSWER_PATTERN, response_text)
            extracted_answer = match.group(1) if match else None
            # score = 0 if extracted_answer is None else \
            #     float(check_equality(self.equality_checker, row["Answer"], extracted_answer))
            score = 0 if extracted_answer is None else math_evaluator.eq(row["Answer"], extracted_answer)
            
            # my change: none -> error
            html = common.jinja_env.from_string(HTML_JINJA).render(
                prompt_messages=prompt_messages,
                next_message=dict(content=response_text, role="assistant"),
                score=score,
                correct_answer=row["Answer"],
                extracted_answer=extracted_answer,
            )
            convo = prompt_messages + [dict(content=response_text, role="assistant")]
            return SingleEvalResult(html=html, score=score, convo=convo)

        results = common.map_with_progress(fn, self.examples)
        return common.aggregate_results(results)

In [7]:
# also from simple_eval by OpenAI

import base64
import time
from typing import Any

import openai
from openai import OpenAI

from eval_types import MessageList, SamplerBase

OPENAI_SYSTEM_MESSAGE_API = "You are a helpful assistant."
OPENAI_SYSTEM_MESSAGE_CHATGPT = (
    "You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture."
    + "\nKnowledge cutoff: 2023-12\nCurrent date: 2024-04-01"
)


class ChatCompletionSampler(SamplerBase):
    """
    Sample from OpenAI's chat completion API
    """

    def __init__(
        self,
        model: str = "gpt-3.5-turbo",
        system_message: str | None = None,
        temperature: float = 0.5,
        max_tokens: int = 1024,
        client = None, 
        return_full_response: bool = False,
    ):
        self.api_key_name = "OPENAI_API_KEY"
        self.client = client or OpenAI()
        # using api_key=os.environ.get("OPENAI_API_KEY")  # please set your API_KEY
        self.model = model
        self.system_message = system_message
        self.temperature = temperature
        self.max_tokens = max_tokens
        self.image_format = "url"
        self.return_full_response = return_full_response

    def _handle_image(
        self, image: str, encoding: str = "base64", format: str = "png", fovea: int = 768
    ):
        new_image = {
            "type": "image_url",
            "image_url": {
                "url": f"data:image/{format};{encoding},{image}",
            },
        }
        return new_image

    def _handle_text(self, text: str):
        return {"type": "text", "text": text}

    def _pack_message(self, role: str, content: Any):
        return {"role": str(role), "content": content}

    def __call__(self, message_list: MessageList) -> str:
        if self.system_message:
            message_list = [self._pack_message("system", self.system_message)] + message_list
        trial = 0
        # print(message_list)
        while True:
            try:
                response = self.client.chat.completions.create(
                    model=self.model,
                    messages=message_list,
                    temperature=self.temperature,
                    max_tokens=self.max_tokens,
                )
                if self.return_full_response:
                    print("message", message_list, "response", response.choices[0].message.content)
                    return response.choices[0].message.content
                else:
                    return response.choices[0].message.content

            # NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are reruning MMMU
            except openai.BadRequestError as e:
                print("Bad Request Error", e)
                return ""
            except Exception as e:
                exception_backoff = 2**trial  # expontial back off
                print(
                    f"Rate limit exception so wait and retry {trial} after {exception_backoff} sec",
                    e,
                )
                time.sleep(exception_backoff)
                trial += 1
            # unknown error shall throw exception

In [8]:
# Use our sglang client as OpenAI client
client = openai.Client(base_url="http://127.0.0.1:10086/v1", api_key="None")

sampler = ChatCompletionSampler(model="meta-llama/Meta-Llama-3.1-8B-Instruct",
                                temperature=0.0,
                                max_tokens=2048,
                                client=client)

In [9]:
debug = False
# equality_checker = sampler
matheval = MathEval(
     equality_checker=sampler, num_examples=10 if debug else 500
)

In [10]:
samplers = {"llama3": sampler}

In [14]:
# also from simple_eval by OpenAI

import pandas as pd

evals = {
    "matheval": matheval
}
print(evals)
debug_suffix = "_DEBUG" if debug else ""
print(debug_suffix)
mergekey2resultpath = {}
for sampler_name, sampler in samplers.items():
    for eval_name, eval_obj in evals.items():
        result = eval_obj(sampler)
        # ^^^ how to use a sampler
        file_stem = f"{eval_name}_{sampler_name}"
        report_filename = f"/tmp/{file_stem}{debug_suffix}.html"
        print(f"Writing report to {report_filename}")
        with open(report_filename, "w") as fh:
            fh.write(common.make_report(result))
        metrics = result.metrics | {"score": result.score}
        print(metrics)
        result_filename = f"/tmp/{file_stem}{debug_suffix}.json"
        with open(result_filename, "w") as f:
            f.write(json.dumps(metrics, indent=2))
        print(f"Writing results to {result_filename}")
        mergekey2resultpath[f"{file_stem}"] = result_filename
merge_metrics = []
for eval_sampler_name, result_filename in mergekey2resultpath.items():
    try:
        result = json.load(open(result_filename, "r+"))
    except Exception as e:
        print(e, result_filename)
        continue
    result = result.get("f1_score", result.get("score", None))
    eval_name = eval_sampler_name[: eval_sampler_name.find("_")]
    sampler_name = eval_sampler_name[eval_sampler_name.find("_") + 1 :]
    merge_metrics.append(
        {"eval_name": eval_name, "sampler_name": sampler_name, "metric": result}
    )
merge_metrics_df = pd.DataFrame(merge_metrics).pivot(
    index=["sampler_name"], columns="eval_name"
)
print("\nAll results: ")
print(merge_metrics_df.to_markdown())

{'matheval': <__main__.MathEval object at 0x7f46d1e51c00>}



100%|██████████| 500/500 [02:59<00:00,  2.79it/s]

Writing report to /tmp/matheval_llama3.html
{'score:std': 0.4976906669810074, 'score': 0.452}
Writing results to /tmp/matheval_llama3.json

All results: 
| sampler_name   |   ('metric', 'matheval') |
|:---------------|-------------------------:|
| llama3         |                    0.452 |





In [None]:
####################################################################################################
###################################### SGLang for Tree Search ######################################
####################################################################################################

In [11]:
import random
import re
from typing import Literal

import blobfile as bf
import pandas

import common
from common import ANSWER_PATTERN, HTML_JINJA, check_equality
from eval_types import Eval, EvalResult, SamplerBase, SingleEvalResult

split = "my_math_500_test"
df = pandas.read_csv(
    # bf.BlobFile(f"https://openaipublic.blob.core.windows.net/simple-evals/{split}.csv")
    bf.BlobFile(f"{split}.csv")
)
examples = [row.to_dict() for _, row in df.iterrows()]

In [12]:
## try to use sglang to sequentially generate next steps.

import sglang as sgl
import argparse
from sglang.test.test_utils import (
    add_common_sglang_args_and_parse,
    select_sglang_backend,
)
import time


max_steps = 30


@sgl.function
def search_try(s, question):
    s += sgl.user(
        f"""Solve the following math problem step by step. Steps should be separated with two new lines. The last line of your response should be of the form Answer: $ANSWER (without quotes) where $ANSWER is the answer to the problem.
        
{question}

Remember to separate steps with two new lines, and finally put your answer on its own line after "Answer:", and you do not need to use a \\boxed command.
"""
    )
    
    s += sgl.assistant_begin()
    s += sgl.gen(max_tokens=256, stop=["\n\n"])
    
    # print(f"''{s.text()}''")
    
    for _ in range(max_steps):
        # s += new_line
        s += "\n\n"
        s += sgl.gen(max_tokens=256, stop=["\n\n"])
        # print(f"''{s.text()}''")  
        if "Answer:" in s.text().split("\n")[-1]:
            break
    
    # s += sgl.assistant(sgl.gen("step", max_tokens=256, temperature=0.3, stop=["\n\n"]))

In [13]:
args = argparse.Namespace(
    # data_path="sglang/benchmark/tree_of_thought_v0/test.jsonl",
    # num_questions=2,
    port=10086,
    parallel=16,
    backend='srt',
    host="http://127.0.0.1",
    result_file="results.txt"
)

# q = """If $x^3$ is a positive factor of $10!,$ how many possible integer values of $x$ are there?  (Reminder: For a positive integer $n$, the expression $n!$ stands for the product of the integers from 1 up to (and including) $n$.)"""

arguments = [{"question": d['Question']} for d in examples[:100]]

# Select backend
backend = select_sglang_backend(args)

# Run requests
tic = time.time()
states = search_try.run_batch(
    arguments,
    temperature=0,
    backend=backend,
    num_threads=args.parallel,
    progress_bar=True,
)
latency = time.time() - tic

100%|██████████| 100/100 [01:13<00:00,  1.36it/s]


In [14]:
# for i, state in enumerate(states):
#     print(f"Question: {arguments[i]['question']}")
#     print(f"Answer: {state.text()}")

scores = 0
for i, state in enumerate(states):
    response_text = state.text()
    match = re.search(ANSWER_PATTERN, response_text.split("\n")[-1])
    extracted_answer = match.group(1) if match else None
    # score = 0 if extracted_answer is None else \
    #     float(check_equality(self.equality_checker, row["Answer"], extracted_answer))
    answer = examples[i]["Answer"]
    score = 0 if extracted_answer is None else math_evaluator.eq(answer, extracted_answer)
    
    print(f"{extracted_answer} || {answer} || {score}")
    scores += score
    
print(scores/len(states))

$\left(3, \frac{\pi}{2}\right)$ || \left( 3, \frac{\pi}{2} \right) || True
None || p - q || 0
$\frac{14}{3}$ || \frac{14}{3} || True
9 || 9 || True
Angela || \text{Evelyn} || False
126 || 42 || False
None || 27 || 0
101.5 || 90^\circ || False
3√13 || 3\sqrt{13} || False
None || 4 || 0
None || 2220 || 0
$\frac{13}{56}$ || \frac{3}{56} || False
284 || 284 || True
$5$ || 5 || True
$10$ || \sqrt{51} || False
None || 6 - 5i || 0
-50 || -50 || True
$\pi$ || \pi || True
112 || 28 || False
None || 3 || 0
6 + 9i || 6+9i || True
None || 13535 || 0
None || 5 || 0
None || x=5 || 0
10 || 10 || True
$\boxed{\text{There are no solutions.}}$ || 1,-2 || False
12 || 144 || False
$78 || 78 || True
-2 + 7i || -2 + 7i || True
None || 225 || 0
$2_8$ || 52_8 || False
11$\sqrt{2}$ || 11\sqrt2 || True
None || 720 || 0
None || \frac{243}{625} || 0
-125 || -125 || True
None || 3 || 0
$2, 5$ || 3, 5, 7 || False
360 || 72 || False
2000 || 2000 || True
23 || 23 || True
12 || 12 || True
-4.5 || 17 || False
4 || 4 ||

In [15]:
import random

In [None]:
import requests

url = "http://localhost:10087/judge"
data = {"model": "Skywork/Skywork-Reward-Llama-3.1-8B-v0.2", "text": prompts}

responses = requests.post(url, json=data).json()
for response in responses:
    # print(type(response))
    print(f"reward: {response['embedding'][0]}")

In [28]:
## Try beam search

import sglang as sgl
max_steps = 30

BEAM_SIZE = 4
BEAM_WIDTH = 2
assert BEAM_SIZE % BEAM_WIDTH == 0

@sgl.function
def beam_search(s, question):
    s += sgl.user(
        f"""Solve the following math problem step by step. Steps should be separated with two new lines. The last line of your response should be of the form Answer: $ANSWER (without quotes) where $ANSWER is the answer to the problem.
        
{question}

Remember to separate steps with two new lines, and finally put your answer on its own line after "Answer:", and you do not need to use a \\boxed command."""
    )

    s += sgl.assistant_begin()
    forks = s.fork(BEAM_SIZE)
    forks += sgl.gen(max_tokens=256, stop=["\n\n"], temperature=0.5)
    step = s.text().split("\n\n")[-1]

    # print(f"''{s.text()}''")
    cur_states = list(forks)
    
    answer_states = []

    for _ in range(max_steps):
        
        # s += new_line
        # randomly select BEAM_WIDTH states
        # print("--A--")
        
        cur_beam_width = min(BEAM_WIDTH, len(cur_states))
        
        cur_states = random.sample(cur_states, cur_beam_width)

        # expand to BEAM_SIZE states
        new_states = []

        for state in cur_states:

            # print("--B--")

            if "Answer:" in state.text().split("\n")[-1]:
                answer_states.append(state)
                continue
            
            # print("--C--")
            forked_states = state.fork((BEAM_SIZE - 1) // cur_beam_width + 1)
            forked_states += "\n\n" + sgl.gen(max_tokens=256, stop=["\n\n"], temperature=0.5)
            new_states.extend(forked_states)
        
        # print("--D--")
        # print(len(new_states))
        cur_states = new_states
        
        if len(answer_states) > 0:
            break
            
        
    return answer_states


In [29]:
bs_states = beam_search.run_batch(
    arguments[:1],
    temperature=0,
    backend=backend,
    num_threads=args.parallel,
    progress_bar=True,
)

100%|██████████| 1/1 [00:03<00:00,  3.11s/it]


In [30]:
scores = 0

for i, states in enumerate(bs_states):
    if len(states.ret_value) == 0:
        continue

    response_text = states.ret_value[0].text()
    match = re.search(ANSWER_PATTERN, response_text.split("\n")[-1])
    extracted_answer = match.group(1) if match else None
    # score = 0 if extracted_answer is None else \
    #     float(check_equality(self.equality_checker, row["Answer"], extracted_answer))
    answer = examples[i]["Answer"]
    score = 0 if extracted_answer is None else math_evaluator.eq(answer, extracted_answer)
    
    print(f"{extracted_answer} || {answer} || {score}")
    scores += score

print(scores/len(bs_states))

(3, \frac{\pi}{2}) || \left( 3, \frac{\pi}{2} \right) || True
1.0


In [21]:
bs_states[0]

ProgramState(<|start_header_id|>user<|end_header_id|>

Solve the following math problem step by step. Steps should be separated with two new lines. The last line of your response should be of the form Answer: $ANSWER (without quotes) where $ANSWER is the answer to the problem.
        
Convert the point $(0,3)$ in rectangular coordinates to polar coordinates.  Enter your answer in the form $(r,\theta),$ where $r > 0$ and $0 \le \theta < 2 \pi.$

Remember to separate steps with two new lines, and finally put your answer on its own line after "Answer:", and you do not need to use a \boxed command.<|eot_id|><|start_header_id|>assistant<|end_header_id|>

)

In [37]:
rm_url = "http://localhost:10087/judge"

In [None]:
## Try beam search

import sglang as sgl
max_steps = 30

BEAM_SIZE = 4
BEAM_WIDTH = 2
assert BEAM_SIZE % BEAM_WIDTH == 0

@sgl.function
def beam_search_with_rm(s, question):
    s += sgl.user(
        f"""Solve the following math problem step by step. Steps should be separated with two new lines. The last line of your response should be of the form Answer: $ANSWER (without quotes) where $ANSWER is the answer to the problem.
        
{question}

Remember to separate steps with two new lines, and finally put your answer on its own line after "Answer:", and you do not need to use a \\boxed command."""
    )

    s += sgl.assistant_begin()
    forks = s.fork(BEAM_SIZE)
    forks += sgl.gen(max_tokens=256, stop=["\n\n"], temperature=0.5)
    step = s.text().split("\n\n")[-1]

    # print(f"''{s.text()}''")
    cur_states = list(forks)
    
    answer_states = []

    for _ in range(max_steps):
        
        # s += new_line
        # randomly select BEAM_WIDTH states
        # print("--A--")
        
        cur_beam_width = min(BEAM_WIDTH, len(cur_states))
        
        print("----A----")
        
        texts = [cur_state.text() for cur_state in cur_states]
        print(texts)
        data = {"model": "Skywork/Skywork-Reward-Llama-3.1-8B-v0.2", "text": texts}
        responses = requests.post(rm_url, json=data).json()
        for response in responses:
            # print(type(response))
            print(f"reward: {response['embedding'][0]}")

        rewards = [response['embedding'][0] for response in responses]
        # cur_states = random.sample(cur_states, cur_beam_width)
        # select the best BEAM_WIDTH states based on reward
        cur_states = [cur_states[i] for i in sorted(range(len(cur_states)), key=lambda i: rewards[i], reverse=True)[:cur_beam_width]]

        # expand to BEAM_SIZE states
        new_states = []

        for state in cur_states:

            # print("--B--")

            if "Answer:" in state.text().split("\n")[-1]:
                answer_states.append(state)
                continue
            
            # print("--C--")
            forked_states = state.fork((BEAM_SIZE - 1) // cur_beam_width + 1)
            forked_states += "\n\n" + sgl.gen(max_tokens=256, stop=["\n\n"], temperature=0.5)
            new_states.extend(forked_states)
        
        # print("--D--")
        # print(len(new_states))
        cur_states = new_states
        
        if len(answer_states) > 0:
            break
            
        
    return answer_states


In [41]:
bs_states = beam_search_with_rm.run_batch(
    arguments[:1],
    temperature=0,
    backend=backend,
    num_threads=args.parallel,
    progress_bar=True,
)

  0%|          | 0/1 [00:00<?, ?it/s]

----A----
['<|start_header_id|>user<|end_header_id|>\n\nSolve the following math problem step by step. Steps should be separated with two new lines. The last line of your response should be of the form Answer: $ANSWER (without quotes) where $ANSWER is the answer to the problem.\n        \nConvert the point $(0,3)$ in rectangular coordinates to polar coordinates.  Enter your answer in the form $(r,\\theta),$ where $r > 0$ and $0 \\le \\theta < 2 \\pi.$\n\nRemember to separate steps with two new lines, and finally put your answer on its own line after "Answer:", and you do not need to use a \\boxed command.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nThe rectangular coordinates $(x,y)=(0,3)$ correspond to the polar coordinates $(r,\\theta)$, where $r$ is the distance from the origin to the point, and $\\theta$ is the angle from the positive $x$-axis to the line connecting the origin to the point.', '<|start_header_id|>user<|end_header_id|>\n\nSolve the following math proble

100%|██████████| 1/1 [00:05<00:00,  5.49s/it]

['<|start_header_id|>user<|end_header_id|>\n\nSolve the following math problem step by step. Steps should be separated with two new lines. The last line of your response should be of the form Answer: $ANSWER (without quotes) where $ANSWER is the answer to the problem.\n        \nConvert the point $(0,3)$ in rectangular coordinates to polar coordinates.  Enter your answer in the form $(r,\\theta),$ where $r > 0$ and $0 \\le \\theta < 2 \\pi.$\n\nRemember to separate steps with two new lines, and finally put your answer on its own line after "Answer:", and you do not need to use a \\boxed command.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nTo convert the point $(0,3)$ from rectangular coordinates to polar coordinates, we need to find the radius $r$ and the angle $\\theta$.\n\nThe radius $r$ can be found using the formula $r = \\sqrt{x^2 + y^2}$, where $x$ and $y$ are the coordinates of the point.\n\nSince the point is $(0,3)$, we have $x = 0$ and $y = 3$.\n\nPlugging these




In [33]:
texts = ['<|start_header_id|>user<|end_header_id|>\n\nSolve the following math problem step by step. Steps should be separated with two new lines. The last line of your response should be of the form Answer: $ANSWER (without quotes) where $ANSWER is the answer to the problem.\n        \nConvert the point $(0,3)$ in rectangular coordinates to polar coordinates.  Enter your answer in the form $(r,\\theta),$ where $r > 0$ and $0 \\le \\theta < 2 \\pi.$\n\nRemember to separate steps with two new lines, and finally put your answer on its own line after "Answer:", and you do not need to use a \\boxed command.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nTo convert the point $(0,3)$ from rectangular coordinates to polar coordinates, we first need to find the value of $r$ and $\\theta$.', '<|start_header_id|>user<|end_header_id|>\n\nSolve the following math problem step by step. Steps should be separated with two new lines. The last line of your response should be of the form Answer: $ANSWER (without quotes) where $ANSWER is the answer to the problem.\n        \nConvert the point $(0,3)$ in rectangular coordinates to polar coordinates.  Enter your answer in the form $(r,\\theta),$ where $r > 0$ and $0 \\le \\theta < 2 \\pi.$\n\nRemember to separate steps with two new lines, and finally put your answer on its own line after "Answer:", and you do not need to use a \\boxed command.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nTo convert the point $(0,3)$ from rectangular coordinates to polar coordinates, we use the formulas $r = \\sqrt{x^2 + y^2}$ and $\\theta = \\tan^{-1}\\left(\\frac{y}{x}\\right)$.']

In [35]:
data = {"model": "Skywork/Skywork-Reward-Llama-3.1-8B-v0.2", "text": texts}
responses = requests.post(url, json=data)

In [36]:
responses

<Response [500]>

In [None]:
for response in responses:
    # print(type(response))
    print(f"reward: {response['embedding'][0]}")

In [None]:
## Try beam search

import sglang as sgl
max_steps = 30

BEAM_SIZE = 4
BEAM_WIDTH = 2
assert BEAM_SIZE % BEAM_WIDTH == 0

@sgl.function
def beam_search_with_critic(s, question):
    s += sgl.user(
        f"""Solve the following math problem step by step. Steps should be separated with two new lines. The last line of your response should be of the form Answer: $ANSWER (without quotes) where $ANSWER is the answer to the problem.
        
{question}

Remember to separate steps with two new lines, and finally put your answer on its own line after "Answer:", and you do not need to use a \\boxed command."""
    )

    s += sgl.assistant_begin()
    forks = s.fork(BEAM_SIZE)
    forks += sgl.gen(max_tokens=256, stop=["\n\n"], temperature=0.5)
    step = s.text().split("\n\n")[-1]

    # print(f"''{s.text()}''")
    cur_states = list(forks)
    
    answer_states = []

    for _ in range(max_steps):
        
        # s += new_line
        # randomly select BEAM_WIDTH states
        # print("--A--")
        
        cur_beam_width = min(BEAM_WIDTH, len(cur_states))
        
        print("----A----")
        
        texts = [cur_state.text() for cur_state in cur_states]
        print(texts)
        data = {"model": "Skywork/Skywork-Reward-Llama-3.1-8B-v0.2", "text": texts}
        responses = requests.post(rm_url, json=data).json()
        for response in responses:
            # print(type(response))
            print(f"reward: {response['embedding'][0]}")

        rewards = [response['embedding'][0] for response in responses]


        # cur_states = random.sample(cur_states, cur_beam_width)
        # select the best BEAM_WIDTH states based on reward
        cur_states = [cur_states[i] for i in sorted(range(len(cur_states)), key=lambda i: rewards[i], reverse=True)[:cur_beam_width]]

        # expand to BEAM_SIZE states
        new_states = []

        for state in cur_states:

            # print("--B--")

            if "Answer:" in state.text().split("\n")[-1]:
                answer_states.append(state)
                continue
            
            # print("--C--")
            forked_states = state.fork((BEAM_SIZE - 1) // cur_beam_width + 1)
            forked_states += "\n\n" + sgl.gen(max_tokens=256, stop=["\n\n"], temperature=0.5)
            new_states.extend(forked_states)
        
        # print("--D--")
        # print(len(new_states))
        cur_states = new_states
        
        if len(answer_states) > 0:
            break
            
        
    return answer_states
