From 409f8195f616904868412e3cddbd8fb5023ce24a Mon Sep 17 00:00:00 2001 From: Omar Sanseviero Date: Fri, 16 Jul 2021 15:12:39 +0200 Subject: [PATCH] Inference API wrapper client (#65) --- src/huggingface_hub/__init__.py | 1 + src/huggingface_hub/inference_api.py | 141 +++++++++++++++++++++++++++ tests/test_inference_api.py | 82 ++++++++++++++++ 3 files changed, 224 insertions(+) create mode 100644 src/huggingface_hub/inference_api.py create mode 100644 tests/test_inference_api.py diff --git a/src/huggingface_hub/__init__.py b/src/huggingface_hub/__init__.py index d3e25d500e..7b5fe4678d 100644 --- a/src/huggingface_hub/__init__.py +++ b/src/huggingface_hub/__init__.py @@ -32,5 +32,6 @@ from .file_download import cached_download, hf_hub_download, hf_hub_url from .hf_api import HfApi, HfFolder, repo_type_and_id_from_hf_id from .hub_mixin import ModelHubMixin +from .inference_api import InferenceApi from .repository import Repository from .snapshot_download import snapshot_download diff --git a/src/huggingface_hub/inference_api.py b/src/huggingface_hub/inference_api.py new file mode 100644 index 0000000000..bfe93ef620 --- /dev/null +++ b/src/huggingface_hub/inference_api.py @@ -0,0 +1,141 @@ +import logging +from typing import Dict, List, Optional, Union + +import requests + +from .hf_api import HfApi + + +logger = logging.getLogger(__name__) + + +ENDPOINT = "https://api-inference.huggingface.co" + +ALL_TASKS = [ + # NLP + "text-classification", + "token-classification", + "table-question-answering", + "question-answering", + "zero-shot-classification", + "translation", + "summarization", + "conversational", + "feature-extraction", + "text-generation", + "text2text-generation", + "fill-mask", + "sentence-similarity", + # Audio + "text-to-speech", + "automatic-speech-recognition", + "audio-to-audio", + "audio-source-separation", + "voice-activity-detection", + # Computer vision + "image-classification", + "object-detection", + "image-segmentation", + # Others + "structured-data-classification", +] + + +class InferenceApi: + """Client to configure requests and make calls to the HuggingFace Inference API. + + Example: + + >>> from huggingface_hub.inference_api import InferenceApi + + >>> # Mask-fill example + >>> api = InferenceApi("bert-base-uncased") + >>> api(inputs="The goal of life is [MASK].") + >>> >> [{'sequence': 'the goal of life is life.', 'score': 0.10933292657136917, 'token': 2166, 'token_str': 'life'}] + + >>> # Question Answering example + >>> api = InferenceApi("deepset/roberta-base-squad2") + >>> inputs = {"question":"What's my name?", "context":"My name is Clara and I live in Berkeley."} + >>> api(inputs) + >>> >> {'score': 0.9326569437980652, 'start': 11, 'end': 16, 'answer': 'Clara'} + + >>> # Zero-shot example + >>> api = InferenceApi("typeform/distilbert-base-uncased-mnli") + >>> inputs = "Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!" + >>> params = {"candidate_labels":["refund", "legal", "faq"]} + >>> api(inputs, params) + >>> >> {'sequence': 'Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!', 'labels': ['refund', 'faq', 'legal'], 'scores': [0.9378499388694763, 0.04914155602455139, 0.013008488342165947]} + + >>> # Overriding configured task + >>> api = InferenceApi("bert-base-uncased", task="feature-extraction") + """ + + def __init__( + self, + repo_id: str, + task: Optional[str] = None, + token: Optional[str] = None, + gpu: Optional[bool] = False, + ): + """Inits headers and API call information. + + Args: + repo_id (``str``): Id of repository (e.g. `user/bert-base-uncased`). + task (``str``, `optional`, defaults ``None``): Whether to force a task instead of using task specified in the repository. + token (:obj:`str`, `optional`): + The API token to use as HTTP bearer authorization. This is not the authentication token. + You can find the token in https://huggingface.co/settings/token. Alternatively, you can + find both your organizations and personal API tokens using `HfApi().whoami(token)`. + gpu (``bool``, `optional`, defaults ``False``): Whether to use GPU instead of CPU for inference(requires Startup plan at least). + .. note:: + Setting :obj:`token` is required when you want to use a private model. + """ + self.options = {"wait_for_model": True, "use_gpu": gpu} + + self.headers = {} + if isinstance(token, str): + self.headers["Authorization"] = "Bearer {}".format(token) + + # Configure task + model_info = HfApi().model_info(repo_id=repo_id, token=token) + if not model_info.pipeline_tag and not task: + raise ValueError( + "Task not specified in the repository. Please add it to the model card using pipeline_tag (https://huggingface.co/docs#how-is-a-models-type-of-inference-api-and-widget-determined)" + ) + + if task and task != model_info.pipeline_tag: + if task not in ALL_TASKS: + raise ValueError(f"Invalid task {task}. Make sure it's valid.") + + logger.warning( + "You're using a different task than the one specified in the repository. Be sure to know what you're doing :)" + ) + self.task = task + else: + self.task = model_info.pipeline_tag + + self.api_url = f"{ENDPOINT}/pipeline/{self.task}/{repo_id}" + + def __repr__(self): + items = (f"{k}='{v}'" for k, v in self.__dict__.items()) + return f"{self.__class__.__name__}({', '.join(items)})" + + def __call__( + self, + inputs: Union[str, Dict, List[str], List[List[str]]], + params: Optional[Dict] = None, + ): + payload = { + "inputs": inputs, + "options": self.options, + } + + if params: + payload["parameters"] = params + + # TODO: Decide if we should raise an error instead of + # returning the json. + response = requests.post( + self.api_url, headers=self.headers, json=payload + ).json() + return response diff --git a/tests/test_inference_api.py b/tests/test_inference_api.py new file mode 100644 index 0000000000..f12a6ed4ff --- /dev/null +++ b/tests/test_inference_api.py @@ -0,0 +1,82 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +from huggingface_hub.inference_api import InferenceApi + +from .testing_utils import with_production_testing + + +class InferenceApiTest(unittest.TestCase): + @with_production_testing + def test_simple_inference(self): + api = InferenceApi("bert-base-uncased") + inputs = "Hi, I think [MASK] is cool" + results = api(inputs) + self.assertIsInstance(results, list) + + result = results[0] + self.assertIsInstance(result, dict) + self.assertTrue("sequence" in result) + self.assertTrue("score" in result) + + @with_production_testing + def test_inference_with_params(self): + api = InferenceApi("typeform/distilbert-base-uncased-mnli") + inputs = "I bought a device but it is not working and I would like to get reimbursed!" + params = {"candidate_labels": ["refund", "legal", "faq"]} + result = api(inputs, params) + self.assertIsInstance(result, dict) + self.assertTrue("sequence" in result) + self.assertTrue("scores" in result) + + @with_production_testing + def test_inference_with_dict_inputs(self): + api = InferenceApi("deepset/roberta-base-squad2") + inputs = { + "question": "What's my name?", + "context": "My name is Clara and I live in Berkeley.", + } + result = api(inputs) + self.assertIsInstance(result, dict) + self.assertTrue("score" in result) + self.assertTrue("answer" in result) + + @with_production_testing + def test_inference_overriding_task(self): + api = InferenceApi( + "sentence-transformers/paraphrase-albert-small-v2", + task="feature-extraction", + ) + inputs = "This is an example again" + result = api(inputs) + self.assertIsInstance(result, list) + + @with_production_testing + def test_inference_overriding_invalid_task(self): + with self.assertRaises( + ValueError, msg="Invalid task invalid-task. Make sure it's valid." + ): + InferenceApi("bert-base-uncased", task="invalid-task") + + @with_production_testing + def test_inference_missing_input(self): + api = InferenceApi("deepset/roberta-base-squad2") + result = api({"question": "What's my name?"}) + self.assertIsInstance(result, dict) + self.assertTrue("error" in result) + self.assertTrue("warnings" in result) + self.assertTrue(len(result["warnings"]) > 0)