Skip to content

Commit

Permalink
feat(codegen): add text2sql views code generation (#40)
Browse files Browse the repository at this point in the history
  • Loading branch information
micpst committed Jun 12, 2024
1 parent e0305ff commit 4284344
Show file tree
Hide file tree
Showing 5 changed files with 897 additions and 6 deletions.
5 changes: 5 additions & 0 deletions src/dbally_cli/main.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import click

from dbally_cli.text2sql import generate_text2sql_view


@click.group()
def cli() -> None:
"""
Command line tool for interacting with dbally.
"""


cli.add_command(generate_text2sql_view)
248 changes: 248 additions & 0 deletions src/dbally_cli/text2sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
import asyncio
import importlib
import os
from typing import Callable, Optional, Union

import click
import sqlalchemy
from click import BadParameter, Context, Option
from sqlalchemy import Column, Engine, Table
from sqlalchemy.exc import ArgumentError

from dbally.embeddings.litellm import LiteLLMEmbeddingClient
from dbally.llms.base import LLM
from dbally.llms.litellm import LiteLLM
from dbally.similarity.faiss_store import FaissStore
from dbally.similarity.index import SimilarityIndex
from dbally.similarity.sqlalchemy_base import SimpleSqlAlchemyFetcher
from dbally_codegen.autodiscovery import configure_text2sql_auto_discovery
from dbally_codegen.generator import Text2SQLViewGenerator


def faiss_builder(engine: sqlalchemy.Engine, table: sqlalchemy.Table, column: sqlalchemy.Column) -> SimilarityIndex:
"""
Build a Faiss store.
Args:
engine: The SQLAlchemy engine.
table: The table.
column: The column.
Returns:
The Faiss store.
"""
return SimilarityIndex(
fetcher=SimpleSqlAlchemyFetcher(
sqlalchemy_engine=engine,
column=column,
table=table,
),
store=FaissStore(
index_dir=".",
index_name=f"{table.name}_{column.name}_index",
embedding_client=LiteLLMEmbeddingClient(),
),
)


def validate_file_path(_ctx: Context, _param: Option, value: str) -> str:
"""
Validate the file path.
Args:
value: The value of the option.
Returns:
The validated file path.
"""
root, ext = os.path.splitext(value)
if not ext:
ext = ".py"
elif ext != ".py":
raise BadParameter("file extension must be '.py'.")
return f"{root}{ext}"


def validate_db_url(_ctx: Context, _param: Option, value: Union[str, Engine]) -> Engine:
"""
Validate the database connection string.
Args:
value: The value of the option.
Returns:
The validated database connection string.
"""
if isinstance(value, Engine):
return value
if not value:
raise BadParameter("database connection string is required.")

try:
return sqlalchemy.create_engine(value)
except ArgumentError as exc:
raise BadParameter("invalid database connection string.") from exc


def validate_llm_object(_ctx: Context, _param: Option, value: Union[str, LLM]) -> Optional[LLM]:
""" "
Validate the LLM object.
Args:
value: The value of the option.
Returns:
The validated LLM object.
"""
if isinstance(value, LLM):
return value
if value == "None" or value is None:
return None
if value.startswith("litellm:"):
return LiteLLM(value.split(":")[1])

llm = load_object(value)
if not isinstance(llm, LLM):
raise BadParameter("The LLM object must be an instance of the LLM class.")
return llm


def validate_similarity_index_factory(
_ctx: Context, _param: Option, value: Union[str, Callable[[Engine, Table, Column], SimilarityIndex]]
) -> Optional[Callable[[Engine, Table, Column], SimilarityIndex]]:
"""
Validate the similarity index factory.
Args:
value: The value of the option.
Returns:
The validated similarity index factory.
"""
if callable(value):
return value
if value == "None" or value is None:
return None
if value == "faiss":
return faiss_builder

index_builder = load_object(value) if value else None
if not callable(index_builder):
raise BadParameter("The similarity index factory must be a callable object.")
return index_builder


@click.command(help="Generate a Text2SQL view definition file.")
@click.option(
"--file_path",
default="text2sql_view.py",
show_default=True,
prompt="File path",
help="The path to the file where the view will be generated.",
callback=validate_file_path,
)
@click.option(
"--db",
default="sqlite://",
show_default=True,
prompt="Database URL",
help="The database connection string.",
callback=validate_db_url,
type=click.UNPROCESSED,
)
@click.option(
"--llm",
default="None",
show_default=True,
prompt="LLM object",
help="The path to the LLM object.",
callback=validate_llm_object,
type=click.UNPROCESSED,
)
@click.option(
"--llm_description",
is_flag=True,
default=False,
show_default=True,
prompt="LLM table description?",
help="Generate tables description using LLM.",
)
@click.option(
"--similarity_index_factory",
default="None",
show_default=True,
prompt="Similarity index factory",
help="The path to the similarity index factory.",
callback=validate_similarity_index_factory,
type=click.UNPROCESSED,
)
def generate_text2sql_view(
file_path: str,
db: Engine,
llm: Optional[LLM],
llm_description: bool,
similarity_index_factory: Optional[Callable[[Engine, sqlalchemy.Table, sqlalchemy.Column], SimilarityIndex]],
) -> None:
"""
Generate a Text2SQL view definition file.
Args:
file_path: The path to the file where the view will be generated.
db: The database connection string.
llm: The path to the LLM object.
llm_description: Generate a description using the LLM object.
similarity_index_factory: The path to the similarity index factory.
"""
builder = configure_text2sql_auto_discovery(db)
if llm:
builder = builder.use_llm(llm)
if llm_description:
builder = builder.generate_description_by_llm()
if similarity_index_factory:
builder = builder.suggest_similarity_indexes(similarity_index_factory)

click.echo("Discovering tables...")
tables = asyncio.run(builder.discover())
click.echo(f"Discovered {len(tables)} tables.")

click.echo("Generating Text2SQL view...")
generator = Text2SQLViewGenerator(tables)
code = generator.generate()

dirs, _ = os.path.split(file_path)
if dirs:
os.makedirs(dirs, exist_ok=True)

with open(file_path, "w", encoding="utf-8") as file:
file.write(code)

click.echo(f"Generated Text2SQL view in {file_path}.")


def load_object(path: str) -> object:
"""
Load an object from a module.
Args:
path: The path to the object in the format 'module:object'.
Returns:
The object.
Raises:
BadParameter: If the object is not found.
"""
try:
module_name, object_name = path.split(":")
except ValueError as exc:
raise BadParameter("The object must be in the format 'module:object'.") from exc

try:
module = importlib.import_module(module_name)
except ModuleNotFoundError as exc:
raise BadParameter(f"Could not find the module '{module_name}'.") from exc

try:
return getattr(module, object_name)
except AttributeError as exc:
raise BadParameter(f"Could not find the '{object_name}' object in the '{module_name}' module.") from exc
11 changes: 5 additions & 6 deletions src/dbally_codegen/autodiscovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from dbally.llms.base import LLM
from dbally.prompts import PromptTemplate
from dbally.similarity.index import SimilarityIndex
from dbally.similarity.store import SimilarityStore
from dbally.views.freeform.text2sql import ColumnConfig, TableConfig

DISCOVERY_TEMPLATE = PromptTemplate(
Expand Down Expand Up @@ -141,7 +140,7 @@ async def select_index(
column: Column,
description: str,
connection: Connection,
) -> Optional[SimilarityStore]:
) -> Optional[SimilarityIndex]:
"""
Select the similarity index for the column.
Expand All @@ -167,7 +166,7 @@ async def select_index(
column: Column,
description: str,
connection: Connection,
) -> Optional[SimilarityStore]:
) -> Optional[SimilarityIndex]:
"""
Select the similarity index for the column.
Expand Down Expand Up @@ -201,7 +200,7 @@ async def select_index(
column: Column,
description: str,
connection: Connection,
) -> Optional[SimilarityStore]:
) -> Optional[SimilarityIndex]:
"""
Select the similarity index for the column using LLM.
Expand Down Expand Up @@ -333,14 +332,14 @@ async def discover(self) -> List[TableConfig]:
)
columns.append(
ColumnConfig(
name=column.name,
name=str(column.name),
data_type=str(column.type),
similarity_index=similarity_index,
)
)
tables.append(
TableConfig(
name=table.name,
name=str(table.name),
description=description,
columns=columns,
)
Expand Down
Loading

0 comments on commit 4284344

Please sign in to comment.