From aea14c44961be8b0e732e05799df8104d7b906ed Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Tue, 6 Feb 2024 12:27:44 -0800 Subject: [PATCH] Refactor Inference API and rename it to Serverless Inference Endpoints (#7295) * changes * changes * add changeset * add changeset * changes * all pipelines * format * clean * add examples * fix audio classification * format * format * fix all pipelines * fixes * fixes * fix tabular * add changeset * added future --------- Co-authored-by: gradio-pr-bot --- .changeset/tired-suns-judge.md | 5 + gradio/external.py | 589 +++++++----------- gradio/external_utils.py | 125 +++- gradio/utils.py | 18 - .../01_using-hugging-face-integrations.md | 12 +- .../01_using-hugging-face-integrations.md | 6 +- test/test_external.py | 11 +- test/test_utils.py | 2 +- 8 files changed, 371 insertions(+), 397 deletions(-) create mode 100644 .changeset/tired-suns-judge.md diff --git a/.changeset/tired-suns-judge.md b/.changeset/tired-suns-judge.md new file mode 100644 index 000000000000..c3e5cd8a1b0b --- /dev/null +++ b/.changeset/tired-suns-judge.md @@ -0,0 +1,5 @@ +--- +"gradio": minor +--- + +feat:Refactor Inference API and rename it to Serverless Inference Endpoints diff --git a/gradio/external.py b/gradio/external.py index 28b849de0d9b..c1b2f7544235 100644 --- a/gradio/external.py +++ b/gradio/external.py @@ -12,30 +12,21 @@ from typing import TYPE_CHECKING, Callable import httpx +import huggingface_hub from gradio_client import Client -from gradio_client import utils as client_utils from gradio_client.client import Endpoint from gradio_client.documentation import document from packaging import version import gradio -from gradio import components, utils +from gradio import components, external_utils, utils from gradio.context import Context from gradio.exceptions import ( - Error, GradioVersionIncompatibleError, ModelNotFoundError, TooManyRequestsError, ) -from gradio.external_utils import ( - cols_to_rows, - encode_to_base64, - get_tabular_examples, - postprocess_label, - rows_to_cols, - streamline_spaces_interface, -) -from gradio.processing_utils import extract_base64_data, save_base64_to_cache, to_binary +from gradio.processing_utils import save_base64_to_cache, to_binary if TYPE_CHECKING: from gradio.blocks import Blocks @@ -109,369 +100,271 @@ def load_blocks_from_repo( return blocks -def chatbot_preprocess(text, state): - payload = { - "inputs": {"generated_responses": None, "past_user_inputs": None, "text": text} - } - if state is not None: - payload["inputs"]["generated_responses"] = state["conversation"][ - "generated_responses" - ] - payload["inputs"]["past_user_inputs"] = state["conversation"][ - "past_user_inputs" - ] - - return payload - - -def chatbot_postprocess(response): - response_json = response.json() - chatbot_value = list( - zip( - response_json["conversation"]["past_user_inputs"], - response_json["conversation"]["generated_responses"], - ) - ) - return chatbot_value, response_json - - def from_model(model_name: str, hf_token: str | None, alias: str | None, **kwargs): model_url = f"https://huggingface.co/{model_name}" api_url = f"https://api-inference.huggingface.co/models/{model_name}" print(f"Fetching model from: {model_url}") headers = {"Authorization": f"Bearer {hf_token}"} if hf_token is not None else {} - - # Checking if model exists, and if so, it gets the pipeline response = httpx.request("GET", api_url, headers=headers) if response.status_code != 200: raise ModelNotFoundError( f"Could not find model: {model_name}. If it is a private or gated model, please provide your Hugging Face access token (https://huggingface.co/settings/tokens) as the argument for the `hf_token` parameter." ) p = response.json().get("pipeline_tag") + + headers["X-Wait-For-Model"] = "true" + client = huggingface_hub.InferenceClient( + model=model_name, headers=headers, token=hf_token + ) + + # For tasks that are not yet supported by the InferenceClient GRADIO_CACHE = os.environ.get("GRADIO_TEMP_DIR") or str( # noqa: N806 Path(tempfile.gettempdir()) / "gradio" ) - pipelines = { - "audio-classification": { - # example model: ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition - "inputs": components.Audio( - sources=["upload"], type="filepath", label="Input", render=False - ), - "outputs": components.Label(label="Class", render=False), - "preprocess": lambda _: to_binary, - "postprocess": lambda r: postprocess_label( - {i["label"].split(", ")[0]: i["score"] for i in r.json()} - ), - }, - "audio-to-audio": { - # example model: facebook/xm_transformer_sm_all-en - "inputs": components.Audio( - sources=["upload"], type="filepath", label="Input", render=False - ), - "outputs": components.Audio(label="Output", render=False), - "preprocess": to_binary, - "postprocess": lambda x: save_base64_to_cache( - encode_to_base64(x), cache_dir=GRADIO_CACHE, file_name="output.wav" - ), - }, - "automatic-speech-recognition": { - # example model: facebook/wav2vec2-base-960h - "inputs": components.Audio( - sources=["upload"], type="filepath", label="Input", render=False - ), - "outputs": components.Textbox(label="Output", render=False), - "preprocess": to_binary, - "postprocess": lambda r: r.json()["text"], - }, - "conversational": { - "inputs": [ - components.Textbox(render=False), - components.State(render=False), - ], # type: ignore - "outputs": [ - components.Chatbot(render=False), - components.State(render=False), - ], # type: ignore - "preprocess": chatbot_preprocess, - "postprocess": chatbot_postprocess, - }, - "feature-extraction": { - # example model: julien-c/distilbert-feature-extraction - "inputs": components.Textbox(label="Input", render=False), - "outputs": components.Dataframe(label="Output", render=False), - "preprocess": lambda x: {"inputs": x}, - "postprocess": lambda r: r.json()[0], - }, - "fill-mask": { - "inputs": components.Textbox(label="Input", render=False), - "outputs": components.Label(label="Classification", render=False), - "preprocess": lambda x: {"inputs": x}, - "postprocess": lambda r: postprocess_label( - {i["token_str"]: i["score"] for i in r.json()} - ), - }, - "image-classification": { - # Example: google/vit-base-patch16-224 - "inputs": components.Image( - type="filepath", label="Input Image", render=False - ), - "outputs": components.Label(label="Classification", render=False), - "preprocess": to_binary, - "postprocess": lambda r: postprocess_label( - {i["label"].split(", ")[0]: i["score"] for i in r.json()} - ), - }, - "question-answering": { - # Example: deepset/xlm-roberta-base-squad2 - "inputs": [ - components.Textbox(lines=7, label="Context", render=False), - components.Textbox(label="Question", render=False), - ], - "outputs": [ - components.Textbox(label="Answer", render=False), - components.Label(label="Score", render=False), - ], - "preprocess": lambda c, q: {"inputs": {"context": c, "question": q}}, - "postprocess": lambda r: (r.json()["answer"], {"label": r.json()["score"]}), - }, - "summarization": { - # Example: facebook/bart-large-cnn - "inputs": components.Textbox(label="Input", render=False), - "outputs": components.Textbox(label="Summary", render=False), - "preprocess": lambda x: {"inputs": x}, - "postprocess": lambda r: r.json()[0]["summary_text"], - }, - "text-classification": { - # Example: distilbert-base-uncased-finetuned-sst-2-english - "inputs": components.Textbox(label="Input", render=False), - "outputs": components.Label(label="Classification", render=False), - "preprocess": lambda x: {"inputs": x}, - "postprocess": lambda r: postprocess_label( - {i["label"].split(", ")[0]: i["score"] for i in r.json()[0]} - ), - }, - "text-generation": { - # Example: gpt2 - "inputs": components.Textbox(label="Input", render=False), - "outputs": components.Textbox(label="Output", render=False), - "preprocess": lambda x: {"inputs": x}, - "postprocess": lambda r: r.json()[0]["generated_text"], - }, - "text2text-generation": { - # Example: valhalla/t5-small-qa-qg-hl - "inputs": components.Textbox(label="Input", render=False), - "outputs": components.Textbox(label="Generated Text", render=False), - "preprocess": lambda x: {"inputs": x}, - "postprocess": lambda r: r.json()[0]["generated_text"], - }, - "translation": { - "inputs": components.Textbox(label="Input", render=False), - "outputs": components.Textbox(label="Translation", render=False), - "preprocess": lambda x: {"inputs": x}, - "postprocess": lambda r: r.json()[0]["translation_text"], - }, - "zero-shot-classification": { - # Example: facebook/bart-large-mnli - "inputs": [ - components.Textbox(label="Input", render=False), - components.Textbox( - label="Possible class names (" "comma-separated)", render=False - ), - components.Checkbox(label="Allow multiple true classes", render=False), - ], - "outputs": components.Label(label="Classification", render=False), - "preprocess": lambda i, c, m: { - "inputs": i, - "parameters": {"candidate_labels": c, "multi_class": m}, - }, - "postprocess": lambda r: postprocess_label( - { - r.json()["labels"][i]: r.json()["scores"][i] - for i in range(len(r.json()["labels"])) - } - ), - }, - "sentence-similarity": { - # Example: sentence-transformers/distilbert-base-nli-stsb-mean-tokens - "inputs": [ - components.Textbox( - value="That is a happy person", - label="Source Sentence", - render=False, - ), - components.Textbox( - lines=7, - placeholder="Separate each sentence by a newline", - label="Sentences to compare to", - render=False, - ), - ], - "outputs": components.Label(label="Classification", render=False), - "preprocess": lambda src, sentences: { - "inputs": { - "source_sentence": src, - "sentences": [s for s in sentences.splitlines() if s != ""], - } - }, - "postprocess": lambda r: postprocess_label( - {f"sentence {i}": v for i, v in enumerate(r.json())} - ), - }, - "text-to-speech": { - # Example: julien-c/ljspeech_tts_train_tacotron2_raw_phn_tacotron_g2p_en_no_space_train - "inputs": components.Textbox(label="Input", render=False), - "outputs": components.Audio(label="Audio", render=False), - "preprocess": lambda x: {"inputs": x}, - "postprocess": lambda x: save_base64_to_cache( - encode_to_base64(x), cache_dir=GRADIO_CACHE, file_name="output.wav" - ), - }, - "text-to-image": { - # example model: osanseviero/BigGAN-deep-128 - "inputs": components.Textbox(label="Input", render=False), - "outputs": components.Image(label="Output", render=False), - "preprocess": lambda x: {"inputs": x}, - "postprocess": lambda x: save_base64_to_cache( - encode_to_base64(x), cache_dir=GRADIO_CACHE, file_name="output.jpg" - ), - }, - "token-classification": { - # example model: huggingface-course/bert-finetuned-ner - "inputs": components.Textbox(label="Input", render=False), - "outputs": components.HighlightedText(label="Output", render=False), - "preprocess": lambda x: {"inputs": x}, - "postprocess": lambda r: r, # Handled as a special case in query_huggingface_api() - }, - "document-question-answering": { - # example model: impira/layoutlm-document-qa - "inputs": [ - components.Image(type="filepath", label="Input Document", render=False), - components.Textbox(label="Question", render=False), - ], - "outputs": components.Label(label="Label", render=False), - "preprocess": lambda img, q: { - "inputs": { - "image": extract_base64_data( - client_utils.encode_url_or_file_to_base64(img["path"]) - ), # Extract base64 data - "question": q, - } - }, - "postprocess": lambda r: postprocess_label( - {i["answer"]: i["score"] for i in r.json()} - ), - }, - "visual-question-answering": { - # example model: dandelin/vilt-b32-finetuned-vqa - "inputs": [ - components.Image(type="filepath", label="Input Image", render=False), - components.Textbox(label="Question", render=False), - ], - "outputs": components.Label(label="Label", render=False), - "preprocess": lambda img, q: { - "inputs": { - "image": extract_base64_data( - client_utils.encode_url_or_file_to_base64(img["path"]) - ), - "question": q, - } - }, - "postprocess": lambda r: postprocess_label( - {i["answer"]: i["score"] for i in r.json()} - ), - }, - "image-to-text": { - # example model: Salesforce/blip-image-captioning-base - "inputs": components.Image( - type="filepath", label="Input Image", render=False - ), - "outputs": components.Textbox(label="Generated Text", render=False), - "preprocess": to_binary, - "postprocess": lambda r: r.json()[0]["generated_text"], - }, - } + def custom_post_binary(data): + data = to_binary({"path": data}) + response = httpx.request("POST", api_url, headers=headers, content=data) + return save_base64_to_cache( + external_utils.encode_to_base64(response), cache_dir=GRADIO_CACHE + ) - if p in ["tabular-classification", "tabular-regression"]: - example_data = get_tabular_examples(model_name) - col_names, example_data = cols_to_rows(example_data) - example_data = [[example_data]] if example_data else None - - pipelines[p] = { - "inputs": components.Dataframe( - label="Input Rows", - type="pandas", - headers=col_names, - col_count=(len(col_names), "fixed"), - render=False, + preprocess = None + postprocess = None + examples = None + + # example model: ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition + if p == "audio-classification": + inputs = components.Audio(type="filepath", label="Input") + outputs = components.Label(label="Class") + postprocess = external_utils.postprocess_label + examples = [ + "https://gradio-builds.s3.amazonaws.com/demo-files/audio_sample.wav" + ] + fn = client.audio_classification + # example model: facebook/xm_transformer_sm_all-en + elif p == "audio-to-audio": + inputs = components.Audio(type="filepath", label="Input") + outputs = components.Audio(label="Output") + examples = [ + "https://gradio-builds.s3.amazonaws.com/demo-files/audio_sample.wav" + ] + fn = custom_post_binary + # example model: facebook/wav2vec2-base-960h + elif p == "automatic-speech-recognition": + inputs = components.Audio(sources=["upload"], type="filepath", label="Input") + outputs = components.Textbox(label="Output") + examples = [ + "https://gradio-builds.s3.amazonaws.com/demo-files/audio_sample.wav" + ] + fn = client.automatic_speech_recognition + # example model: microsoft/DialoGPT-medium + elif p == "conversational": + inputs = [ + components.Textbox(render=False), + components.State(render=False), + ] + outputs = [ + components.Chatbot(render=False), + components.State(render=False), + ] + examples = [["Hello World"]] + preprocess = external_utils.chatbot_preprocess + postprocess = external_utils.chatbot_postprocess + fn = client.conversational + # example model: julien-c/distilbert-feature-extraction + elif p == "feature-extraction": + inputs = components.Textbox(label="Input") + outputs = components.Dataframe(label="Output") + fn = client.feature_extraction + postprocess = utils.resolve_singleton + # example model: distilbert/distilbert-base-uncased + elif p == "fill-mask": + inputs = components.Textbox(label="Input") + outputs = components.Label(label="Classification") + examples = [ + "Hugging Face is the AI community, working together, to [MASK] the future." + ] + postprocess = external_utils.postprocess_mask_tokens + fn = client.fill_mask + # Example: google/vit-base-patch16-224 + elif p == "image-classification": + inputs = components.Image(type="filepath", label="Input Image") + outputs = components.Label(label="Classification") + postprocess = external_utils.postprocess_label + examples = ["https://gradio-builds.s3.amazonaws.com/demo-files/cheetah-002.jpg"] + fn = client.image_classification + # Example: deepset/xlm-roberta-base-squad2 + elif p == "question-answering": + inputs = [ + components.Textbox(label="Question"), + components.Textbox(lines=7, label="Context"), + ] + outputs = [ + components.Textbox(label="Answer"), + components.Label(label="Score"), + ] + examples = [ + [ + "What entity was responsible for the Apollo program?", + "The Apollo program, also known as Project Apollo, was the third United States human spaceflight" + " program carried out by the National Aeronautics and Space Administration (NASA), which accomplished" + " landing the first humans on the Moon from 1969 to 1972.", + ] + ] + postprocess = external_utils.postprocess_question_answering + fn = client.question_answering + # Example: facebook/bart-large-cnn + elif p == "summarization": + inputs = components.Textbox(label="Input") + outputs = components.Textbox(label="Summary") + examples = [ + [ + "The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris. Its base is square, measuring 125 metres (410 ft) on each side. During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world, a title it held for 41 years until the Chrysler Building in New York City was finished in 1930. It was the first structure to reach a height of 300 metres. Due to the addition of a broadcasting aerial at the top of the tower in 1957, it is now taller than the Chrysler Building by 5.2 metres (17 ft). Excluding transmitters, the Eiffel Tower is the second tallest free-standing structure in France after the Millau Viaduct." + ] + ] + fn = client.summarization + # Example: distilbert-base-uncased-finetuned-sst-2-english + elif p == "text-classification": + inputs = components.Textbox(label="Input") + outputs = components.Label(label="Classification") + examples = ["I feel great"] + postprocess = external_utils.postprocess_label + fn = client.text_classification + # Example: gpt2 + elif p == "text-generation": + inputs = components.Textbox(label="Text") + outputs = inputs + examples = ["Once upon a time"] + fn = external_utils.text_generation_wrapper(client) + # Example: valhalla/t5-small-qa-qg-hl + elif p == "text2text-generation": + inputs = components.Textbox(label="Input") + outputs = components.Textbox(label="Generated Text") + examples = ["Translate English to Arabic: How are you?"] + fn = client.text_generation + # Example: Helsinki-NLP/opus-mt-en-ar + elif p == "translation": + inputs = components.Textbox(label="Input") + outputs = components.Textbox(label="Translation") + examples = ["Hello, how are you?"] + fn = client.translation + # Example: facebook/bart-large-mnli + elif p == "zero-shot-classification": + inputs = [ + components.Textbox(label="Input"), + components.Textbox(label="Possible class names (" "comma-separated)"), + components.Checkbox(label="Allow multiple true classes"), + ] + outputs = components.Label(label="Classification") + postprocess = external_utils.postprocess_label + examples = [["I feel great", "happy, sad", False]] + fn = external_utils.zero_shot_classification_wrapper(client) + # Example: sentence-transformers/distilbert-base-nli-stsb-mean-tokens + elif p == "sentence-similarity": + inputs = [ + components.Textbox( + label="Source Sentence", + placeholder="Enter an original sentence", ), - "outputs": components.Dataframe( - label="Predictions", type="array", headers=["prediction"], render=False + components.Textbox( + lines=7, + placeholder="Sentences to compare to -- separate each sentence by a newline", + label="Sentences to compare to", ), - "preprocess": rows_to_cols, - "postprocess": lambda r: { - "headers": ["prediction"], - "data": [[pred] for pred in json.loads(r.text)], - }, - "examples": example_data, - } - - if p is None or p not in pipelines: + ] + outputs = components.JSON(label="Similarity scores") + examples = [["That is a happy person", "That person is very happy"]] + fn = external_utils.sentence_similarity_wrapper(client) + # Example: julien-c/ljspeech_tts_train_tacotron2_raw_phn_tacotron_g2p_en_no_space_train + elif p == "text-to-speech": + inputs = components.Textbox(label="Input") + outputs = components.Audio(label="Audio") + examples = ["Hello, how are you?"] + fn = client.text_to_speech + # example model: osanseviero/BigGAN-deep-128 + elif p == "text-to-image": + inputs = components.Textbox(label="Input") + outputs = components.Image(label="Output") + examples = ["A beautiful sunset"] + fn = client.text_to_image + # example model: huggingface-course/bert-finetuned-ner + elif p == "token-classification": + inputs = components.Textbox(label="Input") + outputs = components.HighlightedText(label="Output") + examples = [ + "Hugging Face is a company based in Paris and New York City that acquired Gradio in 2021." + ] + fn = external_utils.token_classification_wrapper(client) + # example model: impira/layoutlm-document-qa + elif p == "document-question-answering": + inputs = [ + components.Image(type="filepath", label="Input Document"), + components.Textbox(label="Question"), + ] + postprocess = external_utils.postprocess_label + outputs = components.Label(label="Label") + fn = client.document_question_answering + # example model: dandelin/vilt-b32-finetuned-vqa + elif p == "visual-question-answering": + inputs = [ + components.Image(type="filepath", label="Input Image"), + components.Textbox(label="Question"), + ] + outputs = components.Label(label="Label") + postprocess = external_utils.postprocess_visual_question_answering + examples = [ + [ + "https://gradio-builds.s3.amazonaws.com/demo-files/cheetah-002.jpg", + "What animal is in the image?", + ] + ] + fn = client.visual_question_answering + # example model: Salesforce/blip-image-captioning-base + elif p == "image-to-text": + inputs = components.Image(type="filepath", label="Input Image") + outputs = components.Textbox(label="Generated Text") + examples = ["https://gradio-builds.s3.amazonaws.com/demo-files/cheetah-002.jpg"] + fn = client.image_to_text + # example model: rajistics/autotrain-Adult-934630783 + elif p in ["tabular-classification", "tabular-regression"]: + examples = external_utils.get_tabular_examples(model_name) + col_names, examples = external_utils.cols_to_rows(examples) + examples = [[examples]] if examples else None + inputs = components.Dataframe( + label="Input Rows", + type="pandas", + headers=col_names, + col_count=(len(col_names), "fixed"), + render=False, + ) + outputs = components.Dataframe( + label="Predictions", type="array", headers=["prediction"] + ) + fn = external_utils.tabular_wrapper + else: raise ValueError(f"Unsupported pipeline type: {p}") - pipeline = pipelines[p] - - def query_huggingface_api(*params): - # Convert to a list of input components - data = pipeline["preprocess"](*params) - if isinstance( - data, dict - ): # HF doesn't allow additional parameters for binary files (e.g. images or audio files) - data.update({"options": {"wait_for_model": True}}) - data = json.dumps(data) - response = httpx.request("POST", api_url, headers=headers, data=data) # type: ignore - if response.status_code != 200: - errors_json = response.json() - errors, warns = "", "" - if errors_json.get("error"): - errors = f", Error: {errors_json.get('error')}" - if errors_json.get("warnings"): - warns = f", Warnings: {errors_json.get('warnings')}" - raise Error( - f"Could not complete request to HuggingFace API, Status Code: {response.status_code}" - + errors - + warns - ) - if ( - p == "token-classification" - ): # Handle as a special case since HF API only returns the named entities and we need the input as well - ner_groups = response.json() - input_string = params[0] - response = utils.format_ner_list(input_string, ner_groups) - output = pipeline["postprocess"](response) - return output + def query_huggingface_inference_endpoints(*data): + if preprocess is not None: + data = preprocess(*data) + data = fn(*data) # type: ignore + if postprocess is not None: + data = postprocess(data) # type: ignore + return data - if alias is None: - query_huggingface_api.__name__ = model_name - else: - query_huggingface_api.__name__ = alias + query_huggingface_inference_endpoints.__name__ = alias or model_name interface_info = { - "fn": query_huggingface_api, - "inputs": pipeline["inputs"], - "outputs": pipeline["outputs"], + "fn": query_huggingface_inference_endpoints, + "inputs": inputs, + "outputs": outputs, "title": model_name, - "examples": pipeline.get("examples"), + "examples": examples, } kwargs = dict(interface_info, **kwargs) - - # So interface doesn't run pre/postprocess - # except for conversational interfaces which - # are stateful - kwargs["_api_mode"] = p != "conversational" - interface = gradio.Interface(**kwargs) return interface @@ -551,7 +444,7 @@ def from_spaces_interface( iframe_url: str, **kwargs, ) -> Interface: - config = streamline_spaces_interface(config) + config = external_utils.streamline_spaces_interface(config) api_url = f"{iframe_url}/api/predict/" headers = {"Content-Type": "application/json"} if hf_token is not None: diff --git a/gradio/external_utils.py b/gradio/external_utils.py index b46ffcac5269..3d45b5dbe4a5 100644 --- a/gradio/external_utils.py +++ b/gradio/external_utils.py @@ -1,14 +1,15 @@ -"""Utility function for gradio/external.py""" +"""Utility function for gradio/external.py, designed for internal use.""" + +from __future__ import annotations import base64 import math -import operator import re import warnings -from typing import Dict, List, Tuple import httpx import yaml +from huggingface_hub import InferenceClient from gradio import components @@ -17,7 +18,7 @@ ################## -def get_tabular_examples(model_name: str) -> Dict[str, List[float]]: +def get_tabular_examples(model_name: str) -> dict[str, list[float]]: readme = httpx.get(f"https://huggingface.co/{model_name}/resolve/main/README.md") if readme.status_code != 200: warnings.warn(f"Cannot load examples from README for {model_name}", UserWarning) @@ -39,7 +40,7 @@ def get_tabular_examples(model_name: str) -> Dict[str, List[float]]: "See the README.md here: https://huggingface.co/scikit-learn/tabular-playground/blob/main/README.md " "for a reference on how to provide example data to your model." ) - # replace nan with string NaN for inference API + # replace nan with string NaN for inference Endpoints for data in example_data.values(): for i, val in enumerate(data): if isinstance(val, float) and math.isnan(val): @@ -48,8 +49,8 @@ def get_tabular_examples(model_name: str) -> Dict[str, List[float]]: def cols_to_rows( - example_data: Dict[str, List[float]], -) -> Tuple[List[str], List[List[float]]]: + example_data: dict[str, list[float]], +) -> tuple[list[str], list[list[float]]]: headers = list(example_data.keys()) n_rows = max(len(example_data[header] or []) for header in headers) data = [] @@ -65,7 +66,7 @@ def cols_to_rows( return headers, data -def rows_to_cols(incoming_data: Dict) -> Dict[str, Dict[str, Dict[str, List[str]]]]: +def rows_to_cols(incoming_data: dict) -> dict[str, dict[str, dict[str, list[str]]]]: data_column_wise = {} for i, header in enumerate(incoming_data["headers"]): data_column_wise[header] = [str(row[i]) for row in incoming_data["data"]] @@ -77,14 +78,43 @@ def rows_to_cols(incoming_data: Dict) -> Dict[str, Dict[str, Dict[str, List[str] ################## -def postprocess_label(scores: Dict) -> Dict: - sorted_pred = sorted(scores.items(), key=operator.itemgetter(1), reverse=True) - return { - "label": sorted_pred[0][0], - "confidences": [ - {"label": pred[0], "confidence": pred[1]} for pred in sorted_pred - ], - } +def postprocess_label(scores: list[dict[str, str | float]]) -> dict: + return {c["label"]: c["score"] for c in scores} + + +def postprocess_mask_tokens(scores: list[dict[str, str | float]]) -> dict: + return {c["token_str"]: c["score"] for c in scores} + + +def postprocess_question_answering(answer: dict) -> tuple[str, dict]: + return answer["answer"], {answer["answer"]: answer["score"]} + + +def postprocess_visual_question_answering(scores: list[dict[str, str | float]]) -> dict: + return {c["answer"]: c["score"] for c in scores} + + +def zero_shot_classification_wrapper(client: InferenceClient): + def zero_shot_classification_inner(input: str, labels: str, multi_label: bool): + return client.zero_shot_classification( + input, labels.split(","), multi_label=multi_label + ) + + return zero_shot_classification_inner + + +def sentence_similarity_wrapper(client: InferenceClient): + def sentence_similarity_inner(input: str, sentences: str): + return client.sentence_similarity(input, sentences.split("\n")) + + return sentence_similarity_inner + + +def text_generation_wrapper(client: InferenceClient): + def text_generation_inner(input: str): + return input + client.text_generation(input) + + return text_generation_inner def encode_to_base64(r: httpx.Response) -> str: @@ -113,12 +143,73 @@ def encode_to_base64(r: httpx.Response) -> str: return new_base64 +def format_ner_list(input_string: str, ner_groups: list[dict[str, str | int]]): + if len(ner_groups) == 0: + return [(input_string, None)] + + output = [] + end = 0 + prev_end = 0 + + for group in ner_groups: + entity, start, end = group["entity_group"], group["start"], group["end"] + output.append((input_string[prev_end:start], None)) + output.append((input_string[start:end], entity)) + prev_end = end + + output.append((input_string[end:], None)) + return output + + +def token_classification_wrapper(client: InferenceClient): + def token_classification_inner(input: str): + ner_list = client.token_classification(input) + return format_ner_list(input, ner_list) # type: ignore + + return token_classification_inner + + +def chatbot_preprocess(text, state): + if not state: + return text, [], [] + return ( + text, + state["conversation"]["generated_responses"], + state["conversation"]["past_user_inputs"], + ) + + +def chatbot_postprocess(response): + chatbot_history = list( + zip( + response["conversation"]["past_user_inputs"], + response["conversation"]["generated_responses"], + ) + ) + return chatbot_history, response + + +def tabular_wrapper(client: InferenceClient, pipeline: str): + # This wrapper is needed to handle an issue in the InfereneClient where the model name is not + # automatically loaded when using the tabular_classification and tabular_regression methods. + # See: https://github.com/huggingface/huggingface_hub/issues/2015 + def tabular_inner(data): + assert pipeline in ["tabular_classification", "tabular_regression"] + assert client.model is not None + if pipeline == "tabular_classification": + return client.tabular_classification(data, model=client.model) + else: + return client.tabular_regression(data, model=client.model) + + return tabular_inner + + ################## # Helper function for cleaning up an Interface loaded from HF Spaces ################## -def streamline_spaces_interface(config: Dict) -> Dict: +def streamline_spaces_interface(config: dict) -> dict: """Streamlines the interface config dictionary to remove unnecessary keys.""" config["inputs"] = [ components.get_component_instance(component) diff --git a/gradio/utils.py b/gradio/utils.py index 44a6e20013eb..3418808fa0ee 100644 --- a/gradio/utils.py +++ b/gradio/utils.py @@ -386,24 +386,6 @@ def same_children_recursive(children1, chidren2): return True -def format_ner_list(input_string: str, ner_groups: list[dict[str, str | int]]): - if len(ner_groups) == 0: - return [(input_string, None)] - - output = [] - end = 0 - prev_end = 0 - - for group in ner_groups: - entity, start, end = group["entity_group"], group["start"], group["end"] - output.append((input_string[prev_end:start], None)) - output.append((input_string[start:end], entity)) - prev_end = end - - output.append((input_string[end:], None)) - return output - - def delete_none(_dict: dict, skip_value: bool = False) -> dict: """ Delete keys whose values are None from a dictionary diff --git a/guides/06_integrating-other-frameworks/01_using-hugging-face-integrations.md b/guides/06_integrating-other-frameworks/01_using-hugging-face-integrations.md index a2ea4bc3aace..3616b3dd8525 100644 --- a/guides/06_integrating-other-frameworks/01_using-hugging-face-integrations.md +++ b/guides/06_integrating-other-frameworks/01_using-hugging-face-integrations.md @@ -12,9 +12,9 @@ The Hugging Face Hub is a central platform that has hundreds of thousands of [mo Gradio has multiple features that make it extremely easy to leverage existing models and Spaces on the Hub. This guide walks through these features. -## Demos with the Hugging Face Inference API +## Demos with the Hugging Face Inference Endpoints -Hugging Face has a free service called the [Inference API](https://huggingface.co/inference-api), which allows you to send HTTP requests to models in the Hub. For transformers or diffusers-based models, the API can be 2 to 10 times faster than running the inference yourself. The API is free (rate limited), and you can switch to dedicated [Inference Endpoints](https://huggingface.co/pricing) when you want to use it in production. Gradio integrates directly with the Hugging Face Inference API so that you can create a demo simply by specifying a model's name (e.g. `Helsinki-NLP/opus-mt-en-es`), like this: +Hugging Face has a service called [Serverless Inference Endpoints](https://huggingface.co/docs/api-inference/index), which allows you to send HTTP requests to models on the Hub. The API includes a generous free tier, and you can switch to [dedicated Inference Endpoints](https://huggingface.co/inference-endpoints/dedicated) when you want to use it in production. Gradio integrates directly with Serverless Inference Endpoints so that you can create a demo simply by specifying a model's name (e.g. `Helsinki-NLP/opus-mt-en-es`), like this: ```python import gradio as gr @@ -24,11 +24,11 @@ demo = gr.load("Helsinki-NLP/opus-mt-en-es", src="models") demo.launch() ``` -For any Hugging Face model supported in the Inference API, Gradio automatically infers the expected input and output and make the underlying server calls, so you don't have to worry about defining the prediction function. +For any Hugging Face model supported in Inference Endpoints, Gradio automatically infers the expected input and output and make the underlying server calls, so you don't have to worry about defining the prediction function. Notice that we just put specify the model name and state that the `src` should be `models` (Hugging Face's Model Hub). There is no need to install any dependencies (except `gradio`) since you are not loading the model on your computer. -You might notice that the first inference takes about 20 seconds. This happens since the Inference API is loading the model in the server. You get some benefits afterward: +You might notice that the first inference takes a little bit longer. This happens since the Inference Endpoints is loading the model in the server. You get some benefits afterward: - The inference will be much faster. - The server caches your requests. @@ -78,7 +78,7 @@ with gr.Blocks() as demo: demo.launch() ``` -Notice that we use `gr.load()`, the same method we used to load models using the Inference API. However, here we specify that the `src` is `spaces` (Hugging Face Spaces). +Notice that we use `gr.load()`, the same method we used to load models using Inference Endpoints. However, here we specify that the `src` is `spaces` (Hugging Face Spaces). Note: loading a Space in this way may result in slight differences from the original Space. In particular, any attributes that apply to the entire Blocks, such as the theme or custom CSS/JS, will not be loaded. You can copy these properties from the Space you are loading into your own `Blocks` object. @@ -126,7 +126,7 @@ The previous code produces the following interface, which you can try right here That's it! Let's recap the various ways Gradio and Hugging Face work together: -1. You can build a demo around the Inference API without having to load the model easily using `gr.load()`. +1. You can build a demo around Inference Endpoints without having to load the model, by using `gr.load()`. 2. You host your Gradio demo on Hugging Face Spaces, either using the GUI or entirely in Python. 3. You can load demos from Hugging Face Spaces to remix and create new Gradio demos using `gr.load()`. 4. You can convert a `transformers` pipeline into a Gradio demo using `from_pipeline()`. diff --git a/guides/cn/04_integrating-other-frameworks/01_using-hugging-face-integrations.md b/guides/cn/04_integrating-other-frameworks/01_using-hugging-face-integrations.md index b928e92a3c44..d5dc2025285c 100644 --- a/guides/cn/04_integrating-other-frameworks/01_using-hugging-face-integrations.md +++ b/guides/cn/04_integrating-other-frameworks/01_using-hugging-face-integrations.md @@ -52,11 +52,11 @@ demo.launch() -## Using Hugging Face Inference API +## Using Hugging Face Inference Endpoints -Hugging Face 提供了一个名为[Inference API](https://huggingface.co/inference-api)的免费服务,允许您向 Hub 中的模型发送 HTTP 请求。对于基于 transformers 或 diffusers 的模型,API 的速度可以比自己运行推理快 2 到 10 倍。该 API 是免费的(受速率限制),您可以在想要在生产中使用时切换到专用的[推理端点](https://huggingface.co/pricing)。 +Hugging Face 提供了一个名为[Serverless Inference Endpoints](https://huggingface.co/inference-api)的免费服务,允许您向 Hub 中的模型发送 HTTP 请求。对于基于 transformers 或 diffusers 的模型,API 的速度可以比自己运行推理快 2 到 10 倍。该 API 是免费的(受速率限制),您可以在想要在生产中使用时切换到专用的[推理端点](https://huggingface.co/pricing)。 -让我们尝试使用推理 API 而不是自己加载模型的方式进行相同的演示。鉴于 Inference API 支持的 Hugging Face 模型,Gradio 可以自动推断出预期的输入和输出,并进行底层服务器调用,因此您不必担心定义预测函数。以下是代码示例! +让我们尝试使用推理 API 而不是自己加载模型的方式进行相同的演示。鉴于 Inference Endpoints 支持的 Hugging Face 模型,Gradio 可以自动推断出预期的输入和输出,并进行底层服务器调用,因此您不必担心定义预测函数。以下是代码示例! ```python import gradio as gr diff --git a/test/test_external.py b/test/test_external.py index 6965b0ff0698..9ccf172872e9 100644 --- a/test/test_external.py +++ b/test/test_external.py @@ -13,7 +13,8 @@ import gradio as gr from gradio.context import Context from gradio.exceptions import GradioVersionIncompatibleError, InvalidApiNameError -from gradio.external import TooManyRequestsError, cols_to_rows, get_tabular_examples +from gradio.external import TooManyRequestsError +from gradio.external_utils import cols_to_rows, get_tabular_examples """ WARNING: These tests have an external dependency: namely that Hugging Face's @@ -204,7 +205,7 @@ def test_sentiment_model(self): def test_image_classification_model(self): io = gr.load(name="models/google/vit-base-patch16-224") try: - assert io("gradio/test_data/lion.jpg")["label"] == "lion" + assert io("gradio/test_data/lion.jpg")["label"].startswith("lion") except TooManyRequestsError: pass @@ -291,7 +292,9 @@ def test_text_to_image_model(self): io = gr.load("models/osanseviero/BigGAN-deep-128") try: filename = io("chest") - assert filename.endswith(".jpg") or filename.endswith(".jpeg") + assert filename.lower().endswith(".jpg") or filename.lower().endswith( + ".jpeg" + ) except TooManyRequestsError: pass @@ -491,7 +494,7 @@ def test_load_blocks_with_default_values(): ) def test_can_load_tabular_model_with_different_widget_data(hypothetical_readme): with patch( - "gradio.external.get_tabular_examples", return_value=hypothetical_readme + "gradio.external_utils.get_tabular_examples", return_value=hypothetical_readme ): io = gr.load("models/scikit-learn/tabular-playground") check_dataframe(io.config) diff --git a/test/test_utils.py b/test/test_utils.py index 383a6f8e3cd1..f5a708b6b963 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -12,6 +12,7 @@ from typing_extensions import Literal from gradio import EventData, Request +from gradio.external_utils import format_ner_list from gradio.utils import ( abspath, append_unique_suffix, @@ -19,7 +20,6 @@ check_function_inputs_match, colab_check, delete_none, - format_ner_list, get_continuous_fn, get_extension_from_file_path_or_url, get_type_hints,