In [1]:
!pip install llama-index pymysql -q

In [19]:
import logging
import sys

logging.basicConfig(stream=sys.stdout, level=logging.INFO, force=True)
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))

In [20]:
db_user = "root"
db_password = "abc"
db_host = "mysql"
db_name = "classicmodels" #sampleDB

In [21]:
from sqlalchemy import create_engine, text

'''
Let us create a new SQLAlchemy engine to communicate with our database 
SQLAlchemy is a data access library.
'''

# Construct the connection string
connection_string = f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}"

# Create an engine instance
engine = create_engine(connection_string)

# Test the connection using raw SQL
with engine.connect() as connection:
    result = connection.execute(text("select * from customers limit 3"))
    for row in result:
        print(row)

(103, 'Atelier graphique', 'Schmitt', 'Carine ', '40.32.2555', '54, rue Royale', None, 'Nantes', None, '44000', 'France', 1370, Decimal('21000.00'))
(112, 'Signal Gift Stores', 'King', 'Jean', '7025551838', '8489 Strong St.', None, 'Las Vegas', 'NV', '83030', 'USA', 1166, Decimal('71800.00'))
(114, 'Australian Collectors, Co.', 'Ferguson', 'Peter', '03 9520 4555', '636 St Kilda Road', 'Level 3', 'Melbourne', 'Victoria', '3004', 'Australia', 1611, Decimal('117300.00'))


In [22]:
dict
#Note: tabele Description"stores customer's data."(optional)
table_details = {
    "customers": "stores customer’s data.",
    "products": "stores a list of scale model cars.",
    "productlines": "stores a list of product line categories.",
    "orders": "stores sales orders placed by customers.",
    "orderdetails": "stores sales order line items for each sales order.",
    "payments": "stores payments made by customers based on their accounts.",
    "employees": "stores all employee information as well as the organization structure such as who reports to whom.",
    "offices": "stores sales office data."
}

In [23]:
from llama_index.objects import ObjectIndex
from llama_index.objects import SQLTableNodeMapping, SQLTableSchema
import pandas as pd

'''
SQLTableNodeMapping object takes in a SQLDatabase and produces a Node object for each 
SQLTableSchema object passed into the ObjectIndex constructor.
'''
tables = list(sql_database._all_tables)
table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = []
for table in tables:
    table_schema_objs.append((SQLTableSchema(table_name = table, context_str = table_details[table])))

In [24]:
table_schema_objs

[SQLTableSchema(table_name='offices', context_str='stores sales office data.'),
 SQLTableSchema(table_name='products', context_str='stores a list of scale model cars.'),
 SQLTableSchema(table_name='payments', context_str='stores payments made by customers based on their accounts.'),
 SQLTableSchema(table_name='employees', context_str='stores all employee information as well as the organization structure such as who reports to whom.'),
 SQLTableSchema(table_name='productlines', context_str='stores a list of product line categories.'),
 SQLTableSchema(table_name='orderdetails', context_str='stores sales order line items for each sales order.'),
 SQLTableSchema(table_name='customers', context_str='stores customer’s data.'),
 SQLTableSchema(table_name='orders', context_str='stores sales orders placed by customers.')]

In [26]:
import tiktoken
from llama_index.callbacks import CallbackManager, TokenCountingHandler
token_counter = TokenCountingHandler(
    tokenizer=tiktoken.encoding_for_model("gpt-3.5-turbo").encode
)

callback_manager = CallbackManager([token_counter])

In [28]:
#llm _openai under the hood
import os
import openai
os.environ["OPENAI_API_KEY"] = "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"
openai.api_key = os.environ["OPENAI_API_KEY"]

In [29]:
from llama_index import ServiceContext, LLMPredictor, OpenAIEmbedding, PromptHelper
from llama_index.llms import OpenAI
from llama_index import ServiceContext
llm = OpenAI(temperature=0.1, model="gpt-3.5-turbo")

service_context = ServiceContext.from_defaults(
  llm=llm,callback_manager=callback_manager
)

In [30]:
from llama_index.indices.struct_store import SQLTableRetrieverQueryEngine
from llama_index import VectorStoreIndex

obj_index = ObjectIndex.from_objects(
    table_schema_objs,
    table_node_mapping,
    VectorStoreIndex,
    service_context=service_context
)

INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


In [31]:
'''
NLSQLTableQueryEngine puts the entire schema as prompt.
Putting all the table schemas into the prompt may overflow the text-to-SQL prompt.
We first index the schemas with our ObjectIndex, and then use our SQLTableRetrieverQueryEngine abstraction on top.
'''


query_engine = SQLTableRetrieverQueryEngine(
    sql_database, obj_index.as_retriever(similarity_top_k=3), service_context=service_context
)
response = query_engine.query("How many customers we have?")

INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
INFO:llama_index.indices.struct_store.sql_retriever:> Table desc str: Table 'customers' has columns: customerNumber (INTEGER), customerName (VARCHAR(50)), contactLastName (VARCHAR(50)), contactFirstName (VARCHAR(50)), phone (VARCHAR(50)), addressLine1 (VARCHAR(50)), addressLine2 (VARCHAR(50)), city (VARCHAR(50)), state (VARCHAR(50)), postalCode (VARCHAR(15)), country (VARCHAR(50)), salesRepEmployeeNumber (INTEGER), creditLimit (DECIMAL(10, 2)), and foreign keys: ['salesRepEmployeeNumber'] -> employees.['employeeNumber']. The table description is: stores customer’s data.

Table 'orders' has columns: orderNumber (INTEGER), orderDate (DATE), requiredDate (DATE), shippedDate (DATE), status (VARCHAR(15)), comments (TEXT), customerNumber (INTEGER), and foreign keys: ['customerNumber'] -> customers.['customerNumber']. The table descript

In [32]:
print(response)

We have a total of 122 customers.


In [33]:
print(token_counter.total_llm_token_count)

523


In [34]:
print(response.metadata['sql_query'])

SELECT COUNT(*) FROM customers
