In [1]:
# run this cell if langsmith is used to trace the chain
import os
from helper import get_api_key
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_API_KEY"] = get_api_key(1)

In [2]:
# using local installed llama_cpp_python
import multiprocessing

from langchain_community.chat_models import ChatLlamaCpp

local_model = "./models/xLAM-7b-fc-r.Q4_K_S.gguf"

llm = ChatLlamaCpp(
    temperature=0.2,
    model_path=local_model,
    n_ctx=10000,
    n_gpu_layers=80,
    n_batch=600,  # Should be between 1 and n_ctx, consider the amount of VRAM in your GPU.
    max_tokens=4096,
    n_threads=multiprocessing.cpu_count() - 1,
    repeat_penalty=1.5,
    top_p=0.5,
    verbose=True,
)



llama_model_loader: loaded meta data with 25 key-value pairs and 273 tensors from ./models/xLAM-7b-fc-r.Q4_K_S.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = llama
llama_model_loader: - kv   1:                               general.name str              = xlam-FC-7b
llama_model_loader: - kv   2:                          llama.block_count u32              = 30
llama_model_loader: - kv   3:                       llama.context_length u32              = 4096
llama_model_loader: - kv   4:                     llama.embedding_length u32              = 4096
llama_model_loader: - kv   5:                  llama.feed_forward_length u32              = 11008
llama_model_loader: - kv   6:                 llama.attention.head_count u32              = 32
llama_model_loader: - kv   7:              llama.attention.head_count_kv u32

In [1]:
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(
    openai_api_base="http://llama_server:8080/",
    temperature=0.3,
)

In [3]:
from langchain_community.utilities import SQLDatabase


db = SQLDatabase.from_uri("mysql://user:password@mysql_db:3306/classicmodels")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM products LIMIT 5;")

mysql
['customers', 'employees', 'offices', 'orderdetails', 'orders', 'payments', 'productlines', 'products']


"[('S10_1678', '1969 Harley Davidson Ultimate Chopper', 'Motorcycles', '1:10', 'Min Lin Diecast', 'This replica features working kickstand, front suspension, gear-shift lever, footbrake lever, drive chain, wheels and steering. All parts are particularly delicate due to their precise scale and require special care and attention.', 7933, Decimal('48.81'), Decimal('95.70')), ('S10_1949', '1952 Alpine Renault 1300', 'Classic Cars', '1:10', 'Classic Metal Creations', 'Turnable front wheels; steering function; detailed interior; detailed engine; opening hood; opening trunk; opening doors; and detailed chassis.', 7305, Decimal('98.58'), Decimal('214.30')), ('S10_2016', '1996 Moto Guzzi 1100i', 'Motorcycles', '1:10', 'Highway 66 Mini Classics', 'Official Moto Guzzi logos and insignias, saddle bags located on side of motorcycle, detailed engine, working steering, working suspension, two leather seats, luggage rack, dual exhaust pipes, small saddle bag located on handle bars, two-tone paint with

In [10]:
db.run("SELECT buyPrice from products where productName = '1969 Harley Davidson Ultimate Chopper';")

"[(Decimal('48.81'),)]"

In [25]:
db.run("SELECT firstName, lastName, jobTitle FROM employees WHERE jobTitle LIKE '%Marketing%';")

"[('Jeff', 'Firrelli', 'VP Marketing')]"

In [5]:
from langchain.chains import create_sql_query_chain

generate_query = create_sql_query_chain(llm,db)
query = generate_query.invoke({"question": "What is the price of the '1969 Harley Davidson Ultimate Chopper'?"})
print(query)


In [6]:
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
execute_query = QuerySQLDataBaseTool(db=db)
execute_query.invoke(query)

In [29]:
test_chain = generate_query | execute_query
test_chain.invoke({"question": "What is the price of the '1969 Harley Davidson Ultimate Chopper'?"})

Llama.generate: prefix-match hit

llama_print_timings:        load time =     340.28 ms
llama_print_timings:      sample time =      18.07 ms /    24 runs   (    0.75 ms per token,  1328.39 tokens per second)
llama_print_timings: prompt eval time =      45.79 ms /    26 tokens (    1.76 ms per token,   567.86 tokens per second)
llama_print_timings:        eval time =     357.54 ms /    23 runs   (   15.55 ms per token,    64.33 tokens per second)
llama_print_timings:       total time =     435.84 ms /    49 tokens


"[(Decimal('48.81'),)]"

In [30]:
print(test_chain.get_prompts()[0].pretty_print())

You are a MySQL expert. Given an input question, first create a syntactically correct MySQL query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CURDATE() function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the S

In [11]:
from operator import itemgetter

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough

