diff --git a/examples/pipelines/rag/text_to_sql_pipeline.py b/examples/pipelines/rag/text_to_sql_pipeline.py index 7baa4cb6..ab30e72b 100644 --- a/examples/pipelines/rag/text_to_sql_pipeline.py +++ b/examples/pipelines/rag/text_to_sql_pipeline.py @@ -10,6 +10,7 @@ from typing import List, Union, Generator, Iterator import os +from pydantic import BaseModel from llama_index.llms.ollama import Ollama from llama_index.core.query_engine import NLSQLTableQueryEngine from llama_index.core import SQLDatabase, PromptTemplate @@ -17,23 +18,43 @@ class Pipeline: + class Valves(BaseModel): + DB_HOST: str + DB_PORT: str + DB_USER: str + DB_PASSWORD: str + DB_DATABASE: str + DB_TABLES: list[str] + OLLAMA_HOST: str + TEXT_TO_SQL_MODEL: str + + + # Update valves/ environment variables based on your selected database def __init__(self): - self.PG_HOST = os.environ["PG_HOST"] - self.PG_PORT = os.environ["PG_PORT"] - self.PG_USER = os.environ["PG_USER"] - self.PG_PASSWORD = os.environ["PG_PASSWORD"] - self.PG_DB = os.environ["PG_DB"] - self.ollama_host = "http://host.docker.internal:11434" # Make sure to update with the URL of your Ollama host, such at http://localhost:11434 or remote server address - self.model = "phi3:medium-128k" # Model to use for text-to-SQL generation + self.name = "Database RAG Pipeline" self.engine = None self.nlsql_response = "" - self.tables = ["db_table"] # Update to the name of the database table you want to get data from + + # Initialize + self.valves = self.Valves( + **{ + "pipelines": ["*"], # Connect to all pipelines + "DB_HOST": os.environ["PG_HOST"], # Database hostname + "DB_PORT": os.environ["PG_PORT"], # Database port + "DB_USER": os.environ["PG_USER"], # User to connect to the database with + "DB_PASSWORD": os.environ["PG_PASSWORD"], # Password to connect to the database with + "DB_DATABASE": os.environ["PG_DB"], # Database to select on the DB instance + "DB_TABLES": ["albums"], # Table(s) to run queries against + "OLLAMA_HOST": "http://host.docker.internal:11434", # Make sure to update with the URL of your Ollama host, such as http://localhost:11434 or remote server address + "TEXT_TO_SQL_MODEL": "phi3:latest" # Model to use for text-to-SQL generation + } + ) def init_db_connection(self): - self.engine = create_engine(f"postgresql+psycopg2://{self.PG_USER}:{self.PG_PASSWORD}@{self.PG_HOST}:{self.PG_PORT}/{self.PG_DB}") + # Update your DB connection string based on selected DB engine - current connection string is for Postgres + self.engine = create_engine(f"postgresql+psycopg2://{self.valves.DB_USER}:{self.valves.DB_PASSWORD}@{self.valves.DB_HOST}:{self.valves.DB_PORT}/{self.valves.DB_DATABASE}") return self.engine - async def on_startup(self): # This function is called when the server is started. self.init_db_connection() @@ -48,10 +69,10 @@ def pipe( # Debug logging is required to see what SQL query is generated by the LlamaIndex library; enable on Pipelines server if needed # Create database reader for Postgres - sql_database = SQLDatabase(self.engine, include_tables=self.tables) + sql_database = SQLDatabase(self.engine, include_tables=self.valves.DB_TABLES) # Set up LLM connection; uses phi3 model with 128k context limit since some queries have returned 20k+ tokens - llm = Ollama(model=self.model, base_url=self.ollama_host, request_timeout=180.0, context_window=30000) + llm = Ollama(model=self.valves.TEXT_TO_SQL_MODEL, base_url=self.valves.OLLAMA_HOST, request_timeout=180.0, context_window=30000) # Set up the custom prompt used when generating SQL queries from text text_to_sql_prompt = """ @@ -78,7 +99,7 @@ def pipe( query_engine = NLSQLTableQueryEngine( sql_database=sql_database, - tables=self.tables, + tables=self.valves.DB_TABLES, llm=llm, embed_model="local", text_to_sql_prompt=text_to_sql_template, @@ -88,4 +109,3 @@ def pipe( response = query_engine.query(user_message) return response.response_gen -