Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache view_api info in server and python client #7888

Merged
merged 3 commits into from Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changeset/legal-teams-camp.md
@@ -0,0 +1,6 @@
---
"gradio": patch
"gradio_client": patch
---

fix:Cache view_api info in server and python client
7 changes: 2 additions & 5 deletions client/python/gradio_client/client.py
Expand Up @@ -169,7 +169,7 @@ def __init__(
self.upload_url = urllib.parse.urljoin(self.src, utils.UPLOAD_URL)
self.reset_url = urllib.parse.urljoin(self.src, utils.RESET_URL)
self.app_version = version.parse(self.config.get("version", "2.0"))
self._info = None
self._info = self._get_api_info()
self.session_hash = str(uuid.uuid4())

endpoint_class = (
Expand Down Expand Up @@ -611,8 +611,6 @@ def view_api(
}

"""
if not self._info:
self._info = self._get_api_info()
num_named_endpoints = len(self._info["named_endpoints"])
num_unnamed_endpoints = len(self._info["unnamed_endpoints"])
if num_named_endpoints == 0 and all_endpoints is None:
Expand Down Expand Up @@ -1000,6 +998,7 @@ def __init__(
self.api_name: str | Literal[False] | None = (
"/" + api_name if isinstance(api_name, str) else api_name
)
self._info = self.client._info
self.protocol = protocol
self.input_component_types = [
self._get_component_type(id_) for id_ in dependency["inputs"]
Expand Down Expand Up @@ -1029,8 +1028,6 @@ def _get_component_type(self, component_id: int):
)

def _get_parameters_info(self) -> list[ParameterInfo] | None:
if not self.client._info:
self._info = self.client._get_api_info()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whoops 馃槶

if self.api_name in self._info["named_endpoints"]:
return self._info["named_endpoints"][self.api_name]["parameters"]
return None
Expand Down
17 changes: 17 additions & 0 deletions client/python/test/conftest.py
Expand Up @@ -428,3 +428,20 @@ def long_response(_):
None,
gr.Textbox(label="Output"),
)


@pytest.fixture
def many_endpoint_demo():
with gr.Blocks() as demo:

def noop(x):
return x

n_elements = 1000
for _ in range(n_elements):
msg2 = gr.Textbox()
msg2.submit(noop, msg2, msg2)
butn2 = gr.Button()
butn2.click(noop, msg2, msg2)

return demo
8 changes: 8 additions & 0 deletions client/python/test/test_client.py
Expand Up @@ -74,6 +74,14 @@ def test_headers_constructed_correctly(self):
)
assert {"authorization": "Bearer abcde"}.items() <= client.headers.items()

def test_many_endpoint_demo_loads_quickly(self, many_endpoint_demo):
import datetime

start = datetime.datetime.now()
with connect(many_endpoint_demo):
pass
assert (datetime.datetime.now() - start).seconds < 5


class TestClientPredictions:
@pytest.mark.flaky
Expand Down
6 changes: 5 additions & 1 deletion gradio/routes.py
Expand Up @@ -163,6 +163,7 @@ def __init__(
self.change_event: None | threading.Event = None
self._asyncio_tasks: list[asyncio.Task] = []
self.auth_dependency = auth_dependency
self.api_info = None
# Allow user to manually set `docs_url` and `redoc_url`
# when instantiating an App; when they're not set, disable docs and redoc.
kwargs.setdefault("docs_url", None)
Expand Down Expand Up @@ -391,7 +392,10 @@ def main(request: fastapi.Request, user: str = Depends(get_current_user)):
@app.get("/info/", dependencies=[Depends(login_check)])
@app.get("/info", dependencies=[Depends(login_check)])
def api_info():
return app.get_blocks().get_api_info() # type: ignore
# The api info is set in create_app
if not app.api_info:
app.api_info = app.get_blocks().get_api_info()
return app.api_info

@app.get("/config/", dependencies=[Depends(login_check)])
@app.get("/config", dependencies=[Depends(login_check)])
Expand Down