Skip to content

Commit

Permalink
refactor: refactor LLMOptions (#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
micpst committed May 13, 2024
1 parent e0f17a7 commit 3f5c4bf
Show file tree
Hide file tree
Showing 14 changed files with 288 additions and 83 deletions.
26 changes: 22 additions & 4 deletions src/dbally/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from dbally.audit.event_tracker import EventTracker
from dbally.data_models.audit import RequestEnd, RequestStart
from dbally.data_models.execution_result import ExecutionResult
from dbally.llm_client.base import LLMClient
from dbally.llm_client.base import LLMClient, LLMOptions
from dbally.nl_responder.nl_responder import NLResponder
from dbally.similarity.index import AbstractSimilarityIndex
from dbally.utils.errors import NoViewFoundError
Expand Down Expand Up @@ -157,7 +157,13 @@ def list(self) -> Dict[str, str]:
name: (textwrap.dedent(view.__doc__).strip() if view.__doc__ else "") for name, view in self._views.items()
}

async def ask(self, question: str, dry_run: bool = False, return_natural_response: bool = False) -> ExecutionResult:
async def ask(
self,
question: str,
dry_run: bool = False,
return_natural_response: bool = False,
llm_options: Optional[LLMOptions] = None,
) -> ExecutionResult:
"""
Ask question in a text form and retrieve the answer based on the available views.
Expand All @@ -174,6 +180,7 @@ async def ask(self, question: str, dry_run: bool = False, return_natural_respons
dry_run: if True, only generate the query without executing it
return_natural_response: if True (and dry_run is False as natural response requires query results),
the natural response will be included in the answer
llm_options: options to use for the LLM client.
Returns:
ExecutionResult object representing the result of the query execution.
Expand All @@ -197,7 +204,12 @@ async def ask(self, question: str, dry_run: bool = False, return_natural_respons
if len(views) == 1:
selected_view = next(iter(views))
else:
selected_view = await self._view_selector.select_view(question, views, event_tracker)
selected_view = await self._view_selector.select_view(
question=question,
views=views,
event_tracker=event_tracker,
llm_options=llm_options,
)

view = self.get(selected_view)

Expand All @@ -208,12 +220,18 @@ async def ask(self, question: str, dry_run: bool = False, return_natural_respons
event_tracker=event_tracker,
n_retries=self.n_retries,
dry_run=dry_run,
llm_options=llm_options,
)
end_time_view = time.monotonic()

textual_response = None
if not dry_run and return_natural_response:
textual_response = await self._nl_responder.generate_response(view_result, question, event_tracker)
textual_response = await self._nl_responder.generate_response(
result=view_result,
question=question,
event_tracker=event_tracker,
llm_options=llm_options,
)

result = ExecutionResult(
results=view_result.results,
Expand Down
21 changes: 0 additions & 21 deletions src/dbally/data_models/llm_options.py

This file was deleted.

5 changes: 4 additions & 1 deletion src/dbally/iql_generator/iql_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from dbally.audit.event_tracker import EventTracker
from dbally.iql_generator.iql_prompt_template import IQLPromptTemplate, default_iql_template
from dbally.llm_client.base import LLMClient
from dbally.llm_client.base import LLMClient, LLMOptions
from dbally.prompts.prompt_builder import PromptBuilder
from dbally.views.exposed_functions import ExposedFunction

Expand Down Expand Up @@ -49,6 +49,7 @@ async def generate_iql(
question: str,
event_tracker: EventTracker,
conversation: Optional[IQLPromptTemplate] = None,
llm_options: Optional[LLMOptions] = None,
) -> Tuple[str, IQLPromptTemplate]:
"""
Uses LLM to generate IQL in text form
Expand All @@ -58,6 +59,7 @@ async def generate_iql(
filters: list of filters exposed by the view
event_tracker: event store used to audit the generation process
conversation: conversation to be continued
llm_options: options to use for the LLM client
Returns:
IQL - iql generated based on the user question
Expand All @@ -70,6 +72,7 @@ async def generate_iql(
template=template,
fmt={"filters": filters_for_prompt, "question": question},
event_tracker=event_tracker,
options=llm_options,
)

iql_filters = self._prompt_template.llm_response_parser(llm_response)
Expand Down
64 changes: 37 additions & 27 deletions src/dbally/llm_client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,69 +2,79 @@
# pylint: disable=W9015,R0914

import abc
from typing import Dict, List, Optional, Union
from abc import ABC
from dataclasses import asdict, dataclass
from typing import Dict, Generic, Optional, Type, TypeVar, Union

from dbally.audit.event_tracker import EventTracker
from dbally.data_models.audit import LLMEvent
from dbally.data_models.llm_options import LLMOptions
from dbally.prompts import ChatFormat, PromptBuilder, PromptTemplate

LLMClientOptions = TypeVar("LLMClientOptions")

class LLMClient(abc.ABC):

@dataclass
class LLMOptions(ABC):
"""
Abstract dataclass that represents all available LLM call options.
"""

dict = asdict


class LLMClient(Generic[LLMClientOptions], ABC):
"""
Abstract client for interaction with LLM.
It accepts parameters including the template, format, event tracker,
and optional generation parameters like frequency_penalty, max_tokens, and temperature
(the full list of options is provided by the [`LLMOptions` class][dbally.data_models.llm_options.LLMOptions]).
It constructs a prompt using the `PromptBuilder` instance and generates text using the `self.call` method.
"""

def __init__(self, model_name: str):
_options_cls: Type[LLMClientOptions]

def __init__(self, model_name: str, default_options: Optional[LLMClientOptions] = None) -> None:
self.model_name = model_name
self.default_options = default_options or self._options_cls()
self._prompt_builder = PromptBuilder(self.model_name)

def __init_subclass__(cls) -> None:
if not hasattr(cls, "_options_cls"):
raise TypeError(f"Class {cls.__name__} is missing the '_options_cls' attribute")

async def text_generation( # pylint: disable=R0913
self,
template: PromptTemplate,
fmt: dict,
*,
event_tracker: Optional[EventTracker] = None,
frequency_penalty: Optional[float] = 0.0,
max_tokens: Optional[int] = 128,
n: Optional[int] = 1,
presence_penalty: Optional[float] = 0.0,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
temperature: Optional[float] = 1.0,
top_p: Optional[float] = 1.0,
options: Optional[LLMClientOptions] = None,
) -> str:
"""
For a given a PromptType and format dict creates a prompt and
returns the response from LLM.
Args:
template: Prompt template in system/user/assistant openAI format.
fmt: Dictionary with formatting.
event_tracker: Event store used to audit the generation process.
options: options to use for the LLM client.
Returns:
Text response from LLM.
"""

options = LLMOptions(
frequency_penalty=frequency_penalty,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
temperature=temperature,
top_p=top_p,
)
options = options if options else self.default_options

prompt = self._prompt_builder.build(template, fmt)

event = LLMEvent(prompt=prompt, type=type(template).__name__)

event_tracker = event_tracker or EventTracker()
async with event_tracker.track_event(event) as span:
event.response = await self.call(prompt, template.response_format, options, event)
event.response = await self.call(
prompt=prompt,
response_format=template.response_format,
options=options,
event=event,
)
span(event)

return event.response
Expand Down
48 changes: 40 additions & 8 deletions src/dbally/llm_client/openai_client.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,65 @@
from typing import Dict, Optional, Union
from dataclasses import dataclass
from typing import Dict, List, Optional, Union

from openai import NOT_GIVEN, NotGiven

from dbally.data_models.audit import LLMEvent
from dbally.data_models.llm_options import LLMOptions
from dbally.llm_client.base import LLMClient
from dbally.prompts import ChatFormat

from .base import LLMOptions


class OpenAIClient(LLMClient):
@dataclass
class OpenAIOptions(LLMOptions):
"""
Dataclass that represents all available LLM call options for the OpenAI API. Each of them is
described in the [OpenAI API documentation](https://platform.openai.com/docs/api-reference/chat/create.)
"""

frequency_penalty: Union[Optional[float], NotGiven] = NOT_GIVEN
max_tokens: Union[Optional[int], NotGiven] = NOT_GIVEN
n: Union[Optional[int], NotGiven] = NOT_GIVEN
presence_penalty: Union[Optional[float], NotGiven] = NOT_GIVEN
seed: Union[Optional[int], NotGiven] = NOT_GIVEN
stop: Union[Optional[Union[str, List[str]]], NotGiven] = NOT_GIVEN
temperature: Union[Optional[float], NotGiven] = NOT_GIVEN
top_p: Union[Optional[float], NotGiven] = NOT_GIVEN


class OpenAIClient(LLMClient[OpenAIOptions]):
"""
`OpenAIClient` is a class designed to interact with OpenAI's language model (LLM) endpoints,
particularly for the GPT models.
Args:
model_name: Name of the [OpenAI's model](https://platform.openai.com/docs/models) to be used,
model_name: Name of the [OpenAI's model](https://platform.openai.com/docs/models) to be used,\
default is "gpt-3.5-turbo".
api_key: OpenAI's API key. If None OPENAI_API_KEY environment variable will be used
default_options: Default options to be used in the LLM calls.
"""

def __init__(self, model_name: str = "gpt-3.5-turbo", api_key: Optional[str] = None) -> None:
_options_cls = OpenAIOptions

def __init__(
self,
model_name: str = "gpt-3.5-turbo",
api_key: Optional[str] = None,
default_options: Optional[OpenAIOptions] = None,
) -> None:
try:
from openai import AsyncOpenAI # pylint: disable=import-outside-toplevel
except ImportError as exc:
raise ImportError("You need to install openai package to use GPT models") from exc

super().__init__(model_name)
super().__init__(model_name=model_name, default_options=default_options)
self._client = AsyncOpenAI(api_key=api_key)

async def call(
self,
prompt: Union[str, ChatFormat],
response_format: Optional[Dict[str, str]],
options: LLMOptions,
options: OpenAIOptions,
event: LLMEvent,
) -> str:
"""
Expand All @@ -52,7 +81,10 @@ async def call(
response_format = None

response = await self._client.chat.completions.create(
messages=prompt, model=self.model_name, response_format=response_format, **options.dict() # type: ignore
messages=prompt,
model=self.model_name,
response_format=response_format,
**options.dict(), # type: ignore
)

event.completion_tokens = response.usage.completion_tokens
Expand Down
13 changes: 11 additions & 2 deletions src/dbally/nl_responder/nl_responder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from dbally.audit.event_tracker import EventTracker
from dbally.data_models.execution_result import ViewExecutionResult
from dbally.llm_client.base import LLMClient
from dbally.llm_client.base import LLMClient, LLMOptions
from dbally.nl_responder.nl_responder_prompt_template import NLResponderPromptTemplate, default_nl_responder_template
from dbally.nl_responder.query_explainer_prompt_template import (
QueryExplainerPromptTemplate,
Expand Down Expand Up @@ -46,14 +46,21 @@ def __init__(
)
self._max_tokens_count = max_tokens_count

async def generate_response(self, result: ViewExecutionResult, question: str, event_tracker: EventTracker) -> str:
async def generate_response(
self,
result: ViewExecutionResult,
question: str,
event_tracker: EventTracker,
llm_options: Optional[LLMOptions] = None,
) -> str:
"""
Uses LLM to generate a response in natural language form.
Args:
result: object representing the result of the query execution
question: user question
event_tracker: event store used to audit the generation process
llm_options: options to use for the LLM client.
Returns:
Natural language response to the user question.
Expand Down Expand Up @@ -82,6 +89,7 @@ async def generate_response(self, result: ViewExecutionResult, question: str, ev
template=self._query_explainer_prompt_template,
fmt={"question": question, "query": query, "number_of_results": len(result.results)},
event_tracker=event_tracker,
options=llm_options,
)

return llm_response
Expand All @@ -90,6 +98,7 @@ async def generate_response(self, result: ViewExecutionResult, question: str, ev
template=self._nl_responder_prompt_template,
fmt={"rows": _promptify_rows(result.results), "question": question},
event_tracker=event_tracker,
options=llm_options,
)
return llm_response

Expand Down
12 changes: 10 additions & 2 deletions src/dbally/view_selection/base.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,29 @@
import abc
from typing import Dict
from typing import Dict, Optional

from dbally.audit.event_tracker import EventTracker
from dbally.llm_client.base import LLMOptions


class ViewSelector(abc.ABC):
"""Base class for view selectors."""

@abc.abstractmethod
async def select_view(self, question: str, views: Dict[str, str], event_tracker: EventTracker) -> str:
async def select_view(
self,
question: str,
views: Dict[str, str],
event_tracker: EventTracker,
llm_options: Optional[LLMOptions] = None,
) -> str:
"""
Based on user question and list of available views select the most relevant one.
Args:
question: user question asked in the natural language e.g "Do we have any data scientists?"
views: dictionary of available view names with corresponding descriptions.
event_tracker: event tracker used to audit the selection process.
llm_options: options to use for the LLM client.
Returns:
The most relevant view name.
Expand Down
Loading

0 comments on commit 3f5c4bf

Please sign in to comment.