Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

if user exits browser or tab, gradio not cleaning up process/threads #4016

Closed
1 task done
pseudotensor opened this issue Apr 29, 2023 · 18 comments · Fixed by h2oai/h2ogpt#98
Closed
1 task done

if user exits browser or tab, gradio not cleaning up process/threads #4016

pseudotensor opened this issue Apr 29, 2023 · 18 comments · Fixed by h2oai/h2ogpt#98
Assignees
Labels
bug Something isn't working

Comments

@pseudotensor
Copy link
Contributor

pseudotensor commented Apr 29, 2023

Describe the bug

https://huggingface.co/docs/transformers/internal/generation_utils#transformers.TextIteratorStreamer

When using this, if user exits tab or closes browser, generation continues in background indefinitely.

Is there an existing issue for this?

  • I have searched the existing issues

Reproduction

import gradio as gr

with gr.Blocks() as demo:
    chatbot = gr.Chatbot()
    msg = gr.Textbox()
    clear = gr.Button("Clear")

    from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
    tok = AutoTokenizer.from_pretrained("gpt2")
    model = AutoModelForCausalLM.from_pretrained("gpt2")
    streamer = TextIteratorStreamer(tok)


    def respond(message, chat_history):
        from threading import Thread

        # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
        inputs = tok([message], return_tensors="pt")
        generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1000)
        thread = Thread(target=model.generate, kwargs=generation_kwargs)
        thread.start()
        bot_message = ""
        chat_history.append([message, bot_message])
        for new_text in streamer:
            bot_message += new_text
            chat_history[-1][1] = bot_message
            yield chat_history
        return

    msg.submit(respond, [msg, chatbot], chatbot, queue=True)
    clear.click(lambda: None, None, chatbot, queue=False)

demo.queue(concurrency_count=1)
demo.launch()

Enter "The sky is" and soon after close the tab. The generation will continue. For larger models, this is problem since gradio thinks the user is gone, queue is open, but now threads will overlap.

Also, note that adding a raise StopIteration() has no effect on the model.generate(), it only terminates the respond generation. So GPU usage continues in background. i.e.

import gradio as gr

with gr.Blocks() as demo:
    chatbot = gr.Chatbot()
    msg = gr.Textbox()
    clear = gr.Button("Clear")

    from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
    tok = AutoTokenizer.from_pretrained("gpt2")
    model = AutoModelForCausalLM.from_pretrained("gpt2")
    streamer = TextIteratorStreamer(tok)


    def respond(message, chat_history):
        from threading import Thread

        # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
        inputs = tok([message], return_tensors="pt")
        generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1000)
        thread = Thread(target=model.generate, kwargs=generation_kwargs)
        thread.start()
        bot_message = ""
        chat_history.append([message, bot_message])
        for new_text in streamer:
            bot_message += new_text
            chat_history[-1][1] = bot_message
            if len(bot_message) > 50:
                raise StopIteration()
            yield chat_history
        return

    msg.submit(respond, [msg, chatbot], chatbot, queue=True)
    clear.click(lambda: None, None, chatbot, queue=False)

demo.queue(concurrency_count=1)
demo.launch()

This will lead to a termination of output to gradio chatbot, but continue to use GPU in other thread. It doesn't help to join the thread, that just holds there until generation finishes. Some methods for excepting a thread exist, but that doesn't solve first issue and is not standard to do.

Screenshot

image

Logs

None required.

System Info

latest, chrome browser.

Severity

blocking all usage of gradio

@freddyaboulton
Copy link
Collaborator

Hi @pseudotensor ! Thanks for filing the issue. This is a bit tricky.

There isn't an official API for killing a running thread in python AFAIK so i don't think there's anything gradio can do (plus we don't really know when a dev starts a thread from their event handler).

The best would be to somehow tell generate that it should stop running and that will stop its parent thread. However, that would require TextIteratorStreamer to provide an api for stopping which it currently doesn't.

Tagging @gante here who implemented TextIteratorStreamer - he's planning a rewrite but that won't land for at least a couple of weeks.

