Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve chatbot streaming performance with diffs #7102

Merged
merged 36 commits into from Jan 31, 2024
Merged

Conversation

aliabid94
Copy link
Collaborator

@aliabid94 aliabid94 commented Jan 22, 2024

Another piece in fixing the poor performance of chatbot streaming: We would return the entire history of the chatbot instead of just the next token over the network. As the chat history gets longer and longer, this causes the performance of chat to degrade, especially over networks.

Instead of sending the entire updated postprocessed value, we now create the JSON diff between the last generation and the new generation, and yields the diff only. The client merges the diff with a stored last generation , and generates the new generation.

To view the difference in internal API, open the Inspector Network tab and inspect the traffic of the queue/data request while running:

import gradio as gr
import random
import time

with gr.Blocks() as demo:
    chatbot = gr.Chatbot()
    msg = gr.Textbox()
    tokens = gr.Number()
    clear = gr.Button("Clear")

    def user(user_message, history):
        return "", history + [[user_message, None]]

    def bot(history):
        bot_message = random.choice(["How are you?", "I love you", "I'm very hungry"])
        history[-1][1] = ""
        for i, character in enumerate(bot_message):
            print(character)
            history[-1][1] += character
            time.sleep(0.05)
            yield i, history

    msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
        bot, chatbot, [tokens, chatbot], api_name="bot"
    )
    clear.click(lambda: None, None, chatbot, queue=False)

    with gr.Row():
        append_btn = gr.Button("Append")
        replace_btn = gr.Button("Replace")
    output_text = gr.Textbox()

    def append():
        output = "Hello world"
        for i in range(len(output)):
            time.sleep(0.1)
            yield output[:i+1]
    append_btn.click(append, None, output_text)
    def replace():
        output = "Hello world"
        for i in range(len(output)):
            time.sleep(0.1)
            yield output[i]
    replace_btn.click(replace, None, output_text)

    img_btn = gr.Button("Image")
    img = gr.Image()
    def random_color():
        import numpy as np
        for i in range(10):
            time.sleep(0.5)
            color = np.random.randint(0, 255, 3)
            single_color_image = np.ones((100, 100, 3)) * color
            yield single_color_image.astype(np.uint8)

    img_btn.click(random_color, None, img)

demo.launch()

@gradio-pr-bot
Copy link
Contributor

gradio-pr-bot commented Jan 22, 2024

🪼 branch checks and previews

Name Status URL
Spaces ready! Spaces preview
Website ready! Website preview
Storybook ready! Storybook preview
Visual tests 2 failing tests Build review
🦄 Changes detected! Details
📓 Notebooks not matching! Details

The demo notebooks don't match the run.py files. Please run this command from the root of the repo and then commit the changes:

pip install nbformat && cd demo && python generate_notebooks.py

Install Gradio from this PR

pip install https://gradio-builds.s3.amazonaws.com/4d67341431274ab05ed6015d034afb22d489fafa/gradio-4.16.0-py3-none-any.whl

Install Gradio Python Client from this PR

pip install "gradio-client @ git+https://github.com/gradio-app/gradio@4d67341431274ab05ed6015d034afb22d489fafa#subdirectory=client/python"

@gradio-pr-bot
Copy link
Contributor

gradio-pr-bot commented Jan 22, 2024

🦄 change detected

This Pull Request includes changes to the following packages.

Package Version
@gradio/client minor
gradio minor
gradio_client minor

With the following changelog entry.

⚠️ Warning invalid changelog entry.

Changelog entry must be either a paragraph or a paragraph followed by a list:

<type>: <description>

Or

<type>:<description>

- <change-one>
- <change-two>
- <change-three>

If you wish to add a more detailed description, please created a highlight entry instead.

⚠️ The changeset file for this pull request has been modified manually, so the changeset generation bot has been disabled. To go back into automatic mode, delete the changeset file.

Something isn't right?

  • Maintainers can change the version label to modify the version bump.
  • If the bot has failed to detect any changes, or if this pull request needs to update multiple packages to different versions or requires a more comprehensive changelog entry, maintainers can update the changelog file directly.

@abidlabs
Copy link
Member

abidlabs commented Jan 22, 2024

Nice @aliabid94! My only concern with this, as alluded to by @pngwn, is that this makes the updates stateful and thus could cause issues if an intermediate is dropped e.g. due to network issues. I don't think this is a huge issue but we may want to send the "final" update from a generator as a regular update instead of a diff to ensure that the final state of the UI is accurate. We should also make sure the clients are robust and don't crash if they receive a diff that isn't compatible with the current value of the component. But overall, looks like a good approach to me

@aliabid94
Copy link
Collaborator Author

