Skip to content

Commit

Permalink
refactor: move prompt templates to modules which use them (#26)
Browse files Browse the repository at this point in the history
* move prompt templates to modules which use them
* reorganize imports
  • Loading branch information
karllu3 committed May 10, 2024
1 parent 9b31299 commit e0f17a7
Show file tree
Hide file tree
Showing 27 changed files with 39 additions and 55 deletions.
4 changes: 2 additions & 2 deletions benchmark/dbally_benchmark/e2e_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@

import dbally
from dbally.collection import Collection
from dbally.data_models.prompts.iql_prompt_template import default_iql_template
from dbally.data_models.prompts.view_selector_prompt_template import default_view_selector_template
from dbally.iql_generator.iql_prompt_template import default_iql_template
from dbally.llm_client.openai_client import OpenAIClient
from dbally.utils.errors import NoViewFoundError, UnsupportedQueryError
from dbally.view_selection.view_selector_prompt_template import default_view_selector_template


async def _run_dbally_for_single_example(example: BIRDExample, collection: Collection) -> Text2SQLResult:
Expand Down
2 changes: 1 addition & 1 deletion benchmark/dbally_benchmark/iql_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from sqlalchemy import create_engine

from dbally.audit.event_tracker import EventTracker
from dbally.data_models.prompts.iql_prompt_template import default_iql_template
from dbally.iql_generator.iql_generator import IQLGenerator
from dbally.iql_generator.iql_prompt_template import default_iql_template
from dbally.llm_client.openai_client import OpenAIClient
from dbally.utils.errors import UnsupportedQueryError
from dbally.views.structured import BaseStructuredView
Expand Down
2 changes: 1 addition & 1 deletion benchmark/dbally_benchmark/text2sql/prompt_template.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dbally.prompts.prompt_builder import PromptTemplate
from dbally.prompts import PromptTemplate

TEXT2SQL_PROMPT_TEMPLATE = PromptTemplate(
(
Expand Down
2 changes: 1 addition & 1 deletion examples/recruiting.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler
from dbally.audit.event_tracker import EventTracker
from dbally.llm_client.openai_client import OpenAIClient
from dbally.prompts.prompt_builder import PromptTemplate
from dbally.prompts import PromptTemplate

TEXT2SQL_PROMPT_TEMPLATE = PromptTemplate(
(
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ testpaths = ['tests']

[tool.pytest.ini_options]
asyncio_mode = "auto"
testpaths = ["tests"]
pythonpath = ["."]

[tool.mypy]
warn_unused_configs = true
Expand Down
1 change: 1 addition & 0 deletions src/dbally/__version__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
"""Version information."""

__version__ = "0.1.0"
2 changes: 1 addition & 1 deletion src/dbally/data_models/audit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Optional, Union

from dbally.data_models.execution_result import ExecutionResult
from dbally.data_models.prompts.prompt_template import ChatFormat
from dbally.prompts import ChatFormat


class EventType(Enum):
Expand Down
15 changes: 0 additions & 15 deletions src/dbally/data_models/prompts/__init__.py

This file was deleted.

2 changes: 1 addition & 1 deletion src/dbally/iql_generator/iql_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Callable, List, Optional, Tuple, TypeVar

from dbally.audit.event_tracker import EventTracker
from dbally.data_models.prompts.iql_prompt_template import IQLPromptTemplate, default_iql_template
from dbally.iql_generator.iql_prompt_template import IQLPromptTemplate, default_iql_template
from dbally.llm_client.base import LLMClient
from dbally.prompts.prompt_builder import PromptBuilder
from dbally.views.exposed_functions import ExposedFunction
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Callable, Dict, Optional

from dbally.data_models.prompts.common_validation_utils import _check_prompt_variables
from dbally.data_models.prompts.prompt_template import ChatFormat, PromptTemplate
from dbally.prompts import ChatFormat, PromptTemplate, check_prompt_variables
from dbally.utils.errors import UnsupportedQueryError


Expand All @@ -17,7 +16,7 @@ def __init__(
llm_response_parser: Callable = lambda x: x,
):
super().__init__(chat, response_format, llm_response_parser)
self.chat = _check_prompt_variables(chat, {"filters", "question"})
self.chat = check_prompt_variables(chat, {"filters", "question"})


def _validate_iql_response(llm_response: str) -> str:
Expand Down
2 changes: 1 addition & 1 deletion src/dbally/llm_client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from dbally.audit.event_tracker import EventTracker
from dbally.data_models.audit import LLMEvent
from dbally.data_models.llm_options import LLMOptions
from dbally.prompts.prompt_builder import ChatFormat, PromptBuilder, PromptTemplate
from dbally.prompts import ChatFormat, PromptBuilder, PromptTemplate


class LLMClient(abc.ABC):
Expand Down
2 changes: 1 addition & 1 deletion src/dbally/llm_client/openai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dbally.data_models.audit import LLMEvent
from dbally.data_models.llm_options import LLMOptions
from dbally.llm_client.base import LLMClient
from dbally.prompts.prompt_builder import ChatFormat
from dbally.prompts import ChatFormat


class OpenAIClient(LLMClient):
Expand Down
9 changes: 3 additions & 6 deletions src/dbally/nl_responder/nl_responder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,12 @@

from dbally.audit.event_tracker import EventTracker
from dbally.data_models.execution_result import ViewExecutionResult
from dbally.data_models.prompts.nl_responder_prompt_template import (
NLResponderPromptTemplate,
default_nl_responder_template,
)
from dbally.data_models.prompts.query_explainer_prompt_template import (
from dbally.llm_client.base import LLMClient
from dbally.nl_responder.nl_responder_prompt_template import NLResponderPromptTemplate, default_nl_responder_template
from dbally.nl_responder.query_explainer_prompt_template import (
QueryExplainerPromptTemplate,
default_query_explainer_template,
)
from dbally.llm_client.base import LLMClient
from dbally.nl_responder.token_counters import count_tokens_for_huggingface, count_tokens_for_openai


Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Callable, Dict, Optional

from dbally.data_models.prompts.common_validation_utils import _check_prompt_variables
from dbally.data_models.prompts.prompt_template import ChatFormat, PromptTemplate
from dbally.prompts import ChatFormat, PromptTemplate, check_prompt_variables


class NLResponderPromptTemplate(PromptTemplate):
Expand All @@ -25,7 +24,7 @@ def __init__(
"""

super().__init__(chat, response_format, llm_response_parser)
self.chat = _check_prompt_variables(chat, {"rows", "question"})
self.chat = check_prompt_variables(chat, {"rows", "question"})


default_nl_responder_template = NLResponderPromptTemplate(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Callable, Dict, Optional

from dbally.data_models.prompts.common_validation_utils import _check_prompt_variables
from dbally.data_models.prompts.prompt_template import ChatFormat, PromptTemplate
from dbally.prompts import ChatFormat, PromptTemplate, check_prompt_variables


class QueryExplainerPromptTemplate(PromptTemplate):
Expand All @@ -22,7 +21,7 @@ def __init__(
llm_response_parser: Callable = lambda x: x,
) -> None:
super().__init__(chat, response_format, llm_response_parser)
self.chat = _check_prompt_variables(chat, {"question", "query", "number_of_results"})
self.chat = check_prompt_variables(chat, {"question", "query", "number_of_results"})


default_query_explainer_template = QueryExplainerPromptTemplate(
Expand Down
2 changes: 1 addition & 1 deletion src/dbally/nl_responder/token_counters.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Dict

from dbally.data_models.prompts.common_validation_utils import ChatFormat
from dbally.prompts import ChatFormat


def count_tokens_for_openai(messages: ChatFormat, fmt: Dict[str, str], model: str) -> int:
Expand Down
4 changes: 3 additions & 1 deletion src/dbally/prompts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from .common_validation_utils import ChatFormat, PromptTemplateError, check_prompt_variables
from .prompt_builder import PromptBuilder
from .prompt_template import PromptTemplate

__all__ = ["PromptBuilder"]
__all__ = ["PromptBuilder", "PromptTemplate", "PromptTemplateError", "check_prompt_variables", "ChatFormat"]
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def _extract_variables(text: str) -> List[str]:
return re.findall(pattern, text)


def _check_prompt_variables(chat: ChatFormat, variables_to_check: Set[str]) -> ChatFormat:
def check_prompt_variables(chat: ChatFormat, variables_to_check: Set[str]) -> ChatFormat:
"""
Function validates a given chat to make sure it contains variables required.
Expand Down
3 changes: 2 additions & 1 deletion src/dbally/prompts/prompt_builder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import TYPE_CHECKING, Dict, Optional, Union

from dbally.data_models.prompts.prompt_template import ChatFormat, PromptTemplate
from .common_validation_utils import ChatFormat
from .prompt_template import PromptTemplate

if TYPE_CHECKING:
from transformers.tokenization_utils import PreTrainedTokenizer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing_extensions import Self

from dbally.data_models.prompts.common_validation_utils import ChatFormat, PromptTemplateError
from .common_validation_utils import ChatFormat, PromptTemplateError


def _check_chat_order(chat: ChatFormat) -> ChatFormat:
Expand Down
3 changes: 2 additions & 1 deletion src/dbally/view_selection/llm_view_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
from typing import Callable, Dict, Optional

from dbally.audit.event_tracker import EventTracker
from dbally.data_models.prompts import IQLPromptTemplate, default_view_selector_template
from dbally.iql_generator.iql_prompt_template import IQLPromptTemplate
from dbally.llm_client.base import LLMClient
from dbally.prompts import PromptBuilder
from dbally.view_selection.base import ViewSelector
from dbally.view_selection.view_selector_prompt_template import default_view_selector_template


class LLMViewSelector(ViewSelector):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import json
from typing import Callable, Dict, Optional

from dbally.data_models.prompts.common_validation_utils import _check_prompt_variables
from dbally.data_models.prompts.prompt_template import ChatFormat, PromptTemplate
from dbally.prompts import ChatFormat, PromptTemplate, check_prompt_variables


class ViewSelectorPromptTemplate(PromptTemplate):
Expand All @@ -17,7 +16,7 @@ def __init__(
llm_response_parser: Callable = lambda x: x,
):
super().__init__(chat, response_format, llm_response_parser)
self.chat = _check_prompt_variables(chat, {"views"})
self.chat = check_prompt_variables(chat, {"views"})


def _convert_llm_json_response_to_selected_view(llm_response_json: str) -> str:
Expand Down
2 changes: 1 addition & 1 deletion src/dbally/views/freeform/text2sql/_autodiscovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from sqlalchemy.sql.ddl import CreateTable
from typing_extensions import Self

from dbally.data_models.prompts import PromptTemplate
from dbally.llm_client.base import LLMClient
from dbally.prompts import PromptTemplate

from ._config import Text2SQLConfig, Text2SQLTableConfig

Expand Down
2 changes: 1 addition & 1 deletion src/dbally/views/freeform/text2sql/_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

from dbally.audit.event_tracker import EventTracker
from dbally.data_models.execution_result import ViewExecutionResult
from dbally.data_models.prompts import PromptTemplate
from dbally.llm_client.base import LLMClient
from dbally.prompts import PromptTemplate
from dbally.views.base import BaseView

from ._config import Text2SQLConfig
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from typing import List, Tuple
from unittest.mock import create_autospec

from dbally.data_models.prompts.iql_prompt_template import IQLPromptTemplate, default_iql_template
from dbally.iql import IQLQuery
from dbally.iql_generator.iql_generator import IQLGenerator
from dbally.iql_generator.iql_prompt_template import IQLPromptTemplate, default_iql_template
from dbally.llm_client.base import LLMClient
from dbally.similarity.index import AbstractSimilarityIndex
from dbally.view_selection.base import ViewSelector
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_iql_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@

from dbally import decorators
from dbally.audit.event_tracker import EventTracker
from dbally.data_models.prompts.iql_prompt_template import default_iql_template
from dbally.iql import IQLQuery
from dbally.iql_generator.iql_generator import IQLGenerator
from dbally.iql_generator.iql_prompt_template import default_iql_template
from dbally.views.methods_base import MethodsBaseView


Expand Down
5 changes: 2 additions & 3 deletions tests/unit/test_prompt_builder.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import pytest

from dbally.data_models.prompts.iql_prompt_template import IQLPromptTemplate
from dbally.data_models.prompts.prompt_template import ChatFormat, PromptTemplate, PromptTemplateError
from dbally.prompts.prompt_builder import PromptBuilder
from dbally.iql_generator.iql_prompt_template import IQLPromptTemplate
from dbally.prompts import ChatFormat, PromptBuilder, PromptTemplate, PromptTemplateError


@pytest.fixture()
Expand Down

0 comments on commit e0f17a7

Please sign in to comment.