Skip to content

Commit

Permalink
[Spaces] ZeroGPU Queue fix (#5129)
Browse files Browse the repository at this point in the history
* Properly set max_size and concurrency_count for ZeroGPU

* ZeroGPU concurrency_count warning

* Test

* add changeset

* Lint

* Sure, typing_extensions is a third-party ..

* Fix and refactor

* Black

* Empty commit

* add changeset

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
  • Loading branch information
cbensimon and gradio-pr-bot committed Aug 8, 2023
1 parent b84a35b commit 97d804c
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 2 deletions.
5 changes: 5 additions & 0 deletions .changeset/true-camels-shave.md
@@ -0,0 +1,5 @@
---
"gradio": minor
---

fix:[Spaces] ZeroGPU Queue fix
8 changes: 6 additions & 2 deletions gradio/blocks.py
Expand Up @@ -58,6 +58,7 @@
TupleNoPrint,
check_function_inputs_match,
component_or_layout_class,
concurrency_count_warning,
delete_none,
get_cancel_function,
get_continuous_fn,
Expand Down Expand Up @@ -1583,6 +1584,7 @@ def clear(self):
self.children = []
return self

@concurrency_count_warning
@document()
def queue(
self,
Expand Down Expand Up @@ -1624,12 +1626,14 @@ def queue(
warn_deprecation(
"The client_position_to_load_data parameter is deprecated."
)
max_size_default = self.max_threads if utils.is_zero_gpu_space() else None
if utils.is_zero_gpu_space():
concurrency_count = self.max_threads
max_size = 1 if max_size is None else max_size
self._queue = queueing.Queue(
live_updates=status_update_rate == "auto",
concurrency_count=concurrency_count,
update_intervals=status_update_rate if status_update_rate != "auto" else 1,
max_size=max_size_default if max_size is None else max_size,
max_size=max_size,
blocks_dependencies=self.dependencies,
)
self.config = self.get_config_file()
Expand Down
19 changes: 19 additions & 0 deletions gradio/utils.py
Expand Up @@ -39,6 +39,7 @@
from mdit_py_plugins.dollarmath.index import dollarmath_plugin
from mdit_py_plugins.footnote.index import footnote_plugin
from pydantic import BaseModel, parse_obj_as
from typing_extensions import ParamSpec

import gradio
from gradio.context import Context
Expand All @@ -53,6 +54,7 @@
(pkgutil.get_data(__name__, "version.txt") or b"").decode("ascii").strip()
)

P = ParamSpec("P")
T = TypeVar("T")


Expand Down Expand Up @@ -843,6 +845,23 @@ def check_function_inputs_match(fn: Callable, inputs: list, inputs_as_dict: bool
)


def concurrency_count_warning(queue: Callable[P, T]) -> Callable[P, T]:
@functools.wraps(queue)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
_self, *positional = args
if is_zero_gpu_space() and (
len(positional) >= 1 or "concurrency_count" in kwargs
):
warnings.warn(
"Queue concurrency_count on ZeroGPU Spaces cannot be overriden "
"and is always equal to Block's max_threads. "
"Consider setting max_threads value on the Block instead"
)
return queue(*args, **kwargs)

return wrapper


class TupleNoPrint(tuple):
# To remove printing function return in notebook
def __repr__(self):
Expand Down
10 changes: 10 additions & 0 deletions test/test_blocks.py
Expand Up @@ -434,6 +434,16 @@ def test_raise_error_if_event_queued_but_queue_not_enabled(self):

demo.close()

def test_concurrency_count_zero_gpu(self):
os.environ["SPACES_ZERO_GPU"] = "true"
demo = gr.Blocks()
with pytest.warns():
demo.queue(concurrency_count=42)
with pytest.warns():
demo.queue(42)
assert demo._queue.max_thread_count == demo.max_threads
del os.environ["SPACES_ZERO_GPU"]


class TestTempFile:
def test_pil_images_hashed(self, connect, gradio_temp_dir):
Expand Down

1 comment on commit 97d804c

@vercel
Copy link

@vercel vercel bot commented on 97d804c Aug 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.