-
Notifications
You must be signed in to change notification settings - Fork 385
diskcache for caching #1068
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
diskcache for caching #1068
Conversation
Co-authored-by: Francesco Bertolotti <f14.bertolotti@gmail.com>
|
oh great !! will test it today and come back to you for review :) |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking nice overall, a couple nits - could you also try testing it with the transformers and inference provider backends?
| with rich.progress.Progress( | ||
| "[progress.description]{task.description}", | ||
| rich.progress.BarColumn(), | ||
| "[progress.completed]{task.completed}/{task.total}", | ||
| "•", | ||
| rich.progress.TimeElapsedColumn(), | ||
| "•", | ||
| rich.progress.TimeRemainingColumn(), | ||
| ) as pbar: | ||
| task_id = pbar.add_task("[green]Sending Requests...", total=len(docs)) | ||
|
|
||
| async def track(coro): | ||
| """Wraps a coroutine to update progress bar when done.""" | ||
| result = await coro | ||
| pbar.update(task_id, advance=1) | ||
| return result | ||
|
|
||
| wrapped = [ | ||
| track(self._async_one_item(index=index, doc=doc, generative=generative)) | ||
| for index, doc in enumerate(docs) | ||
| ] | ||
|
|
||
| result = await asyncio.gather(*wrapped) | ||
| return result |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unrelated to current PR, would make sense to have it in its standalone PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will make a separate PR for it
| # Save updated dataset | ||
| dataset = Dataset.from_list(all_samples) | ||
| dataset.to_parquet(str(cache_file)) | ||
| def default_json_encoder(obj): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We already have a json encoder in the utils iirc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am having trouble finding it. Could you point it out for me?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
EnhancedJSONEncoder in src/lighteval/logging/evaluation_tracker.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@NathanHB wdyt of moving it to the utils? it's not logging specific
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That encoder should work perfectly. I need only to add .model_dump() for pydantic objects
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
perfect, ty
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A quick note about the encoder: I noticed that the method is defined as an instance method, which means it expects a self parameter. I can work around this by forcing self to None:
functools.partial(EnhancedJSONEncoder.default, None)However, it might be cleaner to define it as a @staticmethod within EnhancedJSONEncoder instead.
What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think you need this - it's supposed to be used with the cls arg in json.dumps, so instead of doing default=your_func you do cls=the_class
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unless I'm misunderstanding your issue, so really feel free to extend a bit on what you need if relevant :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mb, I didn't even know that you could pass to json.dumps a class.
| if isinstance(docs, Doc): | ||
| docs = [docs] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you could keep as_list
| doc, sampling_method=sampling_method, config=self.config, args=args, kwargs=kwargs | ||
| ) | ||
| if key in cache: | ||
| logger.info("Cache hit") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're going to get too many logs with this, especially on datasets like mmlu with 10K+ samples - maybe just store how many times the cache was hit and log it at the end?
Cache was hit for x documents out of y.
| logger.info("Cache hit") | ||
| results[idx] = cache[key]["response"] | ||
| else: | ||
| logger.info("Cache miss") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
idem
|
|
||
| return wrapper | ||
| def decorator(sampler: Callable): # noqa: C901 | ||
| @functools.wraps(sampler) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sampler is a bit misleading here - maybe model_call would be clearer?
| model_outputs = await self.model.loglikelihood(docs) | ||
| outputs[sampling_method] = model_outputs | ||
|
|
||
| self.model.cleanup() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice catch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just gave it a go locally, it's great ! Thanks for the contrib.
Will merge once nits above are addressed and tests are fixed :)
|
Ok, I think I have addressed all the issues. I am opening another PR for the progress bar. |
|
I have added a last minute commit. I forgot to remove two lines relative to this comment #1068 (comment) |
| if isinstance(docs, Doc): | ||
| docs = [docs] | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you should use as_list not remove entirely ^^"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mb, I misunderstood your comment. the new commit uses as_list.
Overview
This PR migrates the custom
SampleCachesystem to the more robust and well-maintaineddiskcachelibrary.It follows up on issue #1053.
What’s Changed
The legacy
SampleCacheclass and its associatedcacheddecorator have been fully removed and replaced with a newcacheddecorator built on top ofdiskcache.The new decorator:
model.config.cache_dirto instantiate a shared cache viadiskcache.Cache(model.config.cache_dir).Most of the logic lives in
src/lighteval/utils/cache_management.py, where the decorator was rewritten from scratch. All remaining changes simply remove references to the previous caching implementation.Currently, cache writes occur only after a full sampling method completes (which can take hours for certain tasks). In the future, we can decorate inner per-request methods to enable more granular incremental caching.
Additional Improvements
Testing
The implementation has been tested with:
uv run lighteval vllm examples/model_configs/vllm_model_config.yaml aime25 # async + sync vLLM uv run lighteval vllm examples/model_configs/transformers_model.yaml aime25 uv run pytest