cause issues if an intermediate is dropped e.g. due to network issues.

Is this something we actually ever have to worry about? Why would we lose a message?

Agree that the "process_complete" message should send the final data point and not just the diff.

@abidlabs
Copy link
Member

Is this something we actually ever have to worry about? Why would we lose a message?

not great internet connection? Pretty sure it happens but maybe @pngwn would know if its a real concern

@pngwn
Copy link
Member

pngwn commented Jan 22, 2024

As long as the SSE implementation is spec compliant, things shouldn't ever be out of order, even on a poor connection (or even a dropped connection), unless there is a proxy somewhere doing something weird.

Even in the case of dropped connections, SSE + EventSource have built-in reconnection. In this case, the browser will reconnect with a Last-Event-ID header, allowing the server to send any dropped messages. I don't know if fastAPI handles this or we need to do it manually.

So, it is possible there could be issues, but the chances seem slim, and we could resolve them by other means. I'm not convinced we even need the last message to contain everything if we have everything else in place, really, but I'm not against it.

@abidlabs
Copy link
Member

abidlabs commented Jan 22, 2024

Also I think we should make gr.Textbox, gr.Code, and gr.HTML instances of StreamingDiff. If anyone is yielding into those components, they are likely appending text strings

(h/t @oobabooga)

@aliabid94 aliabid94 marked this pull request as ready for review January 24, 2024 23:13
@aliabid94
Copy link
Collaborator Author

Ready for review!

@@ -1040,3 +1040,49 @@ def __setitem__(self, key: K, value: V) -> None:

def get_cache_folder() -> Path:
return Path(os.environ.get("GRADIO_EXAMPLES_CACHE", "gradio_cached_examples"))


def diff(old, new):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should write some tests for this, and to make sure that diff and apply_diff undo each other.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed!

gradio/utils.py Outdated Show resolved Hide resolved
@abidlabs
Copy link
Member

abidlabs commented Jan 26, 2024

Nice PR @aliabid94! Code looks good and I've tested a bunch of demos and they work great. Its nice that we didn't have to change any of the tests for this PR (testing behavior and not implementation ftw!).

I left a few nits above. Additionally, I was wondering if you had done any performance comparisons? The chatinterface_streaming_echo demo feels to me a little slower, but I could be wrong. Additionally, if we had graphs (plotting the latency as a function of # of characters streamed), that would make for some nice comms

Copy link
Member

@abidlabs abidlabs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll go ahead and approve since I don't see any issues but let's confirm the performance improvements quantitatively

Copy link
Collaborator

@freddyaboulton freddyaboulton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really nice PR @aliabid94 !

Look at the decrease in data transfer for a really long generation:

Main

image

PR

image

I agree with @abidlabs comments about testing the diff and apply_diff methods - I think it's very important since we will compute diff for all generations now which includes things like every etc.

I suspect this may be slower in some cases because computing and sending the diff is more time/data that just passing the old value. Would be good to identify that in the benchmarks @abidlabs proposed.

Again, nice PR 🔥

gradio/utils.py Show resolved Hide resolved
@aliabid94
Copy link
Collaborator Author

aliabid94 commented Jan 27, 2024

So I did some performance benchmarks - turns out when running locally, there's pretty much no difference. This makes sense because there's no cost to a large payload when running locally. There actually seems to be a small cost to using the diffs method on short generations - going to look into this.

When running with share=True, we see difference after a large amount of text has been generated. In the demo we have a history with the numbers 1 - 9999 already presented, and then streaming the numbers 1 - 1000:

at a rate of 200/s
main: 17.25s
PR: 9.88s

at a rate of 50/s:
main: ~28s
PR: ~28s

So this PR will likely only show improvements once all the following conditions are met:

  • over a network (including colab)
  • streaming at a very high rate
  • generating lots of text
import gradio as gr
import time

with gr.Blocks() as demo:
    chatbot = gr.Chatbot()
    msg = gr.Textbox()

    def user(user_message, history):
        user_message = " ".join([str(i) for i in range(10000)])
        return "", history + [[user_message, None]]

    def bot(history):
        count_to = 1000
        history[-1][1] = ""
        for num in range(count_to):
            history[-1][1] += " " + str(num)
            time.sleep(0.005)
            yield history

    msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
        bot, [chatbot], chatbot
    )
    
demo.launch(share=True)

@aliabid94
Copy link
Collaborator Author

aliabid94 commented Jan 30, 2024

Benchmarks time!!! Demo code at the bottom

First we'll look at speed. Tables a little hard to parse. Basically, we have 4 dimesions:

  • whether running locally or spaces
  • whether running on main or this PR
  • what speed the demo is yielding tokens (msg/s)
  • how many tokens are generated overall (msg)
