New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve chatbot streaming performance with diffs #7102
Conversation
🪼 branch checks and previews
The demo notebooks don't match the run.py files. Please run this command from the root of the repo and then commit the changes: pip install nbformat && cd demo && python generate_notebooks.py Install Gradio from this PR pip install https://gradio-builds.s3.amazonaws.com/4d67341431274ab05ed6015d034afb22d489fafa/gradio-4.16.0-py3-none-any.whl Install Gradio Python Client from this PR pip install "gradio-client @ git+https://github.com/gradio-app/gradio@4d67341431274ab05ed6015d034afb22d489fafa#subdirectory=client/python" |
🦄 change detectedThis Pull Request includes changes to the following packages.
With the following changelog entry.
|
Nice @aliabid94! My only concern with this, as alluded to by @pngwn, is that this makes the updates stateful and thus could cause issues if an intermediate is dropped e.g. due to network issues. I don't think this is a huge issue but we may want to send the "final" update from a generator as a regular update instead of a diff to ensure that the final state of the UI is accurate. We should also make sure the clients are robust and don't crash if they receive a diff that isn't compatible with the current value of the component. But overall, looks like a good approach to me |
Is this something we actually ever have to worry about? Why would we lose a message? Agree that the "process_complete" message should send the final data point and not just the diff. |
not great internet connection? Pretty sure it happens but maybe @pngwn would know if its a real concern |
As long as the SSE implementation is spec compliant, things shouldn't ever be out of order, even on a poor connection (or even a dropped connection), unless there is a proxy somewhere doing something weird. Even in the case of dropped connections, SSE + EventSource have built-in reconnection. In this case, the browser will reconnect with a So, it is possible there could be issues, but the chances seem slim, and we could resolve them by other means. I'm not convinced we even need the last message to contain everything if we have everything else in place, really, but I'm not against it. |
Also I think we should make (h/t @oobabooga) |
Ready for review! |
@@ -1040,3 +1040,49 @@ def __setitem__(self, key: K, value: V) -> None: | |||
|
|||
def get_cache_folder() -> Path: | |||
return Path(os.environ.get("GRADIO_EXAMPLES_CACHE", "gradio_cached_examples")) | |||
|
|||
|
|||
def diff(old, new): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should write some tests for this, and to make sure that diff
and apply_diff
undo each other.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed!
Nice PR @aliabid94! Code looks good and I've tested a bunch of demos and they work great. Its nice that we didn't have to change any of the tests for this PR (testing behavior and not implementation ftw!). I left a few nits above. Additionally, I was wondering if you had done any performance comparisons? The |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll go ahead and approve since I don't see any issues but let's confirm the performance improvements quantitatively
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really nice PR @aliabid94 !
Look at the decrease in data transfer for a really long generation:
Main
PR
I agree with @abidlabs comments about testing the diff
and apply_diff
methods - I think it's very important since we will compute diff for all generations now which includes things like every
etc.
I suspect this may be slower in some cases because computing and sending the diff is more time/data that just passing the old value. Would be good to identify that in the benchmarks @abidlabs proposed.
Again, nice PR 🔥
So I did some performance benchmarks - turns out when running locally, there's pretty much no difference. This makes sense because there's no cost to a large payload when running locally. There actually seems to be a small cost to using the diffs method on short generations - going to look into this. When running with at a rate of 200/s at a rate of 50/s: So this PR will likely only show improvements once all the following conditions are met:
import gradio as gr
import time
with gr.Blocks() as demo:
chatbot = gr.Chatbot()
msg = gr.Textbox()
def user(user_message, history):
user_message = " ".join([str(i) for i in range(10000)])
return "", history + [[user_message, None]]
def bot(history):
count_to = 1000
history[-1][1] = ""
for num in range(count_to):
history[-1][1] += " " + str(num)
time.sleep(0.005)
yield history
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot, [chatbot], chatbot
)
demo.launch(share=True) |
Benchmarks time!!! Demo code at the bottom First we'll look at speed. Tables a little hard to parse. Basically, we have 4 dimesions:
Takeaways:
When we look at data transfer amount, we see a huge difference in main and this PR:
yielding 20k tokens takes 1.03 GB on main, vs 5.5MB on this PR! that's a huge difference. (we could still make considerable linear gains in this because for every token, we send tons of extra info e.g. yield time, avg duration etc.) For gains in speed, we should now focus on improving UI performance rather than network changes. Demo: import gradio as gr
import time
with gr.Blocks() as demo:
with gr.Row():
messages_per_second = gr.Number(100, label="Messages per second")
number_of_messages = gr.Number(200, label="Number of messages")
chatbot = gr.Chatbot()
start = gr.Button("Start")
def bot(messages_per_second, number_of_messages):
history = [[f"Streaming {number_of_messages} msgs at {messages_per_second} msg/s ", ""]]
start_time = time.time()
for num in range(number_of_messages):
history[-1][1] += " " + str(num)
time.sleep(1 / messages_per_second)
yield history, None
end_time = time.time()
yield history, end_time - start_time
with gr.Row():
expected_duration = gr.Number(label="Expected Duration (w/o Overhead)")
backend_duration = gr.Number(label="Duration (for Backend)")
frontend_duration = gr.Number(label="Duration (for Frontend)")
start.click(fn=None, js="""() => {
window.start_time = performance.now();
}""")
start.click(lambda messages_per_second, number_of_messages: number_of_messages / messages_per_second, [messages_per_second, number_of_messages], [expected_duration])
start.click(bot, [messages_per_second, number_of_messages], [chatbot, backend_duration]).then(
fn=None,
outputs=frontend_duration,
js="""() => {
let now = performance.now();
return (now - window.start_time) / 1000;
}""")
demo.launch() |
Thank you @aliabid94! Very interesting findings and nice way to benchmark as well. Let's go ahead with this once we have some unit tests in place. |
cc @pseudotensor @oobabooga for visibility |
Yes, it's a good idea I recommended, but it won't work for beam search or other cases when the past tokens are modified. So it needs to be optional. |
The talk about "dropped messages" piqued my interest. I've dealt with unexplained dropping of SSE messages in a few places and it always turned out to be some flaw in the Event Stream parser. Since SSE is over TCP, you're not going to have random packets get dropped. But a flawed parser can fail to parse certain TCP chunks properly, leading to some events falling through the cracks. I've checked the SSE parser (I pedantically call these Event Stream parsers instead) in this repo and it amazingly follows exactly the same approach I took when implementing our own -- by copy-pasting the spec verbatim and filling between the lines with code! :) The parser here looks rock solid! But it may not be handling a rare-but-possible case that could lead to mishandling of a TCP chunk, therefore creating the illusion of a dropped event (since the parser won't fire it): When an event chunk happens to be fragmented exactly after the CR of a CRLF sequence, the next chunk may fail to parse, because the junk LF character at the beginning of that chunk will throw the parser off. I have a unit test case that you may want to repeat in this repo to see if your parser fails at this edge case: Here's how I explicitly handle it: Again, this is rare, and your parser might already be implicitly handling it, but it's still worth throwing in an additional unit test to check for this. Looking forward to this massive peformance improvement landing! Gradio is such a joy! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some small nits but looks good to me! Tested it and it seems to be working well.
Just want to bring it up again since my other message wasn't noticed maybe. But this approach won't work for beam search with LLM or any time the generation might role back the temporary generation. So it needs to be optional and only possible to choose when not doing beam search or such things. |
@pseudotensor I dont think this is a concern because the patches that are generated can update anything in the current history, not just the current stream or future tokens. At the beginning of the request, we send the 'current' history, and at the end of the request we also send the full results. In between we send patches. If anything about the history gets modified we can patch that too. Here is an example: Screen.Recording.2024-01-30.at.14.40.17.movThis behaviour is consistent even if the message being modified is part in the past, it doesn't need to be part of the current generation. In this screenshot, you can see that everything is generating with these This is the code i used to test with: import gradio as gr
import random
import time
with gr.Blocks() as demo:
chatbot = gr.Chatbot()
msg = gr.Textbox()
clear = gr.Button("Clear")
def user(user_message, history):
return "", history + [[user_message, None]]
def bot(history):
bot_message = random.choice(
["How are you?", "I love you", "I'm very hungry", "I can change things"]
)
history[-1][1] = ""
i = 0
for character in bot_message:
if bot_message == "I can change things" and i == 15:
print("I have changed")
history[0][0] = "THIS HAS CHANGED"
i += 1
history[-1][1] += character
time.sleep(0.05)
yield history
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot, chatbot, chatbot
)
clear.click(lambda: None, None, chatbot, queue=False)
demo.queue()
if __name__ == "__main__":
demo.launch() |
Exactly, the diffs list will use a "replace" operation instead of "append" if a string is changed in a manner different from appending. The diffs should be able to transform any JSON structure into any other JSON structure, and "append" is just a shorthand to compress the size of the diffs. |
@atesgoral thanks for the write up! I'll do a deeper dive into dropped messages handling in another PR and go through the points you mentions, but I did some prelimary tests and didn't find any dropped messages so hopefully won't run into any issues there |
Ok thanks. I am not fully aware of the implementation. Thanks for considering the issue. |
Another piece in fixing the poor performance of chatbot streaming: We would return the entire history of the chatbot instead of just the next token over the network. As the chat history gets longer and longer, this causes the performance of chat to degrade, especially over networks.
Instead of sending the entire updated postprocessed value, we now create the JSON diff between the last generation and the new generation, and yields the diff only. The client merges the diff with a stored last generation , and generates the new generation.
To view the difference in internal API, open the Inspector Network tab and inspect the traffic of the
queue/data
request while running: