Skip to content

Commit

Permalink
Run prompts cache (#479)
Browse files Browse the repository at this point in the history
* Adds an optional cache to run-prompts.

* Not even sure how this broke.

* Extracting CachingPipe so that others can easily make a pipeline with cache support.

* Oops, forgot to check this in. Making use of cachign superclass in the prompt flow.

* Adds an optional cache to run-prompts.

* Extracting CachingPipe so that others can easily make a pipeline with cache support.

* Oops, forgot to check this in. Making use of cachign superclass in the prompt flow.
  • Loading branch information
wpietri committed Jul 18, 2024
1 parent 5c94eea commit 5c71d89
Showing 1 changed file with 10 additions and 22 deletions.
32 changes: 10 additions & 22 deletions modelgauge/prompt_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import diskcache # type: ignore

from modelgauge.pipeline import Source, Pipe, Sink
from modelgauge.pipeline import Source, Pipe, Sink, CachingPipe
from modelgauge.prompt import TextPrompt
from modelgauge.single_turn_prompt_response import PromptWithContext
from modelgauge.sut import PromptResponseSUT, SUT
Expand Down Expand Up @@ -109,34 +109,22 @@ def handle_item(self, item):
self.downstream_put((item, sut_uid))


class NullCache(dict):
def __setitem__(self, __key, __value):
pass

def __getitem__(self, __key):
pass


class PromptSutWorkers(Pipe):
class PromptSutWorkers(CachingPipe):
def __init__(self, suts: dict[str, SUT], workers=None, cache_path=None):
if workers is None:
workers = 8
super().__init__(thread_count=workers)
super().__init__(thread_count=workers, cache_path=cache_path)
self.suts = suts
if cache_path:
self.cache = diskcache.Cache(cache_path)
else:
self.cache = NullCache()

def handle_item(self, item):
def key(self, item):
prompt_item: PromptWithContext
prompt_item, sut_uid = item
return (prompt_item.source_id, prompt_item.prompt.text, sut_uid)

def handle_uncached_item(self, item):
prompt_item: PromptWithContext
prompt_item, sut_uid = item
cache_key = (prompt_item.prompt.text, sut_uid)
if cache_key in self.cache:
response_text = self.cache[cache_key]
else:
response_text = self.call_sut(prompt_item.prompt, self.suts[sut_uid])
self.cache[cache_key] = response_text
response_text = self.call_sut(prompt_item.prompt, self.suts[sut_uid])
return prompt_item, sut_uid, response_text

def call_sut(self, prompt_text: TextPrompt, sut: PromptResponseSUT) -> str:
Expand Down

0 comments on commit 5c71d89

Please sign in to comment.