Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow passing API_BASE as optional parameter for openai provider #6820

Merged
merged 11 commits into from
Jan 30, 2024
17 changes: 13 additions & 4 deletions examples/getting_started.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,43 @@
# Disable pylint errors for logging basicConfig
# pylint: disable=no-logging-basicconfig
augchan42 marked this conversation as resolved.
Show resolved Hide resolved
import logging

from typing import Optional

from haystack.document_stores import InMemoryDocumentStore
from haystack.utils import build_pipeline, add_example_data, print_answers

logging.basicConfig(level=logging.DEBUG)
augchan42 marked this conversation as resolved.
Show resolved Hide resolved

def getting_started(provider, API_KEY):

def getting_started(provider, API_KEY, API_BASE: Optional[str] = None):
"""
This getting_started example shows you how to use LLMs with your data with a technique called Retrieval Augmented Generation - RAG.

:param provider: We are model agnostic :) Here, you can choose from: "anthropic", "cohere", "huggingface", and "openai".
:param API_KEY: The API key matching the provider.
:param API_BASE: The URL to use for a custom endpoint, e.g., if using LM Studio. Only openai provider supported. /v1 at the end is needed (e.g., http://localhost:1234/v1)

"""

# We support many different databases. Here we load a simple and lightweight in-memory database.
document_store = InMemoryDocumentStore(use_bm25=True)

# Pipelines are the main abstraction in Haystack, they connect components like LLMs and databases.
pipeline = build_pipeline(provider, API_KEY, document_store)
pipeline = build_pipeline(provider, API_KEY, document_store, API_BASE)

# Download and add Game of Thrones TXT articles to Haystack's database.
# You can also provide a folder with your local documents.
# You might need to install additional dependencies - look inside the function for more information.
add_example_data(document_store, "data/GoT_getting_started")

# Ask a question on the data you just added.
result = pipeline.run(query="Who is the father of Arya Stark?")
result = pipeline.run(query="Who is the father of Arya Stark?", debug=True)

# For details such as which documents were used to generate the answer, look into the <result> object.
print_answers(result, details="medium")
return result


if __name__ == "__main__":
# getting_started(provider="openai", API_KEY="NOT NEEDED", API_BASE="http://192.168.1.100:1234/v1")
getting_started(provider="openai", API_KEY="ADD KEY HERE")
4 changes: 2 additions & 2 deletions haystack/nodes/prompt/invocation_layer/open_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(
api_key: str,
model_name_or_path: str = "gpt-3.5-turbo-instruct",
max_length: Optional[int] = 100,
api_base: str = "https://api.openai.com/v1",
api_base: Optional[str] = "https://api.openai.com/v1",
augchan42 marked this conversation as resolved.
Show resolved Hide resolved
openai_organization: Optional[str] = None,
timeout: Optional[float] = None,
**kwargs,
Expand Down Expand Up @@ -65,7 +65,7 @@ def __init__(
f"api_key {api_key} must be a valid OpenAI key. Visit https://openai.com/api/ to get one."
)
self.api_key = api_key
self.api_base = api_base
self.api_base = "https://api.openai.com/v1" if api_base is None else api_base
augchan42 marked this conversation as resolved.
Show resolved Hide resolved
self.openai_organization = openai_organization
self.timeout = timeout

Expand Down
3 changes: 3 additions & 0 deletions haystack/nodes/prompt/prompt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
model_name_or_path: str = "google/flan-t5-base",
max_length: Optional[int] = 100,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
timeout: Optional[float] = None,
use_auth_token: Optional[Union[str, bool]] = None,
use_gpu: Optional[bool] = None,
Expand Down Expand Up @@ -65,6 +66,7 @@ def __init__(
self.model_name_or_path = model_name_or_path
self.max_length = max_length
self.api_key = api_key
self.api_base = api_base
self.timeout = timeout
self.use_auth_token = use_auth_token
self.use_gpu = use_gpu
Expand All @@ -78,6 +80,7 @@ def create_invocation_layer(
) -> PromptModelInvocationLayer:
kwargs = {
"api_key": self.api_key,
"api_base": self.api_base,
augchan42 marked this conversation as resolved.
Show resolved Hide resolved
"timeout": self.timeout,
"use_auth_token": self.use_auth_token,
"use_gpu": self.use_gpu,
Expand Down
2 changes: 2 additions & 0 deletions haystack/nodes/prompt/prompt_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(
output_variable: Optional[str] = None,
max_length: Optional[int] = 100,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
timeout: Optional[float] = None,
use_auth_token: Optional[Union[str, bool]] = None,
use_gpu: Optional[bool] = None,
Expand Down Expand Up @@ -114,6 +115,7 @@ def __init__(
model_name_or_path=model_name_or_path,
max_length=max_length,
api_key=api_key,
api_base=api_base,
timeout=timeout,
use_auth_token=use_auth_token,
use_gpu=use_gpu,
Expand Down
3 changes: 2 additions & 1 deletion haystack/utils/getting_started.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
logger = logging.getLogger(__name__)


def build_pipeline(provider, API_KEY, document_store):
def build_pipeline(provider, API_KEY, document_store, API_BASE):
augchan42 marked this conversation as resolved.
Show resolved Hide resolved
# Importing top-level causes a circular import
from haystack.nodes import AnswerParser, PromptNode, PromptTemplate, BM25Retriever
from haystack.pipelines import Pipeline
Expand Down Expand Up @@ -42,6 +42,7 @@ def build_pipeline(provider, API_KEY, document_store):
prompt_node = PromptNode(
model_name_or_path="gpt-3.5-turbo-0301",
api_key=API_KEY,
api_base=API_BASE,
default_prompt_template=question_answering_with_references,
)
else:
Expand Down
8 changes: 8 additions & 0 deletions releasenotes/notes/override-api-base-67bc046a5cc5f46d.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
---
enhancements:
- |
API_BASE can now be passed as an optional parameter in the getting_started sample. Only openai provider is supported in this set of changes.
PromptNode and PromptModel were enhanced to allow passing of this parameter.
This allows RAG against a local endpoint (e.g, http://localhost:1234/v1), so long as it is OpenAI compatible (such as LM Studio)
Logging in the getting started sample was made more verbose, to make it easier for people to see what was happening under the covers.