Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

True OpenAI drop-in replacement by InferenceClient #2384

Merged
merged 7 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 60 additions & 1 deletion docs/source/en/guides/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ Let's get started with a text-to-image task:

In the example above, we initialized an [`InferenceClient`] with the default parameters. The only thing you need to know is the [task](#supported-tasks) you want to perform. By default, the client will connect to the Inference API and select a model to complete the task. In our example, we generated an image from a text prompt. The returned value is a `PIL.Image` object that can be saved to a file. For more details, check out the [`~InferenceClient.text_to_image`] documentation.

Let's now see an example using the `chat_completion` API. This task uses an LLM to generate a response from a list of messages:
Let's now see an example using the [~`InferenceClient.chat_completion`] API. This task uses an LLM to generate a response from a list of messages:

```python
>>> from huggingface_hub import InferenceClient
Expand Down Expand Up @@ -147,6 +147,65 @@ endpoints.

</Tip>

## OpenAI compatibility

The `chat_completion` task follows [OpenAI's Python client](https://github.com/openai/openai-python) syntax. What does it mean for you? It means that if you are used to play with `OpenAI`'s APIs you will be able to switch to `huggingface_hub.InferenceClient` to work with open-source models by updating just 2 line of code!

```py
# instead of `from openai import OpenAI`
from huggingface_hub import InferenceClient

# instead of `client = OpenAI(...)`
client = InferenceClient(
base_url=...,
api_key=...,
)


output = client.chat.completions.create(
model="meta-llama/Meta-Llama-3-8B-Instruct",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Count to 10"},
],
stream=True,
max_tokens=1024,
)

for chunk in output:
print(chunk.choices[0].delta.content)
```

And that's it! The only required changes are to replace `from openai import OpenAI` by `from huggingface_hub import InferenceClient` and `client = OpenAI(...)` by `client = InferenceClient(...)`. You can chose any LLM model from the Hugging Face Hub by passing its model id as `model` parameter. [Here is a list](https://huggingface.co/models?pipeline_tag=text-generation&other=conversational,text-generation-inference&sort=trending) of supported models. For authentication, you should pass a valid [User Access Token](https://huggingface.co/settings/tokens) as `api_key` or authenticate using `huggingface_hub` (see the [authentication guide](https://huggingface.co/docs/huggingface_hub/quick-start#authentication)).

All input parameters and output format are strictly the same. In particular, you can pass `stream=True` to receive tokens as they are generated. You can also use the [`AsyncInferenceClient`] to run inference using `asyncio`:

```py
import asyncio
# instead of `from openai import AsyncOpenAI`
from huggingface_hub import AsyncInferenceClient

# instead of `client = AsyncOpenAI()`
client = AsyncOpenAI()

async def main():
stream = await client.chat.completions.create(
model="meta-llama/Meta-Llama-3-8B-Instruct",
messages=[{"role": "user", "content": "Say this is a test"}],
stream=True,
)
async for chunk in stream:
print(chunk.choices[0].delta.content or "", end="")

asyncio.run(main())
```
Wauplin marked this conversation as resolved.
Show resolved Hide resolved

<Tip>

`InferenceClient.chat.completions.create` is simply an alias for `InferenceClient.chat_completion`. Check out the package reference of [`~InferenceClient.chat_completion`] for more details. `base_url` and `api_key` parameters when instantiating the client are also aliases for `model` and `token`. These aliases have been defined to reduce friction when switching from `OpenAI` to `InferenceClient`.

</Tip>

## Supported tasks

[`InferenceClient`]'s goal is to provide the easiest interface to run inference on Hugging Face models. It
Expand Down
140 changes: 116 additions & 24 deletions src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@
get_session,
hf_raise_for_status,
)
from huggingface_hub.utils._deprecation import _deprecate_positional_args


if TYPE_CHECKING:
Expand All @@ -134,12 +135,16 @@ class InferenceClient:

Args:
model (`str`, `optional`):
The model to run inference with. Can be a model id hosted on the Hugging Face Hub, e.g. `bigcode/starcoder`
The model to run inference with. Can be a model id hosted on the Hugging Face Hub, e.g. `meta-llama/Meta-Llama-3-8B-Instruct`
or a URL to a deployed Inference Endpoint. Defaults to None, in which case a recommended model is
automatically selected for the task.
Note: for better compatibility with OpenAI's client, `model` has been aliased as `base_url`. Those 2
arguments are mutually exclusive and have the exact same behavior.
token (`str` or `bool`, *optional*):
Hugging Face token. Will default to the locally saved token if not provided.
Pass `token=False` if you don't want to send your token to the server.
Note: for better compatibility with OpenAI's client, `token` has been aliased as `api_key`. Those 2
arguments are mutually exclusive and have the exact same behavior.
timeout (`float`, `optional`):
The maximum number of seconds to wait for a response from the server. Loading a new model in Inference
API can take up to several minutes. Defaults to None, meaning it will loop until the server is available.
Expand All @@ -148,26 +153,53 @@ class InferenceClient:
Values in this dictionary will override the default values.
cookies (`Dict[str, str]`, `optional`):
Additional cookies to send to the server.
base_url (`str`, `optional`):
Base URL to run inference. This is a duplicated argument from `model` to make [`InferenceClient`]
follow the same pattern as `openai.OpenAI` client. Cannot be used if `model` is set. Defaults to None.
api_key (`str`, `optional`):
Token to use for authentication. This is a duplicated argument from `token` to make [`InferenceClient`]
follow the same pattern as `openai.OpenAI` client. Cannot be used if `token` is set. Defaults to None.
"""

@_deprecate_positional_args(version="0.26")
def __init__(
self,
model: Optional[str] = None,
*,
token: Union[str, bool, None] = None,
timeout: Optional[float] = None,
headers: Optional[Dict[str, str]] = None,
cookies: Optional[Dict[str, str]] = None,
proxies: Optional[Any] = None,
# OpenAI compatibility
base_url: Optional[str] = None,
api_key: Optional[str] = None,
) -> None:
if model is not None and base_url is not None:
raise ValueError(
"Received both `model` and `base_url` arguments. Please provide only one of them."
" `base_url` is an alias for `model` to make the API compatible with OpenAI's client."
" It has the exact same behavior as `model`."
)
if token is not None and api_key is not None:
raise ValueError(
"Received both `token` and `api_key` arguments. Please provide only one of them."
" `api_key` is an alias for `token` to make the API compatible with OpenAI's client."
" It has the exact same behavior as `token`."
)

self.model: Optional[str] = model
self.token: Union[str, bool, None] = token
self.headers = CaseInsensitiveDict(build_hf_headers(token=token)) # contains 'authorization' + 'user-agent'
self.token: Union[str, bool, None] = token or api_key
self.headers = CaseInsensitiveDict(build_hf_headers(token=self.token)) # 'authorization' + 'user-agent'
if headers is not None:
self.headers.update(headers)
self.cookies = cookies
self.timeout = timeout
self.proxies = proxies

# OpenAI compatibility
self.base_url = base_url

def __repr__(self):
return f"<InferenceClient(model='{self.model if self.model else ''}', timeout={self.timeout})>"

Expand Down Expand Up @@ -441,7 +473,6 @@ def chat_completion( # type: ignore
tools: Optional[List[ChatCompletionInputTool]] = None,
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
model_id: Optional[str] = None,
) -> ChatCompletionOutput: ...

@overload
Expand All @@ -465,7 +496,6 @@ def chat_completion( # type: ignore
tools: Optional[List[ChatCompletionInputTool]] = None,
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
model_id: Optional[str] = None,
) -> Iterable[ChatCompletionStreamOutput]: ...

@overload
Expand All @@ -489,7 +519,6 @@ def chat_completion(
tools: Optional[List[ChatCompletionInputTool]] = None,
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
model_id: Optional[str] = None,
) -> Union[ChatCompletionOutput, Iterable[ChatCompletionStreamOutput]]: ...

def chat_completion(
Expand All @@ -513,18 +542,29 @@ def chat_completion(
tools: Optional[List[ChatCompletionInputTool]] = None,
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
model_id: Optional[str] = None,
) -> Union[ChatCompletionOutput, Iterable[ChatCompletionStreamOutput]]:
"""
A method for completing conversations using a specified language model.

<Tip>

The `client.chat_completion` method is aliased as `client.chat.completions.create` for compatibility with OpenAI's client.
Inputs and outputs are strictly the same and using either syntax will yield the same results.
Check out the [Inference guide](https://huggingface.co/docs/huggingface_hub/guides/inference#openai-compatibility)
for more details about OpenAI's compatibility.

</Tip>

Args:
messages (List[Union[`SystemMessage`, `UserMessage`, `AssistantMessage`]]):
Conversation history consisting of roles and content pairs.
model (`str`, *optional*):
The model to use for chat-completion. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
Inference Endpoint. If not provided, the default recommended model for chat-based text-generation will be used.
See https://huggingface.co/tasks/text-generation for more details.

If `model` is a model ID, it is passed to the server as the `model` parameter. If you want to define a
custom URL while setting `model` in the request payload, you must set `base_url` when initializing [`InferenceClient`].
frequency_penalty (`float`, *optional*):
Penalizes new tokens based on their existing frequency
in the text so far. Range: [-2.0, 2.0]. Defaults to 0.0.
Expand Down Expand Up @@ -568,10 +608,6 @@ def chat_completion(
tools (List of [`ChatCompletionInputTool`], *optional*):
A list of tools the model may call. Currently, only functions are supported as a tool. Use this to
provide a list of functions the model may generate JSON inputs for.
model_id (`str`, *optional*):
The model ID to use for chat-completion. Only used when `model` is a URL to a deployed Text Generation Inference server.
It is passed to the server as the `model` parameter. This parameter has no impact on the URL that will be used to
send the request.

Returns:
[`ChatCompletionOutput`] or Iterable of [`ChatCompletionStreamOutput`]:
Expand Down Expand Up @@ -625,8 +661,35 @@ def chat_completion(
ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content=' capital', role='assistant'), index=0, finish_reason=None)], created=1710498504)
(...)
ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content=' may', role='assistant'), index=0, finish_reason=None)], created=1710498504)
```

Example using OpenAI's syntax:
```py
# instead of `from openai import OpenAI`
from huggingface_hub import InferenceClient

# instead of `client = OpenAI(...)`
client = InferenceClient(
base_url=...,
api_key=...,
)

# Chat example with tools
output = client.chat.completions.create(
model="meta-llama/Meta-Llama-3-8B-Instruct",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Count to 10"},
],
stream=True,
max_tokens=1024,
)

for chunk in output:
print(chunk.choices[0].delta.content)
```

Example using tools:
```py
>>> client = InferenceClient("meta-llama/Meta-Llama-3-70B-Instruct")
>>> messages = [
... {
Expand Down Expand Up @@ -708,8 +771,11 @@ def chat_completion(
)
```
"""
# determine model
model = model or self.model or self.get_recommended_model("text-generation")
# Determine model
# `self.xxx` takes precedence over the method argument only in `chat_completion`
# since `chat_completion(..., model=xxx)` is also a payload parameter for the
# server, we need to handle it differently
model = self.base_url or self.model or model or self.get_recommended_model("text-generation")

if _is_chat_completion_server(model):
# First, let's consider the server has a `/v1/chat/completions` endpoint.
Expand All @@ -718,14 +784,13 @@ def chat_completion(
if not model_url.endswith("/chat/completions"):
model_url += "/v1/chat/completions"

# `model_id` sent in the payload. Not used by the server but can be useful for debugging/routing.
if model_id is None:
if not model.startswith("http") and model.count("/") == 1:
# If it's a ID on the Hub => use it
model_id = model
else:
# Otherwise, we use a random string
model_id = "tgi"
# `model` is sent in the payload. Not used by the server but can be useful for debugging/routing.
if not model.startswith("http") and model.count("/") == 1:
# If it's a ID on the Hub => use it
model_id = model
else:
# Otherwise, we use a random string
model_id = "tgi"

try:
data = self.post(
Expand Down Expand Up @@ -2562,7 +2627,7 @@ def zero_shot_image_classification(
return ZeroShotImageClassificationOutputElement.parse_obj_as_list(response)

def _resolve_url(self, model: Optional[str] = None, task: Optional[str] = None) -> str:
model = model or self.model
model = model or self.model or self.base_url

# If model is already a URL, ignore `task` and return directly
if model is not None and (model.startswith("http://") or model.startswith("https://")):
Expand Down Expand Up @@ -2730,7 +2795,7 @@ def get_model_status(self, model: Optional[str] = None) -> ModelStatus:
```py
>>> from huggingface_hub import InferenceClient
>>> client = InferenceClient()
>>> client.get_model_status("bigcode/starcoder")
>>> client.get_model_status("meta-llama/Meta-Llama-3-8B-Instruct")
ModelStatus(loaded=True, state='Loaded', compute_type='gpu', framework='text-generation-inference')
```
"""
Expand All @@ -2754,3 +2819,30 @@ def get_model_status(self, model: Optional[str] = None) -> ModelStatus:
compute_type=response_data["compute_type"],
framework=response_data["framework"],
)

@property
def chat(self) -> "ProxyClientChat":
return ProxyClientChat(self)


class _ProxyClient:
"""Proxy class to be able to call `client.chat.completion.create(...)` as OpenAI client."""

def __init__(self, client: InferenceClient):
self._client = client


class ProxyClientChat(_ProxyClient):
"""Proxy class to be able to call `client.chat.completion.create(...)` as OpenAI client."""

@property
def completions(self) -> "ProxyClientChatCompletions":
return ProxyClientChatCompletions(self._client)


class ProxyClientChatCompletions(_ProxyClient):
"""Proxy class to be able to call `client.chat.completion.create(...)` as OpenAI client."""

@property
def create(self):
return self._client.chat_completion
Loading
Loading