Skip to content

Commit

Permalink
Convert sse calls in client from async to sync (#8182)
Browse files Browse the repository at this point in the history
* convert sse calls in client from async to sync

* add changeset

* more sync

* lint

* more sync

* fix threadpool

* fix timeouts

* reuse executor

* lint

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
  • Loading branch information
abidlabs and gradio-pr-bot committed May 2, 2024
1 parent d0a759f commit 39791eb
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 102 deletions.
6 changes: 6 additions & 0 deletions .changeset/great-poets-visit.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"gradio": patch
"gradio_client": patch
---

fix:Convert sse calls in client from async to sync
49 changes: 22 additions & 27 deletions client/python/gradio_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,21 +225,21 @@ def _stream_heartbeat(self):
except httpx.TransportError:
return

async def stream_messages(
def stream_messages(
self, protocol: Literal["sse_v1", "sse_v2", "sse_v2.1", "sse_v3"]
) -> None:
try:
async with httpx.AsyncClient(
with httpx.Client(
timeout=httpx.Timeout(timeout=None), verify=self.ssl_verify
) as client:
async with client.stream(
with client.stream(
"GET",
self.sse_url,
params={"session_hash": self.session_hash},
headers=self.headers,
cookies=self.cookies,
) as response:
async for line in response.aiter_lines():
for line in response.iter_lines():
line = line.rstrip("\n")
if not len(line):
continue
Expand Down Expand Up @@ -276,14 +276,13 @@ async def stream_messages(
traceback.print_exc()
raise e

async def send_data(self, data, hash_data, protocol):
async with httpx.AsyncClient(verify=self.ssl_verify) as client:
req = await client.post(
self.sse_data_url,
json={**data, **hash_data},
headers=self.headers,
cookies=self.cookies,
)
def send_data(self, data, hash_data, protocol):
req = httpx.post(
self.sse_data_url,
json={**data, **hash_data},
headers=self.headers,
cookies=self.cookies,
)
if req.status_code == 503:
raise QueueError("Queue is full! Please try again.")
req.raise_for_status()
Expand All @@ -294,7 +293,7 @@ async def send_data(self, data, hash_data, protocol):
self.stream_open = True

def open_stream():
return utils.synchronize_async(self.stream_messages, protocol)
return self.stream_messages(protocol)

def close_stream(_):
self.stream_open = False
Expand Down Expand Up @@ -1119,18 +1118,12 @@ def _predict(*data) -> tuple:
}

if self.protocol == "sse":
result = utils.synchronize_async(
self._sse_fn_v0, data, hash_data, helper
)
result = self._sse_fn_v0(data, hash_data, helper) # type: ignore
elif self.protocol in ("sse_v1", "sse_v2", "sse_v2.1", "sse_v3"):
event_id = utils.synchronize_async(
self.client.send_data, data, hash_data, self.protocol
)
event_id = self.client.send_data(data, hash_data, self.protocol)
self.client.pending_event_ids.add(event_id)
self.client.pending_messages_per_event[event_id] = []
result = utils.synchronize_async(
self._sse_fn_v1plus, helper, event_id, self.protocol
)
result = self._sse_fn_v1plus(helper, event_id, self.protocol)
else:
raise ValueError(f"Unsupported protocol: {self.protocol}")

Expand Down Expand Up @@ -1290,11 +1283,11 @@ def _download_file(self, x: dict) -> str:
shutil.move(temp_dir / Path(url_path).name, dest)
return str(dest.resolve())

async def _sse_fn_v0(self, data: dict, hash_data: dict, helper: Communicator):
async with httpx.AsyncClient(
def _sse_fn_v0(self, data: dict, hash_data: dict, helper: Communicator):
with httpx.Client(
timeout=httpx.Timeout(timeout=None), verify=self.client.ssl_verify
) as client:
return await utils.get_pred_from_sse_v0(
return utils.get_pred_from_sse_v0(
client,
data,
hash_data,
Expand All @@ -1304,22 +1297,24 @@ async def _sse_fn_v0(self, data: dict, hash_data: dict, helper: Communicator):
self.client.headers,
self.client.cookies,
self.client.ssl_verify,
self.client.executor,
)

async def _sse_fn_v1plus(
def _sse_fn_v1plus(
self,
helper: Communicator,
event_id: str,
protocol: Literal["sse_v1", "sse_v2", "sse_v2.1", "sse_v3"],
):
return await utils.get_pred_from_sse_v1plus(
return utils.get_pred_from_sse_v1plus(
helper,
self.client.headers,
self.client.cookies,
self.client.pending_messages_per_event,
event_id,
protocol,
self.client.ssl_verify,
self.client.executor,
)


Expand Down

0 comments on commit 39791eb

Please sign in to comment.