Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

V4 fix typing #5686

Merged
merged 19 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changeset/lovely-news-speak.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"gradio": minor
"gradio_client": minor
---

feat:V4 fix typing
1 change: 1 addition & 0 deletions .github/workflows/backend.yml
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ jobs:
. venv/bin/activate
python -m pip install -e client/python
python -m pip install .
python -c "import gradio"
- name: Install Test Dependencies (Linux)
if: steps.cache.outputs.cache-hit != 'true' && runner.os == 'Linux'
run: |
Expand Down
20 changes: 10 additions & 10 deletions client/python/gradio_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import uuid
import warnings
from concurrent.futures import Future
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from threading import Lock
Expand Down Expand Up @@ -91,7 +92,9 @@ def __init__(
library_version=utils.__version__,
)
self.space_id = None
self.output_dir = output_dir
self.output_dir = (
str(output_dir) if isinstance(output_dir, Path) else output_dir
)

if src.startswith("http://") or src.startswith("https://"):
_src = src if src.endswith("/") else src + "/"
Expand Down Expand Up @@ -776,17 +779,14 @@ def deploy_discord(
)
if is_private:
huggingface_hub.add_space_secret(
space_id, "HF_TOKEN", hf_token, token=hf_token
space_id, "HF_TOKEN", hf_token, token=hf_token # type: ignore
)

url = f"https://huggingface.co/spaces/{space_id}"
print(f"See your discord bot here! {url}")
return url


from dataclasses import dataclass


@dataclass
class ComponentApiType:
skip: bool
Expand Down Expand Up @@ -1013,7 +1013,7 @@ def get_file(d):
new_data.append(d)
return file_list, new_data

def _add_uploaded_files_to_data(self, data: list[Any], files: list[dict]):
def _add_uploaded_files_to_data(self, data: list[Any], files: list[Any]):
def replace(d: ReplaceMe) -> dict:
return files[d.index]

Expand All @@ -1039,7 +1039,7 @@ def _download_file(
save_dir: str | None = None,
root_url: str | None = None,
hf_token: str | None = None,
):
) -> str | None:
if x is None:
return None
if isinstance(x, str):
Expand Down Expand Up @@ -1067,7 +1067,7 @@ def _download_file(
return file_name

def deserialize(self, *data) -> tuple:
data = list(data)
data_ = list(data)

def is_file(d):
return (
Expand All @@ -1080,8 +1080,8 @@ def is_file(d):
and "mime_type" in d
)

data = utils.traverse(data, self.download_file, is_file)
return data
data_: list[Any] = utils.traverse(data_, self.download_file, is_file)
return tuple(data_)

def process_predictions(self, *predictions):
predictions = self.deserialize(*predictions)
Expand Down
2 changes: 1 addition & 1 deletion client/python/gradio_client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ def get_desc(v):
raise APIInfoParseError(f"Cannot parse schema {schema}")


def traverse(json_obj, func, is_root):
def traverse(json_obj: Any, func: Callable, is_root: Callable) -> Any:
if is_root(json_obj):
return func(json_obj)
elif isinstance(json_obj, dict):
Expand Down
2 changes: 1 addition & 1 deletion client/python/test/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ black==23.3.0
pytest-asyncio
pytest==7.1.2
ruff==0.0.264
pyright==1.1.305
pyright==1.1.327
gradio
pydub==0.25.1
2 changes: 1 addition & 1 deletion gradio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
HighlightedText,
Highlightedtext,
Image,
Interpretation,
Json,
Label,
LinePlot,
Expand All @@ -57,6 +56,7 @@
Video,
component,
)
from gradio.data_classes import FileData
from gradio.events import EventData, LikeData, SelectData
from gradio.exceptions import Error
from gradio.external import load
Expand Down
37 changes: 22 additions & 15 deletions gradio/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
InvalidApiNameError,
InvalidBlockError,
)
from gradio.helpers import EventData, create_tracker, skip, special_args
from gradio.helpers import create_tracker, skip, special_args
from gradio.state_holder import SessionState
from gradio.themes import Default as DefaultTheme
from gradio.themes import ThemeClass as Theme
Expand Down Expand Up @@ -73,7 +73,7 @@
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
from fastapi.applications import FastAPI

from gradio.components import Component
from gradio.components.base import Component

