# Natural Language SQL Database Query

Used langchain to implement the RAG for querying a SQL database.
Picked a model (`ollama:qwen3-coder:30b`) that was good with code and tool usage and ran locally.
([Model Info](../../model_info.html))

## Test Program
`sql_db_test_2.py`

### Steps
* Download the Chinook.db (music store DB)
* Create a System prompt
* Use an agent with a model that is tool capable.
* Context which has a reference to the database.
* Keep context in a memory saver.


In [None]:
def download_chinook_db():
    """Download the Chinook sample database if it doesn't already exist.""" 
    url = "https://storage.googleapis.com/benchmarks-artifacts/chinook/Chinook.db"
    local_path = pathlib.Path("Chinook.db")

    if local_path.exists():
        print(f"{local_path} already exists, skipping download.")
    else:
        response = requests.get(url, timeout=10)
        if response.status_code == 200:
            local_path.write_bytes(response.content)
            print(f"File downloaded and saved as {local_path}")
        else:
            print(f"Failed to download the file. Status code: {response.status_code}")

In [None]:
def get_chinook_db() -> SQLDatabase:
    """Connect to the Chinook database and return the SQLDatabase object."""
    db = SQLDatabase.from_uri("sqlite:///Chinook.db")

    print(f"Dialect: {db.dialect}")
    print(f"Available tables: {db.get_usable_table_names()}")
    print(f'Sample output: {db.run("SELECT * FROM Artist LIMIT 5;")}')
    return db

In [None]:
SYS_PROMPT_2 = """
You are a careful SQLite analyst.

Rules:
- Think step by step.
- When you need data, call the tool 'execute_sql' with one select statement.
- Read-only. No INSERT/UPDATE/DELETE/ALTER/DROP/CREATE/REPLACE/TRUNCATE statements allowed.
- Limit to 5 rows of output unless explicitly asked for more.
- If the tool returns an 'Error:', revise the query and try again.
- Prefer explicit column lists; avoid SELECT *.
"""

In [None]:
@dataclass
class RuntimeContext:
    db: SQLDatabase
    is_employee: bool = True

In [None]:
@tool
def exe_query(query: str) -> str:
    """Execute a SQL query against the Chinook database."""
    runtime = get_runtime(RuntimeContext)
    db = runtime.context.db
    try:
        return db.run(query)
    except Exception as e:
        return f"Error executing query: {e}"

In [None]:
LLM = "ollama:qwen3-coder:30b"


def query_agent(question: str) -> str:
    """Query the agent with the provided question."""
    agent = create_agent(
        model=LLM,
        tools=[exe_query],
        system_prompt=SYS_PROMPT_2,
        context_schema=RuntimeContext,
        checkpointer=InMemorySaver(),  # Optional: enables saving/restoring agent state
    )
    steps = []
    for step in agent.stream(
        {"messages": question},
        {"configurable": {"thread_id": "thread_1"}},
        context=RuntimeContext(db=db),
        stream_mode="values",
    ):
        step["messages"][-1].pretty_print()
        steps.append(step)


In [None]:
query_agent("Which genre on average has the longest tracks?")

## Output

File downloaded and saved as Chinook.db
Dialect: sqlite
Available tables: ['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']
Sample output: [(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains')]
================================ Human Message =================================

Which genre on average has the longest tracks?
================================== Ai Message ==================================

To determine which genre has the longest average track duration, I will follow these steps:

1. Join the Genre table with the Track table to associate each track with its genre.
2. Calculate the average duration (in milliseconds) for each genre.
3. Order the results by average duration in descending order.
4. Limit the output to the top result to identify the genre with the longest average track.

I will now execute the SQL query to find this information.
Tool Calls:
  exe_query (916b92b1-7427-480d-b6f0-890efb42fd5f)
 Call ID: 916b92b1-7427-480d-b6f0-890efb42fd5f
  Args:
    query: SELECT
    g.Name AS Genre,
    g.Name AS Genre,
    AVG(t.Milliseconds) AS AverageDuration
FROM
    Genre g
JOIN
    Track t ON g.GenreId = t.GenreId
GROUP BY
    g.Name
ORDER BY
    AverageDuration DESC
LIMIT 1;
================================= Tool Message =================================
Name: exe_query

[('Sci Fi & Fantasy', 2911783.0384615385)]
================================== Ai Message ==================================

The genre with the longest average track duration is "Sci Fi & Fantasy," with an average track length of approximately 2,911,783 milliseconds.