Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/lighteval/metrics/metrics_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,6 +947,7 @@ def __init__(
url: str | None = None,
hf_provider: str | None = None,
max_tokens: int | None = None,
backend_options: dict | None = None,
) -> None:
logger.debug(f"Initializing JudgeLLM with backend: {judge_backend}, model: {judge_model_name}")

Expand Down Expand Up @@ -993,6 +994,7 @@ def __init__(
url=url,
hf_provider=hf_provider,
max_tokens=max_tokens,
backend_options=backend_options,
)

def compute(self, responses: list[ModelResponse], docs: list[Doc], **kwargs) -> list:
Expand Down
66 changes: 48 additions & 18 deletions src/lighteval/metrics/utils/llm_as_judge.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import logging
import time
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import Callable, Literal, Optional

from huggingface_hub import AsyncInferenceClient, InferenceTimeoutError
Expand All @@ -45,28 +46,40 @@
DEFAULT_FORMAT = {"type": "text"}


class JudgeLM:
"""A class representing a judge for evaluating answers using either the OpenAI or Transformers library.
@dataclass
class LitellmBackendOptions:
"""Options for the LiteLLM judge backend with default values.

Args:
model (str): The name of the model.
templates (Callable): A function taking into account the question, options, answer, and gold and returning the judge prompt.
process_judge_response (Callable): A function for processing the judge's response.
judge_backend (Literal["openai", "transformers", "tgi", "vllm"]): The backend for the judge.
url (str | None): The URL for the OpenAI API.
api_key (str | None): The API key for the OpenAI API (either OpenAI or HF key).
Attributes:
caching (bool): Whether to enable caching for the API responses. Defaults to True.
concurrent_requests (int): The maximum number of concurrent requests to the API. Defaults to 10.
increase_max_tokens_for_reasoning (bool): Whether to increase the max tokens for certain reasoning
models. Defaults to True.
"""

caching: bool = True
concurrent_requests: int = 10

# Increases max_tokens depending on the model used, see implementation below
increase_max_tokens_for_reasoning: bool = True


class JudgeLM:
"""A class representing a judge for evaluating answers using either the chosen backend.

Attributes:
model (str): The name of the model.
template (Callable): A function taking into account the question, options, answer, and gold and returning the judge prompt.
API_MAX_RETRY (int): The maximum number of retries for the API.
API_RETRY_SLEEP (int): The time to sleep between retries.
client (OpenAI | None): The OpenAI client.
pipe (LLM | AutoModel | None): The Transformers or vllm pipeline.
templates (Callable): A function taking into account the question, options, answer, and gold and returning the judge prompt.
process_judge_response (Callable): A function for processing the judge's response.
judge_backend (Literal["litellm", "openai", "transformers", "tgi", "vllm", "inference-providers"]): The backend for the judge.
url (str | None): The URL for the OpenAI API.
api_key (str | None): The API key for the OpenAI API (either OpenAI or HF key).
backend (Literal["openai", "transformers", "tgi", "vllm"]): The backend for the judge
max_tokens (int): The maximum number of tokens to generate. Defaults to 512.
response_format (BaseModel | None): The format of the response from the API, used for the OpenAI and TGI backend.
hf_provider (Literal["black-forest-labs", "cerebras", "cohere", "fal-ai", "fireworks-ai",
"inference-providers", "hyperbolic", "nebius", "novita", "openai", "replicate", "sambanova", "together"] | None):
The HuggingFace provider when using the inference-providers backend.
backend_options (dict | None): Options for the backend. Currently only supported for litellm.

Methods:
evaluate_answer: Evaluates an answer using the OpenAI API or Transformers library.
Expand Down Expand Up @@ -103,6 +116,7 @@ def __init__(
"together",
]
] = None,
backend_options: dict | None = None,
):
self.model = model
self.template = templates
Expand All @@ -122,6 +136,12 @@ def __init__(

self.response_format = response_format if not None else DEFAULT_FORMAT

self.backend_options = backend_options or {}

# Override backend options dictionary with the corresponding dataclass to ensure all specified options are valid
if judge_backend == "litellm":
self.backend_options = LitellmBackendOptions(**self.backend_options)

# Validate that hf_provider is specified when using inference-providers backend
if self.backend == "inference-providers" and self.hf_provider is None:
raise ValueError("When using 'inference-providers' as backend, you must specify an 'hf_provider'")
Expand Down Expand Up @@ -286,12 +306,22 @@ def __call_vllm(self, prompt):
def __call_litellm(self, prompts):
import litellm

if self.backend_options.caching:
from litellm.caching.caching import Cache, LiteLLMCacheType

litellm.cache = Cache(type=LiteLLMCacheType.DISK)

# Automatically drop parameters that are not supported by the currently used inference API
litellm.drop_params = True

def __call_api(prompt):
error_message = "ERROR: Failed to get response from the API."
for _ in range(self.API_MAX_RETRY):
try:
max_new_tokens = 512
if "o1" in self.model or "o3" in self.model or "R1" in self.model:
max_new_tokens = self.max_tokens

is_reasoning_model = "o1" in self.model or "o3" in self.model or "R1" in self.model
if is_reasoning_model and self.backend_options.increase_max_tokens_for_reasoning:
max_new_tokens = min(max_new_tokens * 10, 32000)

kwargs = {
Expand Down Expand Up @@ -319,7 +349,7 @@ def __call_api(prompt):
return error_message

results = []
with ThreadPoolExecutor(100) as executor:
with ThreadPoolExecutor(self.backend_options.concurrent_requests) as executor:
for entry in tqdm(executor.map(__call_api, prompts), total=len(prompts)):
results.append(entry)

Expand Down
Loading