Skip to content

Commit

Permalink
Add support for gr.Request to gr.ChatInterface (#5819)
Browse files Browse the repository at this point in the history
* Add support for gr.Request to gr.ChatInterface

* add changeset

* gr.ChatInterface: loose check for gr.Request

* add request test

* update test and chat_interface

* chat interface

* fix test

* formatting

* fixes

* fix examples and add test

* remove .update usage

* revert interface changes

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
  • Loading branch information
3 people committed Oct 10, 2023
1 parent 4633478 commit 5f1cbc4
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 21 deletions.
5 changes: 5 additions & 0 deletions .changeset/giant-beans-clap.md
@@ -0,0 +1,5 @@
---
"gradio": minor
---

feat:Add support for gr.Request to gr.ChatInterface
63 changes: 44 additions & 19 deletions gradio/chat_interface.py
Expand Up @@ -24,7 +24,9 @@
)
from gradio.events import Dependency, EventListenerMethod, on
from gradio.helpers import create_examples as Examples # noqa: N812
from gradio.helpers import special_args
from gradio.layouts import Accordion, Column, Group, Row
from gradio.routes import Request
from gradio.themes import ThemeClass as Theme
from gradio.utils import SyncToAsyncIterator, async_iteration

Expand Down Expand Up @@ -332,16 +334,16 @@ def _setup_stop_events(
for event_trigger in event_triggers:
event_trigger(
lambda: (
Button.update(visible=False),
Button.update(visible=True),
Button(visible=False),
Button(visible=True),
),
None,
[self.submit_btn, self.stop_btn],
api_name=False,
queue=False,
)
event_to_cancel.then(
lambda: (Button.update(visible=True), Button.update(visible=False)),
lambda: (Button(visible=True), Button(visible=False)),
None,
[self.submit_btn, self.stop_btn],
api_name=False,
Expand All @@ -350,14 +352,14 @@ def _setup_stop_events(
else:
for event_trigger in event_triggers:
event_trigger(
lambda: Button.update(visible=True),
lambda: Button(visible=True),
None,
[self.stop_btn],
api_name=False,
queue=False,
)
event_to_cancel.then(
lambda: Button.update(visible=False),
lambda: Button(visible=False),
None,
[self.stop_btn],
api_name=False,
Expand Down Expand Up @@ -394,30 +396,41 @@ async def _submit_fn(
self,
message: str,
history_with_input: list[list[str | None]],
request: Request,
*args,
) -> tuple[list[list[str | None]], list[list[str | None]]]:
history = history_with_input[:-1]
inputs, _, _ = special_args(
self.fn, inputs=[message, history, *args], request=request
)

if self.is_async:
response = await self.fn(message, history, *args)
response = await self.fn(*inputs)
else:
response = await anyio.to_thread.run_sync(
self.fn, message, history, *args, limiter=self.limiter
self.fn, *inputs, limiter=self.limiter
)

history.append([message, response])
return history, history

async def _stream_fn(
self,
message: str,
history_with_input: list[list[str | None]],
request: Request,
*args,
) -> AsyncGenerator:
history = history_with_input[:-1]
inputs, _, _ = special_args(
self.fn, inputs=[message, history, *args], request=request
)

if self.is_async:
generator = self.fn(message, history, *args)
generator = self.fn(*inputs)
else:
generator = await anyio.to_thread.run_sync(
self.fn, message, history, *args, limiter=self.limiter
self.fn, *inputs, limiter=self.limiter
)
generator = SyncToAsyncIterator(generator, self.limiter)
try:
Expand All @@ -432,25 +445,33 @@ async def _stream_fn(
yield update, update

async def _api_submit_fn(
self, message: str, history: list[list[str | None]], *args
self, message: str, history: list[list[str | None]], request: Request, *args
) -> tuple[str, list[list[str | None]]]:
inputs, _, _ = special_args(
self.fn, inputs=[message, history, *args], request=request
)

if self.is_async:
response = await self.fn(message, history, *args)
response = await self.fn(*inputs)
else:
response = await anyio.to_thread.run_sync(
self.fn, message, history, *args, limiter=self.limiter
self.fn, *inputs, limiter=self.limiter
)
history.append([message, response])
return response, history

async def _api_stream_fn(
self, message: str, history: list[list[str | None]], *args
self, message: str, history: list[list[str | None]], request: Request, *args
) -> AsyncGenerator:
inputs, _, _ = special_args(
self.fn, inputs=[message, history, *args], request=request
)

if self.is_async:
generator = self.fn(message, history, *args)
generator = self.fn(*inputs)
else:
generator = await anyio.to_thread.run_sync(
self.fn, message, history, *args, limiter=self.limiter
self.fn, *inputs, limiter=self.limiter
)
generator = SyncToAsyncIterator(generator, self.limiter)
try:
Expand All @@ -462,11 +483,13 @@ async def _api_stream_fn(
yield response, history + [[message, response]]

async def _examples_fn(self, message: str, *args) -> list[list[str | None]]:
inputs, _, _ = special_args(self.fn, inputs=[message, [], *args], request=None)

if self.is_async:
response = await self.fn(message, [], *args)
response = await self.fn(*inputs)
else:
response = await anyio.to_thread.run_sync(
self.fn, message, [], *args, limiter=self.limiter
self.fn, *inputs, limiter=self.limiter
)
return [[message, response]]

Expand All @@ -475,11 +498,13 @@ async def _examples_stream_fn(
message: str,
*args,
) -> AsyncGenerator:
inputs, _, _ = special_args(self.fn, inputs=[message, [], *args], request=None)

if self.is_async:
generator = self.fn(message, [], *args)
generator = self.fn(*inputs)
else:
generator = await anyio.to_thread.run_sync(
self.fn, message, [], *args, limiter=self.limiter
self.fn, *inputs, limiter=self.limiter
)
generator = SyncToAsyncIterator(generator, self.limiter)
async for response in generator:
Expand Down
2 changes: 1 addition & 1 deletion gradio/helpers.py
Expand Up @@ -710,7 +710,7 @@ def special_args(
inputs: list[Any] | None = None,
request: routes.Request | None = None,
event_data: EventData | None = None,
):
) -> tuple[list, int | None, int | None]:
"""
Checks if function has special arguments Request or EventData (via annotation) or Progress (via default value).
If inputs is provided, these values will be loaded into the inputs array.
Expand Down
37 changes: 36 additions & 1 deletion test/test_routes.py
Expand Up @@ -516,7 +516,7 @@ def read_main():


class TestPassingRequest:
def test_request_included_with_regular_function(self):
def test_request_included_with_interface(self):
def identity(name, request: gr.Request):
assert isinstance(request.client.host, str)
return name
Expand All @@ -531,6 +531,41 @@ def identity(name, request: gr.Request):
output = dict(response.json())
assert output["data"] == ["test"]

def test_request_included_with_chat_interface(self):
def identity(x, y, request: gr.Request):
assert isinstance(request.client.host, str)
return x

app, _, _ = gr.ChatInterface(identity).launch(
prevent_thread_lock=True,
)
client = TestClient(app)

response = client.post("/api/chat/", json={"data": ["test", None]})
assert response.status_code == 200
output = dict(response.json())
assert output["data"] == ["test", None]

def test_request_included_with_chat_interface_when_streaming(self):
def identity(x, y, request: gr.Request):
assert isinstance(request.client.host, str)
for i in range(len(x)):
yield x[: i + 1]

app, _, _ = (
gr.ChatInterface(identity)
.queue()
.launch(
prevent_thread_lock=True,
)
)
client = TestClient(app)

response = client.post("/api/chat/", json={"data": ["test", None]})
assert response.status_code == 200
output = dict(response.json())
assert output["data"] == ["t", None]

def test_request_get_headers(self):
def identity(name, request: gr.Request):
assert isinstance(request.headers["user-agent"], str)
Expand Down

0 comments on commit 5f1cbc4

Please sign in to comment.