diff --git a/src/lighteval/metrics/metrics_sample.py b/src/lighteval/metrics/metrics_sample.py index 38a0e2f52..6935f9321 100644 --- a/src/lighteval/metrics/metrics_sample.py +++ b/src/lighteval/metrics/metrics_sample.py @@ -947,6 +947,7 @@ def __init__( url: str | None = None, hf_provider: str | None = None, max_tokens: int | None = None, + backend_options: dict | None = None, ) -> None: logger.debug(f"Initializing JudgeLLM with backend: {judge_backend}, model: {judge_model_name}") @@ -993,6 +994,7 @@ def __init__( url=url, hf_provider=hf_provider, max_tokens=max_tokens, + backend_options=backend_options, ) def compute(self, responses: list[ModelResponse], docs: list[Doc], **kwargs) -> list: diff --git a/src/lighteval/metrics/utils/llm_as_judge.py b/src/lighteval/metrics/utils/llm_as_judge.py index dcf0a5a88..22da4b3e3 100644 --- a/src/lighteval/metrics/utils/llm_as_judge.py +++ b/src/lighteval/metrics/utils/llm_as_judge.py @@ -25,6 +25,7 @@ import logging import time from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass from typing import Callable, Literal, Optional from huggingface_hub import AsyncInferenceClient, InferenceTimeoutError @@ -45,28 +46,40 @@ DEFAULT_FORMAT = {"type": "text"} -class JudgeLM: - """A class representing a judge for evaluating answers using either the OpenAI or Transformers library. +@dataclass +class LitellmBackendOptions: + """Options for the LiteLLM judge backend with default values. - Args: - model (str): The name of the model. - templates (Callable): A function taking into account the question, options, answer, and gold and returning the judge prompt. - process_judge_response (Callable): A function for processing the judge's response. - judge_backend (Literal["openai", "transformers", "tgi", "vllm"]): The backend for the judge. - url (str | None): The URL for the OpenAI API. - api_key (str | None): The API key for the OpenAI API (either OpenAI or HF key). + Attributes: + caching (bool): Whether to enable caching for the API responses. Defaults to True. + concurrent_requests (int): The maximum number of concurrent requests to the API. Defaults to 10. + increase_max_tokens_for_reasoning (bool): Whether to increase the max tokens for certain reasoning + models. Defaults to True. + """ + + caching: bool = True + concurrent_requests: int = 10 + + # Increases max_tokens depending on the model used, see implementation below + increase_max_tokens_for_reasoning: bool = True + + +class JudgeLM: + """A class representing a judge for evaluating answers using either the chosen backend. Attributes: model (str): The name of the model. - template (Callable): A function taking into account the question, options, answer, and gold and returning the judge prompt. - API_MAX_RETRY (int): The maximum number of retries for the API. - API_RETRY_SLEEP (int): The time to sleep between retries. - client (OpenAI | None): The OpenAI client. - pipe (LLM | AutoModel | None): The Transformers or vllm pipeline. + templates (Callable): A function taking into account the question, options, answer, and gold and returning the judge prompt. process_judge_response (Callable): A function for processing the judge's response. + judge_backend (Literal["litellm", "openai", "transformers", "tgi", "vllm", "inference-providers"]): The backend for the judge. url (str | None): The URL for the OpenAI API. api_key (str | None): The API key for the OpenAI API (either OpenAI or HF key). - backend (Literal["openai", "transformers", "tgi", "vllm"]): The backend for the judge + max_tokens (int): The maximum number of tokens to generate. Defaults to 512. + response_format (BaseModel | None): The format of the response from the API, used for the OpenAI and TGI backend. + hf_provider (Literal["black-forest-labs", "cerebras", "cohere", "fal-ai", "fireworks-ai", + "inference-providers", "hyperbolic", "nebius", "novita", "openai", "replicate", "sambanova", "together"] | None): + The HuggingFace provider when using the inference-providers backend. + backend_options (dict | None): Options for the backend. Currently only supported for litellm. Methods: evaluate_answer: Evaluates an answer using the OpenAI API or Transformers library. @@ -103,6 +116,7 @@ def __init__( "together", ] ] = None, + backend_options: dict | None = None, ): self.model = model self.template = templates @@ -122,6 +136,12 @@ def __init__( self.response_format = response_format if not None else DEFAULT_FORMAT + self.backend_options = backend_options or {} + + # Override backend options dictionary with the corresponding dataclass to ensure all specified options are valid + if judge_backend == "litellm": + self.backend_options = LitellmBackendOptions(**self.backend_options) + # Validate that hf_provider is specified when using inference-providers backend if self.backend == "inference-providers" and self.hf_provider is None: raise ValueError("When using 'inference-providers' as backend, you must specify an 'hf_provider'") @@ -286,12 +306,22 @@ def __call_vllm(self, prompt): def __call_litellm(self, prompts): import litellm + if self.backend_options.caching: + from litellm.caching.caching import Cache, LiteLLMCacheType + + litellm.cache = Cache(type=LiteLLMCacheType.DISK) + + # Automatically drop parameters that are not supported by the currently used inference API + litellm.drop_params = True + def __call_api(prompt): error_message = "ERROR: Failed to get response from the API." for _ in range(self.API_MAX_RETRY): try: - max_new_tokens = 512 - if "o1" in self.model or "o3" in self.model or "R1" in self.model: + max_new_tokens = self.max_tokens + + is_reasoning_model = "o1" in self.model or "o3" in self.model or "R1" in self.model + if is_reasoning_model and self.backend_options.increase_max_tokens_for_reasoning: max_new_tokens = min(max_new_tokens * 10, 32000) kwargs = { @@ -319,7 +349,7 @@ def __call_api(prompt): return error_message results = [] - with ThreadPoolExecutor(100) as executor: + with ThreadPoolExecutor(self.backend_options.concurrent_requests) as executor: for entry in tqdm(executor.map(__call_api, prompts), total=len(prompts)): results.append(entry)