answer_prompt = PromptTemplate.from_template(
    """Given the following user question, corresponding SQL query, and SQL result, answer the user question.

Question: {question}
SQL Query: {query}
SQL Result: {result}
Answer: """
)

# create first chain, where the template prompt get passed to the llm, the output is extract as str with output parser
rephrase_answer = answer_prompt | llm | StrOutputParser()

chain = (
    RunnablePassthrough.assign(query=generate_query).assign(
        result=itemgetter("query") | execute_query
    )
    | rephrase_answer
)

chain.invoke({"question": "What is the price of the '1969 Harley Davidson Ultimate Chopper'?"})

Llama.generate: prefix-match hit

llama_print_timings:        load time =     320.63 ms
llama_print_timings:      sample time =      18.21 ms /    24 runs   (    0.76 ms per token,  1318.25 tokens per second)
llama_print_timings: prompt eval time =    1417.89 ms /  2855 tokens (    0.50 ms per token,  2013.55 tokens per second)
llama_print_timings:        eval time =     360.97 ms /    23 runs   (   15.69 ms per token,    63.72 tokens per second)
llama_print_timings:       total time =    1812.61 ms /  2878 tokens
Llama.generate: prefix-match hit

llama_print_timings:        load time =     320.63 ms
llama_print_timings:      sample time =      21.63 ms /    26 runs   (    0.83 ms per token,  1201.87 tokens per second)
llama_print_timings: prompt eval time =      42.34 ms /    88 tokens (    0.48 ms per token,  2078.36 tokens per second)
llama_print_timings:        eval time =     265.11 ms /    25 runs   (   10.60 ms per token,    94.30 tokens per second)
llama_print_timings:       to

"The buy Price for a new '1969 Harlely Davison ultimate choppers' is $50."

In [12]:
print(chain.get_prompts()[0].pretty_print())

You are a MySQL expert. Given an input question, first create a syntactically correct MySQL query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CURDATE() function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the S

In [13]:
print(chain.get_prompts()[1].pretty_print())

Given the following user question, corresponding SQL query, and SQL result, answer the user question.

Question: [33;1m[1;3m{question}[0m
SQL Query: [33;1m[1;3m{query}[0m
SQL Result: [33;1m[1;3m{result}[0m
Answer: 
None


In [26]:
# get a list of example prompts from sql_examples.json
# read json file
import json
with open("sql_examples.json", "r") as read_file:
    examples = json.load(read_file)

In [21]:
from langchain_community.agent_toolkits import SQLDatabaseToolkit

toolkit = SQLDatabaseToolkit(db=db, llm=llm)
tools = toolkit.get_tools()

list_tables_tool = next(tool for tool in tools if tool.name == "sql_db_list_tables")
get_schema_tool = next(tool for tool in tools if tool.name == "sql_db_schema")

tables=list_tables_tool.invoke("")
print(tables)
# split the string by ", "
tables = tables.split(", ")

for table in tables:
    print(get_schema_tool.invoke(table))




customers, employees, offices, orderdetails, orders, payments, productlines, products

CREATE TABLE customers (
	`customerNumber` INTEGER NOT NULL, 
	`customerName` VARCHAR(50) NOT NULL, 
	`contactLastName` VARCHAR(50) NOT NULL, 
	`contactFirstName` VARCHAR(50) NOT NULL, 
	phone VARCHAR(50) NOT NULL, 
	`addressLine1` VARCHAR(50) NOT NULL, 
	`addressLine2` VARCHAR(50), 
	city VARCHAR(50) NOT NULL, 
	state VARCHAR(50), 
	`postalCode` VARCHAR(15), 
	country VARCHAR(50) NOT NULL, 
	`salesRepEmployeeNumber` INTEGER, 
	`creditLimit` DECIMAL(10, 2), 
	PRIMARY KEY (`customerNumber`), 
	CONSTRAINT customers_ibfk_1 FOREIGN KEY(`salesRepEmployeeNumber`) REFERENCES employees (`employeeNumber`)
)ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE utf8mb4_0900_ai_ci

/*
3 rows from customers table:
customerNumber	customerName	contactLastName	contactFirstName	phone	addressLine1	addressLine2	city	state	postalCode	country	salesRepEmployeeNumber	creditLimit
103	Atelier graphique	Schmitt	Carine 	40.32.2555	54,

In [47]:
for tool in tools:
    print(tool.name, tool.description)

sql_db_query Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column 'xxxx' in 'field list', use sql_db_schema to query the correct table fields.
sql_db_schema Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3
sql_db_list_tables Input is an empty string, output is a comma-separated list of tables in the database.
sql_db_query_checker Use this tool to double check if your query is correct before executing it. Always use this tool before executing a query with sql_db_query!
