Skip to content

Commit

Permalink
V4 fix typing (#5686)
Browse files Browse the repository at this point in the history
* Add examples for series and parallel

* v4 fix typing

* add changeset

* Fix

* Fix

* Fix 3.8

* Fix typing 3.8

* Lint

* Add code

* Add key

* Fix typing

* Add code

* Fix deps

* Fix fastapi

* Fix version
'
:

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
  • Loading branch information
freddyaboulton and gradio-pr-bot committed Sep 26, 2023
1 parent b4b3865 commit e19d333
Show file tree
Hide file tree
Showing 48 changed files with 176 additions and 278 deletions.
6 changes: 6 additions & 0 deletions .changeset/lovely-news-speak.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"gradio": minor
"gradio_client": minor
---

feat:V4 fix typing
1 change: 1 addition & 0 deletions .github/workflows/backend.yml
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ jobs:
. venv/bin/activate
python -m pip install -e client/python
python -m pip install .
python -c "import gradio"
- name: Install Test Dependencies (Linux)
if: steps.cache.outputs.cache-hit != 'true' && runner.os == 'Linux'
run: |
Expand Down
20 changes: 10 additions & 10 deletions client/python/gradio_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import uuid
import warnings
from concurrent.futures import Future
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from threading import Lock
Expand Down Expand Up @@ -91,7 +92,9 @@ def __init__(
library_version=utils.__version__,
)
self.space_id = None
self.output_dir = output_dir
self.output_dir = (
str(output_dir) if isinstance(output_dir, Path) else output_dir
)

if src.startswith("http://") or src.startswith("https://"):
_src = src if src.endswith("/") else src + "/"
Expand Down Expand Up @@ -776,17 +779,14 @@ def deploy_discord(
)
if is_private:
huggingface_hub.add_space_secret(
space_id, "HF_TOKEN", hf_token, token=hf_token
space_id, "HF_TOKEN", hf_token, token=hf_token # type: ignore
)

url = f"https://huggingface.co/spaces/{space_id}"
print(f"See your discord bot here! {url}")
return url


from dataclasses import dataclass


@dataclass
class ComponentApiType:
skip: bool
Expand Down Expand Up @@ -1013,7 +1013,7 @@ def get_file(d):
new_data.append(d)
return file_list, new_data

def _add_uploaded_files_to_data(self, data: list[Any], files: list[dict]):
def _add_uploaded_files_to_data(self, data: list[Any], files: list[Any]):
def replace(d: ReplaceMe) -> dict:
return files[d.index]

Expand All @@ -1039,7 +1039,7 @@ def _download_file(
save_dir: str | None = None,
root_url: str | None = None,
hf_token: str | None = None,
):
) -> str | None:
if x is None:
return None
if isinstance(x, str):
Expand Down Expand Up @@ -1067,7 +1067,7 @@ def _download_file(
return file_name

def deserialize(self, *data) -> tuple:
data = list(data)
data_ = list(data)

def is_file(d):
return (
Expand All @@ -1080,8 +1080,8 @@ def is_file(d):
and "mime_type" in d
)

data = utils.traverse(data, self.download_file, is_file)
return data
data_: list[Any] = utils.traverse(data_, self.download_file, is_file)
return tuple(data_)

def process_predictions(self, *predictions):
predictions = self.deserialize(*predictions)
Expand Down
2 changes: 1 addition & 1 deletion client/python/gradio_client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ def get_desc(v):
raise APIInfoParseError(f"Cannot parse schema {schema}")


def traverse(json_obj, func, is_root):
def traverse(json_obj: Any, func: Callable, is_root: Callable) -> Any:
if is_root(json_obj):
return func(json_obj)
elif isinstance(json_obj, dict):
Expand Down
2 changes: 1 addition & 1 deletion client/python/test/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ black==23.3.0
pytest-asyncio
pytest==7.1.2
ruff==0.0.264
pyright==1.1.305
pyright==1.1.327
gradio
pydub==0.25.1
2 changes: 1 addition & 1 deletion gradio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
HighlightedText,
Highlightedtext,
Image,
Interpretation,
Json,
Label,
LinePlot,
Expand All @@ -57,6 +56,7 @@
Video,
component,
)
from gradio.data_classes import FileData
from gradio.events import EventData, LikeData, SelectData
from gradio.exceptions import Error
from gradio.external import load
Expand Down
37 changes: 22 additions & 15 deletions gradio/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
InvalidApiNameError,
InvalidBlockError,
)
from gradio.helpers import EventData, create_tracker, skip, special_args
from gradio.helpers import create_tracker, skip, special_args
from gradio.state_holder import SessionState
from gradio.themes import Default as DefaultTheme
from gradio.themes import ThemeClass as Theme
Expand Down Expand Up @@ -73,7 +73,7 @@
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
from fastapi.applications import FastAPI

