Skip to content

Commit

Permalink
refactor: refactor prompt builder (#30)
Browse files Browse the repository at this point in the history
  • Loading branch information
micpst committed May 14, 2024
1 parent 7191c3f commit 61f1066
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 44 deletions.
4 changes: 0 additions & 4 deletions src/dbally/iql_generator/iql_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from dbally.audit.event_tracker import EventTracker
from dbally.iql_generator.iql_prompt_template import IQLPromptTemplate, default_iql_template
from dbally.llm_client.base import LLMClient, LLMOptions
from dbally.prompts.prompt_builder import PromptBuilder
from dbally.views.exposed_functions import ExposedFunction


Expand All @@ -28,19 +27,16 @@ def __init__(
self,
llm_client: LLMClient,
prompt_template: Optional[IQLPromptTemplate] = None,
prompt_builder: Optional[PromptBuilder] = None,
promptify_view: Optional[Callable] = None,
) -> None:
"""
Args:
llm_client: LLM client used to generate IQL
prompt_template: If not provided by the users is set to `default_iql_template`
prompt_builder: PromptBuilder used to insert arguments into the prompt and adjust style per model.
promptify_view: Function formatting filters for prompt
"""
self._llm_client = llm_client
self._prompt_template = prompt_template or copy.deepcopy(default_iql_template)
self._prompt_builder = prompt_builder or PromptBuilder()
self._promptify_view = promptify_view or _promptify_filters

async def generate_iql(
Expand Down
9 changes: 8 additions & 1 deletion src/dbally/llm_client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import abc
from abc import ABC
from dataclasses import asdict, dataclass
from functools import cached_property
from typing import Any, ClassVar, Dict, Generic, Optional, Type, TypeVar, Union

from dbally.audit.event_tracker import EventTracker
Expand Down Expand Up @@ -67,12 +68,18 @@ class LLMClient(Generic[LLMClientOptions], ABC):
def __init__(self, model_name: str, default_options: Optional[LLMClientOptions] = None) -> None:
self.model_name = model_name
self.default_options = default_options or self._options_cls()
self._prompt_builder = PromptBuilder(self.model_name)

def __init_subclass__(cls) -> None:
if not hasattr(cls, "_options_cls"):
raise TypeError(f"Class {cls.__name__} is missing the '_options_cls' attribute")

@cached_property
def _prompt_builder(self) -> PromptBuilder:
"""
Prompt builder used to construct final prompts for the LLM.
"""
return PromptBuilder()

async def text_generation( # pylint: disable=R0913
self,
template: PromptTemplate,
Expand Down
8 changes: 4 additions & 4 deletions src/dbally/prompts/prompt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,19 @@ class PromptBuilder:
def __init__(self, model_name: Optional[str] = None) -> None:
"""
Args:
model_name: Name of the model to load a tokenizer for.
Tokenizer is used to append special tokens to the prompt. If empty, no tokens will be added.
model_name: name of the tokenizer model to use. If provided, the tokenizer will convert the prompt to the
format expected by the model. The model_name should be a model available on huggingface.co/models.
Raises:
OSError: If model_name is not found in huggingface.co/models
"""
self._tokenizer: Optional["PreTrainedTokenizer"] = None

if model_name is not None and not model_name.startswith("gpt"):
if model_name is not None:
try:
from transformers import AutoTokenizer # pylint: disable=import-outside-toplevel
except ImportError as exc:
raise ImportError("You need to install transformers package to use huggingface models.") from exc
raise ImportError("You need to install transformers package to use huggingface tokenizers") from exc

self._tokenizer = AutoTokenizer.from_pretrained(model_name)

Expand Down
4 changes: 0 additions & 4 deletions src/dbally/view_selection/llm_view_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from dbally.audit.event_tracker import EventTracker
from dbally.iql_generator.iql_prompt_template import IQLPromptTemplate
from dbally.llm_client.base import LLMClient, LLMOptions
from dbally.prompts import PromptBuilder
from dbally.view_selection.base import ViewSelector
from dbally.view_selection.view_selector_prompt_template import default_view_selector_template

Expand All @@ -24,20 +23,17 @@ def __init__(
self,
llm_client: LLMClient,
prompt_template: Optional[IQLPromptTemplate] = None,
prompt_builder: Optional[PromptBuilder] = None,
promptify_views: Optional[Callable[[Dict[str, str]], str]] = None,
) -> None:
"""
Args:
llm_client: LLM client used to generate IQL
prompt_template: template for the prompt used for the view selection
prompt_builder: PromptBuilder used to insert arguments into the prompt and adjust style per model
promptify_views: Function formatting filters for prompt. By default names and descriptions of\
all views are concatenated
"""
self._llm_client = llm_client
self._prompt_template = prompt_template or copy.deepcopy(default_view_selector_template)
self._prompt_builder = prompt_builder or PromptBuilder()
self._promptify_views = promptify_views or _promptify_views

async def select_view(
Expand Down
22 changes: 2 additions & 20 deletions tests/integration/test_llm_options.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from unittest.mock import ANY, AsyncMock, Mock, call
from unittest.mock import ANY, AsyncMock, call

import pytest

from dbally import create_collection
from dbally.llm_client.base import LLMClient
from tests.unit.mocks import MockLLMOptions, MockViewBase
from tests.unit.mocks import MockLLMClient, MockLLMOptions, MockViewBase


class MockView1(MockViewBase):
Expand All @@ -15,23 +14,6 @@ class MockView2(MockViewBase):
...


class MockLLMClient(LLMClient[MockLLMOptions]):
_options_cls = MockLLMOptions

# TODO: Start calling super().__init__ and remove the pyling comment below
# as soon as the base class is refactored to not have PromptBuilder initialization
# hardcoded in its constructor.
# See: DBALLY-105
# pylint: disable=super-init-not-called
def __init__(self, default_options: MockLLMOptions) -> None:
self.model_name = "gpt-4"
self.default_options = default_options
self._prompt_builder = Mock()

async def call(self, *_, **__) -> str:
...


@pytest.mark.asyncio
async def test_llm_options_propagation():
default_options = MockLLMOptions(mock_property1=1, mock_property2="default mock")
Expand Down
18 changes: 7 additions & 11 deletions tests/unit/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""

from dataclasses import dataclass
from typing import List, Tuple, Union
from typing import List, Optional, Tuple, Union
from unittest.mock import create_autospec

from dbally import NOT_GIVEN, NotGiven
Expand Down Expand Up @@ -71,16 +71,12 @@ class MockLLMOptions(LLMOptions):
class MockLLMClient(LLMClient[MockLLMOptions]):
_options_cls = MockLLMOptions

# TODO: Start calling super().__init__ and remove the pyling comment below
# as soon as the base class is refactored to not have PromptBuilder initialization
# hardcoded in its constructor.
# See: DBALLY-105
# pylint: disable=super-init-not-called
def __init__(self, *_, **__) -> None:
self.model_name = "mock model"

async def text_generation(self, *_, **__) -> str:
return "mock response"
def __init__(
self,
model_name: str = "gpt-4-mock",
default_options: Optional[MockLLMOptions] = None,
) -> None:
super().__init__(model_name, default_options)

async def call(self, *_, **__) -> str:
return "mock response"

0 comments on commit 61f1066

Please sign in to comment.