Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions docs/source/evaluating-a-custom-model.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,13 @@ Here's a basic example:
from lighteval.models.abstract_model import LightevalModel
from lighteval.models.model_output import ModelResponse
from lighteval.tasks.requests import Doc, SamplingMethod
from lighteval.utils.cache_management import SampleCache, cached
from lighteval.utils.cache_management import cached

class MyCustomModel(LightevalModel):
def __init__(self, config):
super().__init__(config)
# Initialize your model here...

# Enable caching (recommended)
self._cache = SampleCache(config)

@cached(SamplingMethod.GENERATIVE)
def greedy_until(self, docs: List[Doc]) -> List[ModelResponse]:
# Implement generation logic
Expand Down Expand Up @@ -168,15 +165,14 @@ To enable caching in your custom model:
### Step 1: Import Caching Components
```python
from lighteval.utils.cache_management import SampleCache, cached
from lighteval.utils.cache_management import cached
```
### Step 2: Initialize Cache in Constructor
```python
def __init__(self, config):
super().__init__(config)
# Your initialization code...
self._cache = SampleCache(config)
```
3. Add cache decorators to your prediction methods:
Expand Down
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ dependencies = [
"GitPython>=3.1.41", # for logging
"datasets>=4.0.0",
"pydantic",
"numpy>=2", # pinned to avoid incompatibilities
"hf-xet>=1.1.8", # pinned to avoid failing test suite
"numpy>=2", # pinned to avoid incompatibilities
"hf-xet>=1.1.8", # pinned to avoid failing test suite
# Prettiness
"typer>=0.20.0",
"termcolor==2.3.0",
Expand All @@ -87,6 +87,7 @@ dependencies = [
"httpx>=0.27.2",
"latex2sympy2_extended==1.0.6",
"langcodes",
"diskcache>=5.6.3",
]

[project.optional-dependencies]
Expand Down
2 changes: 2 additions & 0 deletions src/lighteval/logging/evaluation_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def default(self, o): # noqa : C901
return o.__dict__
except Exception:
return str(o)
if hasattr(o, "model_dump"): # is pydantic BaseModel
return o.model_dump()
if callable(o):
if hasattr(o, "__name__"):
return o.__name__
Expand Down
5 changes: 1 addition & 4 deletions src/lighteval/models/dummy/dummy_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from lighteval.models.abstract_model import LightevalModel, ModelConfig
from lighteval.models.model_output import ModelResponse
from lighteval.tasks.requests import Doc, SamplingMethod
from lighteval.utils.cache_management import SampleCache, cached
from lighteval.utils.cache_management import cached


class DummyModelConfig(ModelConfig):
Expand Down Expand Up @@ -70,9 +70,6 @@ def __init__(
self._random = random.Random(self.config.seed)
self._tokenizer = None

# Initialize cache for tokenization and predictions
self._cache = SampleCache(config)

@property
def tokenizer(self):
if not self._tokenizer:
Expand Down
5 changes: 1 addition & 4 deletions src/lighteval/models/endpoints/endpoint_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from lighteval.models.model_output import ModelResponse
from lighteval.tasks.prompt_manager import PromptManager
from lighteval.tasks.requests import Doc, SamplingMethod
from lighteval.utils.cache_management import SampleCache, cached
from lighteval.utils.cache_management import cached


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -268,9 +268,6 @@ def __init__(self, config: Union[InferenceEndpointModelConfig, ServerlessEndpoin
self.generation_parameters = config.generation_parameters
self.generation_config = self.generation_parameters.to_tgi_ie_dict()

# Initialize cache for tokenization and predictions
self._cache = SampleCache(config)

def _create_endpoint( # noqa: C901
self, config: InferenceEndpointModelConfig | ServerlessEndpointModelConfig
) -> Tuple[Union[InferenceEndpoint | None], AsyncInferenceClient, InferenceClient]: # noqa: C901
Expand Down
5 changes: 1 addition & 4 deletions src/lighteval/models/endpoints/inference_providers_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from lighteval.models.model_output import ModelResponse
from lighteval.tasks.prompt_manager import PromptManager
from lighteval.tasks.requests import Doc, SamplingMethod
from lighteval.utils.cache_management import SampleCache, cached
from lighteval.utils.cache_management import cached


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -134,9 +134,6 @@ def __init__(self, config: InferenceProvidersModelConfig) -> None:
use_chat_template=True, tokenizer=self.tokenizer, system_prompt=config.system_prompt
)

# Initialize cache for tokenization and predictions
self._cache = SampleCache(config)

async def __call_api(self, prompt: List[dict], num_samples: int) -> Optional[ChatCompletionOutput]:
"""Make API call with exponential backoff retry logic.

Expand Down
5 changes: 1 addition & 4 deletions src/lighteval/models/endpoints/litellm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from lighteval.models.model_output import ModelResponse
from lighteval.tasks.prompt_manager import PromptManager
from lighteval.tasks.requests import Doc, SamplingMethod
from lighteval.utils.cache_management import SampleCache, cached
from lighteval.utils.cache_management import cached
from lighteval.utils.imports import is_package_available, requires


Expand Down Expand Up @@ -162,9 +162,6 @@ def __init__(self, config: LiteLLMModelConfig) -> None:
use_chat_template=True, tokenizer=self.tokenizer, system_prompt=config.system_prompt
)

# Initialize cache for tokenization and predictions
self._cache = SampleCache(config)

def _prepare_stop_sequence(self, stop_sequence):
"""Prepare and validate stop sequence."""
if self.provider == "anthropic":
Expand Down
4 changes: 0 additions & 4 deletions src/lighteval/models/endpoints/tgi_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from lighteval.models.abstract_model import ModelConfig
from lighteval.models.endpoints.endpoint_model import InferenceEndpointModel
from lighteval.tasks.prompt_manager import PromptManager
from lighteval.utils.cache_management import SampleCache
from lighteval.utils.imports import Extra, is_package_available, requires


Expand Down Expand Up @@ -130,9 +129,6 @@ def __init__(self, config: TGIModelConfig) -> None:
use_chat_template=True, tokenizer=self.tokenizer, system_prompt=config.system_prompt
)

# Initialize cache for tokenization and predictions
self._cache = SampleCache(config)

@requires(Extra.TGI)
def _async_process_request(
self,
Expand Down
5 changes: 1 addition & 4 deletions src/lighteval/models/nanotron/nanotron_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
Doc,
SamplingMethod,
)
from lighteval.utils.cache_management import SampleCache, cached
from lighteval.utils.cache_management import cached
from lighteval.utils.imports import is_package_available
from lighteval.utils.parallelism import find_executable_batch_size
from lighteval.utils.utils import as_list
Expand Down Expand Up @@ -304,9 +304,6 @@ def __init__(
self.pairwise_tokenization = nanotron_config.lighteval_config.tasks.pairwise_tokenization
self.batch_size = nanotron_config.lighteval_config.batch_size

# Initialize cache for tokenization and predictions
self._cache = SampleCache(nanotron_config)

@property
def tokenizer(self):
return self._tokenizer
Expand Down
5 changes: 1 addition & 4 deletions src/lighteval/models/sglang/sglang_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from lighteval.models.utils import _simplify_name, uses_chat_template
from lighteval.tasks.prompt_manager import PromptManager
from lighteval.tasks.requests import Doc, SamplingMethod
from lighteval.utils.cache_management import SampleCache, cached
from lighteval.utils.cache_management import cached
from lighteval.utils.imports import is_package_available, requires


Expand Down Expand Up @@ -163,9 +163,6 @@ def __init__(
self.pairwise_tokenization = config.pairwise_tokenization
self.prompt_manager = PromptManager(self.use_chat_template, self.tokenizer, config.system_prompt)

# Initialize cache for tokenization and predictions
self._cache = SampleCache(config)

@property
def tokenizer(self):
return self._tokenizer
Expand Down
8 changes: 1 addition & 7 deletions src/lighteval/models/transformers/transformers_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
from lighteval.models.utils import _get_dtype, _get_model_sha, _simplify_name, uses_chat_template
from lighteval.tasks.prompt_manager import PromptManager
from lighteval.tasks.requests import Doc, SamplingMethod
from lighteval.utils.cache_management import SampleCache, cached
from lighteval.utils.cache_management import cached
from lighteval.utils.imports import (
is_package_available,
)
Expand Down Expand Up @@ -237,9 +237,6 @@ def __init__(
use_chat_template=self.use_chat_template, tokenizer=self.tokenizer, system_prompt=config.system_prompt
)

# Initialize cache for tokenization and predictions
self._cache = SampleCache(config)

def cleanup(self):
"""Clean up operations if needed, such as closing an endpoint."""
del self.model
Expand Down Expand Up @@ -301,9 +298,6 @@ def from_model(
system_prompt=config.system_prompt if config else None,
)

# Initialize cache for tokenization and predictions
self._cache = SampleCache(config) if config else None

return self

@property
Expand Down
5 changes: 1 addition & 4 deletions src/lighteval/models/transformers/vlm_transformers_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from lighteval.models.utils import _get_dtype, _get_model_sha, _simplify_name
from lighteval.tasks.prompt_manager import PromptManager
from lighteval.tasks.requests import Doc, SamplingMethod
from lighteval.utils.cache_management import SampleCache, cached
from lighteval.utils.cache_management import cached
from lighteval.utils.imports import (
is_package_available,
)
Expand Down Expand Up @@ -177,9 +177,6 @@ def __init__(
use_chat_template=True, tokenizer=self.tokenizer, system_prompt=config.system_prompt
)

# Initialize cache for tokenization and predictions
self._cache = SampleCache(config)

@property
def tokenizer(self):
return self.processor.tokenizer
Expand Down
5 changes: 1 addition & 4 deletions src/lighteval/models/vllm/vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from lighteval.models.utils import _simplify_name, uses_chat_template
from lighteval.tasks.prompt_manager import PromptManager
from lighteval.tasks.requests import Doc, SamplingMethod
from lighteval.utils.cache_management import SampleCache, cached
from lighteval.utils.cache_management import cached
from lighteval.utils.imports import is_package_available, requires


Expand Down Expand Up @@ -216,9 +216,6 @@ def __init__(

self.prompt_manager = PromptManager(self.use_chat_template, self.tokenizer, config.system_prompt)

# Initialize cache for tokenization and predictions
self._cache = SampleCache(config)

@property
def tokenizer(self):
return self._tokenizer
Expand Down
9 changes: 4 additions & 5 deletions src/lighteval/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,6 @@ def __init__(
self.model_config = model_config
self.accelerator, self.parallel_context = self._init_parallelism_manager()
self.model = self._init_model(model_config, model)
# Must occur after model and task init
self.model._cache._init_registry(self.registry)
# Must occur after model init
self._init_accelerator_seeds()

Expand Down Expand Up @@ -308,6 +306,8 @@ async def _run_model_async(self):
model_outputs = await self.model.loglikelihood(docs)
outputs[sampling_method] = model_outputs

self.model.cleanup()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch


return outputs

def _run_model_sync(self):
Expand All @@ -327,6 +327,8 @@ def _run_model_sync(self):
model_outputs = self.model.loglikelihood_rolling(docs)
outputs[sampling_method] = model_outputs

self.model.cleanup()

return outputs

def _run_model(self):
Expand All @@ -339,9 +341,6 @@ def _run_model(self):
else:
outputs = self._run_model_sync()

# Cleaning up the model before running metrics
self.model.cleanup()

return outputs

def _post_process_outputs(self, sampling_method_responses: dict[str, list[ModelResponse]]):
Expand Down
Loading
Loading