- [How to connect LLM to SQL database with LlamaIndex | by Dishen Wang | Dataherald | Medium](https://medium.com/dataherald/how-to-connect-llm-to-sql-database-with-llamaindex-fae0e54de97c)

SQLite

- [sqlite3 — DB-API 2.0 interface for SQLite databases — Python 3.12.1 documentation](https://docs.python.org/3/library/sqlite3.html)

SQLAlchemy

- [Working with Engines and Connections — SQLAlchemy 2.0 Documentation](https://docs.sqlalchemy.org/en/20/core/connections.html)
- [Connections / Engines — SQLAlchemy 2.0 Documentation](https://docs.sqlalchemy.org/en/20/faq/connections.html)
- [pandas.DataFrame.to_sql — pandas 2.1.4 documentation](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.to_sql.html)

Pandas

- [pandas.DataFrame.to_sql — pandas 2.1.4 documentation](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.to_sql.html)

In [22]:
from sqlalchemy import create_engine, MetaData
import sqlite3
import pandas as pd
import duckdb

In [2]:
(sqlite_connect := sqlite3.connect(":memory:"))

<sqlite3.Connection at 0x233b47f45e0>

In [3]:
(db_engine := create_engine("sqlite:///:memory:"))

Engine(sqlite:///:memory:)

In [4]:
(df := pd.read_csv('../demo/demo.csv'))

Unnamed: 0,Trade #,Type,Signal,Date/Time,Price USD,Contracts,Profit USD,Profit %,Cum. Profit USD,Cum. Profit %,Run-up USD,Run-up %,Drawdown USD,Drawdown %
0,23,Exit Long,Open,,,,,,,,,,,
1,23,Entry Long,"{'take_profit': 1.0703532, 'stop_loss': 1.0684...",2023-11-13 16:30,1.06932,,,,,,,,,
2,22,Exit Long,exit,2023-11-10 22:50,1.06728,1852000.0,-2000.16,-0.1,9729.93,-0.2,203.72,0.01,2000.16,0.1
3,22,Entry Long,"{'take_profit': 1.0696884, 'stop_loss': 1.0672...",2023-11-10 22:45,1.06836,1852000.0,-2000.16,-0.1,9729.93,-0.2,203.72,0.01,2000.16,0.1
4,21,Exit Long,exit,2023-11-10 17:45,1.06769,3390000.0,2474.7,0.07,11730.09,0.25,2474.7,0.07,271.2,0.01
5,21,Entry Long,"{'take_profit': 1.0676857, 'stop_loss': 1.0663...",2023-11-10 17:40,1.06696,3390000.0,2474.7,0.07,11730.09,0.25,2474.7,0.07,271.2,0.01
6,20,Exit Short,exit,2023-11-08 22:30,1.06803,2312000.0,-2011.44,-0.08,9255.39,-0.2,138.72,0.01,2011.44,0.08
7,20,Entry Short,"{'take_profit': 1.06609605, 'stop_loss': 1.068...",2023-11-08 22:25,1.06716,2312000.0,-2011.44,-0.08,9255.39,-0.2,138.72,0.01,2011.44,0.08
8,19,Exit Long,exit,2023-11-08 21:25,1.06646,3030000.0,-1999.8,-0.06,11266.83,-0.2,424.2,0.01,1999.8,0.06
9,19,Entry Long,"{'take_profit': 1.0679318, 'stop_loss': 1.0664...",2023-11-08 21:15,1.06712,3030000.0,-1999.8,-0.06,11266.83,-0.2,424.2,0.01,1999.8,0.06


In [5]:
# TODO: this is not adding data to SQLite
df.to_sql('tbl2', db_engine)

46

In [10]:
# LegacyCursorResult
db_engine.execute("SELECT * FROM tbl2").fetchall()[:3]

[(0, 23, 'Exit Long', 'Open', None, None, None, None, None, None, None, None, None, None, None),
 (1, 23, 'Entry Long', "{'take_profit': 1.0703532, 'stop_loss': 1.06848, 'lots': 23.81}", '2023-11-13 16:30', 1.06932, None, None, None, None, None, None, None, None, None),
 (2, 22, 'Exit Long', 'exit', '2023-11-10 22:50', 1.06728, 1852000.0, -2000.16, -0.1, 9729.93, -0.2, 203.72, 0.01, 2000.16, 0.1)]

In [11]:
# NOTE: don't use preseved keyword as table name like "table"
df.to_sql('tbl', sqlite_connect)

46

In [12]:
sqlite_connect.total_changes

46

In [14]:
sqlite_connect.execute('SELECT * FROM tbl').fetchall()[:3]

[(0,
  23,
  'Exit Long',
  'Open',
  None,
  None,
  None,
  None,
  None,
  None,
  None,
  None,
  None,
  None,
  None),
 (1,
  23,
  'Entry Long',
  "{'take_profit': 1.0703532, 'stop_loss': 1.06848, 'lots': 23.81}",
  '2023-11-13 16:30',
  1.06932,
  None,
  None,
  None,
  None,
  None,
  None,
  None,
  None,
  None),
 (2,
  22,
  'Exit Long',
  'exit',
  '2023-11-10 22:50',
  1.06728,
  1852000.0,
  -2000.16,
  -0.1,
  9729.93,
  -0.2,
  203.72,
  0.01,
  2000.16,
  0.1)]

In [44]:
# (cursor := sqlite_connect.cursor())

<sqlite3.Cursor at 0x204d2cc4ac0>

In [15]:
# cursor.execute('SELECT * FROM tbl').fetchall()

## LlamaIndex

In [16]:
from llama_index import LLMPredictor, ServiceContext, SQLDatabase, VectorStoreIndex
from llama_index.indices.struct_store import SQLTableRetrieverQueryEngine
from llama_index.objects import SQLTableNodeMapping, ObjectIndex, SQLTableSchema

In [23]:
metadata_obj = MetaData()
metadata_obj.reflect(db_engine)

metadata_obj

MetaData()

In [18]:
(sql_database := SQLDatabase(db_engine))

<llama_index.utilities.sql_wrapper.SQLDatabase at 0x233d9eb9510>

In [19]:
(table_node_mapping := SQLTableNodeMapping(sql_database))

<llama_index.objects.table_node_mapping.SQLTableNodeMapping at 0x233d9eb8450>

In [24]:
table_schema_objs = []
for table_name in metadata_obj.tables.keys():
    table_schema_objs.append(SQLTableSchema(table_name=table_name))

table_schema_objs

[SQLTableSchema(table_name='tbl2', context_str=None)]

In [26]:
from dotenv import load_dotenv
import os
load_dotenv(os.path.join("../.env"))

True

In [39]:
# LLM
# https://docs.llamaindex.ai/en/stable/examples/customization/llms/AzureOpenAI.html
from llama_index.llms import AzureOpenAI

import logging
import sys

logging.basicConfig(
    stream=sys.stdout, level=logging.INFO
)  # logging.DEBUG for more verbose output
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))

llm = AzureOpenAI(
    model="gpt-35-turbo",
    deployment_name=os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME"),
    api_key=os.getenv("AZURE_OPENAI_KEY"),
    azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
    api_version=os.getenv("AZURE_OPENAI_VERSION"),
)

llm

AzureOpenAI(callback_manager=<llama_index.callbacks.base.CallbackManager object at 0x00000233DCD2F610>, system_prompt=None, messages_to_prompt=<function messages_to_prompt at 0x00000233C2A38D60>, completion_to_prompt=<function default_completion_to_prompt at 0x00000233C2A80360>, output_parser=None, pydantic_program_mode=<PydanticProgramMode.DEFAULT: 'default'>, query_wrapper_prompt=None, model='gpt-35-turbo', temperature=0.1, max_tokens=None, additional_kwargs={}, max_retries=3, timeout=60.0, default_headers=None, reuse_client=True, api_key='2cef94e8d02545359f36dd45947308a3', api_base='https://api.openai.com/v1', api_version='2023-06-01-preview', engine='Streamlit', azure_endpoint='https://openaihzhou.openai.azure.com/', azure_deployment=None, use_azure_ad=False)

In [34]:
from llama_index import set_global_service_context

service_context = ServiceContext.from_defaults(
    llm=llm,
    # TODO: build azure embed_model
    embed_model=None
)

set_global_service_context(service_context)

Embeddings have been explicitly disabled. Using MockEmbedding.


In [38]:
# We dump the table schema information into a vector index. The vector index is stored within the context builder for future use.
obj_index = ObjectIndex.from_objects(
    table_schema_objs,
    table_node_mapping,
    VectorStoreIndex,
)

obj_index

<llama_index.objects.base.ObjectIndex at 0x233dbc12a90>

In [41]:
(llm_predictor := LLMPredictor(llm=llm))

LLMPredictor(system_prompt=None, query_wrapper_prompt=None, pydantic_program_mode=<PydanticProgramMode.DEFAULT: 'default'>)

In [42]:
(service_context := ServiceContext.from_defaults(llm_predictor=llm_predictor))

ServiceContext(llm_predictor=LLMPredictor(system_prompt=None, query_wrapper_prompt=None, pydantic_program_mode=<PydanticProgramMode.DEFAULT: 'default'>), prompt_helper=PromptHelper(context_window=4096, num_output=256, chunk_overlap_ratio=0.1, chunk_size_limit=None, separator=' '), embed_model=MockEmbedding(model_name='unknown', embed_batch_size=10, callback_manager=<llama_index.callbacks.base.CallbackManager object at 0x00000233DCD0D010>, embed_dim=1), transformations=[SentenceSplitter(include_metadata=True, include_prev_next_rel=True, callback_manager=<llama_index.callbacks.base.CallbackManager object at 0x00000233DCD0D010>, id_func=<function default_id_func at 0x00000233D7EEC360>, chunk_size=1024, chunk_overlap=200, separator=' ', paragraph_separator='\n\n\n', secondary_chunking_regex='[^,.;。？！]+[,.;。？！]?')], llama_logger=<llama_index.logger.base.LlamaLogger object at 0x00000233C53E9490>, callback_manager=<llama_index.callbacks.base.CallbackManager object at 0x00000233DCD0D010>)

In [56]:
# We construct a SQLTableRetrieverQueryEngine. 
# Note that we pass in the ObjectRetriever so that we can dynamically retrieve the table during query-time.
# ObjectRetriever: A retriever that retrieves a set of query engine tools.
(query_engine := SQLTableRetrieverQueryEngine(
    sql_database,
    obj_index.as_retriever(similarity_top_k=1),
    service_context=service_context,
))

<llama_index.indices.struct_store.sql_query.SQLTableRetrieverQueryEngine at 0x233e21ff950>

In [49]:
(response := query_engine.query("What is the top 3 profit of long position?")) # This will failed

INFO:llama_index.indices.struct_store.sql_retriever:> Table desc str: Table 'tbl2' has columns: index (BIGINT), Trade # (BIGINT), Type (TEXT), Signal (TEXT), Date/Time (TEXT), Price USD (FLOAT), Contracts (FLOAT), Profit USD (FLOAT), Profit % (FLOAT), Cum. Profit USD (FLOAT), Cum. Profit % (FLOAT), Run-up USD (FLOAT), Run-up % (FLOAT), Drawdown USD (FLOAT), Drawdown % (FLOAT), and foreign keys: .
> Table desc str: Table 'tbl2' has columns: index (BIGINT), Trade # (BIGINT), Type (TEXT), Signal (TEXT), Date/Time (TEXT), Price USD (FLOAT), Contracts (FLOAT), Profit USD (FLOAT), Profit % (FLOAT), Cum. Profit USD (FLOAT), Cum. Profit % (FLOAT), Run-up USD (FLOAT), Run-up % (FLOAT), Drawdown USD (FLOAT), Drawdown % (FLOAT), and foreign keys: .
> Table desc str: Table 'tbl2' has columns: index (BIGINT), Trade # (BIGINT), Type (TEXT), Signal (TEXT), Date/Time (TEXT), Price USD (FLOAT), Contracts (FLOAT), Profit USD (FLOAT), Profit % (FLOAT), Cum. Profit USD (FLOAT), Cum. Profit % (FLOAT), Run-

Response(response='There are no results for the top 3 profit of long positions.', source_nodes=[NodeWithScore(node=TextNode(id_='ce5ef578-20b6-4da1-9ae3-83dc3e42b267', embedding=None, metadata={}, excluded_embed_metadata_keys=[], excluded_llm_metadata_keys=[], relationships={}, hash='b584bc881e7d8c69f5a470d5ea86521f389f3c0c7d43c6ca08d050bda8c56891', text='[]', start_char_idx=None, end_char_idx=None, text_template='{metadata_str}\n\n{content}', metadata_template='{key}: {value}', metadata_seperator='\n'), score=None)], metadata={'ce5ef578-20b6-4da1-9ae3-83dc3e42b267': {}, 'sql_query': 'SELECT "Profit USD" \nFROM tbl2 \nWHERE Type = \'long\' \nORDER BY "Profit USD" DESC \nLIMIT 3', 'result': [], 'col_keys': ['Profit USD']})

In [50]:
print(response)
print(response.metadata['sql_query'])
print(response.metadata['result'])

There are no results for the top 3 profit of long positions.
SELECT "Profit USD" 
FROM tbl2 
WHERE Type = 'long' 
ORDER BY "Profit USD" DESC 
LIMIT 3
[]


In [None]:
(response := query_engine.query("What is the top 3 profit of long position (Exit Long)?"))

In [48]:
print(response)
print(response.metadata['sql_query'])
print(response.metadata['result'])

The top 3 profits for the "Exit Long" position are $2490.84, $2486.76, and $2485.59.
SELECT "Profit USD" 
FROM tbl2 
WHERE Type = "Exit Long" 
ORDER BY "Profit USD" DESC 
LIMIT 3
[(2490.84,), (2486.76,), (2485.59,)]


In [54]:
df[df["Type"] == "Exit Long"].sort_values('Profit USD', ascending=False).iloc[:3]

Unnamed: 0,Trade #,Type,Signal,Date/Time,Price USD,Contracts,Profit USD,Profit %,Cum. Profit USD,Cum. Profit %,Run-up USD,Run-up %,Drawdown USD,Drawdown %
24,11,Exit Long,exit,2023-10-16 19:55,1.05373,3774000.0,2490.84,0.06,4844.74,0.25,2490.84,0.06,679.32,0.02
26,10,Exit Long,exit,2023-10-16 16:45,1.05394,2703000.0,2486.76,0.09,2353.9,0.25,2486.76,0.09,1811.01,0.06
10,18,Exit Long,exit,2023-11-08 02:15,1.06874,2857000.0,2485.59,0.08,13266.63,0.25,2485.59,0.08,885.67,0.03


In [59]:
from llama_index.indices.struct_store import NLSQLTableQueryEngine

(nl_query_engine := NLSQLTableQueryEngine(sql_database,
    service_context=service_context,
))

<llama_index.indices.struct_store.sql_query.NLSQLTableQueryEngine at 0x233dcd27b50>

In [61]:
(response := nl_query_engine.query("What is the top 3 profit of long position (Exit Long)?"))

INFO:llama_index.indices.struct_store.sql_retriever:> Table desc str: Table 'tbl2' has columns: index (BIGINT), Trade # (BIGINT), Type (TEXT), Signal (TEXT), Date/Time (TEXT), Price USD (FLOAT), Contracts (FLOAT), Profit USD (FLOAT), Profit % (FLOAT), Cum. Profit USD (FLOAT), Cum. Profit % (FLOAT), Run-up USD (FLOAT), Run-up % (FLOAT), Drawdown USD (FLOAT), Drawdown % (FLOAT), and foreign keys: .
> Table desc str: Table 'tbl2' has columns: index (BIGINT), Trade # (BIGINT), Type (TEXT), Signal (TEXT), Date/Time (TEXT), Price USD (FLOAT), Contracts (FLOAT), Profit USD (FLOAT), Profit % (FLOAT), Cum. Profit USD (FLOAT), Cum. Profit % (FLOAT), Run-up USD (FLOAT), Run-up % (FLOAT), Drawdown USD (FLOAT), Drawdown % (FLOAT), and foreign keys: .
> Table desc str: Table 'tbl2' has columns: index (BIGINT), Trade # (BIGINT), Type (TEXT), Signal (TEXT), Date/Time (TEXT), Price USD (FLOAT), Contracts (FLOAT), Profit USD (FLOAT), Profit % (FLOAT), Cum. Profit USD (FLOAT), Cum. Profit % (FLOAT), Run-

Response(response='The top 3 profits for the "Exit Long" position are $2490.84, $2486.76, and $2485.59.', source_nodes=[NodeWithScore(node=TextNode(id_='050e006a-be10-404b-b2ff-0c9b7fc492cc', embedding=None, metadata={}, excluded_embed_metadata_keys=[], excluded_llm_metadata_keys=[], relationships={}, hash='908b5ffc1671954e1dd931c7a164430eddb7a7e0f41ba8e3edcfca7369b0cf58', text='[(2490.84,), (2486.76,), (2485.59,)]', start_char_idx=None, end_char_idx=None, text_template='{metadata_str}\n\n{content}', metadata_template='{key}: {value}', metadata_seperator='\n'), score=None)], metadata={'050e006a-be10-404b-b2ff-0c9b7fc492cc': {}, 'sql_query': 'SELECT "Profit USD" \nFROM tbl2 \nWHERE Type = "Exit Long" \nORDER BY "Profit USD" DESC \nLIMIT 3', 'result': [(2490.84,), (2486.76,), (2485.59,)], 'col_keys': ['Profit USD']})

In [64]:
print(response)
print(response.metadata['sql_query'])

The top 3 profits for the "Exit Long" position are $2490.84, $2486.76, and $2485.59.
SELECT "Profit USD" 
FROM tbl2 
WHERE Type = "Exit Long" 
ORDER BY "Profit USD" DESC 
LIMIT 3
