Skip to content

Commit

Permalink
Revert "Use ContextVar instead of threading.local()"
Browse files Browse the repository at this point in the history
This reverts commit 2a76eb4.
  • Loading branch information
cbensimon committed Sep 20, 2023
1 parent e858899 commit 1481ce0
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 20 deletions.
4 changes: 2 additions & 2 deletions gradio/blocks.py
Expand Up @@ -91,9 +91,9 @@


def in_event_listener():
from gradio.context import LocalContext
from gradio import context

return LocalContext.in_event_listener.get()
return getattr(context.thread_data, "in_event_listener", False)


def updateable(fn):
Expand Down
7 changes: 2 additions & 5 deletions gradio/context.py
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from contextvars import ContextVar
import threading
from typing import TYPE_CHECKING

if TYPE_CHECKING: # Only import for type checking (is False at runtime).
Expand All @@ -17,7 +17,4 @@ class Context:
hf_token: str | None = None # The token provided when loading private HF repos


class LocalContext:
blocks: ContextVar[Blocks | None] = ContextVar("blocks", default=None)
in_event_listener: ContextVar[bool] = ContextVar("in_event_listener", default=False)
event_id: ContextVar[str | None] = ContextVar("event_id", default=None)
thread_data = threading.local()
13 changes: 6 additions & 7 deletions gradio/helpers.py
Expand Up @@ -1098,23 +1098,22 @@ def __init__(self, target: Block | None, _data: Any):


def log_message(message: str, level: Literal["info", "warning"] = "info"):
from gradio.context import LocalContext
from gradio import context

blocks = LocalContext.blocks.get()
if blocks is None: # Function called outside of Gradio
if not hasattr(context.thread_data, "blocks"): # Function called outside of Gradio
if level == "info":
print(message)
elif level == "warning":
warnings.warn(message)
return
if not blocks.enable_queue:
if not context.thread_data.blocks.enable_queue:
warnings.warn(
f"Queueing must be enabled to issue {level.capitalize()}: '{message}'."
)
return
event_id = LocalContext.event_id.get()
assert event_id
blocks._queue.log_message(event_id=event_id, log=message, level=level)
context.thread_data.blocks._queue.log_message(
event_id=context.thread_data.event_id, log=message, level=level
)


@document()
Expand Down
12 changes: 6 additions & 6 deletions gradio/utils.py
Expand Up @@ -663,16 +663,16 @@ def get_function_with_locals(
fn: Callable, blocks: Blocks, event_id: str | None, in_event_listener: bool
):
def before_fn(blocks, event_id):
from gradio.context import LocalContext
from gradio.context import thread_data

LocalContext.blocks.set(blocks)
LocalContext.in_event_listener.set(in_event_listener)
LocalContext.event_id.set(event_id)
thread_data.blocks = blocks
thread_data.in_event_listener = in_event_listener
thread_data.event_id = event_id

def after_fn():
from gradio.context import LocalContext
from gradio.context import thread_data

LocalContext.in_event_listener.set(False)
thread_data.in_event_listener = False

return function_wrapper(
fn, before_fn=before_fn, before_args=(blocks, event_id), after_fn=after_fn
Expand Down

0 comments on commit 1481ce0

Please sign in to comment.