@pseudotensor
Copy link
Contributor Author

is it not related to https://github.com/gradio-app/gradio/pull/2783/files ?

@freddyaboulton
Copy link
Collaborator

Hi @pseudotensor. Yea its related. The main difference is that in #2783, the function is not spawning a new thread and is instead yielding each iteration of a python iterator. In that case, gradio can check whether the client is still alive at each iteration. if not, we can terminate the function and clean up. In the case of TextIteratorStreamer, a new thread is spawned which gradio has no idea about and there's no api to tell it to stop running.

@pseudotensor pseudotensor reopened this Jun 16, 2023
@pseudotensor
Copy link
Contributor Author

Didn't mean to close, was closed by other repo.

@pseudotensor
Copy link
Contributor Author

I basically need to know in gradio when user closes tab. How can I do that?

@abidlabs
Copy link
Member

As far as I know, there is no way to detect if a user closes a browser tab. @aliabid94 is that correct?

@NarakuLite
Copy link

I basically need to know in gradio when user closes tab. How can I do that?

+1, I really need that too I have a gradio interface that allow users to upload their files to the storage but the files still there even when the users exit the interface and the storage get full, we need like a change event for gradio.State() trigger when the leave the page it should help a lot of users not only me im i right?

@osbm
Copy link

osbm commented Aug 9, 2023

This would help a lot. My gradio space can only run at 1 user at a time and if a process takes up too much time without a client, i would like to terminate the thread.

@Wwilcz2
Copy link

Wwilcz2 commented Jan 4, 2024

I have the same problem. I have a small testing script:

def thread_function():
    while True:
        print("Thread is running")
        time.sleep(2)

def terminate_thread(thread):
    if not thread.is_alive():
        return

    res = ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(thread.ident), ctypes.py_object(SystemExit))
    if res > 1:
        ctypes.pythonapi.PyThreadState_SetAsyncExc(thread.ident, 0)
        print('Exception raise failure')

def create_thread():
    try:
        # Creating a thread
        thread = threading.Thread(target=thread_function)
        thread.start()

        # Function's main task
        while True:
            print("Function running...")
            yield "function running"
            time.sleep(1)

    except Exception as e:
        print(f"Function interrupted with exception: {e}")

    finally:
        # Ensure the thread is terminated when the function ends
        terminate_thread(thread)
        print("Thread terminated.")

with gr.Blocks(title="Gradio Lab") as demo:
  button = gr.Button("Button")
  box = gr.Textbox(label="Text", interactive = False)
  button.click(create_thread, outputs=box)

# Run the interface
demo.launch(debug=True, share=True)

When a user closes the tab the create_thread() function stops running, but thread_function() continues. I tried to see if create_thread() raises an exception, but it doesn't.

@pseudotensor
Copy link
Contributor Author

pseudotensor commented Feb 14, 2024

Hi @abidlabs ,

If I store something like a cuda model in gr.state() how am I supposed to clear it after user exits? Any ideas how to work-around this gradio issue?

i.e. it's a simple thing:

  1. user logs in
  2. user selects model and it loads. Or maybe its a chroma database or whatever.
  3. user closes tab
    Then no way to clear that model, database, etc.

Why does gradio/fastapi/etc. preserve gr.state forever? Do other app builders like streamlit have some mechanism? Or is there a way to look at all memory across all users and manage it in gradio/fastapi?

The reason why this is important is because the opposite of a global or local state also doesn't work. That is, if one makes a model or database globally, then due to how python works, no matter what I do I cannot free that memory. There's always some reference somewhere in gradio that points to it in some thread.

So this leaves me in a wedged state w.r.t. gradio.

@pseudotensor
Copy link
Contributor Author

pseudotensor commented Feb 14, 2024

GPT-4 suggests something like using a ping on the hash. Ping is used to detect lifetime of the connection. If UI doesn't return ping in some developer set time interval, then the session hash based objects should all be cleaned up.

With code like this:

javascript on blocks head:

