From 5c71d8913f91cf3ae8fb5e312830d058e2a68de3 Mon Sep 17 00:00:00 2001 From: William Pietri Date: Thu, 18 Jul 2024 14:29:14 -0700 Subject: [PATCH] Run prompts cache (#479) * Adds an optional cache to run-prompts. * Not even sure how this broke. * Extracting CachingPipe so that others can easily make a pipeline with cache support. * Oops, forgot to check this in. Making use of cachign superclass in the prompt flow. * Adds an optional cache to run-prompts. * Extracting CachingPipe so that others can easily make a pipeline with cache support. * Oops, forgot to check this in. Making use of cachign superclass in the prompt flow. --- modelgauge/prompt_pipeline.py | 32 ++++++++++---------------------- 1 file changed, 10 insertions(+), 22 deletions(-) diff --git a/modelgauge/prompt_pipeline.py b/modelgauge/prompt_pipeline.py index 68819bfd..beeb30bb 100644 --- a/modelgauge/prompt_pipeline.py +++ b/modelgauge/prompt_pipeline.py @@ -5,7 +5,7 @@ import diskcache # type: ignore -from modelgauge.pipeline import Source, Pipe, Sink +from modelgauge.pipeline import Source, Pipe, Sink, CachingPipe from modelgauge.prompt import TextPrompt from modelgauge.single_turn_prompt_response import PromptWithContext from modelgauge.sut import PromptResponseSUT, SUT @@ -109,34 +109,22 @@ def handle_item(self, item): self.downstream_put((item, sut_uid)) -class NullCache(dict): - def __setitem__(self, __key, __value): - pass - - def __getitem__(self, __key): - pass - - -class PromptSutWorkers(Pipe): +class PromptSutWorkers(CachingPipe): def __init__(self, suts: dict[str, SUT], workers=None, cache_path=None): if workers is None: workers = 8 - super().__init__(thread_count=workers) + super().__init__(thread_count=workers, cache_path=cache_path) self.suts = suts - if cache_path: - self.cache = diskcache.Cache(cache_path) - else: - self.cache = NullCache() - def handle_item(self, item): + def key(self, item): + prompt_item: PromptWithContext + prompt_item, sut_uid = item + return (prompt_item.source_id, prompt_item.prompt.text, sut_uid) + + def handle_uncached_item(self, item): prompt_item: PromptWithContext prompt_item, sut_uid = item - cache_key = (prompt_item.prompt.text, sut_uid) - if cache_key in self.cache: - response_text = self.cache[cache_key] - else: - response_text = self.call_sut(prompt_item.prompt, self.suts[sut_uid]) - self.cache[cache_key] = response_text + response_text = self.call_sut(prompt_item.prompt, self.suts[sut_uid]) return prompt_item, sut_uid, response_text def call_sut(self, prompt_text: TextPrompt, sut: PromptResponseSUT) -> str: