Skip to content

Commit

Permalink
Improve rendering (#8398)
Browse files Browse the repository at this point in the history
* changes

* add changeset

* changes

* changes

* changes

* changes

* changes

* changes

* changeas

* changes

* add changeset

* changes

* add changeset

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* add changeset

* changes

* cganges

* changes

* changes

* changes

* changes

* add changeset

* changes

* chagnes

* changes

* changes

* changes

* changes

* changes

* js

* remove console log

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* add changeset

* changes

* chnages

* changes

* cnages

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* add changeset

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* add changeset

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* add changeset

* Add `state.change` listener (#8297)

* state changes

* changes

---------

Co-authored-by: Ali Abid <aliabid94@gmail.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>

* changes

* changes

* add changeset

* changes

* changes

* changes

* changes

* changes

* changes

* updates

* changes

* add changeset

* changes

* changes

* add changeset

* fix

* changes

* changes

* changes

* changes

* changes

* changes

* add changeset

* changes

---------

Co-authored-by: Ali Abid <aliabid94@gmail.com>
Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
  • Loading branch information
4 people committed May 29, 2024
1 parent d078621 commit 945ac83
Show file tree
Hide file tree
Showing 17 changed files with 158 additions and 51 deletions.
10 changes: 10 additions & 0 deletions .changeset/fine-pillows-open.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
---
"@gradio/app": patch
"@gradio/client": patch
"@gradio/column": patch
"@gradio/row": patch
"@gradio/statustracker": patch
"gradio": patch
---

feat:Improve rendering
37 changes: 23 additions & 14 deletions client/js/src/utils/submit.ts
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,25 @@ export function submit(
}
}

function handle_render_config(render_config: any): void {
if (!config) return;
let render_id: number = render_config.render_id;
config.components = [
...config.components.filter((c) => c.rendered_in !== render_id),
...render_config.components
];
config.dependencies = [
...config.dependencies.filter((d) => d.rendered_in !== render_id),
...render_config.dependencies
];
fire_event({
type: "render",
data: render_config,
endpoint: _endpoint,
fn_index
});
}

this.handle_blob(config.root, resolved_data, endpoint_info).then(
async (_payload) => {
payload = {
Expand Down Expand Up @@ -201,6 +220,9 @@ export function submit(
event_data,
trigger_id
});
if (output.render_config) {
handle_render_config(output.render_config);
}

fire_event({
type: "status",
Expand Down Expand Up @@ -606,20 +628,7 @@ export function submit(
fn_index
});
if (data.render_config) {
config.components = [
...config.components,
...data.render_config.components
];
config.dependencies = [
...config.dependencies,
...data.render_config.dependencies
];
fire_event({
type: "render",
data: data.render_config,
endpoint: _endpoint,
fn_index
});
handle_render_config(data.render_config);
}

if (complete) {
Expand Down
15 changes: 11 additions & 4 deletions gradio/blocks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import copy
import dataclasses
import hashlib
import inspect
import json
Expand Down Expand Up @@ -123,6 +124,7 @@ def __init__(
self.proxy_url = proxy_url
self.share_token = secrets.token_urlsafe(32)
self.parent: BlockContext | None = None
self.rendered_in: Renderable | None = None
self.is_rendered: bool = False
self._constructor_args: list[dict]
self.state_session_capacity = 10000
Expand Down Expand Up @@ -166,6 +168,7 @@ def render(self):
"""
root_context = get_blocks_context()
render_context = get_render_context()
self.rendered_in = LocalContext.renderable.get()
if root_context is not None and self._id in root_context.blocks:
raise DuplicateBlockError(
f"A block with id: {self._id} has already been rendered in the current Blocks."
Expand Down Expand Up @@ -233,13 +236,17 @@ def get_config(self):
for parameter in signature.parameters.values():
if hasattr(self, parameter.name):
value = getattr(self, parameter.name)
config[parameter.name] = utils.convert_to_dict_if_dataclass(value)
if dataclasses.is_dataclass(value):
value = dataclasses.asdict(value)
config[parameter.name] = value
for e in self.events:
to_add = e.config_data()
if to_add:
config = {**to_add, **config}
config.pop("render", None)
config = {**config, "proxy_url": self.proxy_url, "name": self.get_block_class()}
if self.rendered_in is not None:
config["rendered_in"] = self.rendered_in._id
if (_selectable := getattr(self, "_selectable", None)) is not None:
config["_selectable"] = _selectable
return config
Expand Down Expand Up @@ -468,7 +475,7 @@ def __init__(
self,
fn: Callable | None,
inputs: list[Component],
outputs: list[Component],
outputs: list[Block] | list[Component],
preprocess: bool,
postprocess: bool,
inputs_as_dict: bool,
Expand Down Expand Up @@ -658,7 +665,7 @@ def set_event_trigger(
targets: Sequence[EventListenerMethod],
fn: Callable | None,
inputs: Component | list[Component] | set[Component] | None,
outputs: Component | list[Component] | None,
outputs: Block | list[Block] | list[Component] | None,
preprocess: bool = True,
postprocess: bool = True,
scroll_to_output: bool = False,
Expand Down Expand Up @@ -860,7 +867,7 @@ def get_layout(block: Block):
return {"id": block._id, "children": children_layout}

if renderable:
root_block = self.blocks[renderable.column_id]
root_block = self.blocks[renderable.container_id]
else:
root_block = self.root_block
config["layout"] = get_layout(root_block)
Expand Down
4 changes: 2 additions & 2 deletions gradio/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def event_trigger(
block: Block | None,
fn: Callable | None | Literal["decorator"] = "decorator",
inputs: Component | list[Component] | set[Component] | None = None,
outputs: Component | list[Component] | None = None,
outputs: Block | list[Block] | list[Component] | None = None,
api_name: str | None | Literal[False] = None,
scroll_to_output: bool = False,
show_progress: Literal["full", "minimal", "hidden"] = _show_progress,
Expand Down Expand Up @@ -334,7 +334,7 @@ def on(
triggers: Sequence[Any] | Any | None = None,
fn: Callable | None | Literal["decorator"] = "decorator",
inputs: Component | list[Component] | set[Component] | None = None,
outputs: Component | list[Component] | None = None,
outputs: Block | list[Block] | list[Component] | None = None,
*,
api_name: str | None | Literal[False] = None,
scroll_to_output: bool = False,
Expand Down
3 changes: 3 additions & 0 deletions gradio/layouts/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(
elem_id: str | None = None,
elem_classes: list[str] | str | None = None,
render: bool = True,
show_progress: bool = False,
):
"""
Parameters:
Expand All @@ -48,6 +49,7 @@ def __init__(
elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.
elem_classes: An optional string or list of strings that are assigned as the class of this component in the HTML DOM. Can be used for targeting CSS styles.
render: If False, component will not render be rendered in the Blocks context. Should be used if the intention is to assign event listeners now but render the component later.
show_progress: If True, shows progress animation when being updated.
"""
if scale != round(scale):
warnings.warn(
Expand All @@ -59,6 +61,7 @@ def __init__(
self.variant = variant
if variant == "compact":
self.allow_expected_parents = False
self.show_progress = show_progress
BlockContext.__init__(
self,
visible=visible,
Expand Down
3 changes: 3 additions & 0 deletions gradio/layouts/row.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
elem_classes: list[str] | str | None = None,
render: bool = True,
equal_height: bool = True,
show_progress: bool = False,
):
"""
Parameters:
Expand All @@ -41,11 +42,13 @@ def __init__(
elem_classes: An optional string or list of strings that are assigned as the class of this component in the HTML DOM. Can be used for targeting CSS styles.
render: If False, this layout will not be rendered in the Blocks context. Should be used if the intention is to assign event listeners now but render the component later.
equal_height: If True, makes every child element have equal height
show_progress: If True, shows progress animation when being updated.
"""
self.variant = variant
self.equal_height = equal_height
if variant == "compact":
self.allow_expected_parents = False
self.show_progress = show_progress
BlockContext.__init__(
self,
visible=visible,
Expand Down
34 changes: 25 additions & 9 deletions gradio/renderable.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from gradio.components import Component
from gradio.context import Context, LocalContext
from gradio.events import EventListener, EventListenerMethod
from gradio.layouts import Column
from gradio.layouts import Column, Row


class Renderable:
Expand All @@ -17,14 +17,17 @@ def __init__(
triggers: list[tuple[Block | None, str]],
concurrency_limit: int | None | Literal["default"],
concurrency_id: str | None,
trigger_mode: Literal["once", "multiple", "always_last"] | None,
queue: bool,
):
if Context.root_block is None:
raise ValueError("Reactive render must be inside a Blocks context.")

self._id = len(Context.root_block.renderables)
Context.root_block.renderables.append(self)
self.column = Column(render=False)
self.column_id = Column()._id
self.ContainerClass = Row if isinstance(Context.block, Row) else Column
self.container = self.ContainerClass(show_progress=True)
self.container_id = self.container._id

self.fn = fn
self.inputs = inputs
Expand All @@ -35,11 +38,14 @@ def __init__(
self.triggers,
self.apply,
self.inputs,
None,
self.container,
show_api=False,
concurrency_limit=concurrency_limit,
concurrency_id=concurrency_id,
renderable=self,
trigger_mode=trigger_mode,
postprocess=False,
queue=queue,
)

def apply(self, *args, **kwargs):
Expand All @@ -54,14 +60,14 @@ def apply(self, *args, **kwargs):
for _id in fn_ids_to_remove_from_last_render:
del blocks_config.fns[_id]

column_copy = Column(render=False)
column_copy._id = self.column_id
container_copy = self.ContainerClass(render=False, show_progress=True)
container_copy._id = self.container_id
LocalContext.renderable.set(self)

try:
with column_copy:
with container_copy:
self.fn(*args, **kwargs)
blocks_config.blocks[self.column_id] = column_copy
blocks_config.blocks[self.container_id] = container_copy
finally:
LocalContext.renderable.set(None)

Expand All @@ -71,6 +77,8 @@ def render(
triggers: list[EventListener] | EventListener | None = None,
concurrency_limit: int | None | Literal["default"] = None,
concurrency_id: str | None = None,
trigger_mode: Literal["once", "multiple", "always_last"] | None = "always_last",
queue: bool = True,
):
if Context.root_block is None:
raise ValueError("Reactive render must be inside a Blocks context.")
Expand All @@ -92,7 +100,15 @@ def render(
]

def wrapper_function(fn):
Renderable(fn, inputs, _triggers, concurrency_limit, concurrency_id)
Renderable(
fn,
inputs,
_triggers,
concurrency_limit,
concurrency_id,
trigger_mode,
queue,
)
return fn

return wrapper_function
2 changes: 1 addition & 1 deletion gradio/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,7 +860,7 @@ async def queue_data(
session_hash: str,
):
def process_msg(message: EventMessage) -> str:
return f"data: {orjson.dumps(message.model_dump()).decode('utf-8')}\n\n"
return f"data: {orjson.dumps(message.model_dump(), default=str).decode('utf-8')}\n\n"

return await queue_data_helper(request, session_hash, process_msg)

Expand Down
7 changes: 0 additions & 7 deletions gradio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import ast
import asyncio
import copy
import dataclasses
import functools
import importlib
import importlib.util
Expand Down Expand Up @@ -1228,12 +1227,6 @@ def get_extension_from_file_path_or_url(file_path_or_url: str) -> str:
return file_extension[1:] if file_extension else ""


def convert_to_dict_if_dataclass(value):
if dataclasses.is_dataclass(value):
return dataclasses.asdict(value)
return value


K = TypeVar("K")
V = TypeVar("V")

Expand Down
21 changes: 10 additions & 11 deletions js/app/src/init.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ export function create_components(): {
let _components: ComponentMeta[] = [];
let app: client_return;
let keyed_component_values: Record<string | number, any> = {};
let rendered_fns_per_render_id: Record<number, number[]> = {};
let _rootNode: ComponentMeta;

function create_layout({
Expand Down Expand Up @@ -157,16 +156,6 @@ export function create_components(): {
constructor_map.set(k, v);
});

let previous_rendered_fn_ids = rendered_fns_per_render_id[render_id] || [];
Object.values(_target_map).forEach((event_fn_ids_map) => {
Object.values(event_fn_ids_map).forEach((fn_ids) => {
previous_rendered_fn_ids.forEach((fn_id) => {
if (fn_ids.includes(fn_id)) {
fn_ids.splice(fn_ids.indexOf(fn_id), 1);
}
});
});
});
_target_map = {};

dependencies.forEach((dep) => {
Expand Down Expand Up @@ -196,6 +185,16 @@ export function create_components(): {
add_to_current_children(current_element);
store_keyed_values(all_current_children);

Object.entries(instance_map).forEach(([id, component]) => {
let _id = Number(id);
if (component.rendered_in === render_id) {
delete instance_map[_id];
if (_component_map.has(_id)) {
_component_map.delete(_id);
}
}
});

components.forEach((c) => {
instance_map[c.id] = c;
_component_map.set(c.id, c);
Expand Down
1 change: 1 addition & 0 deletions js/app/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ export interface ComponentMeta {
value?: any;
component_class_id: string;
key: string | number | null;
rendered_in?: number;
}

/** Dictates whether a dependency is continous and/or a generator */
Expand Down
Loading

0 comments on commit 945ac83

Please sign in to comment.