<script>
function sendPing() {
    // Use Gradio's provided session_hash directly
    const sessionHash = session_hash; // Gradio automatically sets this variable

    fetch("/ping", {
        method: "POST",
        headers: {
            "Content-Type": "application/json",
        },
        body: JSON.stringify({session_id: sessionHash}),
    })
    .then(response => response.json())
    .then(data => console.log("Ping response:", data))
    .catch(error => console.error('Error sending ping:', error));
}

// Send a ping every 10 seconds
setInterval(sendPing, 10000);
</script>
"""

and something like this in the App.create_app():

from pydantic import BaseModel
class PingRequest(BaseModel):
    session_id: str


# later in create_app:

        # Dictionary to keep track of the last ping time for each session
        active_sessions = {}
        from datetime import datetime, timedelta
        @app.post("/ping")
        async def receive_ping(request_body: PingRequest):
            session_id = request_body.session_id
            # Your logic here
            return {"message": "Ping received"}

        # Background task to clean up inactive sessions
        async def cleanup_sessions():
            while True:
                now = datetime.now()
                for session_id, last_ping in list(active_sessions.items()):
                    if now - last_ping > timedelta(seconds=20):  # Adjust timeout as needed
                        # Logic to clean up the gr.State() object associated with session_id
                        print(f"Cleaning up session: {session_id}")
                        del active_sessions[session_id]
                await asyncio.sleep(10)  # Check every 10 seconds

and

        @app.on_event("startup")
        @app.get("/startup-events")
        async def startup_events():
            if not app.startup_events_triggered:
                app.get_blocks().startup_events()
                cleanup_sessions()
                app.startup_events_triggered = True
                return True
            return False

but it and I don't know how to pass the session has through in the java script even though I can see in other logging that a session hash is used.

Then one would use that clean-up to clean-up any gr.State objects for that given session hash.

I presume already gr.State objects are indexed by the session hash.

@pseudotensor
Copy link
Contributor Author

pseudotensor commented Feb 14, 2024

The problem right now is no matter what, even a small use of gr.State() leads to a memory leak, and over long enough time or with heavy use, eventually one gets OOM. I'm encountering this all the time with gradio.

@pseudotensor
Copy link
Contributor Author

pseudotensor commented Feb 14, 2024

Another GPT-4 idea is to use a websocket that client initiates. When the client goes down (UI or API), then the fastapi/gradio app can cleanup.

I tried to hack something, but it's not working:

    head_js = """
<script type="text/javascript">
        document.addEventListener("DOMContentLoaded", function() {
            // Determine the correct WebSocket scheme based on the page's protocol
            let wsScheme = window.location.protocol === "https:" ? "wss" : "ws";
            // Build the WebSocket URL
            let wsUrl = `${wsScheme}://${window.location.host}/ws`;

            // Create a new WebSocket connection
            const ws = new WebSocket(wsUrl);

            ws.onopen = function() {
                console.log("WebSocket connection established");
                // Send a message to the server if needed
                ws.send("Hello, server!");
            };

            ws.onmessage = function(event) {
                console.log("Message from server: ", event.data);
            };

            ws.onclose = function() {
                console.log("WebSocket connection closed");
            };

            ws.onerror = function(error) {
                console.log("WebSocket error: ", error);
            };
        });
</script>
"""

and:

        from starlette.websockets import WebSocketDisconnect
        @app.websocket("/ws")
        async def websocket_endpoint(websocket: WebSocket):
            await websocket.accept()
            try:
                while True:
                    # You can also process incoming messages here
                    data = await websocket.receive_text()
                    print(f"Message from client: {data}")
            except WebSocketDisconnect:
                print("Client disconnected")
                # Handle cleanup or update status here

I can't even get this part to work, let alone cleanup. But seems ws is good way to go.

And everything here is all hardcoded localhost, which doesn't work:

https://fastapi.tiangolo.com/fa/advanced/websockets/

@pseudotensor
Copy link
Contributor Author

pseudotensor commented Feb 14, 2024

Seems streamlit does this properly: streamlit/streamlit#6166 with websockets.

@pseudotensor
Copy link
Contributor Author

pseudotensor commented Feb 14, 2024

At least this javascript triggers from UI to the app:

    js = """
