Skip to content

Commit

Permalink
Client fixes (#8272)
Browse files Browse the repository at this point in the history
* fix param name

* fix hidden state variable

* pass jwt to heartbeat event

* notebooks

* format

* add changeset

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
  • Loading branch information
pngwn and gradio-pr-bot committed May 13, 2024
1 parent 30463c5 commit fbf4edd
Show file tree
Hide file tree
Showing 8 changed files with 84 additions and 28 deletions.
7 changes: 7 additions & 0 deletions .changeset/weak-bugs-itch.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
"@gradio/app": minor
"@gradio/client": minor
"gradio": minor
---

feat:Client fixes
18 changes: 12 additions & 6 deletions client/js/src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -139,18 +139,24 @@ export class Client {
if (config) {
this.config = config;
if (this.config && this.config.connect_heartbeat) {
// connect to the heartbeat endpoint via GET request
const heartbeat_url = new URL(
`${this.config.root}/heartbeat/${this.session_hash}`
);
this.heartbeat_event = await this.stream(heartbeat_url); // Just connect to the endpoint without parsing the response. Ref: https://github.com/gradio-app/gradio/pull/7974#discussion_r1557717540

if (this.config.space_id && this.options.hf_token) {
this.jwt = await get_jwt(
this.config.space_id,
this.options.hf_token
);
}

// connect to the heartbeat endpoint via GET request
const heartbeat_url = new URL(
`${this.config.root}/heartbeat/${this.session_hash}`
);

// if the jwt is available, add it to the query params
if (this.jwt) {
heartbeat_url.searchParams.set("jwt", this.jwt);
}

this.heartbeat_event = await this.stream(heartbeat_url); // Just connect to the endpoint without parsing the response. Ref: https://github.com/gradio-app/gradio/pull/7974#discussion_r1557717540
}
}
});
Expand Down
46 changes: 39 additions & 7 deletions client/js/src/helpers/api_info.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,11 @@ export function transform_api_info(
Object.entries(api_info[category]).forEach(
([endpoint, { parameters, returns }]) => {
const dependencyIndex =
config.dependencies.findIndex((dep) => dep.api_name === endpoint) ||
config.dependencies.findIndex(
(dep) =>
dep.api_name === endpoint ||
dep.api_name === endpoint.replace("/", "")
) ||
api_map[endpoint.replace("/", "")] ||
-1;

Expand All @@ -86,24 +90,52 @@ export function transform_api_info(
? config.dependencies[dependencyIndex].types
: { continuous: false, generator: false };

if (
dependencyIndex !== -1 &&
config.dependencies[dependencyIndex]?.inputs?.length !==
parameters.length
) {
const components = config.dependencies[dependencyIndex].inputs.map(
(input) => config.components.find((c) => c.id === input)?.type
);

try {
components.forEach((comp, idx) => {
if (comp === "state") {
const new_param = {
component: "state",
example: null,
parameter_default: null,
parameter_has_default: true,
parameter_name: null,
hidden: true
};

// @ts-ignore
parameters.splice(idx, 0, new_param);
}
});
} catch (e) {}
}

const transform_type = (
data: ApiData,
component: string,
serializer: string,
signature_type: "return" | "parameter"
): JsApiData => ({
...data,
description: get_description(data.type, serializer),
description: get_description(data?.type, serializer),
type:
get_type(data.type, component, serializer, signature_type) || ""
get_type(data?.type, component, serializer, signature_type) || ""
});

transformed_info[category][endpoint] = {
parameters: parameters.map((p: ApiData) =>
transform_type(p, p.component, p.serializer, "parameter")
transform_type(p, p?.component, p?.serializer, "parameter")
),
returns: returns.map((r: ApiData) =>
transform_type(r, r.component, r.serializer, "return")
transform_type(r, r?.component, r?.serializer, "return")
),
type: dependencyTypes
};
Expand All @@ -121,7 +153,7 @@ export function get_type(
serializer: string,
signature_type: "return" | "parameter"
): string | undefined {
switch (type.type) {
switch (type?.type) {
case "string":
return "string";
case "boolean":
Expand Down Expand Up @@ -166,7 +198,7 @@ export function get_description(
} else if (serializer === "FileSerializable") {
return "array of files or single file";
}
return type.description;
return type?.description;
}

export function handle_message(
Expand Down
2 changes: 1 addition & 1 deletion demo/chatinterface_multimodal/run.ipynb
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatinterface_multimodal"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "def echo(message, history):\n", " return message[\"text\"]\n", "\n", "demo = gr.ChatInterface(fn=echo, examples=[{\"text\": \"hello\"}, {\"text\": \"hola\"}, {\"text\": \"merhaba\"}], title=\"Echo Bot\", multimodal=True)\n", "demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatinterface_multimodal"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "\n", "def echo(message, history):\n", " return message[\"text\"]\n", "\n", "\n", "demo = gr.ChatInterface(\n", " fn=echo,\n", " examples=[{\"text\": \"hello\"}, {\"text\": \"hola\"}, {\"text\": \"merhaba\"}],\n", " title=\"Echo Bot\",\n", " multimodal=True,\n", ")\n", "demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
9 changes: 8 additions & 1 deletion demo/chatinterface_multimodal/run.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
import gradio as gr


def echo(message, history):
return message["text"]

demo = gr.ChatInterface(fn=echo, examples=[{"text": "hello"}, {"text": "hola"}, {"text": "merhaba"}], title="Echo Bot", multimodal=True)

demo = gr.ChatInterface(
fn=echo,
examples=[{"text": "hello"}, {"text": "hola"}, {"text": "merhaba"}],
title="Echo Bot",
multimodal=True,
)
demo.launch()
2 changes: 1 addition & 1 deletion demo/chatinterface_system_prompt/run.ipynb
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatinterface_system_prompt"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import time\n", "\n", "def echo(message, history, system_prompt, tokens):\n", " response = f\"System prompt: {system_prompt}\\n Message: {message}.\"\n", " for i in range(min(len(response), int(tokens))):\n", " time.sleep(0.05)\n", " yield response[: i+1]\n", "\n", "demo = gr.ChatInterface(echo, \n", " additional_inputs=[\n", " gr.Textbox(\"You are helpful AI.\", label=\"System Prompt\"), \n", " gr.Slider(10, 100)\n", " ]\n", " )\n", "\n", "if __name__ == \"__main__\":\n", " demo.queue().launch()"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatinterface_system_prompt"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import time\n", "\n", "\n", "def echo(message, history, system_prompt, tokens):\n", " response = f\"System prompt: {system_prompt}\\n Message: {message}.\"\n", " for i in range(min(len(response), int(tokens))):\n", " time.sleep(0.05)\n", " yield response[: i + 1]\n", "\n", "\n", "demo = gr.ChatInterface(\n", " echo,\n", " additional_inputs=[\n", " gr.Textbox(\"You are helpful AI.\", label=\"System Prompt\"),\n", " gr.Slider(10, 100),\n", " ],\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " demo.queue().launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
19 changes: 11 additions & 8 deletions demo/chatinterface_system_prompt/run.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
import gradio as gr
import time


def echo(message, history, system_prompt, tokens):
response = f"System prompt: {system_prompt}\n Message: {message}."
for i in range(min(len(response), int(tokens))):
time.sleep(0.05)
yield response[: i+1]
yield response[: i + 1]


demo = gr.ChatInterface(echo,
additional_inputs=[
gr.Textbox("You are helpful AI.", label="System Prompt"),
gr.Slider(10, 100)
]
)
demo = gr.ChatInterface(
echo,
additional_inputs=[
gr.Textbox("You are helpful AI.", label="System Prompt"),
gr.Slider(10, 100),
],
)

if __name__ == "__main__":
demo.queue().launch()
demo.queue().launch()
9 changes: 5 additions & 4 deletions js/app/src/api_docs/CodeSnippet.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ result = client.<span class="highlight">predict</span
</div>
<div bind:this={js_code}>
<pre>import &lbrace; Client &rbrace; from "@gradio/client";
{#each blob_examples as { label, type, python_type, component, example_input, serializer }, i}<!--
{#each blob_examples as { component, example_input }, i}<!--
-->
const response_{i} = await fetch("{example_input.url}");
const example{component} = await response_{i}.blob();
Expand All @@ -88,11 +88,12 @@ const client = await Client.connect(<span class="token string">"{root}"</span>);
const result = await client.predict({#if named}<span class="api-name"
>"/{dependency.api_name}"</span
>{:else}{dependency_index}{/if}, &lbrace; <!--
-->{#each endpoint_parameters as { label, type, python_type, component, example_input, serializer }, i}<!--
-->{#each endpoint_parameters as { label, parameter_name, type, python_type, component, example_input, serializer }, i}<!--
-->{#if blob_components.includes(component)}<!--
-->
<span
class="example-inputs">{label}: example{component}</span
class="example-inputs"
>{parameter_name}: example{component}</span
>, <!--
--><span class="desc"
><!--
Expand All @@ -104,7 +105,7 @@ const result = await client.predict({#if named}<span class="api-name"
-->{:else}<!--
-->
<span class="example-inputs"
>{label}: {represent_value(
>{parameter_name}: {represent_value(
example_input,
python_type.type,
"js"
Expand Down

0 comments on commit fbf4edd

Please sign in to comment.