Skip to content

Commit

Permalink
feat(llm): integrate LLMClient with litellm (#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
micpst authored May 20, 2024
1 parent 59e08ac commit c06b11d
Show file tree
Hide file tree
Showing 60 changed files with 666 additions and 633 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ This is a basic implementation of a db-ally view for an example HR application,

```python
from dbally import decorators, SqlAlchemyBaseView, create_collection
from dbally.llm_client.openai_client import OpenAIClient
from dbally.llms.litellm import LiteLLM
from sqlalchemy import create_engine

class CandidateView(SqlAlchemyBaseView):
Expand All @@ -53,7 +53,7 @@ class CandidateView(SqlAlchemyBaseView):
return Candidate.country == country

engine = create_engine('sqlite:///candidates.db')
llm = OpenAIClient(model_name="gpt-3.5-turbo")
llm = LiteLLM(model_name="gpt-3.5-turbo")
my_collection = create_collection("collection_name", llm)
my_collection.add(CandidateView, lambda: CandidateView(engine))

Expand Down Expand Up @@ -82,12 +82,12 @@ pip install dbally

Additionally, you can install one of our extensions to use specific features.

* `dbally[openai]`: Use [OpenAI's models](https://platform.openai.com/docs/models)
* `dbally[litellm]`: Use [100+ LLMs](https://docs.litellm.ai/docs/providers)
* `dbally[faiss]`: Use [Faiss](https://github.com/facebookresearch/faiss) indexes for similarity search
* `dbally[langsmith]`: Use [LangSmith](https://www.langchain.com/langsmith) for query tracking

```bash
pip install dbally[openai,faiss,langsmith]
pip install dbally[litellm,faiss,langsmith]
```

## License
Expand Down
6 changes: 3 additions & 3 deletions benchmark/dbally_benchmark/e2e_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import dbally
from dbally.collection import Collection
from dbally.iql_generator.iql_prompt_template import default_iql_template
from dbally.llm_client.openai_client import OpenAIClient
from dbally.llms.litellm import LiteLLM
from dbally.utils.errors import NoViewFoundError, UnsupportedQueryError
from dbally.view_selection.view_selector_prompt_template import default_view_selector_template

Expand Down Expand Up @@ -82,12 +82,12 @@ async def evaluate(cfg: DictConfig) -> Any:

engine = create_engine(benchmark_cfg.pg_connection_string + f"/{cfg.db_name}")

llm_client = OpenAIClient(
llm = LiteLLM(
model_name="gpt-4",
api_key=benchmark_cfg.openai_api_key,
)

db = dbally.create_collection(cfg.db_name, llm_client)
db = dbally.create_collection(cfg.db_name, llm)

for view_name in cfg.view_names:
view = VIEW_REGISTRY[ViewName(view_name)]
Expand Down
6 changes: 3 additions & 3 deletions benchmark/dbally_benchmark/iql_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from dbally.audit.event_tracker import EventTracker
from dbally.iql_generator.iql_generator import IQLGenerator
from dbally.iql_generator.iql_prompt_template import default_iql_template
from dbally.llm_client.openai_client import OpenAIClient
from dbally.llms.litellm import LiteLLM
from dbally.utils.errors import UnsupportedQueryError
from dbally.views.structured import BaseStructuredView

Expand Down Expand Up @@ -96,13 +96,13 @@ async def evaluate(cfg: DictConfig) -> Any:
view = VIEW_REGISTRY[ViewName(view_name)](engine)

if "gpt" in cfg.model_name:
llm_client = OpenAIClient(
llm = LiteLLM(
model_name=cfg.model_name,
api_key=benchmark_cfg.openai_api_key,
)
else:
raise ValueError("Only OpenAI's GPT models are supported for now.")
iql_generator = IQLGenerator(llm_client=llm_client)
iql_generator = IQLGenerator(llm=llm)

run = None
if cfg.neptune.log:
Expand Down
19 changes: 8 additions & 11 deletions benchmark/dbally_benchmark/text2sql_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@
from sqlalchemy import create_engine

from dbally.audit.event_tracker import EventTracker
from dbally.llm_client.base import LLMClient
from dbally.llm_client.openai_client import OpenAIClient
from dbally.llms.litellm import LiteLLM


def _load_db_schema(db_name: str, encoding: Optional[str] = None) -> str:
Expand All @@ -35,12 +34,12 @@ def _load_db_schema(db_name: str, encoding: Optional[str] = None) -> str:
return db_schema


async def _run_text2sql_for_single_example(example: BIRDExample, llm_client: LLMClient) -> Text2SQLResult:
async def _run_text2sql_for_single_example(example: BIRDExample, llm: LiteLLM) -> Text2SQLResult:
event_tracker = EventTracker()

db_schema = _load_db_schema(example.db_id)

response = await llm_client.text_generation(
response = await llm.generate_text(
TEXT2SQL_PROMPT_TEMPLATE, {"schema": db_schema, "question": example.question}, event_tracker=event_tracker
)

Expand All @@ -49,13 +48,13 @@ async def _run_text2sql_for_single_example(example: BIRDExample, llm_client: LLM
)


async def run_text2sql_for_dataset(dataset: BIRDDataset, llm_client: LLMClient) -> List[Text2SQLResult]:
async def run_text2sql_for_dataset(dataset: BIRDDataset, llm: LiteLLM) -> List[Text2SQLResult]:
"""
Transforms questions into SQL queries using a Text2SQL model.
Args:
dataset: The dataset containing questions to be transformed into SQL queries.
llm_client: LLM client.
llm: LLM client.
Returns:
A list of Text2SQLResult objects representing the predictions.
Expand All @@ -64,9 +63,7 @@ async def run_text2sql_for_dataset(dataset: BIRDDataset, llm_client: LLMClient)
results: List[Text2SQLResult] = []

for group in batch(dataset, 5):
current_results = await asyncio.gather(
*[_run_text2sql_for_single_example(example, llm_client) for example in group]
)
current_results = await asyncio.gather(*[_run_text2sql_for_single_example(example, llm) for example in group])
results = [*current_results, *results]

return results
Expand All @@ -88,7 +85,7 @@ async def evaluate(cfg: DictConfig) -> Any:
engine = create_engine(benchmark_cfg.pg_connection_string + f"/{cfg.db_name}")

if "gpt" in cfg.model_name:
llm_client = OpenAIClient(
llm = LiteLLM(
model_name=cfg.model_name,
api_key=benchmark_cfg.openai_api_key,
)
Expand All @@ -112,7 +109,7 @@ async def evaluate(cfg: DictConfig) -> Any:
evaluation_dataset = BIRDDataset.from_json_file(
Path(cfg.dataset_path), difficulty_levels=cfg.get("difficulty_levels")
)
text2sql_results = await run_text2sql_for_dataset(dataset=evaluation_dataset, llm_client=llm_client)
text2sql_results = await run_text2sql_for_dataset(dataset=evaluation_dataset, llm=llm)

with open(output_dir / results_file_name, "w", encoding="utf-8") as outfile:
json.dump([result.model_dump() for result in text2sql_results], outfile, indent=4)
Expand Down
4 changes: 2 additions & 2 deletions docs/concepts/collections.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
At its core, a collection groups together multiple [views](views.md). Once you've defined your views, the next step is to register them within a collection. Here's how you might do it:

```python
my_collection = dbally.create_collection("collection_name", llm_client=OpenAIClient())
my_collection = dbally.create_collection("collection_name", llm=LiteLLM())
my_collection.add(ExampleView)
my_collection.add(RecipesView)
```

Sometimes, view classes might need certain arguments when they're instantiated. In these instances, you'll want to register your view with a builder function that takes care of supplying these arguments. For instance, with views that rely on SQLAlchemy, you'll typically need to pass a database engine object like so:

```python
my_collection = dbally.create_collection("collection_name", llm_client=OpenAIClient())
my_collection = dbally.create_collection("collection_name", llm=LiteLLM())
engine = sqlalchemy.create_engine("sqlite://")
my_collection.add(ExampleView, lambda: ExampleView(engine))
my_collection.add(RecipesView, lambda: RecipesView(engine))
Expand Down
2 changes: 1 addition & 1 deletion docs/concepts/freeform_views.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

Freeform views are a type of [view](views.md) that provides a way for developers using db-ally to define what they need from the LLM without requiring a fixed response structure. This flexibility is beneficial when the data structure is unknown beforehand or when potential queries are too diverse to be covered by a structured view. Though freeform views offer more flexibility than structured views, they are less predictable, efficient, and secure, and may be more challenging to integrate with other systems. For these reasons, we recommend using [structured views](./structured_views.md) when possible.

Unlike structured views, which define a response format and a set of operations the LLM may use in response to natural language queries, freeform views only have one task - to respond directly to natural language queries with data from the datasource. They accomplish this by implementing the [`ask`][dbally.views.base.BaseView] method. This method takes a natural language query as input and returns a response. The method also has access to the LLM model (via the `llm_client` attribute), which is typically used to retrieve the correct data from the source (for example, by generating a source-specific query string). To learn more about implementing freeform views, refer to the [How to: Custom Freeform Views](../how-to/custom_freeform_views.md) guide.
Unlike structured views, which define a response format and a set of operations the LLM may use in response to natural language queries, freeform views only have one task - to respond directly to natural language queries with data from the datasource. They accomplish this by implementing the [`ask`][dbally.views.base.BaseView] method. This method takes a natural language query as input and returns a response. The method also has access to the LLM model (via the `llm` attribute), which is typically used to retrieve the correct data from the source (for example, by generating a source-specific query string). To learn more about implementing freeform views, refer to the [How to: Custom Freeform Views](../how-to/custom_freeform_views.md) guide.

## Security

Expand Down
4 changes: 2 additions & 2 deletions docs/how-to/create_custom_event_handler.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,11 @@ To use our event handler, we need to pass it to the collection when creating it:

```python
import dbally
from dbally.llm_client.openai_client import OpenAIClient
from dbally.llms.litellm import LiteLLM

my_collection = bally.create_collection(
"collection_name",
llm_client=OpenAIClient(),
llm=LiteLLM(),
event_handlers=[FileEventHandler()],
)
```
Expand Down
4 changes: 2 additions & 2 deletions docs/how-to/custom_views.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,10 +219,10 @@ Finally, we can use the `CandidatesView` just like any other view in db-ally. We
```python
import asyncio
import dbally
from dbally.llm_client.openai_client import OpenAIClient
from dbally.llms.litellm import LiteLLM

async def main():
llm = OpenAIClient(model_name="gpt-3.5-turbo")
llm = LiteLLM(model_name="gpt-3.5-turbo")
collection = dbally.create_collection("recruitment", llm)
collection.add(CandidateView)

Expand Down
4 changes: 2 additions & 2 deletions docs/how-to/custom_views_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler
from dbally.iql import IQLQuery, syntax
from dbally.data_models.execution_result import ViewExecutionResult
from dbally.llm_client.openai_client import OpenAIClient
from dbally.llms.litellm import LiteLLM

@dataclass
class Candidate:
Expand Down Expand Up @@ -99,7 +99,7 @@ def from_country(self, country: str) -> Callable[[Candidate], bool]:
return lambda x: x.country == country

async def main():
llm = OpenAIClient(model_name="gpt-3.5-turbo")
llm = LiteLLM(model_name="gpt-3.5-turbo")
event_handlers = [CLIEventHandler()]
collection = dbally.create_collection("recruitment", llm, event_handlers=event_handlers)
collection.add(CandidateView)
Expand Down
2 changes: 1 addition & 1 deletion docs/how-to/log_runs_to_langsmith.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ from dbally.audit.event_handlers.langsmith_event_handler import LangSmithEventHa
my_collection = dbally.create_collection(
"collection_name",
llm_client=OpenAIClient(),
llm=LiteLLM(),
event_handlers=[LangSmithEventHandler(api_key="your_api_key")],
)
```
Expand Down
4 changes: 2 additions & 2 deletions docs/how-to/pandas_views.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ To use the view, you need to create a [Collection](../concepts/collections.md) a

```python
import dbally
from dbally.llm_client.openai_client import OpenAIClient
from dbally.llms.litellm import LiteLLM

llm = OpenAIClient(model_name="gpt-3.5-turbo")
llm = LiteLLM(model_name="gpt-3.5-turbo")
collection = dbally.create_collection("recruitment", llm)
collection.add(CandidateView, lambda: CandidateView(CANDIDATE_DATA))

Expand Down
4 changes: 2 additions & 2 deletions docs/how-to/pandas_views_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from dbally import decorators, DataFrameBaseView
from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler
from dbally.llm_client.openai_client import OpenAIClient
from dbally.llms.clients.litellm import LiteLLM


class CandidateView(DataFrameBaseView):
Expand Down Expand Up @@ -46,7 +46,7 @@ def senior_data_scientist_position(self) -> pd.Series:
])

async def main():
llm = OpenAIClient(model_name="gpt-3.5-turbo")
llm = LiteLLM(model_name="gpt-3.5-turbo")
collection = dbally.create_collection("recruitment", llm, event_handlers=[CLIEventHandler()])
collection.add(CandidateView, lambda: CandidateView(CANDIDATE_DATA))

Expand Down
4 changes: 2 additions & 2 deletions docs/how-to/sql_views.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ engine = create_engine('sqlite:///candidates.db')
Once you have defined your view and created an engine, you can register the view with db-ally. You do this by creating a collection and adding the view to it:

```python
from dbally.llm_client.openai_client import OpenAIClient
from dbally.llms.litellm import LiteLLM

my_collection = dbally.create_collection("collection_name", llm_client=OpenAIClient())
my_collection = dbally.create_collection("collection_name", llm=LiteLLM())
my_collection.add(CandidateView, lambda: CandidateView(engine))
```

Expand Down
4 changes: 2 additions & 2 deletions docs/how-to/update_similarity_indexes.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ If you have a [collection](../concepts/collections.md) and want to update Simila

```python
from db_ally import create_collection
from db_ally.llm_client.openai_client import OpenAIClient
from db_ally.llms.litellm import LiteLLM

my_collection = create_collection("collection_name", llm_client=OpenAIClient())
my_collection = create_collection("collection_name", llm=LiteLLM())

# ... add views to the collection

Expand Down
8 changes: 4 additions & 4 deletions docs/quickstart/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ To install db-ally, execute the following command:
pip install dbally
```

Since we will be using OpenAI's GPT, you also need to install the `openai` extension:
Since we will be using OpenAI's GPT, you also need to install the `litellm` extension:

```bash
pip install dbally[openai]
pip install dbally[litellm]
```

## Database Configuration
Expand Down Expand Up @@ -104,9 +104,9 @@ By setting up these filters, you enable the LLM to fetch candidates while option
To use OpenAI's GPT, configure db-ally and provide your OpenAI API key:

```python
from dbally.llm_client.openai_client import OpenAIClient
from dbally.llms.litellm import LiteLLM

llm = OpenAIClient(model_name="gpt-3.5-turbo", api_key="...")
llm = LiteLLM(model_name="gpt-3.5-turbo", api_key="...")
```

Replace `...` with your OpenAI API key. Alternatively, you can set the `OPENAI_API_KEY` environment variable with your API key and omit the `api_key` parameter altogether.
Expand Down
4 changes: 2 additions & 2 deletions docs/quickstart/quickstart2_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler
from dbally.similarity import SimpleSqlAlchemyFetcher, FaissStore, SimilarityIndex
from dbally.embedding_client.openai import OpenAiEmbeddingClient
from dbally.llm_client.openai_client import OpenAIClient
from dbally.llms.litellm import LiteLLM

engine = create_engine('sqlite:///candidates.db')

Expand Down Expand Up @@ -73,7 +73,7 @@ def from_country(self, country: Annotated[str, country_similarity]) -> sqlalchem
async def main():
await country_similarity.update()

llm = OpenAIClient(model_name="gpt-3.5-turbo")
llm = LiteLLM(model_name="gpt-3.5-turbo")
collection = dbally.create_collection("recruitment", llm, event_handlers=[CLIEventHandler()])
collection.add(CandidateView, lambda: CandidateView(engine))

Expand Down
4 changes: 2 additions & 2 deletions docs/quickstart/quickstart3_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from dbally import decorators, SqlAlchemyBaseView, DataFrameBaseView, ExecutionResult
from dbally.similarity import SimpleSqlAlchemyFetcher, FaissStore, SimilarityIndex
from dbally.embedding_client.openai import OpenAiEmbeddingClient
from dbally.llm_client.openai_client import OpenAIClient
from dbally.llms.litellm import LiteLLM

engine = create_engine('sqlite:///candidates.db')

Expand Down Expand Up @@ -122,7 +122,7 @@ def display_results(result: ExecutionResult):
async def main():
await country_similarity.update()

llm = OpenAIClient(model_name="gpt-3.5-turbo")
llm = LiteLLM(model_name="gpt-3.5-turbo")
collection = dbally.create_collection("recruitment", llm)
collection.add(CandidateView, lambda: CandidateView(engine))
collection.add(JobView, lambda: JobView(jobs_data))
Expand Down
4 changes: 2 additions & 2 deletions docs/quickstart/quickstart_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from dbally import decorators, SqlAlchemyBaseView
from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler
from dbally.llm_client.openai_client import OpenAIClient
from dbally.llms.litellm import LiteLLM


engine = create_engine('sqlite:///candidates.db')
Expand Down Expand Up @@ -54,7 +54,7 @@ def from_country(self, country: str) -> sqlalchemy.ColumnElement:
return Candidate.country == country

async def main():
llm = OpenAIClient(model_name="gpt-3.5-turbo")
llm = LiteLLM(model_name="gpt-3.5-turbo")

collection = dbally.create_collection("recruitment", llm, event_handlers=[CLIEventHandler()])
collection.add(CandidateView, lambda: CandidateView(engine))
Expand Down
6 changes: 3 additions & 3 deletions docs/tutorials/LangGraphXdbally.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
},
"outputs": [],
"source": [
"!pip install -U dbally[openai] langgraph langchain langchain_openai langchain_experimental dbally[langsmith]"
"!pip install -U dbally[litellm,langsmith] langgraph langchain langchain_openai langchain_experimental"
]
},
{
Expand Down Expand Up @@ -203,9 +203,9 @@
"outputs": [],
"source": [
"import dbally\n",
"from dbally.llm_client.openai_client import OpenAIClient\n",
"from dbally.llms.litellm import LiteLLM\n",
"\n",
"recruitment_db = dbally.create_collection(\"recruitment\", llm_client=OpenAIClient())\n",
"recruitment_db = dbally.create_collection(\"recruitment\", llm=LiteLLM())\n",
"recruitment_db.add(JobOfferView, lambda: JobOfferView(ENGINE))\n",
"recruitment_db.add(CandidateView, lambda: CandidateView(ENGINE))"
]
Expand Down
Loading

0 comments on commit c06b11d

Please sign in to comment.