from gradio.components import Component
from gradio.components.base import Component

BUILT_IN_THEMES: dict[str, Theme] = {
t.name: t
Expand Down Expand Up @@ -189,7 +189,7 @@ def set_event_trigger(
preprocess: bool = True,
postprocess: bool = True,
scroll_to_output: bool = False,
show_progress: str = "full",
show_progress: Literal["full", "hidden", "minimal"] | None = "full",
api_name: str | None | Literal[False] = None,
js: str | None = None,
no_target: bool = False,
Expand Down Expand Up @@ -458,7 +458,9 @@ def __repr__(self):
return str(self)


def postprocess_update_dict(block: Block, update_dict: dict, postprocess: bool = True):
def postprocess_update_dict(
block: Component | BlockContext, update_dict: dict, postprocess: bool = True
):
"""
Converts a dictionary of updates into a format that can be sent to the frontend.
E.g. {"__type__": "generic_update", "value": "2", "interactive": False}
Expand Down Expand Up @@ -616,7 +618,7 @@ def __init__(
else:
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "True"
super().__init__(render=False, **kwargs)
self.blocks: dict[int, Block] = {}
self.blocks: dict[int, Component | Block] = {}
self.fns: list[BlockFunction] = []
self.dependencies = []
self.mode = mode
Expand Down Expand Up @@ -665,6 +667,11 @@ def __init__(
}
analytics.initiated_analytics(data)

def get_component(self, id: int) -> Component:
comp = self.blocks[id]
assert isinstance(comp, Component)
return comp

@property
def _is_running_in_reload_thread(self):
from gradio.cli.commands.reload import reload_thread
Expand Down Expand Up @@ -1052,7 +1059,7 @@ def serialize_data(self, fn_index: int, inputs: list[Any]) -> list[Any]:
assert isinstance(
block, components.Component
), f"{block.__class__} Component with id {input_id} not a valid input component."
serialized_input = block.serialize(inputs[i])
serialized_input = block.serialize(inputs[i]) # type: ignore
processed_input.append(serialized_input)

return processed_input
Expand All @@ -1071,9 +1078,9 @@ def deserialize_data(self, fn_index: int, outputs: list[Any]) -> list[Any]:
assert isinstance(
block, components.Component
), f"{block.__class__} Component with id {output_id} not a valid output component."
deserialized = block.deserialize(
deserialized = block.deserialize( # type: ignore
outputs[o],
save_dir=block.DEFAULT_TEMP_DIR,
save_dir=block.DEFAULT_TEMP_DIR, # type: ignore
root_url=block.root_url,
hf_token=Context.hf_token,
)
Expand Down Expand Up @@ -1480,7 +1487,7 @@ def load(
outputs: list[Component] | None = None,
api_name: str | None | Literal[False] = None,
scroll_to_output: bool = False,
show_progress: str = "full",
show_progress: Literal["full", "hidden", "minimal"] | None = "full",
queue=None,
batch: bool = False,
max_batch_size: int = 4,
Expand Down Expand Up @@ -2063,7 +2070,7 @@ def reverse(text):
):
self.block_thread()

return TupleNoPrint((self.server_app, self.local_url, self.share_url))
return TupleNoPrint((self.server_app, self.local_url, self.share_url)) # type: ignore

def integrate(
self,
Expand Down Expand Up @@ -2219,8 +2226,8 @@ def get_api_info(self):
# The config has the most specific API info (taking into account the parameters
# of the component), so we use that if it exists. Otherwise, we fallback to the
# Serializer's API info.
info = self.blocks[component["id"]].api_info()
example = self.blocks[component["id"]].example_inputs()
info = self.get_component(component["id"]).api_info()
example = self.get_component(component["id"]).example_inputs
python_type = client_utils.json_schema_to_python_type(info)
dependency_info["parameters"].append(
{
Expand All @@ -2244,11 +2251,11 @@ def get_api_info(self):
skip_endpoint = True # if component not found, skip endpoint
break
type = component["type"]
if self.blocks[component["id"]].skip_api:
if self.get_component(component["id"]).skip_api:
continue
label = component["props"].get("label", f"value_{o}")
info = self.blocks[component["id"]].api_info()
example = self.blocks[component["id"]].example_inputs()
info = self.get_component(component["id"]).api_info()
example = self.get_component(component["id"]).example_inputs()
python_type = client_utils.json_schema_to_python_type(info)
dependency_info["returns"].append(
{
Expand Down
10 changes: 6 additions & 4 deletions gradio/chat_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(
self.cache_examples = True
else:
self.cache_examples = cache_examples or False
self.buttons: list[Button] = []
self.buttons: list[Button | None] = []

if additional_inputs:
if not isinstance(additional_inputs, list):
Expand Down Expand Up @@ -144,7 +144,9 @@ def __init__(
if textbox:
textbox.container = False
textbox.show_label = False
self.textbox = textbox.render()
textbox_ = textbox.render()
assert isinstance(textbox_, Textbox)
self.textbox = textbox_
else:
self.textbox = Textbox(
container=False,
Expand Down Expand Up @@ -184,7 +186,7 @@ def __init__(
raise ValueError(
f"The stop_btn parameter must be a gr.Button, string, or None, not {type(stop_btn)}"
)
self.buttons.extend([submit_btn, stop_btn])
self.buttons.extend([submit_btn, stop_btn]) # type: ignore

with Row():
for btn in [retry_btn, undo_btn, clear_btn]:
Expand All @@ -197,7 +199,7 @@ def __init__(
raise ValueError(
f"All the _btn parameters must be a gr.Button, string, or None, not {type(btn)}"
)
self.buttons.append(btn)
self.buttons.append(btn) # type: ignore

self.fake_api_btn = Button("Fake API", visible=False)
self.fake_response_textbox = Textbox(
Expand Down
9 changes: 4 additions & 5 deletions gradio/cli/commands/components/_create_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ def _get_component_code(template: str | None) -> ComponentFiles:


def _get_js_dependency_version(name: str, local_js_dir: Path) -> str:
package_json = json.load(
open(str(local_js_dir / name.split("/")[1] / "package.json"))
package_json = json.loads(
Path(local_js_dir / name.split("/")[1] / "package.json").read_text()
)
return package_json["version"]

Expand Down Expand Up @@ -161,12 +161,11 @@ def ignore(s, names):
dirs_exist_ok=True,
ignore=ignore,
)
source_package_json = json.load(open(str(frontend / "package.json")))
source_package_json = json.loads(Path(frontend / "package.json").read_text())
source_package_json["name"] = name.lower()
source_package_json = _modify_js_deps(source_package_json, "dependencies", p)
source_package_json = _modify_js_deps(source_package_json, "devDependencies", p)

json.dump(source_package_json, open(str(frontend / "package.json"), "w"), indent=2)
(frontend / "package.json").write_text(json.dumps(source_package_json, indent=2))


def _replace_old_class_name(old_class_name: str, new_class_name: str, content: str):
Expand Down
2 changes: 1 addition & 1 deletion gradio/cli/commands/components/build.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import subprocess
from pathlib import Path
from typing import Annotated

import typer
from typing_extensions import Annotated

from gradio.cli.commands.display import LivePanelDisplay

Expand Down
2 changes: 1 addition & 1 deletion gradio/cli/commands/components/dev.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import subprocess
from pathlib import Path
from typing import Annotated

import typer
from rich import print
from typing_extensions import Annotated

import gradio

Expand Down
4 changes: 1 addition & 3 deletions gradio/components/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from gradio.components.bar_plot import BarPlot
from gradio.components.base import (
Component,
Form,
FormComponent,
StreamingInput,
StreamingOutput,
Expand All @@ -29,7 +28,6 @@
from gradio.components.highlighted_text import HighlightedText
from gradio.components.html import HTML
from gradio.components.image import Image
from gradio.components.interpretation import Interpretation
from gradio.components.json_component import JSON
from gradio.components.label import Label
from gradio.components.line_plot import LinePlot
Expand All @@ -47,6 +45,7 @@
from gradio.components.textbox import Textbox
from gradio.components.upload_button import UploadButton
from gradio.components.video import Video
from gradio.layouts import Form

Text = Textbox
DataFrame = Dataframe
Expand Down Expand Up @@ -81,7 +80,6 @@
"Gallery",
"HTML",
"Image",
"Interpretation",
"JSON",
"Json",
"Label",
Expand Down
4 changes: 2 additions & 2 deletions gradio/components/annotated_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import warnings
from typing import Any, Literal
from typing import Any, List, Literal

import numpy as np
from gradio_client.documentation import document, set_documentation_group
Expand All @@ -26,7 +26,7 @@ class Annotation(GradioModel):

class AnnotatedImageData(GradioModel):
image: FileData
annotations: list[Annotation]
annotations: List[Annotation]


@document()
Expand Down

0 comments on commit e19d333

Please sign in to comment.