Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: refactor LLMOptions #28

Merged
merged 11 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions src/dbally/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from dbally.audit.event_tracker import EventTracker
from dbally.data_models.audit import RequestEnd, RequestStart
from dbally.data_models.execution_result import ExecutionResult
from dbally.llm_client.base import LLMClient
from dbally.llm_client.base import LLMClient, LLMOptions
from dbally.nl_responder.nl_responder import NLResponder
from dbally.similarity.index import AbstractSimilarityIndex
from dbally.utils.errors import NoViewFoundError
Expand Down Expand Up @@ -157,7 +157,13 @@ def list(self) -> Dict[str, str]:
name: (textwrap.dedent(view.__doc__).strip() if view.__doc__ else "") for name, view in self._views.items()
}

async def ask(self, question: str, dry_run: bool = False, return_natural_response: bool = False) -> ExecutionResult:
async def ask(
self,
question: str,
dry_run: bool = False,
return_natural_response: bool = False,
llm_options: Optional[LLMOptions] = None,
) -> ExecutionResult:
"""
Ask question in a text form and retrieve the answer based on the available views.

Expand All @@ -174,6 +180,7 @@ async def ask(self, question: str, dry_run: bool = False, return_natural_respons
dry_run: if True, only generate the query without executing it
return_natural_response: if True (and dry_run is False as natural response requires query results),
the natural response will be included in the answer
llm_options: options to use for the LLM client.
micpst marked this conversation as resolved.
Show resolved Hide resolved

Returns:
ExecutionResult object representing the result of the query execution.
Expand All @@ -197,7 +204,12 @@ async def ask(self, question: str, dry_run: bool = False, return_natural_respons
if len(views) == 1:
selected_view = next(iter(views))
else:
selected_view = await self._view_selector.select_view(question, views, event_tracker)
selected_view = await self._view_selector.select_view(
question=question,
views=views,
event_tracker=event_tracker,
llm_options=llm_options,
)

view = self.get(selected_view)

Expand All @@ -208,12 +220,18 @@ async def ask(self, question: str, dry_run: bool = False, return_natural_respons
event_tracker=event_tracker,
n_retries=self.n_retries,
dry_run=dry_run,
llm_options=llm_options,
)
end_time_view = time.monotonic()

textual_response = None
if not dry_run and return_natural_response:
textual_response = await self._nl_responder.generate_response(view_result, question, event_tracker)
textual_response = await self._nl_responder.generate_response(
micpst marked this conversation as resolved.
Show resolved Hide resolved
result=view_result,
question=question,
event_tracker=event_tracker,
llm_options=llm_options,
)

result = ExecutionResult(
results=view_result.results,
Expand Down
21 changes: 0 additions & 21 deletions src/dbally/data_models/llm_options.py

This file was deleted.

5 changes: 4 additions & 1 deletion src/dbally/iql_generator/iql_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from dbally.audit.event_tracker import EventTracker
from dbally.iql_generator.iql_prompt_template import IQLPromptTemplate, default_iql_template
from dbally.llm_client.base import LLMClient
from dbally.llm_client.base import LLMClient, LLMOptions
from dbally.prompts.prompt_builder import PromptBuilder
from dbally.views.exposed_functions import ExposedFunction

Expand Down Expand Up @@ -49,6 +49,7 @@ async def generate_iql(
question: str,
event_tracker: EventTracker,
conversation: Optional[IQLPromptTemplate] = None,
llm_options: Optional[LLMOptions] = None,
) -> Tuple[str, IQLPromptTemplate]:
"""
Uses LLM to generate IQL in text form
Expand All @@ -58,6 +59,7 @@ async def generate_iql(
filters: list of filters exposed by the view
event_tracker: event store used to audit the generation process
conversation: conversation to be continued
llm_options: options to use for the LLM client

Returns:
IQL - iql generated based on the user question
Expand All @@ -70,6 +72,7 @@ async def generate_iql(
template=template,
fmt={"filters": filters_for_prompt, "question": question},
event_tracker=event_tracker,
options=llm_options,
)

iql_filters = self._prompt_template.llm_response_parser(llm_response)
Expand Down
60 changes: 33 additions & 27 deletions src/dbally/llm_client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,38 @@
# pylint: disable=W9015,R0914

import abc
from typing import Dict, List, Optional, Union
from abc import ABC
from dataclasses import asdict, dataclass
from typing import Dict, Generic, Optional, Type, TypeVar, Union

from dbally.audit.event_tracker import EventTracker
from dbally.data_models.audit import LLMEvent
from dbally.data_models.llm_options import LLMOptions
from dbally.prompts import ChatFormat, PromptBuilder, PromptTemplate

LLMClientOptions = TypeVar("LLMClientOptions")

class LLMClient(abc.ABC):

@dataclass
class LLMOptions(ABC):
"""
Abstract dataclass that represents all available LLM call options.
"""

dict = asdict


class LLMClient(Generic[LLMClientOptions], ABC):
"""
Abstract client for interaction with LLM.

It accepts parameters including the template, format, event tracker,
and optional generation parameters like frequency_penalty, max_tokens, and temperature
(the full list of options is provided by the [`LLMOptions` class][dbally.data_models.llm_options.LLMOptions]).
It constructs a prompt using the `PromptBuilder` instance and generates text using the `self.call` method.
"""

def __init__(self, model_name: str):
_options_cls: Type[LLMClientOptions]
micpst marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, model_name: str, default_options: Optional[LLMClientOptions] = None) -> None:
self.model_name = model_name
self.default_options = default_options or self._options_cls()
self._prompt_builder = PromptBuilder(self.model_name)

