Skip to content

Commit

Permalink
Use ContextVar instead of threading.local() (#5625)
Browse files Browse the repository at this point in the history
* Use ContextVar instead of threading.local()

* Test

* add changeset

* Revert "Use ContextVar instead of threading.local()"

This reverts commit 2a76eb4.

* delete changeset

* Un-revert "Use ContextVar instead of threading.local()"

This reverts commit 1481ce0.

* add changeset

* Add Request in LocalContext

* Sync + Async test

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
  • Loading branch information
cbensimon and gradio-pr-bot committed Sep 21, 2023
1 parent fb5964f commit 9ccc479
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 19 deletions.
5 changes: 5 additions & 0 deletions .changeset/chatty-adults-reply.md
@@ -0,0 +1,5 @@
---
"gradio": patch
---

fix:Use ContextVar instead of threading.local()
10 changes: 7 additions & 3 deletions gradio/blocks.py
Expand Up @@ -91,9 +91,9 @@


def in_event_listener():
from gradio import context
from gradio.context import LocalContext

return getattr(context.thread_data, "in_event_listener", False)
return LocalContext.in_event_listener.get()


def updateable(fn):
Expand Down Expand Up @@ -1135,7 +1135,11 @@ async def call_function(
start = time.time()

fn = utils.get_function_with_locals(
block_fn.fn, self, event_id, in_event_listener
fn=block_fn.fn,
blocks=self,
event_id=event_id,
in_event_listener=in_event_listener,
request=request,
)

if iterator is None: # If not a generator function that has already run
Expand Down
9 changes: 7 additions & 2 deletions gradio/context.py
Expand Up @@ -2,11 +2,12 @@

from __future__ import annotations

import threading
from contextvars import ContextVar
from typing import TYPE_CHECKING

if TYPE_CHECKING: # Only import for type checking (is False at runtime).
from gradio.blocks import BlockContext, Blocks
from gradio.routes import Request


class Context:
Expand All @@ -17,4 +18,8 @@ class Context:
hf_token: str | None = None # The token provided when loading private HF repos


thread_data = threading.local()
class LocalContext:
blocks: ContextVar[Blocks | None] = ContextVar("blocks", default=None)
in_event_listener: ContextVar[bool] = ContextVar("in_event_listener", default=False)
event_id: ContextVar[str | None] = ContextVar("event_id", default=None)
request: ContextVar[Request | None] = ContextVar("request", default=None)
13 changes: 7 additions & 6 deletions gradio/helpers.py
Expand Up @@ -1095,22 +1095,23 @@ def __init__(self, target: Block | None, _data: Any):


def log_message(message: str, level: Literal["info", "warning"] = "info"):
from gradio import context
from gradio.context import LocalContext

if not hasattr(context.thread_data, "blocks"): # Function called outside of Gradio
blocks = LocalContext.blocks.get()
if blocks is None: # Function called outside of Gradio
if level == "info":
print(message)
elif level == "warning":
warnings.warn(message)
return
if not context.thread_data.blocks.enable_queue:
if not blocks.enable_queue:
warnings.warn(
f"Queueing must be enabled to issue {level.capitalize()}: '{message}'."
)
return
context.thread_data.blocks._queue.log_message(
event_id=context.thread_data.event_id, log=message, level=level
)
event_id = LocalContext.event_id.get()
assert event_id
blocks._queue.log_message(event_id=event_id, log=message, level=level)


@document()
Expand Down
22 changes: 14 additions & 8 deletions gradio/utils.py
Expand Up @@ -46,7 +46,7 @@
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
from gradio.blocks import Block, BlockContext, Blocks
from gradio.components import Component
from gradio.routes import App
from gradio.routes import App, Request

JSON_PATH = os.path.join(os.path.dirname(gradio.__file__), "launches.json")

Expand Down Expand Up @@ -660,19 +660,25 @@ def wrapper(*args, **kwargs):


def get_function_with_locals(
fn: Callable, blocks: Blocks, event_id: str | None, in_event_listener: bool
fn: Callable,
blocks: Blocks,
event_id: str | None,
in_event_listener: bool,
request: Request | None,
):
def before_fn(blocks, event_id):
from gradio.context import thread_data
from gradio.context import LocalContext

thread_data.blocks = blocks
thread_data.in_event_listener = in_event_listener
thread_data.event_id = event_id
LocalContext.blocks.set(blocks)
LocalContext.in_event_listener.set(in_event_listener)
LocalContext.event_id.set(event_id)
LocalContext.request.set(request)

def after_fn():
from gradio.context import thread_data
from gradio.context import LocalContext

thread_data.in_event_listener = False
LocalContext.in_event_listener.set(False)
LocalContext.request.set(None)

return function_wrapper(
fn, before_fn=before_fn, before_args=(blocks, event_id), after_fn=after_fn
Expand Down
44 changes: 44 additions & 0 deletions test/test_helpers.py
@@ -1,3 +1,4 @@
import asyncio
import json
import os
import shutil
Expand Down Expand Up @@ -826,3 +827,46 @@ def greet(s):
["Letter c", "info"],
["Too short!", "warning"],
]


@pytest.mark.asyncio
@pytest.mark.parametrize("async_handler", [True, False])
async def test_info_isolation(async_handler: bool):
async def greet_async(name):
await asyncio.sleep(2)
gr.Info(f"Hello {name}")
return name

def greet_sync(name):
time.sleep(2)
gr.Info(f"Hello {name}")
return name

demo = gr.Interface(greet_async if async_handler else greet_sync, "text", "text")
demo.queue(concurrency_count=2).launch(prevent_thread_lock=True)

async def session_interaction(name, delay=0):
await asyncio.sleep(delay)
async with websockets.connect(
f"{demo.local_url.replace('http', 'ws')}queue/join"
) as ws:
log_messages = []
while True:
msg = json.loads(await ws.recv())
if msg["msg"] == "send_data":
await ws.send(json.dumps({"data": [name], "fn_index": 0}))
if msg["msg"] == "send_hash":
await ws.send(json.dumps({"fn_index": 0, "session_hash": name}))
if msg["msg"] == "log":
log_messages.append(msg["log"])
if msg["msg"] == "process_completed":
break
return log_messages

alice_logs, bob_logs = await asyncio.gather(
session_interaction("Alice"),
session_interaction("Bob", delay=1),
)

assert alice_logs == ["Hello Alice"]
assert bob_logs == ["Hello Bob"]

0 comments on commit 9ccc479

Please sign in to comment.