Skip to content

Commit

Permalink
feat: allow passing API_BASE as optional parameter for openai provider (
Browse files Browse the repository at this point in the history
#6820)

* allow passing API_BASE as optional parameter for openai provider

* Add release note

* fix linter whitespace issue

* Update examples/getting_started.py

Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>

* Update examples/getting_started.py

Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>

* revert optional api_base based on shadeMe comment

* Update haystack/utils/getting_started.py

Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>

* Recommendations from shadeMe comments

* fix param ordering to build_pipeline

* Update haystack/nodes/prompt/invocation_layer/open_ai.py

Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>

---------

Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>
  • Loading branch information
augchan42 and shadeMe committed Jan 30, 2024
1 parent 71879b1 commit a0a54dc
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 5 deletions.
15 changes: 11 additions & 4 deletions examples/getting_started.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,41 @@
import logging

from typing import Optional

from haystack.document_stores import InMemoryDocumentStore
from haystack.utils import build_pipeline, add_example_data, print_answers

logger = logging.getLogger(__name__)

def getting_started(provider, API_KEY):

def getting_started(provider, API_KEY, API_BASE: Optional[str] = None):
"""
This getting_started example shows you how to use LLMs with your data with a technique called Retrieval Augmented Generation - RAG.
:param provider: We are model agnostic :) Here, you can choose from: "anthropic", "cohere", "huggingface", and "openai".
:param API_KEY: The API key matching the provider.
:param API_BASE: The URL to use for a custom endpoint, e.g., if using LM Studio. Only openai provider supported. /v1 at the end is needed (e.g., http://localhost:1234/v1)
"""

# We support many different databases. Here we load a simple and lightweight in-memory database.
document_store = InMemoryDocumentStore(use_bm25=True)

# Pipelines are the main abstraction in Haystack, they connect components like LLMs and databases.
pipeline = build_pipeline(provider, API_KEY, document_store)
pipeline = build_pipeline(provider, API_KEY, API_BASE, document_store)

# Download and add Game of Thrones TXT articles to Haystack's database.
# You can also provide a folder with your local documents.
# You might need to install additional dependencies - look inside the function for more information.
add_example_data(document_store, "data/GoT_getting_started")

# Ask a question on the data you just added.
result = pipeline.run(query="Who is the father of Arya Stark?")
result = pipeline.run(query="Who is the father of Arya Stark?", debug=True)

# For details such as which documents were used to generate the answer, look into the <result> object.
print_answers(result, details="medium")
return result


if __name__ == "__main__":
# getting_started(provider="openai", API_KEY="NOT NEEDED", API_BASE="http://192.168.1.100:1234/v1")
getting_started(provider="openai", API_KEY="ADD KEY HERE")
5 changes: 5 additions & 0 deletions haystack/nodes/prompt/prompt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
model_name_or_path: str = "google/flan-t5-base",
max_length: Optional[int] = 100,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
timeout: Optional[float] = None,
use_auth_token: Optional[Union[str, bool]] = None,
use_gpu: Optional[bool] = None,
Expand Down Expand Up @@ -65,6 +66,7 @@ def __init__(
self.model_name_or_path = model_name_or_path
self.max_length = max_length
self.api_key = api_key
self.api_base = api_base
self.timeout = timeout
self.use_auth_token = use_auth_token
self.use_gpu = use_gpu
Expand All @@ -83,6 +85,9 @@ def create_invocation_layer(
"use_gpu": self.use_gpu,
"devices": self.devices,
}
if self.api_base is not None:
kwargs["api_base"] = self.api_base

all_kwargs = {**self.model_kwargs, **kwargs}

if isinstance(invocation_layer_class, str):
Expand Down
2 changes: 2 additions & 0 deletions haystack/nodes/prompt/prompt_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(
output_variable: Optional[str] = None,
max_length: Optional[int] = 100,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
timeout: Optional[float] = None,
use_auth_token: Optional[Union[str, bool]] = None,
use_gpu: Optional[bool] = None,
Expand Down Expand Up @@ -114,6 +115,7 @@ def __init__(
model_name_or_path=model_name_or_path,
max_length=max_length,
api_key=api_key,
api_base=api_base,
timeout=timeout,
use_auth_token=use_auth_token,
use_gpu=use_gpu,
Expand Down
3 changes: 2 additions & 1 deletion haystack/utils/getting_started.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
logger = logging.getLogger(__name__)


def build_pipeline(provider, API_KEY, document_store):
def build_pipeline(provider, API_KEY, API_BASE, document_store):
# Importing top-level causes a circular import
from haystack.nodes import AnswerParser, PromptNode, PromptTemplate, BM25Retriever
from haystack.pipelines import Pipeline
Expand Down Expand Up @@ -42,6 +42,7 @@ def build_pipeline(provider, API_KEY, document_store):
prompt_node = PromptNode(
model_name_or_path="gpt-3.5-turbo-0301",
api_key=API_KEY,
api_base=API_BASE,
default_prompt_template=question_answering_with_references,
)
else:
Expand Down
8 changes: 8 additions & 0 deletions releasenotes/notes/override-api-base-67bc046a5cc5f46d.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
---
enhancements:
- |
API_BASE can now be passed as an optional parameter in the getting_started sample. Only openai provider is supported in this set of changes.
PromptNode and PromptModel were enhanced to allow passing of this parameter.
This allows RAG against a local endpoint (e.g, http://localhost:1234/v1), so long as it is OpenAI compatible (such as LM Studio)
Logging in the getting started sample was made more verbose, to make it easier for people to see what was happening under the covers.

0 comments on commit a0a54dc

Please sign in to comment.