20msg/s 100msg/s 200msg/s
20 msg 200 msg 2000 msg 20 msg 200 msg 10000 msg 20 msg 200 msg 10000 msg 20000 msg
local main 1.17 11.3 112.8 0.336 2.87 143.01 0.23 1.69 79.74 160.4
PR 1.17 11.4 112.8 0.324 2.85 144.01 0.23 1.59 81.1 156
spaces main 1.4 10.39 102.3 0.38 2.32 106.79 0.31 1.33 57 147
PR 1.2 10.45 103.1 0.36 2.35 106.94 0.25 1.48 56.77 114

Takeaways:

  • Running locally, this PR makes no difference from main.
  • Somehow, the demo runs faster on spaces rather than locally when there are lots of messages. I believe this is because much of the slowdown is a result of browser repaint rates, and for some reason (perhaps being in an iframe?) it seems like the "frame rate" of spaces is slower.
  • it is only at 20000 messages generated at 200 msg/s that we see a large difference between main and PR. This makes sense because I was running at a pretty fast internet connection (30 mbps). Yielding a new token on main in the 10k-20k pre-existing message size range would roughly require sending a message of size 15Kb * 200 msg per sec = 24 mbps, so at that level, we were hitting the limits of the internet connection.

When we look at data transfer amount, we see a huge difference in main and this PR:

msgs data (main) data (PR)
20 8kb 8kb
200 123kb 55kb
10000 242MB 2.6MB
20000 1.03GB 5.5MB

yielding 20k tokens takes 1.03 GB on main, vs 5.5MB on this PR! that's a huge difference. (we could still make considerable linear gains in this because for every token, we send tons of extra info e.g. yield time, avg duration etc.)

For gains in speed, we should now focus on improving UI performance rather than network changes.

Demo:

import gradio as gr
import time

with gr.Blocks() as demo:
    with gr.Row():
        messages_per_second = gr.Number(100, label="Messages per second")
        number_of_messages = gr.Number(200, label="Number of messages")

    chatbot = gr.Chatbot()
    start = gr.Button("Start")

    def bot(messages_per_second, number_of_messages):
        history = [[f"Streaming {number_of_messages} msgs at {messages_per_second} msg/s ", ""]]
        start_time = time.time()
        for num in range(number_of_messages):
            history[-1][1] += " " + str(num)
            time.sleep(1 / messages_per_second)
            yield history, None
        end_time = time.time()
        yield history, end_time - start_time
        
    
    with gr.Row():
        expected_duration = gr.Number(label="Expected Duration (w/o Overhead)")
        backend_duration = gr.Number(label="Duration (for Backend)")
        frontend_duration = gr.Number(label="Duration (for Frontend)")

    start.click(fn=None, js="""() => {
        window.start_time = performance.now();
    }""")
    start.click(lambda messages_per_second, number_of_messages: number_of_messages / messages_per_second, [messages_per_second, number_of_messages], [expected_duration])
    start.click(bot, [messages_per_second, number_of_messages], [chatbot, backend_duration]).then(
        fn=None,
        outputs=frontend_duration,
        js="""() => {
        let now = performance.now();
        return (now - window.start_time) / 1000;                                                                                    
    }""")

    
demo.launch()

@abidlabs
Copy link
Member

Thank you @aliabid94! Very interesting findings and nice way to benchmark as well. Let's go ahead with this once we have some unit tests in place.

@abidlabs
Copy link
Member

cc @pseudotensor @oobabooga for visibility

@pseudotensor
Copy link
Contributor

pseudotensor commented Jan 30, 2024

Yes, it's a good idea I recommended, but it won't work for beam search or other cases when the past tokens are modified. So it needs to be optional.

@atesgoral
Copy link
Contributor

atesgoral commented Jan 30, 2024

The talk about "dropped messages" piqued my interest. I've dealt with unexplained dropping of SSE messages in a few places and it always turned out to be some flaw in the Event Stream parser. Since SSE is over TCP, you're not going to have random packets get dropped. But a flawed parser can fail to parse certain TCP chunks properly, leading to some events falling through the cracks.

I've checked the SSE parser (I pedantically call these Event Stream parsers instead) in this repo and it amazingly follows exactly the same approach I took when implementing our own -- by copy-pasting the spec verbatim and filling between the lines with code! :)

