# Text-To-Sql Demo

In [None]:
!git clone https://github.com/iuliabanu/BDNSV.git

In [None]:
%cd /content/BDNSVW/TextToSql

In [None]:
!pip install -q -r requirements.txt

### Test databases.

Load test_databases from BIRD benchmark from [url_bird]("https://bird-bench.oss-cn-beijing.aliyuncs.com/dev.zip").

We will use formula_1 and student_club databases. You can inspect database description for [formula_1](data/formula_1/) and
for [student_club](data/student_club).

Test some basic interaction wiht the databases using [sqlalchemy](https://www.sqlalchemy.org/) or langchain_community.utilities.SQLDatabase.
You can install all the requirments by running:

pip install -r requirements.txt

In [None]:
from sqlalchemy import create_engine, inspect, text
from sqlalchemy.engine import Engine


# The path to SQLite file
db_path = "data/formula_1/formula_1.sqlite"

# SQLAlchemy engine to connect to the database file
engine: Engine = create_engine(f"sqlite:///{db_path}")

# Run a query using the engine
query = "SELECT * FROM drivers LIMIT 10;"

with engine.connect() as conn:
    result = conn.execute(text(query))
    rows = result.fetchall()

# Print results
for row in rows:
    print(row)

# Inspect schema
inspector = inspect(engine)
print(inspector.get_table_names())



In [None]:
from langchain_community.utilities import SQLDatabase
db = SQLDatabase(engine=engine)

### Pydantic BaseModes

Pydantic **BaseModel**s are Python classes that define a structured way to handle, validate, and parse data. They are fundamental to the Pydantic library, a data validation and settings management tool that relies on Python type hints.

Essentially, a BaseModel allows you to define a schema for data, and Pydantic ensures that data that is passed to that model conforms to the specified types and constraints.

Key Features
- Type validation: Ensures that data matches the expected types.
- Automatic parsing: Converts input types (e.g., strings to integers).
- Serialization: Easily convert models to and from JSON/dictionaries.
- Error reporting: Gives detailed error messages when validation fails.

Use cases:
- FastAPI request/response models
- Configuration files
- Data validation in ETL pipelines
- Typed data structures in ML projects

We will define a State - Model for the base functionalities of a text-to-sql pipeline: the nl question, the query generated by the LLM model, the result of the execution of the query and the final answer returned the user.

In [None]:
from pydantic import BaseModel

class State(BaseModel):
    question: str
    query: str = ""
    result: str = ""
    answer: str = ""

### Groq

[Groq project](https://groq.com/) is foucuse on AI inference acceleration. Groq solutions are based on the Language Processing Unit (LPU), a new category of processor, built specifically for running AI models fast and efficiently. GroqCard card used in both cloud and on-premise setups. Groq Cloud is a cloud-hosted AI inference platform powered by LPUs. It’s designed to serve generative AI models across text, audio, and vision with minimal latency and high throughput.

### Langchain

[LangChain](https://www.langchain.com/) is an open-source framework designed to build applications using Large Language Models (LLMs) in a modular and scalable way. It simplifies the creation of chains—structured sequences of operations involving LLMs, tools, and data sources.

[Core concepts](https://python.langchain.com/docs/concepts/)

- **Chains**: Sequences of steps (e.g., prompt → LLM → output → tool) that define how data flows through an application.
- **Agents**: Use a language model to choose a sequence of actions to take. Agents can interact with external resources via tools.
- Tools: A function (e.g., search engines, calculators, databases)  with an associated schema defining the function's name, description, and the arguments it accepts.
- Retriever: A component that returns relevant documents from a knowledge base in response to a query.
- Memory: Information about a conversation that is persisted so that it can be used in future conversations.

Use cases
- Text-to-SQL and database querying
- Chatbots and virtual assistants
- Document Q&A and summarization
- Code generation and debugging
- Autonomous agents and workflows

In [None]:
import os
from dotenv import load_dotenv
load_dotenv()

if not os.environ.get("GROQ_API_KEY"):
    print("API key for Groq is not set. Please check your .env file.")
else:
    print("API key loaded successfully.")


from langchain.chat_models import init_chat_model

llm = init_chat_model("llama-3.1-8b-instant", model_provider="groq")


**Prompt templates** help to translate user input and parameters into instructions for a language model. This can be used to guide a model's response, helping it understand the context and generate relevant and coherent language-based output.

**Tool chaining** Many AI applications interact directly with humans. In these cases, it is appropriate for models to respond in natural language. But what about cases where we want a model to also interact directly with systems, such as databases or an API? These systems often have a particular input schema; for example, APIs frequently have a required payload structure. This need motivates the concept of tool calling. You can use tool calling to request model responses that match a particular schema.

In [None]:
from langchain_core.prompts import ChatPromptTemplate

system_message = """
Given an input question, create a syntactically correct {dialect} query to
run to help find the answer. Unless the user specifies in his question a
specific number of examples they wish to obtain, always limit your query to
at most {top_k} results. You can order the results by a relevant column to
return the most interesting examples in the database.

Never query for all the columns from a specific table, only ask for a the
few relevant columns given the question.

Pay attention to use only the column names that you can see in the schema
description. Be careful to not query for columns that do not exist. Also,
pay attention to which column is in which table.

Only use the following tables:
{table_info}
"""

user_prompt = "Question: {input}"

query_prompt_template = ChatPromptTemplate(
    [("system", system_message), ("user", user_prompt)]
)

for message in query_prompt_template.messages:
    message.pretty_print()

In [None]:
from pydantic import Field

class QueryOutput(BaseModel):
    """Generated SQL query."""

    query: str = Field(description="Syntactically valid SQL query.")


def write_query(state: State):
    """Generate SQL query to fetch information."""

    prompt = query_prompt_template.invoke(
        {
            "dialect": db.dialect,
            "top_k": 1,
            "table_info": db.get_table_info(),
            "input": state.question
        }
    )

#invoke the llm
    #result = llm.invoke(prompt)
    #print(result)

    llm_with_tools = llm.bind_tools([QueryOutput])
    msg = llm_with_tools.invoke(prompt)
    state.query = msg.tool_calls[0]["args"]["query"]
    return msg.tool_calls[0]["args"]



In [None]:
state = State(question="How many drivers are there?")
write_query(state)

In [None]:
from langchain_community.tools.sql_database.tool import QuerySQLDatabaseTool, QuerySQLCheckerTool

def execute_query(state: State):
    """Execute SQL query."""
    checker_tool = QuerySQLCheckerTool(db=db, llm=llm)
    check_result = checker_tool.invoke(state.query)
    if "error" in check_result.lower():
        return {"result": f"Invalid SQL query: {check_result}"}

    execute_query_tool = QuerySQLDatabaseTool(db=db)
    exec_result = execute_query_tool.invoke(state.query)
    state.result = exec_result
    return {"result": execute_query_tool.invoke(state.query)}

In [None]:
execute_query(state)

In [None]:
def generate_answer(state: State):
    """Answer question using retrieved information as context."""
    prompt = (
        "Given the following user question, corresponding SQL query, "
        "and SQL result, answer the user question.\n\n"
        f"Question: {state.question}\n"
        f"SQL Query: {state.query}\n"
        f"SQL Result: {state.result}"
    )
    print(f"prompt: {prompt}")

    response = llm.invoke(prompt)
    return {"answer": response.content}

In [None]:
generate_answer(state)

A StateGraph object defines the structure of an agent as a "state machine". Each node can receive the current State as input and output an update to the state.

In [None]:
from langgraph.graph import START, StateGraph

graph_builder = StateGraph(State).add_sequence(
    [write_query, execute_query, generate_answer]
)
graph_builder.add_edge(START, "write_query")
graph = graph_builder.compile()

In [None]:
from IPython.display import Image, display

display(Image(graph.get_graph().draw_mermaid_png()))

In [None]:
initial_state = State(question="How many drivers are there?")

for step in graph.stream({"question": "How many drivers are there?"}, stream_mode="updates"):
    print(step)


In [None]:
import json

# Suppose your JSON is stored in a file
qustions_path = "data/dev.json"

with open(qustions_path, "r") as f:
    data = json.load(f)  # this should be a list of dicts


def filter_questions(data, db_id, difficulty):
    return [
        (item["question"], item["SQL"])
        for item in data
        if item.get("db_id") == db_id and item.get("difficulty") == difficulty
    ]


# Example usage:
questions = filter_questions(data, db_id="formula_1", difficulty="simple")
print(f"questions: {questions}")

In [None]:
class Eval(BaseModel):
    question: str
    query_ground_truth: str
    query_llm: str
    score_ea: int
    score_cm: int

In [None]:
def make_eval_list(filtered_list, max_questions, graph):
    evals = []
    skip = []

    for i, (question, ground_truth) in enumerate(filtered_list):
        if i >= max_questions:
            break
        try:
            state = State(question=question)

            result = graph.invoke(state)
            llm_query = result.get("query", "")

            evals.append(
                Eval(
                    question=question,
                    query_ground_truth=ground_truth,
                    query_llm=llm_query,
                    score_ea=0,
                    score_cm=0
                )
            )
        except Exception as e:
            print(f"Skipping question due to error: {e}")
            skip.append(
                Eval(question=question,
                     query_ground_truth=ground_truth,
                     query_llm="",
                     score_ea=0,
                     score_cm=0
                     )
            )
            continue
    return evals