Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix remaining xfail tests in backend #6073

Merged
merged 9 commits into from Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changeset/rare-hornets-take.md
@@ -0,0 +1,6 @@
---
"gradio": minor
"gradio_client": minor
---

feat:Fix remaining xfail tests in backend
331 changes: 324 additions & 7 deletions client/python/gradio_client/client.py
Expand Up @@ -30,8 +30,9 @@
)
from packaging import version

from gradio_client import utils
from gradio_client import serializing, utils
from gradio_client.documentation import document, set_documentation_group
from gradio_client.exceptions import SerializationSetupError
from gradio_client.utils import (
Communicator,
JobStatus,
Expand Down Expand Up @@ -128,11 +129,15 @@ def __init__(
self.upload_url = urllib.parse.urljoin(self.src, utils.UPLOAD_URL)
self.reset_url = urllib.parse.urljoin(self.src, utils.RESET_URL)
self.config = self._get_config()
self.app_version = version.parse(self.config.get("version", "2.0"))
self._info = self._get_api_info()
self.session_hash = str(uuid.uuid4())

endpoint_class = Endpoint
if self.app_version < version.Version("4.0.0"):
endpoint_class = EndpointV3Compatibility
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice

self.endpoints = [
Endpoint(self, fn_index, dependency)
endpoint_class(self, fn_index, dependency)
for fn_index, dependency in enumerate(self.config["dependencies"])
]

Expand Down Expand Up @@ -360,9 +365,7 @@ def _get_api_info(self):
else:
api_info_url = urllib.parse.urljoin(self.src, utils.RAW_API_INFO_URL)

# Versions of Gradio older than 3.29.0 returned format of the API info
# from the /info endpoint
if version.parse(self.config.get("version", "2.0")) > version.Version("3.36.1"):
if self.app_version > version.Version("3.36.1"):
r = requests.get(api_info_url, headers=self.headers)
if r.ok:
info = r.json()
Expand Down Expand Up @@ -968,13 +971,21 @@ def _gather_files(self, *data):
file_list = []

def get_file(d):
file_list.append(d)
if utils.is_file_obj(d):
file_list.append(d["name"])
else:
file_list.append(d)
return ReplaceMe(len(file_list) - 1)

new_data = []
for i, d in enumerate(data):
if self.input_component_types[i].value_is_file:
d = utils.traverse(d, get_file, utils.is_filepath)
# Check file dicts and filepaths to upload
# file dict is a corner case but still needed for completeness
# most users should be using filepaths
d = utils.traverse(
d, get_file, lambda s: utils.is_file_obj(s) or utils.is_filepath(s)
)
new_data.append(d)
return file_list, new_data

Expand Down Expand Up @@ -1063,6 +1074,312 @@ async def _ws_fn(self, data, hash_data, helper: Communicator):
return await utils.get_pred_from_ws(websocket, data, hash_data, helper)


class EndpointV3Compatibility:
"""Endpoint class for connecting to v3 endpoints. Backwards compatibility."""

def __init__(self, client: Client, fn_index: int, dependency: dict):
self.client: Client = client
self.fn_index = fn_index
self.dependency = dependency
api_name = dependency.get("api_name")
self.api_name: str | Literal[False] | None = (
"/" + api_name if isinstance(api_name, str) else api_name
)
self.use_ws = self._use_websocket(self.dependency)
self.input_component_types = []
self.output_component_types = []
self.root_url = client.src + "/" if not client.src.endswith("/") else client.src
self.is_continuous = dependency.get("types", {}).get("continuous", False)
try:
# Only a real API endpoint if backend_fn is True (so not just a frontend function), serializers are valid,
# and api_name is not False (meaning that the developer has explicitly disabled the API endpoint)
self.serializers, self.deserializers = self._setup_serializers()
self.is_valid = self.dependency["backend_fn"] and self.api_name is not False
except SerializationSetupError:
self.is_valid = False

def __repr__(self):
return f"Endpoint src: {self.client.src}, api_name: {self.api_name}, fn_index: {self.fn_index}"

def __str__(self):
return self.__repr__()

def make_end_to_end_fn(self, helper: Communicator | None = None):
_predict = self.make_predict(helper)

def _inner(*data):
if not self.is_valid:
raise utils.InvalidAPIEndpointError()
data = self.insert_state(*data)
if self.client.serialize:
data = self.serialize(*data)
predictions = _predict(*data)
predictions = self.process_predictions(*predictions)
# Append final output only if not already present
# for consistency between generators and not generators
if helper:
with helper.lock:
if not helper.job.outputs:
helper.job.outputs.append(predictions)
return predictions

return _inner

def make_predict(self, helper: Communicator | None = None):
def _predict(*data) -> tuple:
data = json.dumps(
{
"data": data,
"fn_index": self.fn_index,
"session_hash": self.client.session_hash,
}
)
hash_data = json.dumps(
{
"fn_index": self.fn_index,
"session_hash": self.client.session_hash,
}
)
if self.use_ws:
result = utils.synchronize_async(self._ws_fn, data, hash_data, helper)
if "error" in result:
raise ValueError(result["error"])
else:
response = requests.post(
self.client.api_url, headers=self.client.headers, data=data
)
result = json.loads(response.content.decode("utf-8"))
try:
output = result["data"]
except KeyError as ke:
is_public_space = (
self.client.space_id
and not huggingface_hub.space_info(self.client.space_id).private
)
if "error" in result and "429" in result["error"] and is_public_space:
raise utils.TooManyRequestsError(
f"Too many requests to the API, please try again later. To avoid being rate-limited, "
f"please duplicate the Space using Client.duplicate({self.client.space_id}) "
f"and pass in your Hugging Face token."
) from None
elif "error" in result:
raise ValueError(result["error"]) from None
raise KeyError(
f"Could not find 'data' key in response. Response received: {result}"
) from ke
return tuple(output)

return _predict

def _predict_resolve(self, *data) -> Any:
"""Needed for gradio.load(), which has a slightly different signature for serializing/deserializing"""
outputs = self.make_predict()(*data)
if len(self.dependency["outputs"]) == 1:
return outputs[0]
return outputs

def _upload(
self, file_paths: list[str | list[str]]
) -> list[str | list[str]] | list[dict[str, Any] | list[dict[str, Any]]]:
if not file_paths:
return []
# Put all the filepaths in one file
# but then keep track of which index in the
# original list they came from so we can recreate
# the original structure
files = []
indices = []
for i, fs in enumerate(file_paths):
if not isinstance(fs, list):
fs = [fs]
for f in fs:
files.append(("files", (Path(f).name, open(f, "rb")))) # noqa: SIM115
indices.append(i)
r = requests.post(
self.client.upload_url, headers=self.client.headers, files=files
)
if r.status_code != 200:
uploaded = file_paths
else:
uploaded = []
result = r.json()
for i, fs in enumerate(file_paths):
if isinstance(fs, list):
output = [o for ix, o in enumerate(result) if indices[ix] == i]
res = [
{
"is_file": True,
"name": o,
"orig_name": Path(f).name,
"data": None,
}
for f, o in zip(fs, output)
]
else:
o = next(o for ix, o in enumerate(result) if indices[ix] == i)
res = {
"is_file": True,
"name": o,
"orig_name": Path(fs).name,
"data": None,
}
uploaded.append(res)
return uploaded

def _add_uploaded_files_to_data(
self,
files: list[str | list[str]] | list[dict[str, Any] | list[dict[str, Any]]],
data: list[Any],
) -> None:
"""Helper function to modify the input data with the uploaded files."""
file_counter = 0
for i, t in enumerate(self.input_component_types):
if t in ["file", "uploadbutton"]:
data[i] = files[file_counter]
file_counter += 1

def insert_state(self, *data) -> tuple:
data = list(data)
for i, input_component_type in enumerate(self.input_component_types):
if input_component_type == utils.STATE_COMPONENT:
data.insert(i, None)
return tuple(data)

def remove_skipped_components(self, *data) -> tuple:
data = [
d
for d, oct in zip(data, self.output_component_types)
if oct not in utils.SKIP_COMPONENTS
]
return tuple(data)

def reduce_singleton_output(self, *data) -> Any:
if (
len(
[
oct
for oct in self.output_component_types
if oct not in utils.SKIP_COMPONENTS
]
)
== 1
):
return data[0]
else:
return data

def serialize(self, *data) -> tuple:
if len(data) != len(self.serializers):
raise ValueError(
f"Expected {len(self.serializers)} arguments, got {len(data)}"
)

files = [
f
for f, t in zip(data, self.input_component_types)
if t in ["file", "uploadbutton"]
]
uploaded_files = self._upload(files)
data = list(data)
self._add_uploaded_files_to_data(uploaded_files, data)
o = tuple([s.serialize(d) for s, d in zip(self.serializers, data)])
return o

def deserialize(self, *data) -> tuple:
if len(data) != len(self.deserializers):
raise ValueError(
f"Expected {len(self.deserializers)} outputs, got {len(data)}"
)
outputs = tuple(
[
s.deserialize(
d,
save_dir=self.client.output_dir,
hf_token=self.client.hf_token,
root_url=self.root_url,
)
for s, d in zip(self.deserializers, data)
]
)
return outputs

def process_predictions(self, *predictions):
if self.client.serialize:
predictions = self.deserialize(*predictions)
predictions = self.remove_skipped_components(*predictions)
predictions = self.reduce_singleton_output(*predictions)
return predictions

def _setup_serializers(
self,
) -> tuple[list[serializing.Serializable], list[serializing.Serializable]]:
inputs = self.dependency["inputs"]
serializers = []

for i in inputs:
for component in self.client.config["components"]:
if component["id"] == i:
component_name = component["type"]
self.input_component_types.append(component_name)
if component.get("serializer"):
serializer_name = component["serializer"]
if serializer_name not in serializing.SERIALIZER_MAPPING:
raise SerializationSetupError(
f"Unknown serializer: {serializer_name}, you may need to update your gradio_client version."
)
serializer = serializing.SERIALIZER_MAPPING[serializer_name]
elif component_name in serializing.COMPONENT_MAPPING:
serializer = serializing.COMPONENT_MAPPING[component_name]
else:
raise SerializationSetupError(
f"Unknown component: {component_name}, you may need to update your gradio_client version."
)
serializers.append(serializer()) # type: ignore

outputs = self.dependency["outputs"]
deserializers = []
for i in outputs:
for component in self.client.config["components"]:
if component["id"] == i:
component_name = component["type"]
self.output_component_types.append(component_name)
if component.get("serializer"):
serializer_name = component["serializer"]
if serializer_name not in serializing.SERIALIZER_MAPPING:
raise SerializationSetupError(
f"Unknown serializer: {serializer_name}, you may need to update your gradio_client version."
)
deserializer = serializing.SERIALIZER_MAPPING[serializer_name]
elif component_name in utils.SKIP_COMPONENTS:
deserializer = serializing.SimpleSerializable
elif component_name in serializing.COMPONENT_MAPPING:
deserializer = serializing.COMPONENT_MAPPING[component_name]
else:
raise SerializationSetupError(
f"Unknown component: {component_name}, you may need to update your gradio_client version."
)
deserializers.append(deserializer()) # type: ignore

return serializers, deserializers

def _use_websocket(self, dependency: dict) -> bool:
queue_enabled = self.client.config.get("enable_queue", False)
queue_uses_websocket = version.parse(
self.client.config.get("version", "2.0")
) >= version.Version("3.2")
dependency_uses_queue = dependency.get("queue", False) is not False
return queue_enabled and queue_uses_websocket and dependency_uses_queue

async def _ws_fn(self, data, hash_data, helper: Communicator):
async with websockets.connect( # type: ignore
self.client.ws_url,
open_timeout=10,
extra_headers=self.client.headers,
max_size=1024 * 1024 * 1024,
) as websocket:
return await utils.get_pred_from_ws(websocket, data, hash_data, helper)


@document("result", "outputs", "status")
class Job(Future):
"""
Expand Down