From 5f1cbc4363b09302334e9bc864587f8ef398550d Mon Sep 17 00:00:00 2001 From: D V <77478658+DarhkVoyd@users.noreply.github.com> Date: Wed, 11 Oct 2023 02:18:20 +0530 Subject: [PATCH] Add support for gr.Request to gr.ChatInterface (#5819) * Add support for gr.Request to gr.ChatInterface * add changeset * gr.ChatInterface: loose check for gr.Request * add request test * update test and chat_interface * chat interface * fix test * formatting * fixes * fix examples and add test * remove .update usage * revert interface changes --------- Co-authored-by: gradio-pr-bot Co-authored-by: Abubakar Abid --- .changeset/giant-beans-clap.md | 5 +++ gradio/chat_interface.py | 63 ++++++++++++++++++++++++---------- gradio/helpers.py | 2 +- test/test_routes.py | 37 +++++++++++++++++++- 4 files changed, 86 insertions(+), 21 deletions(-) create mode 100644 .changeset/giant-beans-clap.md diff --git a/.changeset/giant-beans-clap.md b/.changeset/giant-beans-clap.md new file mode 100644 index 000000000000..e12a99679f40 --- /dev/null +++ b/.changeset/giant-beans-clap.md @@ -0,0 +1,5 @@ +--- +"gradio": minor +--- + +feat:Add support for gr.Request to gr.ChatInterface diff --git a/gradio/chat_interface.py b/gradio/chat_interface.py index a27d92c92d5c..740d7f50e5d1 100644 --- a/gradio/chat_interface.py +++ b/gradio/chat_interface.py @@ -24,7 +24,9 @@ ) from gradio.events import Dependency, EventListenerMethod, on from gradio.helpers import create_examples as Examples # noqa: N812 +from gradio.helpers import special_args from gradio.layouts import Accordion, Column, Group, Row +from gradio.routes import Request from gradio.themes import ThemeClass as Theme from gradio.utils import SyncToAsyncIterator, async_iteration @@ -332,8 +334,8 @@ def _setup_stop_events( for event_trigger in event_triggers: event_trigger( lambda: ( - Button.update(visible=False), - Button.update(visible=True), + Button(visible=False), + Button(visible=True), ), None, [self.submit_btn, self.stop_btn], @@ -341,7 +343,7 @@ def _setup_stop_events( queue=False, ) event_to_cancel.then( - lambda: (Button.update(visible=True), Button.update(visible=False)), + lambda: (Button(visible=True), Button(visible=False)), None, [self.submit_btn, self.stop_btn], api_name=False, @@ -350,14 +352,14 @@ def _setup_stop_events( else: for event_trigger in event_triggers: event_trigger( - lambda: Button.update(visible=True), + lambda: Button(visible=True), None, [self.stop_btn], api_name=False, queue=False, ) event_to_cancel.then( - lambda: Button.update(visible=False), + lambda: Button(visible=False), None, [self.stop_btn], api_name=False, @@ -394,15 +396,21 @@ async def _submit_fn( self, message: str, history_with_input: list[list[str | None]], + request: Request, *args, ) -> tuple[list[list[str | None]], list[list[str | None]]]: history = history_with_input[:-1] + inputs, _, _ = special_args( + self.fn, inputs=[message, history, *args], request=request + ) + if self.is_async: - response = await self.fn(message, history, *args) + response = await self.fn(*inputs) else: response = await anyio.to_thread.run_sync( - self.fn, message, history, *args, limiter=self.limiter + self.fn, *inputs, limiter=self.limiter ) + history.append([message, response]) return history, history @@ -410,14 +418,19 @@ async def _stream_fn( self, message: str, history_with_input: list[list[str | None]], + request: Request, *args, ) -> AsyncGenerator: history = history_with_input[:-1] + inputs, _, _ = special_args( + self.fn, inputs=[message, history, *args], request=request + ) + if self.is_async: - generator = self.fn(message, history, *args) + generator = self.fn(*inputs) else: generator = await anyio.to_thread.run_sync( - self.fn, message, history, *args, limiter=self.limiter + self.fn, *inputs, limiter=self.limiter ) generator = SyncToAsyncIterator(generator, self.limiter) try: @@ -432,25 +445,33 @@ async def _stream_fn( yield update, update async def _api_submit_fn( - self, message: str, history: list[list[str | None]], *args + self, message: str, history: list[list[str | None]], request: Request, *args ) -> tuple[str, list[list[str | None]]]: + inputs, _, _ = special_args( + self.fn, inputs=[message, history, *args], request=request + ) + if self.is_async: - response = await self.fn(message, history, *args) + response = await self.fn(*inputs) else: response = await anyio.to_thread.run_sync( - self.fn, message, history, *args, limiter=self.limiter + self.fn, *inputs, limiter=self.limiter ) history.append([message, response]) return response, history async def _api_stream_fn( - self, message: str, history: list[list[str | None]], *args + self, message: str, history: list[list[str | None]], request: Request, *args ) -> AsyncGenerator: + inputs, _, _ = special_args( + self.fn, inputs=[message, history, *args], request=request + ) + if self.is_async: - generator = self.fn(message, history, *args) + generator = self.fn(*inputs) else: generator = await anyio.to_thread.run_sync( - self.fn, message, history, *args, limiter=self.limiter + self.fn, *inputs, limiter=self.limiter ) generator = SyncToAsyncIterator(generator, self.limiter) try: @@ -462,11 +483,13 @@ async def _api_stream_fn( yield response, history + [[message, response]] async def _examples_fn(self, message: str, *args) -> list[list[str | None]]: + inputs, _, _ = special_args(self.fn, inputs=[message, [], *args], request=None) + if self.is_async: - response = await self.fn(message, [], *args) + response = await self.fn(*inputs) else: response = await anyio.to_thread.run_sync( - self.fn, message, [], *args, limiter=self.limiter + self.fn, *inputs, limiter=self.limiter ) return [[message, response]] @@ -475,11 +498,13 @@ async def _examples_stream_fn( message: str, *args, ) -> AsyncGenerator: + inputs, _, _ = special_args(self.fn, inputs=[message, [], *args], request=None) + if self.is_async: - generator = self.fn(message, [], *args) + generator = self.fn(*inputs) else: generator = await anyio.to_thread.run_sync( - self.fn, message, [], *args, limiter=self.limiter + self.fn, *inputs, limiter=self.limiter ) generator = SyncToAsyncIterator(generator, self.limiter) async for response in generator: diff --git a/gradio/helpers.py b/gradio/helpers.py index 70df46d9f95d..3e6e73638180 100644 --- a/gradio/helpers.py +++ b/gradio/helpers.py @@ -710,7 +710,7 @@ def special_args( inputs: list[Any] | None = None, request: routes.Request | None = None, event_data: EventData | None = None, -): +) -> tuple[list, int | None, int | None]: """ Checks if function has special arguments Request or EventData (via annotation) or Progress (via default value). If inputs is provided, these values will be loaded into the inputs array. diff --git a/test/test_routes.py b/test/test_routes.py index fc0a38fc290f..39773420ab9b 100644 --- a/test/test_routes.py +++ b/test/test_routes.py @@ -516,7 +516,7 @@ def read_main(): class TestPassingRequest: - def test_request_included_with_regular_function(self): + def test_request_included_with_interface(self): def identity(name, request: gr.Request): assert isinstance(request.client.host, str) return name @@ -531,6 +531,41 @@ def identity(name, request: gr.Request): output = dict(response.json()) assert output["data"] == ["test"] + def test_request_included_with_chat_interface(self): + def identity(x, y, request: gr.Request): + assert isinstance(request.client.host, str) + return x + + app, _, _ = gr.ChatInterface(identity).launch( + prevent_thread_lock=True, + ) + client = TestClient(app) + + response = client.post("/api/chat/", json={"data": ["test", None]}) + assert response.status_code == 200 + output = dict(response.json()) + assert output["data"] == ["test", None] + + def test_request_included_with_chat_interface_when_streaming(self): + def identity(x, y, request: gr.Request): + assert isinstance(request.client.host, str) + for i in range(len(x)): + yield x[: i + 1] + + app, _, _ = ( + gr.ChatInterface(identity) + .queue() + .launch( + prevent_thread_lock=True, + ) + ) + client = TestClient(app) + + response = client.post("/api/chat/", json={"data": ["test", None]}) + assert response.status_code == 200 + output = dict(response.json()) + assert output["data"] == ["t", None] + def test_request_get_headers(self): def identity(name, request: gr.Request): assert isinstance(request.headers["user-agent"], str)