<script>
const protocol = window.location.protocol === "https:" ? "wss:" : "ws:";
const host = window.location.host;
const wsUrl = `${protocol}//${host}/ws`;

var ws = new WebSocket(wsUrl);

ws.onopen = function() {
  console.log("Connected to the WebSocket at " + wsUrl);
  ws.send("Hello from Gradio!");
};

ws.onmessage = function(event) {
  console.log("Received message: " + event.data);
};

ws.onerror = function(error) {
  console.log("WebSocket error: ", error);
};

ws.onclose = function(event) {
  console.log("WebSocket connection closed: ", event);
};
</script>
"""
    demo = gr.Blocks(theme=theme, css=css_code, title=page_title, analytics_enabled=False, head=js)

with:

        from starlette.websockets import WebSocketDisconnect
        @app.websocket("/ws")
        async def websocket_endpoint(websocket: WebSocket):
            await websocket.accept()
            try:
                while True:
                    # You can also process incoming messages here
                    data = await websocket.receive_text()
                    print(f"Message from client: {data}")
            except WebSocketDisconnect:
                print("Client disconnected")
                # Handle cleanup or update status here
                # fake clean-up, does clean-up on *all* hashes, which is wrong.
                hashes_cleared = []
                for session_hash, state in app.state_holder.session_data.items():
                    hashes_cleared.append(session_hash)
                    app.state_holder.session_data.pop(session_hash)
                    
                print("Client disconnected: %s" % hashes_cleared, flush=True)

and required imports.

@pseudotensor
Copy link
Contributor Author

pseudotensor commented Feb 14, 2024

The above works perfectly to detect when the user closes the browser.

@abidlabs Please, can you guys incorporate this in order to clear the gr.State for that session?

Just need to:

  1. from client (UI or API) pass session_hash along with websocket -- not sure how to do
  2. clear gr.state for that session_hash when reach except WebSocketDisconnect (I can see state info in app during web socket disconnect, so must be possible)
  3. Add asyncio.asleep(1) or something in that while loop.

@pseudotensor
Copy link
Contributor Author

pseudotensor commented Feb 14, 2024

What's really needed is a callback like streamlit for each state once disconnect reached. This way one can properly handle things like cuda models, to put the model off GPU and clear torch cache, but only for that specific state.

I confirmed at least manually that this would work with above approach. During disconnect, I found my model state and did:

               app.state_holder.session_data[session_hash][1]['model'].cpu()
               import torch
                import gc
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                    torch.cuda.ipc_collect()
                    gc.collect()

before popping out the item because cuda would remain in cache unless moved to cpu first and then cleared cache.

i.e.

        def clear_torch_cache():
            import torch
            import gc
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                torch.cuda.ipc_collect()
                gc.collect()

        from starlette.websockets import WebSocketDisconnect
        @app.websocket("/ws")
        async def websocket_endpoint(websocket: WebSocket):
            a = app.auth
            await websocket.accept()
            try:
                while True:
                    # You can also process incoming messages here
                    data = await websocket.receive_text()
                    print(f"Message from client: {data}")
            except WebSocketDisconnect:
                hashes_cleared = []
                for session_hash, state in app.state_holder.session_data.items():
                    hashes_cleared.append(session_hash)
                    states = app.state_holder.session_data[session_hash]
                    if isinstance(states[1], dict) and 'model' in states[1] and hasattr(states[1]['model'], 'cpu'):
                        app.state_holder.session_data[session_hash][1]['model'].cpu()
                        clear_torch_cache()
                    app.state_holder.session_data.pop(session_hash)
            print("Client disconnected: %s" % hashes_cleared, flush=True)

@pseudotensor
Copy link
Contributor Author

For actual gradio changes, best to edit client.ts directly and always establish that websocket.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment