Skip to content

Commit

Permalink
Image v4 (#6094)
Browse files Browse the repository at this point in the history
* simplify image interface

* changes

* asd

* asd

* more

* add code (#6095)

* more

* fix tests

* add changeset

* fix client build

* fix linting

* fix test

* lint

* Fix tests + lint

* asd

* finish

* webcam selection

* fix backend

* address comments

* fix static checks

* fix everything

* add changeset

* Apply suggestions from code review

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>

* fix examples

* fix tests

* fix tests

---------

Co-authored-by: Freddy Boulton <alfonsoboulton@gmail.com>
Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
  • Loading branch information
4 people committed Oct 30, 2023
1 parent d2dfc1b commit c476bd5
Show file tree
Hide file tree
Showing 91 changed files with 1,425 additions and 2,334 deletions.
20 changes: 20 additions & 0 deletions .changeset/shaky-rings-relate.md
@@ -0,0 +1,20 @@
---
"@gradio/annotatedimage": minor
"@gradio/atoms": minor
"@gradio/audio": minor
"@gradio/chatbot": minor
"@gradio/client": minor
"@gradio/file": minor
"@gradio/gallery": minor
"@gradio/icons": minor
"@gradio/image": minor
"@gradio/model3d": minor
"@gradio/preview": minor
"@gradio/upload": minor
"@gradio/uploadbutton": minor
"@gradio/video": minor
"gradio": minor
"gradio_client": minor
---

feat:Image v4
1 change: 1 addition & 0 deletions client/js/package.json
Expand Up @@ -13,6 +13,7 @@
"./package.json": "./package.json"
},
"dependencies": {
"@gradio/upload": "workspace:^",
"bufferutil": "^4.0.7",
"semiver": "^1.1.0",
"ws": "^8.13.0"
Expand Down
93 changes: 11 additions & 82 deletions client/js/src/client.ts
Expand Up @@ -22,10 +22,11 @@ import type {
UploadResponse,
Status,
SpaceStatus,
SpaceStatusCallback,
FileData
SpaceStatusCallback
} from "./types.js";

import { FileData, normalise_file } from "@gradio/upload/utils";

import type { Config } from "./types.js";

type event = <K extends EventType>(
Expand Down Expand Up @@ -251,7 +252,6 @@ export function api_factory(
submit,
view_api,
component_server
// duplicate
};

const transform_files = normalise_files ?? true;
Expand Down Expand Up @@ -843,28 +843,21 @@ export function api_factory(
);

return Promise.all(
blob_refs.map(async ({ path, blob, data, type }) => {
blob_refs.map(async ({ path, blob, type }) => {
if (blob) {
const file_url = (await upload_files(endpoint, [blob], token))
.files[0];
return { path, file_url, type };
return { path, file_url, type, name: blob?.name };
}
return { path, base64: data, type };
return { path, type };
})
).then((r) => {
r.forEach(({ path, file_url, base64, type }) => {
if (base64) {
update_object(data, base64, path);
} else if (type === "Gallery") {
r.forEach(({ path, file_url, type, name }) => {
if (type === "Gallery") {
update_object(data, file_url, path);
} else if (file_url) {
const o = {
is_file: true,
name: `${file_url}`,
data: null
// orig_name: "file.csv"
};
update_object(data, o, path);
const file = new FileData({ path: file_url, orig_name: name });
update_object(data, file, path);
}
});

Expand Down Expand Up @@ -893,57 +886,13 @@ function transform_output(
? [normalise_file(img[0], root_url, remote_url), img[1]]
: [normalise_file(img, root_url, remote_url), null];
});
} else if (typeof d === "object" && d?.is_file) {
} else if (typeof d === "object" && d.path) {
return normalise_file(d, root_url, remote_url);
}
return d;
});
}

function normalise_file(
file: FileData[],
root: string,
root_url: string | null
): FileData[];
function normalise_file(
file: FileData | string,
root: string,
root_url: string | null
): FileData;
function normalise_file(
file: null,
root: string,
root_url: string | null
): null;
function normalise_file(file, root, root_url): FileData[] | FileData | null {
if (file == null) return null;
if (typeof file === "string") {
return {
name: "file_data",
data: file
};
} else if (Array.isArray(file)) {
const normalized_file: (FileData | null)[] = [];

for (const x of file) {
if (x === null) {
normalized_file.push(null);
} else {
normalized_file.push(normalise_file(x, root, root_url));
}
}

return normalized_file as FileData[];
} else if (file.is_file) {
if (!root_url) {
file.data = root + "/file=" + file.name;
} else {
file.data = "/proxy=" + root_url + "file=" + file.name;
}
}
return file;
}

interface ApiData {
label: string;
type: {
Expand Down Expand Up @@ -1110,7 +1059,6 @@ export async function walk_and_store_blobs(
): Promise<
{
path: string[];
data: string | false;
type: string;
blob: Blob | false;
}[]
Expand Down Expand Up @@ -1142,28 +1090,9 @@ export async function walk_and_store_blobs(
{
path: path,
blob: is_image ? false : new NodeBlob([param]),
data: is_image ? `${param.toString("base64")}` : false,
type
}
];
} else if (
param instanceof Blob ||
(typeof window !== "undefined" && param instanceof File)
) {
if (type === "Image") {
let data;

if (typeof window !== "undefined") {
// browser
data = await image_to_data_uri(param);
} else {
const buffer = await param.arrayBuffer();
data = Buffer.from(buffer).toString("base64");
}

return [{ path, data, type, blob: false }];
}
return [{ path: path, blob: param, type, data: false }];
} else if (typeof param === "object") {
let blob_refs = [];
for (let key in param) {
Expand Down
18 changes: 17 additions & 1 deletion client/js/vite.config.js
@@ -1,4 +1,8 @@
import { defineConfig } from "vite";
import { svelte } from "@sveltejs/vite-plugin-svelte";
import { fileURLToPath } from "url";
import path from "path";
const __dirname = fileURLToPath(new URL(".", import.meta.url));

export default defineConfig({
build: {
Expand All @@ -13,10 +17,22 @@ export default defineConfig({
}
}
},
plugins: [
svelte(),
{
name: "resolve-gradio-client",
enforce: "pre",
resolveId(id) {
if (id === "@gradio/client") {
return path.join(__dirname, "src", "index.ts");
}
}
}
],

ssr: {
target: "node",
format: "esm",
noExternal: ["ws", "semiver"]
noExternal: ["ws", "semiver", "@gradio/upload"]
}
});
32 changes: 11 additions & 21 deletions client/python/gradio_client/client.py
Expand Up @@ -804,7 +804,6 @@ def __init__(self, client: Client, fn_index: int, dependency: dict):
# Only a real API endpoint if backend_fn is True (so not just a frontend function), serializers are valid,
# and api_name is not False (meaning that the developer has explicitly disabled the API endpoint)
self.is_valid = self.dependency["backend_fn"] and self.api_name is not False
print(self.is_valid)

def _get_component_type(self, component_id: int):
component = next(
Expand Down Expand Up @@ -935,20 +934,16 @@ def _upload(
output = [o for ix, o in enumerate(result) if indices[ix] == i]
res = [
{
"is_file": True,
"name": o,
"path": o,
"orig_name": Path(f).name,
"data": None,
}
for f, o in zip(fs, output)
]
else:
o = next(o for ix, o in enumerate(result) if indices[ix] == i)
res = {
"is_file": True,
"name": o,
"path": o,
"orig_name": Path(fs).name,
"data": None,
}
uploaded.append(res)
return uploaded
Expand Down Expand Up @@ -1011,7 +1006,7 @@ def serialize(self, *data) -> tuple:
data = self._add_uploaded_files_to_data(data, uploaded_files)
data = utils.traverse(
data,
lambda s: {"name": s, "is_file": True, "data": None},
lambda s: {"path": s},
utils.is_url,
)
o = tuple(data)
Expand All @@ -1029,18 +1024,14 @@ def _download_file(
if isinstance(x, str):
file_name = utils.decode_base64_to_file(x, dir=save_dir).name
elif isinstance(x, dict):
if x.get("is_file"):
filepath = x.get("name")
assert filepath is not None, f"The 'name' field is missing in {x}"
file_name = utils.download_file(
root_url + "file=" + filepath,
hf_token=hf_token,
dir=save_dir,
)
else:
data = x.get("data")
assert data is not None, f"The 'data' field is missing in {x}"
file_name = utils.decode_base64_to_file(data, dir=save_dir).name
filepath = x.get("path")
assert filepath is not None, f"The 'path' field is missing in {x}"
file_name = utils.download_file(
root_url + "file=" + filepath,
hf_token=hf_token,
dir=save_dir,
)

else:
raise ValueError(
f"A FileSerializable component can only deserialize a string or a dict, not a {type(x)}: {x}"
Expand Down Expand Up @@ -1100,7 +1091,6 @@ def __init__(self, client: Client, fn_index: int, dependency: dict):
self.is_valid = self.dependency["backend_fn"] and self.api_name is not False
except SerializationSetupError:
self.is_valid = False
print("v3", self.is_valid)

def __repr__(self):
return f"Endpoint src: {self.client.src}, api_name: {self.api_name}, fn_index: {self.fn_index}"
Expand Down
16 changes: 8 additions & 8 deletions client/python/gradio_client/media_data.py

Large diffs are not rendered by default.

3 changes: 1 addition & 2 deletions client/python/gradio_client/serializing.py
Expand Up @@ -282,12 +282,11 @@ def _serialize_single(
filename = str(Path(load_dir) / x)
size = Path(filename).stat().st_size
return {
"name": filename,
"name": filename or None,
"data": None
if allow_links
else utils.encode_url_or_file_to_base64(filename),
"orig_name": Path(filename).name,
"is_file": allow_links,
"size": size,
}

Expand Down
4 changes: 2 additions & 2 deletions client/python/gradio_client/utils.py
Expand Up @@ -580,7 +580,7 @@ def get_type(schema: dict):
raise APIInfoParseError(f"Cannot parse type for {schema}")


FILE_DATA = "Dict(name: str | None, data: str | None, size: int | None, is_file: bool | None, orig_name: str | None, mime_type: str | None)"
FILE_DATA = "Dict(path: str, url: str | None, size: int | None, orig_name: str | None, mime_type: str | None)"


def json_schema_to_python_type(schema: Any) -> str:
Expand Down Expand Up @@ -690,7 +690,7 @@ def is_url(s):


def is_file_obj(d):
return isinstance(d, dict) and "name" in d and "is_file" in d and "data" in d
return isinstance(d, dict) and "path" in d


SKIP_COMPONENTS = {
Expand Down
5 changes: 3 additions & 2 deletions client/python/test/test_client.py
Expand Up @@ -308,6 +308,7 @@ def test_stream_audio(self, stream_audio):
assert Path(job2.result()).exists()
assert all(Path(p).exists() for p in job2.outputs())

@pytest.mark.xfail
def test_upload_file_private_space_v4(self):
client = Client(
src="gradio-tests/not-actually-private-file-upload-v4", hf_token=HF_TOKEN
Expand Down Expand Up @@ -1081,9 +1082,9 @@ def test_upload_v4(self):
res = []
for re in results:
if isinstance(re, list):
res.append([r["name"] for r in re])
res.append([r["path"] for r in re])
else:
res.append(re["name"])
res.append(re["path"])

assert res == [
"file1",
Expand Down
2 changes: 1 addition & 1 deletion demo/cancel_events/run.ipynb
@@ -1 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: cancel_events"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import time\n", "import gradio as gr\n", "\n", "\n", "def fake_diffusion(steps):\n", " for i in range(steps):\n", " print(f\"Current step: {i}\")\n", " time.sleep(0.2)\n", " yield str(i)\n", "\n", "\n", "def long_prediction(*args, **kwargs):\n", " time.sleep(10)\n", " return 42\n", "\n", "\n", "with gr.Blocks() as demo:\n", " with gr.Row():\n", " with gr.Column():\n", " n = gr.Slider(1, 10, value=9, step=1, label=\"Number Steps\")\n", " run = gr.Button(value=\"Start Iterating\")\n", " output = gr.Textbox(label=\"Iterative Output\")\n", " stop = gr.Button(value=\"Stop Iterating\")\n", " with gr.Column():\n", " textbox = gr.Textbox(label=\"Prompt\")\n", " prediction = gr.Number(label=\"Expensive Calculation\")\n", " run_pred = gr.Button(value=\"Run Expensive Calculation\")\n", " with gr.Column():\n", " cancel_on_change = gr.Textbox(label=\"Cancel Iteration and Expensive Calculation on Change\")\n", " cancel_on_submit = gr.Textbox(label=\"Cancel Iteration and Expensive Calculation on Submit\")\n", " echo = gr.Textbox(label=\"Echo\")\n", " with gr.Row():\n", " with gr.Column():\n", " image = gr.Image(source=\"webcam\", tool=\"editor\", label=\"Cancel on edit\", interactive=True)\n", " with gr.Column():\n", " video = gr.Video(source=\"webcam\", label=\"Cancel on play\", interactive=True)\n", "\n", " click_event = run.click(fake_diffusion, n, output)\n", " stop.click(fn=None, inputs=None, outputs=None, cancels=[click_event])\n", " pred_event = run_pred.click(fn=long_prediction, inputs=[textbox], outputs=prediction)\n", "\n", " cancel_on_change.change(None, None, None, cancels=[click_event, pred_event])\n", " cancel_on_submit.submit(lambda s: s, cancel_on_submit, echo, cancels=[click_event, pred_event])\n", " image.edit(None, None, None, cancels=[click_event, pred_event])\n", " video.play(None, None, None, cancels=[click_event, pred_event])\n", "\n", " demo.queue(concurrency_count=2, max_size=20)\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: cancel_events"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import time\n", "import gradio as gr\n", "\n", "\n", "def fake_diffusion(steps):\n", " for i in range(steps):\n", " print(f\"Current step: {i}\")\n", " time.sleep(0.2)\n", " yield str(i)\n", "\n", "\n", "def long_prediction(*args, **kwargs):\n", " time.sleep(10)\n", " return 42\n", "\n", "\n", "with gr.Blocks() as demo:\n", " with gr.Row():\n", " with gr.Column():\n", " n = gr.Slider(1, 10, value=9, step=1, label=\"Number Steps\")\n", " run = gr.Button(value=\"Start Iterating\")\n", " output = gr.Textbox(label=\"Iterative Output\")\n", " stop = gr.Button(value=\"Stop Iterating\")\n", " with gr.Column():\n", " textbox = gr.Textbox(label=\"Prompt\")\n", " prediction = gr.Number(label=\"Expensive Calculation\")\n", " run_pred = gr.Button(value=\"Run Expensive Calculation\")\n", " with gr.Column():\n", " cancel_on_change = gr.Textbox(label=\"Cancel Iteration and Expensive Calculation on Change\")\n", " cancel_on_submit = gr.Textbox(label=\"Cancel Iteration and Expensive Calculation on Submit\")\n", " echo = gr.Textbox(label=\"Echo\")\n", " with gr.Row():\n", " with gr.Column():\n", " image = gr.Image(sources=[\"webcam\"], tool=\"editor\", label=\"Cancel on edit\", interactive=True)\n", " with gr.Column():\n", " video = gr.Video(source=\"webcam\", label=\"Cancel on play\", interactive=True)\n", "\n", " click_event = run.click(fake_diffusion, n, output)\n", " stop.click(fn=None, inputs=None, outputs=None, cancels=[click_event])\n", " pred_event = run_pred.click(fn=long_prediction, inputs=[textbox], outputs=prediction)\n", "\n", " cancel_on_change.change(None, None, None, cancels=[click_event, pred_event])\n", " cancel_on_submit.submit(lambda s: s, cancel_on_submit, echo, cancels=[click_event, pred_event])\n", " image.edit(None, None, None, cancels=[click_event, pred_event])\n", " video.play(None, None, None, cancels=[click_event, pred_event])\n", "\n", " demo.queue(concurrency_count=2, max_size=20)\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
2 changes: 1 addition & 1 deletion demo/cancel_events/run.py
Expand Up @@ -31,7 +31,7 @@ def long_prediction(*args, **kwargs):
echo = gr.Textbox(label="Echo")
with gr.Row():
with gr.Column():
image = gr.Image(source="webcam", tool="editor", label="Cancel on edit", interactive=True)
image = gr.Image(sources=["webcam"], tool="editor", label="Cancel on edit", interactive=True)
with gr.Column():
video = gr.Video(source="webcam", label="Cancel on play", interactive=True)

Expand Down
Binary file added demo/image-simple/cheetah.jpg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions demo/image-simple/run.ipynb
@@ -0,0 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: image-simple"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/image-simple/cheetah.jpg"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "\n", "def image(im):\n", " return im\n", "\n", "\n", "with gr.Blocks() as demo:\n", " im = gr.Image()\n", " im2 = gr.Image()\n", " btn = gr.Button()\n", " btn.click(lambda x: x, outputs=im2, inputs=im)\n", "\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
16 changes: 16 additions & 0 deletions demo/image-simple/run.py
@@ -0,0 +1,16 @@
import gradio as gr


def image(im):
return im


with gr.Blocks() as demo:
im = gr.Image()
im2 = gr.Image()
btn = gr.Button()
btn.click(lambda x: x, outputs=im2, inputs=im)


if __name__ == "__main__":
demo.launch()

0 comments on commit c476bd5

Please sign in to comment.