Skip to content

Commit

Permalink
File upload optimization (#5961)
Browse files Browse the repository at this point in the history
* Use custom multipart parser

* add changeset

* remove print

* Add comment

* Lint

* fix code

* remove print

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
  • Loading branch information
freddyaboulton and gradio-pr-bot committed Oct 18, 2023
1 parent 85056de commit be2ed5e
Show file tree
Hide file tree
Showing 5 changed files with 254 additions and 56 deletions.
5 changes: 5 additions & 0 deletions .changeset/sixty-bags-mix.md
@@ -0,0 +1,5 @@
---
"gradio": minor
---

feat:File upload optimization
43 changes: 0 additions & 43 deletions gradio/processing_utils.py
Expand Up @@ -5,7 +5,6 @@
import json
import logging
import os
import secrets
import shutil
import subprocess
import tempfile
Expand All @@ -15,12 +14,8 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal

import aiofiles
import anyio
import numpy as np
import requests
from anyio import CapacityLimiter
from fastapi import UploadFile
from gradio_client import utils as client_utils
from PIL import Image, ImageOps, PngImagePlugin

Expand Down Expand Up @@ -207,44 +202,6 @@ def save_file_to_cache(file_path: str | Path, cache_dir: str) -> str:
return full_temp_file_path


async def save_uploaded_file(
file: UploadFile, upload_dir: str, limiter: CapacityLimiter | None = None
) -> str:
temp_dir = secrets.token_hex(
20
) # Since the full file is being uploaded anyways, there is no benefit to hashing the file.
temp_dir = Path(upload_dir) / temp_dir
temp_dir.mkdir(exist_ok=True, parents=True)

sha1 = hashlib.sha1()

if file.filename:
file_name = Path(file.filename).name
name = client_utils.strip_invalid_filename_characters(file_name)
else:
name = f"tmp{secrets.token_hex(5)}"

full_temp_file_path = str(abspath(temp_dir / name))

async with aiofiles.open(full_temp_file_path, "wb") as output_file:
while True:
content = await file.read(100 * 1024 * 1024)
if not content:
break
sha1.update(content)
await output_file.write(content)

directory = Path(upload_dir) / sha1.hexdigest()
directory.mkdir(exist_ok=True, parents=True)
dest = (directory / name).resolve()

await anyio.to_thread.run_sync(
shutil.move, full_temp_file_path, dest, limiter=limiter
)

return str(dest)


def save_url_to_cache(url: str, cache_dir: str) -> str:
"""Downloads a file and makes a temporary file path for a copy if does not already
exist. Otherwise returns the path to the existing temp file."""
Expand Down
203 changes: 202 additions & 1 deletion gradio/route_utils.py
@@ -1,11 +1,17 @@
from __future__ import annotations

import hashlib
import json
from typing import TYPE_CHECKING, Optional, Union
from tempfile import NamedTemporaryFile, _TemporaryFileWrapper
from typing import TYPE_CHECKING, AsyncGenerator, BinaryIO, List, Optional, Tuple, Union

import fastapi
import httpx
import multipart
from gradio_client.documentation import document, set_documentation_group
from multipart.multipart import parse_options_header
from starlette.datastructures import FormData, Headers, UploadFile
from starlette.formparsers import MultiPartException, MultipartPart

from gradio import utils
from gradio.data_classes import PredictBody
Expand Down Expand Up @@ -264,3 +270,198 @@ def strip_url(orig_url: str) -> str:
stripped_url = parsed_url.copy_with(query=None)
stripped_url = str(stripped_url)
return stripped_url.rstrip("/")


def _user_safe_decode(src: bytes, codec: str) -> str:
try:
return src.decode(codec)
except (UnicodeDecodeError, LookupError):
return src.decode("latin-1")


class GradioUploadFile(UploadFile):
"""UploadFile with a sha attribute."""

def __init__(
self,
file: BinaryIO,
*,
size: int | None = None,
filename: str | None = None,
headers: Headers | None = None,
) -> None:
super().__init__(file, size=size, filename=filename, headers=headers)
self.sha = hashlib.sha1()


class GradioMultiPartParser:
"""Vendored from starlette.MultipartParser.
Thanks starlette!
Made the following modifications
- Use GradioUploadFile instead of UploadFile
- Use NamedTemporaryFile instead of SpooledTemporaryFile
- Compute hash of data as the request is streamed
"""

max_file_size = 1024 * 1024

def __init__(
self,
headers: Headers,
stream: AsyncGenerator[bytes, None],
*,
max_files: Union[int, float] = 1000,
max_fields: Union[int, float] = 1000,
) -> None:
assert (
multipart is not None
), "The `python-multipart` library must be installed to use form parsing."
self.headers = headers
self.stream = stream
self.max_files = max_files
self.max_fields = max_fields
self.items: List[Tuple[str, Union[str, UploadFile]]] = []
self._current_files = 0
self._current_fields = 0
self._current_partial_header_name: bytes = b""
self._current_partial_header_value: bytes = b""
self._current_part = MultipartPart()
self._charset = ""
self._file_parts_to_write: List[Tuple[MultipartPart, bytes]] = []
self._file_parts_to_finish: List[MultipartPart] = []
self._files_to_close_on_error: List[_TemporaryFileWrapper] = []

def on_part_begin(self) -> None:
self._current_part = MultipartPart()

def on_part_data(self, data: bytes, start: int, end: int) -> None:
message_bytes = data[start:end]
if self._current_part.file is None:
self._current_part.data += message_bytes
else:
self._file_parts_to_write.append((self._current_part, message_bytes))

def on_part_end(self) -> None:
if self._current_part.file is None:
self.items.append(
(
self._current_part.field_name,
_user_safe_decode(self._current_part.data, self._charset),
)
)
else:
self._file_parts_to_finish.append(self._current_part)
# The file can be added to the items right now even though it's not
# finished yet, because it will be finished in the `parse()` method, before
# self.items is used in the return value.
self.items.append((self._current_part.field_name, self._current_part.file))

def on_header_field(self, data: bytes, start: int, end: int) -> None:
self._current_partial_header_name += data[start:end]

def on_header_value(self, data: bytes, start: int, end: int) -> None:
self._current_partial_header_value += data[start:end]

def on_header_end(self) -> None:
field = self._current_partial_header_name.lower()
if field == b"content-disposition":
self._current_part.content_disposition = self._current_partial_header_value
self._current_part.item_headers.append(
(field, self._current_partial_header_value)
)
self._current_partial_header_name = b""
self._current_partial_header_value = b""

def on_headers_finished(self) -> None:
disposition, options = parse_options_header(
self._current_part.content_disposition
)
try:
self._current_part.field_name = _user_safe_decode(
options[b"name"], self._charset
)
except KeyError as e:
raise MultiPartException(
'The Content-Disposition header field "name" must be ' "provided."
) from e
if b"filename" in options:
self._current_files += 1
if self._current_files > self.max_files:
raise MultiPartException(
f"Too many files. Maximum number of files is {self.max_files}."
)
filename = _user_safe_decode(options[b"filename"], self._charset)
tempfile = NamedTemporaryFile(delete=False)
self._files_to_close_on_error.append(tempfile)
self._current_part.file = GradioUploadFile(
file=tempfile, # type: ignore[arg-type]
size=0,
filename=filename,
headers=Headers(raw=self._current_part.item_headers),
)
else:
self._current_fields += 1
if self._current_fields > self.max_fields:
raise MultiPartException(
f"Too many fields. Maximum number of fields is {self.max_fields}."
)
self._current_part.file = None

def on_end(self) -> None:
pass

async def parse(self) -> FormData:
# Parse the Content-Type header to get the multipart boundary.
_, params = parse_options_header(self.headers["Content-Type"])
charset = params.get(b"charset", "utf-8")
if type(charset) == bytes:
charset = charset.decode("latin-1")
self._charset = charset
try:
boundary = params[b"boundary"]
except KeyError as e:
raise MultiPartException("Missing boundary in multipart.") from e

# Callbacks dictionary.
callbacks = {
"on_part_begin": self.on_part_begin,
"on_part_data": self.on_part_data,
"on_part_end": self.on_part_end,
"on_header_field": self.on_header_field,
"on_header_value": self.on_header_value,
"on_header_end": self.on_header_end,
"on_headers_finished": self.on_headers_finished,
"on_end": self.on_end,
}

# Create the parser.
parser = multipart.MultipartParser(boundary, callbacks)
try:
# Feed the parser with data from the request.
async for chunk in self.stream:
parser.write(chunk)
# Write file data, it needs to use await with the UploadFile methods
# that call the corresponding file methods *in a threadpool*,
# otherwise, if they were called directly in the callback methods above
# (regular, non-async functions), that would block the event loop in
# the main thread.
for part, data in self._file_parts_to_write:
assert part.file # for type checkers
await part.file.write(data)
part.file.sha.update(data) # type: ignore
for part in self._file_parts_to_finish:
assert part.file # for type checkers
await part.file.seek(0)
self._file_parts_to_write.clear()
self._file_parts_to_finish.clear()
except MultiPartException as exc:
# Close all the files if there was an error.
for file in self._files_to_close_on_error:
file.close()
raise exc

parser.finalize()
return FormData(self.items)
57 changes: 46 additions & 11 deletions gradio/routes.py
Expand Up @@ -15,6 +15,7 @@
import os
import posixpath
import secrets
import shutil
import tempfile
import threading
import time
Expand All @@ -24,11 +25,12 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type

import anyio
import fastapi
import httpx
import markupsafe
import orjson
from fastapi import Depends, FastAPI, File, HTTPException, UploadFile, WebSocket, status
from fastapi import Depends, FastAPI, HTTPException, WebSocket, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import (
FileResponse,
Expand All @@ -38,22 +40,29 @@
)
from fastapi.security import OAuth2PasswordRequestForm
from fastapi.templating import Jinja2Templates
from gradio_client import utils as client_utils
from gradio_client.documentation import document, set_documentation_group
from jinja2.exceptions import TemplateNotFound
from multipart.multipart import parse_options_header
from starlette.background import BackgroundTask
from starlette.responses import RedirectResponse, StreamingResponse
from starlette.websockets import WebSocketState

import gradio
import gradio.ranged_response as ranged_response
from gradio import processing_utils, route_utils, utils, wasm_utils
from gradio import route_utils, utils, wasm_utils
from gradio.context import Context
from gradio.data_classes import ComponentServerBody, PredictBody, ResetBody
from gradio.deprecation import warn_deprecation
from gradio.exceptions import Error
from gradio.oauth import attach_oauth
from gradio.queueing import Estimation, Event
from gradio.route_utils import Request # noqa: F401
from gradio.route_utils import ( # noqa: F401
GradioMultiPartParser,
GradioUploadFile,
MultiPartException,
Request,
)
from gradio.state_holder import StateHolder
from gradio.utils import (
cancel_tasks,
Expand Down Expand Up @@ -654,16 +663,42 @@ async def get_queue_status():
return app.get_blocks()._queue.get_estimation()

@app.post("/upload", dependencies=[Depends(login_check)])
async def upload_file(
files: List[UploadFile] = File(...),
):
async def upload_file(request: fastapi.Request):
content_type_header = request.headers.get("Content-Type")
content_type: bytes
content_type, _ = parse_options_header(content_type_header)
if content_type != b"multipart/form-data":
raise HTTPException(status_code=400, detail="Invalid content type.")

try:
multipart_parser = GradioMultiPartParser(
request.headers,
request.stream(),
max_files=1000,
max_fields=1000,
)
form = await multipart_parser.parse()
except MultiPartException as exc:
raise HTTPException(status_code=400, detail=exc.message) from exc

output_files = []
for input_file in files:
output_files.append(
await processing_utils.save_uploaded_file(
input_file, app.uploaded_file_dir, app.get_blocks().limiter
)
for temp_file in form.getlist("files"):
assert isinstance(temp_file, GradioUploadFile)
if temp_file.filename:
file_name = Path(temp_file.filename).name
name = client_utils.strip_invalid_filename_characters(file_name)
else:
name = f"tmp{secrets.token_hex(5)}"
directory = Path(app.uploaded_file_dir) / temp_file.sha.hexdigest()
directory.mkdir(exist_ok=True, parents=True)
dest = (directory / name).resolve()
await anyio.to_thread.run_sync(
shutil.move,
temp_file.file.name,
dest,
limiter=app.get_blocks().limiter,
)
output_files.append(dest)
return output_files

@app.on_event("startup")
Expand Down

0 comments on commit be2ed5e

Please sign in to comment.