Skip to content

Commit

Permalink
Provide status updates on file uploads (#6307)
Browse files Browse the repository at this point in the history
* Backend

* Backend

* add changeset

* Clean up + close connection

* Lint

* Fix tests

* Apply opacity transition

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
  • Loading branch information
freddyaboulton and gradio-pr-bot committed Nov 8, 2023
1 parent e9bb445 commit f1409f9
Show file tree
Hide file tree
Showing 9 changed files with 281 additions and 41 deletions.
7 changes: 7 additions & 0 deletions .changeset/tired-berries-tease.md
@@ -0,0 +1,7 @@
---
"@gradio/client": minor
"@gradio/upload": minor
"gradio": minor
---

feat:Provide status updates on file uploads
11 changes: 8 additions & 3 deletions client/js/src/client.ts
Expand Up @@ -156,7 +156,8 @@ interface Client {
upload_files: (
root: string,
files: File[],
token?: `hf_${string}`
token?: `hf_${string}`,
upload_id?: string
) => Promise<UploadResponse>;
client: (
app_reference: string,
Expand Down Expand Up @@ -208,7 +209,8 @@ export function api_factory(
async function upload_files(
root: string,
files: (Blob | File)[],
token?: `hf_${string}`
token?: `hf_${string}`,
upload_id?: string
): Promise<UploadResponse> {
const headers: {
Authorization?: string;
Expand All @@ -225,7 +227,10 @@ export function api_factory(
formData.append("files", file);
});
try {
var response = await fetch_implementation(`${root}/upload`, {
const upload_url = upload_id
? `${root}/upload?upload_id=${upload_id}`
: `${root}/upload`;
var response = await fetch_implementation(upload_url, {
method: "POST",
body: formData,
headers
Expand Down
3 changes: 2 additions & 1 deletion client/js/src/upload.ts
Expand Up @@ -88,14 +88,15 @@ export function get_fetchable_url_or_file(
export async function upload(
file_data: FileData[],
root: string,
upload_id?: string,
upload_fn: typeof upload_files = upload_files
): Promise<(FileData | null)[] | null> {
let files = (Array.isArray(file_data) ? file_data : [file_data]).map(
(file_data) => file_data.blob!
);

return await Promise.all(
await upload_fn(root, files).then(
await upload_fn(root, files, undefined, upload_id).then(
async (response: { files?: string[]; error?: string }) => {
if (response.error) {
throw new Error(response.error);
Expand Down
50 changes: 50 additions & 0 deletions gradio/route_utils.py
Expand Up @@ -2,6 +2,8 @@

import hashlib
import json
from collections import deque
from dataclasses import dataclass as python_dataclass
from tempfile import NamedTemporaryFile, _TemporaryFileWrapper
from typing import TYPE_CHECKING, AsyncGenerator, BinaryIO, List, Optional, Tuple, Union

Expand Down Expand Up @@ -294,6 +296,44 @@ def __init__(
self.sha = hashlib.sha1()


@python_dataclass(frozen=True)
class FileUploadProgressUnit:
filename: str
chunk_size: int
is_done: bool


class FileUploadProgress:
def __init__(self) -> None:
self._statuses: dict[str, deque[FileUploadProgressUnit]] = {}

def track(self, upload_id: str):
if upload_id not in self._statuses:
self._statuses[upload_id] = deque()

def update(self, upload_id: str, filename: str, message_bytes: bytes):
if upload_id not in self._statuses:
self._statuses[upload_id] = deque()
self._statuses[upload_id].append(
FileUploadProgressUnit(filename, len(message_bytes), is_done=False)
)

def set_done(self, upload_id: str):
self._statuses[upload_id].append(FileUploadProgressUnit("", 0, is_done=True))

def stop_tracking(self, upload_id: str):
if upload_id in self._statuses:
del self._statuses[upload_id]

def status(self, upload_id: str) -> deque[FileUploadProgressUnit]:
if upload_id not in self._statuses:
return deque()
return self._statuses[upload_id]

def is_tracked(self, upload_id: str):
return upload_id in self._statuses


class GradioMultiPartParser:
"""Vendored from starlette.MultipartParser.
Expand All @@ -315,6 +355,8 @@ def __init__(
*,
max_files: Union[int, float] = 1000,
max_fields: Union[int, float] = 1000,
upload_id: str | None = None,
upload_progress: FileUploadProgress | None = None,
) -> None:
assert (
multipart is not None
Expand All @@ -324,6 +366,8 @@ def __init__(
self.max_files = max_files
self.max_fields = max_fields
self.items: List[Tuple[str, Union[str, UploadFile]]] = []
self.upload_id = upload_id
self.upload_progress = upload_progress
self._current_files = 0
self._current_fields = 0
self._current_partial_header_name: bytes = b""
Expand All @@ -339,6 +383,10 @@ def on_part_begin(self) -> None:

def on_part_data(self, data: bytes, start: int, end: int) -> None:
message_bytes = data[start:end]
if self.upload_progress is not None:
self.upload_progress.update(
self.upload_id, self._current_part.file.filename, message_bytes # type: ignore
)
if self._current_part.file is None:
self._current_part.data += message_bytes
else:
Expand Down Expand Up @@ -464,4 +512,6 @@ async def parse(self) -> FormData:
raise exc

parser.finalize()
if self.upload_progress is not None:
self.upload_progress.set_done(self.upload_id) # type: ignore
return FormData(self.items)
59 changes: 58 additions & 1 deletion gradio/routes.py
Expand Up @@ -56,6 +56,7 @@
from gradio.oauth import attach_oauth
from gradio.queueing import Estimation, Event
from gradio.route_utils import ( # noqa: F401
FileUploadProgress,
GradioMultiPartParser,
GradioUploadFile,
MultiPartException,
Expand Down Expand Up @@ -121,6 +122,9 @@ def move_uploaded_files_to_cache(files: list[str], destinations: list[str]) -> N
shutil.move(file, dest)


file_upload_statuses = FileUploadProgress()


class App(FastAPI):
"""
FastAPI App Wrapper
Expand Down Expand Up @@ -681,20 +685,73 @@ def component_server(body: ComponentServerBody):
async def get_queue_status():
return app.get_blocks()._queue.get_estimation()

@app.get("/upload_progress")
def get_upload_progress(upload_id: str, request: fastapi.Request):
async def sse_stream(request: fastapi.Request):
last_heartbeat = time.perf_counter()
is_done = False
while True:
if await request.is_disconnected():
file_upload_statuses.stop_tracking(upload_id)
return
if is_done:
file_upload_statuses.stop_tracking(upload_id)
return

heartbeat_rate = 15
check_rate = 0.05
message = None
try:
if update := file_upload_statuses.status(upload_id).popleft():
if update.is_done:
message = {"msg": "done"}
is_done = True
else:
message = {
"msg": "update",
"orig_name": update.filename,
"chunk_size": update.chunk_size,
}
else:
await asyncio.sleep(check_rate)
if time.perf_counter() - last_heartbeat > heartbeat_rate:
message = {"msg": "heartbeat"}
last_heartbeat = time.perf_counter()
if message:
yield f"data: {json.dumps(message)}\n\n"
except IndexError:
if not file_upload_statuses.is_tracked(upload_id):
return
# pop from empty queue
continue

return StreamingResponse(
sse_stream(request),
media_type="text/event-stream",
)

@app.post("/upload", dependencies=[Depends(login_check)])
async def upload_file(request: fastapi.Request, bg_tasks: BackgroundTasks):
async def upload_file(
request: fastapi.Request,
bg_tasks: BackgroundTasks,
upload_id: Optional[str] = None,
):
content_type_header = request.headers.get("Content-Type")
content_type: bytes
content_type, _ = parse_options_header(content_type_header)
if content_type != b"multipart/form-data":
raise HTTPException(status_code=400, detail="Invalid content type.")

try:
if upload_id:
file_upload_statuses.track(upload_id)
multipart_parser = GradioMultiPartParser(
request.headers,
request.stream(),
max_files=1000,
max_fields=1000,
upload_id=upload_id if upload_id else None,
upload_progress=file_upload_statuses if upload_id else None,
)
form = await multipart_parser.parse()
except MultiPartException as exc:
Expand Down
6 changes: 3 additions & 3 deletions js/app/test/audio_component_events.spec.ts
Expand Up @@ -7,7 +7,7 @@ test("Audio click-to-upload uploads audio successfuly.", async ({ page }) => {
const uploader = await page.locator("input[type=file]");
await Promise.all([
uploader.setInputFiles(["../../test/test_files/audio_sample.wav"]),
page.waitForResponse("**/upload")
page.waitForResponse("**/upload?*")
]);

await expect(page.getByLabel("# Change Events")).toHaveValue("1");
Expand All @@ -21,7 +21,7 @@ test("Audio click-to-upload uploads audio successfuly.", async ({ page }) => {

await Promise.all([
uploader.setInputFiles(["../../test/test_files/audio_sample.wav"]),
page.waitForResponse("**/upload")
page.waitForResponse("**/upload?*")
]);

await expect(page.getByLabel("# Change Events")).toHaveValue("3");
Expand All @@ -39,7 +39,7 @@ test("Audio drag-and-drop uploads a file to the server correctly.", async ({
"audio_sample.wav",
"audio/wav"
),
page.waitForResponse("**/upload")
page.waitForResponse("**/upload?*")
]);
await expect(page.getByLabel("# Change Events")).toHaveValue("1");
await expect(page.getByLabel("# Upload Events")).toHaveValue("1");
Expand Down
6 changes: 3 additions & 3 deletions js/app/test/video_component_events.spec.ts
Expand Up @@ -9,7 +9,7 @@ test("Video click-to-upload uploads video successfuly. Clear, play, and pause bu
const uploader = await page.locator("input[type=file]");
await Promise.all([
uploader.setInputFiles(["./test/files/file_test.ogg"]),
page.waitForResponse("**/upload")
page.waitForResponse("**/upload?*?*")
]);

await expect(page.getByLabel("# Change Events")).toHaveValue("1");
Expand All @@ -28,7 +28,7 @@ test("Video click-to-upload uploads video successfuly. Clear, play, and pause bu

await Promise.all([
uploader.setInputFiles(["./test/files/file_test.ogg"]),
page.waitForResponse("**/upload")
page.waitForResponse("**/upload?*")
]);

await expect(page.getByLabel("# Change Events")).toHaveValue("3");
Expand All @@ -50,7 +50,7 @@ test("Video drag-and-drop uploads a file to the server correctly.", async ({
"file_test.ogg",
"video/*"
);
await page.waitForResponse("**/upload");
await page.waitForResponse("**/upload?*");
await expect(page.getByLabel("# Change Events")).toHaveValue("1");
await expect(page.getByLabel("# Upload Events")).toHaveValue("1");
});
Expand Down

0 comments on commit f1409f9

Please sign in to comment.