Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable streaming audio in python client #5248

Merged
merged 12 commits into from Aug 21, 2023
Merged
30 changes: 30 additions & 0 deletions .changeset/chilly-fans-make.md
@@ -0,0 +1,30 @@
---
"gradio": minor
"gradio_client": minor
---

highlight:

#### Enable streaming audio in python client

The `gradio_client` now supports streaming file outputs 馃寠

No new syntax! Connect to a gradio demo that supports streaming file outputs and call `predict` or `submit` as you normally would.

```python
import gradio_client as grc
client = grc.Client("gradio/stream_audio_out")

# Get the entire generated audio as a local file
client.predict("/Users/freddy/Pictures/bark_demo.mp4", api_name="/predict")

job = client.submit("/Users/freddy/Pictures/bark_demo.mp4", api_name="/predict")

# Get the entire generated audio as a local file
job.result()

# Each individual chunk
job.outputs()
```


2 changes: 2 additions & 0 deletions .github/workflows/backend.yml
Expand Up @@ -98,6 +98,8 @@ jobs:
run: |
. venv/bin/activate
python -m pip install -r client/python/test/requirements.txt
- name: Install ffmpeg
uses: FedericoCarboni/setup-ffmpeg@v2
- name: Install Gradio and Client Libraries Locally (Linux)
if: runner.os == 'Linux'
run: |
Expand Down
2 changes: 2 additions & 0 deletions client/python/gradio_client/data_classes.py
Expand Up @@ -13,3 +13,5 @@ class FileData(TypedDict):
bool
] # whether the data corresponds to a file or base64 encoded data
orig_name: NotRequired[str] # original filename
mime_type: NotRequired[str]
is_stream: NotRequired[bool]
23 changes: 23 additions & 0 deletions client/python/gradio_client/serializing.py
Expand Up @@ -2,6 +2,8 @@

import json
import os
import secrets
import tempfile
import uuid
from pathlib import Path
from typing import Any
Expand Down Expand Up @@ -204,6 +206,11 @@ def deserialize(
class FileSerializable(Serializable):
"""Expects a dict with base64 representation of object as input/output which is serialized to a filepath."""

def __init__(self) -> None:
self.stream = None
self.stream_name = None
super().__init__()

def serialized_info(self):
return self._single_file_serialized_info()

Expand Down Expand Up @@ -268,6 +275,9 @@ def _serialize_single(
"size": size,
}

def _setup_stream(self, url, hf_token):
return utils.download_byte_stream(url, hf_token)

def _deserialize_single(
self,
x: str | FileData | None,
Expand All @@ -291,6 +301,19 @@ def _deserialize_single(
)
else:
file_name = utils.create_tmp_copy_of_file(filepath, dir=save_dir)
elif x.get("is_stream"):
assert x["name"] and root_url and save_dir
if not self.stream or self.stream_name != x["name"]:
self.stream = self._setup_stream(
root_url + "stream/" + x["name"], hf_token=hf_token
)
self.stream_name = x["name"]
chunk = next(self.stream)
path = Path(save_dir or tempfile.gettempdir()) / secrets.token_hex(20)
path.mkdir(parents=True, exist_ok=True)
path = path / x.get("orig_name", "output")
path.write_bytes(chunk)
file_name = str(path)
else:
data = x.get("data")
assert data is not None, f"The 'data' field is missing in {x}"
Expand Down
10 changes: 10 additions & 0 deletions client/python/gradio_client/utils.py
Expand Up @@ -387,6 +387,16 @@ def encode_url_or_file_to_base64(path: str | Path):
return encode_file_to_base64(path)


def download_byte_stream(url: str, hf_token=None):
arr = bytearray()
headers = {"Authorization": "Bearer " + hf_token} if hf_token else {}
with httpx.stream("GET", url, headers=headers) as r:
for data in r.iter_bytes():
arr += data
yield data
yield arr
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I understanding correctly that what's happening is that we are iterating chunk by chunk and then finally yielding the full array so that .predict() returns the full data array? 2 questions:

(1) How come outputs() doesn't include this final full data array?
(2) Will r.iter_bytes() work if the streaming response returns binary data (as in #5238) instead of file data?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(1) outputs only includes the values returned from the server during the process_generating event
(2) yea the stream enpoint only serves bytes so it should work when #5238 is merged



def decode_base64_to_binary(encoding: str) -> tuple[bytes, str | None]:
extension = get_extension(encoding)
data = encoding.rsplit(",", 1)[-1]
Expand Down
26 changes: 26 additions & 0 deletions client/python/test/conftest.py
Expand Up @@ -4,6 +4,7 @@

import gradio as gr
import pytest
from pydub import AudioSegment


def pytest_configure(config):
Expand Down Expand Up @@ -297,6 +298,31 @@ def greeting(name, state):
return demo


@pytest.fixture
def stream_audio():
import pathlib
import tempfile

def _stream_audio(audio_file):
audio = AudioSegment.from_mp3(audio_file)
i = 0
chunk_size = 3000

while chunk_size * i < len(audio):
chunk = audio[chunk_size * i : chunk_size * (i + 1)]
i += 1
if chunk:
file = str(pathlib.Path(tempfile.gettempdir()) / f"{i}.wav")
chunk.export(file, format="wav")
yield file

return gr.Interface(
fn=_stream_audio,
inputs=gr.Audio(type="filepath", label="Audio file to stream"),
outputs=gr.Audio(autoplay=True, streaming=True),
).queue()


@pytest.fixture
def all_components():
classes_to_check = gr.components.Component.__subclasses__()
Expand Down
1 change: 1 addition & 0 deletions client/python/test/requirements.txt
Expand Up @@ -4,3 +4,4 @@ pytest==7.1.2
ruff==0.0.264
pyright==1.1.305
gradio
pydub==0.25.1
15 changes: 15 additions & 0 deletions client/python/test/test_client.py
Expand Up @@ -268,6 +268,21 @@ def test_cancel_subsequent_jobs_state_reset(self, yield_demo):
assert job2.status().code == Status.FINISHED
assert len(job2.outputs()) == 4

def test_stream_audio(self, stream_audio):
with connect(stream_audio) as client:
job1 = client.submit(
"https://gradio-builds.s3.amazonaws.com/demo-files/bark_demo.mp4",
api_name="/predict",
)
assert Path(job1.result()).exists()

job2 = client.submit(
"https://gradio-builds.s3.amazonaws.com/demo-files/audio_sample.wav",
api_name="/predict",
)
assert Path(job2.result()).exists()
assert all(Path(p).exists() for p in job2.outputs())

@pytest.mark.flaky
def test_upload_file_private_space(self):
client = Client(
Expand Down
8 changes: 4 additions & 4 deletions gradio/blocks.py
Expand Up @@ -1360,10 +1360,10 @@ def handle_streaming_outputs(
if run not in self.pending_streams[session_hash]:
self.pending_streams[session_hash][run] = defaultdict(list)
self.pending_streams[session_hash][run][output_id].append(stream)
data[i] = {
"name": f"{session_hash}/{run}/{output_id}",
"is_stream": True,
}
if data[i]:
data[i]["is_file"] = False
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Preserve the orig_name so that clients have an easier time downloading the byte stream

Copy link
Member

@abidlabs abidlabs Aug 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But isn't ["name"] assigned to the same value as before in the next line?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea we're using "name" to mean the url from which to download the stream so I'm keeping that the same. Just preserving orig_name (if it's there) so that we can give a meaningful name when we go to download the byte stream to a file.

data[i]["name"] = f"{session_hash}/{run}/{output_id}"
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aliabid94 Thoughts on including the stream endpoint name in the route so that it doesn't have to be hardcoded in the client?

data[i]["is_stream"] = True
return data

async def process_api(
Expand Down
7 changes: 6 additions & 1 deletion gradio/components/audio.py
Expand Up @@ -357,7 +357,12 @@ def postprocess(
self.temp_files.add(file_path)
else:
file_path = self.make_temp_copy_if_needed(y)
return {"name": file_path, "data": None, "is_file": True}
return {
"name": file_path,
"data": None,
"is_file": True,
"orig_name": Path(file_path).name,
}

def stream_output(self, y):
if y is None:
Expand Down