Skip to content

Commit

Permalink
Inference API wrapper client (#65)
Browse files Browse the repository at this point in the history
  • Loading branch information
osanseviero committed Jul 16, 2021
1 parent a88c772 commit 409f819
Show file tree
Hide file tree
Showing 3 changed files with 224 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/huggingface_hub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,6 @@
from .file_download import cached_download, hf_hub_download, hf_hub_url
from .hf_api import HfApi, HfFolder, repo_type_and_id_from_hf_id
from .hub_mixin import ModelHubMixin
from .inference_api import InferenceApi
from .repository import Repository
from .snapshot_download import snapshot_download
141 changes: 141 additions & 0 deletions src/huggingface_hub/inference_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import logging
from typing import Dict, List, Optional, Union

import requests

from .hf_api import HfApi


logger = logging.getLogger(__name__)


ENDPOINT = "https://api-inference.huggingface.co"

ALL_TASKS = [
# NLP
"text-classification",
"token-classification",
"table-question-answering",
"question-answering",
"zero-shot-classification",
"translation",
"summarization",
"conversational",
"feature-extraction",
"text-generation",
"text2text-generation",
"fill-mask",
"sentence-similarity",
# Audio
"text-to-speech",
"automatic-speech-recognition",
"audio-to-audio",
"audio-source-separation",
"voice-activity-detection",
# Computer vision
"image-classification",
"object-detection",
"image-segmentation",
# Others
"structured-data-classification",
]


class InferenceApi:
"""Client to configure requests and make calls to the HuggingFace Inference API.
Example:
>>> from huggingface_hub.inference_api import InferenceApi
>>> # Mask-fill example
>>> api = InferenceApi("bert-base-uncased")
>>> api(inputs="The goal of life is [MASK].")
>>> >> [{'sequence': 'the goal of life is life.', 'score': 0.10933292657136917, 'token': 2166, 'token_str': 'life'}]
>>> # Question Answering example
>>> api = InferenceApi("deepset/roberta-base-squad2")
>>> inputs = {"question":"What's my name?", "context":"My name is Clara and I live in Berkeley."}
>>> api(inputs)
>>> >> {'score': 0.9326569437980652, 'start': 11, 'end': 16, 'answer': 'Clara'}
>>> # Zero-shot example
>>> api = InferenceApi("typeform/distilbert-base-uncased-mnli")
>>> inputs = "Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!"
>>> params = {"candidate_labels":["refund", "legal", "faq"]}
>>> api(inputs, params)
>>> >> {'sequence': 'Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!', 'labels': ['refund', 'faq', 'legal'], 'scores': [0.9378499388694763, 0.04914155602455139, 0.013008488342165947]}
>>> # Overriding configured task
>>> api = InferenceApi("bert-base-uncased", task="feature-extraction")
"""

def __init__(
self,
repo_id: str,
task: Optional[str] = None,
token: Optional[str] = None,
gpu: Optional[bool] = False,
):
"""Inits headers and API call information.
Args:
repo_id (``str``): Id of repository (e.g. `user/bert-base-uncased`).
task (``str``, `optional`, defaults ``None``): Whether to force a task instead of using task specified in the repository.
token (:obj:`str`, `optional`):
The API token to use as HTTP bearer authorization. This is not the authentication token.
You can find the token in https://huggingface.co/settings/token. Alternatively, you can
find both your organizations and personal API tokens using `HfApi().whoami(token)`.
gpu (``bool``, `optional`, defaults ``False``): Whether to use GPU instead of CPU for inference(requires Startup plan at least).
.. note::
Setting :obj:`token` is required when you want to use a private model.
"""
self.options = {"wait_for_model": True, "use_gpu": gpu}

self.headers = {}
if isinstance(token, str):
self.headers["Authorization"] = "Bearer {}".format(token)

# Configure task
model_info = HfApi().model_info(repo_id=repo_id, token=token)
if not model_info.pipeline_tag and not task:
raise ValueError(
"Task not specified in the repository. Please add it to the model card using pipeline_tag (https://huggingface.co/docs#how-is-a-models-type-of-inference-api-and-widget-determined)"
)

if task and task != model_info.pipeline_tag:
if task not in ALL_TASKS:
raise ValueError(f"Invalid task {task}. Make sure it's valid.")

logger.warning(
"You're using a different task than the one specified in the repository. Be sure to know what you're doing :)"
)
self.task = task
else:
self.task = model_info.pipeline_tag

self.api_url = f"{ENDPOINT}/pipeline/{self.task}/{repo_id}"

def __repr__(self):
items = (f"{k}='{v}'" for k, v in self.__dict__.items())
return f"{self.__class__.__name__}({', '.join(items)})"

def __call__(
self,
inputs: Union[str, Dict, List[str], List[List[str]]],
params: Optional[Dict] = None,
):
payload = {
"inputs": inputs,
"options": self.options,
}

if params:
payload["parameters"] = params

# TODO: Decide if we should raise an error instead of
# returning the json.
response = requests.post(
self.api_url, headers=self.headers, json=payload
).json()
return response
82 changes: 82 additions & 0 deletions tests/test_inference_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import unittest

from huggingface_hub.inference_api import InferenceApi

from .testing_utils import with_production_testing


class InferenceApiTest(unittest.TestCase):
@with_production_testing
def test_simple_inference(self):
api = InferenceApi("bert-base-uncased")
inputs = "Hi, I think [MASK] is cool"
results = api(inputs)
self.assertIsInstance(results, list)

result = results[0]
self.assertIsInstance(result, dict)
self.assertTrue("sequence" in result)
self.assertTrue("score" in result)

@with_production_testing
def test_inference_with_params(self):
api = InferenceApi("typeform/distilbert-base-uncased-mnli")
inputs = "I bought a device but it is not working and I would like to get reimbursed!"
params = {"candidate_labels": ["refund", "legal", "faq"]}
result = api(inputs, params)
self.assertIsInstance(result, dict)
self.assertTrue("sequence" in result)
self.assertTrue("scores" in result)

@with_production_testing
def test_inference_with_dict_inputs(self):
api = InferenceApi("deepset/roberta-base-squad2")
inputs = {
"question": "What's my name?",
"context": "My name is Clara and I live in Berkeley.",
}
result = api(inputs)
self.assertIsInstance(result, dict)
self.assertTrue("score" in result)
self.assertTrue("answer" in result)

@with_production_testing
def test_inference_overriding_task(self):
api = InferenceApi(
"sentence-transformers/paraphrase-albert-small-v2",
task="feature-extraction",
)
inputs = "This is an example again"
result = api(inputs)
self.assertIsInstance(result, list)

@with_production_testing
def test_inference_overriding_invalid_task(self):
with self.assertRaises(
ValueError, msg="Invalid task invalid-task. Make sure it's valid."
):
InferenceApi("bert-base-uncased", task="invalid-task")

@with_production_testing
def test_inference_missing_input(self):
api = InferenceApi("deepset/roberta-base-squad2")
result = api({"question": "What's my name?"})
self.assertIsInstance(result, dict)
self.assertTrue("error" in result)
self.assertTrue("warnings" in result)
self.assertTrue(len(result["warnings"]) > 0)

0 comments on commit 409f819

Please sign in to comment.