Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding streaming support under guidance v0.1 #568

Merged
merged 8 commits into from
Jan 5, 2024

Conversation

hodlen
Copy link
Contributor

@hodlen hodlen commented Jan 3, 2024

Streaming updates of variables is crucial for my application yet not supported after the v0.1 update. I found the key is about capturing over the syntax tree and have implemented that in a very naive manner. As the result, we can use CaptureEvents to receives the latest Model reference every time a token is generated.

Here is how I use it in my application. Basically, I consume the event_queue of a model running in a separate thread and turn it into an iterator.

LMFunc = Callable[[Model], Model]

def get_vars_recursively(lm: Model) -> Dict[str, str]:
    vars = lm._variables
    if not lm._event_parent:
        return vars
    return {**get_vars_recursively(lm._event_parent), **vars}

def iter_guidance_vars(
    lm: Model, lm_func: LMFunc, timeout_secs=5
) -> Generator[Dict[str, str], None, None]:
    with CaptureEvents(lm) as events:
        worker_thread = threading.Thread(target=lambda: lm_func(lm))
        worker_thread.start()
        first_come_timeout = time.time() + timeout_secs
        while time.time() < first_come_timeout:
            if not events.empty():
                break
        while True:
            if not worker_thread.is_alive():
                for lm_res in events.queue:
                    yield get_vars_recursively(lm_res)
                return
            if not events.empty():
                yield get_vars_recursively(events.get())

I am not sure whether it breaks any other in other aspects. Look forward to your feedbacks and suggestions for any improvements or concerns regarding potential impacts.

@hodlen
Copy link
Contributor Author

hodlen commented Jan 3, 2024

It's worth noting that my implementation has not dealt with event bubbling every well. For nested LM calls, like lm += gen_fn(lm), we can receive the latest model reference when the subroutine generates a token, but it's the model inside gen_fn and seems not the child of lm.

So (1) it doesn't include any variable that has been generated; and (2) we cannot track back to lm from such reference if we want to merge variables from its parent.

@slundberg
Copy link
Contributor

Thanks @hodlen ! Streaming is indeed a key feature we need to bring to v0.1. The most tricky part of it is that during a grammar parse there can be multiple valid ways to parse the same text, so if we just yield the current state at the moment it may be ambiguous what the captures should be. So might need to only yield when the parse is not ambiguous. The other complexity is that the captures need to be computed based on the parse, and if we stream the results that means we recompute those captures way more, so there might be a performance consideration to be had there.

I look forward to digging into this more once I get a moment to do so :)

@slundberg
Copy link
Contributor

Okay. So I spent some time on this and pushed some new updates that now support the following API:

import guidance
from guidance import models, gen

lm = models.Transformers('gpt2')

@guidance
def f(lm):
    lm += "Tell me a story about fish." + gen(max_tokens=20)
    return lm

for part in lm.stream() + f():
    print(part)

When we call lm.stream() we get back a ModelStream object that essentially delays all the executions that come after it until it is iterated over, at which point it fires up a thread and runs the execution in that thread and returns partial results as they come in.

I think this is close to being ready to merge, but there seem to be some unit test failures coming from your refactor of the captures into the parser. I like your idea of moving the _record_captures function in the parser, but @hodlen can you check into what is causing the errors? test_save_stop for example now fails in test_gen.py.

thanks!

@slundberg
Copy link
Contributor

@hodlen I went ahead and wrapped up the debugging here and got it ready to merge. In order make things work for now I needed to disable the computation of capture groups during the middle of a grammar parse. If you want to work on adding that support in a way that plays nice with list_append that would be great. For now though I think this is a great first step for programatic streaming so I will merge what we have. Thanks again!

@slundberg slundberg merged commit 0afe462 into guidance-ai:main Jan 5, 2024
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants