Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Blocks.load() so that it is in the same style as the other listeners #6126

Merged
merged 14 commits into from
Oct 28, 2023
6 changes: 6 additions & 0 deletions .changeset/pretty-apples-juggle.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"@gradio/app": minor
"gradio": minor
---

feat:Refactor `Blocks.load()` so that it is in the same style as the other listeners
77 changes: 7 additions & 70 deletions gradio/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,14 @@
utils,
wasm_utils,
)
from gradio.blocks_events import BlocksEvents, BlocksMeta
from gradio.context import Context
from gradio.data_classes import FileData
from gradio.events import EventData, EventListener, EventListenerMethod
from gradio.events import (
EventData,
EventListener,
EventListenerMethod,
)
from gradio.exceptions import (
DuplicateBlockError,
InvalidApiNameError,
Expand Down Expand Up @@ -77,7 +82,6 @@
from fastapi.applications import FastAPI

from gradio.components.base import Component
from gradio.events import Dependency

BUILT_IN_THEMES: dict[str, Theme] = {
t.name: t
Expand Down Expand Up @@ -424,7 +428,7 @@ def convert_component_dict_to_list(


@document("launch", "queue", "integrate", "load")
class Blocks(BlockContext):
class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
"""
Blocks is Gradio's low-level API that allows you to create more custom web
applications and demos than Interfaces (yet still entirely in Python).
Expand Down Expand Up @@ -1568,73 +1572,6 @@ def __exit__(self, exc_type: type[BaseException] | None = None, *args):
self.progress_tracking = any(block_fn.tracks_progress for block_fn in self.fns)
self.exited = True

def load(
self: Blocks | None = None,
fn: Callable | None = None,
inputs: Component | list[Component] | None = None,
outputs: Component | list[Component] | None = None,
api_name: str | None | Literal[False] = None,
scroll_to_output: bool = False,
show_progress: Literal["full", "hidden", "minimal"] | None = "full",
queue=None,
batch: bool = False,
max_batch_size: int = 4,
preprocess: bool = True,
postprocess: bool = True,
every: float | None = None,
_js: str | None = None,
) -> Dependency:
"""
Adds an event that runs as soon as the demo loads in the browser. Example usage below.
Parameters:
fn: The function to wrap an interface around. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
inputs: List of gradio.components to use as inputs. If the function takes no inputs, this should be an empty list.
outputs: List of gradio.components to use as inputs. If the function returns no outputs, this should be an empty list.
api_name: Defines how the endpoint appears in the API docs. Can be a string, None, or False. If False, the endpoint will not be exposed in the api docs. If set to None, the endpoint will be exposed in the api docs as an unnamed endpoint, although this behavior will be changed in Gradio 4.0. If set to a string, the endpoint will be exposed in the api docs with the given name.
scroll_to_output: If True, will scroll to output component on completion
show_progress: If True, will show progress animation while pending
queue: If True, will place the request on the queue, if the queue exists
batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. The lists should be of equal length (and be up to length `max_batch_size`). The function is then *required* to return a tuple of lists (even if there is only 1 output component), with each list in the tuple corresponding to one output component.
max_batch_size: Maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True)
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
every: Run this event 'every' number of seconds. Interpreted in seconds. Queue must be enabled.
Example:
import gradio as gr
import datetime
with gr.Blocks() as demo:
def get_time():
return datetime.datetime.now().time()
dt = gr.Textbox(label="Current time")
demo.load(get_time, inputs=None, outputs=dt)
demo.launch()
"""
from gradio.events import Dependency, EventListenerMethod

if Context.root_block is None:
raise AttributeError(
"Cannot call load() outside of a gradio.Blocks context."
)

dep, dep_index = Context.root_block.set_event_trigger(
targets=[EventListenerMethod(self, "load")],
fn=fn,
inputs=inputs,
outputs=outputs,
api_name=api_name,
preprocess=preprocess,
postprocess=postprocess,
scroll_to_output=scroll_to_output,
show_progress=show_progress,
js=_js,
queue=queue,
batch=batch,
max_batch_size=max_batch_size,
every=every,
no_target=True,
)
return Dependency(None, dep, dep_index, fn)

def clear(self):
"""Resets the layout of the Blocks object."""
self.blocks = {}
Expand Down
30 changes: 30 additions & 0 deletions gradio/blocks_events.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from __future__ import annotations

from gradio.component_meta import create_or_modify_pyi
from gradio.events import EventListener, Events

BLOCKS_EVENTS: list[EventListener | str] = [Events.load]


class BlocksMeta(type):
def __new__(cls, name, bases, attrs):
for event in BLOCKS_EVENTS:
trigger = (
event
if isinstance(event, EventListener)
else EventListener(event_name=event)
).copy()
trigger.set_doc(component=name)
attrs[event] = trigger.listener
component_class = super().__new__(cls, name, bases, attrs)
create_or_modify_pyi(BlocksEvents, "BlocksEvents", BLOCKS_EVENTS)
return component_class


class BlocksEvents:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add a docstring here explaining why we have this even though it's empty

"""
This class is used to hold the events for the Blocks component. It is populated dynamically
by the BlocksMeta metaclass.
"""

pass
6 changes: 5 additions & 1 deletion gradio/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def _setup(
):
def event_trigger(
block: Block | None,
fn: Callable | None,
fn: Callable | None | Literal["decorator"] = "decorator",
inputs: Component | list[Component] | set[Component] | None = None,
outputs: Component | list[Component] | None = None,
api_name: str | None | Literal[False] = None,
Expand Down Expand Up @@ -506,6 +506,10 @@ class Events:
callback=lambda block: setattr(block, "likeable", True),
doc="This listener is triggered when the user likes/dislikes from within the {{ component }}. This event has EventData of type gradio.LikeData that carries information, accessible through LikeData.index and LikeData.value. See EventData documentation on how to use this event data.",
)
load = EventListener(
"load",
doc="This listener is triggered when the {{ component }} initially loads in the browser.",
)


class LikeData(EventData):
Expand Down
2 changes: 1 addition & 1 deletion js/app/src/Blocks.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@

// handle load triggers
dependencies.forEach((dep, i) => {
if (dep.targets.length === 1 && dep.targets[0][1] === "load") {
if (dep.targets[0][1] === "load") {
trigger_api_call(i);
}
});
Expand Down
16 changes: 16 additions & 0 deletions test/test_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,6 +853,22 @@ async def test_call_multiple_functions(self):
output = demo("World", fn_index=1) # fn_index must be a keyword argument
assert output == "Hi, World"

@pytest.mark.asyncio
async def test_call_decorated_functions(self):
with gr.Blocks() as demo:
name = gr.Textbox(value="Abubakar")
output = gr.Textbox(label="Output Box")

@name.submit(inputs=name, outputs=output)
@demo.load(inputs=name, outputs=output)
def test(x):
return "Hello " + x

output = await demo.call_function(0, ["Adam"])
assert output["prediction"] == "Hello Adam"
output = await demo.call_function(1, ["Adam"])
assert output["prediction"] == "Hello Adam"

@pytest.mark.asyncio
async def test_call_generator(self):
def generator(x):
Expand Down
3 changes: 2 additions & 1 deletion test/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def greet(name):
return "Hello " + name + "!"

gr.on(
triggers=[name.submit, greet_btn.click],
triggers=[name.submit, greet_btn.click, demo.load],
fn=greet,
inputs=name,
outputs=output,
Expand All @@ -98,6 +98,7 @@ def sum(a, b, c):
assert demo.config["dependencies"][0]["targets"] == [
(name._id, "submit"),
(greet_btn._id, "click"),
(demo._id, "load"),
]
assert demo.config["dependencies"][1]["targets"] == [
(num1._id, "change"),
Expand Down
Loading