The parser here looks rock solid! But it may not be handling a rare-but-possible case that could lead to mishandling of a TCP chunk, therefore creating the illusion of a dropped event (since the parser won't fire it):

When an event chunk happens to be fragmented exactly after the CR of a CRLF sequence, the next chunk may fail to parse, because the junk LF character at the beginning of that chunk will throw the parser off. I have a unit test case that you may want to repeat in this repo to see if your parser fails at this edge case:

https://github.com/Shopify/event_stream_parser/blob/bdaaa26784084ddf16a699ede534a5546824922e/test/event_stream_parser_test.rb#L189

Here's how I explicitly handle it:

https://github.com/Shopify/event_stream_parser/blob/bdaaa26784084ddf16a699ede534a5546824922e/lib/event_stream_parser.rb#L41

Again, this is rare, and your parser might already be implicitly handling it, but it's still worth throwing in an additional unit test to check for this.

Looking forward to this massive peformance improvement landing! Gradio is such a joy!

Copy link
Member

@pngwn pngwn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some small nits but looks good to me! Tested it and it seems to be working well.

client/js/src/utils.ts Outdated Show resolved Hide resolved
client/js/src/utils.ts Outdated Show resolved Hide resolved
client/js/src/client.ts Show resolved Hide resolved
@pseudotensor
Copy link
Contributor

As long as the SSE implementation is spec compliant, things shouldn't ever be out of order, even on a poor connection (or even a dropped connection), unless there is a proxy somewhere doing something weird.

Even in the case of dropped connections, SSE + EventSource have built-in reconnection. In this case, the browser will reconnect with a Last-Event-ID header, allowing the server to send any dropped messages. I don't know if fastAPI handles this or we need to do it manually.

So, it is possible there could be issues, but the chances seem slim, and we could resolve them by other means. I'm not convinced we even need the last message to contain everything if we have everything else in place, really, but I'm not against it.

Just want to bring it up again since my other message wasn't noticed maybe.

But this approach won't work for beam search with LLM or any time the generation might role back the temporary generation.

So it needs to be optional and only possible to choose when not doing beam search or such things.

@pngwn
Copy link
Member

pngwn commented Jan 30, 2024

@pseudotensor I dont think this is a concern because the patches that are generated can update anything in the current history, not just the current stream or future tokens. At the beginning of the request, we send the 'current' history, and at the end of the request we also send the full results. In between we send patches. If anything about the history gets modified we can patch that too.

Here is an example:

Screen.Recording.2024-01-30.at.14.40.17.mov

This behaviour is consistent even if the message being modified is part in the past, it doesn't need to be part of the current generation.

In this screenshot, you can see that everything is generating with these append patches until we modify the history, at that point we both replace that history element as well as appending the next token.

Screenshot 2024-01-30 at 14 42 08

This is the code i used to test with:

import gradio as gr
import random
import time

with gr.Blocks() as demo:
    chatbot = gr.Chatbot()
    msg = gr.Textbox()
    clear = gr.Button("Clear")

    def user(user_message, history):
        return "", history + [[user_message, None]]

    def bot(history):
        bot_message = random.choice(
            ["How are you?", "I love you", "I'm very hungry", "I can change things"]
        )

        history[-1][1] = ""
        i = 0
        for character in bot_message:
            if bot_message == "I can change things" and i == 15:
                print("I have changed")
                history[0][0] = "THIS HAS CHANGED"

            i += 1
            history[-1][1] += character
            time.sleep(0.05)
            yield history

    msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
        bot, chatbot, chatbot
    )
    clear.click(lambda: None, None, chatbot, queue=False)

demo.queue()
if __name__ == "__main__":
    demo.launch()

@aliabid94
Copy link
Collaborator Author

@pseudotensor I dont think this is a concern because the patches that are generated can update anything in the current history, not just the current stream or future tokens

Exactly, the diffs list will use a "replace" operation instead of "append" if a string is changed in a manner different from appending. The diffs should be able to transform any JSON structure into any other JSON structure, and "append" is just a shorthand to compress the size of the diffs.

@aliabid94
Copy link
Collaborator Author

@atesgoral thanks for the write up! I'll do a deeper dive into dropped messages handling in another PR and go through the points you mentions, but I did some prelimary tests and didn't find any dropped messages so hopefully won't run into any issues there

@pseudotensor
Copy link
Contributor

@pseudotensor I dont think this is a concern because the patches that are generated can update anything in the current history, not just the current stream or future tokens

Exactly, the diffs list will use a "replace" operation instead of "append" if a string is changed in a manner different from appending. The diffs should be able to transform any JSON structure into any other JSON structure, and "append" is just a shorthand to compress the size of the diffs.

Ok thanks. I am not fully aware of the implementation. Thanks for considering the issue.

@aliabid94 aliabid94 merged commit 68a54a7 into main Jan 31, 2024
13 checks passed
@aliabid94 aliabid94 deleted the diff_chatbot_streaming branch January 31, 2024 18:39
@pngwn pngwn mentioned this pull request Jan 31, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants