Skip to content

Commit

Permalink
Fix file upload on windows (#6289)
Browse files Browse the repository at this point in the history
* use background task

* add changeset

* Add code

* finish comment lol

* lint

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
  • Loading branch information
freddyaboulton and gradio-pr-bot committed Nov 3, 2023
1 parent 36f1972 commit 5668036
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 9 deletions.
5 changes: 5 additions & 0 deletions .changeset/quiet-walls-float.md
@@ -0,0 +1,5 @@
---
"gradio": patch
---

fix:Fix file upload on windows
33 changes: 24 additions & 9 deletions gradio/routes.py
Expand Up @@ -25,12 +25,11 @@
from queue import Empty as EmptyQueue
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Optional, Type

import anyio
import fastapi
import httpx
import markupsafe
import orjson
from fastapi import Depends, FastAPI, HTTPException, status
from fastapi import BackgroundTasks, Depends, FastAPI, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import (
FileResponse,
Expand Down Expand Up @@ -117,6 +116,11 @@ def toorjson(value):
client = httpx.AsyncClient()


def move_uploaded_files_to_cache(files: list[str], destinations: list[str]) -> None:
for file, dest in zip(files, destinations):
shutil.move(file, dest)


class App(FastAPI):
"""
FastAPI App Wrapper
Expand Down Expand Up @@ -678,7 +682,7 @@ async def get_queue_status():
return app.get_blocks()._queue.get_estimation()

@app.post("/upload", dependencies=[Depends(login_check)])
async def upload_file(request: fastapi.Request):
async def upload_file(request: fastapi.Request, bg_tasks: BackgroundTasks):
content_type_header = request.headers.get("Content-Type")
content_type: bytes
content_type, _ = parse_options_header(content_type_header)
Expand All @@ -697,6 +701,8 @@ async def upload_file(request: fastapi.Request):
raise HTTPException(status_code=400, detail=exc.message) from exc

output_files = []
files_to_copy = []
locations: list[str] = []
for temp_file in form.getlist("files"):
assert isinstance(temp_file, GradioUploadFile)
if temp_file.filename:
Expand All @@ -707,13 +713,22 @@ async def upload_file(request: fastapi.Request):
directory = Path(app.uploaded_file_dir) / temp_file.sha.hexdigest()
directory.mkdir(exist_ok=True, parents=True)
dest = (directory / name).resolve()
await anyio.to_thread.run_sync(
shutil.move,
temp_file.file.name,
dest,
limiter=app.get_blocks().limiter,
)
temp_file.file.close()
# we need to move the temp file to the cache directory
# but that's possibly blocking and we're in an async function
# so we try to rename (this is what shutil.move tries first)
# which should be super fast.
# if that fails, we move in the background.
try:
os.rename(temp_file.file.name, dest)
except OSError:
files_to_copy.append(temp_file.file.name)
locations.append(temp_file.file.name)
output_files.append(dest)
if files_to_copy:
bg_tasks.add_task(
move_uploaded_files_to_cache, files_to_copy, locations
)
return output_files

@app.on_event("startup")
Expand Down

0 comments on commit 5668036

Please sign in to comment.