Skip to content

Commit

Permalink
[WIP] Refactor file normalization to be in the backend and remove it …
Browse files Browse the repository at this point in the history
…from the frontend of each component (#7183)

* processing

* add changeset

* changes

* add changeset

* add changeset

* changes

* changes

* clean

* changes

* add changeset

* add changeset

* root url

* refactor

* testing

* testing

* log

* logs

* fix

* format

* add changeset

* remove

* add root

* format

* apply to everything

* annoying fix

* fixes

* lint

* fixes

* fixes

* fixes

* fix tests

* fix js tests

* format

* fix python tests

* clean guides

* add changeset

* add changeset

* simplify

* add changeset

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: pngwn <hello@pngwn.io>
  • Loading branch information
3 people committed Feb 7, 2024
1 parent 7e9b206 commit 49d9c48
Show file tree
Hide file tree
Showing 35 changed files with 211 additions and 283 deletions.
18 changes: 18 additions & 0 deletions .changeset/strong-chefs-study.md
@@ -0,0 +1,18 @@
---
"@gradio/annotatedimage": minor
"@gradio/app": minor
"@gradio/audio": minor
"@gradio/chatbot": minor
"@gradio/client": minor
"@gradio/file": minor
"@gradio/gallery": minor
"@gradio/image": minor
"@gradio/imageeditor": minor
"@gradio/model3d": minor
"@gradio/simpleimage": minor
"@gradio/video": minor
"gradio": minor
"gradio_client": minor
---

feat:[WIP] Refactor file normalization to be in the backend and remove it from the frontend of each component
67 changes: 7 additions & 60 deletions client/js/src/client.ts
Expand Up @@ -28,7 +28,7 @@ import type {
SpaceStatusCallback
} from "./types.js";

import { FileData, normalise_file } from "./upload";
import { FileData } from "./upload";

import type { Config } from "./types.js";

Expand Down Expand Up @@ -166,7 +166,6 @@ interface Client {
options: {
hf_token?: `hf_${string}`;
status_callback?: SpaceStatusCallback;
normalise_files?: boolean;
}
) => Promise<client_return>;
handle_blob: (
Expand Down Expand Up @@ -259,19 +258,17 @@ export function api_factory(
options: {
hf_token?: `hf_${string}`;
status_callback?: SpaceStatusCallback;
normalise_files?: boolean;
} = { normalise_files: true }
} = {}
): Promise<client_return> {
return new Promise(async (res) => {
const { status_callback, hf_token, normalise_files } = options;
const { status_callback, hf_token } = options;
const return_obj = {
predict,
submit,
view_api,
component_server
};

const transform_files = normalise_files ?? true;
if (
(typeof window === "undefined" || !("WebSocket" in window)) &&
!global.Websocket
Expand Down Expand Up @@ -493,14 +490,7 @@ export function api_factory(
hf_token
)
.then(([output, status_code]) => {
const data = transform_files
? transform_output(
output.data,
api_info,
config.root,
config.root_url
)
: output.data;
const data = output.data;
if (status_code == 200) {
fire_event({
type: "data",
Expand Down Expand Up @@ -628,14 +618,7 @@ export function api_factory(
fire_event({
type: "data",
time: new Date(),
data: transform_files
? transform_output(
data.data,
api_info,
config.root,
config.root_url
)
: data.data,
data: data.data,
endpoint: _endpoint,
fn_index
});
Expand Down Expand Up @@ -750,14 +733,7 @@ export function api_factory(
fire_event({
type: "data",
time: new Date(),
data: transform_files
? transform_output(
data.data,
api_info,
config.root,
config.root_url
)
: data.data,
data: data.data,
endpoint: _endpoint,
fn_index
});
Expand Down Expand Up @@ -878,14 +854,7 @@ export function api_factory(
fire_event({
type: "data",
time: new Date(),
data: transform_files
? transform_output(
data.data,
api_info,
config.root,
config.root_url
)
: data.data,
data: data.data,
endpoint: _endpoint,
fn_index
});
Expand Down Expand Up @@ -1244,28 +1213,6 @@ export const { post_data, upload_files, client, handle_blob } = api_factory(
(...args) => new EventSource(...args)
);

function transform_output(
data: any[],
api_info: any,
root_url: string,
remote_url?: string
): unknown[] {
return data.map((d, i) => {
if (api_info?.returns?.[i]?.component === "File") {
return normalise_file(d, root_url, remote_url);
} else if (api_info?.returns?.[i]?.component === "Gallery") {
return d.map((img) => {
return Array.isArray(img)
? [normalise_file(img[0], root_url, remote_url), img[1]]
: [normalise_file(img, root_url, remote_url), null];
});
} else if (typeof d === "object" && d.path) {
return normalise_file(d, root_url, remote_url);
}
return d;
});
}

interface ApiData {
label: string;
type: {
Expand Down
1 change: 0 additions & 1 deletion client/js/src/index.ts
Expand Up @@ -7,7 +7,6 @@ export {
} from "./client.js";
export type { SpaceStatus } from "./types.js";
export {
normalise_file,
FileData,
upload,
get_fetchable_url_or_file,
Expand Down
69 changes: 6 additions & 63 deletions client/js/src/upload.ts
@@ -1,65 +1,5 @@
import { upload_files } from "./client";

export function normalise_file(
file: FileData | null,
server_url: string,
proxy_url: string | null
): FileData | null;

export function normalise_file(
file: FileData[] | null,
server_url: string,
proxy_url: string | null
): FileData[] | null;

export function normalise_file(
file: FileData[] | FileData | null,
server_url: string, // root: string,
proxy_url: string | null // root_url: string | null
): FileData[] | FileData | null;

export function normalise_file(
file: FileData[] | FileData | null,
server_url: string, // root: string,
proxy_url: string | null // root_url: string | null
): FileData[] | FileData | null {
if (file == null) {
return null;
}

if (Array.isArray(file)) {
const normalized_file: (FileData | null)[] = [];

for (const x of file) {
if (x == null) {
normalized_file.push(null);
} else {
normalized_file.push(normalise_file(x, server_url, proxy_url));
}
}

return normalized_file as FileData[];
}

if (file.is_stream) {
if (proxy_url == null) {
return new FileData({
...file,
url: server_url + "/stream/" + file.path
});
}
return new FileData({
...file,
url: "/proxy=" + proxy_url + "stream/" + file.path
});
}

return new FileData({
...file,
url: get_fetchable_url_or_file(file.path, server_url, proxy_url)
});
}

function is_url(str: string): boolean {
try {
const url = new URL(str);
Expand Down Expand Up @@ -103,9 +43,12 @@ export async function upload(
} else {
if (response.files) {
return response.files.map((f, i) => {
const file = new FileData({ ...file_data[i], path: f });

return normalise_file(file, root, null);
const file = new FileData({
...file_data[i],
path: f,
url: root + "/file=" + f
});
return file;
});
}

Expand Down
8 changes: 7 additions & 1 deletion client/python/gradio_client/utils.py
Expand Up @@ -895,7 +895,7 @@ def get_type(schema: dict):
raise APIInfoParseError(f"Cannot parse type for {schema}")


FILE_DATA = "Dict(path: str, url: str | None, size: int | None, orig_name: str | None, mime_type: str | None)"
FILE_DATA = "Dict(path: str, url: str | None, size: int | None, orig_name: str | None, mime_type: str | None, is_stream: bool)"


def json_schema_to_python_type(schema: Any) -> str:
Expand Down Expand Up @@ -1010,6 +1010,12 @@ def is_file_obj(d):
return isinstance(d, dict) and "path" in d


def is_file_obj_with_url(d):
return (
isinstance(d, dict) and "path" in d and "url" in d and isinstance(d["url"], str)
)


SKIP_COMPONENTS = {
"state",
"row",
Expand Down
3 changes: 2 additions & 1 deletion gradio/blocks.py
Expand Up @@ -1325,7 +1325,7 @@ def preprocess_data(
if input_id in state:
block = state[input_id]
inputs_cached = processing_utils.move_files_to_cache(
inputs[i], block
inputs[i], block, add_urls=True
)
if getattr(block, "data_model", None) and inputs_cached is not None:
if issubclass(block.data_model, GradioModel): # type: ignore
Expand Down Expand Up @@ -1454,6 +1454,7 @@ def postprocess_data(
prediction_value,
block, # type: ignore
postprocess=True,
add_urls=True,
)
output.append(outputs_cached)

Expand Down
7 changes: 6 additions & 1 deletion gradio/components/base.py
Expand Up @@ -195,7 +195,12 @@ def __init__(
self.load_event_to_attach: None | tuple[Callable, float | None] = None
load_fn, initial_value = self.get_load_fn_and_initial_value(value)
initial_value = self.postprocess(initial_value)
self.value = move_files_to_cache(initial_value, self, postprocess=True) # type: ignore
self.value = move_files_to_cache(
initial_value,
self, # type: ignore
postprocess=True,
add_urls=True,
)

if callable(load_fn):
self.attach_load_event(load_fn, every)
Expand Down
2 changes: 1 addition & 1 deletion gradio/components/login_button.py
Expand Up @@ -88,7 +88,7 @@ def _check_login_status(self, request: Request) -> LoginButton:
request.request, "session", None
)
if session is None or "oauth_info" not in session:
return LoginButton(value=self.value, interactive=True)
return LoginButton(value=self.value, interactive=True) # type: ignore
else:
username = session["oauth_info"]["userinfo"]["preferred_username"]
logout_text = self.logout_value.format(username)
Expand Down
1 change: 1 addition & 0 deletions gradio/data_classes.py
Expand Up @@ -168,6 +168,7 @@ class FileData(GradioModel):
size: Optional[int] = None # size in bytes
orig_name: Optional[str] = None # original filename
mime_type: Optional[str] = None
is_stream: bool = False

@property
def is_none(self):
Expand Down
37 changes: 33 additions & 4 deletions gradio/processing_utils.py
Expand Up @@ -236,15 +236,22 @@ def move_resource_to_block_cache(
return block.move_resource_to_block_cache(url_or_file_path)


def move_files_to_cache(data: Any, block: Component, postprocess: bool = False):
"""Move files to cache and replace the file path with the cache path.
def move_files_to_cache(
data: Any,
block: Component,
postprocess: bool = False,
add_urls=False,
) -> dict:
"""Move any files in `data` to cache and (optionally), adds URL prefixes (/file=...) needed to access the cached file.
Also handles the case where the file is on an external Gradio app (/proxy=...).
Runs after .postprocess(), after .process_example(), and before .preprocess().
Runs after .postprocess() and before .preprocess().
Args:
data: The input or output data for a component. Can be a dictionary or a dataclass
block: The component
block: The component whose data is being processed
postprocess: Whether its running from postprocessing
root_url: The root URL of the local server, if applicable
"""

def _move_to_cache(d: dict):
Expand All @@ -259,6 +266,19 @@ def _move_to_cache(d: dict):
temp_file_path = move_resource_to_block_cache(payload.path, block)
assert temp_file_path is not None
payload.path = temp_file_path

if add_urls:
url_prefix = "/stream/" if payload.is_stream else "/file="
if block.proxy_url:
url = f"/proxy={block.proxy_url}{url_prefix}{temp_file_path}"
elif client_utils.is_http_url_like(
temp_file_path
) or temp_file_path.startswith(f"{url_prefix}"):
url = temp_file_path
else:
url = f"{url_prefix}{temp_file_path}"
payload.url = url

return payload.model_dump()

if isinstance(data, (GradioRootModel, GradioModel)):
Expand All @@ -267,6 +287,15 @@ def _move_to_cache(d: dict):
return client_utils.traverse(data, _move_to_cache, client_utils.is_file_obj)


def add_root_url(data, root_url) -> dict:
def _add_root_url(file_dict: dict):
if not client_utils.is_http_url_like(file_dict["url"]):
file_dict["url"] = f'{root_url}{file_dict["url"]}'
return file_dict

return client_utils.traverse(data, _add_root_url, client_utils.is_file_obj_with_url)


def resize_and_crop(img, size, crop_type="center"):
"""
Resize and crop an image to fit the specified size.
Expand Down
6 changes: 6 additions & 0 deletions gradio/route_utils.py
Expand Up @@ -2,6 +2,7 @@

import hashlib
import json
import shutil
from collections import deque
from dataclasses import dataclass as python_dataclass
from tempfile import NamedTemporaryFile, _TemporaryFileWrapper
Expand Down Expand Up @@ -545,3 +546,8 @@ async def parse(self) -> FormData:
if self.upload_progress is not None:
self.upload_progress.set_done(self.upload_id) # type: ignore
return FormData(self.items)


def move_uploaded_files_to_cache(files: list[str], destinations: list[str]) -> None:
for file, dest in zip(files, destinations):
shutil.move(file, dest)

0 comments on commit 49d9c48

Please sign in to comment.