From 7acea0d631ee2d8b88855e0d95a6c648407fd26e Mon Sep 17 00:00:00 2001 From: Francesco Bertolotti Date: Mon, 17 Nov 2025 13:17:06 +0100 Subject: [PATCH 1/6] graceful shutdown of vllm async Co-authored-by: Francesco Bertolotti --- docs/source/evaluating-a-custom-model.mdx | 8 +- pyproject.toml | 5 +- src/lighteval/models/dummy/dummy_model.py | 5 +- .../models/endpoints/endpoint_model.py | 5 +- .../endpoints/inference_providers_model.py | 5 +- .../models/endpoints/litellm_model.py | 5 +- src/lighteval/models/endpoints/tgi_model.py | 4 - .../models/nanotron/nanotron_model.py | 5 +- src/lighteval/models/sglang/sglang_model.py | 5 +- .../models/transformers/transformers_model.py | 8 +- .../transformers/vlm_transformers_model.py | 5 +- src/lighteval/models/vllm/vllm_model.py | 36 +- src/lighteval/pipeline.py | 9 +- src/lighteval/utils/cache_management.py | 490 ++++-------------- tests/unit/utils/test_caching.py | 88 +--- tests/utils.py | 2 - 16 files changed, 167 insertions(+), 518 deletions(-) diff --git a/docs/source/evaluating-a-custom-model.mdx b/docs/source/evaluating-a-custom-model.mdx index df9b3ea17..b35506572 100644 --- a/docs/source/evaluating-a-custom-model.mdx +++ b/docs/source/evaluating-a-custom-model.mdx @@ -16,16 +16,13 @@ Here's a basic example: from lighteval.models.abstract_model import LightevalModel from lighteval.models.model_output import ModelResponse from lighteval.tasks.requests import Doc, SamplingMethod -from lighteval.utils.cache_management import SampleCache, cached +from lighteval.utils.cache_management import cached class MyCustomModel(LightevalModel): def __init__(self, config): super().__init__(config) # Initialize your model here... - # Enable caching (recommended) - self._cache = SampleCache(config) - @cached(SamplingMethod.GENERATIVE) def greedy_until(self, docs: List[Doc]) -> List[ModelResponse]: # Implement generation logic @@ -168,7 +165,7 @@ To enable caching in your custom model: ### Step 1: Import Caching Components ```python -from lighteval.utils.cache_management import SampleCache, cached +from lighteval.utils.cache_management import cached ``` ### Step 2: Initialize Cache in Constructor @@ -176,7 +173,6 @@ from lighteval.utils.cache_management import SampleCache, cached def __init__(self, config): super().__init__(config) # Your initialization code... - self._cache = SampleCache(config) ``` 3. Add cache decorators to your prediction methods: diff --git a/pyproject.toml b/pyproject.toml index d1cc2c00c..af903f6c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,8 +65,8 @@ dependencies = [ "GitPython>=3.1.41", # for logging "datasets>=4.0.0", "pydantic", - "numpy>=2", # pinned to avoid incompatibilities - "hf-xet>=1.1.8", # pinned to avoid failing test suite + "numpy>=2", # pinned to avoid incompatibilities + "hf-xet>=1.1.8", # pinned to avoid failing test suite # Prettiness "typer>=0.20.0", "termcolor==2.3.0", @@ -87,6 +87,7 @@ dependencies = [ "httpx>=0.27.2", "latex2sympy2_extended==1.0.6", "langcodes", + "diskcache>=5.6.3", ] [project.optional-dependencies] diff --git a/src/lighteval/models/dummy/dummy_model.py b/src/lighteval/models/dummy/dummy_model.py index e0a13b589..3a0fa148e 100644 --- a/src/lighteval/models/dummy/dummy_model.py +++ b/src/lighteval/models/dummy/dummy_model.py @@ -29,7 +29,7 @@ from lighteval.models.abstract_model import LightevalModel, ModelConfig from lighteval.models.model_output import ModelResponse from lighteval.tasks.requests import Doc, SamplingMethod -from lighteval.utils.cache_management import SampleCache, cached +from lighteval.utils.cache_management import cached class DummyModelConfig(ModelConfig): @@ -70,9 +70,6 @@ def __init__( self._random = random.Random(self.config.seed) self._tokenizer = None - # Initialize cache for tokenization and predictions - self._cache = SampleCache(config) - @property def tokenizer(self): if not self._tokenizer: diff --git a/src/lighteval/models/endpoints/endpoint_model.py b/src/lighteval/models/endpoints/endpoint_model.py index 6b08be575..8371ea8cd 100644 --- a/src/lighteval/models/endpoints/endpoint_model.py +++ b/src/lighteval/models/endpoints/endpoint_model.py @@ -49,7 +49,7 @@ from lighteval.models.model_output import ModelResponse from lighteval.tasks.prompt_manager import PromptManager from lighteval.tasks.requests import Doc, SamplingMethod -from lighteval.utils.cache_management import SampleCache, cached +from lighteval.utils.cache_management import cached logger = logging.getLogger(__name__) @@ -268,9 +268,6 @@ def __init__(self, config: Union[InferenceEndpointModelConfig, ServerlessEndpoin self.generation_parameters = config.generation_parameters self.generation_config = self.generation_parameters.to_tgi_ie_dict() - # Initialize cache for tokenization and predictions - self._cache = SampleCache(config) - def _create_endpoint( # noqa: C901 self, config: InferenceEndpointModelConfig | ServerlessEndpointModelConfig ) -> Tuple[Union[InferenceEndpoint | None], AsyncInferenceClient, InferenceClient]: # noqa: C901 diff --git a/src/lighteval/models/endpoints/inference_providers_model.py b/src/lighteval/models/endpoints/inference_providers_model.py index 54790e45b..980bf07aa 100644 --- a/src/lighteval/models/endpoints/inference_providers_model.py +++ b/src/lighteval/models/endpoints/inference_providers_model.py @@ -36,7 +36,7 @@ from lighteval.models.model_output import ModelResponse from lighteval.tasks.prompt_manager import PromptManager from lighteval.tasks.requests import Doc, SamplingMethod -from lighteval.utils.cache_management import SampleCache, cached +from lighteval.utils.cache_management import cached logger = logging.getLogger(__name__) @@ -134,9 +134,6 @@ def __init__(self, config: InferenceProvidersModelConfig) -> None: use_chat_template=True, tokenizer=self.tokenizer, system_prompt=config.system_prompt ) - # Initialize cache for tokenization and predictions - self._cache = SampleCache(config) - async def __call_api(self, prompt: List[dict], num_samples: int) -> Optional[ChatCompletionOutput]: """Make API call with exponential backoff retry logic. diff --git a/src/lighteval/models/endpoints/litellm_model.py b/src/lighteval/models/endpoints/litellm_model.py index 87332d1d7..5430e0de0 100644 --- a/src/lighteval/models/endpoints/litellm_model.py +++ b/src/lighteval/models/endpoints/litellm_model.py @@ -33,7 +33,7 @@ from lighteval.models.model_output import ModelResponse from lighteval.tasks.prompt_manager import PromptManager from lighteval.tasks.requests import Doc, SamplingMethod -from lighteval.utils.cache_management import SampleCache, cached +from lighteval.utils.cache_management import cached from lighteval.utils.imports import is_package_available, requires @@ -162,9 +162,6 @@ def __init__(self, config: LiteLLMModelConfig) -> None: use_chat_template=True, tokenizer=self.tokenizer, system_prompt=config.system_prompt ) - # Initialize cache for tokenization and predictions - self._cache = SampleCache(config) - def _prepare_stop_sequence(self, stop_sequence): """Prepare and validate stop sequence.""" if self.provider == "anthropic": diff --git a/src/lighteval/models/endpoints/tgi_model.py b/src/lighteval/models/endpoints/tgi_model.py index 4fd765b8d..51f7421e3 100644 --- a/src/lighteval/models/endpoints/tgi_model.py +++ b/src/lighteval/models/endpoints/tgi_model.py @@ -31,7 +31,6 @@ from lighteval.models.abstract_model import ModelConfig from lighteval.models.endpoints.endpoint_model import InferenceEndpointModel from lighteval.tasks.prompt_manager import PromptManager -from lighteval.utils.cache_management import SampleCache from lighteval.utils.imports import Extra, is_package_available, requires @@ -130,9 +129,6 @@ def __init__(self, config: TGIModelConfig) -> None: use_chat_template=True, tokenizer=self.tokenizer, system_prompt=config.system_prompt ) - # Initialize cache for tokenization and predictions - self._cache = SampleCache(config) - @requires(Extra.TGI) def _async_process_request( self, diff --git a/src/lighteval/models/nanotron/nanotron_model.py b/src/lighteval/models/nanotron/nanotron_model.py index 7ed6d35eb..3436e00e8 100644 --- a/src/lighteval/models/nanotron/nanotron_model.py +++ b/src/lighteval/models/nanotron/nanotron_model.py @@ -50,7 +50,7 @@ Doc, SamplingMethod, ) -from lighteval.utils.cache_management import SampleCache, cached +from lighteval.utils.cache_management import cached from lighteval.utils.imports import is_package_available from lighteval.utils.parallelism import find_executable_batch_size from lighteval.utils.utils import as_list @@ -304,9 +304,6 @@ def __init__( self.pairwise_tokenization = nanotron_config.lighteval_config.tasks.pairwise_tokenization self.batch_size = nanotron_config.lighteval_config.batch_size - # Initialize cache for tokenization and predictions - self._cache = SampleCache(nanotron_config) - @property def tokenizer(self): return self._tokenizer diff --git a/src/lighteval/models/sglang/sglang_model.py b/src/lighteval/models/sglang/sglang_model.py index e5c0f4d87..ae60691dd 100644 --- a/src/lighteval/models/sglang/sglang_model.py +++ b/src/lighteval/models/sglang/sglang_model.py @@ -34,7 +34,7 @@ from lighteval.models.utils import _simplify_name, uses_chat_template from lighteval.tasks.prompt_manager import PromptManager from lighteval.tasks.requests import Doc, SamplingMethod -from lighteval.utils.cache_management import SampleCache, cached +from lighteval.utils.cache_management import cached from lighteval.utils.imports import is_package_available, requires @@ -163,9 +163,6 @@ def __init__( self.pairwise_tokenization = config.pairwise_tokenization self.prompt_manager = PromptManager(self.use_chat_template, self.tokenizer, config.system_prompt) - # Initialize cache for tokenization and predictions - self._cache = SampleCache(config) - @property def tokenizer(self): return self._tokenizer diff --git a/src/lighteval/models/transformers/transformers_model.py b/src/lighteval/models/transformers/transformers_model.py index fc69bc5de..b6a750450 100644 --- a/src/lighteval/models/transformers/transformers_model.py +++ b/src/lighteval/models/transformers/transformers_model.py @@ -53,7 +53,7 @@ from lighteval.models.utils import _get_dtype, _get_model_sha, _simplify_name, uses_chat_template from lighteval.tasks.prompt_manager import PromptManager from lighteval.tasks.requests import Doc, SamplingMethod -from lighteval.utils.cache_management import SampleCache, cached +from lighteval.utils.cache_management import cached from lighteval.utils.imports import ( is_package_available, ) @@ -237,9 +237,6 @@ def __init__( use_chat_template=self.use_chat_template, tokenizer=self.tokenizer, system_prompt=config.system_prompt ) - # Initialize cache for tokenization and predictions - self._cache = SampleCache(config) - def cleanup(self): """Clean up operations if needed, such as closing an endpoint.""" del self.model @@ -301,9 +298,6 @@ def from_model( system_prompt=config.system_prompt if config else None, ) - # Initialize cache for tokenization and predictions - self._cache = SampleCache(config) if config else None - return self @property diff --git a/src/lighteval/models/transformers/vlm_transformers_model.py b/src/lighteval/models/transformers/vlm_transformers_model.py index 0697ab729..92690da33 100644 --- a/src/lighteval/models/transformers/vlm_transformers_model.py +++ b/src/lighteval/models/transformers/vlm_transformers_model.py @@ -45,7 +45,7 @@ from lighteval.models.utils import _get_dtype, _get_model_sha, _simplify_name from lighteval.tasks.prompt_manager import PromptManager from lighteval.tasks.requests import Doc, SamplingMethod -from lighteval.utils.cache_management import SampleCache, cached +from lighteval.utils.cache_management import cached from lighteval.utils.imports import ( is_package_available, ) @@ -177,9 +177,6 @@ def __init__( use_chat_template=True, tokenizer=self.tokenizer, system_prompt=config.system_prompt ) - # Initialize cache for tokenization and predictions - self._cache = SampleCache(config) - @property def tokenizer(self): return self.processor.tokenizer diff --git a/src/lighteval/models/vllm/vllm_model.py b/src/lighteval/models/vllm/vllm_model.py index 969caf8fa..630891f4c 100644 --- a/src/lighteval/models/vllm/vllm_model.py +++ b/src/lighteval/models/vllm/vllm_model.py @@ -27,6 +27,7 @@ import os from typing import Coroutine, Optional +import rich import torch from pydantic import NonNegativeFloat, NonNegativeInt, PositiveInt from tqdm import tqdm @@ -37,7 +38,7 @@ from lighteval.models.utils import _simplify_name, uses_chat_template from lighteval.tasks.prompt_manager import PromptManager from lighteval.tasks.requests import Doc, SamplingMethod -from lighteval.utils.cache_management import SampleCache, cached +from lighteval.utils.cache_management import cached from lighteval.utils.imports import is_package_available, requires @@ -216,9 +217,6 @@ def __init__( self.prompt_manager = PromptManager(self.use_chat_template, self.tokenizer, config.system_prompt) - # Initialize cache for tokenization and predictions - self._cache = SampleCache(config) - @property def tokenizer(self): return self._tokenizer @@ -544,6 +542,7 @@ class AsyncVLLMModel(VLLMModel): is_async = True def cleanup(self): + self.model.shutdown() gc.collect() destroy_distributed_environment() torch.cuda.empty_cache() @@ -621,11 +620,30 @@ async def _async_one_item( return output async def _async_batch(self, docs: list[Doc], generative: bool) -> list: - processed_requests = [ - self._async_one_item(index=index, doc=doc, generative=generative) for index, doc in enumerate(docs) - ] - results = await asyncio.gather(*processed_requests) - return results + with rich.progress.Progress( + "[progress.description]{task.description}", + rich.progress.BarColumn(), + "[progress.completed]{task.completed}/{task.total}", + "•", + rich.progress.TimeElapsedColumn(), + "•", + rich.progress.TimeRemainingColumn(), + ) as pbar: + task_id = pbar.add_task("[green]Sending Requests...", total=len(docs)) + + async def track(coro): + """Wraps a coroutine to update progress bar when done.""" + result = await coro + pbar.update(task_id, advance=1) + return result + + wrapped = [ + track(self._async_one_item(index=index, doc=doc, generative=generative)) + for index, doc in enumerate(docs) + ] + + result = await asyncio.gather(*wrapped) + return result @cached(SamplingMethod.GENERATIVE) async def greedy_until( diff --git a/src/lighteval/pipeline.py b/src/lighteval/pipeline.py index 1f5da9c14..2edf29b04 100644 --- a/src/lighteval/pipeline.py +++ b/src/lighteval/pipeline.py @@ -144,8 +144,6 @@ def __init__( self.model_config = model_config self.accelerator, self.parallel_context = self._init_parallelism_manager() self.model = self._init_model(model_config, model) - # Must occur after model and task init - self.model._cache._init_registry(self.registry) # Must occur after model init self._init_accelerator_seeds() @@ -308,6 +306,8 @@ async def _run_model_async(self): model_outputs = await self.model.loglikelihood(docs) outputs[sampling_method] = model_outputs + self.model.cleanup() + return outputs def _run_model_sync(self): @@ -327,6 +327,8 @@ def _run_model_sync(self): model_outputs = self.model.loglikelihood_rolling(docs) outputs[sampling_method] = model_outputs + self.model.cleanup() + return outputs def _run_model(self): @@ -339,9 +341,6 @@ def _run_model(self): else: outputs = self._run_model_sync() - # Cleaning up the model before running metrics - self.model.cleanup() - return outputs def _post_process_outputs(self, sampling_method_responses: dict[str, list[ModelResponse]]): diff --git a/src/lighteval/utils/cache_management.py b/src/lighteval/utils/cache_management.py index 962f8b083..65d318383 100644 --- a/src/lighteval/utils/cache_management.py +++ b/src/lighteval/utils/cache_management.py @@ -20,348 +20,41 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +import asyncio +import dataclasses import functools import hashlib import json import logging -import os -from dataclasses import asdict, dataclass -from pathlib import Path -from typing import Callable, List, Set, Tuple, Union - -import pandas as pd -from datasets import Dataset, load_dataset - -from lighteval.models.abstract_model import ModelConfig -from lighteval.models.model_output import ModelResponse -from lighteval.tasks.lighteval_task import LightevalTaskConfig -from lighteval.tasks.registry import Registry -from lighteval.tasks.requests import Doc, SamplingMethod -from lighteval.utils.utils import as_list - - -logger = logging.getLogger(__name__) - - -@dataclass -class TaskID: - """A unique ID for a grouping of task samples. It relies on the task name, - the task config (which gives the task_hash), and the sampling method (linked to - the metric type) - """ - - task_name: str - task_hash: str - sampling_method: SamplingMethod - - def __str__(self): - return f"{self.task_name} ({self.task_hash}, {self.sampling_method.name})" - - def __hash__(self): - return int.from_bytes(hashlib.sha256(str(self).encode()).digest(), byteorder="big") +from dataclasses import asdict +from typing import Callable, List +import diskcache -class SampleCache: - """Disk-based cache for sample evaluation results using HuggingFace datasets. - The model hash is a hash of the model config, to make sure we rerun the eval if any parameter changes - (generation param, model version, etc). +from lighteval.tasks.requests import Doc, SamplingMethod - Cache Structure: - - {cache_dir}/ - - {model_name}/ - - {model_hash}/ - - {task_name}/ - - {task_hash}/ dataset dict, where splits are SamplingMethod - """ - def __init__(self, model_config: ModelConfig): - """Initialize the sample cache. - - Args: - model_config: Configuration for the model being cached - """ - self.model_config = model_config - self.model_hash = self.get_model_hash(model_config) - - self.cache_dir = ( - Path(os.path.expanduser(self.model_config.cache_dir)) / self.model_config.model_name / self.model_hash - ) - self.cache_dir.mkdir(parents=True, exist_ok=True) - - self.registry = None - - self.existing_indices = self._load_cached_indices() - # Caching the task_hashes to avoid grabbing the registry all the time - self._task_hashes = {} - - def _init_registry(self, registry: Registry): - self.registry = registry - - def _load_cached_indices(self) -> dict: - """Loads all indices for samples which are properly cached. We recursively search for all available tasks and files. - - Returns: - dict: Dictionary mapping task names to lists of cached sample indices - """ - logger.info("[CACHING] Initializing data cache") - cached_indices = {} - cache_dir = self.cache_dir - - if not cache_dir.exists(): - return cached_indices - - for cache_file in cache_dir.rglob("*.parquet"): - try: - # cache_file.parts gives all the subfolders of the url, up to the file name - # last 3 are task_name/task_hash/file_name.parquet, so we take -3 and -2 - task_name, task_hash = cache_file.parts[-3:-1] - sampling_method = SamplingMethod[cache_file.stem] # removes the file extension - task_id = TaskID(task_name, task_hash, sampling_method) - - full_dataset = load_dataset("parquet", data_files=str(cache_file), split="train") - sample_ids = [] - for row in full_dataset: - try: - # We only save indices of correctly formatted samples, though this means we need to load each at least once - self._load_sample(row) - cur_sample = row["sample_id"] - sample_ids.append(cur_sample) - except Exception: - continue - - cached_indices[task_id] = sample_ids - logger.info( - f"[CACHING] Loaded {len(sample_ids)} cached indices for task '{str(task_id)} from {cache_file}" - ) - except Exception as e: - logger.warning(f"Error loading cached indices from {cache_file}: {e}") - - return cached_indices - - def get_model_hash(self, model_config: ModelConfig) -> str: - """Create a hash for model configuration. - - Returns: - str: A 16-character hexadecimal hash of the model configuration - """ - # Use Pydantic's model_dump instead of asdict for BaseModel - config_dict = model_config.model_dump() - config_str = json.dumps(config_dict, sort_keys=True, default=str) - return hashlib.sha256(config_str.encode()).hexdigest()[:16] - - def _get_task_hash(self, full_task_name: str) -> str: - """Builds a task_hash from the LightevalTaskConfig loaded from the task name and the registry. - - Args: - full_task_name (str): task name as provided to the registry (supports both task|few_shot and legacy suite|task|few_shot) - - Returns: - str: a hash of the task config in its current state in the registry, or the NO_HASH string if the - registry has not been preloaded - """ - if self.registry is None: - logger.warning( - "The task registry was not provided to the cache config. We can't test if the current task has the same hash as the saved tasks." - ) - return "NO_HASH" - if full_task_name not in self._task_hashes: - parts = full_task_name.split("|") - if len(parts) == 3: - # Legacy: suite|task|few_shot -> ignore suite - _, task_name, _ = parts - elif len(parts) == 2: - task_name, _ = parts - else: - task_name = parts[0] - - task_configs: list[LightevalTaskConfig] = self.registry.task_to_configs[task_name] - # Use deterministic ordering based on string repr - config_strs = sorted([cfg.__str__(lite=True) for cfg in task_configs]) - config_str = "|".join(config_strs) - task_hash = hashlib.sha256(config_str.encode()).hexdigest()[:16] - self._task_hashes[full_task_name] = task_hash - return self._task_hashes[full_task_name] - - def get_cache_path(self, task_id: TaskID) -> Path: - """Get the file path for a specific task's cache file. - - Args: - task_id: TaskID of the task - - Returns: - Path: Path to the cache file for the given task and sample type - """ - return self.cache_dir / task_id.task_name / task_id.task_hash / f"{task_id.sampling_method.name}.parquet" - - def get_task_id(self, task_name: str, sampling_method: SamplingMethod) -> TaskID: - """Returns a unique task indentifier. Depends on the task name, - task version and parameters (from which a hash is derived), and - current sampling method (current metric we look at). - - Args: - task_name (str): Name of the task - sampling_method (SamplingMethod): Sampling used for the current metric - - Returns: - TaskID: A unique identifier for the task - """ - task_hash = self._get_task_hash(task_name) - return TaskID(task_name, task_hash, sampling_method) - - def get_sampling_method(self, sample: dict) -> str: - if len(sample.get("logprobs", [])) > 0: - return SamplingMethod.LOGPROBS - if len(sample.get("text", [])) > 0: - return SamplingMethod.GENERATIVE - return None - - def _load_sample(self, sample: pd.core.series.Series | dict) -> Union[dict, ModelResponse]: - """Load a sample from cached data based on sample type. - - Args: - sample: Raw sample data from cache, arrives as a dataframe row - - Returns: - Union[dict, ModelResponse]: Loaded sample in appropriate format for processing - """ - # If we just use the pandas dict, lists are converted to np arrays which we don't want - if isinstance(sample, pd.core.series.Series): - sample = json.loads(sample.to_json()) - return ModelResponse(**sample["sample"]) - - def _dump_sample(self, result: Union[dict, ModelResponse]) -> dict: - """Dumps the sample in the correct format for file saving - - Args: - result (Union[dict, ModelResponse]): Processed sample to save - - Returns: - dict - """ - return asdict(result) - - def get_samples_to_process_and_cache( - self, docs: List[Doc], sampling_method: SamplingMethod - ) -> Tuple[List[Doc], Set[TaskID]]: - """ - Identify which docs need processing because they are not cached yet, based on cached doc and task indices. - - Returns: - Tuple of (docs_not_cached, tasks_with_cached_samples) where - - docs_not_cached contains docs that need processing - - tasks_with_cached_samples are the tasks that have some cached samples - """ - cached_indices = self.existing_indices - - docs_not_cached = [] - tasks_with_cached_samples = set() - - for doc in docs: - task_id = self.get_task_id(doc.task_name, sampling_method) - try: - if doc.id in cached_indices[task_id]: - tasks_with_cached_samples.add(task_id) - else: - docs_not_cached.append(doc) - except KeyError: # task id or sampling method not yet there - docs_not_cached.append(doc) - - return docs_not_cached, set(tasks_with_cached_samples) - - def get_samples_from_cache( - self, docs: List[Doc], task_ids: List[TaskID] | set[TaskID], sampling_method: SamplingMethod - ) -> List[dict | ModelResponse]: - """Get cached samples for the given docs. - Warning: Assumes all docs and task_names provided are stored in cache, will fail otherwise. - - Returns: - List of cached items - """ - # Load datasets for tasks that have cached docs - task_datasets = {} - - for task_id in task_ids: - if task_id.sampling_method != sampling_method: - continue - cache_file = self.get_cache_path(task_id) - try: - dataset = load_dataset("parquet", data_files=str(cache_file), split="train") - dataset_df = dataset.to_pandas().set_index("sample_id") - task_datasets[task_id] = dataset_df - except Exception as e: - logger.warning(f"Error loading prediction cache for {str(task_id)}: {e}") - - # Build results list - results = [] - - for doc in docs: - task_id = self.get_task_id(doc.task_name, sampling_method) - row = task_datasets[task_id].loc[doc.id] - results.append(self._load_sample(row)) - - return results - - def cache_samples( # noqa C901 - self, - docs: List[Doc], - results: List[dict] | List[ModelResponse], - task_ids: list[TaskID], - sampling_method: SamplingMethod, - ): - """Store new results for samples in docs""" - if not results: - return - - # Prepare newly processed data for dataset - processed_data = {task_id: [] for task_id in task_ids} - for doc, result in zip(docs, results): - task_id = self.get_task_id(doc.task_name, sampling_method) - sample = self._dump_sample(result) - - processed_data[task_id].append({"sample_id": doc.id, "sample": sample}) - processed_data = {task_id: task_data for task_id, task_data in processed_data.items() if task_data} - - # Concatenate it with existing data and save to file - for task_id, task_data in processed_data.items(): - if task_id not in self.existing_indices.keys(): - self.existing_indices[task_id] = {} - - cache_file = self.get_cache_path(task_id) - - # Load existing data if present - existing_data = [] - existing_samples = {} - if cache_file.exists(): - try: - existing_dataset = load_dataset("parquet", data_files=str(cache_file), split="train") - existing_data = existing_dataset.to_list() - except KeyError: - logger.info(f"No data was cached for {str(task_id)}") - except Exception as e: - logger.error(f"Error loading existing prediction cache for {str(task_id)}: {e}") - - existing_samples = {(row["sample_id"], sampling_method) for row in existing_data} - if any((row["sample_id"], sampling_method) in existing_samples for row in task_data): - logger.warning( - "Unexpected behavior: You have reprocessed already cached items - we will ignore the new version." - ) +logger = logging.getLogger(__name__) - # Merge with new data (new data overwrites existing) - # We look at id + sampling method - new_data = [row for row in task_data if (row["sample_id"], sampling_method) not in existing_samples] - all_samples = existing_data + new_data - # Save updated dataset - dataset = Dataset.from_list(all_samples) - dataset.to_parquet(str(cache_file)) +def default_json_encoder(obj): + """returns a string representation for objects not serializable by default json code""" + if dataclasses.is_dataclass(obj): # is dataclass instance + return dataclasses.asdict(obj) + elif hasattr(obj, "model_dump"): # is pydantic BaseModel + return obj.model_dump() + else: + raise TypeError(f"Object of type {type(obj)} is not JSON serializable") - logger.info(f"Cached {len(all_samples)} samples of {str(task_id)} at {str(cache_file)}.") - # Refresh cached indices after storing new samples - self.existing_indices[task_id] = [sample["sample_id"] for sample in all_samples] +def hash_request(doc: Doc, **kwargs) -> str: + """Create a hash for a request based on the doc and additional parameters.""" + return hashlib.sha256( + json.dumps({"doc": doc, "kwargs": kwargs}, sort_keys=True, default=default_json_encoder).encode() + ).hexdigest() -def cached(sampling_method: SamplingMethod = None): # noqa C901 +def cached(sampling_method: None | SamplingMethod = None): # noqa: C901 """ Decorator to cache method results based on Doc inputs. @@ -377,64 +70,87 @@ def greedy_until(self, docs: List[Doc], ...): Callable: A decorator function that wraps the original function with caching functionality """ - def decorator(func: Callable): # noqa C901 - @functools.wraps(func) - def wrapper(self, docs: Union[Doc, List[Doc]], *args, **kwargs): # noqa C901 - docs = as_list(docs) - - # Check if caching is enabled for the model - if not hasattr(self, "_cache") or self._cache is None: - return func(self, docs, *args, **kwargs) - - cache: SampleCache = self._cache - - # Extract task names - task_ids = {cache.get_task_id(doc.task_name, sampling_method) for doc in docs} - - # 1) Identify which samples must be processed because they are not cached - docs_not_cached: List[Doc] - tasks_with_cached_samples: Set[TaskID] - docs_not_cached, tasks_with_cached_samples = cache.get_samples_to_process_and_cache(docs, sampling_method) - - # Log cache statistics - cached_count = len(docs) - len(docs_not_cached) - if cached_count > 0: - logger.info( - f"Cache: {cached_count}/{len(docs)} samples are cached for tasks {', '.join(t_id.task_name for t_id in tasks_with_cached_samples)}" - ) - - # 2) Process not cached docs and save to file - new_results = [] - if docs_not_cached: - tasks_needing_sample_processing = { - cache.get_task_id(doc.task_name, sampling_method) for doc in docs_not_cached - } - logger.info( - f"Cache: Starting to process {len(docs_not_cached)}/{len(docs)} samples (not found in cache) for tasks {','.join(str(t) for t in tasks_needing_sample_processing)}" - ) - new_results = func(self, docs_not_cached, *args, **kwargs) - - # Store new results in file cache - cache.cache_samples( - docs=docs_not_cached, - results=new_results, - task_ids=task_ids, - sampling_method=sampling_method, - ) - - # 3) Create final results by pulling from newly saved file cache - final_cached_results = cache.get_samples_from_cache(docs, task_ids, sampling_method) - - # 4) We only keep samples with the correct sampling method - final_results = [ - s for s in final_cached_results if cache.get_sampling_method(cache._dump_sample(s)) == sampling_method - ] - - if any(r is None for r in final_results): - raise ValueError("Problem while loading and aggregating items from cache.") - - return final_results - - return wrapper + def decorator(sampler: Callable): # noqa: C901 + @functools.wraps(sampler) + def sync_wrapper(self, docs: Doc | List[Doc], *args, **kwargs): + if isinstance(docs, Doc): + docs = [docs] + + results = [None] * len(docs) + with diskcache.Cache(self.config.cache_dir) as cache: + uncached_docs = [] + uncached_idxs = [] + + for idx, doc in enumerate(docs): + key = hash_request( + doc, sampling_method=sampling_method, config=self.config, args=args, kwargs=kwargs + ) + if key in cache: + logger.info("Cache hit") + results[idx] = cache[key]["response"] + else: + logger.info("Cache miss") + uncached_docs.append(doc) + uncached_idxs.append(idx) + + uncached_responses = sampler(self, uncached_docs, *args, **kwargs) + + for idx, doc, res in zip(uncached_idxs, uncached_docs, uncached_responses): + key = hash_request( + doc, sampling_method=sampling_method, config=self.config, args=args, kwargs=kwargs + ) + cache[key] = { + "response": res, + "config": self.config.dict(), + "doc": asdict(doc), + "sampling_method": sampling_method, + "args": args, + "kwargs": kwargs, + } + results[idx] = res + + return results + + @functools.wraps(sampler) + async def async_wrapper(self, docs: Doc | List[Doc], *args, **kwargs): + if isinstance(docs, Doc): + docs = [docs] + + results = [None] * len(docs) + with diskcache.Cache(self.config.cache_dir) as cache: + uncached_docs = [] + uncached_idxs = [] + + for idx, doc in enumerate(docs): + key = hash_request( + doc, sampling_method=sampling_method, config=self.config, args=args, kwargs=kwargs + ) + if key in cache: + logger.info("Cache hit") + results[idx] = cache[key]["response"] + else: + logger.info("Cache miss") + uncached_docs.append(doc) + uncached_idxs.append(idx) + + uncached_responses = await sampler(self, uncached_docs, *args, **kwargs) + + for idx, doc, res in zip(uncached_idxs, uncached_docs, uncached_responses): + key = hash_request( + doc, sampling_method=sampling_method, config=self.config, args=args, kwargs=kwargs + ) + cache[key] = { + "response": res, + "config": self.config.dict(), + "doc": asdict(doc), + "sampling_method": sampling_method, + "args": args, + "kwargs": kwargs, + } + results[idx] = res + + return results + + return sync_wrapper if not asyncio.iscoroutinefunction(sampler) else async_wrapper return decorator diff --git a/tests/unit/utils/test_caching.py b/tests/unit/utils/test_caching.py index 7ab8644be..36d0069db 100644 --- a/tests/unit/utils/test_caching.py +++ b/tests/unit/utils/test_caching.py @@ -25,13 +25,14 @@ from dataclasses import asdict from unittest.mock import Mock, patch +import diskcache import pytest import torch from lighteval.models.abstract_model import LightevalModel from lighteval.models.model_output import ModelResponse from lighteval.tasks.requests import Doc, SamplingMethod -from lighteval.utils.cache_management import SampleCache +from lighteval.utils.cache_management import hash_request from lighteval.utils.imports import Extra, is_package_available @@ -61,46 +62,6 @@ def setUp(self): self.docs.append(doc) self.model_responses.append(model_resp) - def test_cache_directory_structure(self): - """Test that cache directories are created correctly.""" - from lighteval.models.dummy.dummy_model import DummyModelConfig - from lighteval.models.endpoints.endpoint_model import InferenceEndpointModelConfig - from lighteval.models.endpoints.tgi_model import TGIModelConfig - from lighteval.models.sglang.sglang_model import SGLangModelConfig - from lighteval.models.transformers.transformers_model import TransformersModelConfig - from lighteval.models.transformers.vlm_transformers_model import VLMTransformersModelConfig - from lighteval.models.vllm.vllm_model import VLLMModelConfig - - # We skip AdapterModelConfig, DeltaModelConfig because of imports - # We skip FullNanotronConfig as it's not standardized with our other configs, will need to be homogeneized - model_configs = [ - TransformersModelConfig, - VLMTransformersModelConfig, - VLLMModelConfig, - InferenceEndpointModelConfig, - TGIModelConfig, - SGLangModelConfig, - DummyModelConfig, - ] - - for model_config in model_configs: - with self.subTest(model_config=model_config): - with tempfile.TemporaryDirectory() as temp_dir: - model_name = f"test_model_{model_config.__name__}" - # if model_config in [AdapterModelConfig, DeltaModelConfig]: - # config = model_config(model_name=model_name, base_model=model_name + "2", cache_dir=temp_dir) - # else: - config = model_config(model_name=model_name, cache_dir=temp_dir) - - # Create cache with custom directory - cache = SampleCache(config) - - # Check directory structure - folder = cache.cache_dir - self.assertTrue(folder.exists()) - self.assertIn(str(temp_dir), str(folder)) - self.assertIn(model_name, str(folder)) - def test_cache_decorator_presence(self): """Test that @cached decorators are present on the right methods.""" from lighteval.models.dummy.dummy_model import DummyModel @@ -148,33 +109,24 @@ def _test_cache(self, model: LightevalModel, test_cases): process_inputs = getattr(model, function_name) process_inputs(self.docs) - cache: SampleCache = model._cache - - # Check task_id - task_id = cache.get_task_id(self.task_name, sampling_method) - self.assertEqual(task_id.task_name, self.task_name) - self.assertEqual(task_id.sampling_method, sampling_method) - - # Verify cache files were created - cache_file = cache.get_cache_path(task_id) - self.assertTrue(cache_file.exists(), "Cache file not created") - - # Test retrieving from cache - self.assertEqual(cache._load_cached_indices()[task_id], [doc.id for doc in self.docs]) - uncached_docs, tasks_with_cached_samples = cache.get_samples_to_process_and_cache( - docs=self.docs, sampling_method=sampling_method - ) - self.assertEqual(tasks_with_cached_samples, {task_id}) - self.assertEqual( - len(uncached_docs), 0, f"{len(uncached_docs)} documents not found in cache when it should be 0" - ) - - # Verify cached results match original - cached_responses = cache.get_samples_from_cache( - docs=self.docs, task_ids=[task_id], sampling_method=sampling_method - ) - for cached_response, response in zip(cached_responses, self.model_responses): - self.assertEqual(asdict(cached_response), asdict(response)) + with diskcache.Cache(model.config.cache_dir) as cache: + for doc in self.docs: + key = hash_request( + doc, sampling_method=sampling_method, config=model.config, args=[], kwargs={} + ) + self.assertIn(key, cache, f"Document {doc.id} not found in cache after processing") + self.assertEqual( + cache[key]["doc"], asdict(doc), f"Cached doc does not match original for {doc.id}" + ) + self.assertEqual( + cache[key]["sampling_method"], sampling_method, f"Sampling method mismatch for {doc.id}" + ) + self.assertEqual(cache[key]["config"], model.config.dict(), f"Config mismatch for {doc.id}") + self.assertEqual( + asdict(cache[key]["response"]), + asdict(self.model_responses[self.docs.index(doc)]), + f"Response mismatch for {doc.id}", + ) @patch("lighteval.models.transformers.transformers_model.TransformersModel._loglikelihood_tokens") @patch("lighteval.models.transformers.transformers_model.TransformersModel._padded_greedy_until") diff --git a/tests/utils.py b/tests/utils.py index b8c71c76c..9ff3461d7 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -32,7 +32,6 @@ from lighteval.tasks.lighteval_task import LightevalTask from lighteval.tasks.registry import Registry from lighteval.tasks.requests import Doc -from lighteval.utils.cache_management import SampleCache from lighteval.utils.imports import is_package_available @@ -54,7 +53,6 @@ def __init__( self.greedy_until_responses = greedy_until_responses self.loglikelihood_responses = loglikelihood_responses self.loglikelihood_rolling_responses = loglikelihood_rolling_responses - self._cache = SampleCache(self.config) @property def tokenizer(self): From c3125a0f118c1a3cd6f3501eafd09ff84f52235c Mon Sep 17 00:00:00 2001 From: Francesco Bertolotti Date: Thu, 20 Nov 2025 14:35:56 +0100 Subject: [PATCH 2/6] added model_dump() to EnhancedJSONEncoder --- src/lighteval/logging/evaluation_tracker.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/lighteval/logging/evaluation_tracker.py b/src/lighteval/logging/evaluation_tracker.py index 976b21c86..399d080d5 100644 --- a/src/lighteval/logging/evaluation_tracker.py +++ b/src/lighteval/logging/evaluation_tracker.py @@ -72,6 +72,8 @@ def default(self, o): # noqa : C901 return o.__dict__ except Exception: return str(o) + if hasattr(o, "model_dump"): # is pydantic BaseModel + return o.model_dump() if callable(o): if hasattr(o, "__name__"): return o.__name__ From ef980821974324ffbac17d6b75569154298bd4fe Mon Sep 17 00:00:00 2001 From: Francesco Bertolotti Date: Thu, 20 Nov 2025 14:37:23 +0100 Subject: [PATCH 3/6] using EnhancedJSONEncoder instead of custom encoder --- src/lighteval/utils/cache_management.py | 45 ++++++++++--------------- 1 file changed, 18 insertions(+), 27 deletions(-) diff --git a/src/lighteval/utils/cache_management.py b/src/lighteval/utils/cache_management.py index 65d318383..5a1285e6a 100644 --- a/src/lighteval/utils/cache_management.py +++ b/src/lighteval/utils/cache_management.py @@ -21,7 +21,6 @@ # SOFTWARE. import asyncio -import dataclasses import functools import hashlib import json @@ -31,26 +30,17 @@ import diskcache +from lighteval.logging.evaluation_tracker import EnhancedJSONEncoder from lighteval.tasks.requests import Doc, SamplingMethod logger = logging.getLogger(__name__) -def default_json_encoder(obj): - """returns a string representation for objects not serializable by default json code""" - if dataclasses.is_dataclass(obj): # is dataclass instance - return dataclasses.asdict(obj) - elif hasattr(obj, "model_dump"): # is pydantic BaseModel - return obj.model_dump() - else: - raise TypeError(f"Object of type {type(obj)} is not JSON serializable") - - def hash_request(doc: Doc, **kwargs) -> str: """Create a hash for a request based on the doc and additional parameters.""" return hashlib.sha256( - json.dumps({"doc": doc, "kwargs": kwargs}, sort_keys=True, default=default_json_encoder).encode() + json.dumps({"doc": doc, "kwargs": kwargs}, sort_keys=True, cls=EnhancedJSONEncoder).encode() ).hexdigest() @@ -70,30 +60,29 @@ def greedy_until(self, docs: List[Doc], ...): Callable: A decorator function that wraps the original function with caching functionality """ - def decorator(sampler: Callable): # noqa: C901 - @functools.wraps(sampler) - def sync_wrapper(self, docs: Doc | List[Doc], *args, **kwargs): - if isinstance(docs, Doc): - docs = [docs] - + def decorator(model_call: Callable): # noqa: C901 + @functools.wraps(model_call) + def sync_wrapper(self, docs: List[Doc], *args, **kwargs): results = [None] * len(docs) with diskcache.Cache(self.config.cache_dir) as cache: uncached_docs = [] uncached_idxs = [] + cache_hits, cache_misses = 0, 0 for idx, doc in enumerate(docs): key = hash_request( doc, sampling_method=sampling_method, config=self.config, args=args, kwargs=kwargs ) if key in cache: - logger.info("Cache hit") + cache_hits += 1 results[idx] = cache[key]["response"] else: - logger.info("Cache miss") + cache_misses += 1 uncached_docs.append(doc) uncached_idxs.append(idx) - uncached_responses = sampler(self, uncached_docs, *args, **kwargs) + logger.info(f"Cache hits: {cache_hits}, Cache misses: {cache_misses}") + uncached_responses = model_call(self, uncached_docs, *args, **kwargs) for idx, doc, res in zip(uncached_idxs, uncached_docs, uncached_responses): key = hash_request( @@ -111,8 +100,8 @@ def sync_wrapper(self, docs: Doc | List[Doc], *args, **kwargs): return results - @functools.wraps(sampler) - async def async_wrapper(self, docs: Doc | List[Doc], *args, **kwargs): + @functools.wraps(model_call) + async def async_wrapper(self, docs: List[Doc], *args, **kwargs): if isinstance(docs, Doc): docs = [docs] @@ -120,20 +109,22 @@ async def async_wrapper(self, docs: Doc | List[Doc], *args, **kwargs): with diskcache.Cache(self.config.cache_dir) as cache: uncached_docs = [] uncached_idxs = [] + cache_hits, cache_misses = 0, 0 for idx, doc in enumerate(docs): key = hash_request( doc, sampling_method=sampling_method, config=self.config, args=args, kwargs=kwargs ) if key in cache: - logger.info("Cache hit") + cache_hits += 1 results[idx] = cache[key]["response"] else: - logger.info("Cache miss") + cache_misses += 1 uncached_docs.append(doc) uncached_idxs.append(idx) - uncached_responses = await sampler(self, uncached_docs, *args, **kwargs) + logger.info(f"Cache hits: {cache_hits}, Cache misses: {cache_misses}") + uncached_responses = await model_call(self, uncached_docs, *args, **kwargs) for idx, doc, res in zip(uncached_idxs, uncached_docs, uncached_responses): key = hash_request( @@ -151,6 +142,6 @@ async def async_wrapper(self, docs: Doc | List[Doc], *args, **kwargs): return results - return sync_wrapper if not asyncio.iscoroutinefunction(sampler) else async_wrapper + return sync_wrapper if not asyncio.iscoroutinefunction(model_call) else async_wrapper return decorator From e73adee37e289d6e4c2e5ce224cf57ef93a5c9f1 Mon Sep 17 00:00:00 2001 From: Francesco Bertolotti Date: Thu, 20 Nov 2025 14:44:59 +0100 Subject: [PATCH 4/6] reverted progress bar in async vllm --- src/lighteval/models/vllm/vllm_model.py | 31 ++++--------------------- 1 file changed, 5 insertions(+), 26 deletions(-) diff --git a/src/lighteval/models/vllm/vllm_model.py b/src/lighteval/models/vllm/vllm_model.py index 630891f4c..82744c741 100644 --- a/src/lighteval/models/vllm/vllm_model.py +++ b/src/lighteval/models/vllm/vllm_model.py @@ -27,7 +27,6 @@ import os from typing import Coroutine, Optional -import rich import torch from pydantic import NonNegativeFloat, NonNegativeInt, PositiveInt from tqdm import tqdm @@ -542,7 +541,6 @@ class AsyncVLLMModel(VLLMModel): is_async = True def cleanup(self): - self.model.shutdown() gc.collect() destroy_distributed_environment() torch.cuda.empty_cache() @@ -620,30 +618,11 @@ async def _async_one_item( return output async def _async_batch(self, docs: list[Doc], generative: bool) -> list: - with rich.progress.Progress( - "[progress.description]{task.description}", - rich.progress.BarColumn(), - "[progress.completed]{task.completed}/{task.total}", - "•", - rich.progress.TimeElapsedColumn(), - "•", - rich.progress.TimeRemainingColumn(), - ) as pbar: - task_id = pbar.add_task("[green]Sending Requests...", total=len(docs)) - - async def track(coro): - """Wraps a coroutine to update progress bar when done.""" - result = await coro - pbar.update(task_id, advance=1) - return result - - wrapped = [ - track(self._async_one_item(index=index, doc=doc, generative=generative)) - for index, doc in enumerate(docs) - ] - - result = await asyncio.gather(*wrapped) - return result + processed_requests = [ + self._async_one_item(index=index, doc=doc, generative=generative) for index, doc in enumerate(docs) + ] + results = await asyncio.gather(*processed_requests) + return results @cached(SamplingMethod.GENERATIVE) async def greedy_until( From a42babfa2589b3a86d61feeb8f6e0831b3b5d7f9 Mon Sep 17 00:00:00 2001 From: Francesco Bertolotti Date: Thu, 20 Nov 2025 17:00:05 +0100 Subject: [PATCH 5/6] forgot removing two lines --- src/lighteval/utils/cache_management.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/lighteval/utils/cache_management.py b/src/lighteval/utils/cache_management.py index 5a1285e6a..911c5d005 100644 --- a/src/lighteval/utils/cache_management.py +++ b/src/lighteval/utils/cache_management.py @@ -102,9 +102,6 @@ def sync_wrapper(self, docs: List[Doc], *args, **kwargs): @functools.wraps(model_call) async def async_wrapper(self, docs: List[Doc], *args, **kwargs): - if isinstance(docs, Doc): - docs = [docs] - results = [None] * len(docs) with diskcache.Cache(self.config.cache_dir) as cache: uncached_docs = [] From 8639936c410d5a251d267b36a1084ce397814bb5 Mon Sep 17 00:00:00 2001 From: Francesco Bertolotti Date: Thu, 20 Nov 2025 17:44:56 +0100 Subject: [PATCH 6/6] reverted list checking with as_list --- src/lighteval/utils/cache_management.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/lighteval/utils/cache_management.py b/src/lighteval/utils/cache_management.py index 911c5d005..e4ff370f4 100644 --- a/src/lighteval/utils/cache_management.py +++ b/src/lighteval/utils/cache_management.py @@ -32,6 +32,7 @@ from lighteval.logging.evaluation_tracker import EnhancedJSONEncoder from lighteval.tasks.requests import Doc, SamplingMethod +from lighteval.utils.utils import as_list logger = logging.getLogger(__name__) @@ -63,6 +64,7 @@ def greedy_until(self, docs: List[Doc], ...): def decorator(model_call: Callable): # noqa: C901 @functools.wraps(model_call) def sync_wrapper(self, docs: List[Doc], *args, **kwargs): + docs = as_list(docs) results = [None] * len(docs) with diskcache.Cache(self.config.cache_dir) as cache: uncached_docs = [] @@ -102,6 +104,7 @@ def sync_wrapper(self, docs: List[Doc], *args, **kwargs): @functools.wraps(model_call) async def async_wrapper(self, docs: List[Doc], *args, **kwargs): + docs = as_list(docs) results = [None] * len(docs) with diskcache.Cache(self.config.cache_dir) as cache: uncached_docs = []