Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move get_logits and EngineCallResponse out of the Engine.__call__ function so that the remaining parts can be lowered to Rust or C++ in the future and for simplification of LLM servers that operate in batches #647

Merged
merged 6 commits into from
Mar 8, 2024

Conversation

paulbkoch
Copy link
Collaborator

@paulbkoch paulbkoch commented Feb 21, 2024

This PR has two purposes:

  1. We'd like to lower the contents of the Engine.__call__() function to either Rust or C++, however the calls to get_logits and the yields of EngineCallResponse require python. This change moves them outside of the Engine.__call__(...) function leaving the rest of that function lowerable.
  2. The current Engine class works well for servers that respond to a single request at a time, however batched servers need to maintain state for multiple connections at a time, and benefit from synchronizing the calls to get_logits into batched GPU calls. The current architecture halts on the call to get_logits inside the stack of the Engine.__call__ function. This change moves the call to get_logits outside of that function to a location where a batched server can batch the calls together.

…tion of separating the grammar processing that could be lowered to Rust or C++ into its own separate function
…at the contents of the next(...) function can be lowered to Rust or C++ in the future and for simplification of LLM servers that operate in batches
…essing loop of the Engine class so that the sample ordering can be done in python before lowering into C++ or Rust
@slundberg
Copy link
Contributor

slundberg commented Feb 22, 2024

Thanks @paulbkoch ! I am just starting to dig through this, but one high level question first. Since many Model objects share the same engine does this change prevent those model objects from being async friendly or thread safe (not sure if they are thread safe anyway)? just noting that now the engine has more state than just a cache. (perhaps this means we need to create a cheap sub-object that gets created at each call)

# self._cache_state["new_token_ids"].append(sampled_token_ind)
logits = None
while True:
is_done, logits_state, response_state = self.next(logits)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think currently we return each token that is forced as a separate chunk, this in theory allows us to report to the client the probability of each token. It looks like this might make each forced region of bytes into a single chunk is that right? (not the huge issue, but one to note).

Copy link
Collaborator Author

@paulbkoch paulbkoch Feb 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless I introduced a bug, it should return the same number of token responses as the original.

@slundberg
Copy link
Contributor

Do you have an example function showing how batching might work? I was trying to imagine that but figured it would be faster to see what you had in mind :)

@slundberg
Copy link
Contributor

One other thought here, we should consider how this integrates with a speculative decoder while we are refactoring...

@paulbkoch
Copy link
Collaborator Author

paulbkoch commented Feb 22, 2024

Thanks @paulbkoch ! I am just starting to dig through this, but one high level question first. Since many Model objects share the same engine does this change prevent those model objects from being async friendly or thread safe (not sure if they are thread safe anyway)? just noting that now the engine has more state than just a cache. (perhaps this means we need to create a cheap sub-object that gets created at each call)

It’s definitely not thread-safe as written, but the existing trie isn’t either. The way to do this would probably be to have a separate engine object per model object and deepcopy them when copies are made of the model object. It does work currently as a shared object held by multiple Model objects since the state is only valid between the call to start and the last call to next.

@paulbkoch
Copy link
Collaborator Author

paulbkoch commented Feb 22, 2024

Do you have an example function showing how batching might work? I was trying to imagine that but figured it would be faster to see what you had in mind :)

Here's an example of how it would work:

engines = […] # imagine this contains 10 engine objects and each has its own prompt
for i in range(len(engines)):
    # each engine has its own parser and grammar
    engines[i].start(parsers[i], grammars[i])

while True:
    for i in range(len(engines)):
        # For better performance use joblib on the next function
        done[i], logits_state[i], response_state[i] = engines[i].next(batched_logits[i])

    # GPU computes all 10 arrays of logits in a single batch
    batched_logits = GPU.get_batched_logits(engines)
    
    # Do some complicated state management to handle completed grammars
    # by swapping in new grammars waiting on queues
    # and issuing streaming responses through queues.

@paulbkoch
Copy link
Collaborator Author

One other thought here, we should consider how this integrates with a speculative decoder while we are refactoring...

I'm not clear on what you have in mind here, but happy to discuss it further if you think it should impact the design.

@slundberg
Copy link
Contributor

Thanks again Paul! I looked through everything again and it all looks good for now. Merging :)

@slundberg slundberg merged commit 9fcd78b into guidance-ai:main Mar 8, 2024
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants