Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/nvidia.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ jobs:
if: matrix.python-version == '3.9' && runner.os == 'Linux'
run: hatch run lint:all

- name: Run tests
run: hatch run cov-retry

- name: Generate docs
if: matrix.python-version == '3.9' && runner.os == 'Linux'
run: hatch run docs

- name: Run tests
run: hatch run cov-retry

- name: Run unit tests with lowest direct dependencies
run: |
hatch run uv pip compile pyproject.toml --resolution lowest-direct --output-file requirements_lowest_direct.txt
Expand Down
5 changes: 4 additions & 1 deletion integrations/nvidia/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ classifiers = [
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
dependencies = ["haystack-ai", "requests>=2.25.0", "tqdm>=4.21.0"]
dependencies = ["haystack-ai>=2.13.0", "requests>=2.25.0", "tqdm>=4.21.0"]

[project.urls]
Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/nvidia#readme"
Expand All @@ -46,6 +46,8 @@ installer = "uv"
dependencies = [
"coverage[toml]>=6.5",
"pytest",
"pytest-asyncio",
"pytz",
"pytest-rerunfailures",
"haystack-pydoc-tools",
"requests_mock",
Expand Down Expand Up @@ -160,6 +162,7 @@ module = [
"pytest.*",
"numpy.*",
"requests_mock.*",
"openai.*",
"pydantic.*",
]
ignore_missing_imports = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

from .chat.chat_generator import NvidiaChatGenerator
from .generator import NvidiaGenerator

__all__ = ["NvidiaGenerator"]
__all__ = ["NvidiaChatGenerator", "NvidiaGenerator"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# SPDX-FileCopyrightText: 2024-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0

import os
from typing import Any, Dict, List, Optional, Union

from haystack import component, default_to_dict, logging
from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.dataclasses import StreamingCallbackT
from haystack.tools import Tool, Toolset, serialize_tools_or_toolset
from haystack.utils import serialize_callable
from haystack.utils.auth import Secret

from haystack_integrations.utils.nvidia import DEFAULT_API_URL

logger = logging.getLogger(__name__)


@component
class NvidiaChatGenerator(OpenAIChatGenerator):
"""
Enables text generation using NVIDIA generative models.
For supported models, see [NVIDIA Docs](https://build.nvidia.com/models).

Users can pass any text generation parameters valid for the NVIDIA Chat Completion API
directly to this component via the `generation_kwargs` parameter in `__init__` or the `generation_kwargs`
parameter in `run` method.

This component uses the ChatMessage format for structuring both input and output,
ensuring coherent and contextually relevant responses in chat-based text generation scenarios.
Details on the ChatMessage format can be found in the
[Haystack docs](https://docs.haystack.deepset.ai/docs/data-classes#chatmessage)

For more details on the parameters supported by the NVIDIA API, refer to the
[NVIDIA Docs](https://build.nvidia.com/models).

Usage example:
```python
from haystack_integrations.components.generators.nvidia import NvidiaChatGenerator
from haystack.dataclasses import ChatMessage

messages = [ChatMessage.from_user("What's Natural Language Processing?")]

client = NvidiaChatGenerator()
response = client.run(messages)
print(response)
```
"""

def __init__(
self,
*,
api_key: Secret = Secret.from_env_var("NVIDIA_API_KEY"),
model: str = "meta/llama-3.1-8b-instruct",
streaming_callback: Optional[StreamingCallbackT] = None,
api_base_url: Optional[str] = os.getenv("NVIDIA_API_URL", DEFAULT_API_URL),
generation_kwargs: Optional[Dict[str, Any]] = None,
tools: Optional[Union[List[Tool], Toolset]] = None,
timeout: Optional[float] = None,
max_retries: Optional[int] = None,
http_client_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Creates an instance of NvidiaChatGenerator.

:param api_key:
The NVIDIA API key.
:param model:
The name of the NVIDIA chat completion model to use.
:param streaming_callback:
A callback function that is called when a new token is received from the stream.
The callback function accepts StreamingChunk as an argument.
:param api_base_url:
The NVIDIA API Base url.
:param generation_kwargs:
Other parameters to use for the model. These parameters are all sent directly to
the NVIDIA API endpoint. See [NVIDIA API docs](https://docs.nvcf.nvidia.com/ai/generative-models/)
for more details.
Some of the supported parameters:
- `max_tokens`: The maximum number of tokens the output text can have.
- `temperature`: What sampling temperature to use. Higher values mean the model will take more risks.
Try 0.9 for more creative applications and 0 (argmax sampling) for ones with a well-defined answer.
- `top_p`: An alternative to sampling with temperature, called nucleus sampling, where the model
considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens
comprising the top 10% probability mass are considered.
- `stream`: Whether to stream back partial progress. If set, tokens will be sent as data-only server-sent
events as they become available, with the stream terminated by a data: [DONE] message.
:param tools:
A list of tools or a Toolset for which the model can prepare calls. This parameter can accept either a
list of `Tool` objects or a `Toolset` instance.
:param timeout:
The timeout for the NVIDIA API call.
:param max_retries:
Maximum number of retries to contact NVIDIA after an internal error.
If not set, it defaults to either the `NVIDIA_MAX_RETRIES` environment variable, or set to 5.
:param http_client_kwargs:
A dictionary of keyword arguments to configure a custom `httpx.Client`or `httpx.AsyncClient`.
For more information, see the [HTTPX documentation](https://www.python-httpx.org/api/#client).
"""
super(NvidiaChatGenerator, self).__init__( # noqa: UP008
api_key=api_key,
model=model,
streaming_callback=streaming_callback,
api_base_url=api_base_url,
generation_kwargs=generation_kwargs,
tools=tools,
timeout=timeout,
max_retries=max_retries,
http_client_kwargs=http_client_kwargs,
)

def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.

:returns:
The serialized component as a dictionary.
"""
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None

return default_to_dict(
self,
model=self.model,
streaming_callback=callback_name,
api_base_url=self.api_base_url,
generation_kwargs=self.generation_kwargs,
api_key=self.api_key.to_dict(),
tools=serialize_tools_or_toolset(self.tools),
timeout=self.timeout,
max_retries=self.max_retries,
http_client_kwargs=self.http_client_kwargs,
)
Loading
Loading