BUILT_IN_THEMES: dict[str, Theme] = {
t.name: t
Expand Down Expand Up @@ -189,7 +189,7 @@ def set_event_trigger(
preprocess: bool = True,
postprocess: bool = True,
scroll_to_output: bool = False,
show_progress: str = "full",
show_progress: Literal["full", "hidden", "minimal"] | None = "full",
api_name: str | None | Literal[False] = None,
js: str | None = None,
no_target: bool = False,
Expand Down Expand Up @@ -458,7 +458,9 @@ def __repr__(self):
return str(self)


def postprocess_update_dict(block: Block, update_dict: dict, postprocess: bool = True):
def postprocess_update_dict(
block: Component | BlockContext, update_dict: dict, postprocess: bool = True
):
"""
Converts a dictionary of updates into a format that can be sent to the frontend.
E.g. {"__type__": "generic_update", "value": "2", "interactive": False}
Expand Down Expand Up @@ -616,7 +618,7 @@ def __init__(
else:
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "True"
super().__init__(render=False, **kwargs)
self.blocks: dict[int, Block] = {}
self.blocks: dict[int, Component | Block] = {}
self.fns: list[BlockFunction] = []
self.dependencies = []
self.mode = mode
Expand Down Expand Up @@ -665,6 +667,11 @@ def __init__(
}
analytics.initiated_analytics(data)

def get_component(self, id: int) -> Component:
comp = self.blocks[id]
assert isinstance(comp, Component)
return comp

@property
def _is_running_in_reload_thread(self):
from gradio.cli.commands.reload import reload_thread
Expand Down Expand Up @@ -1052,7 +1059,7 @@ def serialize_data(self, fn_index: int, inputs: list[Any]) -> list[Any]:
assert isinstance(
block, components.Component
), f"{block.__class__} Component with id {input_id} not a valid input component."
serialized_input = block.serialize(inputs[i])
serialized_input = block.serialize(inputs[i]) # type: ignore
processed_input.append(serialized_input)

return processed_input
Expand All @@ -1071,9 +1078,9 @@ def deserialize_data(self, fn_index: int, outputs: list[Any]) -> list[Any]:
assert isinstance(
block, components.Component
), f"{block.__class__} Component with id {output_id} not a valid output component."
deserialized = block.deserialize(
deserialized = block.deserialize( # type: ignore
outputs[o],
save_dir=block.DEFAULT_TEMP_DIR,
save_dir=block.DEFAULT_TEMP_DIR, # type: ignore
root_url=block.root_url,
hf_token=Context.hf_token,
)
Expand Down Expand Up @@ -1480,7 +1487,7 @@ def load(
outputs: list[Component] | None = None,
api_name: str | None | Literal[False] = None,
scroll_to_output: bool = False,
show_progress: str = "full",
show_progress: Literal["full", "hidden", "minimal"] | None = "full",
queue=None,
batch: bool = False,
max_batch_size: int = 4,
Expand Down Expand Up @@ -2063,7 +2070,7 @@ def reverse(text):
):
self.block_thread()

return TupleNoPrint((self.server_app, self.local_url, self.share_url))
return TupleNoPrint((self.server_app, self.local_url, self.share_url)) # type: ignore

def integrate(
self,
Expand Down Expand Up @@ -2219,8 +2226,8 @@ def get_api_info(self):
# The config has the most specific API info (taking into account the parameters
# of the component), so we use that if it exists. Otherwise, we fallback to the
# Serializer's API info.
info = self.blocks[component["id"]].api_info()
example = self.blocks[component["id"]].example_inputs()
info = self.get_component(component["id"]).api_info()
example = self.get_component(component["id"]).example_inputs
python_type = client_utils.json_schema_to_python_type(info)
dependency_info["parameters"].append(
{
Expand All @@ -2244,11 +2251,11 @@ def get_api_info(self):
skip_endpoint = True # if component not found, skip endpoint
break
type = component["type"]
if self.blocks[component["id"]].skip_api:
if self.get_component(component["id"]).skip_api:
continue
label = component["props"].get("label", f"value_{o}")
info = self.blocks[component["id"]].api_info()
example = self.blocks[component["id"]].example_inputs()
info = self.get_component(component["id"]).api_info()
example = self.get_component(component["id"]).example_inputs()
python_type = client_utils.json_schema_to_python_type(info)
dependency_info["returns"].append(
{
Expand Down
10 changes: 6 additions & 4 deletions gradio/chat_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(
self.cache_examples = True
else:
self.cache_examples = cache_examples or False
self.buttons: list[Button] = []
self.buttons: list[Button | None] = []

if additional_inputs:
if not isinstance(additional_inputs, list):
Expand Down Expand Up @@ -144,7 +144,9 @@ def __init__(
if textbox:
textbox.container = False
textbox.show_label = False
self.textbox = textbox.render()
textbox_ = textbox.render()
assert isinstance(textbox_, Textbox)
self.textbox = textbox_
else:
self.textbox = Textbox(
container=False,
Expand Down Expand Up @@ -184,7 +186,7 @@ def __init__(
raise ValueError(
f"The stop_btn parameter must be a gr.Button, string, or None, not {type(stop_btn)}"
)
self.buttons.extend([submit_btn, stop_btn])
self.buttons.extend([submit_btn, stop_btn]) # type: ignore

with Row():
for btn in [retry_btn, undo_btn, clear_btn]:
Expand All @@ -197,7 +199,7 @@ def __init__(
raise ValueError(
f"All the _btn parameters must be a gr.Button, string, or None, not {type(btn)}"
)
self.buttons.append(btn)
self.buttons.append(btn) # type: ignore

self.fake_api_btn = Button("Fake API", visible=False)
self.fake_response_textbox = Textbox(
Expand Down
9 changes: 4 additions & 5 deletions gradio/cli/commands/components/_create_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ def _get_component_code(template: str | None) -> ComponentFiles:


def _get_js_dependency_version(name: str, local_js_dir: Path) -> str:
package_json = json.load(
open(str(local_js_dir / name.split("/")[1] / "package.json"))
package_json = json.loads(
Path(local_js_dir / name.split("/")[1] / "package.json").read_text()
)
return package_json["version"]

Expand Down Expand Up @@ -161,12 +161,11 @@ def ignore(s, names):
dirs_exist_ok=True,
ignore=ignore,
)
source_package_json = json.load(open(str(frontend / "package.json")))
source_package_json = json.loads(Path(frontend / "package.json").read_text())
source_package_json["name"] = name.lower()
source_package_json = _modify_js_deps(source_package_json, "dependencies", p)
source_package_json = _modify_js_deps(source_package_json, "devDependencies", p)

json.dump(source_package_json, open(str(frontend / "package.json"), "w"), indent=2)
(frontend / "package.json").write_text(json.dumps(source_package_json, indent=2))


def _replace_old_class_name(old_class_name: str, new_class_name: str, content: str):
Expand Down
2 changes: 1 addition & 1 deletion gradio/cli/commands/components/build.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import subprocess
from pathlib import Path
from typing import Annotated

import typer
from typing_extensions import Annotated

from gradio.cli.commands.display import LivePanelDisplay

Expand Down
2 changes: 1 addition & 1 deletion gradio/cli/commands/components/dev.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import subprocess
from pathlib import Path
from typing import Annotated

import typer
from rich import print
from typing_extensions import Annotated

import gradio

Expand Down
4 changes: 1 addition & 3 deletions gradio/components/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from gradio.components.bar_plot import BarPlot
from gradio.components.base import (
Component,
Form,
FormComponent,
StreamingInput,
StreamingOutput,
Expand All @@ -29,7 +28,6 @@
from gradio.components.highlighted_text import HighlightedText
from gradio.components.html import HTML
from gradio.components.image import Image
from gradio.components.interpretation import Interpretation
from gradio.components.json_component import JSON
from gradio.components.label import Label
from gradio.components.line_plot import LinePlot
Expand All @@ -47,6 +45,7 @@
from gradio.components.textbox import Textbox
from gradio.components.upload_button import UploadButton
from gradio.components.video import Video
from gradio.layouts import Form

Text = Textbox
DataFrame = Dataframe
Expand Down Expand Up @@ -81,7 +80,6 @@
"Gallery",
"HTML",
"Image",
"Interpretation",
"JSON",
"Json",
"Label",
Expand Down
4 changes: 2 additions & 2 deletions gradio/components/annotated_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import warnings
from typing import Any, Literal
from typing import Any, List, Literal

import numpy as np
from gradio_client.documentation import document, set_documentation_group
Expand All @@ -26,7 +26,7 @@ class Annotation(GradioModel):

class AnnotatedImageData(GradioModel):
image: FileData
annotations: list[Annotation]
annotations: List[Annotation]


@document()
Expand Down