async def text_generation( # pylint: disable=R0913
Expand All @@ -30,41 +42,35 @@ async def text_generation( # pylint: disable=R0913
fmt: dict,
*,
event_tracker: Optional[EventTracker] = None,
frequency_penalty: Optional[float] = 0.0,
max_tokens: Optional[int] = 128,
n: Optional[int] = 1,
presence_penalty: Optional[float] = 0.0,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
temperature: Optional[float] = 1.0,
top_p: Optional[float] = 1.0,
options: Optional[LLMClientOptions] = None,
) -> str:
"""
For a given a PromptType and format dict creates a prompt and
returns the response from LLM.

Args:
template: Prompt template in system/user/assistant openAI format.
fmt: Dictionary with formatting.
event_tracker: Event store used to audit the generation process.
options: options to use for the LLM client.

Returns:
Text response from LLM.
"""

options = LLMOptions(
frequency_penalty=frequency_penalty,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
temperature=temperature,
top_p=top_p,
)
options = options if options else self.default_options
micpst marked this conversation as resolved.
Show resolved Hide resolved

prompt = self._prompt_builder.build(template, fmt)

event = LLMEvent(prompt=prompt, type=type(template).__name__)

event_tracker = event_tracker or EventTracker()
async with event_tracker.track_event(event) as span:
event.response = await self.call(prompt, template.response_format, options, event)
event.response = await self.call(
prompt=prompt,
response_format=template.response_format,
options=options,
event=event,
)
span(event)

return event.response
Expand Down
48 changes: 40 additions & 8 deletions src/dbally/llm_client/openai_client.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,65 @@
from typing import Dict, Optional, Union
from dataclasses import dataclass
from typing import Dict, List, Optional, Union

from openai import NOT_GIVEN, NotGiven

from dbally.data_models.audit import LLMEvent
from dbally.data_models.llm_options import LLMOptions
from dbally.llm_client.base import LLMClient
from dbally.prompts import ChatFormat

from .base import LLMOptions


class OpenAIClient(LLMClient):
@dataclass
class OpenAIOptions(LLMOptions):
"""
Dataclass that represents all available LLM call options for the OpenAI API. Each of them is
described in the [OpenAI API documentation](https://platform.openai.com/docs/api-reference/chat/create.)
"""

frequency_penalty: Union[Optional[float], NotGiven] = NOT_GIVEN
max_tokens: Union[Optional[int], NotGiven] = NOT_GIVEN
n: Union[Optional[int], NotGiven] = NOT_GIVEN
presence_penalty: Union[Optional[float], NotGiven] = NOT_GIVEN
seed: Union[Optional[int], NotGiven] = NOT_GIVEN
stop: Union[Optional[Union[str, List[str]]], NotGiven] = NOT_GIVEN
temperature: Union[Optional[float], NotGiven] = NOT_GIVEN
top_p: Union[Optional[float], NotGiven] = NOT_GIVEN


class OpenAIClient(LLMClient[OpenAIOptions]):
"""
`OpenAIClient` is a class designed to interact with OpenAI's language model (LLM) endpoints,
particularly for the GPT models.

Args:
model_name: Name of the [OpenAI's model](https://platform.openai.com/docs/models) to be used,
model_name: Name of the [OpenAI's model](https://platform.openai.com/docs/models) to be used,\
default is "gpt-3.5-turbo".
api_key: OpenAI's API key. If None OPENAI_API_KEY environment variable will be used
default_options: Default options to be used in the LLM calls.
"""

def __init__(self, model_name: str = "gpt-3.5-turbo", api_key: Optional[str] = None) -> None:
_options_cls = OpenAIOptions

def __init__(
self,
model_name: str = "gpt-3.5-turbo",
api_key: Optional[str] = None,
default_options: Optional[OpenAIOptions] = None,
) -> None:
try:
from openai import AsyncOpenAI # pylint: disable=import-outside-toplevel
except ImportError as exc:
raise ImportError("You need to install openai package to use GPT models") from exc

super().__init__(model_name)
super().__init__(model_name=model_name, default_options=default_options)
self._client = AsyncOpenAI(api_key=api_key)

async def call(
self,
prompt: Union[str, ChatFormat],
response_format: Optional[Dict[str, str]],
options: LLMOptions,
options: OpenAIOptions,
event: LLMEvent,
) -> str:
"""
Expand All @@ -52,7 +81,10 @@ async def call(
response_format = None

response = await self._client.chat.completions.create(
messages=prompt, model=self.model_name, response_format=response_format, **options.dict() # type: ignore
messages=prompt,
model=self.model_name,
response_format=response_format,
**options.dict(), # type: ignore
)

event.completion_tokens = response.usage.completion_tokens
Expand Down
13 changes: 11 additions & 2 deletions src/dbally/nl_responder/nl_responder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from dbally.audit.event_tracker import EventTracker
from dbally.data_models.execution_result import ViewExecutionResult
from dbally.llm_client.base import LLMClient
from dbally.llm_client.base import LLMClient, LLMOptions
from dbally.nl_responder.nl_responder_prompt_template import NLResponderPromptTemplate, default_nl_responder_template
from dbally.nl_responder.query_explainer_prompt_template import (
QueryExplainerPromptTemplate,
Expand Down Expand Up @@ -46,14 +46,21 @@ def __init__(
)
self._max_tokens_count = max_tokens_count

async def generate_response(self, result: ViewExecutionResult, question: str, event_tracker: EventTracker) -> str:
async def generate_response(
self,
result: ViewExecutionResult,
question: str,
event_tracker: EventTracker,
llm_options: Optional[LLMOptions] = None,
) -> str:
"""
Uses LLM to generate a response in natural language form.

Args:
result: object representing the result of the query execution
question: user question
event_tracker: event store used to audit the generation process
llm_options: options to use for the LLM client.

Returns:
Natural language response to the user question.
Expand Down Expand Up @@ -82,6 +89,7 @@ async def generate_response(self, result: ViewExecutionResult, question: str, ev
template=self._query_explainer_prompt_template,
fmt={"question": question, "query": query, "number_of_results": len(result.results)},
event_tracker=event_tracker,
options=llm_options,
)

return llm_response
Expand All @@ -90,6 +98,7 @@ async def generate_response(self, result: ViewExecutionResult, question: str, ev
template=self._nl_responder_prompt_template,
fmt={"rows": _promptify_rows(result.results), "question": question},
event_tracker=event_tracker,
options=llm_options,
)
return llm_response

Expand Down
12 changes: 10 additions & 2 deletions src/dbally/view_selection/base.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,29 @@
import abc
from typing import Dict
from typing import Dict, Optional

from dbally.audit.event_tracker import EventTracker
from dbally.llm_client.base import LLMOptions


class ViewSelector(abc.ABC):
"""Base class for view selectors."""

@abc.abstractmethod
async def select_view(self, question: str, views: Dict[str, str], event_tracker: EventTracker) -> str:
async def select_view(
self,
question: str,
views: Dict[str, str],
event_tracker: EventTracker,
llm_options: Optional[LLMOptions] = None,
) -> str:
"""
Based on user question and list of available views select the most relevant one.

Args:
question: user question asked in the natural language e.g "Do we have any data scientists?"
views: dictionary of available view names with corresponding descriptions.
event_tracker: event tracker used to audit the selection process.
llm_options: options to use for the LLM client.

Returns:
The most relevant view name.
Expand Down
Loading
Loading