-
Notifications
You must be signed in to change notification settings - Fork 11
[WIP] Creates Judges as a wrapper on Policy #202
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
249a6ab
be26c39
e88f58f
634fe59
336c997
6a01bd7
8c87d42
f80ff68
53607fd
00bbffa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
"""To run: | ||
export HF_HUB_DISABLE_XET=1 | ||
python -m apps.vllm.judge --config apps/vllm/llama3_8b.yaml | ||
""" | ||
|
||
import asyncio | ||
|
||
import os | ||
|
||
from forge.actors.judge import EvaluationMode, Judge | ||
from forge.cli.config import parse | ||
from forge.controller.provisioner import shutdown | ||
|
||
from forge.observability.metric_actors import get_or_create_metric_logger | ||
from omegaconf import DictConfig | ||
|
||
os.environ["HYPERACTOR_MESSAGE_DELIVERY_TIMEOUT_SECS"] = "600" | ||
os.environ["HYPERACTOR_CODE_MAX_FRAME_LENGTH"] = "1073741824" | ||
|
||
|
||
async def run(cfg: DictConfig): | ||
metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}}) | ||
mlogger = await get_or_create_metric_logger() | ||
await mlogger.init_backends.call_one(metric_logging_cfg) | ||
|
||
prompt = "What is the capital of Japan?" | ||
responses = ["Aardvark", "Durian", "Tokyo"] | ||
|
||
print("Spawning service...") | ||
judge = await Judge.options(**cfg.services.policy).as_service(**cfg.policy) | ||
|
||
print(f"Prompt: {prompt}") | ||
print(f"Responses: {responses}\n") | ||
print("Evaluating responses...") | ||
best_response_evaluations: list[str] = await judge.evaluate.route( | ||
prompt=prompt, responses=responses, evaluation_mode=EvaluationMode.BEST_RESPONSE | ||
) | ||
response_check_evaluations: list[str] = await judge.evaluate.route( | ||
prompt=prompt, | ||
responses=responses, | ||
evaluation_mode=EvaluationMode.RESPONSE_CHECK, | ||
) | ||
|
||
print("\nGeneration Results:") | ||
print("=" * 80) | ||
for batch, (best, fact) in enumerate( | ||
zip(best_response_evaluations, response_check_evaluations) | ||
): | ||
print(f"Sample {batch + 1}") | ||
print(f"Evaluation (BEST_RESPONSE): {best}") | ||
print(f"Evaluation (RESPONSE_CHECK): {fact}") | ||
print("-" * 80) | ||
|
||
print("\nShutting down...") | ||
await judge.shutdown() | ||
await shutdown() | ||
|
||
|
||
@parse | ||
def recipe_main(cfg: DictConfig) -> None: | ||
asyncio.run(run(cfg)) | ||
|
||
|
||
if __name__ == "__main__": | ||
recipe_main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,214 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from dataclasses import dataclass | ||
from enum import auto, Enum | ||
|
||
from monarch.actor import endpoint | ||
|
||
from forge.actors.policy import Policy | ||
from forge.data_models.completion import Completion | ||
|
||
|
||
class EvaluationMode(Enum): | ||
"""Enum for selecting how a judge should evaluate the provided args""" | ||
|
||
BEST_RESPONSE = auto() | ||
RESPONSE_CHECK = auto() | ||
MATH_CHECK = auto() | ||
|
||
|
||
@dataclass | ||
class Judge(Policy): | ||
""" | ||
`LLM-based Judges` are typically generative models which are then prompted | ||
to evaluate responses. These models NEED prompt engineering to evaluate | ||
and may require more postprocessing | ||
""" | ||
|
||
def _math_check( | ||
self, | ||
prompt: str, | ||
responses: list[str], | ||
ground_truth: None | str = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why does the judge need access to the ground truth for math check? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you're using the judge as a grader the ground truth is convenient |
||
) -> str: | ||
""" | ||
Construct the generator input. Formats the request such that the generator | ||
will return a comma separated list with a [[GOOD]] or [[BAD]] evaluation | ||
for each response, corresponding to whether the model thinks the response | ||
matches the provided ground_truth. Specifically the generator is prompted to | ||
check for mathematical equivalence | ||
|
||
Note: This is not a "good" prompt, it just demonstrates how to make one | ||
""" | ||
|
||
if ground_truth is None: | ||
raise | ||
|
||
system_prompt = f""" | ||
You are a math professor. Given the prompt and ground truth solution, evaluate | ||
each of the provided attempts and return whether the final solution is | ||
numerically equivalent to the ground truth. | ||
|
||
Each response is formatted as [Response #<N>], where <N> represents the | ||
attempt. | ||
|
||
Your answer should be a comma separated list of "[[GOOD]]" or "[[BAD]]", | ||
corresponding to the same order as the reponses provided. | ||
|
||
- If the answer is irrelevant to the prompt, return "[[BAD]]". | ||
- If you are not confident that solution and attempt are equivalent, return "[[BAD]]" | ||
- Only return "[[GOOD]]" if the attempt is numerically equivalent | ||
|
||
Do not explain your reasoning, just provide your evaluations. | ||
--- | ||
Here is the prompt that generated the responses: {prompt}. | ||
--- | ||
Here is the ground truth: {ground_truth} | ||
""" | ||
response_str = "\n".join( | ||
[f"[Response #{i+1}] {resp}" for i, resp in enumerate(responses)] | ||
) | ||
as_chat = [ | ||
{"role": "system", "content": system_prompt}, | ||
{"role": "user", "content": response_str}, | ||
] | ||
tokenizer = self.processor.tokenizer.tokenizer | ||
formatted_request = tokenizer.apply_chat_template( | ||
as_chat, tokenize=False, add_generation_prompt=True | ||
) | ||
return formatted_request | ||
|
||
def _response_check( | ||
self, | ||
prompt: str, | ||
responses: list[str], | ||
ground_truth: None | str = None, | ||
) -> str: | ||
""" | ||
Construct the generator input. Formats the request such that the generator | ||
will return a comma separated list with a [[GOOD]] or [[BAD]] evaluation | ||
for each response, corresponding to whether the model thinks it correct | ||
answers the prompt. | ||
|
||
Note: This is not a "good" prompt, it just demonstrates how to make one | ||
""" | ||
|
||
system_prompt = f""" | ||
You are an expert fact checker. Given a prompt and response attempts, evaluate | ||
each attempt and return whether it accurately answers the prompt. | ||
Each response is formatted as [Response #<N>], where <N> represents the | ||
attempt. | ||
|
||
Your answer should be a comma separated list of "[[GOOD]]" or "[[BAD]]", | ||
corresponding to the same order as the reponses provided. | ||
|
||
- If the answer is irrelevant to the prompt, return "[[BAD]]". | ||
- If you are not confident that the answer accurately answers the prompt, return "[[BAD]]" | ||
- Only return "[[GOOD]]" if the attempt accurately answers the prompt | ||
|
||
Do not explain your reasoning, just provide your evaluations. | ||
Here is the prompt that generated the responses: {prompt}. | ||
""" | ||
response_str = "\n".join( | ||
[f"[Response #{i+1}] {resp}" for i, resp in enumerate(responses)] | ||
) | ||
as_chat = [ | ||
{"role": "system", "content": system_prompt}, | ||
{"role": "user", "content": response_str}, | ||
] | ||
tokenizer = self.processor.tokenizer.tokenizer | ||
formatted_request = tokenizer.apply_chat_template( | ||
as_chat, tokenize=False, add_generation_prompt=True | ||
) | ||
return formatted_request | ||
|
||
def _best_check( | ||
self, | ||
prompt: str, | ||
responses: list[str], | ||
ground_truth: None | str = None, | ||
) -> str: | ||
""" | ||
Construct the generator input. Format the request such that the generator | ||
will respond with a single integer corresponding to the response the model | ||
thinks is most factually correct. | ||
|
||
Note: This is not a "good" prompt, it just demonstrates how to make one | ||
""" | ||
|
||
system_prompt = f""" | ||
You are an expert evaluator. Evaluate the responses provided and return | ||
a single integer indicating which response is the most factually correct. | ||
Each response is formatted as [Response #<N>], where <N> represents the | ||
selection. Do not explain your reasoning, just provide a number. | ||
|
||
Here is the prompt that generated the responses: {prompt}. | ||
""" | ||
response_str = "\n".join( | ||
[f"[Response #{i+1}] {resp}" for i, resp in enumerate(responses)] | ||
) | ||
as_chat = [ | ||
{"role": "system", "content": system_prompt}, | ||
{"role": "user", "content": response_str}, | ||
] | ||
tokenizer = self.processor.tokenizer.tokenizer | ||
formatted_request = tokenizer.apply_chat_template( | ||
as_chat, tokenize=False, add_generation_prompt=True | ||
) | ||
return formatted_request | ||
|
||
def _postprocess_output(self, outputs: list[Completion]) -> list[str]: | ||
return [output.text for output in outputs] | ||
|
||
@endpoint | ||
async def evaluate( | ||
self, | ||
prompt: str, | ||
responses: None | list[str] = None, | ||
ground_truth: None | str = None, | ||
evaluation_mode: EvaluationMode = EvaluationMode.BEST_RESPONSE, | ||
) -> list[str]: | ||
_prompting: dict = { | ||
EvaluationMode.BEST_RESPONSE: self._best_check, | ||
EvaluationMode.RESPONSE_CHECK: self._response_check, | ||
EvaluationMode.MATH_CHECK: self._math_check, | ||
} | ||
|
||
wrapped_prompt: str = _prompting[evaluation_mode]( | ||
prompt, responses, ground_truth | ||
) | ||
response: List[Completion] = await self.generate._method(self, wrapped_prompt) | ||
return self._postprocess_output(response) | ||
|
||
|
||
@dataclass | ||
class RewardModelJudge(Policy): | ||
""" | ||
`RewardModels` are typically discriminative models, post trained to | ||
evaluate responses without further prompting required. | ||
""" | ||
|
||
# TODO: Add reward models formatting | ||
def wrapped_prompt( | ||
self, prompt: str, responses: list[str], ground_truth: None | str = None | ||
) -> str: | ||
return prompt | ||
|
||
def _postprocess_output( | ||
self, outputs: list[Completion], ground_truth: None | str = None | ||
) -> list[str]: | ||
return [output.text for output in outputs] | ||
|
||
@endpoint | ||
async def evaluate( | ||
self, | ||
prompt: str, | ||
responses: list[str], | ||
) -> list[str]: | ||
wrapped_prompt: str = self._wrap_prompt(prompt, responses) | ||
response: List[Completion] = await self.generate._method(self, wrapped_prompt) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we just change the Policy to have _generate with the core logic and a generate endpoint? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do Service/Replica keep the non endpoint attributes? No strong pref, if it does There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
yes, it'll probably be cleaner than accessing the _method field which could change at any time |
||
return self._postprocess_output(response) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
General nit voicing a preference for injection vs. inheritance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh boi this is a fun one. Had a quick chat with @ebsmothers on me opening the flood gates few days ago
cc: @pbontrager too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
another option is like
and have everything be kept within Generator?
imo out of all of the options we've seen, inheritance is the least of all evils. But we should do our best to keep the level of inheritance low, like we certainly shouldn't add in another layer on top of this